pfcfuse/componets/SMFA.py

65 lines
3.0 KiB
Python
Raw Normal View History

import torch
import torch.nn as nn
import torch.nn.functional as F
"""ECCV2024(https://github.com/Zheng-MJ/SMFANet)
基于Transformer的恢复方法取得了显著的效果因为Transformer的自注意力机制SA可以探索非局部信息从而实现更好的高分辨率图像重建然而关键的点积自注意力需要大量的计算资源这限制了其在低功耗设备上的应用
此外自注意力机制的低通滤波特性限制了其捕捉局部细节的能力从而导致重建结果过于平滑为了解决这些问题我们提出了一种自调制特征聚合SMFA模块协同利用局部和非局部特征交互以实现更精确的重建
具体而言SMFA模块采用了高效的自注意力近似EASA分支来建模非局部信息并使用局部细节估计LDE分支来捕捉局部细节此外我们还引入了基于部分卷积的前馈网络PCFN以进一步优化从SMFA提取的代表性特征
大量实验表明所提出的SMFANet系列在公共基准数据集上实现了更好的重建性能与计算效率的平衡
特别是与SwinIR-light的×4放大相比SMFANet+在五个公共测试集上的平均性能提高了0.14dB运行速度提升了约10倍且模型复杂度如FLOPs仅为其约43%
"""
class DMlp(nn.Module):
def __init__(self, dim, growth_rate=2.0):
super().__init__()
hidden_dim = int(dim * growth_rate)
self.conv_0 = nn.Sequential(
nn.Conv2d(dim, hidden_dim, 3, 1, 1, groups=dim),
nn.Conv2d(hidden_dim, hidden_dim, 1, 1, 0)
)
self.act = nn.GELU()
self.conv_1 = nn.Conv2d(hidden_dim, dim, 1, 1, 0)
def forward(self, x):
x = self.conv_0(x)
x = self.act(x)
x = self.conv_1(x)
return x
class SMFA(nn.Module):
def __init__(self, dim=36):
super(SMFA, self).__init__()
self.linear_0 = nn.Conv2d(dim, dim * 2, 1, 1, 0)
self.linear_1 = nn.Conv2d(dim, dim, 1, 1, 0)
self.linear_2 = nn.Conv2d(dim, dim, 1, 1, 0)
self.lde = DMlp(dim, 2)
self.dw_conv = nn.Conv2d(dim, dim, 3, 1, 1, groups=dim)
self.gelu = nn.GELU()
self.down_scale = 8
self.alpha = nn.Parameter(torch.ones((1, dim, 1, 1)))
self.belt = nn.Parameter(torch.zeros((1, dim, 1, 1)))
def forward(self, f):
_, _, h, w = f.shape
y, x = self.linear_0(f).chunk(2, dim=1)
x_s = self.dw_conv(F.adaptive_max_pool2d(x, (h // self.down_scale, w // self.down_scale)))
x_v = torch.var(x, dim=(-2, -1), keepdim=True)
x_l = x * F.interpolate(self.gelu(self.linear_1(x_s * self.alpha + x_v * self.belt)), size=(h, w),
mode='nearest')
y_d = self.lde(y)
return self.linear_2(x_l + y_d)
if __name__ == '__main__':
block = SMFA(dim=36)
input = torch.randn(3, 36, 64, 64)
output = block(input)
print(input.size())
print(output.size())