feat(components): 添加多种特征增强模块

- 新增 DEConv、DynamicFilter、SCSA、SEBlock、SMFA、TIAM 和 UFFC 模块
- 这些模块提供了不同的特征增强功能,如卷积差分、频域滤波、注意力机制等
- 可以根据需求选择合适的模块来提升模型性能
This commit is contained in:
zjut 2024-11-15 09:18:34 +08:00
parent c1eed72f24
commit 927d049f13
9 changed files with 778 additions and 0 deletions

171
componets/DEConv.py Normal file
View File

@ -0,0 +1,171 @@
# https://github.com/cecret3350/DEA-Net/blob/main/code/model/modules/deconv.py
import math
import torch
from torch import nn
from einops.layers.torch import Rearrange
"""
在深度学习和图像处理领域"vanilla" "difference" 卷积是两种不同的卷积操作它们各自有不同的特性和应用场景DEConv细节增强卷积的设计思想是结合这两种卷积的特性来增强模型的性能尤其是在图像去雾等任务中
Vanilla Convolution标准卷积
"Vanilla" 卷积是最基本的卷积类型通常仅称为卷积它是卷积神经网络CNN中最常用的组件用于提取输入数据如图像的特征
标准卷积通过在输入数据上滑动小的可学习的过滤器或称为核并计算过滤器与数据的局部区域之间的点乘来工作通过这种方式它能够捕获输入数据的局部模式和特征
Difference Convolution差分卷积
差分卷积是一种特殊类型的卷积它专注于捕捉输入数据中的局部差异信息例如边缘或纹理的变化
它通过修改标准卷积核的权重或者通过特殊的操作来实现使得网络更加关注于图像的高频信息即图像中的细节和纹理变化在图像处理任务中如图像去雾图像增强边缘检测等捕获这种高频信息非常重要因为它们往往包含了关于物体边界和结构的关键信息
重参数化技术
重参数化技术是一种参数转换方法它允许模型在不增加额外参数和计算代价的情况下实现更复杂的功能或改善性能在DEConv的上下文中重参数化技术使得将vanilla卷积和difference卷积结合起来的操作可以等价地转换成一个标准的卷积操作
这意味着DEConv可以在不增加额外参数和计算成本的情况下通过巧妙地设计卷积核权重同时利用标准卷积和差分卷积的优势从而增强模型处理图像的能力
具体来说通过重参数化可以将差分卷积的效果整合到一个卷积核中使得这个卷积核既能捕获图像的基本特征通过标准卷积部分也能强调图像的细节和差异信息通过差分卷积部分
这种方法特别适用于那些需要同时考虑全局内容和局部细节信息的任务如图像去雾其中既需要理解图像的整体结构也需要恢复由于雾导致的细节丢失
重参数化技术的关键优势在于它允许模型在维持参数数量和计算复杂度不变的前提下实现更为复杂或更为精细的功能
"""
class Conv2d_cd(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size=3, stride=1,
padding=1, dilation=1, groups=1, bias=False, theta=1.0):
super(Conv2d_cd, self).__init__()
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding,
dilation=dilation, groups=groups, bias=bias)
self.theta = theta
def get_weight(self):
conv_weight = self.conv.weight
conv_shape = conv_weight.shape
conv_weight = Rearrange('c_in c_out k1 k2 -> c_in c_out (k1 k2)')(conv_weight)
# conv_weight_cd = torch.cuda.FloatTensor(conv_shape[0], conv_shape[1], 3 * 3).fill_(0)
conv_weight_cd = torch.FloatTensor(conv_shape[0], conv_shape[1], 3 * 3).fill_(0)
conv_weight_cd[:, :, :] = conv_weight[:, :, :]
conv_weight_cd[:, :, 4] = conv_weight[:, :, 4] - conv_weight[:, :, :].sum(2)
conv_weight_cd = Rearrange('c_in c_out (k1 k2) -> c_in c_out k1 k2', k1=conv_shape[2], k2=conv_shape[3])(
conv_weight_cd)
return conv_weight_cd, self.conv.bias
class Conv2d_ad(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size=3, stride=1,
padding=1, dilation=1, groups=1, bias=False, theta=1.0):
super(Conv2d_ad, self).__init__()
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding,
dilation=dilation, groups=groups, bias=bias)
self.theta = theta
def get_weight(self):
conv_weight = self.conv.weight
conv_shape = conv_weight.shape
conv_weight = Rearrange('c_in c_out k1 k2 -> c_in c_out (k1 k2)')(conv_weight)
conv_weight_ad = conv_weight - self.theta * conv_weight[:, :, [3, 0, 1, 6, 4, 2, 7, 8, 5]]
conv_weight_ad = Rearrange('c_in c_out (k1 k2) -> c_in c_out k1 k2', k1=conv_shape[2], k2=conv_shape[3])(
conv_weight_ad)
return conv_weight_ad, self.conv.bias
class Conv2d_rd(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size=3, stride=1,
padding=2, dilation=1, groups=1, bias=False, theta=1.0):
super(Conv2d_rd, self).__init__()
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding,
dilation=dilation, groups=groups, bias=bias)
self.theta = theta
def forward(self, x):
if math.fabs(self.theta - 0.0) < 1e-8:
out_normal = self.conv(x)
return out_normal
else:
conv_weight = self.conv.weight
conv_shape = conv_weight.shape
if conv_weight.is_cuda:
conv_weight_rd = torch.cuda.FloatTensor(conv_shape[0], conv_shape[1], 5 * 5).fill_(0)
else:
conv_weight_rd = torch.zeros(conv_shape[0], conv_shape[1], 5 * 5)
conv_weight = Rearrange('c_in c_out k1 k2 -> c_in c_out (k1 k2)')(conv_weight)
conv_weight_rd[:, :, [0, 2, 4, 10, 14, 20, 22, 24]] = conv_weight[:, :, 1:]
conv_weight_rd[:, :, [6, 7, 8, 11, 13, 16, 17, 18]] = -conv_weight[:, :, 1:] * self.theta
conv_weight_rd[:, :, 12] = conv_weight[:, :, 0] * (1 - self.theta)
conv_weight_rd = conv_weight_rd.view(conv_shape[0], conv_shape[1], 5, 5)
out_diff = nn.functional.conv2d(input=x, weight=conv_weight_rd, bias=self.conv.bias,
stride=self.conv.stride, padding=self.conv.padding, groups=self.conv.groups)
return out_diff
class Conv2d_hd(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size=3, stride=1,
padding=1, dilation=1, groups=1, bias=False, theta=1.0):
super(Conv2d_hd, self).__init__()
self.conv = nn.Conv1d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding,
dilation=dilation, groups=groups, bias=bias)
def get_weight(self):
conv_weight = self.conv.weight
conv_shape = conv_weight.shape
# conv_weight_hd = torch.cuda.FloatTensor(conv_shape[0], conv_shape[1], 3 * 3).fill_(0)
conv_weight_hd = torch.FloatTensor(conv_shape[0], conv_shape[1], 3 * 3).fill_(0)
conv_weight_hd[:, :, [0, 3, 6]] = conv_weight[:, :, :]
conv_weight_hd[:, :, [2, 5, 8]] = -conv_weight[:, :, :]
conv_weight_hd = Rearrange('c_in c_out (k1 k2) -> c_in c_out k1 k2', k1=conv_shape[2], k2=conv_shape[2])(
conv_weight_hd)
return conv_weight_hd, self.conv.bias
class Conv2d_vd(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size=3, stride=1,
padding=1, dilation=1, groups=1, bias=False):
super(Conv2d_vd, self).__init__()
self.conv = nn.Conv1d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding,
dilation=dilation, groups=groups, bias=bias)
def get_weight(self):
conv_weight = self.conv.weight
conv_shape = conv_weight.shape
# conv_weight_vd = torch.cuda.FloatTensor(conv_shape[0], conv_shape[1], 3 * 3).fill_(0)
conv_weight_vd = torch.FloatTensor(conv_shape[0], conv_shape[1], 3 * 3).fill_(0)
conv_weight_vd[:, :, [0, 1, 2]] = conv_weight[:, :, :]
conv_weight_vd[:, :, [6, 7, 8]] = -conv_weight[:, :, :]
conv_weight_vd = Rearrange('c_in c_out (k1 k2) -> c_in c_out k1 k2', k1=conv_shape[2], k2=conv_shape[2])(
conv_weight_vd)
return conv_weight_vd, self.conv.bias
class DEConv(nn.Module):
def __init__(self, dim):
super(DEConv, self).__init__()
self.conv1_1 = Conv2d_cd(dim, dim, 3, bias=True)
self.conv1_2 = Conv2d_hd(dim, dim, 3, bias=True)
self.conv1_3 = Conv2d_vd(dim, dim, 3, bias=True)
self.conv1_4 = Conv2d_ad(dim, dim, 3, bias=True)
self.conv1_5 = nn.Conv2d(dim, dim, 3, padding=1, bias=True)
def forward(self, x):
w1, b1 = self.conv1_1.get_weight()
w2, b2 = self.conv1_2.get_weight()
w3, b3 = self.conv1_3.get_weight()
w4, b4 = self.conv1_4.get_weight()
w5, b5 = self.conv1_5.weight, self.conv1_5.bias
w = w1 + w2 + w3 + w4 + w5
b = b1 + b2 + b3 + b4 + b5
res = nn.functional.conv2d(input=x, weight=w, bias=b, stride=1, padding=1, groups=1)
return res
if __name__ == '__main__':
# 初始化DEConv模块dim为输入和输出的通道数
block = DEConv(dim=16)
# 创建一个随机输入张量,假设输入尺寸为(batch_size, channels, height, width)
input_tensor = torch.rand(4, 16, 64, 64)
# 将输入传递给DEConv模块
output_tensor = block(input_tensor)
# 打印输入和输出张量的尺寸
print("输入尺寸:", input_tensor.size())
print("输出尺寸:", output_tensor.size())

