pfcfuse/componets/TIAM.py
zjut 927d049f13 feat(components): 添加多种特征增强模块
- 新增 DEConv、DynamicFilter、SCSA、SEBlock、SMFA、TIAM 和 UFFC 模块
- 这些模块提供了不同的特征增强功能,如卷积差分、频域滤波、注意力机制等
- 可以根据需求选择合适的模块来提升模型性能
2024-11-15 09:18:34 +08:00

111 lines
5.1 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import torch
import torch.nn as nn
import torch.nn.functional as F
"""Elsevier2024
变化检测 (CD) 是地球观测中一种重要的监测方法,尤其适用于土地利用分析、城市管理和灾害损失评估。然而,在星座互联和空天协作时代,感兴趣区域 (ROI) 的变化由于几何透视旋转和时间风格差异而导致许多错误检测。
为了应对这些挑战,我们引入了 CDNeXt该框架阐明了一种稳健而有效的方法用于将基于预训练主干的 Siamese 网络与用于遥感图像的创新时空交互注意模块 (TIAM) 相结合。
CDNeXt 可分为四个主要组件:编码器、交互器、解码器和检测器。值得注意的是,由 TIAM 提供支持的交互器从编码器提取的二进制时间特征中查询和重建空间透视依赖关系和时间风格相关性,以扩大 ROI 变化的差异。
最后,检测器集成解码器生成的分层特征,随后生成二进制变化掩码。
"""
class SpatiotemporalAttentionFullNotWeightShared(nn.Module):
def __init__(self, in_channels, inter_channels=None, dimension=2, sub_sample=False):
super(SpatiotemporalAttentionFullNotWeightShared, self).__init__()
assert dimension in [2, ]
self.dimension = dimension
self.sub_sample = sub_sample
self.in_channels = in_channels
self.inter_channels = inter_channels
if self.inter_channels is None:
self.inter_channels = in_channels // 2
if self.inter_channels == 0:
self.inter_channels = 1
self.g1 = nn.Sequential(
nn.BatchNorm2d(self.in_channels),
nn.Conv2d(in_channels=self.in_channels, out_channels=self.inter_channels,
kernel_size=1, stride=1, padding=0)
)
self.g2 = nn.Sequential(
nn.BatchNorm2d(self.in_channels),
nn.Conv2d(in_channels=self.in_channels, out_channels=self.inter_channels,
kernel_size=1, stride=1, padding=0),
)
self.W1 = nn.Sequential(
nn.Conv2d(in_channels=self.inter_channels, out_channels=self.in_channels,
kernel_size=1, stride=1, padding=0),
nn.BatchNorm2d(self.in_channels)
)
self.W2 = nn.Sequential(
nn.Conv2d(in_channels=self.inter_channels, out_channels=self.in_channels,
kernel_size=1, stride=1, padding=0),
nn.BatchNorm2d(self.in_channels)
)
self.theta = nn.Sequential(
nn.BatchNorm2d(self.in_channels),
nn.Conv2d(in_channels=self.in_channels, out_channels=self.inter_channels,
kernel_size=1, stride=1, padding=0),
)
self.phi = nn.Sequential(
nn.BatchNorm2d(self.in_channels),
nn.Conv2d(in_channels=self.in_channels, out_channels=self.inter_channels,
kernel_size=1, stride=1, padding=0),
)
def forward(self, x1, x2):
"""
:param x: (b, c, h, w)
:param return_nl_map: if True return z, nl_map, else only return z.
:return:
"""
batch_size = x1.size(0)
g_x11 = self.g1(x1).reshape(batch_size, self.inter_channels, -1)
g_x12 = g_x11.permute(0, 2, 1)
g_x21 = self.g2(x2).reshape(batch_size, self.inter_channels, -1)
g_x22 = g_x21.permute(0, 2, 1)
theta_x1 = self.theta(x1).reshape(batch_size, self.inter_channels, -1)
theta_x2 = theta_x1.permute(0, 2, 1)
phi_x1 = self.phi(x2).reshape(batch_size, self.inter_channels, -1)
phi_x2 = phi_x1.permute(0, 2, 1)
energy_time_1 = torch.matmul(theta_x1, phi_x2)
energy_time_2 = energy_time_1.permute(0, 2, 1)
energy_space_1 = torch.matmul(theta_x2, phi_x1)
energy_space_2 = energy_space_1.permute(0, 2, 1)
energy_time_1s = F.softmax(energy_time_1, dim=-1)
energy_time_2s = F.softmax(energy_time_2, dim=-1)
energy_space_2s = F.softmax(energy_space_1, dim=-2)
energy_space_1s = F.softmax(energy_space_2, dim=-2)
# C1*S(C2) energy_time_1s * C1*H1W1 g_x12 * energy_space_1s S(H2W2)*H1W1 -> C1*H1W1
y1 = torch.matmul(torch.matmul(energy_time_2s, g_x11), energy_space_2s).contiguous() # C2*H2W2
# C2*S(C1) energy_time_2s * C2*H2W2 g_x21 * energy_space_2s S(H1W1)*H2W2 -> C2*H2W2
y2 = torch.matmul(torch.matmul(energy_time_1s, g_x21), energy_space_1s).contiguous() # C1*H1W1
y1 = y1.reshape(batch_size, self.inter_channels, *x2.size()[2:])
y2 = y2.reshape(batch_size, self.inter_channels, *x1.size()[2:])
return x1 + self.W1(y1), x2 + self.W2(y2)
if __name__ == '__main__':
in_channels = 64
batch_size = 8
height = 32
width = 32
block = SpatiotemporalAttentionFullNotWeightShared(in_channels=in_channels)
input1 = torch.rand(batch_size, in_channels, height, width)
input2 = torch.rand(batch_size, in_channels, height, width)
output1, output2 = block(input1, input2)
print(f"Input1 size: {input1.size()}")
print(f"Input2 size: {input2.size()}")
print(f"Output1 size: {output1.size()}")
print(f"Output2 size: {output2.size()}")