【python】【PyTorch】torch中,张量的维度和表示,并详细解释代码
目录
【python】【PyTorch】torch中,张量的维度和表示,并详细解释代码
1. 张量的维度和形状概念
2. PyTorch 中张量的维度表示
代码示例 1:0维张量(标量)
代码示例 2:1维张量(向量)
代码示例 3:2维张量(矩阵)
代码示例 4:3维张量(例如 RGB 图像)
代码示例 5:4维张量(批次图像数据)
3. PyTorch 中常用的张量维度操作
1. unsqueeze(dim)
2. squeeze(dim)
3. view() 和 reshape()
4. 总结
【python】【PyTorch】torch中,张量的维度和表示,并详细解释代码
在 PyTorch 中,张量(tensor)是一个多维数组,用于表示数据。
张量的维度(dimension)和形状(shape)是理解如何在深度学习中组织数据的基础。
接下来,我会详细解释张量的维度、形状以及如何使用相关函数来操作张量,并提供一些具体的代码示例和详细解读。
1. 张量的维度和形状概念
维度(Dimension):张量的维度是其轴(axis)的数量。每个维度代表了数据的一个方向(例如,行或列)。
形状(Shape):张量的形状表示每个维度上的大小,通常是一个包含维度大小的元组。例如,形状 (3, 224, 224)
表示一个有 3 个通道(RGB),大小为 224×224 的图像。
2. PyTorch 中张量的维度表示
我们用 PyTorch 的张量(
torch.Tensor
)来表示数据,并通过ndimension()
或dim()
方法获取张量的维度。
形状可以通过
shape
属性获取,它返回一个元组,表示每个维度的大小。
代码示例 1:0维张量(标量)
import torch
# 创建一个0维张量(标量)
scalar = torch.tensor(5)
print(f"Scalar shape: {scalar.shape}") # 输出: torch.Size([])
print(f"Scalar dimension: {scalar.ndimension()}") # 输出: 0
解释:
()
,没有任何维度。scalar.ndimension()
返回 0,表示它是一个 0 维张量。代码示例 2:1维张量(向量)
# 创建一个1维张量(向量)
vector = torch.tensor([1, 2, 3, 4])
print(f"Vector shape: {vector.shape}") # 输出: torch.Size([4])
print(f"Vector dimension: {vector.ndimension()}") # 输出: 1
解释:
vector.shape
输出的是 torch.Size([4])
,表示它是一个包含 4 个元素的一维张量。vector.ndimension()
返回 1,表示它有 1 个维度。代码示例 3:2维张量(矩阵)
# 创建一个2维张量(矩阵)
matrix = torch.tensor([[1, 2], [3, 4], [5, 6]])
print(f"Matrix shape: {matrix.shape}") # 输出: torch.Size([3, 2])
print(f"Matrix dimension: {matrix.ndimension()}") # 输出: 2
解释:
matrix.shape
输出的是 torch.Size([3, 2])
,表示这是一个 3×2 的矩阵。matrix.ndimension()
返回 2,表示它有 2 个维度。代码示例 4:3维张量(例如 RGB 图像)
# 创建一个3维张量(RGB 图像)
tensor_3d = torch.randn(3, 224, 224) # 3 个通道,224x224 像素
print(f"3D tensor shape: {tensor_3d.shape}") # 输出: torch.Size([3, 224, 224])
print(f"3D tensor dimension: {tensor_3d.ndimension()}") # 输出: 3
解释:
tensor_3d.shape
输出的是 torch.Size([3, 224, 224])
,表示这个张量有 3 个通道,每个通道的尺寸是 224×224。tensor_3d.ndimension()
返回 3,表示它有 3 个维度。代码示例 5:4维张量(批次图像数据)
# 创建一个4维张量(图像批次)
batch_tensor = torch.randn(8, 3, 224, 224) # 8 张 RGB 图像,每张 224x224 像素
print(f"Batch tensor shape: {batch_tensor.shape}") # 输出: torch.Size([8, 3, 224, 224])
print(f"Batch tensor dimension: {batch_tensor.ndimension()}") # 输出: 4
解释:
batch_tensor.shape
输出的是 torch.Size([8, 3, 224, 224])
,表示这是一个 8 张图像的批次。batch_tensor.ndimension()
返回 4,表示它有 4 个维度。3. PyTorch 中常用的张量维度操作
1. unsqueeze(dim)
unsqueeze(dim)
在指定的维度 dim
上插入一个大小为 1 的新维度,通常用于添加批次维度(例如,将一张图像扩展为一个批次)。
# 创建一个形状为 (3, 224, 224) 的张量
image = torch.randn(3, 224, 224)
print(f"Original shape: {image.shape}") # 输出: torch.Size([3, 224, 224])
# 使用 unsqueeze(0) 在第 0 维插入一个新的维度
batch_image = image.unsqueeze(0)
print(f"New shape after unsqueeze(0): {batch_image.shape}") # 输出: torch.Size([1, 3, 224, 224])
解释:
unsqueeze(0)
将第 0 维(即最前面)插入一个大小为 1 的维度,通常用于将单张图像转换为批次大小为 1 的图像。image.unsqueeze(0)
后,形状变为 (1, 3, 224, 224)
,表示这是一个包含 1 张图像的批次。2. squeeze(dim)
squeeze(dim)
移除张量中所有大小为 1 的维度。
如果指定 dim
,它只会移除该维度为 1 的轴。
# 创建一个形状为 (1, 3, 224, 224) 的张量
tensor = torch.randn(1, 3, 224, 224)
print(f"Original shape: {tensor.shape}") # 输出: torch.Size([1, 3, 224, 224])
# 使用 squeeze(0) 移除第 0 维(批次维度)
squeezed_tensor = tensor.squeeze(0)
print(f"New shape after squeeze(0): {squeezed_tensor.shape}") # 输出: torch.Size([3, 224, 224])
解释:
squeeze(0)
会移除大小为 1 的第 0 维,通常用于去掉批次维度。squeeze(0)
操作后,形状变为 (3, 224, 224)
,表示这是一张单独的图像。3. view()
和 reshape()
view()
和 reshape()
用来调整张量的形状。它们的功能相似,但 reshape()
在内存布局不同的情况下更为灵活。
# 创建一个形状为 (3, 224, 224) 的张量
tensor = torch.randn(3, 224, 224)
print(f"Original shape: {tensor.shape}") # 输出: torch.Size([3, 224, 224])
# 使用 view() 将其展平为 3x(224*224) 的张量
flattened_tensor = tensor.view(3, -1) # -1 表示自动计算第二维的大小
print(f"Flattened shape: {flattened_tensor.shape}") # 输出: torch.Size([3, 50176])
解释:
view(3, -1)
将张量的形状转换为 3
行和 224*224=50176
列的形状,-1
表示让 PyTorch 自动计算第二维的大小。4. 总结
dim
):张量的阶数,表示张量有多少个轴。shape
):张量每个维度的大小,表示张量的结构。unsqueeze()
、squeeze()
、view()
和 reshape()
。作者:资源存储库