ece5f30c2d
- 添加 DynamicFilter 模块,用于频域滤波和动态调整滤波器权重 - 添加 SCSA 模块,探索空间注意力和通道注意力之间的协同作用 - 添加 SMFA 模块,自调制特征聚合用于高效图像重建 - 添加 TIAM 模块,时空交互注意力用于变化检测 - 添加 UFFC 模块,无偏快速傅里叶卷积用于图像修复 - 更新 net.py,引入 SCSA 模块替换原有注意力机制 - 优化 train.py,调整导入模块以支持新功能
116 lines
4.7 KiB
Python
116 lines
4.7 KiB
Python
import torch
|
||
import torch.nn as nn
|
||
from timm.layers.helpers import to_2tuple
|
||
|
||
"""
|
||
配备多头自注意力 (MHSA) 的模型在计算机视觉方面取得了显着的性能。它们的计算复杂度与输入特征图中的二次像素数成正比,导致处理速度缓慢,尤其是在处理高分辨率图像时。
|
||
为了规避这个问题,提出了一种新型的代币混合器作为MHSA的替代方案:基于FFT的代币混合器涉及类似于MHSA的全局操作,但计算复杂度较低。
|
||
在这里,我们提出了一种名为动态过滤器的新型令牌混合器以缩小上述差距。
|
||
DynamicFilter 模块通过频域滤波和动态调整滤波器权重,能够对图像进行复杂的增强和处理。
|
||
"""
|
||
|
||
class StarReLU(nn.Module):
|
||
"""
|
||
StarReLU: s * relu(x) ** 2 + b
|
||
"""
|
||
|
||
def __init__(self, scale_value=1.0, bias_value=0.0,
|
||
scale_learnable=True, bias_learnable=True,
|
||
mode=None, inplace=False):
|
||
super().__init__()
|
||
self.inplace = inplace
|
||
self.relu = nn.ReLU(inplace=inplace)
|
||
self.scale = nn.Parameter(scale_value * torch.ones(1),
|
||
requires_grad=scale_learnable)
|
||
self.bias = nn.Parameter(bias_value * torch.ones(1),
|
||
requires_grad=bias_learnable)
|
||
|
||
def forward(self, x):
|
||
return self.scale * self.relu(x) ** 2 + self.bias
|
||
|
||
class Mlp(nn.Module):
|
||
""" MLP as used in MetaFormer models, eg Transformer, MLP-Mixer, PoolFormer, MetaFormer baslines and related networks.
|
||
Mostly copied from timm.
|
||
"""
|
||
|
||
def __init__(self, dim, mlp_ratio=4, out_features=None, act_layer=StarReLU, drop=0.,
|
||
bias=False, **kwargs):
|
||
super().__init__()
|
||
in_features = dim
|
||
out_features = out_features or in_features
|
||
hidden_features = int(mlp_ratio * in_features)
|
||
drop_probs = to_2tuple(drop)
|
||
|
||
self.fc1 = nn.Linear(in_features, hidden_features, bias=bias)
|
||
self.act = act_layer()
|
||
self.drop1 = nn.Dropout(drop_probs[0])
|
||
self.fc2 = nn.Linear(hidden_features, out_features, bias=bias)
|
||
self.drop2 = nn.Dropout(drop_probs[1])
|
||
|
||
def forward(self, x):
|
||
x = self.fc1(x)
|
||
x = self.act(x)
|
||
x = self.drop1(x)
|
||
x = self.fc2(x)
|
||
x = self.drop2(x)
|
||
return x
|
||
|
||
|
||
class DynamicFilter(nn.Module):
|
||
def __init__(self, dim, expansion_ratio=2, reweight_expansion_ratio=.25,
|
||
act1_layer=StarReLU, act2_layer=nn.Identity,
|
||
bias=False, num_filters=4, size=14, weight_resize=False,
|
||
**kwargs):
|
||
super().__init__()
|
||
size = to_2tuple(size)
|
||
self.size = size[0]
|
||
self.filter_size = size[1] // 2 + 1
|
||
self.num_filters = num_filters
|
||
self.dim = dim
|
||
self.med_channels = int(expansion_ratio * dim)
|
||
self.weight_resize = weight_resize
|
||
self.pwconv1 = nn.Linear(dim, self.med_channels, bias=bias)
|
||
self.act1 = act1_layer()
|
||
self.reweight = Mlp(dim, reweight_expansion_ratio, num_filters * self.med_channels)
|
||
self.complex_weights = nn.Parameter(
|
||
torch.randn(self.size, self.filter_size, num_filters, 2,
|
||
dtype=torch.float32) * 0.02)
|
||
self.act2 = act2_layer()
|
||
self.pwconv2 = nn.Linear(self.med_channels, dim, bias=bias)
|
||
|
||
def forward(self, x):
|
||
B, H, W, _ = x.shape
|
||
|
||
routeing = self.reweight(x.mean(dim=(1, 2))).view(B, self.num_filters,
|
||
-1).softmax(dim=1)
|
||
x = self.pwconv1(x)
|
||
x = self.act1(x)
|
||
x = x.to(torch.float32)
|
||
x = torch.fft.rfft2(x, dim=(1, 2), norm='ortho')
|
||
|
||
if self.weight_resize:
|
||
complex_weights = resize_complex_weight(self.complex_weights, x.shape[1],
|
||
x.shape[2])
|
||
complex_weights = torch.view_as_complex(complex_weights.contiguous())
|
||
else:
|
||
complex_weights = torch.view_as_complex(self.complex_weights)
|
||
routeing = routeing.to(torch.complex64)
|
||
weight = torch.einsum('bfc,hwf->bhwc', routeing, complex_weights)
|
||
if self.weight_resize:
|
||
weight = weight.view(-1, x.shape[1], x.shape[2], self.med_channels)
|
||
else:
|
||
weight = weight.view(-1, self.size, self.filter_size, self.med_channels)
|
||
x = x * weight
|
||
x = torch.fft.irfft2(x, s=(H, W), dim=(1, 2), norm='ortho')
|
||
|
||
x = self.act2(x)
|
||
x = self.pwconv2(x)
|
||
return x
|
||
|
||
|
||
if __name__ == '__main__':
|
||
block = DynamicFilter(32, size=64) # size==H,W
|
||
input = torch.rand(3, 64, 64, 32)
|
||
output = block(input)
|
||
print(input.size())
|
||
print(output.size()) |