pfcfuse/componets/UFFC.py
zjut ece5f30c2d feat(componets): 新增多个图像处理模块
- 添加 DynamicFilter 模块,用于频域滤波和动态调整滤波器权重
- 添加 SCSA 模块,探索空间注意力和通道注意力之间的协同作用
- 添加 SMFA 模块,自调制特征聚合用于高效图像重建
- 添加 TIAM 模块,时空交互注意力用于变化检测
- 添加 UFFC 模块,无偏快速傅里叶卷积用于图像修复
- 更新 net.py,引入 SCSA 模块替换原有注意力机制
- 优化 train.py,调整导入模块以支持新功能
2024-11-15 09:28:49 +08:00

123 lines
5.7 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
"""ICCV2023
最近提出的图像修复方法 LaMa 以快速傅里叶卷积 (FFC) 为基础构建了其网络,该网络最初是为图像分类等高级视觉任务而提出的。
FFC 使全卷积网络在其早期层中拥有全局感受野。得益于 FFC 模块的独特特性LaMa 能够生成稳健的重复纹理,
这是以前的修复方法无法实现的。但是,原始 FFC 模块是否适合图像修复等低级视觉任务?
在本文中,我们分析了在图像修复中使用 FFC 的基本缺陷,即 1) 频谱偏移、2) 意外的空间激活和 3) 频率感受野有限。
这些缺陷使得基于 FFC 的修复框架难以生成复杂纹理并执行完美重建。
基于以上分析,我们提出了一种新颖的无偏快速傅里叶卷积 (UFFC) 模块,该模块通过
1) 范围变换和逆变换、2) 绝对位置嵌入、3) 动态跳过连接和 4) 自适应剪辑对原始 FFC 模块进行了修改,以克服这些缺陷,
实现更好的修复效果。在多个基准数据集上进行的大量实验证明了我们方法的有效性,在纹理捕捉能力和表现力方面均优于最先进的方法。
"""
class FourierUnit_modified(nn.Module):
def __init__(self, in_channels, out_channels, groups=1, spatial_scale_factor=None, spatial_scale_mode='bilinear',
spectral_pos_encoding=False, use_se=False, ffc3d=False, fft_norm='ortho'):
# bn_layer not used
super(FourierUnit_modified, self).__init__()
self.groups = groups
self.input_shape = 32 # change!!!!!it!!!!!!manually!!!!!!
self.in_channels = in_channels
self.locMap = nn.Parameter(torch.rand(self.input_shape, self.input_shape // 2 + 1))
self.lambda_base = nn.Parameter(torch.tensor(0.), requires_grad=True)
self.conv_layer_down55 = torch.nn.Conv2d(in_channels=in_channels * 2 + 1, # +1 for locmap
out_channels=out_channels * 2,
kernel_size=1, stride=1, padding=0, dilation=1, groups=self.groups,
bias=False, padding_mode='reflect')
self.conv_layer_down55_shift = torch.nn.Conv2d(in_channels=in_channels * 2 + 1, # +1 for locmap
out_channels=out_channels * 2,
kernel_size=3, stride=1, padding=2, dilation=2,
groups=self.groups, bias=False, padding_mode='reflect')
self.norm = nn.BatchNorm2d(out_channels)
self.relu = nn.ReLU(inplace=True)
self.spatial_scale_factor = spatial_scale_factor
self.spatial_scale_mode = spatial_scale_mode
self.spectral_pos_encoding = spectral_pos_encoding
self.ffc3d = ffc3d
self.fft_norm = fft_norm
self.img_freq = None
self.distill = None
def forward(self, x):
batch = x.shape[0]
if self.spatial_scale_factor is not None:
orig_size = x.shape[-2:]
x = F.interpolate(x, scale_factor=self.spatial_scale_factor, mode=self.spatial_scale_mode,
align_corners=False)
fft_dim = (-3, -2, -1) if self.ffc3d else (-2, -1)
ffted = torch.fft.rfftn(x, dim=fft_dim, norm=self.fft_norm)
ffted = torch.stack((ffted.real, ffted.imag), dim=-1)
ffted = ffted.permute(0, 1, 4, 2, 3).contiguous() # (batch, c, 2, h, w/2+1)
ffted = ffted.view((batch, -1,) + ffted.size()[3:])
locMap = self.locMap.expand_as(ffted[:, :1, :, :]) # B 1 H' W'
ffted_copy = ffted.clone()
cat_img_mask_freq = torch.cat((ffted[:, :self.in_channels, :, :],
ffted[:, self.in_channels:, :, :],
locMap), dim=1)
ffted = self.conv_layer_down55(cat_img_mask_freq)
ffted = torch.fft.fftshift(ffted, dim=-2)
ffted = self.relu(ffted)
locMap_shift = torch.fft.fftshift(locMap, dim=-2) ## ONLY IF NOT SHIFT BACK
# REPEAT CONV
cat_img_mask_freq1 = torch.cat((ffted[:, :self.in_channels, :, :],
ffted[:, self.in_channels:, :, :],
locMap_shift), dim=1)
ffted = self.conv_layer_down55_shift(cat_img_mask_freq1)
ffted = torch.fft.fftshift(ffted, dim=-2)
lambda_base = torch.sigmoid(self.lambda_base)
ffted = ffted_copy * lambda_base + ffted * (1 - lambda_base)
# irfft
ffted = ffted.view((batch, -1, 2,) + ffted.size()[2:]).permute(
0, 1, 3, 4, 2).contiguous() # (batch,c, t, h, w/2+1, 2)
ffted = torch.complex(ffted[..., 0], ffted[..., 1])
ifft_shape_slice = x.shape[-3:] if self.ffc3d else x.shape[-2:]
output = torch.fft.irfftn(ffted, s=ifft_shape_slice, dim=fft_dim, norm=self.fft_norm)
if self.spatial_scale_factor is not None:
output = F.interpolate(output, size=orig_size, mode=self.spatial_scale_mode, align_corners=False)
epsilon = 0.5
output = output - torch.mean(output) + torch.mean(x)
output = torch.clip(output, float(x.min() - epsilon), float(x.max() + epsilon))
self.distill = output # for self perc
return output
if __name__ == '__main__':
in_channels = 16
out_channels = 16
block = FourierUnit_modified(in_channels=in_channels, out_channels=out_channels)
input_tensor = torch.rand(8, in_channels, 32, 32)
output = block(input_tensor)
print("Input size:", input_tensor.size())
print("Output size:", output.size())