pfcfuse/componets/SMFA.py
zjut 927d049f13 feat(components): 添加多种特征增强模块
- 新增 DEConv、DynamicFilter、SCSA、SEBlock、SMFA、TIAM 和 UFFC 模块
- 这些模块提供了不同的特征增强功能,如卷积差分、频域滤波、注意力机制等
- 可以根据需求选择合适的模块来提升模型性能
2024-11-15 09:18:34 +08:00

65 lines
3.0 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import torch
import torch.nn as nn
import torch.nn.functional as F
"""ECCV2024(https://github.com/Zheng-MJ/SMFANet)
基于Transformer的恢复方法取得了显著的效果因为Transformer的自注意力机制SA可以探索非局部信息从而实现更好的高分辨率图像重建。然而关键的点积自注意力需要大量的计算资源这限制了其在低功耗设备上的应用。
此外自注意力机制的低通滤波特性限制了其捕捉局部细节的能力从而导致重建结果过于平滑。为了解决这些问题我们提出了一种自调制特征聚合SMFA模块协同利用局部和非局部特征交互以实现更精确的重建。
具体而言SMFA模块采用了高效的自注意力近似EASA分支来建模非局部信息并使用局部细节估计LDE分支来捕捉局部细节。此外我们还引入了基于部分卷积的前馈网络PCFN以进一步优化从SMFA提取的代表性特征。
大量实验表明所提出的SMFANet系列在公共基准数据集上实现了更好的重建性能与计算效率的平衡。
特别是与SwinIR-light的×4放大相比SMFANet+在五个公共测试集上的平均性能提高了0.14dB运行速度提升了约10倍且模型复杂度如FLOPs仅为其约43%
"""
class DMlp(nn.Module):
def __init__(self, dim, growth_rate=2.0):
super().__init__()
hidden_dim = int(dim * growth_rate)
self.conv_0 = nn.Sequential(
nn.Conv2d(dim, hidden_dim, 3, 1, 1, groups=dim),
nn.Conv2d(hidden_dim, hidden_dim, 1, 1, 0)
)
self.act = nn.GELU()
self.conv_1 = nn.Conv2d(hidden_dim, dim, 1, 1, 0)
def forward(self, x):
x = self.conv_0(x)
x = self.act(x)
x = self.conv_1(x)
return x
class SMFA(nn.Module):
def __init__(self, dim=36):
super(SMFA, self).__init__()
self.linear_0 = nn.Conv2d(dim, dim * 2, 1, 1, 0)
self.linear_1 = nn.Conv2d(dim, dim, 1, 1, 0)
self.linear_2 = nn.Conv2d(dim, dim, 1, 1, 0)
self.lde = DMlp(dim, 2)
self.dw_conv = nn.Conv2d(dim, dim, 3, 1, 1, groups=dim)
self.gelu = nn.GELU()
self.down_scale = 8
self.alpha = nn.Parameter(torch.ones((1, dim, 1, 1)))
self.belt = nn.Parameter(torch.zeros((1, dim, 1, 1)))
def forward(self, f):
_, _, h, w = f.shape
y, x = self.linear_0(f).chunk(2, dim=1)
x_s = self.dw_conv(F.adaptive_max_pool2d(x, (h // self.down_scale, w // self.down_scale)))
x_v = torch.var(x, dim=(-2, -1), keepdim=True)
x_l = x * F.interpolate(self.gelu(self.linear_1(x_s * self.alpha + x_v * self.belt)), size=(h, w),
mode='nearest')
y_d = self.lde(y)
return self.linear_2(x_l + y_d)
if __name__ == '__main__':
block = SMFA(dim=36)
input = torch.randn(3, 36, 64, 64)
output = block(input)
print(input.size())
print(output.size())