Compare commits

..

1 Commits

Author SHA1 Message Date
f95b13b8aa 新增测试脚本和日志记录功能
- 添加 test_shell.py脚本,实现自动测试和日志记录功能
- 新增 deployment.xml 和 iml 文件,配置项目部署和模块信息
- 添加 log_20241008_success.log 和 status.md 文件,记录测试结果和状态
- 更新 test_IVF.py,修改测试模型和输出路径
2024-10-09 11:52:46 +08:00
29 changed files with 156 additions and 4626 deletions

1
.gitignore vendored Normal file
View File

@ -0,0 +1 @@
.idea/

View File

@ -1,592 +0,0 @@
import time
import math
from functools import partial
from typing import Optional, Callable
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.checkpoint as checkpoint
from einops import rearrange, repeat
from timm.models.layers import DropPath, to_2tuple, trunc_normal_
# try:
# from mamba_ssm.ops.selective_scan_interface import selective_scan_fn, selective_scan_ref
# except:
# pass
# an alternative for mamba_ssm (in which causal_conv1d is needed)
try:
from mamba_ssm.ops.selective_scan_interface import selective_scan_fn, selective_scan_ref
# from mamba_ssm.selective_scan import selective_scan_fn as selective_scan_fn_v1
# from selective_scan import selective_scan_ref as selective_scan_ref_v1
except:
pass
DropPath.__repr__ = lambda self: f"timm.DropPath({self.drop_prob})"
"""
CNN在远程建模能力方面的局限性使它们无法有效地提取图像中的特征而Transformers则受到其二次计算复杂性的阻碍最近的研究表明以Mamba为代表的状态空间模型SSM可以在保持线性计算复杂度的同时有效地模拟长程相互作用
我们介绍了一个新颖的Conv-SSM模块Conv-SSM将卷积层的局部特征提取能力与SSM捕获长程依赖性的能力相结合
"""
def flops_selective_scan_ref(B=1, L=256, D=768, N=16, with_D=True, with_Z=False, with_Group=True, with_complex=False):
"""
u: r(B D L)
delta: r(B D L)
A: r(D N)
B: r(B N L)
C: r(B N L)
D: r(D)
z: r(B D L)
delta_bias: r(D), fp32
ignores:
[.float(), +, .softplus, .shape, new_zeros, repeat, stack, to(dtype), silu]
"""
import numpy as np
# fvcore.nn.jit_handles
def get_flops_einsum(input_shapes, equation):
np_arrs = [np.zeros(s) for s in input_shapes]
optim = np.einsum_path(equation, *np_arrs, optimize="optimal")[1]
for line in optim.split("\n"):
if "optimized flop" in line.lower():
# divided by 2 because we count MAC (multiply-add counted as one flop)
flop = float(np.floor(float(line.split(":")[-1]) / 2))
return flop
assert not with_complex
flops = 0 # below code flops = 0
if False:
...
"""
dtype_in = u.dtype
u = u.float()
delta = delta.float()
if delta_bias is not None:
delta = delta + delta_bias[..., None].float()
if delta_softplus:
delta = F.softplus(delta)
batch, dim, dstate = u.shape[0], A.shape[0], A.shape[1]
is_variable_B = B.dim() >= 3
is_variable_C = C.dim() >= 3
if A.is_complex():
if is_variable_B:
B = torch.view_as_complex(rearrange(B.float(), "... (L two) -> ... L two", two=2))
if is_variable_C:
C = torch.view_as_complex(rearrange(C.float(), "... (L two) -> ... L two", two=2))
else:
B = B.float()
C = C.float()
x = A.new_zeros((batch, dim, dstate))
ys = []
"""
flops += get_flops_einsum([[B, D, L], [D, N]], "bdl,dn->bdln")
if with_Group:
flops += get_flops_einsum([[B, D, L], [B, N, L], [B, D, L]], "bdl,bnl,bdl->bdln")
else:
flops += get_flops_einsum([[B, D, L], [B, D, N, L], [B, D, L]], "bdl,bdnl,bdl->bdln")
if False:
...
"""
deltaA = torch.exp(torch.einsum('bdl,dn->bdln', delta, A))
if not is_variable_B:
deltaB_u = torch.einsum('bdl,dn,bdl->bdln', delta, B, u)
else:
if B.dim() == 3:
deltaB_u = torch.einsum('bdl,bnl,bdl->bdln', delta, B, u)
else:
B = repeat(B, "B G N L -> B (G H) N L", H=dim // B.shape[1])
deltaB_u = torch.einsum('bdl,bdnl,bdl->bdln', delta, B, u)
if is_variable_C and C.dim() == 4:
C = repeat(C, "B G N L -> B (G H) N L", H=dim // C.shape[1])
last_state = None
"""
in_for_flops = B * D * N
if with_Group:
in_for_flops += get_flops_einsum([[B, D, N], [B, D, N]], "bdn,bdn->bd")
else:
in_for_flops += get_flops_einsum([[B, D, N], [B, N]], "bdn,bn->bd")
flops += L * in_for_flops
if False:
...
"""
for i in range(u.shape[2]):
x = deltaA[:, :, i] * x + deltaB_u[:, :, i]
if not is_variable_C:
y = torch.einsum('bdn,dn->bd', x, C)
else:
if C.dim() == 3:
y = torch.einsum('bdn,bn->bd', x, C[:, :, i])
else:
y = torch.einsum('bdn,bdn->bd', x, C[:, :, :, i])
if i == u.shape[2] - 1:
last_state = x
if y.is_complex():
y = y.real * 2
ys.append(y)
y = torch.stack(ys, dim=2) # (batch dim L)
"""
if with_D:
flops += B * D * L
if with_Z:
flops += B * D * L
if False:
...
"""
out = y if D is None else y + u * rearrange(D, "d -> d 1")
if z is not None:
out = out * F.silu(z)
out = out.to(dtype=dtype_in)
"""
return flops
class PatchEmbed2D(nn.Module):
r""" Image to Patch Embedding
Args:
patch_size (int): Patch token size. Default: 4.
in_chans (int): Number of input image channels. Default: 3.
embed_dim (int): Number of linear projection output channels. Default: 96.
norm_layer (nn.Module, optional): Normalization layer. Default: None
"""
def __init__(self, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None, **kwargs):
super().__init__()
if isinstance(patch_size, int):
patch_size = (patch_size, patch_size)
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
if norm_layer is not None:
self.norm = norm_layer(embed_dim)
else:
self.norm = None
def forward(self, x):
x = self.proj(x).permute(0, 2, 3, 1)
if self.norm is not None:
x = self.norm(x)
return x
class PatchMerging2D(nn.Module):
r""" Patch Merging Layer.
Args:
input_resolution (tuple[int]): Resolution of input feature.
dim (int): Number of input channels.
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
"""
def __init__(self, dim, norm_layer=nn.LayerNorm):
super().__init__()
self.dim = dim
self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
self.norm = norm_layer(4 * dim)
def forward(self, x):
B, H, W, C = x.shape
SHAPE_FIX = [-1, -1]
if (W % 2 != 0) or (H % 2 != 0):
print(f"Warning, x.shape {x.shape} is not match even ===========", flush=True)
SHAPE_FIX[0] = H // 2
SHAPE_FIX[1] = W // 2
x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C
x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C
x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C
x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C
if SHAPE_FIX[0] > 0:
x0 = x0[:, :SHAPE_FIX[0], :SHAPE_FIX[1], :]
x1 = x1[:, :SHAPE_FIX[0], :SHAPE_FIX[1], :]
x2 = x2[:, :SHAPE_FIX[0], :SHAPE_FIX[1], :]
x3 = x3[:, :SHAPE_FIX[0], :SHAPE_FIX[1], :]
x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C
x = x.view(B, H // 2, W // 2, 4 * C) # B H/2*W/2 4*C
x = self.norm(x)
x = self.reduction(x)
return x
class PatchExpand2D(nn.Module):
def __init__(self, dim, dim_scale=2, norm_layer=nn.LayerNorm):
super().__init__()
self.dim = dim * 2
self.dim_scale = dim_scale
self.expand = nn.Linear(self.dim, dim_scale * self.dim, bias=False)
self.norm = norm_layer(self.dim // dim_scale)
def forward(self, x):
B, H, W, C = x.shape
x = self.expand(x)
x = rearrange(x, 'b h w (p1 p2 c)-> b (h p1) (w p2) c', p1=self.dim_scale, p2=self.dim_scale,
c=C // self.dim_scale)
x = self.norm(x)
return x
class Final_PatchExpand2D(nn.Module):
def __init__(self, dim, dim_scale=4, norm_layer=nn.LayerNorm):
super().__init__()
self.dim = dim
self.dim_scale = dim_scale
self.expand = nn.Linear(self.dim, dim_scale * self.dim, bias=False)
self.norm = norm_layer(self.dim // dim_scale)
def forward(self, x):
B, H, W, C = x.shape
x = self.expand(x)
x = rearrange(x, 'b h w (p1 p2 c)-> b (h p1) (w p2) c', p1=self.dim_scale, p2=self.dim_scale,
c=C // self.dim_scale)
x = self.norm(x)
return x
class SS2D(nn.Module): # SS2DState Space Model 2D它是实现状态空间模型SSM在视觉任务中应用的核心部分。这个类将状态空间模型的长程依赖性建模能力与卷积操作结合起来旨在捕获图像中的长程相互作用同时保持对局部特征的敏感度。
def __init__(
self,
d_model,
d_state=16,
# d_state="auto", # 20240109
d_conv=3,
expand=2,
dt_rank="auto",
dt_min=0.001,
dt_max=0.1,
dt_init="random",
dt_scale=1.0,
dt_init_floor=1e-4,
dropout=0.,
conv_bias=True,
bias=False,
device=None,
dtype=None,
**kwargs,
):
factory_kwargs = {"device": device, "dtype": dtype}
super().__init__()
self.d_model = d_model
self.d_state = d_state
# self.d_state = math.ceil(self.d_model / 6) if d_state == "auto" else d_model # 20240109
self.d_conv = d_conv
self.expand = expand
self.d_inner = int(self.expand * self.d_model)
self.dt_rank = math.ceil(self.d_model / 16) if dt_rank == "auto" else dt_rank
self.in_proj = nn.Linear(self.d_model, self.d_inner * 2, bias=bias, **factory_kwargs)
self.conv2d = nn.Conv2d(
in_channels=self.d_inner,
out_channels=self.d_inner,
groups=self.d_inner,
bias=conv_bias,
kernel_size=d_conv,
padding=(d_conv - 1) // 2,
**factory_kwargs,
)
self.act = nn.SiLU()
self.x_proj = (
nn.Linear(self.d_inner, (self.dt_rank + self.d_state * 2), bias=False, **factory_kwargs),
nn.Linear(self.d_inner, (self.dt_rank + self.d_state * 2), bias=False, **factory_kwargs),
nn.Linear(self.d_inner, (self.dt_rank + self.d_state * 2), bias=False, **factory_kwargs),
nn.Linear(self.d_inner, (self.dt_rank + self.d_state * 2), bias=False, **factory_kwargs),
)
self.x_proj_weight = nn.Parameter(torch.stack([t.weight for t in self.x_proj], dim=0)) # (K=4, N, inner)
del self.x_proj
self.dt_projs = (
self.dt_init(self.dt_rank, self.d_inner, dt_scale, dt_init, dt_min, dt_max, dt_init_floor,
**factory_kwargs),
self.dt_init(self.dt_rank, self.d_inner, dt_scale, dt_init, dt_min, dt_max, dt_init_floor,
**factory_kwargs),
self.dt_init(self.dt_rank, self.d_inner, dt_scale, dt_init, dt_min, dt_max, dt_init_floor,
**factory_kwargs),
self.dt_init(self.dt_rank, self.d_inner, dt_scale, dt_init, dt_min, dt_max, dt_init_floor,
**factory_kwargs),
)
self.dt_projs_weight = nn.Parameter(torch.stack([t.weight for t in self.dt_projs], dim=0)) # (K=4, inner, rank)
self.dt_projs_bias = nn.Parameter(torch.stack([t.bias for t in self.dt_projs], dim=0)) # (K=4, inner)
del self.dt_projs
self.A_logs = self.A_log_init(self.d_state, self.d_inner, copies=4, merge=True) # (K=4, D, N)
self.Ds = self.D_init(self.d_inner, copies=4, merge=True) # (K=4, D, N)
self.selective_scan = selective_scan_fn
self.forward_core = self.forward_corev0
self.out_norm = nn.LayerNorm(self.d_inner)
self.out_proj = nn.Linear(self.d_inner, self.d_model, bias=bias, **factory_kwargs)
self.dropout = nn.Dropout(dropout) if dropout > 0. else None
@staticmethod
def dt_init(dt_rank, d_inner, dt_scale=1.0, dt_init="random", dt_min=0.001, dt_max=0.1, dt_init_floor=1e-4,
**factory_kwargs):
dt_proj = nn.Linear(dt_rank, d_inner, bias=True, **factory_kwargs)
# Initialize special dt projection to preserve variance at initialization
dt_init_std = dt_rank ** -0.5 * dt_scale
if dt_init == "constant":
nn.init.constant_(dt_proj.weight, dt_init_std)
elif dt_init == "random":
nn.init.uniform_(dt_proj.weight, -dt_init_std, dt_init_std)
else:
raise NotImplementedError
# Initialize dt bias so that F.softplus(dt_bias) is between dt_min and dt_max
dt = torch.exp(
torch.rand(d_inner, **factory_kwargs) * (math.log(dt_max) - math.log(dt_min))
+ math.log(dt_min)
).clamp(min=dt_init_floor)
# Inverse of softplus: https://github.com/pytorch/pytorch/issues/72759
inv_dt = dt + torch.log(-torch.expm1(-dt))
with torch.no_grad():
dt_proj.bias.copy_(inv_dt)
# Our initialization would set all Linear.bias to zero, need to mark this one as _no_reinit
dt_proj.bias._no_reinit = True
return dt_proj
@staticmethod
def A_log_init(d_state, d_inner, copies=1, device=None, merge=True):
# S4D real initialization
A = repeat(
torch.arange(1, d_state + 1, dtype=torch.float32, device=device),
"n -> d n",
d=d_inner,
).contiguous()
A_log = torch.log(A) # Keep A_log in fp32
if copies > 1:
A_log = repeat(A_log, "d n -> r d n", r=copies)
if merge:
A_log = A_log.flatten(0, 1)
A_log = nn.Parameter(A_log)
A_log._no_weight_decay = True
return A_log
@staticmethod
def D_init(d_inner, copies=1, device=None, merge=True):
# D "skip" parameter
D = torch.ones(d_inner, device=device)
if copies > 1:
D = repeat(D, "n1 -> r n1", r=copies)
if merge:
D = D.flatten(0, 1)
D = nn.Parameter(D) # Keep in fp32
D._no_weight_decay = True
return D
def forward_corev0(self, x: torch.Tensor):
self.selective_scan = selective_scan_fn
B, C, H, W = x.shape
L = H * W
K = 4
x_hwwh = torch.stack([x.view(B, -1, L), torch.transpose(x, dim0=2, dim1=3).contiguous().view(B, -1, L)],
dim=1).view(B, 2, -1, L)
xs = torch.cat([x_hwwh, torch.flip(x_hwwh, dims=[-1])], dim=1) # (b, k, d, l)
x_dbl = torch.einsum("b k d l, k c d -> b k c l", xs.view(B, K, -1, L), self.x_proj_weight)
# x_dbl = x_dbl + self.x_proj_bias.view(1, K, -1, 1)
dts, Bs, Cs = torch.split(x_dbl, [self.dt_rank, self.d_state, self.d_state], dim=2)
dts = torch.einsum("b k r l, k d r -> b k d l", dts.view(B, K, -1, L), self.dt_projs_weight)
# dts = dts + self.dt_projs_bias.view(1, K, -1, 1)
xs = xs.float().view(B, -1, L) # (b, k * d, l)
dts = dts.contiguous().float().view(B, -1, L) # (b, k * d, l)
Bs = Bs.float().view(B, K, -1, L) # (b, k, d_state, l)
Cs = Cs.float().view(B, K, -1, L) # (b, k, d_state, l)
Ds = self.Ds.float().view(-1) # (k * d)
As = -torch.exp(self.A_logs.float()).view(-1, self.d_state) # (k * d, d_state)
dt_projs_bias = self.dt_projs_bias.float().view(-1) # (k * d)
out_y = self.selective_scan(
xs, dts,
As, Bs, Cs, Ds, z=None,
delta_bias=dt_projs_bias,
delta_softplus=True,
return_last_state=False,
).view(B, K, -1, L)
assert out_y.dtype == torch.float
inv_y = torch.flip(out_y[:, 2:4], dims=[-1]).view(B, 2, -1, L)
wh_y = torch.transpose(out_y[:, 1].view(B, -1, W, H), dim0=2, dim1=3).contiguous().view(B, -1, L)
invwh_y = torch.transpose(inv_y[:, 1].view(B, -1, W, H), dim0=2, dim1=3).contiguous().view(B, -1, L)
return out_y[:, 0], inv_y[:, 0], wh_y, invwh_y
# an alternative to forward_corev1
def forward_corev1(self, x: torch.Tensor):
self.selective_scan = selective_scan_fn
B, C, H, W = x.shape
L = H * W
K = 4
x_hwwh = torch.stack([x.view(B, -1, L), torch.transpose(x, dim0=2, dim1=3).contiguous().view(B, -1, L)],
dim=1).view(B, 2, -1, L)
xs = torch.cat([x_hwwh, torch.flip(x_hwwh, dims=[-1])], dim=1) # (b, k, d, l)
x_dbl = torch.einsum("b k d l, k c d -> b k c l", xs.view(B, K, -1, L), self.x_proj_weight)
# x_dbl = x_dbl + self.x_proj_bias.view(1, K, -1, 1)
dts, Bs, Cs = torch.split(x_dbl, [self.dt_rank, self.d_state, self.d_state], dim=2)
dts = torch.einsum("b k r l, k d r -> b k d l", dts.view(B, K, -1, L), self.dt_projs_weight)
# dts = dts + self.dt_projs_bias.view(1, K, -1, 1)
xs = xs.float().view(B, -1, L) # (b, k * d, l)
dts = dts.contiguous().float().view(B, -1, L) # (b, k * d, l)
Bs = Bs.float().view(B, K, -1, L) # (b, k, d_state, l)
Cs = Cs.float().view(B, K, -1, L) # (b, k, d_state, l)
Ds = self.Ds.float().view(-1) # (k * d)
As = -torch.exp(self.A_logs.float()).view(-1, self.d_state) # (k * d, d_state)
dt_projs_bias = self.dt_projs_bias.float().view(-1) # (k * d)
out_y = self.selective_scan(
xs, dts,
As, Bs, Cs, Ds,
delta_bias=dt_projs_bias,
delta_softplus=True,
).view(B, K, -1, L)
assert out_y.dtype == torch.float
inv_y = torch.flip(out_y[:, 2:4], dims=[-1]).view(B, 2, -1, L)
wh_y = torch.transpose(out_y[:, 1].view(B, -1, W, H), dim0=2, dim1=3).contiguous().view(B, -1, L)
invwh_y = torch.transpose(inv_y[:, 1].view(B, -1, W, H), dim0=2, dim1=3).contiguous().view(B, -1, L)
return out_y[:, 0], inv_y[:, 0], wh_y, invwh_y
def forward(self, x: torch.Tensor, **kwargs):
B, H, W, C = x.shape
xz = self.in_proj(x)
x, z = xz.chunk(2, dim=-1) # (b, h, w, d)
x = x.permute(0, 3, 1, 2).contiguous()
x = self.act(self.conv2d(x)) # (b, d, h, w)
y1, y2, y3, y4 = self.forward_core(x)
assert y1.dtype == torch.float32
y = y1 + y2 + y3 + y4
y = torch.transpose(y, dim0=1, dim1=2).contiguous().view(B, H, W, -1)
y = self.out_norm(y)
y = y * F.silu(z)
out = self.out_proj(y)
if self.dropout is not None:
out = self.dropout(out)
return out
class ConvSSM(nn.Module):
"""
这个类组合了自注意力机制和卷积操作旨在融合自注意力的全局感知能力和卷积的局部特征提取能力
输入特征被分成两部分一部分通过SS2D自注意力模块处理另一部分通过一系列卷积层处理处理后的两部分再次合并并通过最终的卷积层生成输出特征
"""
def __init__(
self,
hidden_dim: int = 0,
drop_path: float = 0,
norm_layer: Callable[..., torch.nn.Module] = partial(nn.LayerNorm, eps=1e-6),
attn_drop_rate: float = 0,
d_state: int = 16,
**kwargs,
):
super().__init__()
self.ln_1 = norm_layer(hidden_dim // 2)
self.self_attention = SS2D(d_model=hidden_dim // 2, dropout=attn_drop_rate, d_state=d_state, **kwargs)
self.drop_path = DropPath(drop_path)
self.conv33conv33conv11 = nn.Sequential(
nn.Conv2d(in_channels=hidden_dim // 2, out_channels=hidden_dim // 2, kernel_size=3, stride=1, padding=1),
nn.BatchNorm2d(hidden_dim // 2),
nn.ReLU(),
nn.Conv2d(in_channels=hidden_dim // 2, out_channels=hidden_dim // 2, kernel_size=3, stride=1, padding=1),
nn.BatchNorm2d(hidden_dim // 2),
nn.ReLU(),
nn.Conv2d(in_channels=hidden_dim // 2, out_channels=hidden_dim // 2, kernel_size=1, stride=1)
)
self.finalconv11 = nn.Conv2d(in_channels=hidden_dim, out_channels=hidden_dim, kernel_size=1, stride=1)
# def forward(self, input: torch.Tensor):
#
#
# """
# <img src="http://42.192.130.83:9000/picgo/imgs/image-20240502123357829.png" alt="image-202http://42.192.130.83:9000/picgo/imgs/image-20240502123357829.png
#
# """
#
# # 将输入图像在最后一个维度上切分成两半 left 4 32 32 64 right 4 32 32 64
# input_left, input_right = input.chunk(2, dim=-1)
# # 应用自注意力机制并经过drop_path和ln_1层归一化处理input_right
# # self.self_attention为SS2D模块它实现了自注意力机制
# x = input_right + self.drop_path(self.self_attention(self.ln_1(input_right))) # 4 32 32 64
# # 将x的维度置换回与input_left匹配以便后续合并
# x = x.permute(0, 3, 1, 2).contiguous() # 4 64 32 32
#
# # 对input_left进行维度置换以适配后续卷积操作的需要
# input_left = input_left.permute(0, 3, 1, 2).contiguous() # 4 64 32 32
# # 应用特定结构的卷积操作 conv33conv33conv11
# input_left = self.conv33conv33conv11(input_left) # 4 64 32 32
#
# output = torch.cat((input_left, x), dim=1) # 4 128 32 32
# output = self.finalconv11(output).permute(0, 2, 3, 1).contiguous() # 4 32 32 128
# return output + input
def forward(self, input: torch.Tensor):
"""
<img src="http://42.192.130.83:9000/picgo/imgs/image-20240502123357829.png" alt="image-202http://42.192.130.83:9000/picgo/imgs/image-20240502123357829.png
"""
input = input.permute(0, 2, 3, 1)
print(input.is_cuda)
# 将输入图像在最后一个维度上切分成两半 left 4 32 32 64 right 4 32 32 64
input_left, input_right = input.chunk(2, dim=-1)
# 应用自注意力机制并经过drop_path和ln_1层归一化处理input_right
# self.self_attention为SS2D模块它实现了自注意力机制
x = input_right + self.drop_path(self.self_attention(self.ln_1(input_right))) # 4 32 32 64
# 将x的维度置换回与input_left匹配以便后续合并
x = x.permute(0, 3, 1, 2).contiguous() # 4 64 32 32
# 对input_left进行维度置换以适配后续卷积操作的需要
input_left = input_left.permute(0, 3, 1, 2).contiguous() # 4 64 32 32
# 应用特定结构的卷积操作 conv33conv33conv11
input_left = self.conv33conv33conv11(input_left) # 4 64 32 32
output = torch.cat((input_left, x), dim=1) # 4 128 32 32
output = self.finalconv11(output).permute(0, 2, 3, 1).contiguous() # 4 32 32 128
return (output + input).permute(0, 3, 1, 2)
if __name__ == '__main__':
# 初始化ConvSSM模块hidden_dim为128
# block = ConvSSM(hidden_dim=128)
block = ConvSSM(hidden_dim=4)
# 生成随机输入张量,尺寸为[批次大小, 高度, 宽度, 通道数]
# 这里批次大小为4高度和宽度为32通道数为128符合hidden_dim的大小
# input_tensor = torch.rand(4, 32, 32, 128)
# input_tensor = torch.rand(2, 1, 256, 128)
input_tensor = torch.rand(2, 4, 256, 256).cuda()
# 前向传递输入张量通过ConvSSM模块
output = block(input_tensor)
# 打印输入和输出张量的尺寸
print("Input tensor size:", input_tensor.size())
print("Output tensor size:", output.size())

View File

@ -1,548 +0,0 @@
import time
import math
from functools import partial
from typing import Optional, Callable
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.checkpoint as checkpoint
from einops import rearrange, repeat
from timm.models.layers import DropPath, to_2tuple, trunc_normal_
try:
from mamba_ssm.ops.selective_scan_interface import selective_scan_fn, selective_scan_ref
except:
pass
# an alternative for mamba_ssm (in which causal_conv1d is needed)
try:
from selective_scan import selective_scan_fn as selective_scan_fn_v1
from selective_scan import selective_scan_ref as selective_scan_ref_v1
except:
pass
DropPath.__repr__ = lambda self: f"timm.DropPath({self.drop_prob})"
"""
CNN在远程建模能力方面的局限性使它们无法有效地提取图像中的特征而Transformers则受到其二次计算复杂性的阻碍最近的研究表明以Mamba为代表的状态空间模型SSM可以在保持线性计算复杂度的同时有效地模拟长程相互作用
我们介绍了一个新颖的Conv-SSM模块Conv-SSM将卷积层的局部特征提取能力与SSM捕获长程依赖性的能力相结合
"""
def flops_selective_scan_ref(B=1, L=256, D=768, N=16, with_D=True, with_Z=False, with_Group=True, with_complex=False):
"""
u: r(B D L)
delta: r(B D L)
A: r(D N)
B: r(B N L)
C: r(B N L)
D: r(D)
z: r(B D L)
delta_bias: r(D), fp32
ignores:
[.float(), +, .softplus, .shape, new_zeros, repeat, stack, to(dtype), silu]
"""
import numpy as np
# fvcore.nn.jit_handles
def get_flops_einsum(input_shapes, equation):
np_arrs = [np.zeros(s) for s in input_shapes]
optim = np.einsum_path(equation, *np_arrs, optimize="optimal")[1]
for line in optim.split("\n"):
if "optimized flop" in line.lower():
# divided by 2 because we count MAC (multiply-add counted as one flop)
flop = float(np.floor(float(line.split(":")[-1]) / 2))
return flop
assert not with_complex
flops = 0 # below code flops = 0
if False:
...
"""
dtype_in = u.dtype
u = u.float()
delta = delta.float()
if delta_bias is not None:
delta = delta + delta_bias[..., None].float()
if delta_softplus:
delta = F.softplus(delta)
batch, dim, dstate = u.shape[0], A.shape[0], A.shape[1]
is_variable_B = B.dim() >= 3
is_variable_C = C.dim() >= 3
if A.is_complex():
if is_variable_B:
B = torch.view_as_complex(rearrange(B.float(), "... (L two) -> ... L two", two=2))
if is_variable_C:
C = torch.view_as_complex(rearrange(C.float(), "... (L two) -> ... L two", two=2))
else:
B = B.float()
C = C.float()
x = A.new_zeros((batch, dim, dstate))
ys = []
"""
flops += get_flops_einsum([[B, D, L], [D, N]], "bdl,dn->bdln")
if with_Group:
flops += get_flops_einsum([[B, D, L], [B, N, L], [B, D, L]], "bdl,bnl,bdl->bdln")
else:
flops += get_flops_einsum([[B, D, L], [B, D, N, L], [B, D, L]], "bdl,bdnl,bdl->bdln")
if False:
...
"""
deltaA = torch.exp(torch.einsum('bdl,dn->bdln', delta, A))
if not is_variable_B:
deltaB_u = torch.einsum('bdl,dn,bdl->bdln', delta, B, u)
else:
if B.dim() == 3:
deltaB_u = torch.einsum('bdl,bnl,bdl->bdln', delta, B, u)
else:
B = repeat(B, "B G N L -> B (G H) N L", H=dim // B.shape[1])
deltaB_u = torch.einsum('bdl,bdnl,bdl->bdln', delta, B, u)
if is_variable_C and C.dim() == 4:
C = repeat(C, "B G N L -> B (G H) N L", H=dim // C.shape[1])
last_state = None
"""
in_for_flops = B * D * N
if with_Group:
in_for_flops += get_flops_einsum([[B, D, N], [B, D, N]], "bdn,bdn->bd")
else:
in_for_flops += get_flops_einsum([[B, D, N], [B, N]], "bdn,bn->bd")
flops += L * in_for_flops
if False:
...
"""
for i in range(u.shape[2]):
x = deltaA[:, :, i] * x + deltaB_u[:, :, i]
if not is_variable_C:
y = torch.einsum('bdn,dn->bd', x, C)
else:
if C.dim() == 3:
y = torch.einsum('bdn,bn->bd', x, C[:, :, i])
else:
y = torch.einsum('bdn,bdn->bd', x, C[:, :, :, i])
if i == u.shape[2] - 1:
last_state = x
if y.is_complex():
y = y.real * 2
ys.append(y)
y = torch.stack(ys, dim=2) # (batch dim L)
"""
if with_D:
flops += B * D * L
if with_Z:
flops += B * D * L
if False:
...
"""
out = y if D is None else y + u * rearrange(D, "d -> d 1")
if z is not None:
out = out * F.silu(z)
out = out.to(dtype=dtype_in)
"""
return flops
class PatchEmbed2D(nn.Module):
r""" Image to Patch Embedding
Args:
patch_size (int): Patch token size. Default: 4.
in_chans (int): Number of input image channels. Default: 3.
embed_dim (int): Number of linear projection output channels. Default: 96.
norm_layer (nn.Module, optional): Normalization layer. Default: None
"""
def __init__(self, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None, **kwargs):
super().__init__()
if isinstance(patch_size, int):
patch_size = (patch_size, patch_size)
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
if norm_layer is not None:
self.norm = norm_layer(embed_dim)
else:
self.norm = None
def forward(self, x):
x = self.proj(x).permute(0, 2, 3, 1)
if self.norm is not None:
x = self.norm(x)
return x
class PatchMerging2D(nn.Module):
r""" Patch Merging Layer.
Args:
input_resolution (tuple[int]): Resolution of input feature.
dim (int): Number of input channels.
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
"""
def __init__(self, dim, norm_layer=nn.LayerNorm):
super().__init__()
self.dim = dim
self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
self.norm = norm_layer(4 * dim)
def forward(self, x):
B, H, W, C = x.shape
SHAPE_FIX = [-1, -1]
if (W % 2 != 0) or (H % 2 != 0):
print(f"Warning, x.shape {x.shape} is not match even ===========", flush=True)
SHAPE_FIX[0] = H // 2
SHAPE_FIX[1] = W // 2
x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C
x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C
x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C
x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C
if SHAPE_FIX[0] > 0:
x0 = x0[:, :SHAPE_FIX[0], :SHAPE_FIX[1], :]
x1 = x1[:, :SHAPE_FIX[0], :SHAPE_FIX[1], :]
x2 = x2[:, :SHAPE_FIX[0], :SHAPE_FIX[1], :]
x3 = x3[:, :SHAPE_FIX[0], :SHAPE_FIX[1], :]
x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C
x = x.view(B, H // 2, W // 2, 4 * C) # B H/2*W/2 4*C
x = self.norm(x)
x = self.reduction(x)
return x
class PatchExpand2D(nn.Module):
def __init__(self, dim, dim_scale=2, norm_layer=nn.LayerNorm):
super().__init__()
self.dim = dim * 2
self.dim_scale = dim_scale
self.expand = nn.Linear(self.dim, dim_scale * self.dim, bias=False)
self.norm = norm_layer(self.dim // dim_scale)
def forward(self, x):
B, H, W, C = x.shape
x = self.expand(x)
x = rearrange(x, 'b h w (p1 p2 c)-> b (h p1) (w p2) c', p1=self.dim_scale, p2=self.dim_scale,
c=C // self.dim_scale)
x = self.norm(x)
return x
class Final_PatchExpand2D(nn.Module):
def __init__(self, dim, dim_scale=4, norm_layer=nn.LayerNorm):
super().__init__()
self.dim = dim
self.dim_scale = dim_scale
self.expand = nn.Linear(self.dim, dim_scale * self.dim, bias=False)
self.norm = norm_layer(self.dim // dim_scale)
def forward(self, x):
B, H, W, C = x.shape
x = self.expand(x)
x = rearrange(x, 'b h w (p1 p2 c)-> b (h p1) (w p2) c', p1=self.dim_scale, p2=self.dim_scale,
c=C // self.dim_scale)
x = self.norm(x)
return x
class SS2D(
nn.Module): # SS2DState Space Model 2D它是实现状态空间模型SSM在视觉任务中应用的核心部分。这个类将状态空间模型的长程依赖性建模能力与卷积操作结合起来旨在捕获图像中的长程相互作用同时保持对局部特征的敏感度。
def __init__(
self,
d_model,
d_state=16,
# d_state="auto", # 20240109
d_conv=3,
expand=2,
dt_rank="auto",
dt_min=0.001,
dt_max=0.1,
dt_init="random",
dt_scale=1.0,
dt_init_floor=1e-4,
dropout=0.,
conv_bias=True,
bias=False,
device=None,
dtype=None,
**kwargs,
):
factory_kwargs = {"device": device, "dtype": dtype}
super().__init__()
self.d_model = d_model
self.d_state = d_state
# self.d_state = math.ceil(self.d_model / 6) if d_state == "auto" else d_model # 20240109
self.d_conv = d_conv
self.expand = expand
self.d_inner = int(self.expand * self.d_model)
self.dt_rank = math.ceil(self.d_model / 16) if dt_rank == "auto" else dt_rank
self.in_proj = nn.Linear(self.d_model, self.d_inner * 2, bias=bias, **factory_kwargs)
self.conv2d = nn.Conv2d(
in_channels=self.d_inner,
out_channels=self.d_inner,
groups=self.d_inner,
bias=conv_bias,
kernel_size=d_conv,
padding=(d_conv - 1) // 2,
**factory_kwargs,
)
self.act = nn.SiLU()
self.x_proj = (
nn.Linear(self.d_inner, (self.dt_rank + self.d_state * 2), bias=False, **factory_kwargs),
nn.Linear(self.d_inner, (self.dt_rank + self.d_state * 2), bias=False, **factory_kwargs),
nn.Linear(self.d_inner, (self.dt_rank + self.d_state * 2), bias=False, **factory_kwargs),
nn.Linear(self.d_inner, (self.dt_rank + self.d_state * 2), bias=False, **factory_kwargs),
)
self.x_proj_weight = nn.Parameter(torch.stack([t.weight for t in self.x_proj], dim=0)) # (K=4, N, inner)
del self.x_proj
self.dt_projs = (
self.dt_init(self.dt_rank, self.d_inner, dt_scale, dt_init, dt_min, dt_max, dt_init_floor,
**factory_kwargs),
self.dt_init(self.dt_rank, self.d_inner, dt_scale, dt_init, dt_min, dt_max, dt_init_floor,
**factory_kwargs),
self.dt_init(self.dt_rank, self.d_inner, dt_scale, dt_init, dt_min, dt_max, dt_init_floor,
**factory_kwargs),
self.dt_init(self.dt_rank, self.d_inner, dt_scale, dt_init, dt_min, dt_max, dt_init_floor,
**factory_kwargs),
)
self.dt_projs_weight = nn.Parameter(torch.stack([t.weight for t in self.dt_projs], dim=0)) # (K=4, inner, rank)
self.dt_projs_bias = nn.Parameter(torch.stack([t.bias for t in self.dt_projs], dim=0)) # (K=4, inner)
del self.dt_projs
self.A_logs = self.A_log_init(self.d_state, self.d_inner, copies=4, merge=True) # (K=4, D, N)
self.Ds = self.D_init(self.d_inner, copies=4, merge=True) # (K=4, D, N)
# self.selective_scan = selective_scan_fn
self.forward_core = self.forward_corev0
self.out_norm = nn.LayerNorm(self.d_inner)
self.out_proj = nn.Linear(self.d_inner, self.d_model, bias=bias, **factory_kwargs)
self.dropout = nn.Dropout(dropout) if dropout > 0. else None
@staticmethod
def dt_init(dt_rank, d_inner, dt_scale=1.0, dt_init="random", dt_min=0.001, dt_max=0.1, dt_init_floor=1e-4,
**factory_kwargs):
dt_proj = nn.Linear(dt_rank, d_inner, bias=True, **factory_kwargs)
# Initialize special dt projection to preserve variance at initialization
dt_init_std = dt_rank ** -0.5 * dt_scale
if dt_init == "constant":
nn.init.constant_(dt_proj.weight, dt_init_std)
elif dt_init == "random":
nn.init.uniform_(dt_proj.weight, -dt_init_std, dt_init_std)
else:
raise NotImplementedError
# Initialize dt bias so that F.softplus(dt_bias) is between dt_min and dt_max
dt = torch.exp(
torch.rand(d_inner, **factory_kwargs) * (math.log(dt_max) - math.log(dt_min))
+ math.log(dt_min)
).clamp(min=dt_init_floor)
# Inverse of softplus: https://github.com/pytorch/pytorch/issues/72759
inv_dt = dt + torch.log(-torch.expm1(-dt))
with torch.no_grad():
dt_proj.bias.copy_(inv_dt)
# Our initialization would set all Linear.bias to zero, need to mark this one as _no_reinit
dt_proj.bias._no_reinit = True
return dt_proj
@staticmethod
def A_log_init(d_state, d_inner, copies=1, device=None, merge=True):
# S4D real initialization
A = repeat(
torch.arange(1, d_state + 1, dtype=torch.float32, device=device),
"n -> d n",
d=d_inner,
).contiguous()
A_log = torch.log(A) # Keep A_log in fp32
if copies > 1:
A_log = repeat(A_log, "d n -> r d n", r=copies)
if merge:
A_log = A_log.flatten(0, 1)
A_log = nn.Parameter(A_log)
A_log._no_weight_decay = True
return A_log
@staticmethod
def D_init(d_inner, copies=1, device=None, merge=True):
# D "skip" parameter
D = torch.ones(d_inner, device=device)
if copies > 1:
D = repeat(D, "n1 -> r n1", r=copies)
if merge:
D = D.flatten(0, 1)
D = nn.Parameter(D) # Keep in fp32
D._no_weight_decay = True
return D
def forward_corev0(self, x: torch.Tensor):
self.selective_scan = selective_scan_fn
B, C, H, W = x.shape
L = H * W
K = 4
x_hwwh = torch.stack([x.view(B, -1, L), torch.transpose(x, dim0=2, dim1=3).contiguous().view(B, -1, L)],
dim=1).view(B, 2, -1, L)
xs = torch.cat([x_hwwh, torch.flip(x_hwwh, dims=[-1])], dim=1) # (b, k, d, l)
x_dbl = torch.einsum("b k d l, k c d -> b k c l", xs.view(B, K, -1, L), self.x_proj_weight)
# x_dbl = x_dbl + self.x_proj_bias.view(1, K, -1, 1)
dts, Bs, Cs = torch.split(x_dbl, [self.dt_rank, self.d_state, self.d_state], dim=2)
dts = torch.einsum("b k r l, k d r -> b k d l", dts.view(B, K, -1, L), self.dt_projs_weight)
# dts = dts + self.dt_projs_bias.view(1, K, -1, 1)
xs = xs.float().view(B, -1, L) # (b, k * d, l)
dts = dts.contiguous().float().view(B, -1, L) # (b, k * d, l)
Bs = Bs.float().view(B, K, -1, L) # (b, k, d_state, l)
Cs = Cs.float().view(B, K, -1, L) # (b, k, d_state, l)
Ds = self.Ds.float().view(-1) # (k * d)
As = -torch.exp(self.A_logs.float()).view(-1, self.d_state) # (k * d, d_state)
dt_projs_bias = self.dt_projs_bias.float().view(-1) # (k * d)
out_y = self.selective_scan(
xs, dts,
As, Bs, Cs, Ds, z=None,
delta_bias=dt_projs_bias,
delta_softplus=True,
return_last_state=False,
).view(B, K, -1, L)
assert out_y.dtype == torch.float
inv_y = torch.flip(out_y[:, 2:4], dims=[-1]).view(B, 2, -1, L)
wh_y = torch.transpose(out_y[:, 1].view(B, -1, W, H), dim0=2, dim1=3).contiguous().view(B, -1, L)
invwh_y = torch.transpose(inv_y[:, 1].view(B, -1, W, H), dim0=2, dim1=3).contiguous().view(B, -1, L)
return out_y[:, 0], inv_y[:, 0], wh_y, invwh_y
# an alternative to forward_corev1
def forward_corev1(self, x: torch.Tensor):
self.selective_scan = selective_scan_fn_v1
B, C, H, W = x.shape
L = H * W
K = 4
x_hwwh = torch.stack([x.view(B, -1, L), torch.transpose(x, dim0=2, dim1=3).contiguous().view(B, -1, L)],
dim=1).view(B, 2, -1, L)
xs = torch.cat([x_hwwh, torch.flip(x_hwwh, dims=[-1])], dim=1) # (b, k, d, l)
x_dbl = torch.einsum("b k d l, k c d -> b k c l", xs.view(B, K, -1, L), self.x_proj_weight)
# x_dbl = x_dbl + self.x_proj_bias.view(1, K, -1, 1)
dts, Bs, Cs = torch.split(x_dbl, [self.dt_rank, self.d_state, self.d_state], dim=2)
dts = torch.einsum("b k r l, k d r -> b k d l", dts.view(B, K, -1, L), self.dt_projs_weight)
# dts = dts + self.dt_projs_bias.view(1, K, -1, 1)
xs = xs.float().view(B, -1, L) # (b, k * d, l)
dts = dts.contiguous().float().view(B, -1, L) # (b, k * d, l)
Bs = Bs.float().view(B, K, -1, L) # (b, k, d_state, l)
Cs = Cs.float().view(B, K, -1, L) # (b, k, d_state, l)
Ds = self.Ds.float().view(-1) # (k * d)
As = -torch.exp(self.A_logs.float()).view(-1, self.d_state) # (k * d, d_state)
dt_projs_bias = self.dt_projs_bias.float().view(-1) # (k * d)
out_y = self.selective_scan(
xs, dts,
As, Bs, Cs, Ds,
delta_bias=dt_projs_bias,
delta_softplus=True,
).view(B, K, -1, L)
assert out_y.dtype == torch.float
inv_y = torch.flip(out_y[:, 2:4], dims=[-1]).view(B, 2, -1, L)
wh_y = torch.transpose(out_y[:, 1].view(B, -1, W, H), dim0=2, dim1=3).contiguous().view(B, -1, L)
invwh_y = torch.transpose(inv_y[:, 1].view(B, -1, W, H), dim0=2, dim1=3).contiguous().view(B, -1, L)
return out_y[:, 0], inv_y[:, 0], wh_y, invwh_y
def forward(self, x: torch.Tensor, **kwargs):
B, H, W, C = x.shape
xz = self.in_proj(x)
x, z = xz.chunk(2, dim=-1) # (b, h, w, d)
x = x.permute(0, 3, 1, 2).contiguous()
x = self.act(self.conv2d(x)) # (b, d, h, w)
y1, y2, y3, y4 = self.forward_core(x)
assert y1.dtype == torch.float32
y = y1 + y2 + y3 + y4
y = torch.transpose(y, dim0=1, dim1=2).contiguous().view(B, H, W, -1)
y = self.out_norm(y)
y = y * F.silu(z)
out = self.out_proj(y)
if self.dropout is not None:
out = self.dropout(out)
return out
class ConvSSM(nn.Module):
"""
这个类组合了自注意力机制和卷积操作旨在融合自注意力的全局感知能力和卷积的局部特征提取能力
输入特征被分成两部分一部分通过SS2D自注意力模块处理另一部分通过一系列卷积层处理处理后的两部分再次合并并通过最终的卷积层生成输出特征
"""
def __init__(
self,
hidden_dim: int = 0,
drop_path: float = 0,
norm_layer: Callable[..., torch.nn.Module] = partial(nn.LayerNorm, eps=1e-6),
attn_drop_rate: float = 0,
d_state: int = 16,
**kwargs,
):
super().__init__()
self.ln_1 = norm_layer(hidden_dim // 2)
self.self_attention = SS2D(d_model=hidden_dim // 2, dropout=attn_drop_rate, d_state=d_state, **kwargs)
self.drop_path = DropPath(drop_path)
self.conv33conv33conv11 = nn.Sequential(
nn.Conv2d(in_channels=hidden_dim // 2, out_channels=hidden_dim // 2, kernel_size=3, stride=1, padding=1),
nn.BatchNorm2d(hidden_dim // 2),
nn.ReLU(),
nn.Conv2d(in_channels=hidden_dim // 2, out_channels=hidden_dim // 2, kernel_size=3, stride=1, padding=1),
nn.BatchNorm2d(hidden_dim // 2),
nn.ReLU(),
nn.Conv2d(in_channels=hidden_dim // 2, out_channels=hidden_dim // 2, kernel_size=1, stride=1)
)
self.finalconv11 = nn.Conv2d(in_channels=hidden_dim, out_channels=hidden_dim, kernel_size=1, stride=1)
def forward(self, input: torch.Tensor):
input_left, input_right = input.chunk(2, dim=-1)
x = input_right + self.drop_path(self.self_attention(self.ln_1(input_right)))
input_left = input_left.permute(0, 3, 1, 2).contiguous()
input_left = self.conv33conv33conv11(input_left)
x = x.permute(0, 3, 1, 2).contiguous()
output = torch.cat((input_left, x), dim=1)
output = self.finalconv11(output).permute(0, 2, 3, 1).contiguous()
return output + input
if __name__ == '__main__':
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# 初始化ConvSSM模块hidden_dim为128
block = ConvSSM(hidden_dim=128).to(device)
# 生成随机输入张量,尺寸为[批次大小, 高度, 宽度, 通道数]
# 这里批次大小为4高度和宽度为32通道数为128符合hidden_dim的大小
input_tensor = torch.rand(4, 32, 32, 128).to(device)
# 前向传递输入张量通过ConvSSM模块
output = block(input_tensor)
# 打印输入和输出张量的尺寸
print("Input tensor size:", input_tensor.size())
print("Output tensor size:", output.size())

View File

@ -1,198 +0,0 @@
import pywt
import pywt.data
import torch
from torch import nn
from torch.autograd import Function
import torch.nn.functional as F
"""ECCV2024 https://arxiv.org/abs/2407.05848
近年来人们尝试增加卷积神经网络 (CNN) 的内核大小以模拟 Vision Transformers (ViTs) 自注意力模块的全局接受场
然而这种方法在实现全局接受场之前就很快达到上限并饱和
在这项研究中我们证明了通过利用小波变换 (WT)实际上可以获得非常大的接受场而不会遭受过度参数化所提出的层名为 WTConv
可用作现有架构中的直接替换产生有效的多频响应并可随接受场的大小优雅地扩展
我们证明了 ConvNeXt MobileNetV2 架构中 WTConv 层对图像分类的有效性以及下游任务的主干并表明它具有其他属性
例如对图像损坏的鲁棒性和对纹理形状的响应增强
"""
def create_wavelet_filter(wave, in_size, out_size, type=torch.float):
w = pywt.Wavelet(wave)
dec_hi = torch.tensor(w.dec_hi[::-1], dtype=type)
dec_lo = torch.tensor(w.dec_lo[::-1], dtype=type)
dec_filters = torch.stack([dec_lo.unsqueeze(0) * dec_lo.unsqueeze(1),
dec_lo.unsqueeze(0) * dec_hi.unsqueeze(1),
dec_hi.unsqueeze(0) * dec_lo.unsqueeze(1),
dec_hi.unsqueeze(0) * dec_hi.unsqueeze(1)], dim=0)
dec_filters = dec_filters[:, None].repeat(in_size, 1, 1, 1)
rec_hi = torch.tensor(w.rec_hi[::-1], dtype=type).flip(dims=[0])
rec_lo = torch.tensor(w.rec_lo[::-1], dtype=type).flip(dims=[0])
rec_filters = torch.stack([rec_lo.unsqueeze(0) * rec_lo.unsqueeze(1),
rec_lo.unsqueeze(0) * rec_hi.unsqueeze(1),
rec_hi.unsqueeze(0) * rec_lo.unsqueeze(1),
rec_hi.unsqueeze(0) * rec_hi.unsqueeze(1)], dim=0)
rec_filters = rec_filters[:, None].repeat(out_size, 1, 1, 1)
return dec_filters, rec_filters
def wavelet_transform(x, filters):
b, c, h, w = x.shape
pad = (filters.shape[2] // 2 - 1, filters.shape[3] // 2 - 1)
x = F.conv2d(x, filters, stride=2, groups=c, padding=pad)
x = x.reshape(b, c, 4, h // 2, w // 2)
return x
def inverse_wavelet_transform(x, filters):
b, c, _, h_half, w_half = x.shape
pad = (filters.shape[2] // 2 - 1, filters.shape[3] // 2 - 1)
x = x.reshape(b, c * 4, h_half, w_half)
x = F.conv_transpose2d(x, filters, stride=2, groups=c, padding=pad)
return x
def wavelet_transform_init(filters):
class WaveletTransform(Function):
@staticmethod
def forward(ctx, input):
with torch.no_grad():
x = wavelet_transform(input, filters)
return x
@staticmethod
def backward(ctx, grad_output):
grad = inverse_wavelet_transform(grad_output, filters)
return grad, None
return WaveletTransform().apply
def inverse_wavelet_transform_init(filters):
class InverseWaveletTransform(Function):
@staticmethod
def forward(ctx, input):
with torch.no_grad():
x = inverse_wavelet_transform(input, filters)
return x
@staticmethod
def backward(ctx, grad_output):
grad = wavelet_transform(grad_output, filters)
return grad, None
return InverseWaveletTransform().apply
class WTConv2d(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size=5, stride=1, bias=True, wt_levels=1, wt_type='db1'):
super(WTConv2d, self).__init__()
assert in_channels == out_channels
self.in_channels = in_channels
self.wt_levels = wt_levels
self.stride = stride
self.dilation = 1
self.wt_filter, self.iwt_filter = create_wavelet_filter(wt_type, in_channels, in_channels, torch.float)
self.wt_filter = nn.Parameter(self.wt_filter, requires_grad=False)
self.iwt_filter = nn.Parameter(self.iwt_filter, requires_grad=False)
self.wt_function = wavelet_transform_init(self.wt_filter)
self.iwt_function = inverse_wavelet_transform_init(self.iwt_filter)
self.base_conv = nn.Conv2d(in_channels, in_channels, kernel_size, padding='same', stride=1, dilation=1,
groups=in_channels, bias=bias)
self.base_scale = _ScaleModule([1, in_channels, 1, 1])
self.wavelet_convs = nn.ModuleList(
[nn.Conv2d(in_channels * 4, in_channels * 4, kernel_size, padding='same', stride=1, dilation=1,
groups=in_channels * 4, bias=False) for _ in range(self.wt_levels)]
)
self.wavelet_scale = nn.ModuleList(
[_ScaleModule([1, in_channels * 4, 1, 1], init_scale=0.1) for _ in range(self.wt_levels)]
)
if self.stride > 1:
self.stride_filter = nn.Parameter(torch.ones(in_channels, 1, 1, 1), requires_grad=False)
self.do_stride = lambda x_in: F.conv2d(x_in, self.stride_filter, bias=None, stride=self.stride,
groups=in_channels)
else:
self.do_stride = None
def forward(self, x):
x_ll_in_levels = []
x_h_in_levels = []
shapes_in_levels = []
curr_x_ll = x
for i in range(self.wt_levels):
curr_shape = curr_x_ll.shape
shapes_in_levels.append(curr_shape)
if (curr_shape[2] % 2 > 0) or (curr_shape[3] % 2 > 0):
curr_pads = (0, curr_shape[3] % 2, 0, curr_shape[2] % 2)
curr_x_ll = F.pad(curr_x_ll, curr_pads)
curr_x = self.wt_function(curr_x_ll)
curr_x_ll = curr_x[:, :, 0, :, :]
shape_x = curr_x.shape
curr_x_tag = curr_x.reshape(shape_x[0], shape_x[1] * 4, shape_x[3], shape_x[4])
curr_x_tag = self.wavelet_scale[i](self.wavelet_convs[i](curr_x_tag))
curr_x_tag = curr_x_tag.reshape(shape_x)
x_ll_in_levels.append(curr_x_tag[:, :, 0, :, :])
x_h_in_levels.append(curr_x_tag[:, :, 1:4, :, :])
next_x_ll = 0
for i in range(self.wt_levels - 1, -1, -1):
curr_x_ll = x_ll_in_levels.pop()
curr_x_h = x_h_in_levels.pop()
curr_shape = shapes_in_levels.pop()
curr_x_ll = curr_x_ll + next_x_ll
curr_x = torch.cat([curr_x_ll.unsqueeze(2), curr_x_h], dim=2)
next_x_ll = self.iwt_function(curr_x)
next_x_ll = next_x_ll[:, :, :curr_shape[2], :curr_shape[3]]
x_tag = next_x_ll
assert len(x_ll_in_levels) == 0
x = self.base_scale(self.base_conv(x))
x = x + x_tag
if self.do_stride is not None:
x = self.do_stride(x)
return x
class _ScaleModule(nn.Module):
def __init__(self, dims, init_scale=1.0, init_bias=0):
super(_ScaleModule, self).__init__()
self.dims = dims
self.weight = nn.Parameter(torch.ones(*dims) * init_scale)
self.bias = None
def forward(self, x):
return torch.mul(self.weight, x)
if __name__ == '__main__':
in_channels = 3
out_channels = 3
block = WTConv2d(in_channels, out_channels)
input = torch.rand(1, in_channels, 64, 64)
output = block(input)
print(input.size())
print(output.size())

View File

@ -1,25 +0,0 @@
/home/star/anaconda3/envs/pfcfuse/bin/python /home/star/whaiDir/PFCFuse/train.py
2.4.1+cu121
True
Model: PFCFuse
Number of epochs: 60
Epoch gap: 40
Learning rate: 0.0001
Weight decay: 0
Batch size: 1
GPU number: 0
Coefficient of MSE loss VF: 1.0
Coefficient of MSE loss IF: 1.0
Coefficient of RMI loss VF: 1.0
Coefficient of RMI loss IF: 1.0
Coefficient of Cosine loss VF: 1.0
Coefficient of Cosine loss IF: 1.0
Coefficient of Decomposition loss: 2.0
Coefficient of Total Variation loss: 5.0
Clip gradient norm value: 0.01
Optimization step: 20
Optimization gamma: 0.5
[Epoch 39/60] [Batch 6486/6487] [loss: 0.002562] ETA: 3:30:05.95/home/star/whaiDir/PFCFuse/utils/loss.py:15: UserWarning: Using a target size (torch.Size([1, 1, 128, 128])) that is different to the input size (torch.Size([])). This will likely lead to incorrect results due to broadcasting. Please ensure they have the same size.
loss_rmi=F.l1_loss(x_rmi_max, generate_img)
[Epoch 59/60] [Batch 6486/6487] [loss: 2.106119] ETA: 0:00:00.08
Process finished with exit code 0

View File

@ -0,0 +1,38 @@
/home/star/anaconda3/envs/pfcfuse/bin/python /home/star/whaiDir/PFCFuse/test_IVF.py
================================================================================
The test result of TNO :
19.png
05.png
21.png
18.png
15.png
22.png
14.png
13.png
08.png
01.png
02.png
03.png
25.png
17.png
11.png
16.png
06.png
07.png
09.png
10.png
12.png
23.png
24.png
20.png
04.png
EN SD SF MI SCD VIF Qabf SSIM
PFCFuse 7.01 40.4 15.51 1.55 1.75 0.66 0.54 0.96
================================================================================
Process finished with exit code 0

45
logs/status.md Normal file
View File

@ -0,0 +1,45 @@
PFCFuse
20241008
```
/home/star/anaconda3/envs/pfcfuse/bin/python /home/star/whaiDir/PFCFuse/test_IVF.py
================================================================================
The test result of TNO :
19.png
05.png
21.png
18.png
15.png
22.png
14.png
13.png
08.png
01.png
02.png
03.png
25.png
17.png
11.png
16.png
06.png
07.png
09.png
10.png
12.png
23.png
24.png
20.png
04.png
EN SD SF MI SCD VIF Qabf SSIM
PFCFuse 7.01 40.4 15.51 1.55 1.75 0.66 0.54 0.96
================================================================================
Process finished with exit code 0
```

View File

@ -1,5 +0,0 @@
# __version__ = "1.2.0.post1"
#
# from mamba_ssm.ops.selective_scan_interface import selective_scan_fn, mamba_inner_fn
# from mamba_ssm.modules.mamba_simple import Mamba
# from mamba_ssm.models.mixer_seq_simple import MambaLMHeadModel

View File

@ -1,15 +0,0 @@
from dataclasses import dataclass, field
@dataclass
class MambaConfig:
d_model: int = 2560
n_layer: int = 64
vocab_size: int = 50277
ssm_cfg: dict = field(default_factory=dict)
rms_norm: bool = True
residual_in_fp32: bool = True
fused_add_norm: bool = True
pad_vocab_size_multiple: int = 8
tie_embeddings: bool = True

View File

@ -1,265 +0,0 @@
# Copyright (c) 2023, Albert Gu, Tri Dao.
import math
from functools import partial
import json
import os
from collections import namedtuple
import torch
import torch.nn as nn
from mamba_ssm.models.config_mamba import MambaConfig
from mamba_ssm.modules.mamba_simple import Mamba, Block
from mamba_ssm.utils.generation import GenerationMixin
from mamba_ssm.utils.hf import load_config_hf, load_state_dict_hf
try:
from mamba_ssm.ops.triton.layernorm import RMSNorm, layer_norm_fn, rms_norm_fn
except ImportError:
RMSNorm, layer_norm_fn, rms_norm_fn = None, None, None
def create_block(
d_model,
ssm_cfg=None,
norm_epsilon=1e-5,
rms_norm=False,
residual_in_fp32=False,
fused_add_norm=False,
layer_idx=None,
device=None,
dtype=None,
):
if ssm_cfg is None:
ssm_cfg = {}
factory_kwargs = {"device": device, "dtype": dtype}
mixer_cls = partial(Mamba, layer_idx=layer_idx, **ssm_cfg, **factory_kwargs)
norm_cls = partial(
nn.LayerNorm if not rms_norm else RMSNorm, eps=norm_epsilon, **factory_kwargs
)
block = Block(
d_model,
mixer_cls,
norm_cls=norm_cls,
fused_add_norm=fused_add_norm,
residual_in_fp32=residual_in_fp32,
)
block.layer_idx = layer_idx
return block
# https://github.com/huggingface/transformers/blob/c28d04e9e252a1a099944e325685f14d242ecdcd/src/transformers/models/gpt2/modeling_gpt2.py#L454
def _init_weights(
module,
n_layer,
initializer_range=0.02, # Now only used for embedding layer.
rescale_prenorm_residual=True,
n_residuals_per_layer=1, # Change to 2 if we have MLP
):
if isinstance(module, nn.Linear):
if module.bias is not None:
if not getattr(module.bias, "_no_reinit", False):
nn.init.zeros_(module.bias)
elif isinstance(module, nn.Embedding):
nn.init.normal_(module.weight, std=initializer_range)
if rescale_prenorm_residual:
# Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme:
# > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale
# > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers.
# > -- GPT-2 :: https://openai.com/blog/better-language-models/
#
# Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py
for name, p in module.named_parameters():
if name in ["out_proj.weight", "fc2.weight"]:
# Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block
# Following Pytorch init, except scale by 1/sqrt(2 * n_layer)
# We need to reinit p since this code could be called multiple times
# Having just p *= scale would repeatedly scale it down
nn.init.kaiming_uniform_(p, a=math.sqrt(5))
with torch.no_grad():
p /= math.sqrt(n_residuals_per_layer * n_layer)
class MixerModel(nn.Module):
def __init__(
self,
d_model: int,
n_layer: int,
vocab_size: int,
ssm_cfg=None,
norm_epsilon: float = 1e-5,
rms_norm: bool = False,
initializer_cfg=None,
fused_add_norm=False,
residual_in_fp32=False,
device=None,
dtype=None,
) -> None:
factory_kwargs = {"device": device, "dtype": dtype}
super().__init__()
self.residual_in_fp32 = residual_in_fp32
self.embedding = nn.Embedding(vocab_size, d_model, **factory_kwargs)
# We change the order of residual and layer norm:
# Instead of LN -> Attn / MLP -> Add, we do:
# Add -> LN -> Attn / MLP / Mixer, returning both the residual branch (output of Add) and
# the main branch (output of MLP / Mixer). The model definition is unchanged.
# This is for performance reason: we can fuse add + layer_norm.
self.fused_add_norm = fused_add_norm
if self.fused_add_norm:
if layer_norm_fn is None or rms_norm_fn is None:
raise ImportError("Failed to import Triton LayerNorm / RMSNorm kernels")
self.layers = nn.ModuleList(
[
create_block(
d_model,
ssm_cfg=ssm_cfg,
norm_epsilon=norm_epsilon,
rms_norm=rms_norm,
residual_in_fp32=residual_in_fp32,
fused_add_norm=fused_add_norm,
layer_idx=i,
**factory_kwargs,
)
for i in range(n_layer)
]
)
self.norm_f = (nn.LayerNorm if not rms_norm else RMSNorm)(
d_model, eps=norm_epsilon, **factory_kwargs
)
self.apply(
partial(
_init_weights,
n_layer=n_layer,
**(initializer_cfg if initializer_cfg is not None else {}),
)
)
def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
return {
i: layer.allocate_inference_cache(batch_size, max_seqlen, dtype=dtype, **kwargs)
for i, layer in enumerate(self.layers)
}
def forward(self, input_ids, inference_params=None):
hidden_states = self.embedding(input_ids)
residual = None
for layer in self.layers:
hidden_states, residual = layer(
hidden_states, residual, inference_params=inference_params
)
if not self.fused_add_norm:
residual = (hidden_states + residual) if residual is not None else hidden_states
hidden_states = self.norm_f(residual.to(dtype=self.norm_f.weight.dtype))
else:
# Set prenorm=False here since we don't need the residual
fused_add_norm_fn = rms_norm_fn if isinstance(self.norm_f, RMSNorm) else layer_norm_fn
hidden_states = fused_add_norm_fn(
hidden_states,
self.norm_f.weight,
self.norm_f.bias,
eps=self.norm_f.eps,
residual=residual,
prenorm=False,
residual_in_fp32=self.residual_in_fp32,
)
return hidden_states
class MambaLMHeadModel(nn.Module, GenerationMixin):
def __init__(
self,
config: MambaConfig,
initializer_cfg=None,
device=None,
dtype=None,
) -> None:
self.config = config
d_model = config.d_model
n_layer = config.n_layer
vocab_size = config.vocab_size
ssm_cfg = config.ssm_cfg
rms_norm = config.rms_norm
residual_in_fp32 = config.residual_in_fp32
fused_add_norm = config.fused_add_norm
pad_vocab_size_multiple = config.pad_vocab_size_multiple
factory_kwargs = {"device": device, "dtype": dtype}
super().__init__()
if vocab_size % pad_vocab_size_multiple != 0:
vocab_size += pad_vocab_size_multiple - (vocab_size % pad_vocab_size_multiple)
self.backbone = MixerModel(
d_model=d_model,
n_layer=n_layer,
vocab_size=vocab_size,
ssm_cfg=ssm_cfg,
rms_norm=rms_norm,
initializer_cfg=initializer_cfg,
fused_add_norm=fused_add_norm,
residual_in_fp32=residual_in_fp32,
**factory_kwargs,
)
self.lm_head = nn.Linear(d_model, vocab_size, bias=False, **factory_kwargs)
# Initialize weights and apply final processing
self.apply(
partial(
_init_weights,
n_layer=n_layer,
**(initializer_cfg if initializer_cfg is not None else {}),
)
)
self.tie_weights()
def tie_weights(self):
if self.config.tie_embeddings:
self.lm_head.weight = self.backbone.embedding.weight
def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
return self.backbone.allocate_inference_cache(batch_size, max_seqlen, dtype=dtype, **kwargs)
def forward(self, input_ids, position_ids=None, inference_params=None, num_last_tokens=0):
"""
"position_ids" is just to be compatible with Transformer generation. We don't use it.
num_last_tokens: if > 0, only return the logits for the last n tokens
"""
hidden_states = self.backbone(input_ids, inference_params=inference_params)
if num_last_tokens > 0:
hidden_states = hidden_states[:, -num_last_tokens:]
lm_logits = self.lm_head(hidden_states)
CausalLMOutput = namedtuple("CausalLMOutput", ["logits"])
return CausalLMOutput(logits=lm_logits)
@classmethod
def from_pretrained(cls, pretrained_model_name, device=None, dtype=None, **kwargs):
config_data = load_config_hf(pretrained_model_name)
config = MambaConfig(**config_data)
model = cls(config, device=device, dtype=dtype, **kwargs)
model.load_state_dict(load_state_dict_hf(pretrained_model_name, device=device, dtype=dtype))
return model
def save_pretrained(self, save_directory):
"""
Minimal implementation of save_pretrained for MambaLMHeadModel.
Save the model and its configuration file to a directory.
"""
# Ensure save_directory exists
if not os.path.exists(save_directory):
os.makedirs(save_directory)
# Save the model's state_dict
model_path = os.path.join(save_directory, 'pytorch_model.bin')
torch.save(self.state_dict(), model_path)
# Save the configuration of the model
config_path = os.path.join(save_directory, 'config.json')
with open(config_path, 'w') as f:
json.dump(self.config.__dict__, f)

View File

@ -1,353 +0,0 @@
# Copyright (c) 2023, Tri Dao, Albert Gu.
import math
from typing import Optional
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor
from einops import rearrange, repeat
from mamba_ssm.ops.selective_scan_interface import selective_scan_fn, mamba_inner_fn
try:
from causal_conv1d import causal_conv1d_fn, causal_conv1d_update
except ImportError:
causal_conv1d_fn, causal_conv1d_update = None, None
try:
from mamba_ssm.ops.triton.selective_state_update import selective_state_update
except ImportError:
selective_state_update = None
try:
from mamba_ssm.ops.triton.layernorm import RMSNorm, layer_norm_fn, rms_norm_fn
except ImportError:
RMSNorm, layer_norm_fn, rms_norm_fn = None, None, None
class Mamba(nn.Module):
def __init__(
self,
d_model,
d_state=16,
d_conv=4,
expand=2,
dt_rank="auto",
dt_min=0.001,
dt_max=0.1,
dt_init="random",
dt_scale=1.0,
dt_init_floor=1e-4,
conv_bias=True,
bias=False,
use_fast_path=True, # Fused kernel options
layer_idx=None,
device=None,
dtype=None,
):
factory_kwargs = {"device": device, "dtype": dtype}
super().__init__()
self.d_model = d_model
self.d_state = d_state
self.d_conv = d_conv
self.expand = expand
self.d_inner = int(self.expand * self.d_model)
self.dt_rank = math.ceil(self.d_model / 16) if dt_rank == "auto" else dt_rank
self.use_fast_path = use_fast_path
self.layer_idx = layer_idx
self.in_proj = nn.Linear(self.d_model, self.d_inner * 2, bias=bias, **factory_kwargs)
self.conv1d = nn.Conv1d(
in_channels=self.d_inner,
out_channels=self.d_inner,
bias=conv_bias,
kernel_size=d_conv,
groups=self.d_inner,
padding=d_conv - 1,
**factory_kwargs,
)
self.activation = "silu"
self.act = nn.SiLU()
self.x_proj = nn.Linear(
self.d_inner, self.dt_rank + self.d_state * 2, bias=False, **factory_kwargs
)
self.dt_proj = nn.Linear(self.dt_rank, self.d_inner, bias=True, **factory_kwargs)
# Initialize special dt projection to preserve variance at initialization
dt_init_std = self.dt_rank**-0.5 * dt_scale
if dt_init == "constant":
nn.init.constant_(self.dt_proj.weight, dt_init_std)
elif dt_init == "random":
nn.init.uniform_(self.dt_proj.weight, -dt_init_std, dt_init_std)
else:
raise NotImplementedError
# Initialize dt bias so that F.softplus(dt_bias) is between dt_min and dt_max
dt = torch.exp(
torch.rand(self.d_inner, **factory_kwargs) * (math.log(dt_max) - math.log(dt_min))
+ math.log(dt_min)
).clamp(min=dt_init_floor)
# Inverse of softplus: https://github.com/pytorch/pytorch/issues/72759
inv_dt = dt + torch.log(-torch.expm1(-dt))
with torch.no_grad():
self.dt_proj.bias.copy_(inv_dt)
# Our initialization would set all Linear.bias to zero, need to mark this one as _no_reinit
self.dt_proj.bias._no_reinit = True
# S4D real initialization
A = repeat(
torch.arange(1, self.d_state + 1, dtype=torch.float32, device=device),
"n -> d n",
d=self.d_inner,
).contiguous()
A_log = torch.log(A) # Keep A_log in fp32
self.A_log = nn.Parameter(A_log)
self.A_log._no_weight_decay = True
# D "skip" parameter
self.D = nn.Parameter(torch.ones(self.d_inner, device=device)) # Keep in fp32
self.D._no_weight_decay = True
self.out_proj = nn.Linear(self.d_inner, self.d_model, bias=bias, **factory_kwargs)
def forward(self, hidden_states, inference_params=None):
"""
hidden_states: (B, L, D)
Returns: same shape as hidden_states
"""
batch, seqlen, dim = hidden_states.shape
conv_state, ssm_state = None, None
if inference_params is not None:
conv_state, ssm_state = self._get_states_from_cache(inference_params, batch)
if inference_params.seqlen_offset > 0:
# The states are updated inplace
out, _, _ = self.step(hidden_states, conv_state, ssm_state)
return out
# We do matmul and transpose BLH -> HBL at the same time
xz = rearrange(
self.in_proj.weight @ rearrange(hidden_states, "b l d -> d (b l)"),
"d (b l) -> b d l",
l=seqlen,
)
if self.in_proj.bias is not None:
xz = xz + rearrange(self.in_proj.bias.to(dtype=xz.dtype), "d -> d 1")
A = -torch.exp(self.A_log.float()) # (d_inner, d_state)
# In the backward pass we write dx and dz next to each other to avoid torch.cat
if self.use_fast_path and causal_conv1d_fn is not None and inference_params is None: # Doesn't support outputting the states
out = mamba_inner_fn(
xz,
self.conv1d.weight,
self.conv1d.bias,
self.x_proj.weight,
self.dt_proj.weight,
self.out_proj.weight,
self.out_proj.bias,
A,
None, # input-dependent B
None, # input-dependent C
self.D.float(),
delta_bias=self.dt_proj.bias.float(),
delta_softplus=True,
)
else:
x, z = xz.chunk(2, dim=1)
# Compute short convolution
if conv_state is not None:
# If we just take x[:, :, -self.d_conv :], it will error if seqlen < self.d_conv
# Instead F.pad will pad with zeros if seqlen < self.d_conv, and truncate otherwise.
conv_state.copy_(F.pad(x, (self.d_conv - x.shape[-1], 0))) # Update state (B D W)
if causal_conv1d_fn is None:
x = self.act(self.conv1d(x)[..., :seqlen])
else:
assert self.activation in ["silu", "swish"]
x = causal_conv1d_fn(
x=x,
weight=rearrange(self.conv1d.weight, "d 1 w -> d w"),
bias=self.conv1d.bias,
activation=self.activation,
)
# We're careful here about the layout, to avoid extra transposes.
# We want dt to have d as the slowest moving dimension
# and L as the fastest moving dimension, since those are what the ssm_scan kernel expects.
x_dbl = self.x_proj(rearrange(x, "b d l -> (b l) d")) # (bl d)
dt, B, C = torch.split(x_dbl, [self.dt_rank, self.d_state, self.d_state], dim=-1)
dt = self.dt_proj.weight @ dt.t()
dt = rearrange(dt, "d (b l) -> b d l", l=seqlen)
B = rearrange(B, "(b l) dstate -> b dstate l", l=seqlen).contiguous()
C = rearrange(C, "(b l) dstate -> b dstate l", l=seqlen).contiguous()
assert self.activation in ["silu", "swish"]
y = selective_scan_fn(
x,
dt,
A,
B,
C,
self.D.float(),
z=z,
delta_bias=self.dt_proj.bias.float(),
delta_softplus=True,
return_last_state=ssm_state is not None,
)
if ssm_state is not None:
y, last_state = y
ssm_state.copy_(last_state)
y = rearrange(y, "b d l -> b l d")
out = self.out_proj(y)
return out
def step(self, hidden_states, conv_state, ssm_state):
dtype = hidden_states.dtype
assert hidden_states.shape[1] == 1, "Only support decoding with 1 token at a time for now"
xz = self.in_proj(hidden_states.squeeze(1)) # (B 2D)
x, z = xz.chunk(2, dim=-1) # (B D)
# Conv step
if causal_conv1d_update is None:
conv_state.copy_(torch.roll(conv_state, shifts=-1, dims=-1)) # Update state (B D W)
conv_state[:, :, -1] = x
x = torch.sum(conv_state * rearrange(self.conv1d.weight, "d 1 w -> d w"), dim=-1) # (B D)
if self.conv1d.bias is not None:
x = x + self.conv1d.bias
x = self.act(x).to(dtype=dtype)
else:
x = causal_conv1d_update(
x,
conv_state,
rearrange(self.conv1d.weight, "d 1 w -> d w"),
self.conv1d.bias,
self.activation,
)
x_db = self.x_proj(x) # (B dt_rank+2*d_state)
dt, B, C = torch.split(x_db, [self.dt_rank, self.d_state, self.d_state], dim=-1)
# Don't add dt_bias here
dt = F.linear(dt, self.dt_proj.weight) # (B d_inner)
A = -torch.exp(self.A_log.float()) # (d_inner, d_state)
# SSM step
if selective_state_update is None:
# Discretize A and B
dt = F.softplus(dt + self.dt_proj.bias.to(dtype=dt.dtype))
dA = torch.exp(torch.einsum("bd,dn->bdn", dt, A))
dB = torch.einsum("bd,bn->bdn", dt, B)
ssm_state.copy_(ssm_state * dA + rearrange(x, "b d -> b d 1") * dB)
y = torch.einsum("bdn,bn->bd", ssm_state.to(dtype), C)
y = y + self.D.to(dtype) * x
y = y * self.act(z) # (B D)
else:
y = selective_state_update(
ssm_state, x, dt, A, B, C, self.D, z=z, dt_bias=self.dt_proj.bias, dt_softplus=True
)
out = self.out_proj(y)
return out.unsqueeze(1), conv_state, ssm_state
def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
device = self.out_proj.weight.device
conv_dtype = self.conv1d.weight.dtype if dtype is None else dtype
conv_state = torch.zeros(
batch_size, self.d_model * self.expand, self.d_conv, device=device, dtype=conv_dtype
)
ssm_dtype = self.dt_proj.weight.dtype if dtype is None else dtype
# ssm_dtype = torch.float32
ssm_state = torch.zeros(
batch_size, self.d_model * self.expand, self.d_state, device=device, dtype=ssm_dtype
)
return conv_state, ssm_state
def _get_states_from_cache(self, inference_params, batch_size, initialize_states=False):
assert self.layer_idx is not None
if self.layer_idx not in inference_params.key_value_memory_dict:
batch_shape = (batch_size,)
conv_state = torch.zeros(
batch_size,
self.d_model * self.expand,
self.d_conv,
device=self.conv1d.weight.device,
dtype=self.conv1d.weight.dtype,
)
ssm_state = torch.zeros(
batch_size,
self.d_model * self.expand,
self.d_state,
device=self.dt_proj.weight.device,
dtype=self.dt_proj.weight.dtype,
# dtype=torch.float32,
)
inference_params.key_value_memory_dict[self.layer_idx] = (conv_state, ssm_state)
else:
conv_state, ssm_state = inference_params.key_value_memory_dict[self.layer_idx]
# TODO: What if batch size changes between generation, and we reuse the same states?
if initialize_states:
conv_state.zero_()
ssm_state.zero_()
return conv_state, ssm_state
class Block(nn.Module):
def __init__(
self, dim, mixer_cls, norm_cls=nn.LayerNorm, fused_add_norm=False, residual_in_fp32=False
):
"""
Simple block wrapping a mixer class with LayerNorm/RMSNorm and residual connection"
This Block has a slightly different structure compared to a regular
prenorm Transformer block.
The standard block is: LN -> MHA/MLP -> Add.
[Ref: https://arxiv.org/abs/2002.04745]
Here we have: Add -> LN -> Mixer, returning both
the hidden_states (output of the mixer) and the residual.
This is purely for performance reasons, as we can fuse add and LayerNorm.
The residual needs to be provided (except for the very first block).
"""
super().__init__()
self.residual_in_fp32 = residual_in_fp32
self.fused_add_norm = fused_add_norm
self.mixer = mixer_cls(dim)
self.norm = norm_cls(dim)
if self.fused_add_norm:
assert RMSNorm is not None, "RMSNorm import fails"
assert isinstance(
self.norm, (nn.LayerNorm, RMSNorm)
), "Only LayerNorm and RMSNorm are supported for fused_add_norm"
def forward(
self, hidden_states: Tensor, residual: Optional[Tensor] = None, inference_params=None
):
r"""Pass the input through the encoder layer.
Args:
hidden_states: the sequence to the encoder layer (required).
residual: hidden_states = Mixer(LN(residual))
"""
if not self.fused_add_norm:
residual = (hidden_states + residual) if residual is not None else hidden_states
hidden_states = self.norm(residual.to(dtype=self.norm.weight.dtype))
if self.residual_in_fp32:
residual = residual.to(torch.float32)
else:
fused_add_norm_fn = rms_norm_fn if isinstance(self.norm, RMSNorm) else layer_norm_fn
hidden_states, residual = fused_add_norm_fn(
hidden_states,
self.norm.weight,
self.norm.bias,
residual=residual,
prenorm=True,
residual_in_fp32=self.residual_in_fp32,
eps=self.norm.eps,
)
hidden_states = self.mixer(hidden_states, inference_params=inference_params)
return hidden_states, residual
def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
return self.mixer.allocate_inference_cache(batch_size, max_seqlen, dtype=dtype, **kwargs)

View File

@ -1,372 +0,0 @@
# Copyright (c) 2023, Tri Dao, Albert Gu.
import torch
import torch.nn.functional as F
from torch.cuda.amp import custom_bwd, custom_fwd
from einops import rearrange, repeat
try:
from causal_conv1d import causal_conv1d_fn
import causal_conv1d_cuda
except ImportError:
causal_conv1d_fn = None
causal_conv1d_cuda = None
# import selective_scan_cuda
class SelectiveScanFn(torch.autograd.Function):
@staticmethod
def forward(ctx, u, delta, A, B, C, D=None, z=None, delta_bias=None, delta_softplus=False,
return_last_state=False):
if u.stride(-1) != 1:
u = u.contiguous()
if delta.stride(-1) != 1:
delta = delta.contiguous()
if D is not None:
D = D.contiguous()
if B.stride(-1) != 1:
B = B.contiguous()
if C.stride(-1) != 1:
C = C.contiguous()
if z is not None and z.stride(-1) != 1:
z = z.contiguous()
if B.dim() == 3:
B = rearrange(B, "b dstate l -> b 1 dstate l")
ctx.squeeze_B = True
if C.dim() == 3:
C = rearrange(C, "b dstate l -> b 1 dstate l")
ctx.squeeze_C = True
out, x, *rest = selective_scan_cuda.fwd(u, delta, A, B, C, D, z, delta_bias, delta_softplus)
ctx.delta_softplus = delta_softplus
ctx.has_z = z is not None
last_state = x[:, :, -1, 1::2] # (batch, dim, dstate)
if not ctx.has_z:
ctx.save_for_backward(u, delta, A, B, C, D, delta_bias, x)
return out if not return_last_state else (out, last_state)
else:
ctx.save_for_backward(u, delta, A, B, C, D, z, delta_bias, x, out)
out_z = rest[0]
return out_z if not return_last_state else (out_z, last_state)
@staticmethod
def backward(ctx, dout, *args):
if not ctx.has_z:
u, delta, A, B, C, D, delta_bias, x = ctx.saved_tensors
z = None
out = None
else:
u, delta, A, B, C, D, z, delta_bias, x, out = ctx.saved_tensors
if dout.stride(-1) != 1:
dout = dout.contiguous()
# The kernel supports passing in a pre-allocated dz (e.g., in case we want to fuse the
# backward of selective_scan_cuda with the backward of chunk).
# Here we just pass in None and dz will be allocated in the C++ code.
du, ddelta, dA, dB, dC, dD, ddelta_bias, *rest = selective_scan_cuda.bwd(
u, delta, A, B, C, D, z, delta_bias, dout, x, out, None, ctx.delta_softplus,
False # option to recompute out_z, not used here
)
dz = rest[0] if ctx.has_z else None
dB = dB.squeeze(1) if getattr(ctx, "squeeze_B", False) else dB
dC = dC.squeeze(1) if getattr(ctx, "squeeze_C", False) else dC
return (du, ddelta, dA, dB, dC,
dD if D is not None else None,
dz,
ddelta_bias if delta_bias is not None else None,
None,
None)
# def selective_scan_fn(u, delta, A, B, C, D=None, z=None, delta_bias=None, delta_softplus=False,
# return_last_state=False):
# """if return_last_state is True, returns (out, last_state)
# last_state has shape (batch, dim, dstate). Note that the gradient of the last state is
# not considered in the backward pass.
# """
# return SelectiveScanFn.apply(u, delta, A, B, C, D, z, delta_bias, delta_softplus, return_last_state)
def selective_scan_fn(u, delta, A, B, C, D=None, z=None, delta_bias=None, delta_softplus=False,
return_last_state=False):
"""if return_last_state is True, returns (out, last_state)
last_state has shape (batch, dim, dstate). Note that the gradient of the last state is
not considered in the backward pass.
"""
return selective_scan_ref(u, delta, A, B, C, D, z, delta_bias, delta_softplus, return_last_state)
def selective_scan_ref(u, delta, A, B, C, D=None, z=None, delta_bias=None, delta_softplus=False,
return_last_state=False):
"""
u: r(B D L)
delta: r(B D L)
A: c(D N) or r(D N)
B: c(D N) or r(B N L) or r(B N 2L) or r(B G N L) or (B G N L)
C: c(D N) or r(B N L) or r(B N 2L) or r(B G N L) or (B G N L)
D: r(D)
z: r(B D L)
delta_bias: r(D), fp32
out: r(B D L)
last_state (optional): r(B D dstate) or c(B D dstate)
"""
dtype_in = u.dtype
u = u.float()
delta = delta.float()
if delta_bias is not None:
delta = delta + delta_bias[..., None].float()
if delta_softplus:
delta = F.softplus(delta)
batch, dim, dstate = u.shape[0], A.shape[0], A.shape[1]
is_variable_B = B.dim() >= 3
is_variable_C = C.dim() >= 3
if A.is_complex():
if is_variable_B:
B = torch.view_as_complex(rearrange(B.float(), "... (L two) -> ... L two", two=2))
if is_variable_C:
C = torch.view_as_complex(rearrange(C.float(), "... (L two) -> ... L two", two=2))
else:
B = B.float()
C = C.float()
x = A.new_zeros((batch, dim, dstate))
ys = []
deltaA = torch.exp(torch.einsum('bdl,dn->bdln', delta, A))
if not is_variable_B:
deltaB_u = torch.einsum('bdl,dn,bdl->bdln', delta, B, u)
else:
if B.dim() == 3:
deltaB_u = torch.einsum('bdl,bnl,bdl->bdln', delta, B, u)
else:
B = repeat(B, "B G N L -> B (G H) N L", H=dim // B.shape[1])
deltaB_u = torch.einsum('bdl,bdnl,bdl->bdln', delta, B, u)
if is_variable_C and C.dim() == 4:
C = repeat(C, "B G N L -> B (G H) N L", H=dim // C.shape[1])
last_state = None
for i in range(u.shape[2]):
x = deltaA[:, :, i] * x + deltaB_u[:, :, i]
if not is_variable_C:
y = torch.einsum('bdn,dn->bd', x, C)
else:
if C.dim() == 3:
y = torch.einsum('bdn,bn->bd', x, C[:, :, i])
else:
y = torch.einsum('bdn,bdn->bd', x, C[:, :, :, i])
if i == u.shape[2] - 1:
last_state = x
if y.is_complex():
y = y.real * 2
ys.append(y)
y = torch.stack(ys, dim=2) # (batch dim L)
out = y if D is None else y + u * rearrange(D, "d -> d 1")
if z is not None:
out = out * F.silu(z)
out = out.to(dtype=dtype_in)
return out if not return_last_state else (out, last_state)
class MambaInnerFn(torch.autograd.Function):
@staticmethod
@custom_fwd
def forward(ctx, xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weight,
out_proj_weight, out_proj_bias,
A, B=None, C=None, D=None, delta_bias=None, B_proj_bias=None,
C_proj_bias=None, delta_softplus=True, checkpoint_lvl=1):
"""
xz: (batch, dim, seqlen)
"""
assert causal_conv1d_cuda is not None, "causal_conv1d_cuda is not available. Please install causal-conv1d."
assert checkpoint_lvl in [0, 1]
L = xz.shape[-1]
delta_rank = delta_proj_weight.shape[1]
d_state = A.shape[-1] * (1 if not A.is_complex() else 2)
if torch.is_autocast_enabled():
x_proj_weight = x_proj_weight.to(dtype=torch.get_autocast_gpu_dtype())
delta_proj_weight = delta_proj_weight.to(dtype=torch.get_autocast_gpu_dtype())
out_proj_weight = out_proj_weight.to(dtype=torch.get_autocast_gpu_dtype())
out_proj_bias = (out_proj_bias.to(dtype=torch.get_autocast_gpu_dtype())
if out_proj_bias is not None else None)
if xz.stride(-1) != 1:
xz = xz.contiguous()
conv1d_weight = rearrange(conv1d_weight, "d 1 w -> d w")
x, z = xz.chunk(2, dim=1)
conv1d_bias = conv1d_bias.contiguous() if conv1d_bias is not None else None
conv1d_out = causal_conv1d_cuda.causal_conv1d_fwd(
x, conv1d_weight, conv1d_bias, None, None, None, True
)
# We're being very careful here about the layout, to avoid extra transposes.
# We want delta to have d as the slowest moving dimension
# and L as the fastest moving dimension, since those are what the ssm_scan kernel expects.
x_dbl = F.linear(rearrange(conv1d_out, 'b d l -> (b l) d'), x_proj_weight) # (bl d)
delta = rearrange(delta_proj_weight @ x_dbl[:, :delta_rank].t(), "d (b l) -> b d l", l = L)
ctx.is_variable_B = B is None
ctx.is_variable_C = C is None
ctx.B_proj_bias_is_None = B_proj_bias is None
ctx.C_proj_bias_is_None = C_proj_bias is None
if B is None: # variable B
B = x_dbl[:, delta_rank:delta_rank + d_state] # (bl dstate)
if B_proj_bias is not None:
B = B + B_proj_bias.to(dtype=B.dtype)
if not A.is_complex():
# B = rearrange(B, "(b l) dstate -> b dstate l", l=L).contiguous()
B = rearrange(B, "(b l) dstate -> b 1 dstate l", l=L).contiguous()
else:
B = rearrange(B, "(b l) (dstate two) -> b 1 dstate (l two)", l=L, two=2).contiguous()
else:
if B.stride(-1) != 1:
B = B.contiguous()
if C is None: # variable C
C = x_dbl[:, -d_state:] # (bl dstate)
if C_proj_bias is not None:
C = C + C_proj_bias.to(dtype=C.dtype)
if not A.is_complex():
# C = rearrange(C, "(b l) dstate -> b dstate l", l=L).contiguous()
C = rearrange(C, "(b l) dstate -> b 1 dstate l", l=L).contiguous()
else:
C = rearrange(C, "(b l) (dstate two) -> b 1 dstate (l two)", l=L, two=2).contiguous()
else:
if C.stride(-1) != 1:
C = C.contiguous()
if D is not None:
D = D.contiguous()
out, scan_intermediates, out_z = selective_scan_cuda.fwd(
conv1d_out, delta, A, B, C, D, z, delta_bias, delta_softplus
)
ctx.delta_softplus = delta_softplus
ctx.out_proj_bias_is_None = out_proj_bias is None
ctx.checkpoint_lvl = checkpoint_lvl
if checkpoint_lvl >= 1: # Will recompute conv1d_out and delta in the backward pass
conv1d_out, delta = None, None
ctx.save_for_backward(xz, conv1d_weight, conv1d_bias, x_dbl, x_proj_weight,
delta_proj_weight, out_proj_weight, conv1d_out, delta,
A, B, C, D, delta_bias, scan_intermediates, out)
return F.linear(rearrange(out_z, "b d l -> b l d"), out_proj_weight, out_proj_bias)
@staticmethod
@custom_bwd
def backward(ctx, dout):
# dout: (batch, seqlen, dim)
assert causal_conv1d_cuda is not None, "causal_conv1d_cuda is not available. Please install causal-conv1d."
(xz, conv1d_weight, conv1d_bias, x_dbl, x_proj_weight, delta_proj_weight, out_proj_weight,
conv1d_out, delta, A, B, C, D, delta_bias, scan_intermediates, out) = ctx.saved_tensors
L = xz.shape[-1]
delta_rank = delta_proj_weight.shape[1]
d_state = A.shape[-1] * (1 if not A.is_complex() else 2)
x, z = xz.chunk(2, dim=1)
if dout.stride(-1) != 1:
dout = dout.contiguous()
if ctx.checkpoint_lvl == 1:
conv1d_out = causal_conv1d_cuda.causal_conv1d_fwd(
x, conv1d_weight, conv1d_bias, None, None, None, True
)
delta = rearrange(delta_proj_weight @ x_dbl[:, :delta_rank].t(),
"d (b l) -> b d l", l = L)
# The kernel supports passing in a pre-allocated dz (e.g., in case we want to fuse the
# backward of selective_scan_cuda with the backward of chunk).
dxz = torch.empty_like(xz) # (batch, dim, seqlen)
dx, dz = dxz.chunk(2, dim=1)
dout = rearrange(dout, "b l e -> e (b l)")
dout_y = rearrange(out_proj_weight.t() @ dout, "d (b l) -> b d l", l=L)
dconv1d_out, ddelta, dA, dB, dC, dD, ddelta_bias, dz, out_z = selective_scan_cuda.bwd(
conv1d_out, delta, A, B, C, D, z, delta_bias, dout_y, scan_intermediates, out, dz,
ctx.delta_softplus,
True # option to recompute out_z
)
dout_proj_weight = torch.einsum("eB,dB->ed", dout, rearrange(out_z, "b d l -> d (b l)"))
dout_proj_bias = dout.sum(dim=(0, 1)) if not ctx.out_proj_bias_is_None else None
dD = dD if D is not None else None
dx_dbl = torch.empty_like(x_dbl)
dB_proj_bias = None
if ctx.is_variable_B:
if not A.is_complex():
dB = rearrange(dB, "b 1 dstate l -> (b l) dstate").contiguous()
else:
dB = rearrange(dB, "b 1 dstate (l two) -> (b l) (dstate two)", two=2).contiguous()
dB_proj_bias = dB.sum(0) if not ctx.B_proj_bias_is_None else None
dx_dbl[:, delta_rank:delta_rank + d_state] = dB # (bl d)
dB = None
dC_proj_bias = None
if ctx.is_variable_C:
if not A.is_complex():
dC = rearrange(dC, "b 1 dstate l -> (b l) dstate").contiguous()
else:
dC = rearrange(dC, "b 1 dstate (l two) -> (b l) (dstate two)", two=2).contiguous()
dC_proj_bias = dC.sum(0) if not ctx.C_proj_bias_is_None else None
dx_dbl[:, -d_state:] = dC # (bl d)
dC = None
ddelta = rearrange(ddelta, "b d l -> d (b l)")
ddelta_proj_weight = torch.einsum("dB,Br->dr", ddelta, x_dbl[:, :delta_rank])
dx_dbl[:, :delta_rank] = torch.einsum("dB,dr->Br", ddelta, delta_proj_weight)
dconv1d_out = rearrange(dconv1d_out, "b d l -> d (b l)")
dx_proj_weight = torch.einsum("Br,Bd->rd", dx_dbl, rearrange(conv1d_out, "b d l -> (b l) d"))
dconv1d_out = torch.addmm(dconv1d_out, x_proj_weight.t(), dx_dbl.t(), out=dconv1d_out)
dconv1d_out = rearrange(dconv1d_out, "d (b l) -> b d l", b=x.shape[0], l=x.shape[-1])
# The kernel supports passing in a pre-allocated dx (e.g., in case we want to fuse the
# backward of conv1d with the backward of chunk).
dx, dconv1d_weight, dconv1d_bias, *_ = causal_conv1d_cuda.causal_conv1d_bwd(
x, conv1d_weight, conv1d_bias, dconv1d_out, None, None, None, dx, False, True
)
dconv1d_bias = dconv1d_bias if conv1d_bias is not None else None
dconv1d_weight = rearrange(dconv1d_weight, "d w -> d 1 w")
return (dxz, dconv1d_weight, dconv1d_bias, dx_proj_weight, ddelta_proj_weight,
dout_proj_weight, dout_proj_bias,
dA, dB, dC, dD,
ddelta_bias if delta_bias is not None else None,
dB_proj_bias, dC_proj_bias, None)
# def mamba_inner_fn(
# xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weight,
# out_proj_weight, out_proj_bias,
# A, B=None, C=None, D=None, delta_bias=None, B_proj_bias=None,
# C_proj_bias=None, delta_softplus=True
# ):
# return MambaInnerFn.apply(xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weight,
# out_proj_weight, out_proj_bias,
# A, B, C, D, delta_bias, B_proj_bias, C_proj_bias, delta_softplus)
def mamba_inner_fn(
xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weight,
out_proj_weight, out_proj_bias,
A, B=None, C=None, D=None, delta_bias=None, B_proj_bias=None,
C_proj_bias=None, delta_softplus=True
):
return mamba_inner_ref(xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weight,
out_proj_weight, out_proj_bias,
A, B, C, D, delta_bias, B_proj_bias, C_proj_bias, delta_softplus)
def mamba_inner_ref(
xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weight,
out_proj_weight, out_proj_bias,
A, B=None, C=None, D=None, delta_bias=None, B_proj_bias=None,
C_proj_bias=None, delta_softplus=True
):
assert causal_conv1d_fn is not None, "causal_conv1d_fn is not available. Please install causal-conv1d."
L = xz.shape[-1]
delta_rank = delta_proj_weight.shape[1]
d_state = A.shape[-1] * (1 if not A.is_complex() else 2)
x, z = xz.chunk(2, dim=1)
x = causal_conv1d_fn(x, rearrange(conv1d_weight, "d 1 w -> d w"), conv1d_bias, activation="silu")
# We're being very careful here about the layout, to avoid extra transposes.
# We want delta to have d as the slowest moving dimension
# and L as the fastest moving dimension, since those are what the ssm_scan kernel expects.
x_dbl = F.linear(rearrange(x, 'b d l -> (b l) d'), x_proj_weight) # (bl d)
delta = delta_proj_weight @ x_dbl[:, :delta_rank].t()
delta = rearrange(delta, "d (b l) -> b d l", l=L)
if B is None: # variable B
B = x_dbl[:, delta_rank:delta_rank + d_state] # (bl d)
if B_proj_bias is not None:
B = B + B_proj_bias.to(dtype=B.dtype)
if not A.is_complex():
B = rearrange(B, "(b l) dstate -> b dstate l", l=L).contiguous()
else:
B = rearrange(B, "(b l) (dstate two) -> b dstate (l two)", l=L, two=2).contiguous()
if C is None: # variable B
C = x_dbl[:, -d_state:] # (bl d)
if C_proj_bias is not None:
C = C + C_proj_bias.to(dtype=C.dtype)
if not A.is_complex():
C = rearrange(C, "(b l) dstate -> b dstate l", l=L).contiguous()
else:
C = rearrange(C, "(b l) (dstate two) -> b dstate (l two)", l=L, two=2).contiguous()
y = selective_scan_fn(x, delta, A, B, C, D, z=z, delta_bias=delta_bias, delta_softplus=True)
return F.linear(rearrange(y, "b d l -> b l d"), out_proj_weight, out_proj_bias)

View File

@ -1,635 +0,0 @@
# Copyright (c) 2023, Tri Dao.
# Implement residual + layer_norm / rms_norm.
# Based on the Triton LayerNorm tutorial: https://triton-lang.org/main/getting-started/tutorials/05-layer-norm.html
# For the backward pass, we keep weight_grad and bias_grad in registers and accumulate.
# This is faster for dimensions up to 8k, but after that it's much slower due to register spilling.
# The models we train have hidden dim up to 8k anyway (e.g. Llama 70B), so this is fine.
import math
import torch
import torch.nn.functional as F
from torch.cuda.amp import custom_fwd, custom_bwd
import triton
import triton.language as tl
def layer_norm_ref(x, weight, bias, residual=None, eps=1e-6, prenorm=False, upcast=False):
dtype = x.dtype
if upcast:
weight = weight.float()
bias = bias.float() if bias is not None else None
if upcast:
x = x.float()
residual = residual.float() if residual is not None else residual
if residual is not None:
x = (x + residual).to(x.dtype)
out = F.layer_norm(x.to(weight.dtype), x.shape[-1:], weight=weight, bias=bias, eps=eps).to(
dtype
)
return out if not prenorm else (out, x)
def rms_norm_ref(x, weight, bias, residual=None, eps=1e-6, prenorm=False, upcast=False):
dtype = x.dtype
if upcast:
weight = weight.float()
bias = bias.float() if bias is not None else None
if upcast:
x = x.float()
residual = residual.float() if residual is not None else residual
if residual is not None:
x = (x + residual).to(x.dtype)
rstd = 1 / torch.sqrt((x.square()).mean(dim=-1, keepdim=True) + eps)
out = (x * rstd * weight) + bias if bias is not None else (x * rstd * weight)
out = out.to(dtype)
return out if not prenorm else (out, x)
@triton.autotune(
configs=[
triton.Config({}, num_warps=1),
triton.Config({}, num_warps=2),
triton.Config({}, num_warps=4),
triton.Config({}, num_warps=8),
triton.Config({}, num_warps=16),
triton.Config({}, num_warps=32),
],
key=["N", "HAS_RESIDUAL", "STORE_RESIDUAL_OUT", "IS_RMS_NORM", "HAS_BIAS"],
)
# @triton.heuristics({"HAS_BIAS": lambda args: args["B"] is not None})
# @triton.heuristics({"HAS_RESIDUAL": lambda args: args["RESIDUAL"] is not None})
@triton.jit
def _layer_norm_fwd_1pass_kernel(
X, # pointer to the input
Y, # pointer to the output
W, # pointer to the weights
B, # pointer to the biases
RESIDUAL, # pointer to the residual
RESIDUAL_OUT, # pointer to the residual
Mean, # pointer to the mean
Rstd, # pointer to the 1/std
stride_x_row, # how much to increase the pointer when moving by 1 row
stride_y_row,
stride_res_row,
stride_res_out_row,
N, # number of columns in X
eps, # epsilon to avoid division by zero
IS_RMS_NORM: tl.constexpr,
BLOCK_N: tl.constexpr,
HAS_RESIDUAL: tl.constexpr,
STORE_RESIDUAL_OUT: tl.constexpr,
HAS_BIAS: tl.constexpr,
):
# Map the program id to the row of X and Y it should compute.
row = tl.program_id(0)
X += row * stride_x_row
Y += row * stride_y_row
if HAS_RESIDUAL:
RESIDUAL += row * stride_res_row
if STORE_RESIDUAL_OUT:
RESIDUAL_OUT += row * stride_res_out_row
# Compute mean and variance
cols = tl.arange(0, BLOCK_N)
x = tl.load(X + cols, mask=cols < N, other=0.0).to(tl.float32)
if HAS_RESIDUAL:
residual = tl.load(RESIDUAL + cols, mask=cols < N, other=0.0).to(tl.float32)
x += residual
if STORE_RESIDUAL_OUT:
tl.store(RESIDUAL_OUT + cols, x, mask=cols < N)
if not IS_RMS_NORM:
mean = tl.sum(x, axis=0) / N
tl.store(Mean + row, mean)
xbar = tl.where(cols < N, x - mean, 0.0)
var = tl.sum(xbar * xbar, axis=0) / N
else:
xbar = tl.where(cols < N, x, 0.0)
var = tl.sum(xbar * xbar, axis=0) / N
rstd = 1 / tl.sqrt(var + eps)
tl.store(Rstd + row, rstd)
# Normalize and apply linear transformation
mask = cols < N
w = tl.load(W + cols, mask=mask).to(tl.float32)
if HAS_BIAS:
b = tl.load(B + cols, mask=mask).to(tl.float32)
x_hat = (x - mean) * rstd if not IS_RMS_NORM else x * rstd
y = x_hat * w + b if HAS_BIAS else x_hat * w
# Write output
tl.store(Y + cols, y, mask=mask)
def _layer_norm_fwd(
x, weight, bias, eps, residual=None, out_dtype=None, residual_dtype=None, is_rms_norm=False
):
if residual is not None:
residual_dtype = residual.dtype
M, N = x.shape
assert x.stride(-1) == 1
if residual is not None:
assert residual.stride(-1) == 1
assert residual.shape == (M, N)
assert weight.shape == (N,)
assert weight.stride(-1) == 1
if bias is not None:
assert bias.stride(-1) == 1
assert bias.shape == (N,)
# allocate output
y = torch.empty_like(x, dtype=x.dtype if out_dtype is None else out_dtype)
assert y.stride(-1) == 1
if residual is not None or (residual_dtype is not None and residual_dtype != x.dtype):
residual_out = torch.empty(M, N, device=x.device, dtype=residual_dtype)
assert residual_out.stride(-1) == 1
else:
residual_out = None
mean = torch.empty((M,), dtype=torch.float32, device=x.device) if not is_rms_norm else None
rstd = torch.empty((M,), dtype=torch.float32, device=x.device)
# Less than 64KB per feature: enqueue fused kernel
MAX_FUSED_SIZE = 65536 // x.element_size()
BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(N))
if N > BLOCK_N:
raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.")
# heuristics for number of warps
with torch.cuda.device(x.device.index):
_layer_norm_fwd_1pass_kernel[(M,)](
x,
y,
weight,
bias,
residual,
residual_out,
mean,
rstd,
x.stride(0),
y.stride(0),
residual.stride(0) if residual is not None else 0,
residual_out.stride(0) if residual_out is not None else 0,
N,
eps,
is_rms_norm,
BLOCK_N,
residual is not None,
residual_out is not None,
bias is not None,
)
# residual_out is None if residual is None and residual_dtype == input_dtype
return y, mean, rstd, residual_out if residual_out is not None else x
@triton.autotune(
configs=[
triton.Config({}, num_warps=1),
triton.Config({}, num_warps=2),
triton.Config({}, num_warps=4),
triton.Config({}, num_warps=8),
triton.Config({}, num_warps=16),
triton.Config({}, num_warps=32),
],
key=["N", "HAS_DRESIDUAL", "STORE_DRESIDUAL", "IS_RMS_NORM", "HAS_BIAS"],
)
# @triton.heuristics({"HAS_BIAS": lambda args: args["B"] is not None})
# @triton.heuristics({"HAS_DRESIDUAL": lambda args: args["DRESIDUAL"] is not None})
# @triton.heuristics({"STORE_DRESIDUAL": lambda args: args["DRESIDUAL_IN"] is not None})
@triton.heuristics({"RECOMPUTE_OUTPUT": lambda args: args["Y"] is not None})
@triton.jit
def _layer_norm_bwd_kernel(
X, # pointer to the input
W, # pointer to the weights
B, # pointer to the biases
Y, # pointer to the output to be recomputed
DY, # pointer to the output gradient
DX, # pointer to the input gradient
DW, # pointer to the partial sum of weights gradient
DB, # pointer to the partial sum of biases gradient
DRESIDUAL,
DRESIDUAL_IN,
Mean, # pointer to the mean
Rstd, # pointer to the 1/std
stride_x_row, # how much to increase the pointer when moving by 1 row
stride_y_row,
stride_dy_row,
stride_dx_row,
stride_dres_row,
stride_dres_in_row,
M, # number of rows in X
N, # number of columns in X
eps, # epsilon to avoid division by zero
rows_per_program,
IS_RMS_NORM: tl.constexpr,
BLOCK_N: tl.constexpr,
HAS_DRESIDUAL: tl.constexpr,
STORE_DRESIDUAL: tl.constexpr,
HAS_BIAS: tl.constexpr,
RECOMPUTE_OUTPUT: tl.constexpr,
):
# Map the program id to the elements of X, DX, and DY it should compute.
row_block_id = tl.program_id(0)
row_start = row_block_id * rows_per_program
cols = tl.arange(0, BLOCK_N)
mask = cols < N
X += row_start * stride_x_row
if HAS_DRESIDUAL:
DRESIDUAL += row_start * stride_dres_row
if STORE_DRESIDUAL:
DRESIDUAL_IN += row_start * stride_dres_in_row
DY += row_start * stride_dy_row
DX += row_start * stride_dx_row
if RECOMPUTE_OUTPUT:
Y += row_start * stride_y_row
w = tl.load(W + cols, mask=mask).to(tl.float32)
if RECOMPUTE_OUTPUT and HAS_BIAS:
b = tl.load(B + cols, mask=mask, other=0.0).to(tl.float32)
dw = tl.zeros((BLOCK_N,), dtype=tl.float32)
if HAS_BIAS:
db = tl.zeros((BLOCK_N,), dtype=tl.float32)
row_end = min((row_block_id + 1) * rows_per_program, M)
for row in range(row_start, row_end):
# Load data to SRAM
x = tl.load(X + cols, mask=mask, other=0).to(tl.float32)
dy = tl.load(DY + cols, mask=mask, other=0).to(tl.float32)
if not IS_RMS_NORM:
mean = tl.load(Mean + row)
rstd = tl.load(Rstd + row)
# Compute dx
xhat = (x - mean) * rstd if not IS_RMS_NORM else x * rstd
xhat = tl.where(mask, xhat, 0.0)
if RECOMPUTE_OUTPUT:
y = xhat * w + b if HAS_BIAS else xhat * w
tl.store(Y + cols, y, mask=mask)
wdy = w * dy
dw += dy * xhat
if HAS_BIAS:
db += dy
if not IS_RMS_NORM:
c1 = tl.sum(xhat * wdy, axis=0) / N
c2 = tl.sum(wdy, axis=0) / N
dx = (wdy - (xhat * c1 + c2)) * rstd
else:
c1 = tl.sum(xhat * wdy, axis=0) / N
dx = (wdy - xhat * c1) * rstd
if HAS_DRESIDUAL:
dres = tl.load(DRESIDUAL + cols, mask=mask, other=0).to(tl.float32)
dx += dres
# Write dx
if STORE_DRESIDUAL:
tl.store(DRESIDUAL_IN + cols, dx, mask=mask)
tl.store(DX + cols, dx, mask=mask)
X += stride_x_row
if HAS_DRESIDUAL:
DRESIDUAL += stride_dres_row
if STORE_DRESIDUAL:
DRESIDUAL_IN += stride_dres_in_row
if RECOMPUTE_OUTPUT:
Y += stride_y_row
DY += stride_dy_row
DX += stride_dx_row
tl.store(DW + row_block_id * N + cols, dw, mask=mask)
if HAS_BIAS:
tl.store(DB + row_block_id * N + cols, db, mask=mask)
def _layer_norm_bwd(
dy,
x,
weight,
bias,
eps,
mean,
rstd,
dresidual=None,
has_residual=False,
is_rms_norm=False,
x_dtype=None,
recompute_output=False,
):
M, N = x.shape
assert x.stride(-1) == 1
assert dy.stride(-1) == 1
assert dy.shape == (M, N)
if dresidual is not None:
assert dresidual.stride(-1) == 1
assert dresidual.shape == (M, N)
assert weight.shape == (N,)
assert weight.stride(-1) == 1
if bias is not None:
assert bias.stride(-1) == 1
assert bias.shape == (N,)
# allocate output
dx = (
torch.empty_like(x)
if x_dtype is None
else torch.empty(M, N, dtype=x_dtype, device=x.device)
)
dresidual_in = torch.empty_like(x) if has_residual and dx.dtype != x.dtype else None
y = torch.empty(M, N, dtype=dy.dtype, device=dy.device) if recompute_output else None
# Less than 64KB per feature: enqueue fused kernel
MAX_FUSED_SIZE = 65536 // x.element_size()
BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(N))
if N > BLOCK_N:
raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.")
sm_count = torch.cuda.get_device_properties(x.device).multi_processor_count
_dw = torch.empty((sm_count, N), dtype=torch.float32, device=weight.device)
_db = (
torch.empty((sm_count, N), dtype=torch.float32, device=bias.device)
if bias is not None
else None
)
rows_per_program = math.ceil(M / sm_count)
grid = (sm_count,)
with torch.cuda.device(x.device.index):
_layer_norm_bwd_kernel[grid](
x,
weight,
bias,
y,
dy,
dx,
_dw,
_db,
dresidual,
dresidual_in,
mean,
rstd,
x.stride(0),
0 if not recompute_output else y.stride(0),
dy.stride(0),
dx.stride(0),
dresidual.stride(0) if dresidual is not None else 0,
dresidual_in.stride(0) if dresidual_in is not None else 0,
M,
N,
eps,
rows_per_program,
is_rms_norm,
BLOCK_N,
dresidual is not None,
dresidual_in is not None,
bias is not None,
)
dw = _dw.sum(0).to(weight.dtype)
db = _db.sum(0).to(bias.dtype) if bias is not None else None
# Don't need to compute dresidual_in separately in this case
if has_residual and dx.dtype == x.dtype:
dresidual_in = dx
return (dx, dw, db, dresidual_in) if not recompute_output else (dx, dw, db, dresidual_in, y)
class LayerNormFn(torch.autograd.Function):
@staticmethod
def forward(
ctx,
x,
weight,
bias,
residual=None,
eps=1e-6,
prenorm=False,
residual_in_fp32=False,
is_rms_norm=False,
):
x_shape_og = x.shape
# reshape input data into 2D tensor
x = x.reshape(-1, x.shape[-1])
if x.stride(-1) != 1:
x = x.contiguous()
if residual is not None:
assert residual.shape == x_shape_og
residual = residual.reshape(-1, residual.shape[-1])
if residual.stride(-1) != 1:
residual = residual.contiguous()
weight = weight.contiguous()
if bias is not None:
bias = bias.contiguous()
residual_dtype = (
residual.dtype
if residual is not None
else (torch.float32 if residual_in_fp32 else None)
)
y, mean, rstd, residual_out = _layer_norm_fwd(
x, weight, bias, eps, residual, residual_dtype=residual_dtype, is_rms_norm=is_rms_norm
)
ctx.save_for_backward(residual_out, weight, bias, mean, rstd)
ctx.x_shape_og = x_shape_og
ctx.eps = eps
ctx.is_rms_norm = is_rms_norm
ctx.has_residual = residual is not None
ctx.prenorm = prenorm
ctx.x_dtype = x.dtype
y = y.reshape(x_shape_og)
return y if not prenorm else (y, residual_out.reshape(x_shape_og))
@staticmethod
def backward(ctx, dy, *args):
x, weight, bias, mean, rstd = ctx.saved_tensors
dy = dy.reshape(-1, dy.shape[-1])
if dy.stride(-1) != 1:
dy = dy.contiguous()
assert dy.shape == x.shape
if ctx.prenorm:
dresidual = args[0]
dresidual = dresidual.reshape(-1, dresidual.shape[-1])
if dresidual.stride(-1) != 1:
dresidual = dresidual.contiguous()
assert dresidual.shape == x.shape
else:
dresidual = None
dx, dw, db, dresidual_in = _layer_norm_bwd(
dy,
x,
weight,
bias,
ctx.eps,
mean,
rstd,
dresidual,
ctx.has_residual,
ctx.is_rms_norm,
x_dtype=ctx.x_dtype,
)
return (
dx.reshape(ctx.x_shape_og),
dw,
db,
dresidual_in.reshape(ctx.x_shape_og) if ctx.has_residual else None,
None,
None,
None,
None,
)
def layer_norm_fn(
x,
weight,
bias,
residual=None,
eps=1e-6,
prenorm=False,
residual_in_fp32=False,
is_rms_norm=False,
):
return LayerNormFn.apply(x, weight, bias, residual, eps, prenorm, residual_in_fp32, is_rms_norm)
def rms_norm_fn(x, weight, bias, residual=None, prenorm=False, residual_in_fp32=False, eps=1e-6):
return LayerNormFn.apply(x, weight, bias, residual, eps, prenorm, residual_in_fp32, True)
class RMSNorm(torch.nn.Module):
def __init__(self, hidden_size, eps=1e-5, device=None, dtype=None):
factory_kwargs = {"device": device, "dtype": dtype}
super().__init__()
self.eps = eps
self.weight = torch.nn.Parameter(torch.empty(hidden_size, **factory_kwargs))
self.register_parameter("bias", None)
self.reset_parameters()
def reset_parameters(self):
torch.nn.init.ones_(self.weight)
def forward(self, x, residual=None, prenorm=False, residual_in_fp32=False):
return rms_norm_fn(
x,
self.weight,
self.bias,
residual=residual,
eps=self.eps,
prenorm=prenorm,
residual_in_fp32=residual_in_fp32,
)
class LayerNormLinearFn(torch.autograd.Function):
@staticmethod
@custom_fwd
def forward(
ctx,
x,
norm_weight,
norm_bias,
linear_weight,
linear_bias,
residual=None,
eps=1e-6,
prenorm=False,
residual_in_fp32=False,
is_rms_norm=False,
):
x_shape_og = x.shape
# reshape input data into 2D tensor
x = x.reshape(-1, x.shape[-1])
if x.stride(-1) != 1:
x = x.contiguous()
if residual is not None:
assert residual.shape == x_shape_og
residual = residual.reshape(-1, residual.shape[-1])
if residual.stride(-1) != 1:
residual = residual.contiguous()
norm_weight = norm_weight.contiguous()
if norm_bias is not None:
norm_bias = norm_bias.contiguous()
residual_dtype = (
residual.dtype
if residual is not None
else (torch.float32 if residual_in_fp32 else None)
)
y, mean, rstd, residual_out = _layer_norm_fwd(
x,
norm_weight,
norm_bias,
eps,
residual,
out_dtype=None if not torch.is_autocast_enabled() else torch.get_autocast_gpu_dtype(),
residual_dtype=residual_dtype,
is_rms_norm=is_rms_norm,
)
y = y.reshape(x_shape_og)
dtype = torch.get_autocast_gpu_dtype() if torch.is_autocast_enabled() else y.dtype
linear_weight = linear_weight.to(dtype)
linear_bias = linear_bias.to(dtype) if linear_bias is not None else None
out = F.linear(y.to(linear_weight.dtype), linear_weight, linear_bias)
# We don't store y, will be recomputed in the backward pass to save memory
ctx.save_for_backward(residual_out, norm_weight, norm_bias, linear_weight, mean, rstd)
ctx.x_shape_og = x_shape_og
ctx.eps = eps
ctx.is_rms_norm = is_rms_norm
ctx.has_residual = residual is not None
ctx.prenorm = prenorm
ctx.x_dtype = x.dtype
ctx.linear_bias_is_none = linear_bias is None
return out if not prenorm else (out, residual_out.reshape(x_shape_og))
@staticmethod
@custom_bwd
def backward(ctx, dout, *args):
x, norm_weight, norm_bias, linear_weight, mean, rstd = ctx.saved_tensors
dout = dout.reshape(-1, dout.shape[-1])
dy = F.linear(dout, linear_weight.t())
dlinear_bias = None if ctx.linear_bias_is_none else dout.sum(0)
if dy.stride(-1) != 1:
dy = dy.contiguous()
assert dy.shape == x.shape
if ctx.prenorm:
dresidual = args[0]
dresidual = dresidual.reshape(-1, dresidual.shape[-1])
if dresidual.stride(-1) != 1:
dresidual = dresidual.contiguous()
assert dresidual.shape == x.shape
else:
dresidual = None
dx, dnorm_weight, dnorm_bias, dresidual_in, y = _layer_norm_bwd(
dy,
x,
norm_weight,
norm_bias,
ctx.eps,
mean,
rstd,
dresidual,
ctx.has_residual,
ctx.is_rms_norm,
x_dtype=ctx.x_dtype,
recompute_output=True,
)
dlinear_weight = torch.einsum("bo,bi->oi", dout, y)
return (
dx.reshape(ctx.x_shape_og),
dnorm_weight,
dnorm_bias,
dlinear_weight,
dlinear_bias,
dresidual_in.reshape(ctx.x_shape_og) if ctx.has_residual else None,
None,
None,
None,
None,
)
def layer_norm_linear_fn(
x,
norm_weight,
norm_bias,
linear_weight,
linear_bias,
residual=None,
eps=1e-6,
prenorm=False,
residual_in_fp32=False,
is_rms_norm=False,
):
return LayerNormLinearFn.apply(
x,
norm_weight,
norm_bias,
linear_weight,
linear_bias,
residual,
eps,
prenorm,
residual_in_fp32,
is_rms_norm,
)

View File

@ -1,192 +0,0 @@
# Copyright (c) 2023, Tri Dao.
"""We want triton==2.1.0 for this
"""
import math
import torch
import torch.nn.functional as F
import triton
import triton.language as tl
from einops import rearrange, repeat
@triton.heuristics({"HAS_DT_BIAS": lambda args: args["dt_bias_ptr"] is not None})
@triton.heuristics({"HAS_D": lambda args: args["D_ptr"] is not None})
@triton.heuristics({"HAS_Z": lambda args: args["z_ptr"] is not None})
@triton.heuristics({"BLOCK_SIZE_DSTATE": lambda args: triton.next_power_of_2(args["dstate"])})
@triton.jit
def _selective_scan_update_kernel(
# Pointers to matrices
state_ptr, x_ptr, dt_ptr, dt_bias_ptr, A_ptr, B_ptr, C_ptr, D_ptr, z_ptr, out_ptr,
# Matrix dimensions
batch, dim, dstate,
# Strides
stride_state_batch, stride_state_dim, stride_state_dstate,
stride_x_batch, stride_x_dim,
stride_dt_batch, stride_dt_dim,
stride_dt_bias_dim,
stride_A_dim, stride_A_dstate,
stride_B_batch, stride_B_dstate,
stride_C_batch, stride_C_dstate,
stride_D_dim,
stride_z_batch, stride_z_dim,
stride_out_batch, stride_out_dim,
# Meta-parameters
DT_SOFTPLUS: tl.constexpr,
BLOCK_SIZE_M: tl.constexpr,
HAS_DT_BIAS: tl.constexpr,
HAS_D: tl.constexpr,
HAS_Z: tl.constexpr,
BLOCK_SIZE_DSTATE: tl.constexpr,
):
pid_m = tl.program_id(axis=0)
pid_b = tl.program_id(axis=1)
state_ptr += pid_b * stride_state_batch
x_ptr += pid_b * stride_x_batch
dt_ptr += pid_b * stride_dt_batch
B_ptr += pid_b * stride_B_batch
C_ptr += pid_b * stride_C_batch
if HAS_Z:
z_ptr += pid_b * stride_z_batch
out_ptr += pid_b * stride_out_batch
offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
offs_n = tl.arange(0, BLOCK_SIZE_DSTATE)
state_ptrs = state_ptr + (offs_m[:, None] * stride_state_dim + offs_n[None, :] * stride_state_dstate)
x_ptrs = x_ptr + offs_m * stride_x_dim
dt_ptrs = dt_ptr + offs_m * stride_dt_dim
if HAS_DT_BIAS:
dt_bias_ptrs = dt_bias_ptr + offs_m * stride_dt_bias_dim
A_ptrs = A_ptr + (offs_m[:, None] * stride_A_dim + offs_n[None, :] * stride_A_dstate)
B_ptrs = B_ptr + offs_n * stride_B_dstate
C_ptrs = C_ptr + offs_n * stride_C_dstate
if HAS_D:
D_ptrs = D_ptr + offs_m * stride_D_dim
if HAS_Z:
z_ptrs = z_ptr + offs_m * stride_z_dim
out_ptrs = out_ptr + offs_m * stride_out_dim
state = tl.load(state_ptrs, mask=(offs_m[:, None] < dim) & (offs_n[None, :] < dstate), other=0.0)
x = tl.load(x_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)
dt = tl.load(dt_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)
if HAS_DT_BIAS:
dt += tl.load(dt_bias_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)
if DT_SOFTPLUS:
dt = tl.where(dt <= 20.0, tl.math.log1p(tl.exp(dt)), dt)
A = tl.load(A_ptrs, mask=(offs_m[:, None] < dim) & (offs_n[None, :] < dstate), other=0.0).to(tl.float32)
dA = tl.exp(A * dt[:, None])
B = tl.load(B_ptrs, mask=offs_n < dstate, other=0.0).to(tl.float32)
C = tl.load(C_ptrs, mask=offs_n < dstate, other=0.0).to(tl.float32)
if HAS_D:
D = tl.load(D_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)
if HAS_Z:
z = tl.load(z_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)
dB = B[None, :] * dt[:, None]
state = state * dA + dB * x[:, None]
tl.store(state_ptrs, state, mask=(offs_m[:, None] < dim) & (offs_n[None, :] < dstate))
out = tl.sum(state * C[None, :], axis=1)
if HAS_D:
out += x * D
if HAS_Z:
out *= z * tl.sigmoid(z)
tl.store(out_ptrs, out, mask=offs_m < dim)
def selective_state_update(state, x, dt, A, B, C, D=None, z=None, dt_bias=None, dt_softplus=False):
"""
Argument:
state: (batch, dim, dstate)
x: (batch, dim)
dt: (batch, dim)
A: (dim, dstate)
B: (batch, dstate)
C: (batch, dstate)
D: (dim,)
z: (batch, dim)
dt_bias: (dim,)
Return:
out: (batch, dim)
"""
batch, dim, dstate = state.shape
assert x.shape == (batch, dim)
assert dt.shape == x.shape
assert A.shape == (dim, dstate)
assert B.shape == (batch, dstate)
assert C.shape == B.shape
if D is not None:
assert D.shape == (dim,)
if z is not None:
assert z.shape == x.shape
if dt_bias is not None:
assert dt_bias.shape == (dim,)
out = torch.empty_like(x)
grid = lambda META: (triton.cdiv(dim, META['BLOCK_SIZE_M']), batch)
z_strides = ((z.stride(0), z.stride(1)) if z is not None else (0, 0))
# We don't want autotune since it will overwrite the state
# We instead tune by hand.
BLOCK_SIZE_M, num_warps = ((32, 4) if dstate <= 16
else ((16, 4) if dstate <= 32 else
((8, 4) if dstate <= 64 else
((4, 4) if dstate <= 128 else
((4, 8))))))
with torch.cuda.device(x.device.index):
_selective_scan_update_kernel[grid](
state, x, dt, dt_bias, A, B, C, D, z, out,
batch, dim, dstate,
state.stride(0), state.stride(1), state.stride(2),
x.stride(0), x.stride(1),
dt.stride(0), dt.stride(1),
dt_bias.stride(0) if dt_bias is not None else 0,
A.stride(0), A.stride(1),
B.stride(0), B.stride(1),
C.stride(0), C.stride(1),
D.stride(0) if D is not None else 0,
z_strides[0], z_strides[1],
out.stride(0), out.stride(1),
dt_softplus,
BLOCK_SIZE_M,
num_warps=num_warps,
)
return out
def selective_state_update_ref(state, x, dt, A, B, C, D=None, z=None, dt_bias=None, dt_softplus=False):
"""
Argument:
state: (batch, dim, dstate)
x: (batch, dim)
dt: (batch, dim)
A: (dim, dstate)
B: (batch, dstate)
C: (batch, dstate)
D: (dim,)
z: (batch, dim)
dt_bias: (dim,)
Return:
out: (batch, dim)
"""
batch, dim, dstate = state.shape
assert x.shape == (batch, dim)
assert dt.shape == x.shape
assert A.shape == (dim, dstate)
assert B.shape == (batch, dstate)
assert C.shape == B.shape
if D is not None:
assert D.shape == (dim,)
if z is not None:
assert z.shape == x.shape
if dt_bias is not None:
assert dt_bias.shape == (dim,)
dt = dt + dt_bias
dt = F.softplus(dt) if dt_softplus else dt
dA = torch.exp(rearrange(dt, "b d -> b d 1") * A) # (batch, dim, dstate)
dB = rearrange(dt, "b d -> b d 1") * rearrange(B, "b n -> b 1 n") # (batch, dim, dstate)
state.copy_(state * dA + dB * rearrange(x, "b d -> b d 1")) # (batch, dim, dstate
out = torch.einsum("bdn,bn->bd", state.to(C.dtype), C)
if D is not None:
out += (x * D).to(out.dtype)
return (out if z is None else out * F.silu(z)).to(x.dtype)

View File

@ -1,387 +0,0 @@
# Copyright (c) 2023, Albert Gu, Tri Dao.
import gc
import time
from collections import namedtuple
from dataclasses import dataclass, field
from functools import partial
from typing import Callable, Optional, Sequence, Union
import torch
import torch.nn.functional as F
from einops import rearrange, repeat
from torch import Tensor
from torch.profiler import ProfilerActivity, profile, record_function
from transformers.generation import GreedySearchDecoderOnlyOutput, SampleDecoderOnlyOutput, TextStreamer
@dataclass
class InferenceParams:
"""Inference parameters that are passed to the main model in order
to efficienly calculate and store the context during inference."""
max_seqlen: int
max_batch_size: int
seqlen_offset: int = 0
batch_size_offset: int = 0
key_value_memory_dict: dict = field(default_factory=dict)
lengths_per_sample: Optional[Tensor] = None
def reset(self, max_seqlen, max_batch_size):
self.max_seqlen = max_seqlen
self.max_batch_size = max_batch_size
self.seqlen_offset = 0
if self.lengths_per_sample is not None:
self.lengths_per_sample.zero_()
def modify_logits_for_min_p_filtering(logits, min_p):
"""Set the logits for none min_p values to -inf. Done in-place."""
if min_p <= 0.0 or min_p >= 1.0:
return
indices_to_remove = logits < min_p
logits.masked_fill_(indices_to_remove, float("-Inf"))
# https://github.com/NVIDIA/Megatron-LM/blob/0bb597b42c53355a567aba2a1357cc34b9d99ddd/megatron/text_generation/sampling.py
# https://github.com/huggingface/transformers/blob/a44985b41cfa2de48a5e1de7f1f93b7483da25d1/src/transformers/generation/logits_process.py#L231
def modify_logits_for_top_k_filtering(logits, top_k):
"""Set the logits for none top-k values to -inf. Done in-place."""
indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
logits.masked_fill_(indices_to_remove, float("-Inf"))
# https://github.com/NVIDIA/Megatron-LM/blob/0bb597b42c53355a567aba2a1357cc34b9d99ddd/megatron/text_generation/sampling.py
# https://github.com/huggingface/transformers/blob/a44985b41cfa2de48a5e1de7f1f93b7483da25d1/src/transformers/generation/logits_process.py#L170
def modify_logits_for_top_p_filtering(logits, top_p):
"""Set the logits for none top-p values to -inf. Done in-place."""
if top_p <= 0.0 or top_p >= 1.0:
return
# First sort and calculate cumulative sum of probabilities.
sorted_logits, sorted_indices = torch.sort(logits, descending=False)
cumulative_probs = sorted_logits.softmax(dim=-1).cumsum(dim=-1)
# Remove tokens with cumulative top_p above the threshold (token with 0 are kept)
sorted_indices_to_remove = cumulative_probs <= (1 - top_p)
# scatter sorted tensors to original indexing
indices_to_remove = sorted_indices_to_remove.scatter(
1, sorted_indices, sorted_indices_to_remove
)
logits.masked_fill_(indices_to_remove, float("-inf"))
def modify_logit_for_repetition_penalty(logits, prev_output_tokens, repetition_penalty=1.0):
"""Apply repetition penalty. See https://arxiv.org/abs/1909.05858
logits: (batch_size, vocab_size)
prev_output_tokens: (batch_size, seq_len)
"""
if repetition_penalty == 1.0:
return logits
score = torch.gather(logits, 1, prev_output_tokens)
# if score < 0 then repetition penalty has to be multiplied to reduce the previous token probability
score = torch.where(score < 0, score * repetition_penalty, score / repetition_penalty)
logits.scatter_(1, prev_output_tokens, score)
return logits
def sample(logits, top_k=1, top_p=0.0, min_p=0.0, temperature=1.0):
"""Sample from top-k logits.
Arguments:
logits: Tensor of shape (batch_size, vocab_size)
"""
if top_k == 1: # Short-circuit for greedy decoding
return logits.argmax(dim=-1)
else:
if top_p > 0.0:
assert top_p <= 1.0, "top-p should be in (0, 1]."
if top_k > 0:
top_k = min(top_k, logits.size(-1)) # Safety check
logits_top, indices = torch.topk(logits, top_k, dim=-1)
if temperature != 1.0:
logits_top /= temperature
modify_logits_for_top_p_filtering(logits_top, top_p)
return indices[
torch.arange(indices.shape[0], device=indices.device),
torch.multinomial(torch.softmax(logits_top, dim=-1), num_samples=1).squeeze(dim=-1),
]
else:
if min_p > 0.0:
logits_top = logits.clone()
max_prob = logits_top[..., 0].item()
min_prob = max_prob * min_p
modify_logits_for_min_p_filtering(logits_top, min_p)
if temperature != 1.0:
logits_top /= temperature
return torch.multinomial(torch.softmax(logits_top, dim=-1), num_samples=1).squeeze(dim=-1)
# Clone so that when we modify for top_p we don't change the original logits
logits_top = logits / temperature if temperature != 1.0 else logits.clone()
modify_logits_for_top_p_filtering(logits_top, top_p)
return torch.multinomial(torch.softmax(logits_top, dim=-1), num_samples=1).squeeze(
dim=-1
)
@torch.inference_mode()
def decode(
input_ids,
model,
max_length,
top_k=1,
top_p=0.0,
min_p=0.0,
temperature=1.0,
repetition_penalty=1.0,
eos_token_id=None,
teacher_outputs=None,
vocab_size=None,
cg=False,
enable_timing=False,
streamer: Optional[TextStreamer] = None
):
"""Decoding, either greedy or with top-k or top-p sampling.
If top-k = 0, don't limit the number of candidates (pure sampling).
Top-k and top-p can be used together. If top_k > 0 and top_p > 0, then top-k is applied first,
then top-p.
We assume that all sequences in the same batch have the same length.
Arguments:
input_ids: (batch, seq_len)
max_length: int
teacher_outputs (optional): (batch, seq_len). If provided, instead of sampling from the
logits, the next token is taken from the teacher_outputs. Useful for testing.
Returns: GreedySearchDecoderOnlyOutput or SampleDecoderOnlyOutput, with the following fields:
sequences: (batch, max_length)
scores: tuples of (batch, vocab_size)
"""
if streamer is not None:
streamer.put(input_ids.cpu())
batch_size, seqlen_og = input_ids.shape
teacher_output_len = teacher_outputs.shape[1] if teacher_outputs is not None else 0
if cg:
if not hasattr(model, "_decoding_cache"):
model._decoding_cache = None
model._decoding_cache = update_graph_cache(
model,
model._decoding_cache,
batch_size,
seqlen_og,
max_length,
)
inference_params = model._decoding_cache.inference_params
inference_params.reset(max_length, batch_size)
else:
inference_params = InferenceParams(max_seqlen=max_length, max_batch_size=batch_size)
def get_logits(input_ids, inference_params):
decoding = inference_params.seqlen_offset > 0
if decoding:
position_ids = torch.full(
(batch_size, 1),
inference_params.seqlen_offset,
dtype=torch.long,
device=input_ids.device,
)
else:
position_ids = None
if not cg or not decoding:
logits = model(
input_ids,
position_ids=position_ids,
inference_params=inference_params,
num_last_tokens=1,
).logits.squeeze(dim=1)
else:
logits = model._decoding_cache.run(
input_ids, position_ids, inference_params.seqlen_offset
).squeeze(dim=1)
return logits[..., :vocab_size] if vocab_size is not None else logits
def sample_tokens(logits, inference_params):
if teacher_outputs is None or teacher_output_len <= inference_params.seqlen_offset:
token = sample(logits, top_k=top_k, top_p=top_p, min_p=min_p, temperature=temperature)
else:
token = teacher_outputs[:, inference_params.seqlen_offset]
# return rearrange(token, "b -> b 1")
return token.unsqueeze(1)
def should_stop(current_token, inference_params):
if inference_params.seqlen_offset == 0:
return False
if eos_token_id is not None and (current_token == eos_token_id).all():
return True
if inference_params.seqlen_offset >= max_length - 1:
return True
return False
start = torch.cuda.Event(enable_timing=enable_timing)
end = torch.cuda.Event(enable_timing=enable_timing)
if enable_timing:
start.record()
scores, sequences = [], [input_ids]
sequences_cat = input_ids
while not should_stop(sequences[-1], inference_params):
scores.append(get_logits(sequences[-1], inference_params))
inference_params.seqlen_offset += sequences[-1].shape[1]
if repetition_penalty == 1.0:
sampled_tokens = sample_tokens(scores[-1], inference_params)
else:
logits = modify_logit_for_repetition_penalty(
scores[-1].clone(), sequences_cat, repetition_penalty
)
sampled_tokens = sample_tokens(logits, inference_params)
sequences_cat = torch.cat([sequences_cat, sampled_tokens], dim=1)
sequences.append(sampled_tokens)
if streamer is not None:
streamer.put(sampled_tokens.cpu())
if streamer is not None:
streamer.end()
if enable_timing:
end.record()
torch.cuda.synchronize()
print(f"Prompt processing + decoding time: {(start.elapsed_time(end)):.0f}ms")
output_cls = GreedySearchDecoderOnlyOutput if top_k == 1 else SampleDecoderOnlyOutput
return output_cls(sequences=torch.cat(sequences, dim=1), scores=tuple(scores))
class GenerationMixin:
def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
raise NotImplementedError
def generate(
self,
input_ids,
max_length,
top_k=1,
top_p=0.0,
min_p=0.0,
temperature=1.0,
return_dict_in_generate=False,
output_scores=False,
**kwargs,
):
output = decode(
input_ids, self, max_length, top_k=top_k, top_p=top_p, min_p = min_p, temperature=temperature, **kwargs
)
if not output_scores:
output.scores = None
return output if return_dict_in_generate else output.sequences
@dataclass
class DecodingCGCache:
max_batch_size: int = 0
max_seqlen: int = 0
device = None
dtype = None
callables: dict = field(default_factory=dict)
mempool = None
inference_params: Optional[InferenceParams] = None
run: Optional[Callable] = None
@torch.inference_mode()
def update_graph_cache(
model,
cache,
batch_size,
seqlen_og,
max_seqlen,
decoding_seqlens=(1,),
dtype=None,
n_warmups=2,
):
if cache is None:
cache = DecodingCGCache()
param_example = next(iter(model.parameters()))
device = param_example.device
if dtype is None:
dtype = param_example.dtype
if (
(device, dtype) != (cache.device, cache.dtype)
or batch_size > cache.max_batch_size
or max_seqlen > cache.max_seqlen
): # Invalidate the cache
cache.callables = {}
cache.mempool = None
cache.inference_params = None
gc.collect()
cache.device, cache.dtype = device, dtype
cache.max_batch_size, cache.max_seqlen = batch_size, max_seqlen
assert hasattr(model, "allocate_inference_cache"), "CUDA graph decoding requires that the model has a method allocate_inference_cache"
inf_cache = model.allocate_inference_cache(batch_size, max_seqlen, dtype)
lengths_per_sample = torch.full((batch_size,), seqlen_og, dtype=torch.int32, device=device)
cache.inference_params = InferenceParams(
max_seqlen=max_seqlen,
max_batch_size=batch_size,
seqlen_offset=seqlen_og,
key_value_memory_dict=inf_cache,
lengths_per_sample=lengths_per_sample,
)
cache.mempool = torch.cuda.graphs.graph_pool_handle()
for decoding_seqlen in decoding_seqlens:
if (batch_size, decoding_seqlen) not in cache.callables:
cache.callables[batch_size, decoding_seqlen] = capture_graph(
model,
cache.inference_params,
batch_size,
max_seqlen,
decoding_seqlen=decoding_seqlen,
mempool=cache.mempool,
n_warmups=n_warmups,
)
def dispatch(input_ids, position_ids, seqlen):
batch_size, decoding_seqlen = input_ids.shape[:2]
return cache.callables[batch_size, decoding_seqlen](input_ids, position_ids, seqlen)
cache.run = dispatch
cache.inference_params.seqlen_offset = 0 # Reset so it's not confusing
return cache
def capture_graph(
model, inference_params, batch_size, max_seqlen, decoding_seqlen=1, mempool=None, n_warmups=2
):
device = next(iter(model.parameters())).device
input_ids = torch.full((batch_size, decoding_seqlen), 0, dtype=torch.long, device=device)
position_ids = torch.full((batch_size, decoding_seqlen), 0, dtype=torch.long, device=device)
seqlen_offset_og = inference_params.seqlen_offset
inference_params.seqlen_offset = max_seqlen - decoding_seqlen
inference_params.lengths_per_sample[:] = inference_params.seqlen_offset
# Warmup before capture
s = torch.cuda.Stream()
s.wait_stream(torch.cuda.current_stream())
with torch.cuda.stream(s):
for _ in range(n_warmups):
logits = model(
input_ids,
position_ids=position_ids,
inference_params=inference_params,
num_last_tokens=decoding_seqlen,
).logits
s.synchronize()
# This might be needed for correctness if we run with NCCL_GRAPH_MIXING_SUPPORT=0,
# which requires that graph launch and non-captured launch to not overlap (I think,
# that's how I interpret the documentation). I'm not sure if this is required.
if torch.distributed.is_initialized():
torch.distributed.barrier()
torch.cuda.current_stream().wait_stream(s)
# Captures the graph
# To allow capture, automatically sets a side stream as the current stream in the context
graph = torch.cuda.CUDAGraph()
with torch.cuda.graph(graph, pool=mempool):
logits = model(
input_ids,
position_ids=position_ids,
inference_params=inference_params,
num_last_tokens=decoding_seqlen,
).logits
def run(new_input_ids, new_position_ids, seqlen):
inference_params.lengths_per_sample[:] = seqlen
input_ids.copy_(new_input_ids)
position_ids.copy_(new_position_ids)
graph.replay()
return logits.clone()
inference_params.seqlen_offset = seqlen_offset_og
return run

View File

@ -1,23 +0,0 @@
import json
import torch
from transformers.utils import WEIGHTS_NAME, CONFIG_NAME
from transformers.utils.hub import cached_file
def load_config_hf(model_name):
resolved_archive_file = cached_file(model_name, CONFIG_NAME, _raise_exceptions_for_missing_entries=False)
return json.load(open(resolved_archive_file))
def load_state_dict_hf(model_name, device=None, dtype=None):
# If not fp32, then we don't want to load directly to the GPU
mapped_device = "cpu" if dtype not in [torch.float32, None] else device
resolved_archive_file = cached_file(model_name, WEIGHTS_NAME, _raise_exceptions_for_missing_entries=False)
return torch.load(resolved_archive_file, map_location=mapped_device)
# Convert dtype before moving to GPU to save memory
if dtype is not None:
state_dict = {k: v.to(dtype=dtype) for k, v in state_dict.items()}
state_dict = {k: v.to(device=device) for k, v in state_dict.items()}
return state_dict

126
net.py
View File

@ -6,10 +6,7 @@ import torch.utils.checkpoint as checkpoint
from timm.models.layers import DropPath, to_2tuple, trunc_normal_
from einops import rearrange
from componets.WTConvCV2 import WTConv2d
# 以一定概率随机丢弃输入张量中的路径,用于正则化模型
def drop_path(x, drop_prob: float = 0., training: bool = False):
if drop_prob == 0. or not training:
return x
@ -35,9 +32,6 @@ class DropPath(nn.Module):
def forward(self, x):
return drop_path(x, self.drop_prob, self.training)
# 改点使用Pooling替换AttentionBase
class Pooling(nn.Module):
def __init__(self, kernel_size=3):
super().__init__()
@ -50,8 +44,8 @@ class Pooling(nn.Module):
class PoolMlp(nn.Module):
"""
实现基于1x1卷积的MLP模块
输入形状为[B, C, H, W]的张量
Implementation of MLP with 1*1 convolutions.
Input: tensor with shape [B, C, H, W]
"""
def __init__(self,
@ -61,17 +55,6 @@ class PoolMlp(nn.Module):
act_layer=nn.GELU,
bias=False,
drop=0.):
"""
初始化PoolMlp模块
参数:
in_features (int): 输入特征的数量
hidden_features (int, 可选): 隐藏层特征的数量默认为None设置为与in_features相同
out_features (int, 可选): 输出特征的数量默认为None设置为与in_features相同
act_layer (nn.Module, 可选): 使用的激活层默认为nn.GELU
bias (bool, 可选): 是否在卷积层中包含偏置项默认为False
drop (float, 可选): Dropout比率默认为0
"""
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
@ -81,15 +64,6 @@ class PoolMlp(nn.Module):
self.drop = nn.Dropout(drop)
def forward(self, x):
"""
通过PoolMlp模块的前向传播
参数:
x (torch.Tensor): 形状为[B, C, H, W]的输入张量
返回:
torch.Tensor: 形状为[B, C, H, W]的输出张量
"""
x = self.fc1(x) # (B, C, H, W) --> (B, C, H, W)
x = self.act(x)
x = self.drop(x)
@ -97,55 +71,6 @@ class PoolMlp(nn.Module):
x = self.drop(x)
return x
# class BaseFeatureExtraction1(nn.Module):
# def __init__(self, dim, pool_size=3, mlp_ratio=4.,
# act_layer=nn.GELU,
# # norm_layer=nn.LayerNorm,
# drop=0., drop_path=0.,
# use_layer_scale=True, layer_scale_init_value=1e-5):
#
# super().__init__()
#
# self.norm1 = LayerNorm(dim, 'WithBias')
# self.token_mixer = Pooling(kernel_size=pool_size) # vits是msaMLPs是mlp这个用pool来替代
# self.norm2 = LayerNorm(dim, 'WithBias')
# mlp_hidden_dim = int(dim * mlp_ratio)
# self.poolmlp = PoolMlp(in_features=dim, hidden_features=mlp_hidden_dim,
# act_layer=act_layer, drop=drop)
#
# # The following two techniques are useful to train deep PoolFormers.
# self.drop_path = DropPath(drop_path) if drop_path > 0. \
# else nn.Identity()
# self.use_layer_scale = use_layer_scale
#
# if use_layer_scale:
# self.layer_scale_1 = nn.Parameter(
# torch.ones(dim, dtype=torch.float32) * layer_scale_init_value)
#
# self.layer_scale_2 = nn.Parameter(
# torch.ones(dim, dtype=torch.float32) * layer_scale_init_value)
#
# def forward(self, x): # 1 64 128 128
# if self.use_layer_scale:
# # self.layer_scale_1(64,)
# tmp1 = self.layer_scale_1.unsqueeze(-1).unsqueeze(-1) # 64 1 1
# normal = self.norm1(x) # 1 64 128 128
# token_mix = self.token_mixer(normal) # 1 64 128 128
# x = (x +
# self.drop_path(
# tmp1 * token_mix
# )
# # 该表达式将 self.layer_scale_1 这个一维张量(或变量)在维度末尾添加两个新的维度,使其从一维变为三维。这通常用于使其能够与三维的特征图进行广播操作,如元素相乘。具体用途可能包括调整卷积层或注意力机制中的权重。
# )
# x = x + self.drop_path(
# self.layer_scale_2.unsqueeze(-1).unsqueeze(-1)
# * self.poolmlp(self.norm2(x)))
# else:
# x = x + self.drop_path(self.token_mixer(self.norm1(x))) # 匹配cddfuse
# x = x + self.drop_path(self.poolmlp(self.norm2(x)))
# return x
class BaseFeatureExtraction(nn.Module):
def __init__(self, dim, pool_size=3, mlp_ratio=4.,
act_layer=nn.GELU,
@ -155,7 +80,6 @@ class BaseFeatureExtraction(nn.Module):
super().__init__()
self.WTConv2d = WTConv2d(dim, dim)
self.norm1 = LayerNorm(dim, 'WithBias')
self.token_mixer = Pooling(kernel_size=pool_size) # vits是msaMLPs是mlp这个用pool来替代
self.norm2 = LayerNorm(dim, 'WithBias')
@ -175,29 +99,20 @@ class BaseFeatureExtraction(nn.Module):
self.layer_scale_2 = nn.Parameter(
torch.ones(dim, dtype=torch.float32) * layer_scale_init_value)
def forward(self, x): # 1 64 128 128
def forward(self, x):
if self.use_layer_scale:
# self.layer_scale_1(64,)
tmp1 = self.layer_scale_1.unsqueeze(-1).unsqueeze(-1) # 64 1 1
normal = self.norm1(x) # 1 64 128 128
token_mix = self.token_mixer(normal) # 1 64 128 128
x = self.WTConv2d(x)
x = (x +
self.drop_path(
tmp1 * token_mix
)
# 该表达式将 self.layer_scale_1 这个一维张量(或变量)在维度末尾添加两个新的维度,使其从一维变为三维。这通常用于使其能够与三维的特征图进行广播操作,如元素相乘。具体用途可能包括调整卷积层或注意力机制中的权重。
)
x = x + self.drop_path(
self.layer_scale_1.unsqueeze(-1).unsqueeze(-1)
* self.token_mixer(self.norm1(x)))
x = x + self.drop_path(
self.layer_scale_2.unsqueeze(-1).unsqueeze(-1)
* self.poolmlp(self.norm2(x)))
else:
x = x + self.drop_path(self.token_mixer(self.norm1(x))) # 匹配cddfuse
x = x + self.drop_path(self.token_mixer(self.norm1(x)))
x = x + self.drop_path(self.poolmlp(self.norm2(x)))
return x
class InvertedResidualBlock(nn.Module):
def __init__(self, inp, oup, expand_ratio):
super(InvertedResidualBlock, self).__init__()
@ -216,12 +131,12 @@ class InvertedResidualBlock(nn.Module):
nn.Conv2d(hidden_dim, oup, 1, bias=False),
# nn.BatchNorm2d(oup),
)
def forward(self, x):
return self.bottleneckBlock(x)
class DetailNode(nn.Module):
# <img src = "http://42.192.130.83:9000/picgo/imgs/小绿鲸英文文献阅读器_ELTITYqm5G.png" / > '
class DetailNode(nn.Module):
def __init__(self):
super(DetailNode, self).__init__()
@ -248,24 +163,14 @@ class DetailFeatureExtraction(nn.Module):
super(DetailFeatureExtraction, self).__init__()
INNmodules = [DetailNode() for _ in range(num_layers)]
self.net = nn.Sequential(*INNmodules)
self.enhancement_module = nn.Sequential(
nn.Conv2d(32, 32, kernel_size=3, padding=1, bias=True),
nn.ReLU(inplace=True),
nn.Conv2d(32, 32, kernel_size=3, padding=1, bias=True),
)
def forward(self, x): # 1 64 128 128
z1, z2 = x[:, :x.shape[1] // 2], x[:, x.shape[1] // 2:x.shape[1]] # 1 32 128 128
# 增强并添加残差连接
enhanced_z1 = self.enhancement_module(z1)
enhanced_z2 = self.enhancement_module(z2)
# 残差连接
z1 = z1 + enhanced_z1
z2 = z2 + enhanced_z2
def forward(self, x):
z1, z2 = x[:, :x.shape[1] // 2], x[:, x.shape[1] // 2:x.shape[1]]
for layer in self.net:
z1, z2 = layer(z1, z2)
return torch.cat((z1, z2), dim=1)
# =============================================================================
# =============================================================================
@ -472,7 +377,8 @@ class Restormer_Decoder(nn.Module):
super(Restormer_Decoder, self).__init__()
self.reduce_channel = nn.Conv2d(int(dim * 2), int(dim), kernel_size=1, bias=bias)
self.encoder_level2 = nn.Sequential(*[TransformerBlock(dim=dim, num_heads=heads[1], ffn_expansion_factor=ffn_expansion_factor,
self.encoder_level2 = nn.Sequential(
*[TransformerBlock(dim=dim, num_heads=heads[1], ffn_expansion_factor=ffn_expansion_factor,
bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_blocks[1])])
self.output = nn.Sequential(
nn.Conv2d(int(dim), int(dim) // 2, kernel_size=3,
@ -499,5 +405,3 @@ if __name__ == '__main__':
window_size = 8
modelE = Restormer_Encoder().cuda()
modelD = Restormer_Decoder().cuda()
print(modelE)
print(modelD)

View File

@ -1,411 +0,0 @@
import torch
import torch.nn as nn
import math
import torch.nn.functional as F
import torch.utils.checkpoint as checkpoint
from timm.models.layers import DropPath, to_2tuple, trunc_normal_
from einops import rearrange
def drop_path(x, drop_prob: float = 0., training: bool = False):
"""
Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
This is the same as the DropConnect impl I created for EfficientNet, etc networks, however,
the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for
changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use
'survival rate' as the argument.
"""
if drop_prob == 0. or not training:
return x
keep_prob = 1 - drop_prob
# work with diff dim tensors, not just 2D ConvNets
shape = (x.shape[0],) + (1,) * (x.ndim - 1)
random_tensor = keep_prob + \
torch.rand(shape, dtype=x.dtype, device=x.device)
random_tensor.floor_() # binarize
output = x.div(keep_prob) * random_tensor
return output
class DropPath(nn.Module):
"""
Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
"""
def __init__(self, drop_prob=None):
super(DropPath, self).__init__()
self.drop_prob = drop_prob
def forward(self, x):
return drop_path(x, self.drop_prob, self.training)
class AttentionBase(nn.Module):
def __init__(self,
dim,
num_heads=8,
qkv_bias=False,):
super(AttentionBase, self).__init__()
self.num_heads = num_heads
head_dim = dim // num_heads
self.scale = nn.Parameter(torch.ones(num_heads, 1, 1))
self.qkv1 = nn.Conv2d(dim, dim*3, kernel_size=1, bias=qkv_bias)
self.qkv2 = nn.Conv2d(dim*3, dim*3, kernel_size=3, padding=1, bias=qkv_bias)
self.proj = nn.Conv2d(dim, dim, kernel_size=1, bias=qkv_bias)
def forward(self, x):
# [batch_size, num_patches + 1, total_embed_dim]
b, c, h, w = x.shape
qkv = self.qkv2(self.qkv1(x))
q, k, v = qkv.chunk(3, dim=1)
q = rearrange(q, 'b (head c) h w -> b head c (h w)',
head=self.num_heads)
k = rearrange(k, 'b (head c) h w -> b head c (h w)',
head=self.num_heads)
v = rearrange(v, 'b (head c) h w -> b head c (h w)',
head=self.num_heads)
q = torch.nn.functional.normalize(q, dim=-1)
k = torch.nn.functional.normalize(k, dim=-1)
# transpose: -> [batch_size, num_heads, embed_dim_per_head, num_patches + 1]
# @: multiply -> [batch_size, num_heads, num_patches + 1, num_patches + 1]
attn = (q @ k.transpose(-2, -1)) * self.scale
attn = attn.softmax(dim=-1)
out = (attn @ v)
out = rearrange(out, 'b head c (h w) -> b (head c) h w',
head=self.num_heads, h=h, w=w)
out = self.proj(out)
return out
class Mlp(nn.Module):
"""
MLP as used in Vision Transformer, MLP-Mixer and related networks
"""
def __init__(self,
in_features,
hidden_features=None,
ffn_expansion_factor = 2,
bias = False):
super().__init__()
hidden_features = int(in_features*ffn_expansion_factor)
self.project_in = nn.Conv2d(
in_features, hidden_features*2, kernel_size=1, bias=bias)
self.dwconv = nn.Conv2d(hidden_features*2, hidden_features*2, kernel_size=3,
stride=1, padding=1, groups=hidden_features, bias=bias)
self.project_out = nn.Conv2d(
hidden_features, in_features, kernel_size=1, bias=bias)
def forward(self, x):
x = self.project_in(x)
x1, x2 = self.dwconv(x).chunk(2, dim=1)
x = F.gelu(x1) * x2
x = self.project_out(x)
return x
class BaseFeatureExtraction(nn.Module):
def __init__(self,
dim,
num_heads,
ffn_expansion_factor=1.,
qkv_bias=False,):
super(BaseFeatureExtraction, self).__init__()
self.norm1 = LayerNorm(dim, 'WithBias')
# https://zhuanlan.zhihu.com/p/444887088#:~:text=%E5%9C%A8%E6%9C%AC%E6%96%87%E4%B8%AD%EF%BC%8C%E6%88%91%E4%BB%AC%E6%8F%90%E5%87%BA%E4%BA%86
self.attn = AttentionBase(dim, num_heads=num_heads, qkv_bias=qkv_bias,)
self.norm2 = LayerNorm(dim, 'WithBias')
self.mlp = Mlp(in_features=dim,
ffn_expansion_factor=ffn_expansion_factor,)
def forward(self, x):
x = x + self.attn(self.norm1(x))
x = x + self.mlp(self.norm2(x))
return x
class InvertedResidualBlock(nn.Module):
def __init__(self, inp, oup, expand_ratio):
super(InvertedResidualBlock, self).__init__()
hidden_dim = int(inp * expand_ratio)
self.bottleneckBlock = nn.Sequential(
# pw
nn.Conv2d(inp, hidden_dim, 1, bias=False),
# nn.BatchNorm2d(hidden_dim),
nn.ReLU6(inplace=True),
# dw
nn.ReflectionPad2d(1),
nn.Conv2d(hidden_dim, hidden_dim, 3, groups=hidden_dim, bias=False),
# nn.BatchNorm2d(hidden_dim),
nn.ReLU6(inplace=True),
# pw-linear
nn.Conv2d(hidden_dim, oup, 1, bias=False),
# nn.BatchNorm2d(oup),
)
def forward(self, x):
return self.bottleneckBlock(x)
class DetailNode(nn.Module):
def __init__(self):
super(DetailNode, self).__init__()
# Scale is Ax + b, i.e. affine transformation
self.theta_phi = InvertedResidualBlock(inp=32, oup=32, expand_ratio=2)
self.theta_rho = InvertedResidualBlock(inp=32, oup=32, expand_ratio=2)
self.theta_eta = InvertedResidualBlock(inp=32, oup=32, expand_ratio=2)
self.shffleconv = nn.Conv2d(64, 64, kernel_size=1,
stride=1, padding=0, bias=True)
def separateFeature(self, x):
z1, z2 = x[:, :x.shape[1]//2], x[:, x.shape[1]//2:x.shape[1]]
return z1, z2
def forward(self, z1, z2):
z1, z2 = self.separateFeature(
self.shffleconv(torch.cat((z1, z2), dim=1)))
z2 = z2 + self.theta_phi(z1)
z1 = z1 * torch.exp(self.theta_rho(z2)) + self.theta_eta(z2)
return z1, z2
class DetailFeatureExtraction(nn.Module):
def __init__(self, num_layers=3):
super(DetailFeatureExtraction, self).__init__()
INNmodules = [DetailNode() for _ in range(num_layers)]
self.net = nn.Sequential(*INNmodules)
def forward(self, x):
z1, z2 = x[:, :x.shape[1]//2], x[:, x.shape[1]//2:x.shape[1]]
for layer in self.net:
z1, z2 = layer(z1, z2)
return torch.cat((z1, z2), dim=1)
if __name__ == '__main__':
x = torch.randn(1, 64, 256, 256)
model = DetailFeatureExtraction(3)
print(model(x).shape)
# =============================================================================
# =============================================================================
import numbers
##########################################################################
## Layer Norm
def to_3d(x):
return rearrange(x, 'b c h w -> b (h w) c')
def to_4d(x, h, w):
return rearrange(x, 'b (h w) c -> b c h w', h=h, w=w)
class BiasFree_LayerNorm(nn.Module):
def __init__(self, normalized_shape):
super(BiasFree_LayerNorm, self).__init__()
if isinstance(normalized_shape, numbers.Integral):
normalized_shape = (normalized_shape,)
normalized_shape = torch.Size(normalized_shape)
assert len(normalized_shape) == 1
self.weight = nn.Parameter(torch.ones(normalized_shape))
self.normalized_shape = normalized_shape
def forward(self, x):
sigma = x.var(-1, keepdim=True, unbiased=False)
return x / torch.sqrt(sigma+1e-5) * self.weight
class WithBias_LayerNorm(nn.Module):
def __init__(self, normalized_shape):
super(WithBias_LayerNorm, self).__init__()
if isinstance(normalized_shape, numbers.Integral):
normalized_shape = (normalized_shape,)
normalized_shape = torch.Size(normalized_shape)
assert len(normalized_shape) == 1
self.weight = nn.Parameter(torch.ones(normalized_shape))
self.bias = nn.Parameter(torch.zeros(normalized_shape))
self.normalized_shape = normalized_shape
def forward(self, x):
mu = x.mean(-1, keepdim=True)
sigma = x.var(-1, keepdim=True, unbiased=False)
return (x - mu) / torch.sqrt(sigma+1e-5) * self.weight + self.bias
class LayerNorm(nn.Module):
def __init__(self, dim, LayerNorm_type):
super(LayerNorm, self).__init__()
if LayerNorm_type == 'BiasFree':
self.body = BiasFree_LayerNorm(dim)
else:
self.body = WithBias_LayerNorm(dim)
def forward(self, x):
h, w = x.shape[-2:]
return to_4d(self.body(to_3d(x)), h, w)
##########################################################################
## Gated-Dconv Feed-Forward Network (GDFN)
class FeedForward(nn.Module):
def __init__(self, dim, ffn_expansion_factor, bias):
super(FeedForward, self).__init__()
hidden_features = int(dim*ffn_expansion_factor)
self.project_in = nn.Conv2d(
dim, hidden_features*2, kernel_size=1, bias=bias)
self.dwconv = nn.Conv2d(hidden_features*2, hidden_features*2, kernel_size=3,
stride=1, padding=1, groups=hidden_features*2, bias=bias)
self.project_out = nn.Conv2d(
hidden_features, dim, kernel_size=1, bias=bias)
def forward(self, x):
x = self.project_in(x)
x1, x2 = self.dwconv(x).chunk(2, dim=1)
x = F.gelu(x1) * x2
x = self.project_out(x)
return x
##########################################################################
## Multi-DConv Head Transposed Self-Attention (MDTA)
class Attention(nn.Module):
def __init__(self, dim, num_heads, bias):
super(Attention, self).__init__()
self.num_heads = num_heads
self.temperature = nn.Parameter(torch.ones(num_heads, 1, 1))
self.qkv = nn.Conv2d(dim, dim*3, kernel_size=1, bias=bias)
self.qkv_dwconv = nn.Conv2d(
dim*3, dim*3, kernel_size=3, stride=1, padding=1, groups=dim*3, bias=bias)
self.project_out = nn.Conv2d(dim, dim, kernel_size=1, bias=bias)
def forward(self, x):
b, c, h, w = x.shape
qkv = self.qkv_dwconv(self.qkv(x))
q, k, v = qkv.chunk(3, dim=1)
q = rearrange(q, 'b (head c) h w -> b head c (h w)',
head=self.num_heads)
k = rearrange(k, 'b (head c) h w -> b head c (h w)',
head=self.num_heads)
v = rearrange(v, 'b (head c) h w -> b head c (h w)',
head=self.num_heads)
q = torch.nn.functional.normalize(q, dim=-1)
k = torch.nn.functional.normalize(k, dim=-1)
attn = (q @ k.transpose(-2, -1)) * self.temperature
attn = attn.softmax(dim=-1)
out = (attn @ v)
out = rearrange(out, 'b head c (h w) -> b (head c) h w',
head=self.num_heads, h=h, w=w)
out = self.project_out(out)
return out
##########################################################################
class TransformerBlock(nn.Module):
def __init__(self, dim, num_heads, ffn_expansion_factor, bias, LayerNorm_type):
super(TransformerBlock, self).__init__()
self.norm1 = LayerNorm(dim, LayerNorm_type)
self.attn = Attention(dim, num_heads, bias)
self.norm2 = LayerNorm(dim, LayerNorm_type)
self.ffn = FeedForward(dim, ffn_expansion_factor, bias)
def forward(self, x):
x = x + self.attn(self.norm1(x))
x = x + self.ffn(self.norm2(x))
return x
##########################################################################
## Overlapped image patch embedding with 3x3 Conv
class OverlapPatchEmbed(nn.Module):
def __init__(self, in_c=3, embed_dim=48, bias=False):
super(OverlapPatchEmbed, self).__init__()
self.proj = nn.Conv2d(in_c, embed_dim, kernel_size=3,
stride=1, padding=1, bias=bias)
def forward(self, x):
x = self.proj(x)
return x
class Restormer_Encoder(nn.Module):
def __init__(self,
inp_channels=1,
out_channels=1,
dim=64,
num_blocks=[4, 4],
heads=[8, 8, 8],
ffn_expansion_factor=2,
bias=False,
LayerNorm_type='WithBias',
):
super(Restormer_Encoder, self).__init__()
self.patch_embed = OverlapPatchEmbed(inp_channels, dim)
self.encoder_level1 = nn.Sequential(*[TransformerBlock(dim=dim, num_heads=heads[0], ffn_expansion_factor=ffn_expansion_factor,
bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_blocks[0])])
self.baseFeature = BaseFeatureExtraction(dim=dim, num_heads = heads[2])
self.detailFeature = DetailFeatureExtraction()
def forward(self, inp_img):
inp_enc_level1 = self.patch_embed(inp_img)
out_enc_level1 = self.encoder_level1(inp_enc_level1)
base_feature = self.baseFeature(out_enc_level1)
detail_feature = self.detailFeature(out_enc_level1)
return base_feature, detail_feature, out_enc_level1
class Restormer_Decoder(nn.Module):
def __init__(self,
inp_channels=1,
out_channels=1,
dim=64,
num_blocks=[4, 4],
heads=[8, 8, 8],
ffn_expansion_factor=2,
bias=False,
LayerNorm_type='WithBias',
):
super(Restormer_Decoder, self).__init__()
self.reduce_channel = nn.Conv2d(int(dim*2), int(dim), kernel_size=1, bias=bias)
self.encoder_level2 = nn.Sequential(*[TransformerBlock(dim=dim, num_heads=heads[1], ffn_expansion_factor=ffn_expansion_factor,
bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_blocks[1])])
self.output = nn.Sequential(
nn.Conv2d(int(dim), int(dim)//2, kernel_size=3,
stride=1, padding=1, bias=bias),
nn.LeakyReLU(),
nn.Conv2d(int(dim)//2, out_channels, kernel_size=3,
stride=1, padding=1, bias=bias),)
self.sigmoid = nn.Sigmoid()
def forward(self, inp_img, base_feature, detail_feature):
out_enc_level0 = torch.cat((base_feature, detail_feature), dim=1)
out_enc_level0 = self.reduce_channel(out_enc_level0)
out_enc_level1 = self.encoder_level2(out_enc_level0)
if inp_img is not None:
out_enc_level1 = self.output(out_enc_level1) + inp_img
else:
out_enc_level1 = self.output(out_enc_level1)
return self.sigmoid(out_enc_level1), out_enc_level0
if __name__ == '__main__':
height = 128
width = 128
window_size = 8
modelE = Restormer_Encoder().cuda()
modelD = Restormer_Decoder().cuda()

434
net_me.py
View File

@ -1,434 +0,0 @@
import torch
import torch.nn as nn
import math
import torch.nn.functional as F
import torch.utils.checkpoint as checkpoint
from timm.models.layers import DropPath, to_2tuple, trunc_normal_
from einops import rearrange
def drop_path(x, drop_prob: float = 0., training: bool = False):
if drop_prob == 0. or not training:
return x
keep_prob = 1 - drop_prob
# work with diff dim tensors, not just 2D ConvNets
shape = (x.shape[0],) + (1,) * (x.ndim - 1)
random_tensor = keep_prob + \
torch.rand(shape, dtype=x.dtype, device=x.device)
random_tensor.floor_() # binarize
output = x.div(keep_prob) * random_tensor
return output
class DropPath(nn.Module):
"""
Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
"""
def __init__(self, drop_prob=None):
super(DropPath, self).__init__()
self.drop_prob = drop_prob
def forward(self, x):
return drop_path(x, self.drop_prob, self.training)
# 改点使用Pooling替换AttentionBase
class Pooling(nn.Module):
def __init__(self, kernel_size=3):
super().__init__()
self.pool = nn.AvgPool2d(
kernel_size, stride=1, padding=kernel_size // 2)
def forward(self, x):
return self.pool(x) - x
class PoolMlp(nn.Module):
"""
实现基于1x1卷积的MLP模块
输入形状为[B, C, H, W]的张量
"""
def __init__(self,
in_features,
hidden_features=None,
out_features=None,
act_layer=nn.GELU,
bias=False,
drop=0.):
"""
初始化PoolMlp模块
参数:
in_features (int): 输入特征的数量
hidden_features (int, 可选): 隐藏层特征的数量默认为None设置为与in_features相同
out_features (int, 可选): 输出特征的数量默认为None设置为与in_features相同
act_layer (nn.Module, 可选): 使用的激活层默认为nn.GELU
bias (bool, 可选): 是否在卷积层中包含偏置项默认为False
drop (float, 可选): Dropout比率默认为0
"""
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
self.fc1 = nn.Conv2d(in_features, hidden_features, 1, bias=bias)
self.act = act_layer()
self.fc2 = nn.Conv2d(hidden_features, out_features, 1, bias=bias)
self.drop = nn.Dropout(drop)
def forward(self, x):
"""
通过PoolMlp模块的前向传播
参数:
x (torch.Tensor): 形状为[B, C, H, W]的输入张量
返回:
torch.Tensor: 形状为[B, C, H, W]的输出张量
"""
x = self.fc1(x) # (B, C, H, W) --> (B, C, H, W)
x = self.act(x)
x = self.drop(x)
x = self.fc2(x) # (B, C, H, W) --> (B, C, H, W)
x = self.drop(x)
return x
class BaseFeatureExtraction(nn.Module):
def __init__(self, dim, pool_size=3, mlp_ratio=4.,
act_layer=nn.GELU,
# norm_layer=nn.LayerNorm,
drop=0., drop_path=0.,
use_layer_scale=True, layer_scale_init_value=1e-5):
super().__init__()
self.norm1 = LayerNorm(dim, 'WithBias')
self.token_mixer = Pooling(kernel_size=pool_size) # vits是msaMLPs是mlp这个用pool来替代
self.norm2 = LayerNorm(dim, 'WithBias')
mlp_hidden_dim = int(dim * mlp_ratio)
self.poolmlp = PoolMlp(in_features=dim, hidden_features=mlp_hidden_dim,
act_layer=act_layer, drop=drop)
# The following two techniques are useful to train deep PoolFormers.
self.drop_path = DropPath(drop_path) if drop_path > 0. \
else nn.Identity()
self.use_layer_scale = use_layer_scale
if use_layer_scale:
self.layer_scale_1 = nn.Parameter(
torch.ones(dim, dtype=torch.float32) * layer_scale_init_value)
self.layer_scale_2 = nn.Parameter(
torch.ones(dim, dtype=torch.float32) * layer_scale_init_value)
def forward(self, x):
if self.use_layer_scale:
x = x + self.drop_path(
self.layer_scale_1.unsqueeze(-1).unsqueeze(-1)
* self.token_mixer(self.norm1(x)))
x = x + self.drop_path(
self.layer_scale_2.unsqueeze(-1).unsqueeze(-1)
* self.poolmlp(self.norm2(x)))
else:
x = x + self.drop_path(self.token_mixer(self.norm1(x))) # 匹配cddfuse
x = x + self.drop_path(self.poolmlp(self.norm2(x)))
return x
class InvertedResidualBlock(nn.Module):
def __init__(self, inp, oup, expand_ratio):
super(InvertedResidualBlock, self).__init__()
hidden_dim = int(inp * expand_ratio)
self.bottleneckBlock = nn.Sequential(
# pw
nn.Conv2d(inp, hidden_dim, 1, bias=False),
# nn.BatchNorm2d(hidden_dim),
nn.ReLU6(inplace=True),
# dw
nn.ReflectionPad2d(1),
nn.Conv2d(hidden_dim, hidden_dim, 3, groups=hidden_dim, bias=False),
# nn.BatchNorm2d(hidden_dim),
nn.ReLU6(inplace=True),
# pw-linear
nn.Conv2d(hidden_dim, oup, 1, bias=False),
# nn.BatchNorm2d(oup),
)
def forward(self, x):
return self.bottleneckBlock(x)
class DetailNode(nn.Module):
# <img src = "http://42.192.130.83:9000/picgo/imgs/小绿鲸英文文献阅读器_ELTITYqm5G.png" / > '
def __init__(self):
super(DetailNode, self).__init__()
self.theta_phi = InvertedResidualBlock(inp=32, oup=32, expand_ratio=2)
self.theta_rho = InvertedResidualBlock(inp=32, oup=32, expand_ratio=2)
self.theta_eta = InvertedResidualBlock(inp=32, oup=32, expand_ratio=2)
self.shffleconv = nn.Conv2d(64, 64, kernel_size=1,
stride=1, padding=0, bias=True)
def separateFeature(self, x):
z1, z2 = x[:, :x.shape[1] // 2], x[:, x.shape[1] // 2:x.shape[1]]
return z1, z2
def forward(self, z1, z2):
z1, z2 = self.separateFeature(
self.shffleconv(torch.cat((z1, z2), dim=1)))
z2 = z2 + self.theta_phi(z1)
z1 = z1 * torch.exp(self.theta_rho(z2)) + self.theta_eta(z2)
return z1, z2
class DetailFeatureExtraction(nn.Module):
def __init__(self, num_layers=3):
super(DetailFeatureExtraction, self).__init__()
INNmodules = [DetailNode() for _ in range(num_layers)]
self.net = nn.Sequential(*INNmodules)
def forward(self, x):
z1, z2 = x[:, :x.shape[1] // 2], x[:, x.shape[1] // 2:x.shape[1]]
for layer in self.net:
z1, z2 = layer(z1, z2)
return torch.cat((z1, z2), dim=1)
# =============================================================================
# =============================================================================
import numbers
##########################################################################
## Layer Norm
def to_3d(x):
return rearrange(x, 'b c h w -> b (h w) c')
def to_4d(x, h, w):
return rearrange(x, 'b (h w) c -> b c h w', h=h, w=w)
class BiasFree_LayerNorm(nn.Module):
def __init__(self, normalized_shape):
super(BiasFree_LayerNorm, self).__init__()
if isinstance(normalized_shape, numbers.Integral):
normalized_shape = (normalized_shape,)
normalized_shape = torch.Size(normalized_shape)
assert len(normalized_shape) == 1
self.weight = nn.Parameter(torch.ones(normalized_shape))
self.normalized_shape = normalized_shape
def forward(self, x):
sigma = x.var(-1, keepdim=True, unbiased=False)
return x / torch.sqrt(sigma + 1e-5) * self.weight
class WithBias_LayerNorm(nn.Module):
def __init__(self, normalized_shape):
super(WithBias_LayerNorm, self).__init__()
if isinstance(normalized_shape, numbers.Integral):
normalized_shape = (normalized_shape,)
normalized_shape = torch.Size(normalized_shape)
assert len(normalized_shape) == 1
self.weight = nn.Parameter(torch.ones(normalized_shape))
self.bias = nn.Parameter(torch.zeros(normalized_shape))
self.normalized_shape = normalized_shape
def forward(self, x):
mu = x.mean(-1, keepdim=True)
sigma = x.var(-1, keepdim=True, unbiased=False)
return (x - mu) / torch.sqrt(sigma + 1e-5) * self.weight + self.bias
class LayerNorm(nn.Module):
def __init__(self, dim, LayerNorm_type):
super(LayerNorm, self).__init__()
if LayerNorm_type == 'BiasFree':
self.body = BiasFree_LayerNorm(dim)
else:
self.body = WithBias_LayerNorm(dim)
def forward(self, x):
h, w = x.shape[-2:]
return to_4d(self.body(to_3d(x)), h, w)
##########################################################################
## Gated-Dconv Feed-Forward Network (GDFN)
class FeedForward(nn.Module):
def __init__(self, dim, ffn_expansion_factor, bias):
super(FeedForward, self).__init__()
hidden_features = int(dim * ffn_expansion_factor)
self.project_in = nn.Conv2d(
dim, hidden_features * 2, kernel_size=1, bias=bias)
self.dwconv = nn.Conv2d(hidden_features * 2, hidden_features * 2, kernel_size=3,
stride=1, padding=1, groups=hidden_features * 2, bias=bias)
self.project_out = nn.Conv2d(
hidden_features, dim, kernel_size=1, bias=bias)
def forward(self, x):
x = self.project_in(x)
x1, x2 = self.dwconv(x).chunk(2, dim=1)
x = F.gelu(x1) * x2
x = self.project_out(x)
return x
##########################################################################
## Multi-DConv Head Transposed Self-Attention (MDTA)
class Attention(nn.Module):
def __init__(self, dim, num_heads, bias):
super(Attention, self).__init__()
self.num_heads = num_heads
self.temperature = nn.Parameter(torch.ones(num_heads, 1, 1))
self.qkv = nn.Conv2d(dim, dim * 3, kernel_size=1, bias=bias)
self.qkv_dwconv = nn.Conv2d(
dim * 3, dim * 3, kernel_size=3, stride=1, padding=1, groups=dim * 3, bias=bias)
self.project_out = nn.Conv2d(dim, dim, kernel_size=1, bias=bias)
def forward(self, x):
b, c, h, w = x.shape
qkv = self.qkv_dwconv(self.qkv(x))
q, k, v = qkv.chunk(3, dim=1)
q = rearrange(q, 'b (head c) h w -> b head c (h w)',
head=self.num_heads)
k = rearrange(k, 'b (head c) h w -> b head c (h w)',
head=self.num_heads)
v = rearrange(v, 'b (head c) h w -> b head c (h w)',
head=self.num_heads)
q = torch.nn.functional.normalize(q, dim=-1)
k = torch.nn.functional.normalize(k, dim=-1)
attn = (q @ k.transpose(-2, -1)) * self.temperature
attn = attn.softmax(dim=-1)
out = (attn @ v)
out = rearrange(out, 'b head c (h w) -> b (head c) h w',
head=self.num_heads, h=h, w=w)
out = self.project_out(out)
return out
##########################################################################
class TransformerBlock(nn.Module):
def __init__(self, dim, num_heads, ffn_expansion_factor, bias, LayerNorm_type):
super(TransformerBlock, self).__init__()
self.norm1 = LayerNorm(dim, LayerNorm_type)
self.attn = Attention(dim, num_heads, bias)
self.norm2 = LayerNorm(dim, LayerNorm_type)
self.ffn = FeedForward(dim, ffn_expansion_factor, bias)
def forward(self, x):
x = x + self.attn(self.norm1(x))
x = x + self.ffn(self.norm2(x))
return x
##########################################################################
## Overlapped image patch embedding with 3x3 Conv
class OverlapPatchEmbed(nn.Module):
def __init__(self, in_c=3, embed_dim=48, bias=False):
super(OverlapPatchEmbed, self).__init__()
self.proj = nn.Conv2d(in_c, embed_dim, kernel_size=3,
stride=1, padding=1, bias=bias)
def forward(self, x):
x = self.proj(x)
return x
class Restormer_Encoder(nn.Module):
def __init__(self,
inp_channels=1,
out_channels=1,
dim=64,
num_blocks=[4, 4],
heads=[8, 8, 8],
ffn_expansion_factor=2,
bias=False,
LayerNorm_type='WithBias',
):
super(Restormer_Encoder, self).__init__()
self.patch_embed = OverlapPatchEmbed(inp_channels, dim)
self.encoder_level1 = nn.Sequential(
*[TransformerBlock(dim=dim, num_heads=heads[0], ffn_expansion_factor=ffn_expansion_factor,
bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_blocks[0])])
self.baseFeature = BaseFeatureExtraction(dim=dim)
self.detailFeature = DetailFeatureExtraction()
def forward(self, inp_img):
inp_enc_level1 = self.patch_embed(inp_img)
out_enc_level1 = self.encoder_level1(inp_enc_level1)
base_feature = self.baseFeature(out_enc_level1)
detail_feature = self.detailFeature(out_enc_level1)
return base_feature, detail_feature, out_enc_level1
class Restormer_Decoder(nn.Module):
def __init__(self,
inp_channels=1,
out_channels=1,
dim=64,
num_blocks=[4, 4],
heads=[8, 8, 8],
ffn_expansion_factor=2,
bias=False,
LayerNorm_type='WithBias',
):
super(Restormer_Decoder, self).__init__()
self.reduce_channel = nn.Conv2d(int(dim * 2), int(dim), kernel_size=1, bias=bias)
self.encoder_level2 = nn.Sequential(
*[TransformerBlock(dim=dim, num_heads=heads[1], ffn_expansion_factor=ffn_expansion_factor,
bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_blocks[1])])
self.output = nn.Sequential(
nn.Conv2d(int(dim), int(dim) // 2, kernel_size=3,
stride=1, padding=1, bias=bias),
nn.LeakyReLU(),
nn.Conv2d(int(dim) // 2, out_channels, kernel_size=3,
stride=1, padding=1, bias=bias), )
self.sigmoid = nn.Sigmoid()
def forward(self, inp_img, base_feature, detail_feature):
out_enc_level0 = torch.cat((base_feature, detail_feature), dim=1)
out_enc_level0 = self.reduce_channel(out_enc_level0)
out_enc_level1 = self.encoder_level2(out_enc_level0)
if inp_img is not None:
out_enc_level1 = self.output(out_enc_level1) + inp_img
else:
out_enc_level1 = self.output(out_enc_level1)
return self.sigmoid(out_enc_level1), out_enc_level0
if __name__ == '__main__':
height = 128
width = 128
window_size = 8
modelE = Restormer_Encoder().cuda()
modelD = Restormer_Decoder().cuda()
print(modelE)
print(modelD)

View File

@ -1,5 +0,0 @@
scipy==1.9.3
scikit-image==0.19.2
scikit-learn==1.1.3
tqdm==4.62.0

View File

@ -13,13 +13,14 @@ logging.basicConfig(level=logging.CRITICAL)
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
ckpt_path= r"/home/star/whaiDir/PFCFuse/models/PFCFusion10-05-20-46.pth"
ckpt_path= r"PFCFuse_IVF.pth"
for dataset_name in ["TNO","RoadScene"]:
for dataset_name in ["TNO"]:
print("\n"*2+"="*80)
model_name="PFCFuse "
print("The test result of "+dataset_name+' :')
test_folder=os.path.join('/home/star/whaiDir/CDDFuse/test_img/',dataset_name)
test_folder=os.path.join('test_img',dataset_name)
test_out_folder=os.path.join('test_result',dataset_name)
device = 'cuda' if torch.cuda.is_available() else 'cpu'

41
test_shell.py Normal file
View File

@ -0,0 +1,41 @@
import os
import subprocess
import datetime
import sys
# 定义命令
command = "/home/star/anaconda3/envs/pfcfuse/bin/python /home/star/whaiDir/PFCFuse/test_IVF.py"
# 获取当前时间并格式化为文件名
current_time = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
output_file = f"/home/star/whaiDir/PFCFuse/logs/ans_log_{current_time}.log"
try:
# 打开文件用于写入
with open(output_file, 'w') as log_file:
# 运行命令
process = subprocess.Popen(command.split(), stdout=subprocess.PIPE, stderr=subprocess.STDOUT,
universal_newlines=True)
# 循环读取输出
for line in process.stdout:
# 同时打印到控制台和写入文件
print(line.strip())
log_file.write(line)
log_file.flush() # 确保立即写入文件
# 等待命令完成
process.wait()
# 检查返回码
if process.returncode != 0:
raise subprocess.CalledProcessError(process.returncode, command)
# 如果命令成功执行,则打印确认信息
print(f"Command executed successfully. Output has been written to {output_file}")
except subprocess.CalledProcessError as e:
# 如果命令执行失败,则删除文件并打印错误信息
if os.path.exists(output_file):
os.remove(output_file)
print(f"Command failed with return code {e.returncode}. No log file was created.")

View File

@ -34,7 +34,7 @@ criteria_fusion = Fusionloss()
model_str = 'PFCFuse'
# . Set the hyper-parameters for training
num_epochs = 60 # total epoch
num_epochs = 120 # total epoch
epoch_gap = 40 # epoches of Phase I
lr = 1e-4
@ -57,28 +57,6 @@ clip_grad_norm_value = 0.01
optim_step = 20
optim_gamma = 0.5
# 打印所有参数
print(f"Model: {model_str}")
print(f"Number of epochs: {num_epochs}")
print(f"Epoch gap: {epoch_gap}")
print(f"Learning rate: {lr}")
print(f"Weight decay: {weight_decay}")
print(f"Batch size: {batch_size}")
print(f"GPU number: {GPU_number}")
print(f"Coefficient of MSE loss VF: {coeff_mse_loss_VF}")
print(f"Coefficient of MSE loss IF: {coeff_mse_loss_IF}")
print(f"Coefficient of RMI loss VF: {coeff_rmi_loss_VF}")
print(f"Coefficient of RMI loss IF: {coeff_rmi_loss_IF}")
print(f"Coefficient of Cosine loss VF: {coeff_cos_loss_VF}")
print(f"Coefficient of Cosine loss IF: {coeff_cos_loss_IF}")
print(f"Coefficient of Decomposition loss: {coeff_decomp}")
print(f"Coefficient of Total Variation loss: {coeff_tv}")
print(f"Clip gradient norm value: {clip_grad_norm_value}")
print(f"Optimization step: {optim_step}")
print(f"Optimization gamma: {optim_gamma}")
# Model
device = 'cuda' if torch.cuda.is_available() else 'cpu'
@ -109,7 +87,7 @@ Loss_ssim = kornia.losses.SSIM(11, reduction='mean')
HuberLoss = nn.HuberLoss()
# data loader
trainloader = DataLoader(H5Dataset(r"/home/star/whaiDir/CDDFuse/data/MSRS_train_imgsize_128_stride_200.h5"),
trainloader = DataLoader(H5Dataset(r"data/MSRS_train_imgsize_128_stride_200.h5"),
batch_size=batch_size,
shuffle=True,
num_workers=0)
@ -222,19 +200,16 @@ for epoch in range(num_epochs):
time_left = datetime.timedelta(seconds=batches_left * (time.time() - prev_time))
epoch_time = time.time() - prev_time
prev_time = time.time()
if step % 100 == 0:
sys.stdout.write(
"\r[Epoch %d/%d] [Batch %d/%d] [loss: %f] ETA: %.10s"
% (
epoch,
num_epochs,
i,
len(loader['train']),
loss.item(),
time_left,
)
sys.stdout.write(
"\r[Epoch %d/%d] [Batch %d/%d] [loss: %f]"
% (
epoch,
num_epochs,
i,
len(loader['train']),
loss.item(),
)
)
# adjust the learning rate
@ -260,5 +235,5 @@ if True:
'BaseFuseLayer': BaseFuseLayer.state_dict(),
'DetailFuseLayer': DetailFuseLayer.state_dict(),
}
torch.save(checkpoint, os.path.join("models/whaiFusion"+timestamp+'.pth'))
torch.save(checkpoint, os.path.join("models/PFCFusion"+timestamp+'.pth'))

View File

@ -1,15 +0,0 @@
import subprocess
import datetime
# 定义命令
command = "/home/star/anaconda3/envs/pfcfuse/bin/python /home/star/whaiDir/PFCFuse/train.py"
# 获取当前时间并格式化为文件名
current_time = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
output_file = f"/home/star/whaiDir/PFCFuse/logs/log_{current_time}.log"
# 运行命令并将输出重定向到文件
with open(output_file, 'w') as file:
subprocess.run(command.split(), stdout=file, stderr=subprocess.STDOUT)
print(f"Command output has been written to {output_file}")