pfcfuse/componets/DEConv.py
zjut 927d049f13 feat(components): 添加多种特征增强模块
- 新增 DEConv、DynamicFilter、SCSA、SEBlock、SMFA、TIAM 和 UFFC 模块
- 这些模块提供了不同的特征增强功能,如卷积差分、频域滤波、注意力机制等
- 可以根据需求选择合适的模块来提升模型性能
2024-11-15 09:18:34 +08:00

172 lines
9.4 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# https://github.com/cecret3350/DEA-Net/blob/main/code/model/modules/deconv.py
import math
import torch
from torch import nn
from einops.layers.torch import Rearrange
"""
在深度学习和图像处理领域,"vanilla""difference" 卷积是两种不同的卷积操作它们各自有不同的特性和应用场景。DEConv细节增强卷积的设计思想是结合这两种卷积的特性来增强模型的性能尤其是在图像去雾等任务中。
Vanilla Convolution标准卷积
"Vanilla" 卷积是最基本的卷积类型通常仅称为卷积。它是卷积神经网络CNN中最常用的组件用于提取输入数据如图像的特征。
标准卷积通过在输入数据上滑动小的、可学习的过滤器(或称为核),并计算过滤器与数据的局部区域之间的点乘,来工作。通过这种方式,它能够捕获输入数据的局部模式和特征。
Difference Convolution差分卷积
差分卷积是一种特殊类型的卷积,它专注于捕捉输入数据中的局部差异信息,例如边缘或纹理的变化。
它通过修改标准卷积核的权重或者通过特殊的操作来实现,使得网络更加关注于图像的高频信息,即图像中的细节和纹理变化。在图像处理任务中,如图像去雾、图像增强、边缘检测等,捕获这种高频信息非常重要,因为它们往往包含了关于物体边界和结构的关键信息。
重参数化技术
重参数化技术是一种参数转换方法它允许模型在不增加额外参数和计算代价的情况下实现更复杂的功能或改善性能。在DEConv的上下文中重参数化技术使得将vanilla卷积和difference卷积结合起来的操作可以等价地转换成一个标准的卷积操作。
这意味着DEConv可以在不增加额外参数和计算成本的情况下通过巧妙地设计卷积核权重同时利用标准卷积和差分卷积的优势从而增强模型处理图像的能力。
具体来说,通过重参数化,可以将差分卷积的效果整合到一个卷积核中,使得这个卷积核既能捕获图像的基本特征(通过标准卷积部分),也能强调图像的细节和差异信息(通过差分卷积部分)。
这种方法特别适用于那些需要同时考虑全局内容和局部细节信息的任务,如图像去雾,其中既需要理解图像的整体结构,也需要恢复由于雾导致的细节丢失。
重参数化技术的关键优势在于,它允许模型在维持参数数量和计算复杂度不变的前提下,实现更为复杂或更为精细的功能。
"""
class Conv2d_cd(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size=3, stride=1,
padding=1, dilation=1, groups=1, bias=False, theta=1.0):
super(Conv2d_cd, self).__init__()
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding,
dilation=dilation, groups=groups, bias=bias)
self.theta = theta
def get_weight(self):
conv_weight = self.conv.weight
conv_shape = conv_weight.shape
conv_weight = Rearrange('c_in c_out k1 k2 -> c_in c_out (k1 k2)')(conv_weight)
# conv_weight_cd = torch.cuda.FloatTensor(conv_shape[0], conv_shape[1], 3 * 3).fill_(0)
conv_weight_cd = torch.FloatTensor(conv_shape[0], conv_shape[1], 3 * 3).fill_(0)
conv_weight_cd[:, :, :] = conv_weight[:, :, :]
conv_weight_cd[:, :, 4] = conv_weight[:, :, 4] - conv_weight[:, :, :].sum(2)
conv_weight_cd = Rearrange('c_in c_out (k1 k2) -> c_in c_out k1 k2', k1=conv_shape[2], k2=conv_shape[3])(
conv_weight_cd)
return conv_weight_cd, self.conv.bias
class Conv2d_ad(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size=3, stride=1,
padding=1, dilation=1, groups=1, bias=False, theta=1.0):
super(Conv2d_ad, self).__init__()
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding,
dilation=dilation, groups=groups, bias=bias)
self.theta = theta
def get_weight(self):
conv_weight = self.conv.weight
conv_shape = conv_weight.shape
conv_weight = Rearrange('c_in c_out k1 k2 -> c_in c_out (k1 k2)')(conv_weight)
conv_weight_ad = conv_weight - self.theta * conv_weight[:, :, [3, 0, 1, 6, 4, 2, 7, 8, 5]]
conv_weight_ad = Rearrange('c_in c_out (k1 k2) -> c_in c_out k1 k2', k1=conv_shape[2], k2=conv_shape[3])(
conv_weight_ad)
return conv_weight_ad, self.conv.bias
class Conv2d_rd(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size=3, stride=1,
padding=2, dilation=1, groups=1, bias=False, theta=1.0):
super(Conv2d_rd, self).__init__()
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding,
dilation=dilation, groups=groups, bias=bias)
self.theta = theta
def forward(self, x):
if math.fabs(self.theta - 0.0) < 1e-8:
out_normal = self.conv(x)
return out_normal
else:
conv_weight = self.conv.weight
conv_shape = conv_weight.shape
if conv_weight.is_cuda:
conv_weight_rd = torch.cuda.FloatTensor(conv_shape[0], conv_shape[1], 5 * 5).fill_(0)
else:
conv_weight_rd = torch.zeros(conv_shape[0], conv_shape[1], 5 * 5)
conv_weight = Rearrange('c_in c_out k1 k2 -> c_in c_out (k1 k2)')(conv_weight)
conv_weight_rd[:, :, [0, 2, 4, 10, 14, 20, 22, 24]] = conv_weight[:, :, 1:]
conv_weight_rd[:, :, [6, 7, 8, 11, 13, 16, 17, 18]] = -conv_weight[:, :, 1:] * self.theta
conv_weight_rd[:, :, 12] = conv_weight[:, :, 0] * (1 - self.theta)
conv_weight_rd = conv_weight_rd.view(conv_shape[0], conv_shape[1], 5, 5)
out_diff = nn.functional.conv2d(input=x, weight=conv_weight_rd, bias=self.conv.bias,
stride=self.conv.stride, padding=self.conv.padding, groups=self.conv.groups)
return out_diff
class Conv2d_hd(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size=3, stride=1,
padding=1, dilation=1, groups=1, bias=False, theta=1.0):
super(Conv2d_hd, self).__init__()
self.conv = nn.Conv1d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding,
dilation=dilation, groups=groups, bias=bias)
def get_weight(self):
conv_weight = self.conv.weight
conv_shape = conv_weight.shape
# conv_weight_hd = torch.cuda.FloatTensor(conv_shape[0], conv_shape[1], 3 * 3).fill_(0)
conv_weight_hd = torch.FloatTensor(conv_shape[0], conv_shape[1], 3 * 3).fill_(0)
conv_weight_hd[:, :, [0, 3, 6]] = conv_weight[:, :, :]
conv_weight_hd[:, :, [2, 5, 8]] = -conv_weight[:, :, :]
conv_weight_hd = Rearrange('c_in c_out (k1 k2) -> c_in c_out k1 k2', k1=conv_shape[2], k2=conv_shape[2])(
conv_weight_hd)
return conv_weight_hd, self.conv.bias
class Conv2d_vd(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size=3, stride=1,
padding=1, dilation=1, groups=1, bias=False):
super(Conv2d_vd, self).__init__()
self.conv = nn.Conv1d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding,
dilation=dilation, groups=groups, bias=bias)
def get_weight(self):
conv_weight = self.conv.weight
conv_shape = conv_weight.shape
# conv_weight_vd = torch.cuda.FloatTensor(conv_shape[0], conv_shape[1], 3 * 3).fill_(0)
conv_weight_vd = torch.FloatTensor(conv_shape[0], conv_shape[1], 3 * 3).fill_(0)
conv_weight_vd[:, :, [0, 1, 2]] = conv_weight[:, :, :]
conv_weight_vd[:, :, [6, 7, 8]] = -conv_weight[:, :, :]
conv_weight_vd = Rearrange('c_in c_out (k1 k2) -> c_in c_out k1 k2', k1=conv_shape[2], k2=conv_shape[2])(
conv_weight_vd)
return conv_weight_vd, self.conv.bias
class DEConv(nn.Module):
def __init__(self, dim):
super(DEConv, self).__init__()
self.conv1_1 = Conv2d_cd(dim, dim, 3, bias=True)
self.conv1_2 = Conv2d_hd(dim, dim, 3, bias=True)
self.conv1_3 = Conv2d_vd(dim, dim, 3, bias=True)
self.conv1_4 = Conv2d_ad(dim, dim, 3, bias=True)
self.conv1_5 = nn.Conv2d(dim, dim, 3, padding=1, bias=True)
def forward(self, x):
w1, b1 = self.conv1_1.get_weight()
w2, b2 = self.conv1_2.get_weight()
w3, b3 = self.conv1_3.get_weight()
w4, b4 = self.conv1_4.get_weight()
w5, b5 = self.conv1_5.weight, self.conv1_5.bias
w = w1 + w2 + w3 + w4 + w5
b = b1 + b2 + b3 + b4 + b5
res = nn.functional.conv2d(input=x, weight=w, bias=b, stride=1, padding=1, groups=1)
return res
if __name__ == '__main__':
# 初始化DEConv模块dim为输入和输出的通道数
block = DEConv(dim=16)
# 创建一个随机输入张量,假设输入尺寸为(batch_size, channels, height, width)
input_tensor = torch.rand(4, 16, 64, 64)
# 将输入传递给DEConv模块
output_tensor = block(input_tensor)
# 打印输入和输出张量的尺寸
print("输入尺寸:", input_tensor.size())
print("输出尺寸:", output_tensor.size())