pfcfuse/mamba_ssm/ops/triton/selective_state_update.py

193 lines
6.9 KiB
Python
Raw Normal View History

# Copyright (c) 2023, Tri Dao.
"""We want triton==2.1.0 for this
"""
import math
import torch
import torch.nn.functional as F
import triton
import triton.language as tl
from einops import rearrange, repeat
@triton.heuristics({"HAS_DT_BIAS": lambda args: args["dt_bias_ptr"] is not None})
@triton.heuristics({"HAS_D": lambda args: args["D_ptr"] is not None})
@triton.heuristics({"HAS_Z": lambda args: args["z_ptr"] is not None})
@triton.heuristics({"BLOCK_SIZE_DSTATE": lambda args: triton.next_power_of_2(args["dstate"])})
@triton.jit
def _selective_scan_update_kernel(
# Pointers to matrices
state_ptr, x_ptr, dt_ptr, dt_bias_ptr, A_ptr, B_ptr, C_ptr, D_ptr, z_ptr, out_ptr,
# Matrix dimensions
batch, dim, dstate,
# Strides
stride_state_batch, stride_state_dim, stride_state_dstate,
stride_x_batch, stride_x_dim,
stride_dt_batch, stride_dt_dim,
stride_dt_bias_dim,
stride_A_dim, stride_A_dstate,
stride_B_batch, stride_B_dstate,
stride_C_batch, stride_C_dstate,
stride_D_dim,
stride_z_batch, stride_z_dim,
stride_out_batch, stride_out_dim,
# Meta-parameters
DT_SOFTPLUS: tl.constexpr,
BLOCK_SIZE_M: tl.constexpr,
HAS_DT_BIAS: tl.constexpr,
HAS_D: tl.constexpr,
HAS_Z: tl.constexpr,
BLOCK_SIZE_DSTATE: tl.constexpr,
):
pid_m = tl.program_id(axis=0)
pid_b = tl.program_id(axis=1)
state_ptr += pid_b * stride_state_batch
x_ptr += pid_b * stride_x_batch
dt_ptr += pid_b * stride_dt_batch
B_ptr += pid_b * stride_B_batch
C_ptr += pid_b * stride_C_batch
if HAS_Z:
z_ptr += pid_b * stride_z_batch
out_ptr += pid_b * stride_out_batch
offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
offs_n = tl.arange(0, BLOCK_SIZE_DSTATE)
state_ptrs = state_ptr + (offs_m[:, None] * stride_state_dim + offs_n[None, :] * stride_state_dstate)
x_ptrs = x_ptr + offs_m * stride_x_dim
dt_ptrs = dt_ptr + offs_m * stride_dt_dim
if HAS_DT_BIAS:
dt_bias_ptrs = dt_bias_ptr + offs_m * stride_dt_bias_dim
A_ptrs = A_ptr + (offs_m[:, None] * stride_A_dim + offs_n[None, :] * stride_A_dstate)
B_ptrs = B_ptr + offs_n * stride_B_dstate
C_ptrs = C_ptr + offs_n * stride_C_dstate
if HAS_D:
D_ptrs = D_ptr + offs_m * stride_D_dim
if HAS_Z:
z_ptrs = z_ptr + offs_m * stride_z_dim
out_ptrs = out_ptr + offs_m * stride_out_dim
state = tl.load(state_ptrs, mask=(offs_m[:, None] < dim) & (offs_n[None, :] < dstate), other=0.0)
x = tl.load(x_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)
dt = tl.load(dt_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)
if HAS_DT_BIAS:
dt += tl.load(dt_bias_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)
if DT_SOFTPLUS:
dt = tl.where(dt <= 20.0, tl.math.log1p(tl.exp(dt)), dt)
A = tl.load(A_ptrs, mask=(offs_m[:, None] < dim) & (offs_n[None, :] < dstate), other=0.0).to(tl.float32)
dA = tl.exp(A * dt[:, None])
B = tl.load(B_ptrs, mask=offs_n < dstate, other=0.0).to(tl.float32)
C = tl.load(C_ptrs, mask=offs_n < dstate, other=0.0).to(tl.float32)
if HAS_D:
D = tl.load(D_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)
if HAS_Z:
z = tl.load(z_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)
dB = B[None, :] * dt[:, None]
state = state * dA + dB * x[:, None]
tl.store(state_ptrs, state, mask=(offs_m[:, None] < dim) & (offs_n[None, :] < dstate))
out = tl.sum(state * C[None, :], axis=1)
if HAS_D:
out += x * D
if HAS_Z:
out *= z * tl.sigmoid(z)
tl.store(out_ptrs, out, mask=offs_m < dim)
def selective_state_update(state, x, dt, A, B, C, D=None, z=None, dt_bias=None, dt_softplus=False):
"""
Argument:
state: (batch, dim, dstate)
x: (batch, dim)
dt: (batch, dim)
A: (dim, dstate)
B: (batch, dstate)
C: (batch, dstate)
D: (dim,)
z: (batch, dim)
dt_bias: (dim,)
Return:
out: (batch, dim)
"""
batch, dim, dstate = state.shape
assert x.shape == (batch, dim)
assert dt.shape == x.shape
assert A.shape == (dim, dstate)
assert B.shape == (batch, dstate)
assert C.shape == B.shape
if D is not None:
assert D.shape == (dim,)
if z is not None:
assert z.shape == x.shape
if dt_bias is not None:
assert dt_bias.shape == (dim,)
out = torch.empty_like(x)
grid = lambda META: (triton.cdiv(dim, META['BLOCK_SIZE_M']), batch)
z_strides = ((z.stride(0), z.stride(1)) if z is not None else (0, 0))
# We don't want autotune since it will overwrite the state
# We instead tune by hand.
BLOCK_SIZE_M, num_warps = ((32, 4) if dstate <= 16
else ((16, 4) if dstate <= 32 else
((8, 4) if dstate <= 64 else
((4, 4) if dstate <= 128 else
((4, 8))))))
with torch.cuda.device(x.device.index):
_selective_scan_update_kernel[grid](
state, x, dt, dt_bias, A, B, C, D, z, out,
batch, dim, dstate,
state.stride(0), state.stride(1), state.stride(2),
x.stride(0), x.stride(1),
dt.stride(0), dt.stride(1),
dt_bias.stride(0) if dt_bias is not None else 0,
A.stride(0), A.stride(1),
B.stride(0), B.stride(1),
C.stride(0), C.stride(1),
D.stride(0) if D is not None else 0,
z_strides[0], z_strides[1],
out.stride(0), out.stride(1),
dt_softplus,
BLOCK_SIZE_M,
num_warps=num_warps,
)
return out
def selective_state_update_ref(state, x, dt, A, B, C, D=None, z=None, dt_bias=None, dt_softplus=False):
"""
Argument:
state: (batch, dim, dstate)
x: (batch, dim)
dt: (batch, dim)
A: (dim, dstate)
B: (batch, dstate)
C: (batch, dstate)
D: (dim,)
z: (batch, dim)
dt_bias: (dim,)
Return:
out: (batch, dim)
"""
batch, dim, dstate = state.shape
assert x.shape == (batch, dim)
assert dt.shape == x.shape
assert A.shape == (dim, dstate)
assert B.shape == (batch, dstate)
assert C.shape == B.shape
if D is not None:
assert D.shape == (dim,)
if z is not None:
assert z.shape == x.shape
if dt_bias is not None:
assert dt_bias.shape == (dim,)
dt = dt + dt_bias
dt = F.softplus(dt) if dt_softplus else dt
dA = torch.exp(rearrange(dt, "b d -> b d 1") * A) # (batch, dim, dstate)
dB = rearrange(dt, "b d -> b d 1") * rearrange(B, "b n -> b 1 n") # (batch, dim, dstate)
state.copy_(state * dA + dB * rearrange(x, "b d -> b d 1")) # (batch, dim, dstate
out = torch.einsum("bdn,bn->bd", state.to(C.dtype), C)
if D is not None:
out += (x * D).to(out.dtype)
return (out if z is None else out * F.silu(z)).to(x.dtype)