pfcfuse/componets/WTConvCV2.py

198 lines
7.3 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import pywt
import pywt.data
import torch
from torch import nn
from torch.autograd import Function
import torch.nn.functional as F
"""ECCV2024 https://arxiv.org/abs/2407.05848
近年来,人们尝试增加卷积神经网络 (CNN) 的内核大小,以模拟 Vision Transformers (ViTs) 自注意力模块的全局接受场。
然而,这种方法在实现全局接受场之前就很快达到上限并饱和。
在这项研究中,我们证明了通过利用小波变换 (WT),实际上可以获得非常大的接受场,而不会遭受过度参数化,所提出的层名为 WTConv
可用作现有架构中的直接替换,产生有效的多频响应,并可随接受场的大小优雅地扩展。
我们证明了 ConvNeXt 和 MobileNetV2 架构中 WTConv 层对图像分类的有效性,以及下游任务的主干,并表明它具有其他属性,
例如对图像损坏的鲁棒性和对纹理形状的响应增强。
"""
def create_wavelet_filter(wave, in_size, out_size, type=torch.float):
w = pywt.Wavelet(wave)
dec_hi = torch.tensor(w.dec_hi[::-1], dtype=type)
dec_lo = torch.tensor(w.dec_lo[::-1], dtype=type)
dec_filters = torch.stack([dec_lo.unsqueeze(0) * dec_lo.unsqueeze(1),
dec_lo.unsqueeze(0) * dec_hi.unsqueeze(1),
dec_hi.unsqueeze(0) * dec_lo.unsqueeze(1),
dec_hi.unsqueeze(0) * dec_hi.unsqueeze(1)], dim=0)
dec_filters = dec_filters[:, None].repeat(in_size, 1, 1, 1)
rec_hi = torch.tensor(w.rec_hi[::-1], dtype=type).flip(dims=[0])
rec_lo = torch.tensor(w.rec_lo[::-1], dtype=type).flip(dims=[0])
rec_filters = torch.stack([rec_lo.unsqueeze(0) * rec_lo.unsqueeze(1),
rec_lo.unsqueeze(0) * rec_hi.unsqueeze(1),
rec_hi.unsqueeze(0) * rec_lo.unsqueeze(1),
rec_hi.unsqueeze(0) * rec_hi.unsqueeze(1)], dim=0)
rec_filters = rec_filters[:, None].repeat(out_size, 1, 1, 1)
return dec_filters, rec_filters
def wavelet_transform(x, filters):
b, c, h, w = x.shape
pad = (filters.shape[2] // 2 - 1, filters.shape[3] // 2 - 1)
x = F.conv2d(x, filters, stride=2, groups=c, padding=pad)
x = x.reshape(b, c, 4, h // 2, w // 2)
return x
def inverse_wavelet_transform(x, filters):
b, c, _, h_half, w_half = x.shape
pad = (filters.shape[2] // 2 - 1, filters.shape[3] // 2 - 1)
x = x.reshape(b, c * 4, h_half, w_half)
x = F.conv_transpose2d(x, filters, stride=2, groups=c, padding=pad)
return x
def wavelet_transform_init(filters):
class WaveletTransform(Function):
@staticmethod
def forward(ctx, input):
with torch.no_grad():
x = wavelet_transform(input, filters)
return x
@staticmethod
def backward(ctx, grad_output):
grad = inverse_wavelet_transform(grad_output, filters)
return grad, None
return WaveletTransform().apply
def inverse_wavelet_transform_init(filters):
class InverseWaveletTransform(Function):
@staticmethod
def forward(ctx, input):
with torch.no_grad():
x = inverse_wavelet_transform(input, filters)
return x
@staticmethod
def backward(ctx, grad_output):
grad = wavelet_transform(grad_output, filters)
return grad, None
return InverseWaveletTransform().apply
class WTConv2d(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size=5, stride=1, bias=True, wt_levels=1, wt_type='db1'):
super(WTConv2d, self).__init__()
assert in_channels == out_channels
self.in_channels = in_channels
self.wt_levels = wt_levels
self.stride = stride
self.dilation = 1
self.wt_filter, self.iwt_filter = create_wavelet_filter(wt_type, in_channels, in_channels, torch.float)
self.wt_filter = nn.Parameter(self.wt_filter, requires_grad=False)
self.iwt_filter = nn.Parameter(self.iwt_filter, requires_grad=False)
self.wt_function = wavelet_transform_init(self.wt_filter)
self.iwt_function = inverse_wavelet_transform_init(self.iwt_filter)
self.base_conv = nn.Conv2d(in_channels, in_channels, kernel_size, padding='same', stride=1, dilation=1,
groups=in_channels, bias=bias)
self.base_scale = _ScaleModule([1, in_channels, 1, 1])
self.wavelet_convs = nn.ModuleList(
[nn.Conv2d(in_channels * 4, in_channels * 4, kernel_size, padding='same', stride=1, dilation=1,
groups=in_channels * 4, bias=False) for _ in range(self.wt_levels)]
)
self.wavelet_scale = nn.ModuleList(
[_ScaleModule([1, in_channels * 4, 1, 1], init_scale=0.1) for _ in range(self.wt_levels)]
)
if self.stride > 1:
self.stride_filter = nn.Parameter(torch.ones(in_channels, 1, 1, 1), requires_grad=False)
self.do_stride = lambda x_in: F.conv2d(x_in, self.stride_filter, bias=None, stride=self.stride,
groups=in_channels)
else:
self.do_stride = None
def forward(self, x):
x_ll_in_levels = []
x_h_in_levels = []
shapes_in_levels = []
curr_x_ll = x
for i in range(self.wt_levels):
curr_shape = curr_x_ll.shape
shapes_in_levels.append(curr_shape)
if (curr_shape[2] % 2 > 0) or (curr_shape[3] % 2 > 0):
curr_pads = (0, curr_shape[3] % 2, 0, curr_shape[2] % 2)
curr_x_ll = F.pad(curr_x_ll, curr_pads)
curr_x = self.wt_function(curr_x_ll)
curr_x_ll = curr_x[:, :, 0, :, :]
shape_x = curr_x.shape
curr_x_tag = curr_x.reshape(shape_x[0], shape_x[1] * 4, shape_x[3], shape_x[4])
curr_x_tag = self.wavelet_scale[i](self.wavelet_convs[i](curr_x_tag))
curr_x_tag = curr_x_tag.reshape(shape_x)
x_ll_in_levels.append(curr_x_tag[:, :, 0, :, :])
x_h_in_levels.append(curr_x_tag[:, :, 1:4, :, :])
next_x_ll = 0
for i in range(self.wt_levels - 1, -1, -1):
curr_x_ll = x_ll_in_levels.pop()
curr_x_h = x_h_in_levels.pop()
curr_shape = shapes_in_levels.pop()
curr_x_ll = curr_x_ll + next_x_ll
curr_x = torch.cat([curr_x_ll.unsqueeze(2), curr_x_h], dim=2)
next_x_ll = self.iwt_function(curr_x)
next_x_ll = next_x_ll[:, :, :curr_shape[2], :curr_shape[3]]
x_tag = next_x_ll
assert len(x_ll_in_levels) == 0
x = self.base_scale(self.base_conv(x))
x = x + x_tag
if self.do_stride is not None:
x = self.do_stride(x)
return x
class _ScaleModule(nn.Module):
def __init__(self, dims, init_scale=1.0, init_bias=0):
super(_ScaleModule, self).__init__()
self.dims = dims
self.weight = nn.Parameter(torch.ones(*dims) * init_scale)
self.bias = None
def forward(self, x):
return torch.mul(self.weight, x)
if __name__ == '__main__':
in_channels = 3
out_channels = 3
block = WTConv2d(in_channels, out_channels)
input = torch.rand(1, in_channels, 64, 64)
output = block(input)
print(input.size())
print(output.size())