【python】【PyTorch】详细中文解释unsqueeze,代码和代码解读

目录

【python】【PyTorch】详细中文解释unsqueeze,代码和代码解读 

unsqueeze() 函数的作用:

语法:

unsqueeze() 操作示例:

示例 1:将一个一维张量转换为二维张量

示例 2:在最后一维插入一个新维度

示例 3:负索引插入维度

示例 4:将二维张量转为三维张量

总结:


【python】【PyTorch】详细中文解释unsqueeze,代码和代码解读 

在 PyTorch 中,unsqueeze() 是一个非常实用的函数,用于在张量的指定位置插入一个维度。

简而言之,unsqueeze() 通过增加一个长度为1的维度来扩展张量的维度。

unsqueeze() 函数的作用:

unsqueeze() 函数将一个张量的维度增加 1

这个函数常用于调整张量的形状,特别是在需要将一个二维或一维张量转换为更高维度的张量时。

语法:

torch.unsqueeze(input, dim)
  • input:输入张量。
  • dim:指定要插入新维度的位置。dim 是一个整数,表示新维度的位置,取值范围是 [-input.dim() - 1, input.dim()]。如果 dim 为负数,它表示从最后一个维度开始计数。
  • unsqueeze() 操作示例:

    示例 1:将一个一维张量转换为二维张量

    假设我们有一个一维张量 [1, 2, 3],我们希望通过 unsqueeze() 将其转换为一个二维张量,并在第 0 维度(最前面)插入一个新的维度。

    import torch
    
    # 创建一个一维张量
    x = torch.tensor([1, 2, 3])
    
    # 在第0维插入一个新的维度
    y = torch.unsqueeze(x, 0)
    
    print("Original shape:", x.shape)  # 原始张量形状
    print("New shape:", y.shape)  # 新张量形状
    print(y)
    

    输出:

    Original shape: torch.Size([3])
    New shape: torch.Size([1, 3])
    tensor([[1, 2, 3]])
    
  • 解释
  • 原始张量 x 的形状是 (3),表示这是一个包含 3 个元素的一维张量。
  • 使用 torch.unsqueeze(x, 0) 后,在张量的第 0 维插入了一个新的维度。结果是一个形状为 (1, 3) 的二维张量。
  • unsqueeze(0) 会在第一个维度(最前面)插入新的维度,表示这个张量现在有 1 行,3 列。
  • 示例 2:在最后一维插入一个新维度

    假设我们希望将张量 [1, 2, 3] 变成形状为 (3, 1) 的二维张量,我们可以在第 1 维(最后一维)插入一个新的维度。

    # 在第1维插入一个新的维度
    z = torch.unsqueeze(x, 1)
    
    print("Original shape:", x.shape)
    print("New shape:", z.shape)
    print(z)
    

    输出:

    Original shape: torch.Size([3])
    New shape: torch.Size([3, 1])
    tensor([[1],
            [2],
            [3]])
    
  • 解释
  • 原始张量 x 的形状是 (3),是一个一维张量。
  • 使用 torch.unsqueeze(x, 1) 后,在第 1 维(即最后一个维度)插入了一个新的维度。结果是一个形状为 (3, 1) 的二维张量,表示这个张量现在有 3 行,1 列。
  • 示例 3:负索引插入维度

    我们可以使用负数索引来指定维度的位置。负数表示从最后一个维度开始计数。

    # 在倒数第一维(最后一维)插入一个新的维度
    w = torch.unsqueeze(x, -1)
    
    print("Original shape:", x.shape)
    print("New shape:", w.shape)
    print(w)
    

    输出:

    Original shape: torch.Size([3])
    New shape: torch.Size([3, 1])
    tensor([[1],
            [2],
            [3]])
    
  • 解释
  • 使用 torch.unsqueeze(x, -1) 等同于使用 torch.unsqueeze(x, 1),在张量的最后一个维度插入了一个新的维度。
  • 结果是一个形状为 (3, 1) 的二维张量,表示张量现在有 3 行,1 列。
  • 示例 4:将二维张量转为三维张量

    如果我们有一个形状为 (2, 3) 的二维张量,并希望将其转换为三维张量(例如,插入一个维度表示批次大小),我们可以使用 unsqueeze()

    # 创建一个二维张量
    a = torch.tensor([[1, 2, 3], [4, 5, 6]])
    
    # 在第0维插入新维度
    b = torch.unsqueeze(a, 0)
    
    print("Original shape:", a.shape)
    print("New shape:", b.shape)
    print(b)
    

    输出:

    Original shape: torch.Size([2, 3])
    New shape: torch.Size([1, 2, 3])
    tensor([[[1, 2, 3],
             [4, 5, 6]]])
    
  • 解释
  • 原始张量 a 的形状是 (2, 3),表示它有 2 行,3 列。
  • 使用 torch.unsqueeze(a, 0) 后,在第 0 维(最前面)插入了一个新的维度,结果是一个形状为 (1, 2, 3) 的三维张量,表示这个张量现在有 1 个批次,2 行,3 列。
  • 总结:

  • unsqueeze() 函数用于增加张量的维度,可以通过指定维度位置插入一个新的维度(长度为 1)。
  • 常见应用:增加批次维度(例如将一维张量转换为二维张量)或调整张量形状以满足模型输入的要求。
  • 这个函数特别有用,在需要将张量的维度对齐时,或者在深度学习框架中调整数据形状时非常常见。
  • 作者:资源存储库

    物联沃分享整理
    物联沃-IOTWORD物联网 » 【python】【PyTorch】详细中文解释unsqueeze,代码和代码解读

    发表回复