pfcfuse/componets/UFFC.py
zjut e1a339e04b feat(net): 替换 SMFA 为 SCSA 并调整相关配置
- 将 SMFA 模块替换为 SCSA 模块
- 更新项目配置,使用本地 Python 3.8 环境
-调整 SCSA 模块参数,如维度、头数等
- 优化注意力机制,提高模型性能
2024-11-08 12:04:52 +08:00

123 lines
5.7 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
"""ICCV2023
最近提出的图像修复方法 LaMa 以快速傅里叶卷积 (FFC) 为基础构建了其网络,该网络最初是为图像分类等高级视觉任务而提出的。
FFC 使全卷积网络在其早期层中拥有全局感受野。得益于 FFC 模块的独特特性LaMa 能够生成稳健的重复纹理,
这是以前的修复方法无法实现的。但是,原始 FFC 模块是否适合图像修复等低级视觉任务?
在本文中,我们分析了在图像修复中使用 FFC 的基本缺陷,即 1) 频谱偏移、2) 意外的空间激活和 3) 频率感受野有限。
这些缺陷使得基于 FFC 的修复框架难以生成复杂纹理并执行完美重建。
基于以上分析,我们提出了一种新颖的无偏快速傅里叶卷积 (UFFC) 模块,该模块通过
1) 范围变换和逆变换、2) 绝对位置嵌入、3) 动态跳过连接和 4) 自适应剪辑对原始 FFC 模块进行了修改,以克服这些缺陷,
实现更好的修复效果。在多个基准数据集上进行的大量实验证明了我们方法的有效性,在纹理捕捉能力和表现力方面均优于最先进的方法。
"""
class FourierUnit_modified(nn.Module):
def __init__(self, in_channels, out_channels, groups=1, spatial_scale_factor=None, spatial_scale_mode='bilinear',
spectral_pos_encoding=False, use_se=False, ffc3d=False, fft_norm='ortho'):
# bn_layer not used
super(FourierUnit_modified, self).__init__()
self.groups = groups
self.input_shape = 32 # change!!!!!it!!!!!!manually!!!!!!
self.in_channels = in_channels
self.locMap = nn.Parameter(torch.rand(self.input_shape, self.input_shape // 2 + 1))
self.lambda_base = nn.Parameter(torch.tensor(0.), requires_grad=True)
self.conv_layer_down55 = torch.nn.Conv2d(in_channels=in_channels * 2 + 1, # +1 for locmap
out_channels=out_channels * 2,
kernel_size=1, stride=1, padding=0, dilation=1, groups=self.groups,
bias=False, padding_mode='reflect')
self.conv_layer_down55_shift = torch.nn.Conv2d(in_channels=in_channels * 2 + 1, # +1 for locmap
out_channels=out_channels * 2,
kernel_size=3, stride=1, padding=2, dilation=2,
groups=self.groups, bias=False, padding_mode='reflect')
self.norm = nn.BatchNorm2d(out_channels)
self.relu = nn.ReLU(inplace=True)
self.spatial_scale_factor = spatial_scale_factor
self.spatial_scale_mode = spatial_scale_mode
self.spectral_pos_encoding = spectral_pos_encoding
self.ffc3d = ffc3d
self.fft_norm = fft_norm
self.img_freq = None
self.distill = None
def forward(self, x):
batch = x.shape[0]
if self.spatial_scale_factor is not None:
orig_size = x.shape[-2:]
x = F.interpolate(x, scale_factor=self.spatial_scale_factor, mode=self.spatial_scale_mode,
align_corners=False)
fft_dim = (-3, -2, -1) if self.ffc3d else (-2, -1)
ffted = torch.fft.rfftn(x, dim=fft_dim, norm=self.fft_norm)
ffted = torch.stack((ffted.real, ffted.imag), dim=-1)
ffted = ffted.permute(0, 1, 4, 2, 3).contiguous() # (batch, c, 2, h, w/2+1)
ffted = ffted.view((batch, -1,) + ffted.size()[3:])
locMap = self.locMap.expand_as(ffted[:, :1, :, :]) # B 1 H' W'
ffted_copy = ffted.clone()
cat_img_mask_freq = torch.cat((ffted[:, :self.in_channels, :, :],
ffted[:, self.in_channels:, :, :],
locMap), dim=1)
ffted = self.conv_layer_down55(cat_img_mask_freq)
ffted = torch.fft.fftshift(ffted, dim=-2)
ffted = self.relu(ffted)
locMap_shift = torch.fft.fftshift(locMap, dim=-2) ## ONLY IF NOT SHIFT BACK
# REPEAT CONV
cat_img_mask_freq1 = torch.cat((ffted[:, :self.in_channels, :, :],
ffted[:, self.in_channels:, :, :],
locMap_shift), dim=1)
ffted = self.conv_layer_down55_shift(cat_img_mask_freq1)
ffted = torch.fft.fftshift(ffted, dim=-2)
lambda_base = torch.sigmoid(self.lambda_base)
ffted = ffted_copy * lambda_base + ffted * (1 - lambda_base)
# irfft
ffted = ffted.view((batch, -1, 2,) + ffted.size()[2:]).permute(
0, 1, 3, 4, 2).contiguous() # (batch,c, t, h, w/2+1, 2)
ffted = torch.complex(ffted[..., 0], ffted[..., 1])
ifft_shape_slice = x.shape[-3:] if self.ffc3d else x.shape[-2:]
output = torch.fft.irfftn(ffted, s=ifft_shape_slice, dim=fft_dim, norm=self.fft_norm)
if self.spatial_scale_factor is not None:
output = F.interpolate(output, size=orig_size, mode=self.spatial_scale_mode, align_corners=False)
epsilon = 0.5
output = output - torch.mean(output) + torch.mean(x)
output = torch.clip(output, float(x.min() - epsilon), float(x.max() + epsilon))
self.distill = output # for self perc
return output
if __name__ == '__main__':
in_channels = 16
out_channels = 16
block = FourierUnit_modified(in_channels=in_channels, out_channels=out_channels)
input_tensor = torch.rand(8, in_channels, 32, 32)
output = block(input_tensor)
print("Input size:", input_tensor.size())
print("Output size:", output.size())