精彩呈现:基于深度学习的图像自动上色技术

文章目录

  • 1. 前言
  • 2.图像格式(RGB,HSV,Lab)
  • 2.1 RGB
  • 2.2 hsv
  • 2.3 Lab
  • 3. 生成对抗网络(GAN)
  • 3.1 生成网络(Unet)
  • 3.2 判别网络(resnet18)
  • 4. 数据集
  • 5. 模型训练与预测流程图
  • 5.1 训练流程图
  • 5.2 预测流程图
  • 6. 模型预测效果
  • 7. GUI界面制作
  • 8.代码下载
  • 如果有不懂的,欢迎下方评论,你还在为毕设课设烦恼吗?注意下方图片右下角水印,解决一切问题,欢迎咨询。
    右下角水印qq

    1. 前言

    本文基于pytorch和opencv使用生成对抗网络对灰度图像自动上色,然后可以对上色后的图片手动调节亮度对比度等信息,最后可以保存上色后的图像,闲话少说,先看一下效果,文章最后附有全部代码及数据集下载链接。

    灰度图自动上色

    b站视频地址:b站视频地址

    2.图像格式(RGB,HSV,Lab)

    2.1 RGB

    想要对灰度图片上色,首先要了解图像的格式,对于一副普通的图像通常为RGB格式的,即红、绿、蓝三个通道,可以使用opencv分离图像的三个通道,代码如下所示:

    import cv2
    
    img=cv2.imread('pic/7.jpg')
    B,G,R=cv2.split(img)
    cv2.imshow('img',img)
    cv2.imshow('B',B)
    cv2.imshow('G',G)
    cv2.imshow('R',R)
    cv2.waitKey(0)
    

    代码运行结果如下所示。

    2.2 hsv

    hsv是图像的另一种格式,其中h代表图像的色调,s代表饱和度,v代表图像亮度,可以通过调节h、s、v的值来改变图像的色调、饱和度、亮度等信息。
    同样可以使用opencv将图像从RGB格式转换成hsv格式。然后可以分离h、s、v三个通道并显示图像代码如下所示:

    import cv2
    
    img=cv2.imread('pic/7.jpg')
    hsv=cv2.cvtColor(img,cv2.COLOR_BGR2HSV)
    h,s,v=cv2.split(hsv)
    cv2.imshow('hsv',hsv)
    cv2.imshow('h',h)
    cv2.imshow('s',s)
    cv2.imshow('v',v)
    cv2.waitKey(0)
    

    运行结果如下所示:

    2.3 Lab

    Lab是图像的另一种格式,也是本文使用的格式,其中L代表灰度图像,a、b代表颜色通道,本文使用L通道灰度图作为输入,ab两个颜色通道作为输出,训练生成对抗网络,将图像由RGB格式转换成Lab格式的代码如下所示:

    import cv2
    
    img=cv2.imread('pic/7.jpg')
    Lab=cv2.cvtColor(img,cv2.COLOR_BGR2Lab)
    L,a,b=cv2.split(Lab)
    cv2.imshow('Lab',Lab)
    cv2.imshow('L',L)
    cv2.imshow('a',a)
    cv2.imshow('b',b)
    cv2.waitKey(0)
    

    3. 生成对抗网络(GAN)

    生成对抗网络主要包含两部分,分别是生成网络和判别网络。
    生成网络负责生成图像,判别网络负责鉴定生成图像的好坏,二者相辅相成,相互博弈。
    本文使用U-net作为生成网络,使用ResNet18作为判别网络。U-net网络的结构图如下所示:
    

    3.1 生成网络(Unet)

    pytorch构建unet网络的代码如下所示:

    class DownsampleLayer(nn.Module):
        def __init__(self,in_ch,out_ch):
            super(DownsampleLayer, self).__init__()
            self.Conv_BN_ReLU_2=nn.Sequential(
                nn.Conv2d(in_channels=in_ch,out_channels=out_ch,kernel_size=3,stride=1,padding=1),
                nn.BatchNorm2d(out_ch),
                nn.ReLU(),
                nn.Conv2d(in_channels=out_ch, out_channels=out_ch, kernel_size=3, stride=1,padding=1),
                nn.BatchNorm2d(out_ch),
                nn.ReLU()
            )
            self.downsample=nn.Sequential(
                nn.Conv2d(in_channels=out_ch,out_channels=out_ch,kernel_size=3,stride=2,padding=1),
                nn.BatchNorm2d(out_ch),
                nn.ReLU()
            )
    
        def forward(self,x):
            """
            :param x:
            :return: out输出到深层,out_2输入到下一层,
            """
            out=self.Conv_BN_ReLU_2(x)
            out_2=self.downsample(out)
            return out,out_2
    class UpSampleLayer(nn.Module):
    	def __init__(self,in_ch,out_ch):
    	   # 512-1024-512
    	   # 1024-512-256
    	   # 512-256-128
    	   # 256-128-64
    	   super(UpSampleLayer, self).__init__()
       self.Conv_BN_ReLU_2 = nn.Sequential(
           nn.Conv2d(in_channels=in_ch, out_channels=out_ch*2, kernel_size=3, stride=1,padding=1),
           nn.BatchNorm2d(out_ch*2),
           nn.ReLU(),
           nn.Conv2d(in_channels=out_ch*2, out_channels=out_ch*2, kernel_size=3, stride=1,padding=1),
           nn.BatchNorm2d(out_ch*2),
           nn.ReLU()
       )
       self.upsample=nn.Sequential(
           nn.ConvTranspose2d(in_channels=out_ch*2,out_channels=out_ch,kernel_size=3,stride=2,padding=1,output_padding=1),
           nn.BatchNorm2d(out_ch),
           nn.ReLU()
       )
    
    	def forward(self,x,out):
    	   '''
    	   :param x: 输入卷积层
    	   :param out:与上采样层进行cat
    	   :return:
    	   '''
    	   x_out=self.Conv_BN_ReLU_2(x)
    	   x_out=self.upsample(x_out)
    	   cat_out=torch.cat((x_out,out),dim=1)
    	   return cat_out
    class UNet(nn.Module):
        def __init__(self):
            super(UNet, self).__init__()
            out_channels=[2**(i+6) for i in range(5)] #[64, 128, 256, 512, 1024]
            #下采样
            self.d1=DownsampleLayer(3,out_channels[0])#3-64
            self.d2=DownsampleLayer(out_channels[0],out_channels[1])#64-128
            self.d3=DownsampleLayer(out_channels[1],out_channels[2])#128-256
            self.d4=DownsampleLayer(out_channels[2],out_channels[3])#256-512
            #上采样
            self.u1=UpSampleLayer(out_channels[3],out_channels[3])#512-1024-512
            self.u2=UpSampleLayer(out_channels[4],out_channels[2])#1024-512-256
            self.u3=UpSampleLayer(out_channels[3],out_channels[1])#512-256-128
            self.u4=UpSampleLayer(out_channels[2],out_channels[0])#256-128-64
            #输出
            self.o=nn.Sequential(
                nn.Conv2d(out_channels[1],out_channels[0],kernel_size=3,stride=1,padding=1),
                nn.BatchNorm2d(out_channels[0]),
                nn.ReLU(),
                nn.Conv2d(out_channels[0], out_channels[0], kernel_size=3, stride=1, padding=1),
                nn.BatchNorm2d(out_channels[0]),
                nn.ReLU(),
                nn.Conv2d(out_channels[0],3,3,1,1),
                nn.Sigmoid(),
                # BCELoss
            )
        def forward(self,x):
            out_1,out1=self.d1(x)
            out_2,out2=self.d2(out1)
            out_3,out3=self.d3(out2)
            out_4,out4=self.d4(out3)
            out5=self.u1(out4,out_4)
            out6=self.u2(out5,out_3)
            out7=self.u3(out6,out_2)
            out8=self.u4(out7,out_1)
            out=self.o(out8)
            return out
    
    
    

    3.2 判别网络(resnet18)

    resnet18的结构图如下所示:

    在pytorch内部自带resnet18模型,只需一行代码即可构建resnet18模型,然后还需要去除网络最后的全连接层,代码如下所示:

    from torchvision import models
    
    resnet18=models.resnet18(pretrained=False)
    del resnet18.fc
    
    print(resnet18)
    
    

    4. 数据集

    本文使用的是自然风景类的数据图片,在网站上爬取了大概1000多张数据图片,部分图片如下所示

    5. 模型训练与预测流程图

    5.1 训练流程图

    如下图所示,首先将RGB图像转换成Lab图像,然后将L通道作为生成网络输入,生成网络的输出为新的ab两通道,然后将图像原始的ab通道,与生成网络生成的ab通道输入判别网络中。

    5.2 预测流程图

    下图为模型的预测过程,在预测过程中判别网络已经没有作用了,首先将RGB图像转换成,Lab图像,接着将L灰度图输入生成网络可以得到新的ab通道图像,接着将L通道图像与生成的ab通道图像进行拼接(concate),拼接以后可以得到一张新的Lab图像,然后再将其转换成RGB格式,此时图像即为上色以后的图像。

    6. 模型预测效果

    下图为模型的预测效果。左侧的为灰度图像,中间的为原始的彩色图像,右侧的是模型上色以后的图像。整体上看,网络的上色效果还不错。
    




    7. GUI界面制作

    为了更加方便使用模型,本文使用pyqt5制作操作界面,其界面如下图所示:首先可以从电脑中加载图像,还可以切换上一张或者下一张,可以将图像灰度化显示。可以对其上色,然后可以调整上色后图像的H、S、V信息,最后支持图像导出,可以将上色后的图像保存到本地中。

    8.代码下载

    链接中包含了训练代码,测试代码,以及界面代码。此外还包含1000多张数据集,直接运行main.py程序即可弹出操作界面。
    代码及数据集下载链接

    物联沃分享整理
    物联沃-IOTWORD物联网 » 精彩呈现:基于深度学习的图像自动上色技术

    发表回复