41b2ea1ff9
- 在 net.py 中引入 SMFA 组件 - 优化 BasicLayer 类的前向传播逻辑 - 添加 SMFA、DynamicFilter 和 UFFC 组件的实现 - 使用SMFA替代Pooling self.WTConv2d = WTConv2d(dim, dim) self.norm1 = LayerNorm(dim, 'WithBias') self.token_mixer = SMFA(dim=dim) # self.token_mixer = Pooling(kernel_size=pool_size) # vits是msa,MLPs是mlp,这个用pool来替代 self.norm2 = LayerNorm(dim, 'WithBias') mlp_hidden_dim = int(dim * mlp_ratio) self.poolmlp = PoolMlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
65 lines
3.0 KiB
Python
65 lines
3.0 KiB
Python
import torch
|
||
import torch.nn as nn
|
||
import torch.nn.functional as F
|
||
|
||
|
||
"""ECCV2024(https://github.com/Zheng-MJ/SMFANet)
|
||
基于Transformer的恢复方法取得了显著的效果,因为Transformer的自注意力机制(SA)可以探索非局部信息,从而实现更好的高分辨率图像重建。然而,关键的点积自注意力需要大量的计算资源,这限制了其在低功耗设备上的应用。
|
||
此外,自注意力机制的低通滤波特性限制了其捕捉局部细节的能力,从而导致重建结果过于平滑。为了解决这些问题,我们提出了一种自调制特征聚合(SMFA)模块,协同利用局部和非局部特征交互,以实现更精确的重建。
|
||
具体而言,SMFA模块采用了高效的自注意力近似(EASA)分支来建模非局部信息,并使用局部细节估计(LDE)分支来捕捉局部细节。此外,我们还引入了基于部分卷积的前馈网络(PCFN),以进一步优化从SMFA提取的代表性特征。
|
||
大量实验表明,所提出的SMFANet系列在公共基准数据集上实现了更好的重建性能与计算效率的平衡。
|
||
特别是,与SwinIR-light的×4放大相比,SMFANet+在五个公共测试集上的平均性能提高了0.14dB,运行速度提升了约10倍,且模型复杂度(如FLOPs)仅为其约43%。
|
||
"""
|
||
|
||
class DMlp(nn.Module):
|
||
def __init__(self, dim, growth_rate=2.0):
|
||
super().__init__()
|
||
hidden_dim = int(dim * growth_rate)
|
||
self.conv_0 = nn.Sequential(
|
||
nn.Conv2d(dim, hidden_dim, 3, 1, 1, groups=dim),
|
||
nn.Conv2d(hidden_dim, hidden_dim, 1, 1, 0)
|
||
)
|
||
self.act = nn.GELU()
|
||
self.conv_1 = nn.Conv2d(hidden_dim, dim, 1, 1, 0)
|
||
|
||
def forward(self, x):
|
||
x = self.conv_0(x)
|
||
x = self.act(x)
|
||
x = self.conv_1(x)
|
||
return x
|
||
|
||
|
||
class SMFA(nn.Module):
|
||
def __init__(self, dim=36):
|
||
super(SMFA, self).__init__()
|
||
self.linear_0 = nn.Conv2d(dim, dim * 2, 1, 1, 0)
|
||
self.linear_1 = nn.Conv2d(dim, dim, 1, 1, 0)
|
||
self.linear_2 = nn.Conv2d(dim, dim, 1, 1, 0)
|
||
|
||
self.lde = DMlp(dim, 2)
|
||
|
||
self.dw_conv = nn.Conv2d(dim, dim, 3, 1, 1, groups=dim)
|
||
|
||
self.gelu = nn.GELU()
|
||
self.down_scale = 8
|
||
|
||
self.alpha = nn.Parameter(torch.ones((1, dim, 1, 1)))
|
||
self.belt = nn.Parameter(torch.zeros((1, dim, 1, 1)))
|
||
|
||
def forward(self, f):
|
||
_, _, h, w = f.shape
|
||
y, x = self.linear_0(f).chunk(2, dim=1)
|
||
x_s = self.dw_conv(F.adaptive_max_pool2d(x, (h // self.down_scale, w // self.down_scale)))
|
||
x_v = torch.var(x, dim=(-2, -1), keepdim=True)
|
||
x_l = x * F.interpolate(self.gelu(self.linear_1(x_s * self.alpha + x_v * self.belt)), size=(h, w),
|
||
mode='nearest')
|
||
y_d = self.lde(y)
|
||
return self.linear_2(x_l + y_d)
|
||
|
||
|
||
if __name__ == '__main__':
|
||
block = SMFA(dim=36)
|
||
input = torch.randn(3, 36, 64, 64)
|
||
output = block(input)
|
||
print(input.size())
|
||
print(output.size()) |