ABOUT ME

-

Today
-
Yesterday
-
Total
-
  • [PyTorch] pack_padded_sequence, pad_packed_sequence 설명
    Deep Learning 2022. 2. 5. 20:24
    반응형

    Referece

    https://gist.github.com/HarshTrivedi/f4e7293e941b17d19058f6fb90ab0fec

     

    Minimal tutorial on packing (pack_padded_sequence) and unpacking (pad_packed_sequence) sequences in pytorch.

    Minimal tutorial on packing (pack_padded_sequence) and unpacking (pad_packed_sequence) sequences in pytorch. - pad_packed_demo.py

    gist.github.com

    https://simonjisu.github.io/nlp/2018/07/05/packedsequence.html

     

    Pytorch 의 PackedSequence object 알아보기

    PackedSequence 란? 아래의 일련의 과정을 PackedSequence 라고 할 수 있다. NLP 에서 매 배치(batch)마다 고정된 문장의 길이로 만들어주기 위해서 토큰을 넣어야 한다. 아래 그림의 파란색 영역은 토큰이다

    simonjisu.github.io

     

    Start

    Seq2seq 구현 코드를 보던 중 pack_padded_sequence 과 pad_packed_sequence 라는 함수를 접하게 되었다. 어떻게 사용되는지 정리하기 위해 예제와 함께 살펴본다.

     

    먼저 다음과 같은 시퀀스가 있다.

    sequences = ['I like a banana and an apple .', 'I want to eat a banana .', 'There is an apple .']

    시퀀스에 대한 vocab은 다음과 같다. <pad>, <go>, <eos>는 special 토큰으로 자세한 설명은 여기에서 확인할 수 있다.

    ['<pad>', '<go>', '<eos>', '.', 'a', 'an', 'and', 'apple', 'banana', 'eat', 'i', 'is', 'like', 'there', 'to', 'want']

     

    컴퓨터는 자연어를 이해할 수 없으므로 생성된 vocab을 활용해 단어를 index로 변환한다.

    [[1, 10, 12, 4, 8, 6, 5, 7, 3, 2],
     [1, 10, 15, 14, 9, 4, 8, 3, 2],
     [1, 13, 11, 5, 7, 3, 2]]

     

    index화 된 문장을 보면 길이가 제각각인 것을 알 수 있다. 위의 데이터를 신경망의 입력으로 사용하려면 데이터의 크기가 모두 동일해야 한다. 최대 길이(여기서는 10)에 맞게 빈자리를 0으로 채우는 padding을 적용하면 다음과 같다. 아래 텐서를 seq_tensor라 부르고 텐서의 크기는 (3, 10) 이다.

    tensor([[ 1, 10, 12,  4,  8,  6,  5,  7,  3,  2],
            [ 1, 10, 15, 14,  9,  4,  8,  3,  2,  0],
            [ 1, 13, 11,  5,  7,  3,  2,  0,  0,  0]])

     

    입력 데이터는 준비가 되었고 파라미터를 아래와 같이 정의한다. vocab size는 16으로 정해져 있고 E, H는 임의로 정한 크기이다. 물론 B도 임의로 정할 수 있다. 여기서는 시퀀스가 3개이므로 3으로 정했다.

    V (vocab size) = 16

    E (embedding dim) = 12

    H (hidden dim) = 8

    B (batch size) = 3

    max_len = 10 (문장 최대 길이)

     

    모델을 생성한다.

    embed = Embedding(V, E)
    lstm = LSTM(E, H, batch_first=True)

     

    seq_tensor를 embed에 공급하여 크기가 (3, 10, 12)인 embedded_seq_tensor를 얻는다.

    # Embedding
    seq_tensor.shape # torch.Size([3, 10])
    embedded_seq_tensor = embed(seq_tensor)
    embedded_seq_tensor.shape # torch.Size([3, 10, 12])

     

    seq_tensor를 보면 padding 때문에 <pad> 토큰이 존재한다. 이 상태로 연산을 하게 되면 아무 의미 없는 <pad> 토큰까지 연산을 하게 되므로 비효율적이다. 따라서 <pad> 토큰을 계산하지 않고 병렬로 처리하기 위해 pack_padded_sequence를 적용한다. (이 때 문장은 길이가 긴 순서대로 정렬이 되어야 한다.)

    packed_input = pack_padded_sequence(embedded_seq_tensor, seq_lengths.cpu().numpy(), batch_first=True)
    packed_input.data.shape # torch.Size([26, 12])
    packed_input.batch_sizes # tensor([3, 3, 3, 3, 3, 3, 3, 2, 2, 1])
    sum(packed_input.batch_sizes) # tensor(26)

    packed_input.batch_sizes가 [3, 3, 3, 3, 3, 3, 3, 2, 2, 1] 이러한 형태인데 시각화하면 위의 그림과 같다. (좀 더 자세한 내용은 여기를 참고) 그리고 packed_input.data의 shape이 (26, 12) 인 것을 알 수 있다. 여기서 26은 batch_sizes를 모두 합한 크기이고 12는 위에서 정의한 embedding의 크기이다.

     

    packed한 후 LSTM에 공급한다.

    packed_output, (ht, ct) = lstm(packed_input)
    packed_output.data.shape # torch.Size([26, 8])
    ht.shape # torch.Size([1, 3, 8])
    ct.shape # torch.Size([1, 3, 8])

    packed_output.data를 살펴보면 다음과 같다.

    더보기
    tensor([[-5.6794e-01,  2.7689e-01,  1.0287e+00,  9.4165e-01,  6.9081e-02,
              4.6665e-01,  2.6159e-01, -1.4951e+00,  3.5094e-01, -1.3359e+00,
             -3.6471e-01, -1.0625e+00],
            [-5.6794e-01,  2.7689e-01,  1.0287e+00,  9.4165e-01,  6.9081e-02,
              4.6665e-01,  2.6159e-01, -1.4951e+00,  3.5094e-01, -1.3359e+00,
             -3.6471e-01, -1.0625e+00],
            [-5.6794e-01,  2.7689e-01,  1.0287e+00,  9.4165e-01,  6.9081e-02,
              4.6665e-01,  2.6159e-01, -1.4951e+00,  3.5094e-01, -1.3359e+00,
             -3.6471e-01, -1.0625e+00],
            [-2.6378e-01,  1.6530e+00,  2.0868e+00, -1.8516e+00, -2.6107e-01,
              7.5944e-01, -6.9303e-01, -6.4306e-02, -8.0219e-02,  1.3855e+00,
              1.0216e+00, -1.0879e+00],
            [-2.6378e-01,  1.6530e+00,  2.0868e+00, -1.8516e+00, -2.6107e-01,
              7.5944e-01, -6.9303e-01, -6.4306e-02, -8.0219e-02,  1.3855e+00,
              1.0216e+00, -1.0879e+00],
            [ 7.8942e-01, -1.8114e+00, -5.0175e-01,  1.4189e+00,  4.4967e-01,
              7.9247e-01, -8.0421e-02,  1.3570e+00, -2.4875e-01,  5.2472e-01,
              1.3979e-01, -4.2160e-01],
            [ 6.1198e-01,  1.2688e+00,  1.3549e+00,  5.2502e-02, -9.2254e-01,
             -9.4588e-01, -5.7345e-01, -9.6508e-01, -1.0735e+00,  9.9476e-01,
              1.6168e+00, -1.3874e+00],
            [ 4.9105e-01,  1.4185e+00, -3.3907e-01, -4.8275e-01,  7.5624e-01,
              2.3500e+00,  7.8906e-01, -4.9209e-01,  1.2486e+00, -1.0072e+00,
             -2.0285e-01, -8.4952e-01],
            [-1.1193e+00, -4.5886e-01, -1.0642e+00,  5.0395e-01, -1.0741e-01,
              7.7054e-01, -9.3702e-01,  1.1860e-01, -1.3007e+00, -1.3174e+00,
              2.1728e-01, -8.4409e-01],
            [ 9.0825e-01, -3.2593e-01,  8.5675e-01,  1.0214e-01,  1.7477e-01,
             -4.1466e-01, -3.4677e-01,  5.9111e-01,  2.6824e-02, -4.4657e-01,
             -1.7561e+00,  2.1294e-01],
            [ 1.1037e+00, -2.3495e-01, -7.8878e-02, -1.0911e-01, -9.2271e-02,
              1.2789e+00, -1.2131e+00,  4.4591e-01,  1.7414e+00,  7.9310e-01,
             -5.9679e-01,  2.1264e-01],
            [-8.5093e-01,  1.2810e+00, -8.5232e-01,  6.1296e-01, -9.0446e-01,
              1.2744e+00,  9.0589e-01,  3.8584e-02, -1.9473e-01, -8.3713e-01,
              4.2939e-01,  5.7211e-04],
            [-7.8793e-01, -5.6019e-01,  1.1485e+00,  1.5109e+00,  9.5753e-01,
              6.5259e-01,  7.5140e-01,  3.3920e-01, -2.4991e-02,  1.3870e+00,
             -1.1999e+00, -3.7376e-02],
            [ 1.3433e+00, -6.5546e-01,  8.6841e-01,  2.1329e+00,  9.5268e-01,
             -2.2266e-01,  1.6290e+00,  4.9487e-01, -1.2753e+00, -1.0202e+00,
             -1.5787e+00,  1.7448e+00],
            [-2.5664e-01, -9.0030e-01,  7.2246e-01, -1.8204e-01, -4.9761e-01,
             -1.1750e+00, -1.0461e-01, -3.8936e-01, -1.2106e+00,  1.7963e+00,
             -1.5432e+00, -9.9805e-01],
            [ 1.5903e+00, -8.0831e-01,  1.0260e+00, -1.3571e+00,  3.5445e-01,
              9.5258e-01, -1.2618e+00, -6.0503e-01, -1.7413e-01,  1.5377e+00,
             -1.3901e+00,  4.4937e-01],
            [ 9.0825e-01, -3.2593e-01,  8.5675e-01,  1.0214e-01,  1.7477e-01,
             -4.1466e-01, -3.4677e-01,  5.9111e-01,  2.6824e-02, -4.4657e-01,
             -1.7561e+00,  2.1294e-01],
            [-1.3962e+00,  2.6571e-02,  3.1700e-01, -6.4004e-01,  2.6758e+00,
             -1.2146e+00, -3.5804e-01,  3.3823e-01, -6.7654e-01,  1.1049e+00,
              9.9425e-01,  2.0593e-01],
            [-8.5093e-01,  1.2810e+00, -8.5232e-01,  6.1296e-01, -9.0446e-01,
              1.2744e+00,  9.0589e-01,  3.8584e-02, -1.9473e-01, -8.3713e-01,
              4.2939e-01,  5.7211e-04],
            [-7.8793e-01, -5.6019e-01,  1.1485e+00,  1.5109e+00,  9.5753e-01,
              6.5259e-01,  7.5140e-01,  3.3920e-01, -2.4991e-02,  1.3870e+00,
             -1.1999e+00, -3.7376e-02],
            [-1.2253e-01, -1.6157e+00,  4.8746e-01, -9.7390e-01, -1.0417e+00,
              6.4194e-01,  6.1655e-01,  9.0440e-01,  5.9711e-01,  1.7920e+00,
              8.7472e-01, -6.1827e-02],
            [-2.5664e-01, -9.0030e-01,  7.2246e-01, -1.8204e-01, -4.9761e-01,
             -1.1750e+00, -1.0461e-01, -3.8936e-01, -1.2106e+00,  1.7963e+00,
             -1.5432e+00, -9.9805e-01],
            [-1.3962e+00,  2.6571e-02,  3.1700e-01, -6.4004e-01,  2.6758e+00,
             -1.2146e+00, -3.5804e-01,  3.3823e-01, -6.7654e-01,  1.1049e+00,
              9.9425e-01,  2.0593e-01],
            [-1.3962e+00,  2.6571e-02,  3.1700e-01, -6.4004e-01,  2.6758e+00,
             -1.2146e+00, -3.5804e-01,  3.3823e-01, -6.7654e-01,  1.1049e+00,
              9.9425e-01,  2.0593e-01],
            [-1.2253e-01, -1.6157e+00,  4.8746e-01, -9.7390e-01, -1.0417e+00,
              6.4194e-01,  6.1655e-01,  9.0440e-01,  5.9711e-01,  1.7920e+00,
              8.7472e-01, -6.1827e-02],
            [-1.2253e-01, -1.6157e+00,  4.8746e-01, -9.7390e-01, -1.0417e+00,
              6.4194e-01,  6.1655e-01,  9.0440e-01,  5.9711e-01,  1.7920e+00,
              8.7472e-01, -6.1827e-02]], grad_fn=<PackPaddedSequenceBackward>)

     

    packed 된 텐서를 pad_packed_sequence를 사용해 unpack하는 코드는 아래와 같다.

    output, input_sizes = pad_packed_sequence(packed_output, batch_first=True)
    output.shape # torch.Size([3, 10, 8])
    input_sizes # tensor([10,  9,  7])

    output을 살펴보면 다음과 같다. 0으로 채워진 <pad>가 다시 있는 것을 확인할 수 있다.

    더보기
    tensor([[[ 0.0903, -0.2255,  0.0056,  0.0609, -0.2936,  0.0688,  0.1118,
              -0.0255],
             [-0.0762, -0.0881, -0.1027,  0.1277,  0.0267,  0.1366,  0.0778,
              -0.0532],
             [-0.0832,  0.2314, -0.0445, -0.0232,  0.0473,  0.2478,  0.2611,
              -0.1795],
             [-0.1828,  0.1067, -0.1392, -0.0563,  0.1719,  0.0160,  0.0251,
              -0.3677],
             [-0.1303,  0.0484, -0.2470,  0.0176,  0.0958,  0.0701, -0.1115,
              -0.1241],
             [-0.0611, -0.0657, -0.1954, -0.0832,  0.1803, -0.2150, -0.1865,
               0.0780],
             [ 0.1397, -0.0753, -0.1830,  0.1046,  0.1810, -0.1877, -0.1006,
               0.0671],
             [ 0.0502,  0.1151, -0.2515, -0.0405,  0.0377,  0.1292,  0.0359,
              -0.0316],
             [-0.1013, -0.0158, -0.1363,  0.0169,  0.2653, -0.2760, -0.0626,
              -0.0788],
             [-0.1739,  0.0112, -0.1262,  0.1003,  0.0847,  0.0521,  0.0714,
              -0.2660]],
    
            [[ 0.0903, -0.2255,  0.0056,  0.0609, -0.2936,  0.0688,  0.1118,
              -0.0255],
             [-0.0762, -0.0881, -0.1027,  0.1277,  0.0267,  0.1366,  0.0778,
              -0.0532],
             [ 0.0823, -0.2139, -0.2096,  0.2825, -0.0600, -0.1222, -0.0851,
               0.0552],
             [-0.0395, -0.0171, -0.2184,  0.0893,  0.0684,  0.2072,  0.0669,
               0.0112],
             [-0.0098,  0.0296, -0.2921,  0.1573,  0.0842,  0.2472,  0.0062,
               0.0194],
             [-0.0714,  0.0801, -0.3365,  0.0572,  0.1408,  0.0459, -0.0214,
              -0.2152],
             [-0.0779,  0.0676, -0.3487,  0.0538,  0.0828,  0.0727, -0.1083,
              -0.0593],
             [-0.1687, -0.0256, -0.1152,  0.1912,  0.2776, -0.3015, -0.0994,
              -0.1095],
             [-0.1899,  0.0035, -0.1219,  0.1380,  0.0990,  0.0473,  0.0658,
              -0.3102],
             [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
               0.0000]],
    
            [[ 0.0903, -0.2255,  0.0056,  0.0609, -0.2936,  0.0688,  0.1118,
              -0.0255],
             [ 0.0928, -0.4141,  0.0050,  0.1396, -0.3024, -0.0420, -0.0327,
              -0.0121],
             [-0.0738, -0.1471,  0.0233,  0.0067, -0.0592, -0.0411, -0.0543,
               0.0554],
             [ 0.1575, -0.0936, -0.1218,  0.1490,  0.0208, -0.1644, -0.0661,
               0.0589],
             [ 0.0563,  0.1035, -0.2255, -0.0374, -0.0823,  0.1467,  0.0467,
              -0.0416],
             [-0.0983, -0.0192, -0.1206,  0.0456,  0.2277, -0.2669, -0.0616,
              -0.0866],
             [-0.1743,  0.0072, -0.1156,  0.1077,  0.0611,  0.0531,  0.0713,
              -0.2749],
             [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
               0.0000],
             [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
               0.0000],
             [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
               0.0000]]], grad_fn=<TransposeBackward0>)

     

    Summary

    (B X max_len)                 -->  Embedding(V, E)  -->    (B X max_len X E)

    (B X max_len X E)            -->          Pack         -->    (batch_sum_seq_len X E)

    (batch_sum_seq_len X E)    -->         LSTM        -->    (batch_sum_seq_len X H)

    (batch_sum_seq_len X H)   -->        UnPack       -->    (B X max_len X H)

    반응형

    'Deep Learning' 카테고리의 다른 글

    batch size vs epoch vs iteration  (0) 2022.02.04

    댓글

Designed by Tistory.