View File

@ -0,0 +1,116 @@
import torch
import torch.nn as nn
from timm.layers.helpers import to_2tuple
"""
配备多头自注意力 MHSA 的模型在计算机视觉方面取得了显着的性能它们的计算复杂度与输入特征图中的二次像素数成正比导致处理速度缓慢尤其是在处理高分辨率图像时
为了规避这个问题提出了一种新型的代币混合器作为MHSA的替代方案基于FFT的代币混合器涉及类似于MHSA的全局操作但计算复杂度较低
在这里我们提出了一种名为动态过滤器的新型令牌混合器以缩小上述差距
DynamicFilter 模块通过频域滤波和动态调整滤波器权重能够对图像进行复杂的增强和处理
"""
class StarReLU(nn.Module):
"""
StarReLU: s * relu(x) ** 2 + b
"""
def __init__(self, scale_value=1.0, bias_value=0.0,
scale_learnable=True, bias_learnable=True,
mode=None, inplace=False):
super().__init__()
self.inplace = inplace
self.relu = nn.ReLU(inplace=inplace)
self.scale = nn.Parameter(scale_value * torch.ones(1),
requires_grad=scale_learnable)
self.bias = nn.Parameter(bias_value * torch.ones(1),
requires_grad=bias_learnable)
def forward(self, x):
return self.scale * self.relu(x) ** 2 + self.bias
class Mlp(nn.Module):
""" MLP as used in MetaFormer models, eg Transformer, MLP-Mixer, PoolFormer, MetaFormer baslines and related networks.
Mostly copied from timm.
"""
def __init__(self, dim, mlp_ratio=4, out_features=None, act_layer=StarReLU, drop=0.,
bias=False, **kwargs):
super().__init__()
in_features = dim
out_features = out_features or in_features
hidden_features = int(mlp_ratio * in_features)
drop_probs = to_2tuple(drop)
self.fc1 = nn.Linear(in_features, hidden_features, bias=bias)
self.act = act_layer()
self.drop1 = nn.Dropout(drop_probs[0])
self.fc2 = nn.Linear(hidden_features, out_features, bias=bias)
self.drop2 = nn.Dropout(drop_probs[1])
def forward(self, x):
x = self.fc1(x)
x = self.act(x)
x = self.drop1(x)
x = self.fc2(x)
x = self.drop2(x)
return x
class DynamicFilter(nn.Module):
def __init__(self, dim, expansion_ratio=2, reweight_expansion_ratio=.25,
act1_layer=StarReLU, act2_layer=nn.Identity,
bias=False, num_filters=4, size=14, weight_resize=False,
**kwargs):
super().__init__()
size = to_2tuple(size)
self.size = size[0]
self.filter_size = size[1] // 2 + 1
self.num_filters = num_filters
self.dim = dim
self.med_channels = int(expansion_ratio * dim)
self.weight_resize = weight_resize
self.pwconv1 = nn.Linear(dim, self.med_channels, bias=bias)
self.act1 = act1_layer()
self.reweight = Mlp(dim, reweight_expansion_ratio, num_filters * self.med_channels)
self.complex_weights = nn.Parameter(
torch.randn(self.size, self.filter_size, num_filters, 2,
dtype=torch.float32) * 0.02)
self.act2 = act2_layer()
self.pwconv2 = nn.Linear(self.med_channels, dim, bias=bias)
def forward(self, x):
B, H, W, _ = x.shape
routeing = self.reweight(x.mean(dim=(1, 2))).view(B, self.num_filters,
-1).softmax(dim=1)
x = self.pwconv1(x)
x = self.act1(x)
x = x.to(torch.float32)
x = torch.fft.rfft2(x, dim=(1, 2), norm='ortho')
if self.weight_resize:
complex_weights = resize_complex_weight(self.complex_weights, x.shape[1],
x.shape[2])
complex_weights = torch.view_as_complex(complex_weights.contiguous())
else:
complex_weights = torch.view_as_complex(self.complex_weights)
routeing = routeing.to(torch.complex64)
weight = torch.einsum('bfc,hwf->bhwc', routeing, complex_weights)
if self.weight_resize:
weight = weight.view(-1, x.shape[1], x.shape[2], self.med_channels)
else:
weight = weight.view(-1, self.size, self.filter_size, self.med_channels)
x = x * weight
x = torch.fft.irfft2(x, s=(H, W), dim=(1, 2), norm='ortho')
x = self.act2(x)
x = self.pwconv2(x)
return x
if __name__ == '__main__':
block = DynamicFilter(32, size=64) # size==H,W
input = torch.rand(3, 64, 64, 32)
output = block(input)
print(input.size())
print(output.size())

