548 lines
21 KiB
Python
548 lines
21 KiB
Python
import time
|
||
import math
|
||
from functools import partial
|
||
from typing import Optional, Callable
|
||
|
||
import torch
|
||
import torch.nn as nn
|
||
import torch.nn.functional as F
|
||
import torch.utils.checkpoint as checkpoint
|
||
from einops import rearrange, repeat
|
||
from timm.models.layers import DropPath, to_2tuple, trunc_normal_
|
||
|
||
try:
|
||
from mamba_ssm.ops.selective_scan_interface import selective_scan_fn, selective_scan_ref
|
||
except:
|
||
pass
|
||
|
||
# an alternative for mamba_ssm (in which causal_conv1d is needed)
|
||
try:
|
||
from selective_scan import selective_scan_fn as selective_scan_fn_v1
|
||
from selective_scan import selective_scan_ref as selective_scan_ref_v1
|
||
except:
|
||
pass
|
||
|
||
DropPath.__repr__ = lambda self: f"timm.DropPath({self.drop_prob})"
|
||
|
||
"""
|
||
CNN在远程建模能力方面的局限性使它们无法有效地提取图像中的特征,而Transformers则受到其二次计算复杂性的阻碍。最近的研究表明,以Mamba为代表的状态空间模型(SSM)可以在保持线性计算复杂度的同时有效地模拟长程相互作用。
|
||
我们介绍了一个新颖的Conv-SSM模块。Conv-SSM将卷积层的局部特征提取能力与SSM捕获长程依赖性的能力相结合。
|
||
"""
|
||
|
||
|
||
def flops_selective_scan_ref(B=1, L=256, D=768, N=16, with_D=True, with_Z=False, with_Group=True, with_complex=False):
|
||
"""
|
||
u: r(B D L)
|
||
delta: r(B D L)
|
||
A: r(D N)
|
||
B: r(B N L)
|
||
C: r(B N L)
|
||
D: r(D)
|
||
z: r(B D L)
|
||
delta_bias: r(D), fp32
|
||
|
||
ignores:
|
||
[.float(), +, .softplus, .shape, new_zeros, repeat, stack, to(dtype), silu]
|
||
"""
|
||
import numpy as np
|
||
|
||
# fvcore.nn.jit_handles
|
||
def get_flops_einsum(input_shapes, equation):
|
||
np_arrs = [np.zeros(s) for s in input_shapes]
|
||
optim = np.einsum_path(equation, *np_arrs, optimize="optimal")[1]
|
||
for line in optim.split("\n"):
|
||
if "optimized flop" in line.lower():
|
||
# divided by 2 because we count MAC (multiply-add counted as one flop)
|
||
flop = float(np.floor(float(line.split(":")[-1]) / 2))
|
||
return flop
|
||
|
||
assert not with_complex
|
||
|
||
flops = 0 # below code flops = 0
|
||
if False:
|
||
...
|
||
"""
|
||
dtype_in = u.dtype
|
||
u = u.float()
|
||
delta = delta.float()
|
||
if delta_bias is not None:
|
||
delta = delta + delta_bias[..., None].float()
|
||
if delta_softplus:
|
||
delta = F.softplus(delta)
|
||
batch, dim, dstate = u.shape[0], A.shape[0], A.shape[1]
|
||
is_variable_B = B.dim() >= 3
|
||
is_variable_C = C.dim() >= 3
|
||
if A.is_complex():
|
||
if is_variable_B:
|
||
B = torch.view_as_complex(rearrange(B.float(), "... (L two) -> ... L two", two=2))
|
||
if is_variable_C:
|
||
C = torch.view_as_complex(rearrange(C.float(), "... (L two) -> ... L two", two=2))
|
||
else:
|
||
B = B.float()
|
||
C = C.float()
|
||
x = A.new_zeros((batch, dim, dstate))
|
||
ys = []
|
||
"""
|
||
|
||
flops += get_flops_einsum([[B, D, L], [D, N]], "bdl,dn->bdln")
|
||
if with_Group:
|
||
flops += get_flops_einsum([[B, D, L], [B, N, L], [B, D, L]], "bdl,bnl,bdl->bdln")
|
||
else:
|
||
flops += get_flops_einsum([[B, D, L], [B, D, N, L], [B, D, L]], "bdl,bdnl,bdl->bdln")
|
||
if False:
|
||
...
|
||
"""
|
||
deltaA = torch.exp(torch.einsum('bdl,dn->bdln', delta, A))
|
||
if not is_variable_B:
|
||
deltaB_u = torch.einsum('bdl,dn,bdl->bdln', delta, B, u)
|
||
else:
|
||
if B.dim() == 3:
|
||
deltaB_u = torch.einsum('bdl,bnl,bdl->bdln', delta, B, u)
|
||
else:
|
||
B = repeat(B, "B G N L -> B (G H) N L", H=dim // B.shape[1])
|
||
deltaB_u = torch.einsum('bdl,bdnl,bdl->bdln', delta, B, u)
|
||
if is_variable_C and C.dim() == 4:
|
||
C = repeat(C, "B G N L -> B (G H) N L", H=dim // C.shape[1])
|
||
last_state = None
|
||
"""
|
||
|
||
in_for_flops = B * D * N
|
||
if with_Group:
|
||
in_for_flops += get_flops_einsum([[B, D, N], [B, D, N]], "bdn,bdn->bd")
|
||
else:
|
||
in_for_flops += get_flops_einsum([[B, D, N], [B, N]], "bdn,bn->bd")
|
||
flops += L * in_for_flops
|
||
if False:
|
||
...
|
||
"""
|
||
for i in range(u.shape[2]):
|
||
x = deltaA[:, :, i] * x + deltaB_u[:, :, i]
|
||
if not is_variable_C:
|
||
y = torch.einsum('bdn,dn->bd', x, C)
|
||
else:
|
||
if C.dim() == 3:
|
||
y = torch.einsum('bdn,bn->bd', x, C[:, :, i])
|
||
else:
|
||
y = torch.einsum('bdn,bdn->bd', x, C[:, :, :, i])
|
||
if i == u.shape[2] - 1:
|
||
last_state = x
|
||
if y.is_complex():
|
||
y = y.real * 2
|
||
ys.append(y)
|
||
y = torch.stack(ys, dim=2) # (batch dim L)
|
||
"""
|
||
|
||
if with_D:
|
||
flops += B * D * L
|
||
if with_Z:
|
||
flops += B * D * L
|
||
if False:
|
||
...
|
||
"""
|
||
out = y if D is None else y + u * rearrange(D, "d -> d 1")
|
||
if z is not None:
|
||
out = out * F.silu(z)
|
||
out = out.to(dtype=dtype_in)
|
||
"""
|
||
|
||
return flops
|
||
|
||
|
||
class PatchEmbed2D(nn.Module):
|
||
r""" Image to Patch Embedding
|
||
Args:
|
||
patch_size (int): Patch token size. Default: 4.
|
||
in_chans (int): Number of input image channels. Default: 3.
|
||
embed_dim (int): Number of linear projection output channels. Default: 96.
|
||
norm_layer (nn.Module, optional): Normalization layer. Default: None
|
||
"""
|
||
|
||
def __init__(self, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None, **kwargs):
|
||
super().__init__()
|
||
if isinstance(patch_size, int):
|
||
patch_size = (patch_size, patch_size)
|
||
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
|
||
if norm_layer is not None:
|
||
self.norm = norm_layer(embed_dim)
|
||
else:
|
||
self.norm = None
|
||
|
||
def forward(self, x):
|
||
x = self.proj(x).permute(0, 2, 3, 1)
|
||
if self.norm is not None:
|
||
x = self.norm(x)
|
||
return x
|
||
|
||
|
||
class PatchMerging2D(nn.Module):
|
||
r""" Patch Merging Layer.
|
||
Args:
|
||
input_resolution (tuple[int]): Resolution of input feature.
|
||
dim (int): Number of input channels.
|
||
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
|
||
"""
|
||
|
||
def __init__(self, dim, norm_layer=nn.LayerNorm):
|
||
super().__init__()
|
||
self.dim = dim
|
||
self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
|
||
self.norm = norm_layer(4 * dim)
|
||
|
||
def forward(self, x):
|
||
B, H, W, C = x.shape
|
||
|
||
SHAPE_FIX = [-1, -1]
|
||
if (W % 2 != 0) or (H % 2 != 0):
|
||
print(f"Warning, x.shape {x.shape} is not match even ===========", flush=True)
|
||
SHAPE_FIX[0] = H // 2
|
||
SHAPE_FIX[1] = W // 2
|
||
|
||
x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C
|
||
x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C
|
||
x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C
|
||
x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C
|
||
|
||
if SHAPE_FIX[0] > 0:
|
||
x0 = x0[:, :SHAPE_FIX[0], :SHAPE_FIX[1], :]
|
||
x1 = x1[:, :SHAPE_FIX[0], :SHAPE_FIX[1], :]
|
||
x2 = x2[:, :SHAPE_FIX[0], :SHAPE_FIX[1], :]
|
||
x3 = x3[:, :SHAPE_FIX[0], :SHAPE_FIX[1], :]
|
||
|
||
x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C
|
||
x = x.view(B, H // 2, W // 2, 4 * C) # B H/2*W/2 4*C
|
||
|
||
x = self.norm(x)
|
||
x = self.reduction(x)
|
||
|
||
return x
|
||
|
||
|
||
class PatchExpand2D(nn.Module):
|
||
def __init__(self, dim, dim_scale=2, norm_layer=nn.LayerNorm):
|
||
super().__init__()
|
||
self.dim = dim * 2
|
||
self.dim_scale = dim_scale
|
||
self.expand = nn.Linear(self.dim, dim_scale * self.dim, bias=False)
|
||
self.norm = norm_layer(self.dim // dim_scale)
|
||
|
||
def forward(self, x):
|
||
B, H, W, C = x.shape
|
||
x = self.expand(x)
|
||
|
||
x = rearrange(x, 'b h w (p1 p2 c)-> b (h p1) (w p2) c', p1=self.dim_scale, p2=self.dim_scale,
|
||
c=C // self.dim_scale)
|
||
x = self.norm(x)
|
||
|
||
return x
|
||
|
||
|
||
class Final_PatchExpand2D(nn.Module):
|
||
def __init__(self, dim, dim_scale=4, norm_layer=nn.LayerNorm):
|
||
super().__init__()
|
||
self.dim = dim
|
||
self.dim_scale = dim_scale
|
||
self.expand = nn.Linear(self.dim, dim_scale * self.dim, bias=False)
|
||
self.norm = norm_layer(self.dim // dim_scale)
|
||
|
||
def forward(self, x):
|
||
B, H, W, C = x.shape
|
||
x = self.expand(x)
|
||
|
||
x = rearrange(x, 'b h w (p1 p2 c)-> b (h p1) (w p2) c', p1=self.dim_scale, p2=self.dim_scale,
|
||
c=C // self.dim_scale)
|
||
x = self.norm(x)
|
||
|
||
return x
|
||
|
||
|
||
class SS2D(
|
||
nn.Module): # SS2D(State Space Model 2D)类,它是实现状态空间模型(SSM)在视觉任务中应用的核心部分。这个类将状态空间模型的长程依赖性建模能力与卷积操作结合起来,旨在捕获图像中的长程相互作用,同时保持对局部特征的敏感度。
|
||
def __init__(
|
||
self,
|
||
d_model,
|
||
d_state=16,
|
||
# d_state="auto", # 20240109
|
||
d_conv=3,
|
||
expand=2,
|
||
dt_rank="auto",
|
||
dt_min=0.001,
|
||
dt_max=0.1,
|
||
dt_init="random",
|
||
dt_scale=1.0,
|
||
dt_init_floor=1e-4,
|
||
dropout=0.,
|
||
conv_bias=True,
|
||
bias=False,
|
||
device=None,
|
||
dtype=None,
|
||
**kwargs,
|
||
):
|
||
factory_kwargs = {"device": device, "dtype": dtype}
|
||
super().__init__()
|
||
self.d_model = d_model
|
||
self.d_state = d_state
|
||
# self.d_state = math.ceil(self.d_model / 6) if d_state == "auto" else d_model # 20240109
|
||
self.d_conv = d_conv
|
||
self.expand = expand
|
||
self.d_inner = int(self.expand * self.d_model)
|
||
self.dt_rank = math.ceil(self.d_model / 16) if dt_rank == "auto" else dt_rank
|
||
|
||
self.in_proj = nn.Linear(self.d_model, self.d_inner * 2, bias=bias, **factory_kwargs)
|
||
self.conv2d = nn.Conv2d(
|
||
in_channels=self.d_inner,
|
||
out_channels=self.d_inner,
|
||
groups=self.d_inner,
|
||
bias=conv_bias,
|
||
kernel_size=d_conv,
|
||
padding=(d_conv - 1) // 2,
|
||
**factory_kwargs,
|
||
)
|
||
self.act = nn.SiLU()
|
||
|
||
self.x_proj = (
|
||
nn.Linear(self.d_inner, (self.dt_rank + self.d_state * 2), bias=False, **factory_kwargs),
|
||
nn.Linear(self.d_inner, (self.dt_rank + self.d_state * 2), bias=False, **factory_kwargs),
|
||
nn.Linear(self.d_inner, (self.dt_rank + self.d_state * 2), bias=False, **factory_kwargs),
|
||
nn.Linear(self.d_inner, (self.dt_rank + self.d_state * 2), bias=False, **factory_kwargs),
|
||
)
|
||
self.x_proj_weight = nn.Parameter(torch.stack([t.weight for t in self.x_proj], dim=0)) # (K=4, N, inner)
|
||
del self.x_proj
|
||
|
||
self.dt_projs = (
|
||
self.dt_init(self.dt_rank, self.d_inner, dt_scale, dt_init, dt_min, dt_max, dt_init_floor,
|
||
**factory_kwargs),
|
||
self.dt_init(self.dt_rank, self.d_inner, dt_scale, dt_init, dt_min, dt_max, dt_init_floor,
|
||
**factory_kwargs),
|
||
self.dt_init(self.dt_rank, self.d_inner, dt_scale, dt_init, dt_min, dt_max, dt_init_floor,
|
||
**factory_kwargs),
|
||
self.dt_init(self.dt_rank, self.d_inner, dt_scale, dt_init, dt_min, dt_max, dt_init_floor,
|
||
**factory_kwargs),
|
||
)
|
||
self.dt_projs_weight = nn.Parameter(torch.stack([t.weight for t in self.dt_projs], dim=0)) # (K=4, inner, rank)
|
||
self.dt_projs_bias = nn.Parameter(torch.stack([t.bias for t in self.dt_projs], dim=0)) # (K=4, inner)
|
||
del self.dt_projs
|
||
|
||
self.A_logs = self.A_log_init(self.d_state, self.d_inner, copies=4, merge=True) # (K=4, D, N)
|
||
self.Ds = self.D_init(self.d_inner, copies=4, merge=True) # (K=4, D, N)
|
||
|
||
# self.selective_scan = selective_scan_fn
|
||
self.forward_core = self.forward_corev0
|
||
|
||
self.out_norm = nn.LayerNorm(self.d_inner)
|
||
self.out_proj = nn.Linear(self.d_inner, self.d_model, bias=bias, **factory_kwargs)
|
||
self.dropout = nn.Dropout(dropout) if dropout > 0. else None
|
||
|
||
@staticmethod
|
||
def dt_init(dt_rank, d_inner, dt_scale=1.0, dt_init="random", dt_min=0.001, dt_max=0.1, dt_init_floor=1e-4,
|
||
**factory_kwargs):
|
||
dt_proj = nn.Linear(dt_rank, d_inner, bias=True, **factory_kwargs)
|
||
|
||
# Initialize special dt projection to preserve variance at initialization
|
||
dt_init_std = dt_rank ** -0.5 * dt_scale
|
||
if dt_init == "constant":
|
||
nn.init.constant_(dt_proj.weight, dt_init_std)
|
||
elif dt_init == "random":
|
||
nn.init.uniform_(dt_proj.weight, -dt_init_std, dt_init_std)
|
||
else:
|
||
raise NotImplementedError
|
||
|
||
# Initialize dt bias so that F.softplus(dt_bias) is between dt_min and dt_max
|
||
dt = torch.exp(
|
||
torch.rand(d_inner, **factory_kwargs) * (math.log(dt_max) - math.log(dt_min))
|
||
+ math.log(dt_min)
|
||
).clamp(min=dt_init_floor)
|
||
# Inverse of softplus: https://github.com/pytorch/pytorch/issues/72759
|
||
inv_dt = dt + torch.log(-torch.expm1(-dt))
|
||
with torch.no_grad():
|
||
dt_proj.bias.copy_(inv_dt)
|
||
# Our initialization would set all Linear.bias to zero, need to mark this one as _no_reinit
|
||
dt_proj.bias._no_reinit = True
|
||
|
||
return dt_proj
|
||
|
||
@staticmethod
|
||
def A_log_init(d_state, d_inner, copies=1, device=None, merge=True):
|
||
# S4D real initialization
|
||
A = repeat(
|
||
torch.arange(1, d_state + 1, dtype=torch.float32, device=device),
|
||
"n -> d n",
|
||
d=d_inner,
|
||
).contiguous()
|
||
A_log = torch.log(A) # Keep A_log in fp32
|
||
if copies > 1:
|
||
A_log = repeat(A_log, "d n -> r d n", r=copies)
|
||
if merge:
|
||
A_log = A_log.flatten(0, 1)
|
||
A_log = nn.Parameter(A_log)
|
||
A_log._no_weight_decay = True
|
||
return A_log
|
||
|
||
@staticmethod
|
||
def D_init(d_inner, copies=1, device=None, merge=True):
|
||
# D "skip" parameter
|
||
D = torch.ones(d_inner, device=device)
|
||
if copies > 1:
|
||
D = repeat(D, "n1 -> r n1", r=copies)
|
||
if merge:
|
||
D = D.flatten(0, 1)
|
||
D = nn.Parameter(D) # Keep in fp32
|
||
D._no_weight_decay = True
|
||
return D
|
||
|
||
def forward_corev0(self, x: torch.Tensor):
|
||
self.selective_scan = selective_scan_fn
|
||
|
||
B, C, H, W = x.shape
|
||
L = H * W
|
||
K = 4
|
||
|
||
x_hwwh = torch.stack([x.view(B, -1, L), torch.transpose(x, dim0=2, dim1=3).contiguous().view(B, -1, L)],
|
||
dim=1).view(B, 2, -1, L)
|
||
xs = torch.cat([x_hwwh, torch.flip(x_hwwh, dims=[-1])], dim=1) # (b, k, d, l)
|
||
|
||
x_dbl = torch.einsum("b k d l, k c d -> b k c l", xs.view(B, K, -1, L), self.x_proj_weight)
|
||
# x_dbl = x_dbl + self.x_proj_bias.view(1, K, -1, 1)
|
||
dts, Bs, Cs = torch.split(x_dbl, [self.dt_rank, self.d_state, self.d_state], dim=2)
|
||
dts = torch.einsum("b k r l, k d r -> b k d l", dts.view(B, K, -1, L), self.dt_projs_weight)
|
||
# dts = dts + self.dt_projs_bias.view(1, K, -1, 1)
|
||
|
||
xs = xs.float().view(B, -1, L) # (b, k * d, l)
|
||
dts = dts.contiguous().float().view(B, -1, L) # (b, k * d, l)
|
||
Bs = Bs.float().view(B, K, -1, L) # (b, k, d_state, l)
|
||
Cs = Cs.float().view(B, K, -1, L) # (b, k, d_state, l)
|
||
Ds = self.Ds.float().view(-1) # (k * d)
|
||
As = -torch.exp(self.A_logs.float()).view(-1, self.d_state) # (k * d, d_state)
|
||
dt_projs_bias = self.dt_projs_bias.float().view(-1) # (k * d)
|
||
|
||
out_y = self.selective_scan(
|
||
xs, dts,
|
||
As, Bs, Cs, Ds, z=None,
|
||
delta_bias=dt_projs_bias,
|
||
delta_softplus=True,
|
||
return_last_state=False,
|
||
).view(B, K, -1, L)
|
||
assert out_y.dtype == torch.float
|
||
|
||
inv_y = torch.flip(out_y[:, 2:4], dims=[-1]).view(B, 2, -1, L)
|
||
wh_y = torch.transpose(out_y[:, 1].view(B, -1, W, H), dim0=2, dim1=3).contiguous().view(B, -1, L)
|
||
invwh_y = torch.transpose(inv_y[:, 1].view(B, -1, W, H), dim0=2, dim1=3).contiguous().view(B, -1, L)
|
||
|
||
return out_y[:, 0], inv_y[:, 0], wh_y, invwh_y
|
||
|
||
# an alternative to forward_corev1
|
||
def forward_corev1(self, x: torch.Tensor):
|
||
self.selective_scan = selective_scan_fn_v1
|
||
|
||
B, C, H, W = x.shape
|
||
L = H * W
|
||
K = 4
|
||
|
||
x_hwwh = torch.stack([x.view(B, -1, L), torch.transpose(x, dim0=2, dim1=3).contiguous().view(B, -1, L)],
|
||
dim=1).view(B, 2, -1, L)
|
||
xs = torch.cat([x_hwwh, torch.flip(x_hwwh, dims=[-1])], dim=1) # (b, k, d, l)
|
||
|
||
x_dbl = torch.einsum("b k d l, k c d -> b k c l", xs.view(B, K, -1, L), self.x_proj_weight)
|
||
# x_dbl = x_dbl + self.x_proj_bias.view(1, K, -1, 1)
|
||
dts, Bs, Cs = torch.split(x_dbl, [self.dt_rank, self.d_state, self.d_state], dim=2)
|
||
dts = torch.einsum("b k r l, k d r -> b k d l", dts.view(B, K, -1, L), self.dt_projs_weight)
|
||
# dts = dts + self.dt_projs_bias.view(1, K, -1, 1)
|
||
|
||
xs = xs.float().view(B, -1, L) # (b, k * d, l)
|
||
dts = dts.contiguous().float().view(B, -1, L) # (b, k * d, l)
|
||
Bs = Bs.float().view(B, K, -1, L) # (b, k, d_state, l)
|
||
Cs = Cs.float().view(B, K, -1, L) # (b, k, d_state, l)
|
||
Ds = self.Ds.float().view(-1) # (k * d)
|
||
As = -torch.exp(self.A_logs.float()).view(-1, self.d_state) # (k * d, d_state)
|
||
dt_projs_bias = self.dt_projs_bias.float().view(-1) # (k * d)
|
||
|
||
out_y = self.selective_scan(
|
||
xs, dts,
|
||
As, Bs, Cs, Ds,
|
||
delta_bias=dt_projs_bias,
|
||
delta_softplus=True,
|
||
).view(B, K, -1, L)
|
||
assert out_y.dtype == torch.float
|
||
|
||
inv_y = torch.flip(out_y[:, 2:4], dims=[-1]).view(B, 2, -1, L)
|
||
wh_y = torch.transpose(out_y[:, 1].view(B, -1, W, H), dim0=2, dim1=3).contiguous().view(B, -1, L)
|
||
invwh_y = torch.transpose(inv_y[:, 1].view(B, -1, W, H), dim0=2, dim1=3).contiguous().view(B, -1, L)
|
||
|
||
return out_y[:, 0], inv_y[:, 0], wh_y, invwh_y
|
||
|
||
def forward(self, x: torch.Tensor, **kwargs):
|
||
B, H, W, C = x.shape
|
||
|
||
xz = self.in_proj(x)
|
||
x, z = xz.chunk(2, dim=-1) # (b, h, w, d)
|
||
|
||
x = x.permute(0, 3, 1, 2).contiguous()
|
||
x = self.act(self.conv2d(x)) # (b, d, h, w)
|
||
y1, y2, y3, y4 = self.forward_core(x)
|
||
assert y1.dtype == torch.float32
|
||
y = y1 + y2 + y3 + y4
|
||
y = torch.transpose(y, dim0=1, dim1=2).contiguous().view(B, H, W, -1)
|
||
y = self.out_norm(y)
|
||
y = y * F.silu(z)
|
||
out = self.out_proj(y)
|
||
if self.dropout is not None:
|
||
out = self.dropout(out)
|
||
return out
|
||
|
||
|
||
class ConvSSM(nn.Module):
|
||
"""
|
||
这个类组合了自注意力机制和卷积操作,旨在融合自注意力的全局感知能力和卷积的局部特征提取能力。
|
||
输入特征被分成两部分,一部分通过SS2D自注意力模块处理,另一部分通过一系列卷积层处理。处理后的两部分再次合并,并通过最终的卷积层生成输出特征。
|
||
"""
|
||
|
||
def __init__(
|
||
self,
|
||
hidden_dim: int = 0,
|
||
drop_path: float = 0,
|
||
norm_layer: Callable[..., torch.nn.Module] = partial(nn.LayerNorm, eps=1e-6),
|
||
attn_drop_rate: float = 0,
|
||
d_state: int = 16,
|
||
**kwargs,
|
||
):
|
||
super().__init__()
|
||
self.ln_1 = norm_layer(hidden_dim // 2)
|
||
self.self_attention = SS2D(d_model=hidden_dim // 2, dropout=attn_drop_rate, d_state=d_state, **kwargs)
|
||
self.drop_path = DropPath(drop_path)
|
||
|
||
self.conv33conv33conv11 = nn.Sequential(
|
||
nn.Conv2d(in_channels=hidden_dim // 2, out_channels=hidden_dim // 2, kernel_size=3, stride=1, padding=1),
|
||
nn.BatchNorm2d(hidden_dim // 2),
|
||
nn.ReLU(),
|
||
nn.Conv2d(in_channels=hidden_dim // 2, out_channels=hidden_dim // 2, kernel_size=3, stride=1, padding=1),
|
||
nn.BatchNorm2d(hidden_dim // 2),
|
||
nn.ReLU(),
|
||
nn.Conv2d(in_channels=hidden_dim // 2, out_channels=hidden_dim // 2, kernel_size=1, stride=1)
|
||
)
|
||
self.finalconv11 = nn.Conv2d(in_channels=hidden_dim, out_channels=hidden_dim, kernel_size=1, stride=1)
|
||
|
||
def forward(self, input: torch.Tensor):
|
||
input_left, input_right = input.chunk(2, dim=-1)
|
||
x = input_right + self.drop_path(self.self_attention(self.ln_1(input_right)))
|
||
input_left = input_left.permute(0, 3, 1, 2).contiguous()
|
||
input_left = self.conv33conv33conv11(input_left)
|
||
x = x.permute(0, 3, 1, 2).contiguous()
|
||
output = torch.cat((input_left, x), dim=1)
|
||
output = self.finalconv11(output).permute(0, 2, 3, 1).contiguous()
|
||
return output + input
|
||
|
||
|
||
if __name__ == '__main__':
|
||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||
# 初始化ConvSSM模块,hidden_dim为128
|
||
block = ConvSSM(hidden_dim=128).to(device)
|
||
|
||
# 生成随机输入张量,尺寸为[批次大小, 高度, 宽度, 通道数]
|
||
# 这里批次大小为4,高度和宽度为32,通道数为128(符合hidden_dim的大小)
|
||
input_tensor = torch.rand(4, 32, 32, 128).to(device)
|
||
|
||
# 前向传递输入张量通过ConvSSM模块
|
||
output = block(input_tensor)
|
||
|
||
# 打印输入和输出张量的尺寸
|
||
print("Input tensor size:", input_tensor.size())
|
||
print("Output tensor size:", output.size()) |