pfcfuse/net.py
zjut e8a0212bbb feat(net): 增加 SAR 图像处理支持
- 新增 BaseFeatureExtractionSAR 和 DetailFeatureExtractionSAR 模块
- 修改 DIDF_Encoder 类,支持 SAR 图像输入
- 更新测试和训练脚本,增加 SAR 图像处理相关逻辑
2024-11-16 12:59:07 +08:00

472 lines
17 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import torch
import torch.nn as nn
import math
import torch.nn.functional as F
import torch.utils.checkpoint as checkpoint
from timm.models.layers import DropPath, to_2tuple, trunc_normal_
from einops import rearrange
def drop_path(x, drop_prob: float = 0., training: bool = False):
if drop_prob == 0. or not training:
return x
keep_prob = 1 - drop_prob
# work with diff dim tensors, not just 2D ConvNets
shape = (x.shape[0],) + (1,) * (x.ndim - 1)
random_tensor = keep_prob + \
torch.rand(shape, dtype=x.dtype, device=x.device)
random_tensor.floor_() # binarize
output = x.div(keep_prob) * random_tensor
return output
class DropPath(nn.Module):
"""
Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
"""
def __init__(self, drop_prob=None):
super(DropPath, self).__init__()
self.drop_prob = drop_prob
def forward(self, x):
return drop_path(x, self.drop_prob, self.training)
class Pooling(nn.Module):
def __init__(self, kernel_size=3):
super().__init__()
self.pool = nn.AvgPool2d(
kernel_size, stride=1, padding=kernel_size // 2)
def forward(self, x):
return self.pool(x) - x
class PoolMlp(nn.Module):
"""
Implementation of MLP with 1*1 convolutions.
Input: tensor with shape [B, C, H, W]
"""
def __init__(self,
in_features,
hidden_features=None,
out_features=None,
act_layer=nn.GELU,
bias=False,
drop=0.):
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
self.fc1 = nn.Conv2d(in_features, hidden_features, 1, bias=bias)
self.act = act_layer()
self.fc2 = nn.Conv2d(hidden_features, out_features, 1, bias=bias)
self.drop = nn.Dropout(drop)
def forward(self, x):
x = self.fc1(x) # (B, C, H, W) --> (B, C, H, W)
x = self.act(x)
x = self.drop(x)
x = self.fc2(x) # (B, C, H, W) --> (B, C, H, W)
x = self.drop(x)
return x
class BaseFeatureExtraction(nn.Module):
def __init__(self, dim, pool_size=3, mlp_ratio=4.,
act_layer=nn.GELU,
# norm_layer=nn.LayerNorm,
drop=0., drop_path=0.,
use_layer_scale=True, layer_scale_init_value=1e-5):
super().__init__()
self.norm1 = LayerNorm(dim, 'WithBias')
self.token_mixer = Pooling(kernel_size=pool_size) # vits是msaMLPs是mlp这个用pool来替代
self.norm2 = LayerNorm(dim, 'WithBias')
mlp_hidden_dim = int(dim * mlp_ratio)
self.poolmlp = PoolMlp(in_features=dim, hidden_features=mlp_hidden_dim,
act_layer=act_layer, drop=drop)
# The following two techniques are useful to train deep PoolFormers.
self.drop_path = DropPath(drop_path) if drop_path > 0. \
else nn.Identity()
self.use_layer_scale = use_layer_scale
if use_layer_scale:
self.layer_scale_1 = nn.Parameter(
torch.ones(dim, dtype=torch.float32) * layer_scale_init_value)
self.layer_scale_2 = nn.Parameter(
torch.ones(dim, dtype=torch.float32) * layer_scale_init_value)
def forward(self, x):
if self.use_layer_scale:
x = x + self.drop_path(
self.layer_scale_1.unsqueeze(-1).unsqueeze(-1)
* self.token_mixer(self.norm1(x)))
x = x + self.drop_path(
self.layer_scale_2.unsqueeze(-1).unsqueeze(-1)
* self.poolmlp(self.norm2(x)))
else:
x = x + self.drop_path(self.token_mixer(self.norm1(x)))
x = x + self.drop_path(self.poolmlp(self.norm2(x)))
return x
class BaseFeatureExtractionSAR(nn.Module):
def __init__(self, dim, pool_size=3, mlp_ratio=4.,
act_layer=nn.GELU,
# norm_layer=nn.LayerNorm,
drop=0., drop_path=0.,
use_layer_scale=True, layer_scale_init_value=1e-5):
super().__init__()
self.norm1 = LayerNorm(dim, 'WithBias')
self.token_mixer = Pooling(kernel_size=pool_size) # vits是msaMLPs是mlp这个用pool来替代
self.norm2 = LayerNorm(dim, 'WithBias')
mlp_hidden_dim = int(dim * mlp_ratio)
self.poolmlp = PoolMlp(in_features=dim, hidden_features=mlp_hidden_dim,
act_layer=act_layer, drop=drop)
# The following two techniques are useful to train deep PoolFormers.
self.drop_path = DropPath(drop_path) if drop_path > 0. \
else nn.Identity()
self.use_layer_scale = use_layer_scale
if use_layer_scale:
self.layer_scale_1 = nn.Parameter(
torch.ones(dim, dtype=torch.float32) * layer_scale_init_value)
self.layer_scale_2 = nn.Parameter(
torch.ones(dim, dtype=torch.float32) * layer_scale_init_value)
def forward(self, x):
if self.use_layer_scale:
x = x + self.drop_path(
self.layer_scale_1.unsqueeze(-1).unsqueeze(-1)
* self.token_mixer(self.norm1(x)))
x = x + self.drop_path(
self.layer_scale_2.unsqueeze(-1).unsqueeze(-1)
* self.poolmlp(self.norm2(x)))
else:
x = x + self.drop_path(self.token_mixer(self.norm1(x)))
x = x + self.drop_path(self.poolmlp(self.norm2(x)))
return x
class InvertedResidualBlock(nn.Module):
def __init__(self, inp, oup, expand_ratio):
super(InvertedResidualBlock, self).__init__()
hidden_dim = int(inp * expand_ratio)
self.bottleneckBlock = nn.Sequential(
# pw
nn.Conv2d(inp, hidden_dim, 1, bias=False),
# nn.BatchNorm2d(hidden_dim),
nn.ReLU6(inplace=True),
# dw
nn.ReflectionPad2d(1),
nn.Conv2d(hidden_dim, hidden_dim, 3, groups=hidden_dim, bias=False),
# nn.BatchNorm2d(hidden_dim),
nn.ReLU6(inplace=True),
# pw-linear
nn.Conv2d(hidden_dim, oup, 1, bias=False),
# nn.BatchNorm2d(oup),
)
def forward(self, x):
return self.bottleneckBlock(x)
class DetailNode(nn.Module):
def __init__(self):
super(DetailNode, self).__init__()
self.theta_phi = InvertedResidualBlock(inp=32, oup=32, expand_ratio=2)
self.theta_rho = InvertedResidualBlock(inp=32, oup=32, expand_ratio=2)
self.theta_eta = InvertedResidualBlock(inp=32, oup=32, expand_ratio=2)
self.shffleconv = nn.Conv2d(64, 64, kernel_size=1,
stride=1, padding=0, bias=True)
def separateFeature(self, x):
z1, z2 = x[:, :x.shape[1] // 2], x[:, x.shape[1] // 2:x.shape[1]]
return z1, z2
def forward(self, z1, z2):
z1, z2 = self.separateFeature(
self.shffleconv(torch.cat((z1, z2), dim=1)))
z2 = z2 + self.theta_phi(z1)
z1 = z1 * torch.exp(self.theta_rho(z2)) + self.theta_eta(z2)
return z1, z2
class DetailFeatureExtraction(nn.Module):
def __init__(self, num_layers=3):
super(DetailFeatureExtraction, self).__init__()
INNmodules = [DetailNode() for _ in range(num_layers)]
self.net = nn.Sequential(*INNmodules)
def forward(self, x):
z1, z2 = x[:, :x.shape[1] // 2], x[:, x.shape[1] // 2:x.shape[1]]
for layer in self.net:
z1, z2 = layer(z1, z2)
return torch.cat((z1, z2), dim=1)
class DetailFeatureExtractionSAR(nn.Module):
def __init__(self, num_layers=3):
super(DetailFeatureExtractionSAR, self).__init__()
INNmodules = [DetailNode() for _ in range(num_layers)]
self.net = nn.Sequential(*INNmodules)
def forward(self, x):
z1, z2 = x[:, :x.shape[1] // 2], x[:, x.shape[1] // 2:x.shape[1]]
for layer in self.net:
z1, z2 = layer(z1, z2)
return torch.cat((z1, z2), dim=1)
# =============================================================================
# =============================================================================
import numbers
##########################################################################
## Layer Norm
def to_3d(x):
return rearrange(x, 'b c h w -> b (h w) c')
def to_4d(x, h, w):
return rearrange(x, 'b (h w) c -> b c h w', h=h, w=w)
class BiasFree_LayerNorm(nn.Module):
def __init__(self, normalized_shape):
super(BiasFree_LayerNorm, self).__init__()
if isinstance(normalized_shape, numbers.Integral):
normalized_shape = (normalized_shape,)
normalized_shape = torch.Size(normalized_shape)
assert len(normalized_shape) == 1
self.weight = nn.Parameter(torch.ones(normalized_shape))
self.normalized_shape = normalized_shape
def forward(self, x):
sigma = x.var(-1, keepdim=True, unbiased=False)
return x / torch.sqrt(sigma + 1e-5) * self.weight
class WithBias_LayerNorm(nn.Module):
def __init__(self, normalized_shape):
super(WithBias_LayerNorm, self).__init__()
if isinstance(normalized_shape, numbers.Integral):
normalized_shape = (normalized_shape,)
normalized_shape = torch.Size(normalized_shape)
assert len(normalized_shape) == 1
self.weight = nn.Parameter(torch.ones(normalized_shape))
self.bias = nn.Parameter(torch.zeros(normalized_shape))
self.normalized_shape = normalized_shape
def forward(self, x):
mu = x.mean(-1, keepdim=True)
sigma = x.var(-1, keepdim=True, unbiased=False)
return (x - mu) / torch.sqrt(sigma + 1e-5) * self.weight + self.bias
class LayerNorm(nn.Module):
def __init__(self, dim, LayerNorm_type):
super(LayerNorm, self).__init__()
if LayerNorm_type == 'BiasFree':
self.body = BiasFree_LayerNorm(dim)
else:
self.body = WithBias_LayerNorm(dim)
def forward(self, x):
h, w = x.shape[-2:]
return to_4d(self.body(to_3d(x)), h, w)
##########################################################################
## Gated-Dconv Feed-Forward Network (GDFN)
class FeedForward(nn.Module):
def __init__(self, dim, ffn_expansion_factor, bias):
super(FeedForward, self).__init__()
hidden_features = int(dim * ffn_expansion_factor)
self.project_in = nn.Conv2d(
dim, hidden_features * 2, kernel_size=1, bias=bias)
self.dwconv = nn.Conv2d(hidden_features * 2, hidden_features * 2, kernel_size=3,
stride=1, padding=1, groups=hidden_features * 2, bias=bias)
self.project_out = nn.Conv2d(
hidden_features, dim, kernel_size=1, bias=bias)
def forward(self, x):
x = self.project_in(x)
x1, x2 = self.dwconv(x).chunk(2, dim=1)
x = F.gelu(x1) * x2
x = self.project_out(x)
return x
##########################################################################
## Multi-DConv Head Transposed Self-Attention (MDTA)
class Attention(nn.Module):
def __init__(self, dim, num_heads, bias):
super(Attention, self).__init__()
self.num_heads = num_heads
self.temperature = nn.Parameter(torch.ones(num_heads, 1, 1))
self.qkv = nn.Conv2d(dim, dim * 3, kernel_size=1, bias=bias)
self.qkv_dwconv = nn.Conv2d(
dim * 3, dim * 3, kernel_size=3, stride=1, padding=1, groups=dim * 3, bias=bias)
self.project_out = nn.Conv2d(dim, dim, kernel_size=1, bias=bias)
def forward(self, x):
b, c, h, w = x.shape
qkv = self.qkv_dwconv(self.qkv(x))
q, k, v = qkv.chunk(3, dim=1)
q = rearrange(q, 'b (head c) h w -> b head c (h w)',
head=self.num_heads)
k = rearrange(k, 'b (head c) h w -> b head c (h w)',
head=self.num_heads)
v = rearrange(v, 'b (head c) h w -> b head c (h w)',
head=self.num_heads)
q = torch.nn.functional.normalize(q, dim=-1)
k = torch.nn.functional.normalize(k, dim=-1)
attn = (q @ k.transpose(-2, -1)) * self.temperature
attn = attn.softmax(dim=-1)
out = (attn @ v)
out = rearrange(out, 'b head c (h w) -> b (head c) h w',
head=self.num_heads, h=h, w=w)
out = self.project_out(out)
return out
##########################################################################
class TransformerBlock(nn.Module):
def __init__(self, dim, num_heads, ffn_expansion_factor, bias, LayerNorm_type):
super(TransformerBlock, self).__init__()
self.norm1 = LayerNorm(dim, LayerNorm_type)
self.attn = Attention(dim, num_heads, bias)
self.norm2 = LayerNorm(dim, LayerNorm_type)
self.ffn = FeedForward(dim, ffn_expansion_factor, bias)
def forward(self, x):
x = x + self.attn(self.norm1(x))
x = x + self.ffn(self.norm2(x))
return x
##########################################################################
## Overlapped image patch embedding with 3x3 Conv
class OverlapPatchEmbed(nn.Module):
def __init__(self, in_c=3, embed_dim=48, bias=False):
super(OverlapPatchEmbed, self).__init__()
self.proj = nn.Conv2d(in_c, embed_dim, kernel_size=3,
stride=1, padding=1, bias=bias)
def forward(self, x):
x = self.proj(x)
return x
class Restormer_Encoder(nn.Module):
def __init__(self,
inp_channels=1,
out_channels=1,
dim=64,
num_blocks=[4, 4],
heads=[8, 8, 8],
ffn_expansion_factor=2,
bias=False,
LayerNorm_type='WithBias',
):
super(Restormer_Encoder, self).__init__()
self.patch_embed = OverlapPatchEmbed(inp_channels, dim)
self.encoder_level1 = nn.Sequential(
*[TransformerBlock(dim=dim, num_heads=heads[0], ffn_expansion_factor=ffn_expansion_factor,
bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_blocks[0])])
self.baseFeature = BaseFeatureExtraction(dim=dim)
self.detailFeature = DetailFeatureExtraction()
self.baseFeatureSar= BaseFeatureExtractionSAR(dim=dim)
self.detailFeatureSar = DetailFeatureExtractionSAR()
def forward(self, inp_img, sar_img=False):
inp_enc_level1 = self.patch_embed(inp_img)
out_enc_level1 = self.encoder_level1(inp_enc_level1)
if sar_img:
base_feature = self.baseFeature(out_enc_level1)
detail_feature = self.detailFeature(out_enc_level1)
else:
base_feature= self.baseFeature(out_enc_level1)
detail_feature = self.detailFeature(out_enc_level1)
return base_feature, detail_feature, out_enc_level1
class Restormer_Decoder(nn.Module):
def __init__(self,
inp_channels=1,
out_channels=1,
dim=64,
num_blocks=[4, 4],
heads=[8, 8, 8],
ffn_expansion_factor=2,
bias=False,
LayerNorm_type='WithBias',
):
super(Restormer_Decoder, self).__init__()
self.reduce_channel = nn.Conv2d(int(dim * 2), int(dim), kernel_size=1, bias=bias)
self.encoder_level2 = nn.Sequential(
*[TransformerBlock(dim=dim, num_heads=heads[1], ffn_expansion_factor=ffn_expansion_factor,
bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_blocks[1])])
self.output = nn.Sequential(
nn.Conv2d(int(dim), int(dim) // 2, kernel_size=3,
stride=1, padding=1, bias=bias),
nn.LeakyReLU(),
nn.Conv2d(int(dim) // 2, out_channels, kernel_size=3,
stride=1, padding=1, bias=bias), )
self.sigmoid = nn.Sigmoid()
def forward(self, inp_img, base_feature, detail_feature):
out_enc_level0 = torch.cat((base_feature, detail_feature), dim=1)
out_enc_level0 = self.reduce_channel(out_enc_level0)
out_enc_level1 = self.encoder_level2(out_enc_level0)
if inp_img is not None:
out_enc_level1 = self.output(out_enc_level1) + inp_img
else:
out_enc_level1 = self.output(out_enc_level1)
return self.sigmoid(out_enc_level1), out_enc_level0
if __name__ == '__main__':
height = 128
width = 128
window_size = 8
modelE = Restormer_Encoder().cuda()
modelD = Restormer_Decoder().cuda()