torch.take는 input tensor들을 1차원으로 생각해서 처리한다. 따라서 batch 단위 tensor를 처리할 때는 torch.gather를 사용하는 것이 좋다. 

Returns a new tensor with the elements of input at the given indices. The input tensor is treated as if it were viewed as a 1-D tensor. The result takes the same shape as the indices.

다음 예시를 통해 torch.take와 torch.gather의 차이점을 확인할 수 있다.

 

예시는 batch 유저별 label과 negative sample(100개)의 logits을 구하기 위한 과정이다.

전체 예측(preds)는 유저별 모든 아이템에 대한 예측 정보를 가지고 있다.

따라서 전체 예측(preds) 중에서 label과 negative sample에 해당하는 index만 추출해야 한다.

 

이때, preds는 임의로 rand으로 생성했지만 logits으로 생각하면 된다.

candidates도 설명을 위해 dim 마다 모두 같은 값을 입력했지만 실제론 다른 값이 들어갈 것이다.

preds = torch.rand(256, 3708)
candidates = torch.arange(101, dtype=torch.long).unsqueeze(0).expand(256, 101)

preds
tensor([[0.6327, 0.4644, 0.6167,  ..., 0.7969, 0.1830, 0.3973],
        [0.7434, 0.8146, 0.8771,  ..., 0.1170, 0.5679, 0.8292],
        [0.0805, 0.9033, 0.5985,  ..., 0.5601, 0.1804, 0.3152],
        ...,
        [0.4856, 0.8060, 0.7188,  ..., 0.6322, 0.5385, 0.6622],
        [0.7099, 0.4325, 0.1462,  ..., 0.8371, 0.6800, 0.2073],
        [0.3820, 0.7253, 0.5862,  ..., 0.6846, 0.5780, 0.6198]])

candidates
tensor([[  0,   1,   2,  ...,  98,  99, 100],
        [  0,   1,   2,  ...,  98,  99, 100],
        [  0,   1,   2,  ...,  98,  99, 100],
        ...,
        [  0,   1,   2,  ...,  98,  99, 100],
        [  0,   1,   2,  ...,  98,  99, 100],
        [  0,   1,   2,  ...,  98,  99, 100]])

 

torch.take를 사용하게 된다면 preds를 1차원으로 인식하여 첫번째 유저의 예측 정보만 추출돼 원하는 결과를 얻을 수 없다.

torch.take(preds, candidates)
tensor([[0.6327, 0.4644, 0.6167,  ..., 0.2276, 0.5263, 0.1950],
        [0.6327, 0.4644, 0.6167,  ..., 0.2276, 0.5263, 0.1950],
        [0.6327, 0.4644, 0.6167,  ..., 0.2276, 0.5263, 0.1950],
        ...,
        [0.6327, 0.4644, 0.6167,  ..., 0.2276, 0.5263, 0.1950],
        [0.6327, 0.4644, 0.6167,  ..., 0.2276, 0.5263, 0.1950],
        [0.6327, 0.4644, 0.6167,  ..., 0.2276, 0.5263, 0.1950]])

 

이를 방지하기 위해선 torch.gather를 사용해야 한다.

torch.gather(preds, 1, candidates)
tensor([[0.6327, 0.4644, 0.6167,  ..., 0.2276, 0.5263, 0.1950],
        [0.7434, 0.8146, 0.8771,  ..., 0.2775, 0.0377, 0.5715],
        [0.0805, 0.9033, 0.5985,  ..., 0.8943, 0.5950, 0.4170],
        ...,
        [0.4856, 0.8060, 0.7188,  ..., 0.1387, 0.0982, 0.2764],
        [0.7099, 0.4325, 0.1462,  ..., 0.6371, 0.6319, 0.1468],
        [0.3820, 0.7253, 0.5862,  ..., 0.3460, 0.2899, 0.3546]])
 

torch.take — PyTorch 1.11.0 documentation

Shortcuts

pytorch.org

 

torch.gather — PyTorch 1.11.0 documentation

Shortcuts

pytorch.org