111 lines
5.1 KiB
Python
111 lines
5.1 KiB
Python
import torch
|
||
import torch.nn as nn
|
||
import torch.nn.functional as F
|
||
|
||
"""Elsevier2024
|
||
变化检测 (CD) 是地球观测中一种重要的监测方法,尤其适用于土地利用分析、城市管理和灾害损失评估。然而,在星座互联和空天协作时代,感兴趣区域 (ROI) 的变化由于几何透视旋转和时间风格差异而导致许多错误检测。
|
||
为了应对这些挑战,我们引入了 CDNeXt,该框架阐明了一种稳健而有效的方法,用于将基于预训练主干的 Siamese 网络与用于遥感图像的创新时空交互注意模块 (TIAM) 相结合。
|
||
CDNeXt 可分为四个主要组件:编码器、交互器、解码器和检测器。值得注意的是,由 TIAM 提供支持的交互器从编码器提取的二进制时间特征中查询和重建空间透视依赖关系和时间风格相关性,以扩大 ROI 变化的差异。
|
||
最后,检测器集成解码器生成的分层特征,随后生成二进制变化掩码。
|
||
"""
|
||
|
||
class SpatiotemporalAttentionFullNotWeightShared(nn.Module):
|
||
def __init__(self, in_channels, inter_channels=None, dimension=2, sub_sample=False):
|
||
super(SpatiotemporalAttentionFullNotWeightShared, self).__init__()
|
||
assert dimension in [2, ]
|
||
self.dimension = dimension
|
||
self.sub_sample = sub_sample
|
||
self.in_channels = in_channels
|
||
self.inter_channels = inter_channels
|
||
|
||
if self.inter_channels is None:
|
||
self.inter_channels = in_channels // 2
|
||
if self.inter_channels == 0:
|
||
self.inter_channels = 1
|
||
|
||
self.g1 = nn.Sequential(
|
||
nn.BatchNorm2d(self.in_channels),
|
||
nn.Conv2d(in_channels=self.in_channels, out_channels=self.inter_channels,
|
||
kernel_size=1, stride=1, padding=0)
|
||
)
|
||
self.g2 = nn.Sequential(
|
||
nn.BatchNorm2d(self.in_channels),
|
||
nn.Conv2d(in_channels=self.in_channels, out_channels=self.inter_channels,
|
||
kernel_size=1, stride=1, padding=0),
|
||
)
|
||
|
||
self.W1 = nn.Sequential(
|
||
nn.Conv2d(in_channels=self.inter_channels, out_channels=self.in_channels,
|
||
kernel_size=1, stride=1, padding=0),
|
||
nn.BatchNorm2d(self.in_channels)
|
||
)
|
||
self.W2 = nn.Sequential(
|
||
nn.Conv2d(in_channels=self.inter_channels, out_channels=self.in_channels,
|
||
kernel_size=1, stride=1, padding=0),
|
||
nn.BatchNorm2d(self.in_channels)
|
||
)
|
||
self.theta = nn.Sequential(
|
||
nn.BatchNorm2d(self.in_channels),
|
||
nn.Conv2d(in_channels=self.in_channels, out_channels=self.inter_channels,
|
||
kernel_size=1, stride=1, padding=0),
|
||
)
|
||
self.phi = nn.Sequential(
|
||
nn.BatchNorm2d(self.in_channels),
|
||
nn.Conv2d(in_channels=self.in_channels, out_channels=self.inter_channels,
|
||
kernel_size=1, stride=1, padding=0),
|
||
)
|
||
|
||
def forward(self, x1, x2):
|
||
"""
|
||
:param x: (b, c, h, w)
|
||
:param return_nl_map: if True return z, nl_map, else only return z.
|
||
:return:
|
||
"""
|
||
batch_size = x1.size(0)
|
||
g_x11 = self.g1(x1).reshape(batch_size, self.inter_channels, -1)
|
||
g_x12 = g_x11.permute(0, 2, 1)
|
||
g_x21 = self.g2(x2).reshape(batch_size, self.inter_channels, -1)
|
||
g_x22 = g_x21.permute(0, 2, 1)
|
||
|
||
theta_x1 = self.theta(x1).reshape(batch_size, self.inter_channels, -1)
|
||
theta_x2 = theta_x1.permute(0, 2, 1)
|
||
|
||
phi_x1 = self.phi(x2).reshape(batch_size, self.inter_channels, -1)
|
||
phi_x2 = phi_x1.permute(0, 2, 1)
|
||
|
||
energy_time_1 = torch.matmul(theta_x1, phi_x2)
|
||
energy_time_2 = energy_time_1.permute(0, 2, 1)
|
||
energy_space_1 = torch.matmul(theta_x2, phi_x1)
|
||
energy_space_2 = energy_space_1.permute(0, 2, 1)
|
||
|
||
energy_time_1s = F.softmax(energy_time_1, dim=-1)
|
||
energy_time_2s = F.softmax(energy_time_2, dim=-1)
|
||
energy_space_2s = F.softmax(energy_space_1, dim=-2)
|
||
energy_space_1s = F.softmax(energy_space_2, dim=-2)
|
||
# C1*S(C2) energy_time_1s * C1*H1W1 g_x12 * energy_space_1s S(H2W2)*H1W1 -> C1*H1W1
|
||
y1 = torch.matmul(torch.matmul(energy_time_2s, g_x11), energy_space_2s).contiguous() # C2*H2W2
|
||
# C2*S(C1) energy_time_2s * C2*H2W2 g_x21 * energy_space_2s S(H1W1)*H2W2 -> C2*H2W2
|
||
y2 = torch.matmul(torch.matmul(energy_time_1s, g_x21), energy_space_1s).contiguous() # C1*H1W1
|
||
y1 = y1.reshape(batch_size, self.inter_channels, *x2.size()[2:])
|
||
y2 = y2.reshape(batch_size, self.inter_channels, *x1.size()[2:])
|
||
return x1 + self.W1(y1), x2 + self.W2(y2)
|
||
|
||
|
||
if __name__ == '__main__':
|
||
in_channels = 64
|
||
batch_size = 8
|
||
height = 32
|
||
width = 32
|
||
|
||
block = SpatiotemporalAttentionFullNotWeightShared(in_channels=in_channels)
|
||
|
||
input1 = torch.rand(batch_size, in_channels, height, width)
|
||
input2 = torch.rand(batch_size, in_channels, height, width)
|
||
|
||
output1, output2 = block(input1, input2)
|
||
|
||
print(f"Input1 size: {input1.size()}")
|
||
print(f"Input2 size: {input2.size()}")
|
||
print(f"Output1 size: {output1.size()}")
|
||
print(f"Output2 size: {output2.size()}")
|