首页 > 科技 >

🔥torch.cat()用法✨

发布时间:2025-03-23 07:09:07来源:

在PyTorch中,`torch.cat()` 是一个非常实用的函数,用于将多个张量按指定维度拼接在一起。简单来说,它就像是把不同的积木块拼成一个更大的结构!🤔

首先,确保你导入了PyTorch:`import torch` 。然后,假设你有两个形状相同的张量 `tensor1 = torch.tensor([[1, 2], [3, 4]])` 和 `tensor2 = torch.tensor([[5, 6], [7, 8]])`。如果你想沿行方向(即第0维)拼接它们,只需执行:

```python

result = torch.cat((tensor1, tensor2), dim=0)

```

结果会是:

```

tensor([[1, 2],

[3, 4],

[5, 6],

[7, 8]])

```

如果想沿列方向(即第1维)拼接,则设置 `dim=1`:

```python

result = torch.cat((tensor1, tensor2), dim=1)

```

输出变为:

```

tensor([[1, 2, 5, 6],

[3, 4, 7, 8]])

```

需要注意的是,所有张量的形状必须在非拼接维度上保持一致哦!💼

掌握这个小技巧,处理多维数据时会更加得心应手!🚀

免责声明:本答案或内容为用户上传,不代表本网观点。其原创性以及文中陈述文字和内容未经本站证实,对本文以及其中全部或者部分内容、文字的真实性、完整性、及时性本站不作任何保证或承诺,请读者仅作参考,并请自行核实相关内容。 如遇侵权请及时联系本站删除。