-
[PyTorch] pack_padded_sequence, pad_packed_sequence 설명Deep Learning 2022. 2. 5. 20:24반응형
Referece
https://gist.github.com/HarshTrivedi/f4e7293e941b17d19058f6fb90ab0fec
https://simonjisu.github.io/nlp/2018/07/05/packedsequence.html
Start
Seq2seq 구현 코드를 보던 중 pack_padded_sequence 과 pad_packed_sequence 라는 함수를 접하게 되었다. 어떻게 사용되는지 정리하기 위해 예제와 함께 살펴본다.
먼저 다음과 같은 시퀀스가 있다.
sequences = ['I like a banana and an apple .', 'I want to eat a banana .', 'There is an apple .']
시퀀스에 대한 vocab은 다음과 같다. <pad>, <go>, <eos>는 special 토큰으로 자세한 설명은 여기에서 확인할 수 있다.
['<pad>', '<go>', '<eos>', '.', 'a', 'an', 'and', 'apple', 'banana', 'eat', 'i', 'is', 'like', 'there', 'to', 'want']
컴퓨터는 자연어를 이해할 수 없으므로 생성된 vocab을 활용해 단어를 index로 변환한다.
[[1, 10, 12, 4, 8, 6, 5, 7, 3, 2], [1, 10, 15, 14, 9, 4, 8, 3, 2], [1, 13, 11, 5, 7, 3, 2]]
index화 된 문장을 보면 길이가 제각각인 것을 알 수 있다. 위의 데이터를 신경망의 입력으로 사용하려면 데이터의 크기가 모두 동일해야 한다. 최대 길이(여기서는 10)에 맞게 빈자리를 0으로 채우는 padding을 적용하면 다음과 같다. 아래 텐서를 seq_tensor라 부르고 텐서의 크기는 (3, 10) 이다.
tensor([[ 1, 10, 12, 4, 8, 6, 5, 7, 3, 2], [ 1, 10, 15, 14, 9, 4, 8, 3, 2, 0], [ 1, 13, 11, 5, 7, 3, 2, 0, 0, 0]])
입력 데이터는 준비가 되었고 파라미터를 아래와 같이 정의한다. vocab size는 16으로 정해져 있고 E, H는 임의로 정한 크기이다. 물론 B도 임의로 정할 수 있다. 여기서는 시퀀스가 3개이므로 3으로 정했다.
V (vocab size) = 16
E (embedding dim) = 12
H (hidden dim) = 8
B (batch size) = 3
max_len = 10 (문장 최대 길이)
모델을 생성한다.
embed = Embedding(V, E) lstm = LSTM(E, H, batch_first=True)
seq_tensor를 embed에 공급하여 크기가 (3, 10, 12)인 embedded_seq_tensor를 얻는다.
# Embedding seq_tensor.shape # torch.Size([3, 10]) embedded_seq_tensor = embed(seq_tensor) embedded_seq_tensor.shape # torch.Size([3, 10, 12])
seq_tensor를 보면 padding 때문에 <pad> 토큰이 존재한다. 이 상태로 연산을 하게 되면 아무 의미 없는 <pad> 토큰까지 연산을 하게 되므로 비효율적이다. 따라서 <pad> 토큰을 계산하지 않고 병렬로 처리하기 위해 pack_padded_sequence를 적용한다. (이 때 문장은 길이가 긴 순서대로 정렬이 되어야 한다.)
packed_input = pack_padded_sequence(embedded_seq_tensor, seq_lengths.cpu().numpy(), batch_first=True) packed_input.data.shape # torch.Size([26, 12]) packed_input.batch_sizes # tensor([3, 3, 3, 3, 3, 3, 3, 2, 2, 1]) sum(packed_input.batch_sizes) # tensor(26)
packed_input.batch_sizes가 [3, 3, 3, 3, 3, 3, 3, 2, 2, 1] 이러한 형태인데 시각화하면 위의 그림과 같다. (좀 더 자세한 내용은 여기를 참고) 그리고 packed_input.data의 shape이 (26, 12) 인 것을 알 수 있다. 여기서 26은 batch_sizes를 모두 합한 크기이고 12는 위에서 정의한 embedding의 크기이다.
packed한 후 LSTM에 공급한다.
packed_output, (ht, ct) = lstm(packed_input) packed_output.data.shape # torch.Size([26, 8]) ht.shape # torch.Size([1, 3, 8]) ct.shape # torch.Size([1, 3, 8])
packed_output.data를 살펴보면 다음과 같다.
더보기tensor([[-5.6794e-01, 2.7689e-01, 1.0287e+00, 9.4165e-01, 6.9081e-02, 4.6665e-01, 2.6159e-01, -1.4951e+00, 3.5094e-01, -1.3359e+00, -3.6471e-01, -1.0625e+00], [-5.6794e-01, 2.7689e-01, 1.0287e+00, 9.4165e-01, 6.9081e-02, 4.6665e-01, 2.6159e-01, -1.4951e+00, 3.5094e-01, -1.3359e+00, -3.6471e-01, -1.0625e+00], [-5.6794e-01, 2.7689e-01, 1.0287e+00, 9.4165e-01, 6.9081e-02, 4.6665e-01, 2.6159e-01, -1.4951e+00, 3.5094e-01, -1.3359e+00, -3.6471e-01, -1.0625e+00], [-2.6378e-01, 1.6530e+00, 2.0868e+00, -1.8516e+00, -2.6107e-01, 7.5944e-01, -6.9303e-01, -6.4306e-02, -8.0219e-02, 1.3855e+00, 1.0216e+00, -1.0879e+00], [-2.6378e-01, 1.6530e+00, 2.0868e+00, -1.8516e+00, -2.6107e-01, 7.5944e-01, -6.9303e-01, -6.4306e-02, -8.0219e-02, 1.3855e+00, 1.0216e+00, -1.0879e+00], [ 7.8942e-01, -1.8114e+00, -5.0175e-01, 1.4189e+00, 4.4967e-01, 7.9247e-01, -8.0421e-02, 1.3570e+00, -2.4875e-01, 5.2472e-01, 1.3979e-01, -4.2160e-01], [ 6.1198e-01, 1.2688e+00, 1.3549e+00, 5.2502e-02, -9.2254e-01, -9.4588e-01, -5.7345e-01, -9.6508e-01, -1.0735e+00, 9.9476e-01, 1.6168e+00, -1.3874e+00], [ 4.9105e-01, 1.4185e+00, -3.3907e-01, -4.8275e-01, 7.5624e-01, 2.3500e+00, 7.8906e-01, -4.9209e-01, 1.2486e+00, -1.0072e+00, -2.0285e-01, -8.4952e-01], [-1.1193e+00, -4.5886e-01, -1.0642e+00, 5.0395e-01, -1.0741e-01, 7.7054e-01, -9.3702e-01, 1.1860e-01, -1.3007e+00, -1.3174e+00, 2.1728e-01, -8.4409e-01], [ 9.0825e-01, -3.2593e-01, 8.5675e-01, 1.0214e-01, 1.7477e-01, -4.1466e-01, -3.4677e-01, 5.9111e-01, 2.6824e-02, -4.4657e-01, -1.7561e+00, 2.1294e-01], [ 1.1037e+00, -2.3495e-01, -7.8878e-02, -1.0911e-01, -9.2271e-02, 1.2789e+00, -1.2131e+00, 4.4591e-01, 1.7414e+00, 7.9310e-01, -5.9679e-01, 2.1264e-01], [-8.5093e-01, 1.2810e+00, -8.5232e-01, 6.1296e-01, -9.0446e-01, 1.2744e+00, 9.0589e-01, 3.8584e-02, -1.9473e-01, -8.3713e-01, 4.2939e-01, 5.7211e-04], [-7.8793e-01, -5.6019e-01, 1.1485e+00, 1.5109e+00, 9.5753e-01, 6.5259e-01, 7.5140e-01, 3.3920e-01, -2.4991e-02, 1.3870e+00, -1.1999e+00, -3.7376e-02], [ 1.3433e+00, -6.5546e-01, 8.6841e-01, 2.1329e+00, 9.5268e-01, -2.2266e-01, 1.6290e+00, 4.9487e-01, -1.2753e+00, -1.0202e+00, -1.5787e+00, 1.7448e+00], [-2.5664e-01, -9.0030e-01, 7.2246e-01, -1.8204e-01, -4.9761e-01, -1.1750e+00, -1.0461e-01, -3.8936e-01, -1.2106e+00, 1.7963e+00, -1.5432e+00, -9.9805e-01], [ 1.5903e+00, -8.0831e-01, 1.0260e+00, -1.3571e+00, 3.5445e-01, 9.5258e-01, -1.2618e+00, -6.0503e-01, -1.7413e-01, 1.5377e+00, -1.3901e+00, 4.4937e-01], [ 9.0825e-01, -3.2593e-01, 8.5675e-01, 1.0214e-01, 1.7477e-01, -4.1466e-01, -3.4677e-01, 5.9111e-01, 2.6824e-02, -4.4657e-01, -1.7561e+00, 2.1294e-01], [-1.3962e+00, 2.6571e-02, 3.1700e-01, -6.4004e-01, 2.6758e+00, -1.2146e+00, -3.5804e-01, 3.3823e-01, -6.7654e-01, 1.1049e+00, 9.9425e-01, 2.0593e-01], [-8.5093e-01, 1.2810e+00, -8.5232e-01, 6.1296e-01, -9.0446e-01, 1.2744e+00, 9.0589e-01, 3.8584e-02, -1.9473e-01, -8.3713e-01, 4.2939e-01, 5.7211e-04], [-7.8793e-01, -5.6019e-01, 1.1485e+00, 1.5109e+00, 9.5753e-01, 6.5259e-01, 7.5140e-01, 3.3920e-01, -2.4991e-02, 1.3870e+00, -1.1999e+00, -3.7376e-02], [-1.2253e-01, -1.6157e+00, 4.8746e-01, -9.7390e-01, -1.0417e+00, 6.4194e-01, 6.1655e-01, 9.0440e-01, 5.9711e-01, 1.7920e+00, 8.7472e-01, -6.1827e-02], [-2.5664e-01, -9.0030e-01, 7.2246e-01, -1.8204e-01, -4.9761e-01, -1.1750e+00, -1.0461e-01, -3.8936e-01, -1.2106e+00, 1.7963e+00, -1.5432e+00, -9.9805e-01], [-1.3962e+00, 2.6571e-02, 3.1700e-01, -6.4004e-01, 2.6758e+00, -1.2146e+00, -3.5804e-01, 3.3823e-01, -6.7654e-01, 1.1049e+00, 9.9425e-01, 2.0593e-01], [-1.3962e+00, 2.6571e-02, 3.1700e-01, -6.4004e-01, 2.6758e+00, -1.2146e+00, -3.5804e-01, 3.3823e-01, -6.7654e-01, 1.1049e+00, 9.9425e-01, 2.0593e-01], [-1.2253e-01, -1.6157e+00, 4.8746e-01, -9.7390e-01, -1.0417e+00, 6.4194e-01, 6.1655e-01, 9.0440e-01, 5.9711e-01, 1.7920e+00, 8.7472e-01, -6.1827e-02], [-1.2253e-01, -1.6157e+00, 4.8746e-01, -9.7390e-01, -1.0417e+00, 6.4194e-01, 6.1655e-01, 9.0440e-01, 5.9711e-01, 1.7920e+00, 8.7472e-01, -6.1827e-02]], grad_fn=<PackPaddedSequenceBackward>)
packed 된 텐서를 pad_packed_sequence를 사용해 unpack하는 코드는 아래와 같다.
output, input_sizes = pad_packed_sequence(packed_output, batch_first=True) output.shape # torch.Size([3, 10, 8]) input_sizes # tensor([10, 9, 7])
output을 살펴보면 다음과 같다. 0으로 채워진 <pad>가 다시 있는 것을 확인할 수 있다.
더보기tensor([[[ 0.0903, -0.2255, 0.0056, 0.0609, -0.2936, 0.0688, 0.1118, -0.0255], [-0.0762, -0.0881, -0.1027, 0.1277, 0.0267, 0.1366, 0.0778, -0.0532], [-0.0832, 0.2314, -0.0445, -0.0232, 0.0473, 0.2478, 0.2611, -0.1795], [-0.1828, 0.1067, -0.1392, -0.0563, 0.1719, 0.0160, 0.0251, -0.3677], [-0.1303, 0.0484, -0.2470, 0.0176, 0.0958, 0.0701, -0.1115, -0.1241], [-0.0611, -0.0657, -0.1954, -0.0832, 0.1803, -0.2150, -0.1865, 0.0780], [ 0.1397, -0.0753, -0.1830, 0.1046, 0.1810, -0.1877, -0.1006, 0.0671], [ 0.0502, 0.1151, -0.2515, -0.0405, 0.0377, 0.1292, 0.0359, -0.0316], [-0.1013, -0.0158, -0.1363, 0.0169, 0.2653, -0.2760, -0.0626, -0.0788], [-0.1739, 0.0112, -0.1262, 0.1003, 0.0847, 0.0521, 0.0714, -0.2660]], [[ 0.0903, -0.2255, 0.0056, 0.0609, -0.2936, 0.0688, 0.1118, -0.0255], [-0.0762, -0.0881, -0.1027, 0.1277, 0.0267, 0.1366, 0.0778, -0.0532], [ 0.0823, -0.2139, -0.2096, 0.2825, -0.0600, -0.1222, -0.0851, 0.0552], [-0.0395, -0.0171, -0.2184, 0.0893, 0.0684, 0.2072, 0.0669, 0.0112], [-0.0098, 0.0296, -0.2921, 0.1573, 0.0842, 0.2472, 0.0062, 0.0194], [-0.0714, 0.0801, -0.3365, 0.0572, 0.1408, 0.0459, -0.0214, -0.2152], [-0.0779, 0.0676, -0.3487, 0.0538, 0.0828, 0.0727, -0.1083, -0.0593], [-0.1687, -0.0256, -0.1152, 0.1912, 0.2776, -0.3015, -0.0994, -0.1095], [-0.1899, 0.0035, -0.1219, 0.1380, 0.0990, 0.0473, 0.0658, -0.3102], [ 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000]], [[ 0.0903, -0.2255, 0.0056, 0.0609, -0.2936, 0.0688, 0.1118, -0.0255], [ 0.0928, -0.4141, 0.0050, 0.1396, -0.3024, -0.0420, -0.0327, -0.0121], [-0.0738, -0.1471, 0.0233, 0.0067, -0.0592, -0.0411, -0.0543, 0.0554], [ 0.1575, -0.0936, -0.1218, 0.1490, 0.0208, -0.1644, -0.0661, 0.0589], [ 0.0563, 0.1035, -0.2255, -0.0374, -0.0823, 0.1467, 0.0467, -0.0416], [-0.0983, -0.0192, -0.1206, 0.0456, 0.2277, -0.2669, -0.0616, -0.0866], [-0.1743, 0.0072, -0.1156, 0.1077, 0.0611, 0.0531, 0.0713, -0.2749], [ 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000], [ 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000], [ 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000]]], grad_fn=<TransposeBackward0>)
Summary
(B X max_len) --> Embedding(V, E) --> (B X max_len X E)
(B X max_len X E) --> Pack --> (batch_sum_seq_len X E)
(batch_sum_seq_len X E) --> LSTM --> (batch_sum_seq_len X H)
(batch_sum_seq_len X H) --> UnPack --> (B X max_len X H)
반응형'Deep Learning' 카테고리의 다른 글
batch size vs epoch vs iteration (0) 2022.02.04