927d049f13
- 新增 DEConv、DynamicFilter、SCSA、SEBlock、SMFA、TIAM 和 UFFC 模块 - 这些模块提供了不同的特征增强功能,如卷积差分、频域滤波、注意力机制等 - 可以根据需求选择合适的模块来提升模型性能
123 lines
5.7 KiB
Python
123 lines
5.7 KiB
Python
import numpy as np
|
||
import torch
|
||
import torch.nn as nn
|
||
import torch.nn.functional as F
|
||
|
||
"""ICCV2023
|
||
最近提出的图像修复方法 LaMa 以快速傅里叶卷积 (FFC) 为基础构建了其网络,该网络最初是为图像分类等高级视觉任务而提出的。
|
||
FFC 使全卷积网络在其早期层中拥有全局感受野。得益于 FFC 模块的独特特性,LaMa 能够生成稳健的重复纹理,
|
||
这是以前的修复方法无法实现的。但是,原始 FFC 模块是否适合图像修复等低级视觉任务?
|
||
在本文中,我们分析了在图像修复中使用 FFC 的基本缺陷,即 1) 频谱偏移、2) 意外的空间激活和 3) 频率感受野有限。
|
||
这些缺陷使得基于 FFC 的修复框架难以生成复杂纹理并执行完美重建。
|
||
基于以上分析,我们提出了一种新颖的无偏快速傅里叶卷积 (UFFC) 模块,该模块通过
|
||
1) 范围变换和逆变换、2) 绝对位置嵌入、3) 动态跳过连接和 4) 自适应剪辑对原始 FFC 模块进行了修改,以克服这些缺陷,
|
||
实现更好的修复效果。在多个基准数据集上进行的大量实验证明了我们方法的有效性,在纹理捕捉能力和表现力方面均优于最先进的方法。
|
||
"""
|
||
|
||
class FourierUnit_modified(nn.Module):
|
||
|
||
def __init__(self, in_channels, out_channels, groups=1, spatial_scale_factor=None, spatial_scale_mode='bilinear',
|
||
spectral_pos_encoding=False, use_se=False, ffc3d=False, fft_norm='ortho'):
|
||
# bn_layer not used
|
||
super(FourierUnit_modified, self).__init__()
|
||
self.groups = groups
|
||
|
||
self.input_shape = 32 # change!!!!!it!!!!!!manually!!!!!!
|
||
self.in_channels = in_channels
|
||
|
||
self.locMap = nn.Parameter(torch.rand(self.input_shape, self.input_shape // 2 + 1))
|
||
|
||
self.lambda_base = nn.Parameter(torch.tensor(0.), requires_grad=True)
|
||
|
||
self.conv_layer_down55 = torch.nn.Conv2d(in_channels=in_channels * 2 + 1, # +1 for locmap
|
||
out_channels=out_channels * 2,
|
||
kernel_size=1, stride=1, padding=0, dilation=1, groups=self.groups,
|
||
bias=False, padding_mode='reflect')
|
||
self.conv_layer_down55_shift = torch.nn.Conv2d(in_channels=in_channels * 2 + 1, # +1 for locmap
|
||
out_channels=out_channels * 2,
|
||
kernel_size=3, stride=1, padding=2, dilation=2,
|
||
groups=self.groups, bias=False, padding_mode='reflect')
|
||
|
||
self.norm = nn.BatchNorm2d(out_channels)
|
||
|
||
self.relu = nn.ReLU(inplace=True)
|
||
|
||
self.spatial_scale_factor = spatial_scale_factor
|
||
self.spatial_scale_mode = spatial_scale_mode
|
||
self.spectral_pos_encoding = spectral_pos_encoding
|
||
self.ffc3d = ffc3d
|
||
self.fft_norm = fft_norm
|
||
|
||
self.img_freq = None
|
||
self.distill = None
|
||
|
||
def forward(self, x):
|
||
batch = x.shape[0]
|
||
|
||
if self.spatial_scale_factor is not None:
|
||
orig_size = x.shape[-2:]
|
||
x = F.interpolate(x, scale_factor=self.spatial_scale_factor, mode=self.spatial_scale_mode,
|
||
align_corners=False)
|
||
|
||
fft_dim = (-3, -2, -1) if self.ffc3d else (-2, -1)
|
||
ffted = torch.fft.rfftn(x, dim=fft_dim, norm=self.fft_norm)
|
||
ffted = torch.stack((ffted.real, ffted.imag), dim=-1)
|
||
ffted = ffted.permute(0, 1, 4, 2, 3).contiguous() # (batch, c, 2, h, w/2+1)
|
||
ffted = ffted.view((batch, -1,) + ffted.size()[3:])
|
||
|
||
locMap = self.locMap.expand_as(ffted[:, :1, :, :]) # B 1 H' W'
|
||
ffted_copy = ffted.clone()
|
||
|
||
cat_img_mask_freq = torch.cat((ffted[:, :self.in_channels, :, :],
|
||
ffted[:, self.in_channels:, :, :],
|
||
locMap), dim=1)
|
||
|
||
ffted = self.conv_layer_down55(cat_img_mask_freq)
|
||
ffted = torch.fft.fftshift(ffted, dim=-2)
|
||
|
||
ffted = self.relu(ffted)
|
||
|
||
locMap_shift = torch.fft.fftshift(locMap, dim=-2) ## ONLY IF NOT SHIFT BACK
|
||
|
||
# REPEAT CONV
|
||
cat_img_mask_freq1 = torch.cat((ffted[:, :self.in_channels, :, :],
|
||
ffted[:, self.in_channels:, :, :],
|
||
locMap_shift), dim=1)
|
||
|
||
ffted = self.conv_layer_down55_shift(cat_img_mask_freq1)
|
||
ffted = torch.fft.fftshift(ffted, dim=-2)
|
||
|
||
lambda_base = torch.sigmoid(self.lambda_base)
|
||
|
||
ffted = ffted_copy * lambda_base + ffted * (1 - lambda_base)
|
||
|
||
# irfft
|
||
ffted = ffted.view((batch, -1, 2,) + ffted.size()[2:]).permute(
|
||
0, 1, 3, 4, 2).contiguous() # (batch,c, t, h, w/2+1, 2)
|
||
ffted = torch.complex(ffted[..., 0], ffted[..., 1])
|
||
|
||
ifft_shape_slice = x.shape[-3:] if self.ffc3d else x.shape[-2:]
|
||
output = torch.fft.irfftn(ffted, s=ifft_shape_slice, dim=fft_dim, norm=self.fft_norm)
|
||
|
||
if self.spatial_scale_factor is not None:
|
||
output = F.interpolate(output, size=orig_size, mode=self.spatial_scale_mode, align_corners=False)
|
||
|
||
epsilon = 0.5
|
||
output = output - torch.mean(output) + torch.mean(x)
|
||
output = torch.clip(output, float(x.min() - epsilon), float(x.max() + epsilon))
|
||
|
||
self.distill = output # for self perc
|
||
return output
|
||
|
||
if __name__ == '__main__':
|
||
in_channels = 16
|
||
out_channels = 16
|
||
|
||
block = FourierUnit_modified(in_channels=in_channels, out_channels=out_channels)
|
||
|
||
input_tensor = torch.rand(8, in_channels, 32, 32)
|
||
|
||
output = block(input_tensor)
|
||
|
||
print("Input size:", input_tensor.size())
|
||
print("Output size:", output.size()) |