pfcfuse/componets/DynamicFilter(频域模块动态滤波器用于CV2维图像).py
zjut 927d049f13 feat(components): 添加多种特征增强模块
- 新增 DEConv、DynamicFilter、SCSA、SEBlock、SMFA、TIAM 和 UFFC 模块
- 这些模块提供了不同的特征增强功能,如卷积差分、频域滤波、注意力机制等
- 可以根据需求选择合适的模块来提升模型性能
2024-11-15 09:18:34 +08:00

116 lines
4.7 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import torch
import torch.nn as nn
from timm.layers.helpers import to_2tuple
"""
配备多头自注意力 MHSA 的模型在计算机视觉方面取得了显着的性能。它们的计算复杂度与输入特征图中的二次像素数成正比,导致处理速度缓慢,尤其是在处理高分辨率图像时。
为了规避这个问题提出了一种新型的代币混合器作为MHSA的替代方案基于FFT的代币混合器涉及类似于MHSA的全局操作但计算复杂度较低。
在这里,我们提出了一种名为动态过滤器的新型令牌混合器以缩小上述差距。
DynamicFilter 模块通过频域滤波和动态调整滤波器权重,能够对图像进行复杂的增强和处理。
"""
class StarReLU(nn.Module):
"""
StarReLU: s * relu(x) ** 2 + b
"""
def __init__(self, scale_value=1.0, bias_value=0.0,
scale_learnable=True, bias_learnable=True,
mode=None, inplace=False):
super().__init__()
self.inplace = inplace
self.relu = nn.ReLU(inplace=inplace)
self.scale = nn.Parameter(scale_value * torch.ones(1),
requires_grad=scale_learnable)
self.bias = nn.Parameter(bias_value * torch.ones(1),
requires_grad=bias_learnable)
def forward(self, x):
return self.scale * self.relu(x) ** 2 + self.bias
class Mlp(nn.Module):
""" MLP as used in MetaFormer models, eg Transformer, MLP-Mixer, PoolFormer, MetaFormer baslines and related networks.
Mostly copied from timm.
"""
def __init__(self, dim, mlp_ratio=4, out_features=None, act_layer=StarReLU, drop=0.,
bias=False, **kwargs):
super().__init__()
in_features = dim
out_features = out_features or in_features
hidden_features = int(mlp_ratio * in_features)
drop_probs = to_2tuple(drop)
self.fc1 = nn.Linear(in_features, hidden_features, bias=bias)
self.act = act_layer()
self.drop1 = nn.Dropout(drop_probs[0])
self.fc2 = nn.Linear(hidden_features, out_features, bias=bias)
self.drop2 = nn.Dropout(drop_probs[1])
def forward(self, x):
x = self.fc1(x)
x = self.act(x)
x = self.drop1(x)
x = self.fc2(x)
x = self.drop2(x)
return x
class DynamicFilter(nn.Module):
def __init__(self, dim, expansion_ratio=2, reweight_expansion_ratio=.25,
act1_layer=StarReLU, act2_layer=nn.Identity,
bias=False, num_filters=4, size=14, weight_resize=False,
**kwargs):
super().__init__()
size = to_2tuple(size)
self.size = size[0]
self.filter_size = size[1] // 2 + 1
self.num_filters = num_filters
self.dim = dim
self.med_channels = int(expansion_ratio * dim)
self.weight_resize = weight_resize
self.pwconv1 = nn.Linear(dim, self.med_channels, bias=bias)
self.act1 = act1_layer()
self.reweight = Mlp(dim, reweight_expansion_ratio, num_filters * self.med_channels)
self.complex_weights = nn.Parameter(
torch.randn(self.size, self.filter_size, num_filters, 2,
dtype=torch.float32) * 0.02)
self.act2 = act2_layer()
self.pwconv2 = nn.Linear(self.med_channels, dim, bias=bias)
def forward(self, x):
B, H, W, _ = x.shape
routeing = self.reweight(x.mean(dim=(1, 2))).view(B, self.num_filters,
-1).softmax(dim=1)
x = self.pwconv1(x)
x = self.act1(x)
x = x.to(torch.float32)
x = torch.fft.rfft2(x, dim=(1, 2), norm='ortho')
if self.weight_resize:
complex_weights = resize_complex_weight(self.complex_weights, x.shape[1],
x.shape[2])
complex_weights = torch.view_as_complex(complex_weights.contiguous())
else:
complex_weights = torch.view_as_complex(self.complex_weights)
routeing = routeing.to(torch.complex64)
weight = torch.einsum('bfc,hwf->bhwc', routeing, complex_weights)
if self.weight_resize:
weight = weight.view(-1, x.shape[1], x.shape[2], self.med_channels)
else:
weight = weight.view(-1, self.size, self.filter_size, self.med_channels)
x = x * weight
x = torch.fft.irfft2(x, s=(H, W), dim=(1, 2), norm='ortho')
x = self.act2(x)
x = self.pwconv2(x)
return x
if __name__ == '__main__':
block = DynamicFilter(32, size=64) # size==H,W
input = torch.rand(3, 64, 64, 32)
output = block(input)
print(input.size())
print(output.size())