pfcfuse/componets/SMFA.py
zjut ece5f30c2d feat(componets): 新增多个图像处理模块
- 添加 DynamicFilter 模块,用于频域滤波和动态调整滤波器权重
- 添加 SCSA 模块,探索空间注意力和通道注意力之间的协同作用
- 添加 SMFA 模块,自调制特征聚合用于高效图像重建
- 添加 TIAM 模块,时空交互注意力用于变化检测
- 添加 UFFC 模块,无偏快速傅里叶卷积用于图像修复
- 更新 net.py,引入 SCSA 模块替换原有注意力机制
- 优化 train.py,调整导入模块以支持新功能
2024-11-15 09:28:49 +08:00

65 lines
3.0 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import torch
import torch.nn as nn
import torch.nn.functional as F
"""ECCV2024(https://github.com/Zheng-MJ/SMFANet)
基于Transformer的恢复方法取得了显著的效果因为Transformer的自注意力机制SA可以探索非局部信息从而实现更好的高分辨率图像重建。然而关键的点积自注意力需要大量的计算资源这限制了其在低功耗设备上的应用。
此外自注意力机制的低通滤波特性限制了其捕捉局部细节的能力从而导致重建结果过于平滑。为了解决这些问题我们提出了一种自调制特征聚合SMFA模块协同利用局部和非局部特征交互以实现更精确的重建。
具体而言SMFA模块采用了高效的自注意力近似EASA分支来建模非局部信息并使用局部细节估计LDE分支来捕捉局部细节。此外我们还引入了基于部分卷积的前馈网络PCFN以进一步优化从SMFA提取的代表性特征。
大量实验表明所提出的SMFANet系列在公共基准数据集上实现了更好的重建性能与计算效率的平衡。
特别是与SwinIR-light的×4放大相比SMFANet+在五个公共测试集上的平均性能提高了0.14dB运行速度提升了约10倍且模型复杂度如FLOPs仅为其约43%
"""
class DMlp(nn.Module):
def __init__(self, dim, growth_rate=2.0):
super().__init__()
hidden_dim = int(dim * growth_rate)
self.conv_0 = nn.Sequential(
nn.Conv2d(dim, hidden_dim, 3, 1, 1, groups=dim),
nn.Conv2d(hidden_dim, hidden_dim, 1, 1, 0)
)
self.act = nn.GELU()
self.conv_1 = nn.Conv2d(hidden_dim, dim, 1, 1, 0)
def forward(self, x):
x = self.conv_0(x)
x = self.act(x)
x = self.conv_1(x)
return x
class SMFA(nn.Module):
def __init__(self, dim=36):
super(SMFA, self).__init__()
self.linear_0 = nn.Conv2d(dim, dim * 2, 1, 1, 0)
self.linear_1 = nn.Conv2d(dim, dim, 1, 1, 0)
self.linear_2 = nn.Conv2d(dim, dim, 1, 1, 0)
self.lde = DMlp(dim, 2)
self.dw_conv = nn.Conv2d(dim, dim, 3, 1, 1, groups=dim)
self.gelu = nn.GELU()
self.down_scale = 8
self.alpha = nn.Parameter(torch.ones((1, dim, 1, 1)))
self.belt = nn.Parameter(torch.zeros((1, dim, 1, 1)))
def forward(self, f):
_, _, h, w = f.shape
y, x = self.linear_0(f).chunk(2, dim=1)
x_s = self.dw_conv(F.adaptive_max_pool2d(x, (h // self.down_scale, w // self.down_scale)))
x_v = torch.var(x, dim=(-2, -1), keepdim=True)
x_l = x * F.interpolate(self.gelu(self.linear_1(x_s * self.alpha + x_v * self.belt)), size=(h, w),
mode='nearest')
y_d = self.lde(y)
return self.linear_2(x_l + y_d)
if __name__ == '__main__':
block = SMFA(dim=36)
input = torch.randn(3, 36, 64, 64)
output = block(input)
print(input.size())
print(output.size())