1. repeat_interleave(self: Tensor, repeats: _int, dim: Optional[_int]=None) 參數說明: self: 傳入的數據為tensor repeats: 複製的份數 dim: 要複製的維度,可設定為0/1/2..... 2. 例子 2 ...
1. repeat_interleave(self: Tensor, repeats: _int, dim: Optional[_int]=None)
參數說明:
self: 傳入的數據為tensor
repeats: 複製的份數
dim: 要複製的維度,可設定為0/1/2.....
2. 例子
2.1 Code
此處定義了一個4維tensor,要對第2個維度複製,由原來的1變為3,即將設定dim=1。
1 import torch 2 3 4 def function(): 5 data1 = torch.rand([2, 1, 3, 3]) 6 print("data1_shape: ", data1.shape) 7 print("data1: ", data1) 8 9 data2 = torch.repeat_interleave(data1, repeats=3, dim=1) 10 print("data2_shape: ", data2.shape) 11 print("data2: ", data2) 12 13 14 if __name__ == '__main__': 15 function()View Code
2.2 輸出顯示
即可看到輸入tensor形狀為[2, 1, 3, 3],經過repeat後,tensor變為[2, 3, 3, 3],併在第二維度上保持相同的數據。