feat(components): 添加多种特征增强模块
- 新增 DEConv、DynamicFilter、SCSA、SEBlock、SMFA、TIAM 和 UFFC 模块 - 这些模块提供了不同的特征增强功能,如卷积差分、频域滤波、注意力机制等 - 可以根据需求选择合适的模块来提升模型性能
This commit is contained in:
parent
c1eed72f24
commit
927d049f13
171
componets/DEConv.py
Normal file
171
componets/DEConv.py
Normal file
@ -0,0 +1,171 @@
|
|||||||
|
# https://github.com/cecret3350/DEA-Net/blob/main/code/model/modules/deconv.py
|
||||||
|
|
||||||
|
import math
|
||||||
|
import torch
|
||||||
|
from torch import nn
|
||||||
|
from einops.layers.torch import Rearrange
|
||||||
|
|
||||||
|
"""
|
||||||
|
在深度学习和图像处理领域,"vanilla" 和 "difference" 卷积是两种不同的卷积操作,它们各自有不同的特性和应用场景。DEConv(细节增强卷积)的设计思想是结合这两种卷积的特性来增强模型的性能,尤其是在图像去雾等任务中。
|
||||||
|
|
||||||
|
Vanilla Convolution(标准卷积)
|
||||||
|
"Vanilla" 卷积是最基本的卷积类型,通常仅称为卷积。它是卷积神经网络(CNN)中最常用的组件,用于提取输入数据(如图像)的特征。
|
||||||
|
标准卷积通过在输入数据上滑动小的、可学习的过滤器(或称为核),并计算过滤器与数据的局部区域之间的点乘,来工作。通过这种方式,它能够捕获输入数据的局部模式和特征。
|
||||||
|
|
||||||
|
Difference Convolution(差分卷积)
|
||||||
|
差分卷积是一种特殊类型的卷积,它专注于捕捉输入数据中的局部差异信息,例如边缘或纹理的变化。
|
||||||
|
它通过修改标准卷积核的权重或者通过特殊的操作来实现,使得网络更加关注于图像的高频信息,即图像中的细节和纹理变化。在图像处理任务中,如图像去雾、图像增强、边缘检测等,捕获这种高频信息非常重要,因为它们往往包含了关于物体边界和结构的关键信息。
|
||||||
|
|
||||||
|
重参数化技术
|
||||||
|
重参数化技术是一种参数转换方法,它允许模型在不增加额外参数和计算代价的情况下,实现更复杂的功能或改善性能。在DEConv的上下文中,重参数化技术使得将vanilla卷积和difference卷积结合起来的操作,可以等价地转换成一个标准的卷积操作。
|
||||||
|
这意味着DEConv可以在不增加额外参数和计算成本的情况下,通过巧妙地设计卷积核权重,同时利用标准卷积和差分卷积的优势,从而增强模型处理图像的能力。
|
||||||
|
具体来说,通过重参数化,可以将差分卷积的效果整合到一个卷积核中,使得这个卷积核既能捕获图像的基本特征(通过标准卷积部分),也能强调图像的细节和差异信息(通过差分卷积部分)。
|
||||||
|
这种方法特别适用于那些需要同时考虑全局内容和局部细节信息的任务,如图像去雾,其中既需要理解图像的整体结构,也需要恢复由于雾导致的细节丢失。
|
||||||
|
重参数化技术的关键优势在于,它允许模型在维持参数数量和计算复杂度不变的前提下,实现更为复杂或更为精细的功能。
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
class Conv2d_cd(nn.Module):
|
||||||
|
def __init__(self, in_channels, out_channels, kernel_size=3, stride=1,
|
||||||
|
padding=1, dilation=1, groups=1, bias=False, theta=1.0):
|
||||||
|
super(Conv2d_cd, self).__init__()
|
||||||
|
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding,
|
||||||
|
dilation=dilation, groups=groups, bias=bias)
|
||||||
|
self.theta = theta
|
||||||
|
|
||||||
|
def get_weight(self):
|
||||||
|
conv_weight = self.conv.weight
|
||||||
|
conv_shape = conv_weight.shape
|
||||||
|
conv_weight = Rearrange('c_in c_out k1 k2 -> c_in c_out (k1 k2)')(conv_weight)
|
||||||
|
# conv_weight_cd = torch.cuda.FloatTensor(conv_shape[0], conv_shape[1], 3 * 3).fill_(0)
|
||||||
|
conv_weight_cd = torch.FloatTensor(conv_shape[0], conv_shape[1], 3 * 3).fill_(0)
|
||||||
|
conv_weight_cd[:, :, :] = conv_weight[:, :, :]
|
||||||
|
conv_weight_cd[:, :, 4] = conv_weight[:, :, 4] - conv_weight[:, :, :].sum(2)
|
||||||
|
conv_weight_cd = Rearrange('c_in c_out (k1 k2) -> c_in c_out k1 k2', k1=conv_shape[2], k2=conv_shape[3])(
|
||||||
|
conv_weight_cd)
|
||||||
|
return conv_weight_cd, self.conv.bias
|
||||||
|
|
||||||
|
|
||||||
|
class Conv2d_ad(nn.Module):
|
||||||
|
def __init__(self, in_channels, out_channels, kernel_size=3, stride=1,
|
||||||
|
padding=1, dilation=1, groups=1, bias=False, theta=1.0):
|
||||||
|
super(Conv2d_ad, self).__init__()
|
||||||
|
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding,
|
||||||
|
dilation=dilation, groups=groups, bias=bias)
|
||||||
|
self.theta = theta
|
||||||
|
|
||||||
|
def get_weight(self):
|
||||||
|
conv_weight = self.conv.weight
|
||||||
|
conv_shape = conv_weight.shape
|
||||||
|
conv_weight = Rearrange('c_in c_out k1 k2 -> c_in c_out (k1 k2)')(conv_weight)
|
||||||
|
conv_weight_ad = conv_weight - self.theta * conv_weight[:, :, [3, 0, 1, 6, 4, 2, 7, 8, 5]]
|
||||||
|
conv_weight_ad = Rearrange('c_in c_out (k1 k2) -> c_in c_out k1 k2', k1=conv_shape[2], k2=conv_shape[3])(
|
||||||
|
conv_weight_ad)
|
||||||
|
return conv_weight_ad, self.conv.bias
|
||||||
|
|
||||||
|
|
||||||
|
class Conv2d_rd(nn.Module):
|
||||||
|
def __init__(self, in_channels, out_channels, kernel_size=3, stride=1,
|
||||||
|
padding=2, dilation=1, groups=1, bias=False, theta=1.0):
|
||||||
|
|
||||||
|
super(Conv2d_rd, self).__init__()
|
||||||
|
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding,
|
||||||
|
dilation=dilation, groups=groups, bias=bias)
|
||||||
|
self.theta = theta
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
|
||||||
|
if math.fabs(self.theta - 0.0) < 1e-8:
|
||||||
|
out_normal = self.conv(x)
|
||||||
|
return out_normal
|
||||||
|
else:
|
||||||
|
conv_weight = self.conv.weight
|
||||||
|
conv_shape = conv_weight.shape
|
||||||
|
if conv_weight.is_cuda:
|
||||||
|
conv_weight_rd = torch.cuda.FloatTensor(conv_shape[0], conv_shape[1], 5 * 5).fill_(0)
|
||||||
|
else:
|
||||||
|
conv_weight_rd = torch.zeros(conv_shape[0], conv_shape[1], 5 * 5)
|
||||||
|
conv_weight = Rearrange('c_in c_out k1 k2 -> c_in c_out (k1 k2)')(conv_weight)
|
||||||
|
conv_weight_rd[:, :, [0, 2, 4, 10, 14, 20, 22, 24]] = conv_weight[:, :, 1:]
|
||||||
|
conv_weight_rd[:, :, [6, 7, 8, 11, 13, 16, 17, 18]] = -conv_weight[:, :, 1:] * self.theta
|
||||||
|
conv_weight_rd[:, :, 12] = conv_weight[:, :, 0] * (1 - self.theta)
|
||||||
|
conv_weight_rd = conv_weight_rd.view(conv_shape[0], conv_shape[1], 5, 5)
|
||||||
|
out_diff = nn.functional.conv2d(input=x, weight=conv_weight_rd, bias=self.conv.bias,
|
||||||
|
stride=self.conv.stride, padding=self.conv.padding, groups=self.conv.groups)
|
||||||
|
|
||||||
|
return out_diff
|
||||||
|
|
||||||
|
|
||||||
|
class Conv2d_hd(nn.Module):
|
||||||
|
def __init__(self, in_channels, out_channels, kernel_size=3, stride=1,
|
||||||
|
padding=1, dilation=1, groups=1, bias=False, theta=1.0):
|
||||||
|
super(Conv2d_hd, self).__init__()
|
||||||
|
self.conv = nn.Conv1d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding,
|
||||||
|
dilation=dilation, groups=groups, bias=bias)
|
||||||
|
|
||||||
|
def get_weight(self):
|
||||||
|
conv_weight = self.conv.weight
|
||||||
|
conv_shape = conv_weight.shape
|
||||||
|
# conv_weight_hd = torch.cuda.FloatTensor(conv_shape[0], conv_shape[1], 3 * 3).fill_(0)
|
||||||
|
conv_weight_hd = torch.FloatTensor(conv_shape[0], conv_shape[1], 3 * 3).fill_(0)
|
||||||
|
conv_weight_hd[:, :, [0, 3, 6]] = conv_weight[:, :, :]
|
||||||
|
conv_weight_hd[:, :, [2, 5, 8]] = -conv_weight[:, :, :]
|
||||||
|
conv_weight_hd = Rearrange('c_in c_out (k1 k2) -> c_in c_out k1 k2', k1=conv_shape[2], k2=conv_shape[2])(
|
||||||
|
conv_weight_hd)
|
||||||
|
return conv_weight_hd, self.conv.bias
|
||||||
|
|
||||||
|
|
||||||
|
class Conv2d_vd(nn.Module):
|
||||||
|
def __init__(self, in_channels, out_channels, kernel_size=3, stride=1,
|
||||||
|
padding=1, dilation=1, groups=1, bias=False):
|
||||||
|
super(Conv2d_vd, self).__init__()
|
||||||
|
self.conv = nn.Conv1d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding,
|
||||||
|
dilation=dilation, groups=groups, bias=bias)
|
||||||
|
|
||||||
|
def get_weight(self):
|
||||||
|
conv_weight = self.conv.weight
|
||||||
|
conv_shape = conv_weight.shape
|
||||||
|
# conv_weight_vd = torch.cuda.FloatTensor(conv_shape[0], conv_shape[1], 3 * 3).fill_(0)
|
||||||
|
conv_weight_vd = torch.FloatTensor(conv_shape[0], conv_shape[1], 3 * 3).fill_(0)
|
||||||
|
conv_weight_vd[:, :, [0, 1, 2]] = conv_weight[:, :, :]
|
||||||
|
conv_weight_vd[:, :, [6, 7, 8]] = -conv_weight[:, :, :]
|
||||||
|
conv_weight_vd = Rearrange('c_in c_out (k1 k2) -> c_in c_out k1 k2', k1=conv_shape[2], k2=conv_shape[2])(
|
||||||
|
conv_weight_vd)
|
||||||
|
return conv_weight_vd, self.conv.bias
|
||||||
|
|
||||||
|
|
||||||
|
class DEConv(nn.Module):
|
||||||
|
def __init__(self, dim):
|
||||||
|
super(DEConv, self).__init__()
|
||||||
|
self.conv1_1 = Conv2d_cd(dim, dim, 3, bias=True)
|
||||||
|
self.conv1_2 = Conv2d_hd(dim, dim, 3, bias=True)
|
||||||
|
self.conv1_3 = Conv2d_vd(dim, dim, 3, bias=True)
|
||||||
|
self.conv1_4 = Conv2d_ad(dim, dim, 3, bias=True)
|
||||||
|
self.conv1_5 = nn.Conv2d(dim, dim, 3, padding=1, bias=True)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
w1, b1 = self.conv1_1.get_weight()
|
||||||
|
w2, b2 = self.conv1_2.get_weight()
|
||||||
|
w3, b3 = self.conv1_3.get_weight()
|
||||||
|
w4, b4 = self.conv1_4.get_weight()
|
||||||
|
w5, b5 = self.conv1_5.weight, self.conv1_5.bias
|
||||||
|
|
||||||
|
w = w1 + w2 + w3 + w4 + w5
|
||||||
|
b = b1 + b2 + b3 + b4 + b5
|
||||||
|
res = nn.functional.conv2d(input=x, weight=w, bias=b, stride=1, padding=1, groups=1)
|
||||||
|
|
||||||
|
return res
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
# 初始化DEConv模块,dim为输入和输出的通道数
|
||||||
|
block = DEConv(dim=16)
|
||||||
|
# 创建一个随机输入张量,假设输入尺寸为(batch_size, channels, height, width)
|
||||||
|
input_tensor = torch.rand(4, 16, 64, 64)
|
||||||
|
# 将输入传递给DEConv模块
|
||||||
|
output_tensor = block(input_tensor)
|
||||||
|
# 打印输入和输出张量的尺寸
|
||||||
|
print("输入尺寸:", input_tensor.size())
|
||||||
|
print("输出尺寸:", output_tensor.size())
|
||||||
|
|
||||||
|
|
116
componets/DynamicFilter(频域模块动态滤波器用于CV2维图像).py
Normal file
116
componets/DynamicFilter(频域模块动态滤波器用于CV2维图像).py
Normal file
@ -0,0 +1,116 @@
|
|||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from timm.layers.helpers import to_2tuple
|
||||||
|
|
||||||
|
"""
|
||||||
|
配备多头自注意力 (MHSA) 的模型在计算机视觉方面取得了显着的性能。它们的计算复杂度与输入特征图中的二次像素数成正比,导致处理速度缓慢,尤其是在处理高分辨率图像时。
|
||||||
|
为了规避这个问题,提出了一种新型的代币混合器作为MHSA的替代方案:基于FFT的代币混合器涉及类似于MHSA的全局操作,但计算复杂度较低。
|
||||||
|
在这里,我们提出了一种名为动态过滤器的新型令牌混合器以缩小上述差距。
|
||||||
|
DynamicFilter 模块通过频域滤波和动态调整滤波器权重,能够对图像进行复杂的增强和处理。
|
||||||
|
"""
|
||||||
|
|
||||||
|
class StarReLU(nn.Module):
|
||||||
|
"""
|
||||||
|
StarReLU: s * relu(x) ** 2 + b
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, scale_value=1.0, bias_value=0.0,
|
||||||
|
scale_learnable=True, bias_learnable=True,
|
||||||
|
mode=None, inplace=False):
|
||||||
|
super().__init__()
|
||||||
|
self.inplace = inplace
|
||||||
|
self.relu = nn.ReLU(inplace=inplace)
|
||||||
|
self.scale = nn.Parameter(scale_value * torch.ones(1),
|
||||||
|
requires_grad=scale_learnable)
|
||||||
|
self.bias = nn.Parameter(bias_value * torch.ones(1),
|
||||||
|
requires_grad=bias_learnable)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return self.scale * self.relu(x) ** 2 + self.bias
|
||||||
|
|
||||||
|
class Mlp(nn.Module):
|
||||||
|
""" MLP as used in MetaFormer models, eg Transformer, MLP-Mixer, PoolFormer, MetaFormer baslines and related networks.
|
||||||
|
Mostly copied from timm.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, dim, mlp_ratio=4, out_features=None, act_layer=StarReLU, drop=0.,
|
||||||
|
bias=False, **kwargs):
|
||||||
|
super().__init__()
|
||||||
|
in_features = dim
|
||||||
|
out_features = out_features or in_features
|
||||||
|
hidden_features = int(mlp_ratio * in_features)
|
||||||
|
drop_probs = to_2tuple(drop)
|
||||||
|
|
||||||
|
self.fc1 = nn.Linear(in_features, hidden_features, bias=bias)
|
||||||
|
self.act = act_layer()
|
||||||
|
self.drop1 = nn.Dropout(drop_probs[0])
|
||||||
|
self.fc2 = nn.Linear(hidden_features, out_features, bias=bias)
|
||||||
|
self.drop2 = nn.Dropout(drop_probs[1])
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x = self.fc1(x)
|
||||||
|
x = self.act(x)
|
||||||
|
x = self.drop1(x)
|
||||||
|
x = self.fc2(x)
|
||||||
|
x = self.drop2(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class DynamicFilter(nn.Module):
|
||||||
|
def __init__(self, dim, expansion_ratio=2, reweight_expansion_ratio=.25,
|
||||||
|
act1_layer=StarReLU, act2_layer=nn.Identity,
|
||||||
|
bias=False, num_filters=4, size=14, weight_resize=False,
|
||||||
|
**kwargs):
|
||||||
|
super().__init__()
|
||||||
|
size = to_2tuple(size)
|
||||||
|
self.size = size[0]
|
||||||
|
self.filter_size = size[1] // 2 + 1
|
||||||
|
self.num_filters = num_filters
|
||||||
|
self.dim = dim
|
||||||
|
self.med_channels = int(expansion_ratio * dim)
|
||||||
|
self.weight_resize = weight_resize
|
||||||
|
self.pwconv1 = nn.Linear(dim, self.med_channels, bias=bias)
|
||||||
|
self.act1 = act1_layer()
|
||||||
|
self.reweight = Mlp(dim, reweight_expansion_ratio, num_filters * self.med_channels)
|
||||||
|
self.complex_weights = nn.Parameter(
|
||||||
|
torch.randn(self.size, self.filter_size, num_filters, 2,
|
||||||
|
dtype=torch.float32) * 0.02)
|
||||||
|
self.act2 = act2_layer()
|
||||||
|
self.pwconv2 = nn.Linear(self.med_channels, dim, bias=bias)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
B, H, W, _ = x.shape
|
||||||
|
|
||||||
|
routeing = self.reweight(x.mean(dim=(1, 2))).view(B, self.num_filters,
|
||||||
|
-1).softmax(dim=1)
|
||||||
|
x = self.pwconv1(x)
|
||||||
|
x = self.act1(x)
|
||||||
|
x = x.to(torch.float32)
|
||||||
|
x = torch.fft.rfft2(x, dim=(1, 2), norm='ortho')
|
||||||
|
|
||||||
|
if self.weight_resize:
|
||||||
|
complex_weights = resize_complex_weight(self.complex_weights, x.shape[1],
|
||||||
|
x.shape[2])
|
||||||
|
complex_weights = torch.view_as_complex(complex_weights.contiguous())
|
||||||
|
else:
|
||||||
|
complex_weights = torch.view_as_complex(self.complex_weights)
|
||||||
|
routeing = routeing.to(torch.complex64)
|
||||||
|
weight = torch.einsum('bfc,hwf->bhwc', routeing, complex_weights)
|
||||||
|
if self.weight_resize:
|
||||||
|
weight = weight.view(-1, x.shape[1], x.shape[2], self.med_channels)
|
||||||
|
else:
|
||||||
|
weight = weight.view(-1, self.size, self.filter_size, self.med_channels)
|
||||||
|
x = x * weight
|
||||||
|
x = torch.fft.irfft2(x, s=(H, W), dim=(1, 2), norm='ortho')
|
||||||
|
|
||||||
|
x = self.act2(x)
|
||||||
|
x = self.pwconv2(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
block = DynamicFilter(32, size=64) # size==H,W
|
||||||
|
input = torch.rand(3, 64, 64, 32)
|
||||||
|
output = block(input)
|
||||||
|
print(input.size())
|
||||||
|
print(output.size())
|
BIN
componets/SCSA(CV2维图像).pdf
Normal file
BIN
componets/SCSA(CV2维图像).pdf
Normal file
Binary file not shown.
156
componets/SCSA.py
Normal file
156
componets/SCSA.py
Normal file
@ -0,0 +1,156 @@
|
|||||||
|
import typing as t
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from einops.einops import rearrange
|
||||||
|
from mmengine.model import BaseModule
|
||||||
|
__all__ = ['SCSA']
|
||||||
|
|
||||||
|
"""SCSA:探索空间注意力和通道注意力之间的协同作用
|
||||||
|
通道和空间注意力分别在为各种下游视觉任务提取特征依赖性和空间结构关系方面带来了显着的改进。
|
||||||
|
虽然它们的结合更有利于发挥各自的优势,但通道和空间注意力之间的协同作用尚未得到充分探索,缺乏充分利用多语义信息的协同潜力来进行特征引导和缓解语义差异。
|
||||||
|
我们的研究试图在多个语义层面揭示空间和通道注意力之间的协同关系,提出了一种新颖的空间和通道协同注意力模块(SCSA)。我们的SCSA由两部分组成:可共享的多语义空间注意力(SMSA)和渐进式通道自注意力(PCSA)。
|
||||||
|
SMSA 集成多语义信息并利用渐进式压缩策略将判别性空间先验注入 PCSA 的通道自注意力中,有效地指导通道重新校准。此外,PCSA 中基于自注意力机制的稳健特征交互进一步缓解了 SMSA 中不同子特征之间多语义信息的差异。
|
||||||
|
我们在七个基准数据集上进行了广泛的实验,包括 ImageNet-1K 上的分类、MSCOCO 2017 上的对象检测、ADE20K 上的分割以及其他四个复杂场景检测数据集。我们的结果表明,我们提出的 SCSA 不仅超越了当前最先进的注意力机制,
|
||||||
|
而且在各种任务场景中表现出增强的泛化能力。
|
||||||
|
"""
|
||||||
|
|
||||||
|
class SCSA(BaseModule):
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
dim: int,
|
||||||
|
head_num: int,
|
||||||
|
window_size: int = 7,
|
||||||
|
group_kernel_sizes: t.List[int] = [3, 5, 7, 9],
|
||||||
|
qkv_bias: bool = False,
|
||||||
|
fuse_bn: bool = False,
|
||||||
|
norm_cfg: t.Dict = dict(type='BN'),
|
||||||
|
act_cfg: t.Dict = dict(type='ReLU'),
|
||||||
|
down_sample_mode: str = 'avg_pool',
|
||||||
|
attn_drop_ratio: float = 0.,
|
||||||
|
gate_layer: str = 'sigmoid',
|
||||||
|
):
|
||||||
|
super(SCSA, self).__init__()
|
||||||
|
self.dim = dim
|
||||||
|
self.head_num = head_num
|
||||||
|
self.head_dim = dim // head_num
|
||||||
|
self.scaler = self.head_dim ** -0.5
|
||||||
|
self.group_kernel_sizes = group_kernel_sizes
|
||||||
|
self.window_size = window_size
|
||||||
|
self.qkv_bias = qkv_bias
|
||||||
|
self.fuse_bn = fuse_bn
|
||||||
|
self.down_sample_mode = down_sample_mode
|
||||||
|
|
||||||
|
assert self.dim // 4, 'The dimension of input feature should be divisible by 4.'
|
||||||
|
self.group_chans = group_chans = self.dim // 4
|
||||||
|
|
||||||
|
self.local_dwc = nn.Conv1d(group_chans, group_chans, kernel_size=group_kernel_sizes[0],
|
||||||
|
padding=group_kernel_sizes[0] // 2, groups=group_chans)
|
||||||
|
self.global_dwc_s = nn.Conv1d(group_chans, group_chans, kernel_size=group_kernel_sizes[1],
|
||||||
|
padding=group_kernel_sizes[1] // 2, groups=group_chans)
|
||||||
|
self.global_dwc_m = nn.Conv1d(group_chans, group_chans, kernel_size=group_kernel_sizes[2],
|
||||||
|
padding=group_kernel_sizes[2] // 2, groups=group_chans)
|
||||||
|
self.global_dwc_l = nn.Conv1d(group_chans, group_chans, kernel_size=group_kernel_sizes[3],
|
||||||
|
padding=group_kernel_sizes[3] // 2, groups=group_chans)
|
||||||
|
self.sa_gate = nn.Softmax(dim=2) if gate_layer == 'softmax' else nn.Sigmoid()
|
||||||
|
self.norm_h = nn.GroupNorm(4, dim)
|
||||||
|
self.norm_w = nn.GroupNorm(4, dim)
|
||||||
|
|
||||||
|
self.conv_d = nn.Identity()
|
||||||
|
self.norm = nn.GroupNorm(1, dim)
|
||||||
|
self.q = nn.Conv2d(in_channels=dim, out_channels=dim, kernel_size=1, bias=qkv_bias, groups=dim)
|
||||||
|
self.k = nn.Conv2d(in_channels=dim, out_channels=dim, kernel_size=1, bias=qkv_bias, groups=dim)
|
||||||
|
self.v = nn.Conv2d(in_channels=dim, out_channels=dim, kernel_size=1, bias=qkv_bias, groups=dim)
|
||||||
|
self.attn_drop = nn.Dropout(attn_drop_ratio)
|
||||||
|
self.ca_gate = nn.Softmax(dim=1) if gate_layer == 'softmax' else nn.Sigmoid()
|
||||||
|
|
||||||
|
if window_size == -1:
|
||||||
|
self.down_func = nn.AdaptiveAvgPool2d((1, 1))
|
||||||
|
else:
|
||||||
|
if down_sample_mode == 'recombination':
|
||||||
|
self.down_func = self.space_to_chans
|
||||||
|
# dimensionality reduction
|
||||||
|
self.conv_d = nn.Conv2d(in_channels=dim * window_size ** 2, out_channels=dim, kernel_size=1, bias=False)
|
||||||
|
elif down_sample_mode == 'avg_pool':
|
||||||
|
self.down_func = nn.AvgPool2d(kernel_size=(window_size, window_size), stride=window_size)
|
||||||
|
elif down_sample_mode == 'max_pool':
|
||||||
|
self.down_func = nn.MaxPool2d(kernel_size=(window_size, window_size), stride=window_size)
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
The dim of x is (B, C, H, W)
|
||||||
|
"""
|
||||||
|
# Spatial attention priority calculation
|
||||||
|
b, c, h_, w_ = x.size()
|
||||||
|
# (B, C, H)
|
||||||
|
x_h = x.mean(dim=3)
|
||||||
|
l_x_h, g_x_h_s, g_x_h_m, g_x_h_l = torch.split(x_h, self.group_chans, dim=1)
|
||||||
|
# (B, C, W)
|
||||||
|
x_w = x.mean(dim=2)
|
||||||
|
l_x_w, g_x_w_s, g_x_w_m, g_x_w_l = torch.split(x_w, self.group_chans, dim=1)
|
||||||
|
|
||||||
|
x_h_attn = self.sa_gate(self.norm_h(torch.cat((
|
||||||
|
self.local_dwc(l_x_h),
|
||||||
|
self.global_dwc_s(g_x_h_s),
|
||||||
|
self.global_dwc_m(g_x_h_m),
|
||||||
|
self.global_dwc_l(g_x_h_l),
|
||||||
|
), dim=1)))
|
||||||
|
x_h_attn = x_h_attn.view(b, c, h_, 1)
|
||||||
|
|
||||||
|
x_w_attn = self.sa_gate(self.norm_w(torch.cat((
|
||||||
|
self.local_dwc(l_x_w),
|
||||||
|
self.global_dwc_s(g_x_w_s),
|
||||||
|
self.global_dwc_m(g_x_w_m),
|
||||||
|
self.global_dwc_l(g_x_w_l)
|
||||||
|
), dim=1)))
|
||||||
|
x_w_attn = x_w_attn.view(b, c, 1, w_)
|
||||||
|
|
||||||
|
x = x * x_h_attn * x_w_attn
|
||||||
|
|
||||||
|
# Channel attention based on self attention
|
||||||
|
# reduce calculations
|
||||||
|
y = self.down_func(x)
|
||||||
|
y = self.conv_d(y)
|
||||||
|
_, _, h_, w_ = y.size()
|
||||||
|
|
||||||
|
# normalization first, then reshape -> (B, H, W, C) -> (B, C, H * W) and generate q, k and v
|
||||||
|
y = self.norm(y)
|
||||||
|
q = self.q(y)
|
||||||
|
k = self.k(y)
|
||||||
|
v = self.v(y)
|
||||||
|
# (B, C, H, W) -> (B, head_num, head_dim, N)
|
||||||
|
q = rearrange(q, 'b (head_num head_dim) h w -> b head_num head_dim (h w)', head_num=int(self.head_num),
|
||||||
|
head_dim=int(self.head_dim))
|
||||||
|
k = rearrange(k, 'b (head_num head_dim) h w -> b head_num head_dim (h w)', head_num=int(self.head_num),
|
||||||
|
head_dim=int(self.head_dim))
|
||||||
|
v = rearrange(v, 'b (head_num head_dim) h w -> b head_num head_dim (h w)', head_num=int(self.head_num),
|
||||||
|
head_dim=int(self.head_dim))
|
||||||
|
|
||||||
|
# (B, head_num, head_dim, head_dim)
|
||||||
|
attn = q @ k.transpose(-2, -1) * self.scaler
|
||||||
|
attn = self.attn_drop(attn.softmax(dim=-1))
|
||||||
|
# (B, head_num, head_dim, N)
|
||||||
|
attn = attn @ v
|
||||||
|
# (B, C, H_, W_)
|
||||||
|
attn = rearrange(attn, 'b head_num head_dim (h w) -> b (head_num head_dim) h w', h=int(h_), w=int(w_))
|
||||||
|
# (B, C, 1, 1)
|
||||||
|
attn = attn.mean((2, 3), keepdim=True)
|
||||||
|
attn = self.ca_gate(attn)
|
||||||
|
return attn * x
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
|
||||||
|
block = SCSA(
|
||||||
|
dim=256,
|
||||||
|
head_num=8,
|
||||||
|
)
|
||||||
|
|
||||||
|
input_tensor = torch.rand(1, 256, 32, 32)
|
||||||
|
|
||||||
|
# 调用模块进行前向传播
|
||||||
|
output_tensor = block(input_tensor)
|
||||||
|
|
||||||
|
# 打印输入和输出张量的大小
|
||||||
|
print("Input size:", input_tensor.size())
|
||||||
|
print("Output size:", output_tensor.size())
|
37
componets/SEBlock.py
Normal file
37
componets/SEBlock.py
Normal file
@ -0,0 +1,37 @@
|
|||||||
|
'''-------------一、SE模块-----------------------------'''
|
||||||
|
import torch
|
||||||
|
from torch import nn
|
||||||
|
|
||||||
|
|
||||||
|
# 全局平均池化+1*1卷积核+ReLu+1*1卷积核+Sigmoid
|
||||||
|
class SE_Block(nn.Module):
|
||||||
|
def __init__(self, inchannel, ratio=16):
|
||||||
|
super(SE_Block, self).__init__()
|
||||||
|
# 全局平均池化(Fsq操作)
|
||||||
|
self.gap = nn.AdaptiveAvgPool2d((1, 1))
|
||||||
|
# 两个全连接层(Fex操作)
|
||||||
|
self.fc = nn.Sequential(
|
||||||
|
nn.Linear(inchannel, inchannel // ratio, bias=False), # 从 c -> c/r
|
||||||
|
nn.ReLU(),
|
||||||
|
nn.Linear(inchannel // ratio, inchannel, bias=False), # 从 c/r -> c
|
||||||
|
nn.Sigmoid()
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
# 读取批数据图片数量及通道数
|
||||||
|
b, c, h, w = x.size()
|
||||||
|
# Fsq操作:经池化后输出b*c的矩阵
|
||||||
|
y = self.gap(x).view(b, c)
|
||||||
|
# Fex操作:经全连接层输出(b,c,1,1)矩阵
|
||||||
|
y = self.fc(y).view(b, c, 1, 1)
|
||||||
|
# Fscale操作:将得到的权重乘以原来的特征图x
|
||||||
|
return x * y.expand_as(x)
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
input = torch.randn(1, 64, 32, 32)
|
||||||
|
seblock = SE_Block(64)
|
||||||
|
print(seblock)
|
||||||
|
output = seblock(input)
|
||||||
|
print(input.shape)
|
||||||
|
print(output.shape)
|
||||||
|
|
65
componets/SMFA.py
Normal file
65
componets/SMFA.py
Normal file
@ -0,0 +1,65 @@
|
|||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
|
||||||
|
"""ECCV2024(https://github.com/Zheng-MJ/SMFANet)
|
||||||
|
基于Transformer的恢复方法取得了显著的效果,因为Transformer的自注意力机制(SA)可以探索非局部信息,从而实现更好的高分辨率图像重建。然而,关键的点积自注意力需要大量的计算资源,这限制了其在低功耗设备上的应用。
|
||||||
|
此外,自注意力机制的低通滤波特性限制了其捕捉局部细节的能力,从而导致重建结果过于平滑。为了解决这些问题,我们提出了一种自调制特征聚合(SMFA)模块,协同利用局部和非局部特征交互,以实现更精确的重建。
|
||||||
|
具体而言,SMFA模块采用了高效的自注意力近似(EASA)分支来建模非局部信息,并使用局部细节估计(LDE)分支来捕捉局部细节。此外,我们还引入了基于部分卷积的前馈网络(PCFN),以进一步优化从SMFA提取的代表性特征。
|
||||||
|
大量实验表明,所提出的SMFANet系列在公共基准数据集上实现了更好的重建性能与计算效率的平衡。
|
||||||
|
特别是,与SwinIR-light的×4放大相比,SMFANet+在五个公共测试集上的平均性能提高了0.14dB,运行速度提升了约10倍,且模型复杂度(如FLOPs)仅为其约43%。
|
||||||
|
"""
|
||||||
|
|
||||||
|
class DMlp(nn.Module):
|
||||||
|
def __init__(self, dim, growth_rate=2.0):
|
||||||
|
super().__init__()
|
||||||
|
hidden_dim = int(dim * growth_rate)
|
||||||
|
self.conv_0 = nn.Sequential(
|
||||||
|
nn.Conv2d(dim, hidden_dim, 3, 1, 1, groups=dim),
|
||||||
|
nn.Conv2d(hidden_dim, hidden_dim, 1, 1, 0)
|
||||||
|
)
|
||||||
|
self.act = nn.GELU()
|
||||||
|
self.conv_1 = nn.Conv2d(hidden_dim, dim, 1, 1, 0)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x = self.conv_0(x)
|
||||||
|
x = self.act(x)
|
||||||
|
x = self.conv_1(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class SMFA(nn.Module):
|
||||||
|
def __init__(self, dim=36):
|
||||||
|
super(SMFA, self).__init__()
|
||||||
|
self.linear_0 = nn.Conv2d(dim, dim * 2, 1, 1, 0)
|
||||||
|
self.linear_1 = nn.Conv2d(dim, dim, 1, 1, 0)
|
||||||
|
self.linear_2 = nn.Conv2d(dim, dim, 1, 1, 0)
|
||||||
|
|
||||||
|
self.lde = DMlp(dim, 2)
|
||||||
|
|
||||||
|
self.dw_conv = nn.Conv2d(dim, dim, 3, 1, 1, groups=dim)
|
||||||
|
|
||||||
|
self.gelu = nn.GELU()
|
||||||
|
self.down_scale = 8
|
||||||
|
|
||||||
|
self.alpha = nn.Parameter(torch.ones((1, dim, 1, 1)))
|
||||||
|
self.belt = nn.Parameter(torch.zeros((1, dim, 1, 1)))
|
||||||
|
|
||||||
|
def forward(self, f):
|
||||||
|
_, _, h, w = f.shape
|
||||||
|
y, x = self.linear_0(f).chunk(2, dim=1)
|
||||||
|
x_s = self.dw_conv(F.adaptive_max_pool2d(x, (h // self.down_scale, w // self.down_scale)))
|
||||||
|
x_v = torch.var(x, dim=(-2, -1), keepdim=True)
|
||||||
|
x_l = x * F.interpolate(self.gelu(self.linear_1(x_s * self.alpha + x_v * self.belt)), size=(h, w),
|
||||||
|
mode='nearest')
|
||||||
|
y_d = self.lde(y)
|
||||||
|
return self.linear_2(x_l + y_d)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
block = SMFA(dim=36)
|
||||||
|
input = torch.randn(3, 36, 64, 64)
|
||||||
|
output = block(input)
|
||||||
|
print(input.size())
|
||||||
|
print(output.size())
|
110
componets/TIAM.py
Normal file
110
componets/TIAM.py
Normal file
@ -0,0 +1,110 @@
|
|||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
"""Elsevier2024
|
||||||
|
变化检测 (CD) 是地球观测中一种重要的监测方法,尤其适用于土地利用分析、城市管理和灾害损失评估。然而,在星座互联和空天协作时代,感兴趣区域 (ROI) 的变化由于几何透视旋转和时间风格差异而导致许多错误检测。
|
||||||
|
为了应对这些挑战,我们引入了 CDNeXt,该框架阐明了一种稳健而有效的方法,用于将基于预训练主干的 Siamese 网络与用于遥感图像的创新时空交互注意模块 (TIAM) 相结合。
|
||||||
|
CDNeXt 可分为四个主要组件:编码器、交互器、解码器和检测器。值得注意的是,由 TIAM 提供支持的交互器从编码器提取的二进制时间特征中查询和重建空间透视依赖关系和时间风格相关性,以扩大 ROI 变化的差异。
|
||||||
|
最后,检测器集成解码器生成的分层特征,随后生成二进制变化掩码。
|
||||||
|
"""
|
||||||
|
|
||||||
|
class SpatiotemporalAttentionFullNotWeightShared(nn.Module):
|
||||||
|
def __init__(self, in_channels, inter_channels=None, dimension=2, sub_sample=False):
|
||||||
|
super(SpatiotemporalAttentionFullNotWeightShared, self).__init__()
|
||||||
|
assert dimension in [2, ]
|
||||||
|
self.dimension = dimension
|
||||||
|
self.sub_sample = sub_sample
|
||||||
|
self.in_channels = in_channels
|
||||||
|
self.inter_channels = inter_channels
|
||||||
|
|
||||||
|
if self.inter_channels is None:
|
||||||
|
self.inter_channels = in_channels // 2
|
||||||
|
if self.inter_channels == 0:
|
||||||
|
self.inter_channels = 1
|
||||||
|
|
||||||
|
self.g1 = nn.Sequential(
|
||||||
|
nn.BatchNorm2d(self.in_channels),
|
||||||
|
nn.Conv2d(in_channels=self.in_channels, out_channels=self.inter_channels,
|
||||||
|
kernel_size=1, stride=1, padding=0)
|
||||||
|
)
|
||||||
|
self.g2 = nn.Sequential(
|
||||||
|
nn.BatchNorm2d(self.in_channels),
|
||||||
|
nn.Conv2d(in_channels=self.in_channels, out_channels=self.inter_channels,
|
||||||
|
kernel_size=1, stride=1, padding=0),
|
||||||
|
)
|
||||||
|
|
||||||
|
self.W1 = nn.Sequential(
|
||||||
|
nn.Conv2d(in_channels=self.inter_channels, out_channels=self.in_channels,
|
||||||
|
kernel_size=1, stride=1, padding=0),
|
||||||
|
nn.BatchNorm2d(self.in_channels)
|
||||||
|
)
|
||||||
|
self.W2 = nn.Sequential(
|
||||||
|
nn.Conv2d(in_channels=self.inter_channels, out_channels=self.in_channels,
|
||||||
|
kernel_size=1, stride=1, padding=0),
|
||||||
|
nn.BatchNorm2d(self.in_channels)
|
||||||
|
)
|
||||||
|
self.theta = nn.Sequential(
|
||||||
|
nn.BatchNorm2d(self.in_channels),
|
||||||
|
nn.Conv2d(in_channels=self.in_channels, out_channels=self.inter_channels,
|
||||||
|
kernel_size=1, stride=1, padding=0),
|
||||||
|
)
|
||||||
|
self.phi = nn.Sequential(
|
||||||
|
nn.BatchNorm2d(self.in_channels),
|
||||||
|
nn.Conv2d(in_channels=self.in_channels, out_channels=self.inter_channels,
|
||||||
|
kernel_size=1, stride=1, padding=0),
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x1, x2):
|
||||||
|
"""
|
||||||
|
:param x: (b, c, h, w)
|
||||||
|
:param return_nl_map: if True return z, nl_map, else only return z.
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
batch_size = x1.size(0)
|
||||||
|
g_x11 = self.g1(x1).reshape(batch_size, self.inter_channels, -1)
|
||||||
|
g_x12 = g_x11.permute(0, 2, 1)
|
||||||
|
g_x21 = self.g2(x2).reshape(batch_size, self.inter_channels, -1)
|
||||||
|
g_x22 = g_x21.permute(0, 2, 1)
|
||||||
|
|
||||||
|
theta_x1 = self.theta(x1).reshape(batch_size, self.inter_channels, -1)
|
||||||
|
theta_x2 = theta_x1.permute(0, 2, 1)
|
||||||
|
|
||||||
|
phi_x1 = self.phi(x2).reshape(batch_size, self.inter_channels, -1)
|
||||||
|
phi_x2 = phi_x1.permute(0, 2, 1)
|
||||||
|
|
||||||
|
energy_time_1 = torch.matmul(theta_x1, phi_x2)
|
||||||
|
energy_time_2 = energy_time_1.permute(0, 2, 1)
|
||||||
|
energy_space_1 = torch.matmul(theta_x2, phi_x1)
|
||||||
|
energy_space_2 = energy_space_1.permute(0, 2, 1)
|
||||||
|
|
||||||
|
energy_time_1s = F.softmax(energy_time_1, dim=-1)
|
||||||
|
energy_time_2s = F.softmax(energy_time_2, dim=-1)
|
||||||
|
energy_space_2s = F.softmax(energy_space_1, dim=-2)
|
||||||
|
energy_space_1s = F.softmax(energy_space_2, dim=-2)
|
||||||
|
# C1*S(C2) energy_time_1s * C1*H1W1 g_x12 * energy_space_1s S(H2W2)*H1W1 -> C1*H1W1
|
||||||
|
y1 = torch.matmul(torch.matmul(energy_time_2s, g_x11), energy_space_2s).contiguous() # C2*H2W2
|
||||||
|
# C2*S(C1) energy_time_2s * C2*H2W2 g_x21 * energy_space_2s S(H1W1)*H2W2 -> C2*H2W2
|
||||||
|
y2 = torch.matmul(torch.matmul(energy_time_1s, g_x21), energy_space_1s).contiguous() # C1*H1W1
|
||||||
|
y1 = y1.reshape(batch_size, self.inter_channels, *x2.size()[2:])
|
||||||
|
y2 = y2.reshape(batch_size, self.inter_channels, *x1.size()[2:])
|
||||||
|
return x1 + self.W1(y1), x2 + self.W2(y2)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
in_channels = 64
|
||||||
|
batch_size = 8
|
||||||
|
height = 32
|
||||||
|
width = 32
|
||||||
|
|
||||||
|
block = SpatiotemporalAttentionFullNotWeightShared(in_channels=in_channels)
|
||||||
|
|
||||||
|
input1 = torch.rand(batch_size, in_channels, height, width)
|
||||||
|
input2 = torch.rand(batch_size, in_channels, height, width)
|
||||||
|
|
||||||
|
output1, output2 = block(input1, input2)
|
||||||
|
|
||||||
|
print(f"Input1 size: {input1.size()}")
|
||||||
|
print(f"Input2 size: {input2.size()}")
|
||||||
|
print(f"Output1 size: {output1.size()}")
|
||||||
|
print(f"Output2 size: {output2.size()}")
|
BIN
componets/UFFC(CV2维任务).pdf
Normal file
BIN
componets/UFFC(CV2维任务).pdf
Normal file
Binary file not shown.
123
componets/UFFC.py
Normal file
123
componets/UFFC.py
Normal file
@ -0,0 +1,123 @@
|
|||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
"""ICCV2023
|
||||||
|
最近提出的图像修复方法 LaMa 以快速傅里叶卷积 (FFC) 为基础构建了其网络,该网络最初是为图像分类等高级视觉任务而提出的。
|
||||||
|
FFC 使全卷积网络在其早期层中拥有全局感受野。得益于 FFC 模块的独特特性,LaMa 能够生成稳健的重复纹理,
|
||||||
|
这是以前的修复方法无法实现的。但是,原始 FFC 模块是否适合图像修复等低级视觉任务?
|
||||||
|
在本文中,我们分析了在图像修复中使用 FFC 的基本缺陷,即 1) 频谱偏移、2) 意外的空间激活和 3) 频率感受野有限。
|
||||||
|
这些缺陷使得基于 FFC 的修复框架难以生成复杂纹理并执行完美重建。
|
||||||
|
基于以上分析,我们提出了一种新颖的无偏快速傅里叶卷积 (UFFC) 模块,该模块通过
|
||||||
|
1) 范围变换和逆变换、2) 绝对位置嵌入、3) 动态跳过连接和 4) 自适应剪辑对原始 FFC 模块进行了修改,以克服这些缺陷,
|
||||||
|
实现更好的修复效果。在多个基准数据集上进行的大量实验证明了我们方法的有效性,在纹理捕捉能力和表现力方面均优于最先进的方法。
|
||||||
|
"""
|
||||||
|
|
||||||
|
class FourierUnit_modified(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, in_channels, out_channels, groups=1, spatial_scale_factor=None, spatial_scale_mode='bilinear',
|
||||||
|
spectral_pos_encoding=False, use_se=False, ffc3d=False, fft_norm='ortho'):
|
||||||
|
# bn_layer not used
|
||||||
|
super(FourierUnit_modified, self).__init__()
|
||||||
|
self.groups = groups
|
||||||
|
|
||||||
|
self.input_shape = 32 # change!!!!!it!!!!!!manually!!!!!!
|
||||||
|
self.in_channels = in_channels
|
||||||
|
|
||||||
|
self.locMap = nn.Parameter(torch.rand(self.input_shape, self.input_shape // 2 + 1))
|
||||||
|
|
||||||
|
self.lambda_base = nn.Parameter(torch.tensor(0.), requires_grad=True)
|
||||||
|
|
||||||
|
self.conv_layer_down55 = torch.nn.Conv2d(in_channels=in_channels * 2 + 1, # +1 for locmap
|
||||||
|
out_channels=out_channels * 2,
|
||||||
|
kernel_size=1, stride=1, padding=0, dilation=1, groups=self.groups,
|
||||||
|
bias=False, padding_mode='reflect')
|
||||||
|
self.conv_layer_down55_shift = torch.nn.Conv2d(in_channels=in_channels * 2 + 1, # +1 for locmap
|
||||||
|
out_channels=out_channels * 2,
|
||||||
|
kernel_size=3, stride=1, padding=2, dilation=2,
|
||||||
|
groups=self.groups, bias=False, padding_mode='reflect')
|
||||||
|
|
||||||
|
self.norm = nn.BatchNorm2d(out_channels)
|
||||||
|
|
||||||
|
self.relu = nn.ReLU(inplace=True)
|
||||||
|
|
||||||
|
self.spatial_scale_factor = spatial_scale_factor
|
||||||
|
self.spatial_scale_mode = spatial_scale_mode
|
||||||
|
self.spectral_pos_encoding = spectral_pos_encoding
|
||||||
|
self.ffc3d = ffc3d
|
||||||
|
self.fft_norm = fft_norm
|
||||||
|
|
||||||
|
self.img_freq = None
|
||||||
|
self.distill = None
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
batch = x.shape[0]
|
||||||
|
|
||||||
|
if self.spatial_scale_factor is not None:
|
||||||
|
orig_size = x.shape[-2:]
|
||||||
|
x = F.interpolate(x, scale_factor=self.spatial_scale_factor, mode=self.spatial_scale_mode,
|
||||||
|
align_corners=False)
|
||||||
|
|
||||||
|
fft_dim = (-3, -2, -1) if self.ffc3d else (-2, -1)
|
||||||
|
ffted = torch.fft.rfftn(x, dim=fft_dim, norm=self.fft_norm)
|
||||||
|
ffted = torch.stack((ffted.real, ffted.imag), dim=-1)
|
||||||
|
ffted = ffted.permute(0, 1, 4, 2, 3).contiguous() # (batch, c, 2, h, w/2+1)
|
||||||
|
ffted = ffted.view((batch, -1,) + ffted.size()[3:])
|
||||||
|
|
||||||
|
locMap = self.locMap.expand_as(ffted[:, :1, :, :]) # B 1 H' W'
|
||||||
|
ffted_copy = ffted.clone()
|
||||||
|
|
||||||
|
cat_img_mask_freq = torch.cat((ffted[:, :self.in_channels, :, :],
|
||||||
|
ffted[:, self.in_channels:, :, :],
|
||||||
|
locMap), dim=1)
|
||||||
|
|
||||||
|
ffted = self.conv_layer_down55(cat_img_mask_freq)
|
||||||
|
ffted = torch.fft.fftshift(ffted, dim=-2)
|
||||||
|
|
||||||
|
ffted = self.relu(ffted)
|
||||||
|
|
||||||
|
locMap_shift = torch.fft.fftshift(locMap, dim=-2) ## ONLY IF NOT SHIFT BACK
|
||||||
|
|
||||||
|
# REPEAT CONV
|
||||||
|
cat_img_mask_freq1 = torch.cat((ffted[:, :self.in_channels, :, :],
|
||||||
|
ffted[:, self.in_channels:, :, :],
|
||||||
|
locMap_shift), dim=1)
|
||||||
|
|
||||||
|
ffted = self.conv_layer_down55_shift(cat_img_mask_freq1)
|
||||||
|
ffted = torch.fft.fftshift(ffted, dim=-2)
|
||||||
|
|
||||||
|
lambda_base = torch.sigmoid(self.lambda_base)
|
||||||
|
|
||||||
|
ffted = ffted_copy * lambda_base + ffted * (1 - lambda_base)
|
||||||
|
|
||||||
|
# irfft
|
||||||
|
ffted = ffted.view((batch, -1, 2,) + ffted.size()[2:]).permute(
|
||||||
|
0, 1, 3, 4, 2).contiguous() # (batch,c, t, h, w/2+1, 2)
|
||||||
|
ffted = torch.complex(ffted[..., 0], ffted[..., 1])
|
||||||
|
|
||||||
|
ifft_shape_slice = x.shape[-3:] if self.ffc3d else x.shape[-2:]
|
||||||
|
output = torch.fft.irfftn(ffted, s=ifft_shape_slice, dim=fft_dim, norm=self.fft_norm)
|
||||||
|
|
||||||
|
if self.spatial_scale_factor is not None:
|
||||||
|
output = F.interpolate(output, size=orig_size, mode=self.spatial_scale_mode, align_corners=False)
|
||||||
|
|
||||||
|
epsilon = 0.5
|
||||||
|
output = output - torch.mean(output) + torch.mean(x)
|
||||||
|
output = torch.clip(output, float(x.min() - epsilon), float(x.max() + epsilon))
|
||||||
|
|
||||||
|
self.distill = output # for self perc
|
||||||
|
return output
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
in_channels = 16
|
||||||
|
out_channels = 16
|
||||||
|
|
||||||
|
block = FourierUnit_modified(in_channels=in_channels, out_channels=out_channels)
|
||||||
|
|
||||||
|
input_tensor = torch.rand(8, in_channels, 32, 32)
|
||||||
|
|
||||||
|
output = block(input_tensor)
|
||||||
|
|
||||||
|
print("Input size:", input_tensor.size())
|
||||||
|
print("Output size:", output.size())
|
Loading…
Reference in New Issue
Block a user