Binary file not shown.

156
componets/SCSA.py Normal file
View File

@ -0,0 +1,156 @@
import typing as t
import torch
import torch.nn as nn
from einops.einops import rearrange
from mmengine.model import BaseModule
__all__ = ['SCSA']
"""SCSA探索空间注意力和通道注意力之间的协同作用
通道和空间注意力分别在为各种下游视觉任务提取特征依赖性和空间结构关系方面带来了显着的改进
虽然它们的结合更有利于发挥各自的优势但通道和空间注意力之间的协同作用尚未得到充分探索缺乏充分利用多语义信息的协同潜力来进行特征引导和缓解语义差异
我们的研究试图在多个语义层面揭示空间和通道注意力之间的协同关系提出了一种新颖的空间和通道协同注意力模块SCSA我们的SCSA由两部分组成可共享的多语义空间注意力SMSA和渐进式通道自注意力PCSA
SMSA 集成多语义信息并利用渐进式压缩策略将判别性空间先验注入 PCSA 的通道自注意力中有效地指导通道重新校准此外PCSA 中基于自注意力机制的稳健特征交互进一步缓解了 SMSA 中不同子特征之间多语义信息的差异
我们在七个基准数据集上进行了广泛的实验包括 ImageNet-1K 上的分类MSCOCO 2017 上的对象检测ADE20K 上的分割以及其他四个复杂场景检测数据集我们的结果表明我们提出的 SCSA 不仅超越了当前最先进的注意力机制
而且在各种任务场景中表现出增强的泛化能力
"""
class SCSA(BaseModule):
def __init__(
self,
dim: int,
head_num: int,
window_size: int = 7,
group_kernel_sizes: t.List[int] = [3, 5, 7, 9],
qkv_bias: bool = False,
fuse_bn: bool = False,
norm_cfg: t.Dict = dict(type='BN'),
act_cfg: t.Dict = dict(type='ReLU'),
down_sample_mode: str = 'avg_pool',
attn_drop_ratio: float = 0.,
gate_layer: str = 'sigmoid',
):
super(SCSA, self).__init__()
self.dim = dim
self.head_num = head_num
self.head_dim = dim // head_num
self.scaler = self.head_dim ** -0.5
self.group_kernel_sizes = group_kernel_sizes
self.window_size = window_size
self.qkv_bias = qkv_bias
self.fuse_bn = fuse_bn
self.down_sample_mode = down_sample_mode
assert self.dim // 4, 'The dimension of input feature should be divisible by 4.'
self.group_chans = group_chans = self.dim // 4
self.local_dwc = nn.Conv1d(group_chans, group_chans, kernel_size=group_kernel_sizes[0],
padding=group_kernel_sizes[0] // 2, groups=group_chans)
self.global_dwc_s = nn.Conv1d(group_chans, group_chans, kernel_size=group_kernel_sizes[1],
padding=group_kernel_sizes[1] // 2, groups=group_chans)
self.global_dwc_m = nn.Conv1d(group_chans, group_chans, kernel_size=group_kernel_sizes[2],
padding=group_kernel_sizes[2] // 2, groups=group_chans)
self.global_dwc_l = nn.Conv1d(group_chans, group_chans, kernel_size=group_kernel_sizes[3],
padding=group_kernel_sizes[3] // 2, groups=group_chans)
self.sa_gate = nn.Softmax(dim=2) if gate_layer == 'softmax' else nn.Sigmoid()
self.norm_h = nn.GroupNorm(4, dim)
self.norm_w = nn.GroupNorm(4, dim)
self.conv_d = nn.Identity()
self.norm = nn.GroupNorm(1, dim)
self.q = nn.Conv2d(in_channels=dim, out_channels=dim, kernel_size=1, bias=qkv_bias, groups=dim)
self.k = nn.Conv2d(in_channels=dim, out_channels=dim, kernel_size=1, bias=qkv_bias, groups=dim)
self.v = nn.Conv2d(in_channels=dim, out_channels=dim, kernel_size=1, bias=qkv_bias, groups=dim)
self.attn_drop = nn.Dropout(attn_drop_ratio)
self.ca_gate = nn.Softmax(dim=1) if gate_layer == 'softmax' else nn.Sigmoid()
if window_size == -1:
self.down_func = nn.AdaptiveAvgPool2d((1, 1))
else:
if down_sample_mode == 'recombination':
self.down_func = self.space_to_chans
# dimensionality reduction
self.conv_d = nn.Conv2d(in_channels=dim * window_size ** 2, out_channels=dim, kernel_size=1, bias=False)
elif down_sample_mode == 'avg_pool':
self.down_func = nn.AvgPool2d(kernel_size=(window_size, window_size), stride=window_size)
elif down_sample_mode == 'max_pool':
self.down_func = nn.MaxPool2d(kernel_size=(window_size, window_size), stride=window_size)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
The dim of x is (B, C, H, W)
"""
# Spatial attention priority calculation
b, c, h_, w_ = x.size()
# (B, C, H)
x_h = x.mean(dim=3)
l_x_h, g_x_h_s, g_x_h_m, g_x_h_l = torch.split(x_h, self.group_chans, dim=1)
# (B, C, W)
x_w = x.mean(dim=2)
l_x_w, g_x_w_s, g_x_w_m, g_x_w_l = torch.split(x_w, self.group_chans, dim=1)
x_h_attn = self.sa_gate(self.norm_h(torch.cat((
self.local_dwc(l_x_h),
self.global_dwc_s(g_x_h_s),
self.global_dwc_m(g_x_h_m),
self.global_dwc_l(g_x_h_l),
), dim=1)))
x_h_attn = x_h_attn.view(b, c, h_, 1)
x_w_attn = self.sa_gate(self.norm_w(torch.cat((
self.local_dwc(l_x_w),
self.global_dwc_s(g_x_w_s),
self.global_dwc_m(g_x_w_m),
self.global_dwc_l(g_x_w_l)
), dim=1)))
x_w_attn = x_w_attn.view(b, c, 1, w_)
x = x * x_h_attn * x_w_attn
# Channel attention based on self attention
# reduce calculations
y = self.down_func(x)
y = self.conv_d(y)
_, _, h_, w_ = y.size()
# normalization first, then reshape -> (B, H, W, C) -> (B, C, H * W) and generate q, k and v
y = self.norm(y)
q = self.q(y)
k = self.k(y)
v = self.v(y)
# (B, C, H, W) -> (B, head_num, head_dim, N)
q = rearrange(q, 'b (head_num head_dim) h w -> b head_num head_dim (h w)', head_num=int(self.head_num),
head_dim=int(self.head_dim))
k = rearrange(k, 'b (head_num head_dim) h w -> b head_num head_dim (h w)', head_num=int(self.head_num),
head_dim=int(self.head_dim))
v = rearrange(v, 'b (head_num head_dim) h w -> b head_num head_dim (h w)', head_num=int(self.head_num),
head_dim=int(self.head_dim))
# (B, head_num, head_dim, head_dim)
attn = q @ k.transpose(-2, -1) * self.scaler
attn = self.attn_drop(attn.softmax(dim=-1))
# (B, head_num, head_dim, N)
attn = attn @ v
# (B, C, H_, W_)
attn = rearrange(attn, 'b head_num head_dim (h w) -> b (head_num head_dim) h w', h=int(h_), w=int(w_))
# (B, C, 1, 1)
attn = attn.mean((2, 3), keepdim=True)
attn = self.ca_gate(attn)
return attn * x
if __name__ == '__main__':
block = SCSA(
dim=256,
head_num=8,
)
input_tensor = torch.rand(1, 256, 32, 32)
# 调用模块进行前向传播
output_tensor = block(input_tensor)
# 打印输入和输出张量的大小
print("Input size:", input_tensor.size())
print("Output size:", output_tensor.size())

37
componets/SEBlock.py Normal file
View File

@ -0,0 +1,37 @@
'''-------------一、SE模块-----------------------------'''
import torch
from torch import nn
# 全局平均池化+1*1卷积核+ReLu+1*1卷积核+Sigmoid
class SE_Block(nn.Module):
def __init__(self, inchannel, ratio=16):
super(SE_Block, self).__init__()
# 全局平均池化(Fsq操作)
self.gap = nn.AdaptiveAvgPool2d((1, 1))
# 两个全连接层(Fex操作)
self.fc = nn.Sequential(
nn.Linear(inchannel, inchannel // ratio, bias=False), # 从 c -> c/r
nn.ReLU(),
nn.Linear(inchannel // ratio, inchannel, bias=False), # 从 c/r -> c
nn.Sigmoid()
)
def forward(self, x):
# 读取批数据图片数量及通道数
b, c, h, w = x.size()
# Fsq操作经池化后输出b*c的矩阵
y = self.gap(x).view(b, c)
# Fex操作经全连接层输出bc11矩阵
y = self.fc(y).view(b, c, 1, 1)
# Fscale操作将得到的权重乘以原来的特征图x
return x * y.expand_as(x)
if __name__ == '__main__':
input = torch.randn(1, 64, 32, 32)
seblock = SE_Block(64)
print(seblock)
output = seblock(input)
print(input.shape)
print(output.shape)

65
componets/SMFA.py Normal file
View File

@ -0,0 +1,65 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
"""ECCV2024(https://github.com/Zheng-MJ/SMFANet)
基于Transformer的恢复方法取得了显著的效果因为Transformer的自注意力机制SA可以探索非局部信息从而实现更好的高分辨率图像重建然而关键的点积自注意力需要大量的计算资源这限制了其在低功耗设备上的应用
此外自注意力机制的低通滤波特性限制了其捕捉局部细节的能力从而导致重建结果过于平滑为了解决这些问题我们提出了一种自调制特征聚合SMFA模块协同利用局部和非局部特征交互以实现更精确的重建
具体而言SMFA模块采用了高效的自注意力近似EASA分支来建模非局部信息并使用局部细节估计LDE分支来捕捉局部细节此外我们还引入了基于部分卷积的前馈网络PCFN以进一步优化从SMFA提取的代表性特征
大量实验表明所提出的SMFANet系列在公共基准数据集上实现了更好的重建性能与计算效率的平衡
特别是与SwinIR-light的×4放大相比SMFANet+在五个公共测试集上的平均性能提高了0.14dB运行速度提升了约10倍且模型复杂度如FLOPs仅为其约43%
"""
class DMlp(nn.Module):
def __init__(self, dim, growth_rate=2.0):
super().__init__()
hidden_dim = int(dim * growth_rate)
self.conv_0 = nn.Sequential(
nn.Conv2d(dim, hidden_dim, 3, 1, 1, groups=dim),
nn.Conv2d(hidden_dim, hidden_dim, 1, 1, 0)
)
self.act = nn.GELU()
self.conv_1 = nn.Conv2d(hidden_dim, dim, 1, 1, 0)
def forward(self, x):
x = self.conv_0(x)
x = self.act(x)
x = self.conv_1(x)
return x
class SMFA(nn.Module):
def __init__(self, dim=36):
super(SMFA, self).__init__()
self.linear_0 = nn.Conv2d(dim, dim * 2, 1, 1, 0)
self.linear_1 = nn.Conv2d(dim, dim, 1, 1, 0)
self.linear_2 = nn.Conv2d(dim, dim, 1, 1, 0)
self.lde = DMlp(dim, 2)
self.dw_conv = nn.Conv2d(dim, dim, 3, 1, 1, groups=dim)
self.gelu = nn.GELU()
self.down_scale = 8
self.alpha = nn.Parameter(torch.ones((1, dim, 1, 1)))
self.belt = nn.Parameter(torch.zeros((1, dim, 1, 1)))
def forward(self, f):
_, _, h, w = f.shape
y, x = self.linear_0(f).chunk(2, dim=1)
x_s = self.dw_conv(F.adaptive_max_pool2d(x, (h // self.down_scale, w // self.down_scale)))
x_v = torch.var(x, dim=(-2, -1), keepdim=True)
x_l = x * F.interpolate(self.gelu(self.linear_1(x_s * self.alpha + x_v * self.belt)), size=(h, w),
mode='nearest')
y_d = self.lde(y)
return self.linear_2(x_l + y_d)
if __name__ == '__main__':
block = SMFA(dim=36)
input = torch.randn(3, 36, 64, 64)
output = block(input)
print(input.size())
print(output.size())

110
componets/TIAM.py Normal file
View File

@ -0,0 +1,110 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
"""Elsevier2024
变化检测 (CD) 是地球观测中一种重要的监测方法尤其适用于土地利用分析城市管理和灾害损失评估然而在星座互联和空天协作时代感兴趣区域 (ROI) 的变化由于几何透视旋转和时间风格差异而导致许多错误检测
为了应对这些挑战我们引入了 CDNeXt该框架阐明了一种稳健而有效的方法用于将基于预训练主干的 Siamese 网络与用于遥感图像的创新时空交互注意模块 (TIAM) 相结合
CDNeXt 可分为四个主要组件编码器交互器解码器和检测器值得注意的是 TIAM 提供支持的交互器从编码器提取的二进制时间特征中查询和重建空间透视依赖关系和时间风格相关性以扩大 ROI 变化的差异
最后检测器集成解码器生成的分层特征随后生成二进制变化掩码
"""
class SpatiotemporalAttentionFullNotWeightShared(nn.Module):
def __init__(self, in_channels, inter_channels=None, dimension=2, sub_sample=False):
super(SpatiotemporalAttentionFullNotWeightShared, self).__init__()
assert dimension in [2, ]
self.dimension = dimension
self.sub_sample = sub_sample
self.in_channels = in_channels
self.inter_channels = inter_channels
if self.inter_channels is None:
self.inter_channels = in_channels // 2
if self.inter_channels == 0:
self.inter_channels = 1
self.g1 = nn.Sequential(
nn.BatchNorm2d(self.in_channels),
nn.Conv2d(in_channels=self.in_channels, out_channels=self.inter_channels,
kernel_size=1, stride=1, padding=0)
)
self.g2 = nn.Sequential(
nn.BatchNorm2d(self.in_channels),
nn.Conv2d(in_channels=self.in_channels, out_channels=self.inter_channels,
kernel_size=1, stride=1, padding=0),
)
self.W1 = nn.Sequential(
nn.Conv2d(in_channels=self.inter_channels, out_channels=self.in_channels,
kernel_size=1, stride=1, padding=0),
nn.BatchNorm2d(self.in_channels)
)
self.W2 = nn.Sequential(
nn.Conv2d(in_channels=self.inter_channels, out_channels=self.in_channels,
kernel_size=1, stride=1, padding=0),
nn.BatchNorm2d(self.in_channels)
)
self.theta = nn.Sequential(
nn.BatchNorm2d(self.in_channels),
nn.Conv2d(in_channels=self.in_channels, out_channels=self.inter_channels,
kernel_size=1, stride=1, padding=0),
)
self.phi = nn.Sequential(
nn.BatchNorm2d(self.in_channels),
nn.Conv2d(in_channels=self.in_channels, out_channels=self.inter_channels,
kernel_size=1, stride=1, padding=0),
)
def forward(self, x1, x2):
"""
:param x: (b, c, h, w)
:param return_nl_map: if True return z, nl_map, else only return z.
:return:
"""
batch_size = x1.size(0)
g_x11 = self.g1(x1).reshape(batch_size, self.inter_channels, -1)
g_x12 = g_x11.permute(0, 2, 1)
g_x21 = self.g2(x2).reshape(batch_size, self.inter_channels, -1)
g_x22 = g_x21.permute(0, 2, 1)
theta_x1 = self.theta(x1).reshape(batch_size, self.inter_channels, -1)
theta_x2 = theta_x1.permute(0, 2, 1)
phi_x1 = self.phi(x2).reshape(batch_size, self.inter_channels, -1)
phi_x2 = phi_x1.permute(0, 2, 1)
energy_time_1 = torch.matmul(theta_x1, phi_x2)
energy_time_2 = energy_time_1.permute(0, 2, 1)
energy_space_1 = torch.matmul(theta_x2, phi_x1)
energy_space_2 = energy_space_1.permute(0, 2, 1)
energy_time_1s = F.softmax(energy_time_1, dim=-1)
energy_time_2s = F.softmax(energy_time_2, dim=-1)
energy_space_2s = F.softmax(energy_space_1, dim=-2)
energy_space_1s = F.softmax(energy_space_2, dim=-2)
# C1*S(C2) energy_time_1s * C1*H1W1 g_x12 * energy_space_1s S(H2W2)*H1W1 -> C1*H1W1
y1 = torch.matmul(torch.matmul(energy_time_2s, g_x11), energy_space_2s).contiguous() # C2*H2W2
# C2*S(C1) energy_time_2s * C2*H2W2 g_x21 * energy_space_2s S(H1W1)*H2W2 -> C2*H2W2
y2 = torch.matmul(torch.matmul(energy_time_1s, g_x21), energy_space_1s).contiguous() # C1*H1W1
y1 = y1.reshape(batch_size, self.inter_channels, *x2.size()[2:])
y2 = y2.reshape(batch_size, self.inter_channels, *x1.size()[2:])
return x1 + self.W1(y1), x2 + self.W2(y2)
if __name__ == '__main__':
in_channels = 64
batch_size = 8
height = 32
width = 32
block = SpatiotemporalAttentionFullNotWeightShared(in_channels=in_channels)
input1 = torch.rand(batch_size, in_channels, height, width)
input2 = torch.rand(batch_size, in_channels, height, width)
output1, output2 = block(input1, input2)
print(f"Input1 size: {input1.size()}")
print(f"Input2 size: {input2.size()}")
print(f"Output1 size: {output1.size()}")
print(f"Output2 size: {output2.size()}")

Binary file not shown.

123
componets/UFFC.py Normal file
View File

@ -0,0 +1,123 @@
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
"""ICCV2023
最近提出的图像修复方法 LaMa 以快速傅里叶卷积 (FFC) 为基础构建了其网络该网络最初是为图像分类等高级视觉任务而提出的
FFC 使全卷积网络在其早期层中拥有全局感受野得益于 FFC 模块的独特特性LaMa 能够生成稳健的重复纹理
这是以前的修复方法无法实现的但是原始 FFC 模块是否适合图像修复等低级视觉任务
在本文中我们分析了在图像修复中使用 FFC 的基本缺陷 1) 频谱偏移2) 意外的空间激活和 3) 频率感受野有限
这些缺陷使得基于 FFC 的修复框架难以生成复杂纹理并执行完美重建
基于以上分析我们提出了一种新颖的无偏快速傅里叶卷积 (UFFC) 模块该模块通过
1) 范围变换和逆变换2) 绝对位置嵌入3) 动态跳过连接和 4) 自适应剪辑对原始 FFC 模块进行了修改以克服这些缺陷
实现更好的修复效果在多个基准数据集上进行的大量实验证明了我们方法的有效性在纹理捕捉能力和表现力方面均优于最先进的方法
"""
class FourierUnit_modified(nn.Module):
def __init__(self, in_channels, out_channels, groups=1, spatial_scale_factor=None, spatial_scale_mode='bilinear',
spectral_pos_encoding=False, use_se=False, ffc3d=False, fft_norm='ortho'):
# bn_layer not used
super(FourierUnit_modified, self).__init__()
self.groups = groups
self.input_shape = 32 # change!!!!!it!!!!!!manually!!!!!!
self.in_channels = in_channels
self.locMap = nn.Parameter(torch.rand(self.input_shape, self.input_shape // 2 + 1))
self.lambda_base = nn.Parameter(torch.tensor(0.), requires_grad=True)
self.conv_layer_down55 = torch.nn.Conv2d(in_channels=in_channels * 2 + 1, # +1 for locmap
out_channels=out_channels * 2,
kernel_size=1, stride=1, padding=0, dilation=1, groups=self.groups,
bias=False, padding_mode='reflect')
self.conv_layer_down55_shift = torch.nn.Conv2d(in_channels=in_channels * 2 + 1, # +1 for locmap
out_channels=out_channels * 2,
kernel_size=3, stride=1, padding=2, dilation=2,
groups=self.groups, bias=False, padding_mode='reflect')
self.norm = nn.BatchNorm2d(out_channels)
self.relu = nn.ReLU(inplace=True)
self.spatial_scale_factor = spatial_scale_factor
self.spatial_scale_mode = spatial_scale_mode
self.spectral_pos_encoding = spectral_pos_encoding
self.ffc3d = ffc3d
self.fft_norm = fft_norm
self.img_freq = None
self.distill = None
def forward(self, x):
batch = x.shape[0]
if self.spatial_scale_factor is not None:
orig_size = x.shape[-2:]
x = F.interpolate(x, scale_factor=self.spatial_scale_factor, mode=self.spatial_scale_mode,
align_corners=False)
fft_dim = (-3, -2, -1) if self.ffc3d else (-2, -1)
ffted = torch.fft.rfftn(x, dim=fft_dim, norm=self.fft_norm)
ffted = torch.stack((ffted.real, ffted.imag), dim=-1)
ffted = ffted.permute(0, 1, 4, 2, 3).contiguous() # (batch, c, 2, h, w/2+1)
ffted = ffted.view((batch, -1,) + ffted.size()[3:])
locMap = self.locMap.expand_as(ffted[:, :1, :, :]) # B 1 H' W'
ffted_copy = ffted.clone()
cat_img_mask_freq = torch.cat((ffted[:, :self.in_channels, :, :],
ffted[:, self.in_channels:, :, :],
locMap), dim=1)
ffted = self.conv_layer_down55(cat_img_mask_freq)
ffted = torch.fft.fftshift(ffted, dim=-2)
ffted = self.relu(ffted)
locMap_shift = torch.fft.fftshift(locMap, dim=-2) ## ONLY IF NOT SHIFT BACK
# REPEAT CONV
cat_img_mask_freq1 = torch.cat((ffted[:, :self.in_channels, :, :],
ffted[:, self.in_channels:, :, :],
locMap_shift), dim=1)
ffted = self.conv_layer_down55_shift(cat_img_mask_freq1)
ffted = torch.fft.fftshift(ffted, dim=-2)
lambda_base = torch.sigmoid(self.lambda_base)
ffted = ffted_copy * lambda_base + ffted * (1 - lambda_base)
# irfft
ffted = ffted.view((batch, -1, 2,) + ffted.size()[2:]).permute(
0, 1, 3, 4, 2).contiguous() # (batch,c, t, h, w/2+1, 2)
ffted = torch.complex(ffted[..., 0], ffted[..., 1])
ifft_shape_slice = x.shape[-3:] if self.ffc3d else x.shape[-2:]
output = torch.fft.irfftn(ffted, s=ifft_shape_slice, dim=fft_dim, norm=self.fft_norm)
if self.spatial_scale_factor is not None:
output = F.interpolate(output, size=orig_size, mode=self.spatial_scale_mode, align_corners=False)
epsilon = 0.5
output = output - torch.mean(output) + torch.mean(x)
output = torch.clip(output, float(x.min() - epsilon), float(x.max() + epsilon))
self.distill = output # for self perc
return output
if __name__ == '__main__':
in_channels = 16
out_channels = 16
block = FourierUnit_modified(in_channels=in_channels, out_channels=out_channels)
input_tensor = torch.rand(8, in_channels, 32, 32)
output = block(input_tensor)
print("Input size:", input_tensor.size())
print("Output size:", output.size())