pfcfuse/net.py
whaifree 85dc7a92ed feat(net): 添加时空注意力机制- 在 net.py 中引入 SpatiotemporalAttentionFullNotWeightShared 模块
- 在 Restormer_Decoder 类中添加时空注意力机制处理基础特征和细节特征
2024-10-26 19:18:01 +08:00

590 lines
22 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import torch
import torch.nn as nn
import math
import torch.nn.functional as F
import torch.utils.checkpoint as checkpoint
from timm.models.layers import DropPath, to_2tuple, trunc_normal_
from einops import rearrange
from componets.TIAM import SpatiotemporalAttentionFullNotWeightShared
from componets.WTConvCV2 import WTConv2d
# 以一定概率随机丢弃输入张量中的路径,用于正则化模型
def drop_path(x, drop_prob: float = 0., training: bool = False):
if drop_prob == 0. or not training:
return x
keep_prob = 1 - drop_prob
# work with diff dim tensors, not just 2D ConvNets
shape = (x.shape[0],) + (1,) * (x.ndim - 1)
random_tensor = keep_prob + \
torch.rand(shape, dtype=x.dtype, device=x.device)
random_tensor.floor_() # binarize
output = x.div(keep_prob) * random_tensor
return output
class DropPath(nn.Module):
"""
Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
"""
def __init__(self, drop_prob=None):
super(DropPath, self).__init__()
self.drop_prob = drop_prob
def forward(self, x):
return drop_path(x, self.drop_prob, self.training)
# 改点使用Pooling替换AttentionBase
class Pooling(nn.Module):
def __init__(self, kernel_size=3):
super().__init__()
self.pool = nn.AvgPool2d(
kernel_size, stride=1, padding=kernel_size // 2)
def forward(self, x):
return self.pool(x) - x
class PoolMlp(nn.Module):
"""
实现基于1x1卷积的MLP模块。
输入:形状为[B, C, H, W]的张量。
"""
def __init__(self,
in_features,
hidden_features=None,
out_features=None,
act_layer=nn.GELU,
bias=False,
drop=0.):
"""
初始化PoolMlp模块。
参数:
in_features (int): 输入特征的数量。
hidden_features (int, 可选): 隐藏层特征的数量。默认为None设置为与in_features相同。
out_features (int, 可选): 输出特征的数量。默认为None设置为与in_features相同。
act_layer (nn.Module, 可选): 使用的激活层。默认为nn.GELU。
bias (bool, 可选): 是否在卷积层中包含偏置项。默认为False。
drop (float, 可选): Dropout比率。默认为0。
"""
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
self.fc1 = nn.Conv2d(in_features, hidden_features, 1, bias=bias)
self.act = act_layer()
self.fc2 = nn.Conv2d(hidden_features, out_features, 1, bias=bias)
self.drop = nn.Dropout(drop)
def forward(self, x):
"""
通过PoolMlp模块的前向传播。
参数:
x (torch.Tensor): 形状为[B, C, H, W]的输入张量。
返回:
torch.Tensor: 形状为[B, C, H, W]的输出张量。
"""
x = self.fc1(x) # (B, C, H, W) --> (B, C, H, W)
x = self.act(x)
x = self.drop(x)
x = self.fc2(x) # (B, C, H, W) --> (B, C, H, W)
x = self.drop(x)
return x
# class BaseFeatureExtraction1(nn.Module):
# def __init__(self, dim, pool_size=3, mlp_ratio=4.,
# act_layer=nn.GELU,
# # norm_layer=nn.LayerNorm,
# drop=0., drop_path=0.,
# use_layer_scale=True, layer_scale_init_value=1e-5):
#
# super().__init__()
#
# self.norm1 = LayerNorm(dim, 'WithBias')
# self.token_mixer = Pooling(kernel_size=pool_size) # vits是msaMLPs是mlp这个用pool来替代
# self.norm2 = LayerNorm(dim, 'WithBias')
# mlp_hidden_dim = int(dim * mlp_ratio)
# self.poolmlp = PoolMlp(in_features=dim, hidden_features=mlp_hidden_dim,
# act_layer=act_layer, drop=drop)
#
# # The following two techniques are useful to train deep PoolFormers.
# self.drop_path = DropPath(drop_path) if drop_path > 0. \
# else nn.Identity()
# self.use_layer_scale = use_layer_scale
#
# if use_layer_scale:
# self.layer_scale_1 = nn.Parameter(
# torch.ones(dim, dtype=torch.float32) * layer_scale_init_value)
#
# self.layer_scale_2 = nn.Parameter(
# torch.ones(dim, dtype=torch.float32) * layer_scale_init_value)
#
# def forward(self, x): # 1 64 128 128
# if self.use_layer_scale:
# # self.layer_scale_1(64,)
# tmp1 = self.layer_scale_1.unsqueeze(-1).unsqueeze(-1) # 64 1 1
# normal = self.norm1(x) # 1 64 128 128
# token_mix = self.token_mixer(normal) # 1 64 128 128
# x = (x +
# self.drop_path(
# tmp1 * token_mix
# )
# # 该表达式将 self.layer_scale_1 这个一维张量(或变量)在维度末尾添加两个新的维度,使其从一维变为三维。这通常用于使其能够与三维的特征图进行广播操作,如元素相乘。具体用途可能包括调整卷积层或注意力机制中的权重。
# )
# x = x + self.drop_path(
# self.layer_scale_2.unsqueeze(-1).unsqueeze(-1)
# * self.poolmlp(self.norm2(x)))
# else:
# x = x + self.drop_path(self.token_mixer(self.norm1(x))) # 匹配cddfuse
# x = x + self.drop_path(self.poolmlp(self.norm2(x)))
# return x
class BaseFeatureExtraction(nn.Module):
def __init__(self, dim, pool_size=3, mlp_ratio=4.,
act_layer=nn.GELU,
# norm_layer=nn.LayerNorm,
drop=0., drop_path=0.,
use_layer_scale=True, layer_scale_init_value=1e-5):
super().__init__()
self.WTConv2d = WTConv2d(dim, dim)
self.norm1 = LayerNorm(dim, 'WithBias')
self.token_mixer = Pooling(kernel_size=pool_size) # vits是msaMLPs是mlp这个用pool来替代
self.norm2 = LayerNorm(dim, 'WithBias')
mlp_hidden_dim = int(dim * mlp_ratio)
self.poolmlp = PoolMlp(in_features=dim, hidden_features=mlp_hidden_dim,
act_layer=act_layer, drop=drop)
# The following two techniques are useful to train deep PoolFormers.
self.drop_path = DropPath(drop_path) if drop_path > 0. \
else nn.Identity()
self.use_layer_scale = use_layer_scale
if use_layer_scale:
self.layer_scale_1 = nn.Parameter(
torch.ones(dim, dtype=torch.float32) * layer_scale_init_value)
self.layer_scale_2 = nn.Parameter(
torch.ones(dim, dtype=torch.float32) * layer_scale_init_value)
def forward(self, x): # 1 64 128 128
if self.use_layer_scale:
# self.layer_scale_1(64,)
tmp1 = self.layer_scale_1.unsqueeze(-1).unsqueeze(-1) # 64 1 1
normal = self.norm1(x) # 1 64 128 128
token_mix = self.token_mixer(normal) # 1 64 128 128
x = self.WTConv2d(x)
x = (x +
self.drop_path(
tmp1 * token_mix
)
# 该表达式将 self.layer_scale_1 这个一维张量(或变量)在维度末尾添加两个新的维度,使其从一维变为三维。这通常用于使其能够与三维的特征图进行广播操作,如元素相乘。具体用途可能包括调整卷积层或注意力机制中的权重。
)
x = x + self.drop_path(
self.layer_scale_2.unsqueeze(-1).unsqueeze(-1)
* self.poolmlp(self.norm2(x)))
else:
x = x + self.drop_path(self.token_mixer(self.norm1(x))) # 匹配cddfuse
x = x + self.drop_path(self.poolmlp(self.norm2(x)))
return x
class InvertedResidualBlock(nn.Module):
def __init__(self, inp, oup, expand_ratio):
super(InvertedResidualBlock, self).__init__()
hidden_dim = int(inp * expand_ratio)
self.bottleneckBlock = nn.Sequential(
# pw
nn.Conv2d(inp, hidden_dim, 1, bias=False),
# nn.BatchNorm2d(hidden_dim),
nn.ReLU6(inplace=True),
# dw
nn.ReflectionPad2d(1),
nn.Conv2d(hidden_dim, hidden_dim, 3, groups=hidden_dim, bias=False),
# nn.BatchNorm2d(hidden_dim),
nn.ReLU6(inplace=True),
# pw-linear
nn.Conv2d(hidden_dim, oup, 1, bias=False),
# nn.BatchNorm2d(oup),
)
def forward(self, x):
return self.bottleneckBlock(x)
class DetailNode(nn.Module):
# <img src = "http://42.192.130.83:9000/picgo/imgs/小绿鲸英文文献阅读器_ELTITYqm5G.png" / > '
def __init__(self):
super(DetailNode, self).__init__()
self.theta_phi = InvertedResidualBlock(inp=32, oup=32, expand_ratio=2)
self.theta_rho = InvertedResidualBlock(inp=32, oup=32, expand_ratio=2)
self.theta_eta = InvertedResidualBlock(inp=32, oup=32, expand_ratio=2)
self.shffleconv = nn.Conv2d(64, 64, kernel_size=1,
stride=1, padding=0, bias=True)
def separateFeature(self, x):
z1, z2 = x[:, :x.shape[1] // 2], x[:, x.shape[1] // 2:x.shape[1]]
return z1, z2
def forward(self, z1, z2):
z1, z2 = self.separateFeature(
self.shffleconv(torch.cat((z1, z2), dim=1)))
z2 = z2 + self.theta_phi(z1)
z1 = z1 * torch.exp(self.theta_rho(z2)) + self.theta_eta(z2)
return z1, z2
class DetailFeatureExtraction(nn.Module):
def __init__(self, num_layers=3):
super(DetailFeatureExtraction, self).__init__()
INNmodules = [DetailNode() for _ in range(num_layers)]
self.net = nn.Sequential(*INNmodules)
self.enhancement_module = WTConv2d(32, 32)
def forward(self, x): # 1 64 128 128
z1, z2 = x[:, :x.shape[1] // 2], x[:, x.shape[1] // 2:x.shape[1]] # 1 32 128 128
# 增强并添加残差连接
enhanced_z1 = self.enhancement_module(z1)
enhanced_z2 = self.enhancement_module(z2)
# 残差连接
z1 = z1 + enhanced_z1
z2 = z2 + enhanced_z2
for layer in self.net:
z1, z2 = layer(z1, z2)
return torch.cat((z1, z2), dim=1)
# =============================================================================
# =============================================================================
import numbers
##########################################################################
## Layer Norm
def to_3d(x):
return rearrange(x, 'b c h w -> b (h w) c')
def to_4d(x, h, w):
return rearrange(x, 'b (h w) c -> b c h w', h=h, w=w)
class BiasFree_LayerNorm(nn.Module):
def __init__(self, normalized_shape):
super(BiasFree_LayerNorm, self).__init__()
if isinstance(normalized_shape, numbers.Integral):
normalized_shape = (normalized_shape,)
normalized_shape = torch.Size(normalized_shape)
assert len(normalized_shape) == 1
self.weight = nn.Parameter(torch.ones(normalized_shape))
self.normalized_shape = normalized_shape
def forward(self, x):
sigma = x.var(-1, keepdim=True, unbiased=False)
return x / torch.sqrt(sigma + 1e-5) * self.weight
class WithBias_LayerNorm(nn.Module):
def __init__(self, normalized_shape):
super(WithBias_LayerNorm, self).__init__()
if isinstance(normalized_shape, numbers.Integral):
normalized_shape = (normalized_shape,)
normalized_shape = torch.Size(normalized_shape)
assert len(normalized_shape) == 1
self.weight = nn.Parameter(torch.ones(normalized_shape))
self.bias = nn.Parameter(torch.zeros(normalized_shape))
self.normalized_shape = normalized_shape
def forward(self, x):
mu = x.mean(-1, keepdim=True)
sigma = x.var(-1, keepdim=True, unbiased=False)
return (x - mu) / torch.sqrt(sigma + 1e-5) * self.weight + self.bias
class LayerNorm(nn.Module):
def __init__(self, dim, LayerNorm_type):
super(LayerNorm, self).__init__()
if LayerNorm_type == 'BiasFree':
self.body = BiasFree_LayerNorm(dim)
else:
self.body = WithBias_LayerNorm(dim)
def forward(self, x):
h, w = x.shape[-2:]
return to_4d(self.body(to_3d(x)), h, w)
##########################################################################
## Gated-Dconv Feed-Forward Network (GDFN)
class FeedForward(nn.Module):
def __init__(self, dim, ffn_expansion_factor, bias):
super(FeedForward, self).__init__()
hidden_features = int(dim * ffn_expansion_factor)
self.project_in = nn.Conv2d(
dim, hidden_features * 2, kernel_size=1, bias=bias)
self.dwconv = nn.Conv2d(hidden_features * 2, hidden_features * 2, kernel_size=3,
stride=1, padding=1, groups=hidden_features * 2, bias=bias)
self.project_out = nn.Conv2d(
hidden_features, dim, kernel_size=1, bias=bias)
def forward(self, x):
x = self.project_in(x)
x1, x2 = self.dwconv(x).chunk(2, dim=1)
x = F.gelu(x1) * x2
x = self.project_out(x)
return x
##########################################################################
## Multi-DConv Head Transposed Self-Attention (MDTA)
class Attention(nn.Module):
def __init__(self, dim, num_heads, bias):
super(Attention, self).__init__()
self.num_heads = num_heads
self.temperature = nn.Parameter(torch.ones(num_heads, 1, 1))
self.qkv = nn.Conv2d(dim, dim * 3, kernel_size=1, bias=bias)
self.qkv_dwconv = nn.Conv2d(
dim * 3, dim * 3, kernel_size=3, stride=1, padding=1, groups=dim * 3, bias=bias)
self.project_out = nn.Conv2d(dim, dim, kernel_size=1, bias=bias)
def forward(self, x):
b, c, h, w = x.shape
qkv = self.qkv_dwconv(self.qkv(x))
q, k, v = qkv.chunk(3, dim=1)
q = rearrange(q, 'b (head c) h w -> b head c (h w)',
head=self.num_heads)
k = rearrange(k, 'b (head c) h w -> b head c (h w)',
head=self.num_heads)
v = rearrange(v, 'b (head c) h w -> b head c (h w)',
head=self.num_heads)
q = torch.nn.functional.normalize(q, dim=-1)
k = torch.nn.functional.normalize(k, dim=-1)
attn = (q @ k.transpose(-2, -1)) * self.temperature
attn = attn.softmax(dim=-1)
out = (attn @ v)
out = rearrange(out, 'b head c (h w) -> b (head c) h w',
head=self.num_heads, h=h, w=w)
out = self.project_out(out)
return out
##########################################################################
class TransformerBlock(nn.Module):
def __init__(self, dim, num_heads, ffn_expansion_factor, bias, LayerNorm_type):
super(TransformerBlock, self).__init__()
self.norm1 = LayerNorm(dim, LayerNorm_type)
self.attn = Attention(dim, num_heads, bias)
self.norm2 = LayerNorm(dim, LayerNorm_type)
self.ffn = FeedForward(dim, ffn_expansion_factor, bias)
def forward(self, x):
x = x + self.attn(self.norm1(x))
x = x + self.ffn(self.norm2(x))
return x
##########################################################################
## Overlapped image patch embedding with 3x3 Conv
class OverlapPatchEmbed(nn.Module):
def __init__(self, in_c=3, embed_dim=48, bias=False):
super(OverlapPatchEmbed, self).__init__()
self.proj = nn.Conv2d(in_c, embed_dim, kernel_size=3,
stride=1, padding=1, bias=bias)
def forward(self, x):
x = self.proj(x)
return x
class BaseFeatureExtractionSAR(nn.Module):
def __init__(self, dim, pool_size=3, mlp_ratio=4.,
act_layer=nn.GELU,
# norm_layer=nn.LayerNorm,
drop=0., drop_path=0.,
use_layer_scale=True, layer_scale_init_value=1e-5):
super().__init__()
self.WTConv2d = WTConv2d(dim, dim)
self.norm1 = LayerNorm(dim, 'WithBias')
self.token_mixer = Pooling(kernel_size=pool_size) # vits是msaMLPs是mlp这个用pool来替代
self.norm2 = LayerNorm(dim, 'WithBias')
mlp_hidden_dim = int(dim * mlp_ratio)
self.poolmlp = PoolMlp(in_features=dim, hidden_features=mlp_hidden_dim,
act_layer=act_layer, drop=drop)
# The following two techniques are useful to train deep PoolFormers.
self.drop_path = DropPath(drop_path) if drop_path > 0. \
else nn.Identity()
self.use_layer_scale = use_layer_scale
if use_layer_scale:
self.layer_scale_1 = nn.Parameter(
torch.ones(dim, dtype=torch.float32) * layer_scale_init_value)
self.layer_scale_2 = nn.Parameter(
torch.ones(dim, dtype=torch.float32) * layer_scale_init_value)
def forward(self, x): # 1 64 128 128
if self.use_layer_scale:
# self.layer_scale_1(64,)
tmp1 = self.layer_scale_1.unsqueeze(-1).unsqueeze(-1) # 64 1 1
normal = self.norm1(x) # 1 64 128 128
token_mix = self.token_mixer(normal) # 1 64 128 128
x = self.WTConv2d(x)
x = (x +
self.drop_path(
tmp1 * token_mix
)
# 该表达式将 self.layer_scale_1 这个一维张量(或变量)在维度末尾添加两个新的维度,使其从一维变为三维。这通常用于使其能够与三维的特征图进行广播操作,如元素相乘。具体用途可能包括调整卷积层或注意力机制中的权重。
)
x = x + self.drop_path(
self.layer_scale_2.unsqueeze(-1).unsqueeze(-1)
* self.poolmlp(self.norm2(x)))
else:
x = x + self.drop_path(self.token_mixer(self.norm1(x))) # 匹配cddfuse
x = x + self.drop_path(self.poolmlp(self.norm2(x)))
return x
class DetailFeatureExtractionSAR(nn.Module):
def __init__(self, num_layers=3):
super(DetailFeatureExtractionSAR, self).__init__()
INNmodules = [DetailNode() for _ in range(num_layers)]
self.net = nn.Sequential(*INNmodules)
self.enhancement_module = WTConv2d(32, 32)
def forward(self, x): # 1 64 128 128
z1, z2 = x[:, :x.shape[1] // 2], x[:, x.shape[1] // 2:x.shape[1]] # 1 32 128 128
# 增强并添加残差连接
enhanced_z1 = self.enhancement_module(z1)
enhanced_z2 = self.enhancement_module(z2)
# 残差连接
z1 = z1 + enhanced_z1
z2 = z2 + enhanced_z2
for layer in self.net:
z1, z2 = layer(z1, z2)
return torch.cat((z1, z2), dim=1)
class Restormer_Encoder(nn.Module):
def __init__(self,
inp_channels=1,
out_channels=1,
dim=64,
num_blocks=[4, 4],
heads=[8, 8, 8],
ffn_expansion_factor=2,
bias=False,
LayerNorm_type='WithBias',
):
super(Restormer_Encoder, self).__init__()
# 区分
self.patch_embed = OverlapPatchEmbed(inp_channels, dim)
self.encoder_level1 = nn.Sequential(
*[TransformerBlock(dim=dim, num_heads=heads[0], ffn_expansion_factor=ffn_expansion_factor,
bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_blocks[0])])
self.baseFeature = BaseFeatureExtraction(dim=dim)
self.detailFeature = DetailFeatureExtraction()
self.baseFeature_sar = BaseFeatureExtractionSAR(dim=dim)
self.detailFeature_sar = DetailFeatureExtractionSAR()
def forward(self, inp_img,is_sar = False):
inp_enc_level1 = self.patch_embed(inp_img)
out_enc_level1 = self.encoder_level1(inp_enc_level1)
if is_sar:
base_feature = self.baseFeature_sar(out_enc_level1) # 1 64 128 128
detail_feature = self.detailFeature_sar(out_enc_level1) # 1 64 128 128
return base_feature, detail_feature, out_enc_level1 # 1 64 128 128
else:
base_feature = self.baseFeature(out_enc_level1) # 1 64 128 128
detail_feature = self.detailFeature(out_enc_level1) # 1 64 128 128
return base_feature, detail_feature, out_enc_level1 # 1 64 128 128
class Restormer_Decoder(nn.Module):
def __init__(self,
inp_channels=1,
out_channels=1,
dim=64,
num_blocks=[4, 4],
heads=[8, 8, 8],
ffn_expansion_factor=2,
bias=False,
LayerNorm_type='WithBias',
):
super(Restormer_Decoder, self).__init__()
self.reduce_channel = nn.Conv2d(int(dim * 2), int(dim), kernel_size=1, bias=bias)
self.encoder_level2 = nn.Sequential(*[TransformerBlock(dim=dim, num_heads=heads[1], ffn_expansion_factor=ffn_expansion_factor,
bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_blocks[1])])
self.output = nn.Sequential(
nn.Conv2d(int(dim), int(dim) // 2, kernel_size=3,
stride=1, padding=1, bias=bias),
nn.LeakyReLU(),
nn.Conv2d(int(dim) // 2, out_channels, kernel_size=3,
stride=1, padding=1, bias=bias), )
self.sigmoid = nn.Sigmoid()
self.spatiotemporalAttentionFullNotWeightShared = SpatiotemporalAttentionFullNotWeightShared(dim=dim)
def forward(self, inp_img, base_feature, detail_feature):
base_feature, detail_feature =self.spatiotemporalAttentionFullNotWeightShared(base_feature, detail_feature)
out_enc_level0 = torch.cat((base_feature, detail_feature), dim=1)
out_enc_level0 = self.reduce_channel(out_enc_level0)
out_enc_level1 = self.encoder_level2(out_enc_level0)
if inp_img is not None:
out_enc_level1 = self.output(out_enc_level1) + inp_img
else:
out_enc_level1 = self.output(out_enc_level1)
return self.sigmoid(out_enc_level1), out_enc_level0
if __name__ == '__main__':
height = 128
width = 128
window_size = 8
modelE = Restormer_Encoder().cuda()
modelD = Restormer_Decoder().cuda()
print(modelE)
print(modelD)