import torch import torch.nn as nn import torch.nn.functional as F """ECCV2024(https://github.com/Zheng-MJ/SMFANet) 基于Transformer的恢复方法取得了显著的效果,因为Transformer的自注意力机制(SA)可以探索非局部信息,从而实现更好的高分辨率图像重建。然而,关键的点积自注意力需要大量的计算资源,这限制了其在低功耗设备上的应用。 此外,自注意力机制的低通滤波特性限制了其捕捉局部细节的能力,从而导致重建结果过于平滑。为了解决这些问题,我们提出了一种自调制特征聚合(SMFA)模块,协同利用局部和非局部特征交互,以实现更精确的重建。 具体而言,SMFA模块采用了高效的自注意力近似(EASA)分支来建模非局部信息,并使用局部细节估计(LDE)分支来捕捉局部细节。此外,我们还引入了基于部分卷积的前馈网络(PCFN),以进一步优化从SMFA提取的代表性特征。 大量实验表明,所提出的SMFANet系列在公共基准数据集上实现了更好的重建性能与计算效率的平衡。 特别是,与SwinIR-light的×4放大相比,SMFANet+在五个公共测试集上的平均性能提高了0.14dB,运行速度提升了约10倍,且模型复杂度(如FLOPs)仅为其约43%。 """ class DMlp(nn.Module): def __init__(self, dim, growth_rate=2.0): super().__init__() hidden_dim = int(dim * growth_rate) self.conv_0 = nn.Sequential( nn.Conv2d(dim, hidden_dim, 3, 1, 1, groups=dim), nn.Conv2d(hidden_dim, hidden_dim, 1, 1, 0) ) self.act = nn.GELU() self.conv_1 = nn.Conv2d(hidden_dim, dim, 1, 1, 0) def forward(self, x): x = self.conv_0(x) x = self.act(x) x = self.conv_1(x) return x class SMFA(nn.Module): def __init__(self, dim=36): super(SMFA, self).__init__() self.linear_0 = nn.Conv2d(dim, dim * 2, 1, 1, 0) self.linear_1 = nn.Conv2d(dim, dim, 1, 1, 0) self.linear_2 = nn.Conv2d(dim, dim, 1, 1, 0) self.lde = DMlp(dim, 2) self.dw_conv = nn.Conv2d(dim, dim, 3, 1, 1, groups=dim) self.gelu = nn.GELU() self.down_scale = 8 self.alpha = nn.Parameter(torch.ones((1, dim, 1, 1))) self.belt = nn.Parameter(torch.zeros((1, dim, 1, 1))) def forward(self, f): _, _, h, w = f.shape y, x = self.linear_0(f).chunk(2, dim=1) x_s = self.dw_conv(F.adaptive_max_pool2d(x, (h // self.down_scale, w // self.down_scale))) x_v = torch.var(x, dim=(-2, -1), keepdim=True) x_l = x * F.interpolate(self.gelu(self.linear_1(x_s * self.alpha + x_v * self.belt)), size=(h, w), mode='nearest') y_d = self.lde(y) return self.linear_2(x_l + y_d) if __name__ == '__main__': block = SMFA(dim=36) input = torch.randn(3, 36, 64, 64) output = block(input) print(input.size()) print(output.size())