使用Python GPU与VGG19模型实现图像风格转换的实践指南
一.图像风格转换技术介绍
图像风格转换的目的是将一幅图像的风格特征转换迁移到另一幅图像上,同时保留后者的内容信息.其核心目标是实现内容与风格的有效解耦和重组,使源图能够获得目标风格特征,创造出具有独特视觉效果的新图像。就是利用一种技术,比如要让毕加索将我画的画按照他的画画风格去画,最后得出来的图像就是对我画的图进行了一个风格转换的结果。
实现这种功能的方法有很多,大部分采用的关键技术都有特征提取与表示,损失函数设计和优化与生成。在特征提取与表示中,会利用预训练的卷积神经网络提取图像的特征。在损失函数设计中,会计算其内容损失、风格损失和总变差损失。在优化与生成中,会使用优化算法(如L-BFGS或Adam)对目标图像进行迭代优化,生成的图像会在每次迭代后逐渐接近目标风格,同时保持内容图像的语义信息。
二.VGG19模型
VGG19是一个深度卷积神经网络(CNN),由牛津大学的Visual Geometry Group(VGG)提出,因此得名VGG19。它在2014年的ImageNet Large Scale Visual Recognition Challenge(ILSVRC)中表现出色,成为当时最先进的图像识别模型之一。VGG19的设计目标是通过使用更小的卷积核和更深的网络结构来提高图像识别的准确性。
VGG19模型由五个卷积块和三个全连接层组成,它的优缺点如下:
优点:简单有效:VGG19的架构简单,易于理解和实现,同时在图像识别任务中表现出色。
强大的特征提取能力:由于其深度结构,VGG19能够捕捉图像中的复杂特征,使其在迁移学习中非常有用。
预训练模型:VGG19在ImageNet数据集上预训练的权重可以作为强大的特征提取器,用于各种计算机视觉任务,如目标检测、图像分割和风格迁移。
缺点:计算需求高:VGG19由于其深度和小卷积核的使用,需要大量的内存和计算资源,适合在硬件条件较好的环境中使用。
参数数量多:VGG19模型有超过1亿个参数,这使得模型的存储和训练成本较高。
VGG19在迁移学习中得到了广泛应用,其预训练模型在大型数据集(如ImageNet)上训练得到的特征提取能力,使其成为各种计算机视觉任务(如目标检测、图像分割和风格迁移)的强大基础。在工业领域,VGG19的预训练权重被用作强大的特征提取器,应用于从医学成像到自动驾驶汽车等多个领域。
三.主要代码分析
def load_image(img_path, max_size=400, shape=None):
image = Image.open(img_path).convert('RGB')
if max(image.size) > max_size:
size = max_size
else:
size = max(image.size)
if shape is not None:
size = shape
in_transform = transforms.Compose([
transforms.Resize(size),
transforms.ToTensor(),
transforms.Normalize((0.485, 0.456, 0.406),
(0.229, 0.224, 0.225))])
image = in_transform(image)[:3, :, :].unsqueeze(0)
return image.to(device)
加载和预处理图像:使用PIL库打开图像文件,并将其转换为RGB格式。将图像的尺寸进行限制,以防过大造成计算量大时间长的局面,也可以指定尺寸大小不进行限制。对图像进行限制后,将图像转换为PyTorch张量,并将像素值从0,255的范围缩放到0.0,1.0的范围使其数据保持一致性和稳定性。然后对图像进行标准化处理,使其符合预训练模型的输入要求,参数是基于图像网络数据集的均值和标准差。最后将图像张量的维度调整为适合神经网络输入的格式并移动到指定的设备(GPU或CPU)上并返回。
def im_convert(tensor):
image = tensor.to("cpu").clone().detach()
image = image.numpy().squeeze()
image = image.transpose(1, 2, 0)
image = image * np.array((0.229, 0.224, 0.225)) + np.array((0.485, 0.456, 0.406))
image = image.clip(0, 1)
return image
将图像张量转换为可显示的图像格式:将图像张量从当前设备移动到CPU上,并进行克隆和分离操作,以便进行后续的修改,然后转换为NumPy数组,并去掉批次维度。转换一下图像数据的维度顺序为(高度、宽度,通道数),以便于显示.接着对图像进行反标准化处理,将标准化后的像素值还原为原始的范围并使用clip函数将像素值限制在[0,1]范围内,确保图像数据的有效性。最后返回转换后的图像数组,即可以正常显示风格转换后的图像。
def get_features(image, model, layers=None):
if layers is None:
layers = {'0': 'conv1_1',
'5': 'conv2_1',
'10': 'conv3_1',
'19': 'conv4_1',
'21': 'conv4_2',
'28': 'conv5_1'}
features = {}
x = image
for name, layer in model._modules.items():
x = layer(x)
if name in layers:
features[layers[name]] = x
return features
定义模型和损失函数:如果没有提供提取特征的层名,则使用默认的(默认为VGG19模型中的一些常用层)层名映射。遍历模型的每一层,将输入图像通过该层进行处理。如果当前层的名称在layers字典中,则将改层的输出特征图保存到features字典中,使用layers字典中对于的键作为关键字,最后返回包含特定层特征图的字典。该函数是整个风格转换过程中最关键的部分,通过使用预训练的卷积神经网络,能够获取图像在不同层次上的特征表示。提取出的内容特征和风格特征分别用于计算内容损失和风格损失。内容特征确保目标图像在视觉内容上与内容图像相似,而风格特征则使目标图像在风格上接近风格图像。
def gram_matrix(tensor):
_, d, h, w = tensor.size()
tensor = tensor.view(d, h * w)
gram = torch.mm(tensor, tensor.t())
return gram
计算输入张量的格拉姆矩阵,用于表示风格特征。先获取张量的维度信息,包括通道数d、高度h和宽度w。再将特征图张量展平为二维矩阵,然后计算展平后的矩阵与其转置矩阵的矩阵乘积从而得到格拉姆矩阵,最后返回得到的格拉姆矩阵。通过格拉姆矩阵可以计算出风格损失,模型能够在不改变图像内容的情况下,将风格图像的风格迁移到目标图像上,实现风格转换的目标。
from torchvision.models import vgg19, VGG19_Weights
vgg = vgg19(weights=VGG19_Weights.IMAGENET1K_V1).features.to(device)
for param in vgg.parameters():
param.requires_grad_(False)
加载预训练的VGG模型:将其特征提取部分移动到指定的设备上。同时,将模型的所有参数设置为不需要梯度计算(在该程序中,不需要对预训练模型的参数进行更新,只需利用其特征提取能力)。
optimizer = optim.Adam([target], lr=0.003)
style_weight = 1e6
content_weight = 1
for i in range(1, 3001):
target_features = get_features(target, vgg)
content_loss = torch.mean((target_features['conv4_2'] - content_features['conv4_2']) ** 2)
style_loss = 0
for layer in style_grams:
target_feature = target_features[layer]
target_gram = gram_matrix(target_feature)
style_gram = style_grams[layer]
layer_style_loss = torch.mean((target_gram - style_gram) ** 2)
style_loss += layer_style_loss / style_grams[layer].numel()
total_loss = content_weight * content_loss + style_weight * style_loss
optimizer.zero_grad()
total_loss.backward()
optimizer.step()
if i % 500 == 0:
print(f'Iteration: {i}, Total loss: {total_loss.item()}')
优化目标图像:使用Adam优化器对目标图像进行优化,对风格损失和内容损失进行权重分配,在优化过程中更注重风格的迁移,内容仅需保留即可。在迭代过程中,首先提取目标图像的特征,这些特征将用来计算内容损失和风格损失。然后通过目标图像和内容图像在conv4_2层的特征图的差异来计算内容损失,对于每一层的风格特征,计算目标图像和风格图像的格拉姆矩阵的差异以得到该层的风格损失,将各层的风格损失累加,并除以相应的格拉姆矩阵的元素数量,以得到最终的风格损失,最后根据权重计算总损失。清除之前的梯度信息,反向传播计算目标图像对于总损失的梯度,根据计算得到的梯度更新目标图像的像素值,使其朝着降低总损失的方向变化。
四.运行部分结果展示
目标图像:
风格图像:
运行结果:
五.注意事项
改代码在对损失进行计算的时候用的是GPU,需要在有显卡的环境使用。迭代了3000次需要2分钟左右的时间,如果用CPU的话就需要十分钟左右。配置使用GPU的环境可以参考一下其他的文章,下面会给出用CPU和GPU的两个完整代码。
GPU:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import models, transforms
from PIL import Image
import numpy as np
import matplotlib.pyplot as plt
# 检查GPU是否可用
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f'Using device: {device}')
# 加载和预处理图像
def load_image(img_path, max_size=400, shape=None):
image = Image.open(img_path).convert('RGB')
if max(image.size) > max_size:
size = max_size
else:
size = max(image.size)
if shape is not None:
size = shape
in_transform = transforms.Compose([
transforms.Resize(size),
transforms.ToTensor(),
transforms.Normalize((0.485, 0.456, 0.406),
(0.229, 0.224, 0.225))])
image = in_transform(image)[:3, :, :].unsqueeze(0)
return image.to(device)
# 将图像张量转换为可显示的图像格式
def im_convert(tensor):
image = tensor.to("cpu").clone().detach()
image = image.numpy().squeeze()
image = image.transpose(1, 2, 0)
image = image * np.array((0.229, 0.224, 0.225)) + np.array((0.485, 0.456, 0.406))
image = image.clip(0, 1)
return image
# 加载内容图片和风格图片
content = load_image('C:/Users/Be fearless/Desktop/398f4912b45260cca24eb3ec9b37e711.jpg')
style = load_image('C:/Users/Be fearless/Desktop/R-C.jpg', shape=content.shape[-2:])
# 定义模型和损失函数
def get_features(image, model, layers=None):
if layers is None:
layers = {'0': 'conv1_1',
'5': 'conv2_1',
'10': 'conv3_1',
'19': 'conv4_1',
'21': 'conv4_2',
'28': 'conv5_1'}
features = {}
x = image
for name, layer in model._modules.items():
x = layer(x)
if name in layers:
features[layers[name]] = x
return features
def gram_matrix(tensor):
_, d, h, w = tensor.size()
tensor = tensor.view(d, h * w)
gram = torch.mm(tensor, tensor.t())
return gram
# 加载预训练的VGG模型
from torchvision.models import vgg19, VGG19_Weights
vgg = vgg19(weights=VGG19_Weights.IMAGENET1K_V1).features.to(device)
for param in vgg.parameters():
param.requires_grad_(False)
# 获取内容和风格特征
content_features = get_features(content, vgg)
style_features = get_features(style, vgg)
style_grams = {layer: gram_matrix(style_features[layer]) for layer in style_features}
# 初始化目标图像
target = content.clone().requires_grad_(True).to(device)
# 优化目标图像
optimizer = optim.Adam([target], lr=0.003)
style_weight = 1e6
content_weight = 1
for i in range(1, 3001):
target_features = get_features(target, vgg)
content_loss = torch.mean((target_features['conv4_2'] - content_features['conv4_2']) ** 2)
style_loss = 0
for layer in style_grams:
target_feature = target_features[layer]
target_gram = gram_matrix(target_feature)
style_gram = style_grams[layer]
layer_style_loss = torch.mean((target_gram - style_gram) ** 2)
style_loss += layer_style_loss / style_grams[layer].numel()
total_loss = content_weight * content_loss + style_weight * style_loss
optimizer.zero_grad()
total_loss.backward()
optimizer.step()
if i % 500 == 0:
print(f'Iteration: {i}, Total loss: {total_loss.item()}')
# 显示结果
target = target.cpu()
plt.imshow(im_convert(target))
plt.show()
CPU:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import models, transforms
from PIL import Image
import numpy as np
from torchvision.models import vgg19, VGG19_Weights
import matplotlib.pyplot as plt # 添加这行代码.
def load_image(img_path, max_size=400, shape=None):
image = Image.open(img_path).convert('RGB')
if max(image.size) > max_size:
size = max_size
else:
size = max(image.size)
if shape is not None:
size = shape
in_transform = transforms.Compose([
transforms.Resize(size),
transforms.ToTensor(),
transforms.Normalize((0.485, 0.456, 0.406),
(0.229, 0.224, 0.225))])
image = in_transform(image)[:3, :, :].unsqueeze(0)
return image
content = load_image('C:/Users/Be fearless/Desktop/398f4912b45260cca24eb3ec9b37e711.jpg')
style = load_image('C:/Users/Be fearless/Desktop/R-C.jpg', shape=content.shape[-2:])
def get_features(image, model, layers=None):
if layers is None:
layers = {'0': 'conv1_1',
'5': 'conv2_1',
'10': 'conv3_1',
'19': 'conv4_1',
'21': 'conv4_2', # content representation
'28': 'conv5_1'}
features = {}
x = image
for name, layer in model._modules.items():
x = layer(x)
if name in layers:
features[layers[name]] = x
return features
def gram_matrix(tensor):
_, d, h, w = tensor.size()
tensor = tensor.view(d, h * w)
gram = torch.mm(tensor, tensor.t())
return gram
vgg = vgg19(weights=VGG19_Weights.IMAGENET1K_V1).features
for param in vgg.parameters():
param.requires_grad_(False)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
vgg.to(device)
content_features = get_features(content, vgg)
style_features = get_features(style, vgg)
style_grams = {layer: gram_matrix(style_features[layer]) for layer in style_features}
target = content.clone().requires_grad_(True).to(device)
optimizer = optim.Adam([target], lr=0.003)
style_weight = 1e6
content_weight = 1
def im_convert(tensor):
image = tensor.to("cpu").clone().detach()
image = image.numpy().squeeze()
image = image.transpose(1, 2, 0)
image = image * np.array((0.229, 0.224, 0.225)) + np.array((0.485, 0.456, 0.406))
image = image.clip(0, 1)
return image
for i in range(1, 3001):
target_features = get_features(target, vgg)
content_loss = torch.mean((target_features['conv4_2'] - content_features['conv4_2']) ** 2)
style_loss = 0
for layer in style_grams:
target_feature = target_features[layer]
target_gram = gram_matrix(target_feature)
style_gram = style_grams[layer]
layer_style_loss = torch.mean((target_gram - style_gram) ** 2)
style_loss += layer_style_loss / style_grams[layer].numel()
total_loss = content_weight * content_loss + style_weight * style_loss
optimizer.zero_grad()
total_loss.backward()
optimizer.step()
if i % 500 == 0:
print('Iteration: {}, Total loss: {}'.format(i, total_loss.item()))
plt.imshow(im_convert(target))
plt.show()
plt.imshow(im_convert(target))
plt.show()
作者:VVYY要成为大神