193 lines
6.9 KiB
Python
193 lines
6.9 KiB
Python
|
# Copyright (c) 2023, Tri Dao.
|
||
|
|
||
|
"""We want triton==2.1.0 for this
|
||
|
"""
|
||
|
|
||
|
import math
|
||
|
import torch
|
||
|
import torch.nn.functional as F
|
||
|
|
||
|
import triton
|
||
|
import triton.language as tl
|
||
|
|
||
|
from einops import rearrange, repeat
|
||
|
|
||
|
|
||
|
@triton.heuristics({"HAS_DT_BIAS": lambda args: args["dt_bias_ptr"] is not None})
|
||
|
@triton.heuristics({"HAS_D": lambda args: args["D_ptr"] is not None})
|
||
|
@triton.heuristics({"HAS_Z": lambda args: args["z_ptr"] is not None})
|
||
|
@triton.heuristics({"BLOCK_SIZE_DSTATE": lambda args: triton.next_power_of_2(args["dstate"])})
|
||
|
@triton.jit
|
||
|
def _selective_scan_update_kernel(
|
||
|
# Pointers to matrices
|
||
|
state_ptr, x_ptr, dt_ptr, dt_bias_ptr, A_ptr, B_ptr, C_ptr, D_ptr, z_ptr, out_ptr,
|
||
|
# Matrix dimensions
|
||
|
batch, dim, dstate,
|
||
|
# Strides
|
||
|
stride_state_batch, stride_state_dim, stride_state_dstate,
|
||
|
stride_x_batch, stride_x_dim,
|
||
|
stride_dt_batch, stride_dt_dim,
|
||
|
stride_dt_bias_dim,
|
||
|
stride_A_dim, stride_A_dstate,
|
||
|
stride_B_batch, stride_B_dstate,
|
||
|
stride_C_batch, stride_C_dstate,
|
||
|
stride_D_dim,
|
||
|
stride_z_batch, stride_z_dim,
|
||
|
stride_out_batch, stride_out_dim,
|
||
|
# Meta-parameters
|
||
|
DT_SOFTPLUS: tl.constexpr,
|
||
|
BLOCK_SIZE_M: tl.constexpr,
|
||
|
HAS_DT_BIAS: tl.constexpr,
|
||
|
HAS_D: tl.constexpr,
|
||
|
HAS_Z: tl.constexpr,
|
||
|
BLOCK_SIZE_DSTATE: tl.constexpr,
|
||
|
):
|
||
|
pid_m = tl.program_id(axis=0)
|
||
|
pid_b = tl.program_id(axis=1)
|
||
|
state_ptr += pid_b * stride_state_batch
|
||
|
x_ptr += pid_b * stride_x_batch
|
||
|
dt_ptr += pid_b * stride_dt_batch
|
||
|
B_ptr += pid_b * stride_B_batch
|
||
|
C_ptr += pid_b * stride_C_batch
|
||
|
if HAS_Z:
|
||
|
z_ptr += pid_b * stride_z_batch
|
||
|
out_ptr += pid_b * stride_out_batch
|
||
|
|
||
|
offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
|
||
|
offs_n = tl.arange(0, BLOCK_SIZE_DSTATE)
|
||
|
state_ptrs = state_ptr + (offs_m[:, None] * stride_state_dim + offs_n[None, :] * stride_state_dstate)
|
||
|
x_ptrs = x_ptr + offs_m * stride_x_dim
|
||
|
dt_ptrs = dt_ptr + offs_m * stride_dt_dim
|
||
|
if HAS_DT_BIAS:
|
||
|
dt_bias_ptrs = dt_bias_ptr + offs_m * stride_dt_bias_dim
|
||
|
A_ptrs = A_ptr + (offs_m[:, None] * stride_A_dim + offs_n[None, :] * stride_A_dstate)
|
||
|
B_ptrs = B_ptr + offs_n * stride_B_dstate
|
||
|
C_ptrs = C_ptr + offs_n * stride_C_dstate
|
||
|
if HAS_D:
|
||
|
D_ptrs = D_ptr + offs_m * stride_D_dim
|
||
|
if HAS_Z:
|
||
|
z_ptrs = z_ptr + offs_m * stride_z_dim
|
||
|
out_ptrs = out_ptr + offs_m * stride_out_dim
|
||
|
|
||
|
state = tl.load(state_ptrs, mask=(offs_m[:, None] < dim) & (offs_n[None, :] < dstate), other=0.0)
|
||
|
x = tl.load(x_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)
|
||
|
dt = tl.load(dt_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)
|
||
|
if HAS_DT_BIAS:
|
||
|
dt += tl.load(dt_bias_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)
|
||
|
if DT_SOFTPLUS:
|
||
|
dt = tl.where(dt <= 20.0, tl.math.log1p(tl.exp(dt)), dt)
|
||
|
A = tl.load(A_ptrs, mask=(offs_m[:, None] < dim) & (offs_n[None, :] < dstate), other=0.0).to(tl.float32)
|
||
|
dA = tl.exp(A * dt[:, None])
|
||
|
B = tl.load(B_ptrs, mask=offs_n < dstate, other=0.0).to(tl.float32)
|
||
|
C = tl.load(C_ptrs, mask=offs_n < dstate, other=0.0).to(tl.float32)
|
||
|
if HAS_D:
|
||
|
D = tl.load(D_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)
|
||
|
if HAS_Z:
|
||
|
z = tl.load(z_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)
|
||
|
|
||
|
dB = B[None, :] * dt[:, None]
|
||
|
state = state * dA + dB * x[:, None]
|
||
|
tl.store(state_ptrs, state, mask=(offs_m[:, None] < dim) & (offs_n[None, :] < dstate))
|
||
|
out = tl.sum(state * C[None, :], axis=1)
|
||
|
if HAS_D:
|
||
|
out += x * D
|
||
|
if HAS_Z:
|
||
|
out *= z * tl.sigmoid(z)
|
||
|
tl.store(out_ptrs, out, mask=offs_m < dim)
|
||
|
|
||
|
|
||
|
def selective_state_update(state, x, dt, A, B, C, D=None, z=None, dt_bias=None, dt_softplus=False):
|
||
|
"""
|
||
|
Argument:
|
||
|
state: (batch, dim, dstate)
|
||
|
x: (batch, dim)
|
||
|
dt: (batch, dim)
|
||
|
A: (dim, dstate)
|
||
|
B: (batch, dstate)
|
||
|
C: (batch, dstate)
|
||
|
D: (dim,)
|
||
|
z: (batch, dim)
|
||
|
dt_bias: (dim,)
|
||
|
Return:
|
||
|
out: (batch, dim)
|
||
|
"""
|
||
|
batch, dim, dstate = state.shape
|
||
|
assert x.shape == (batch, dim)
|
||
|
assert dt.shape == x.shape
|
||
|
assert A.shape == (dim, dstate)
|
||
|
assert B.shape == (batch, dstate)
|
||
|
assert C.shape == B.shape
|
||
|
if D is not None:
|
||
|
assert D.shape == (dim,)
|
||
|
if z is not None:
|
||
|
assert z.shape == x.shape
|
||
|
if dt_bias is not None:
|
||
|
assert dt_bias.shape == (dim,)
|
||
|
out = torch.empty_like(x)
|
||
|
grid = lambda META: (triton.cdiv(dim, META['BLOCK_SIZE_M']), batch)
|
||
|
z_strides = ((z.stride(0), z.stride(1)) if z is not None else (0, 0))
|
||
|
# We don't want autotune since it will overwrite the state
|
||
|
# We instead tune by hand.
|
||
|
BLOCK_SIZE_M, num_warps = ((32, 4) if dstate <= 16
|
||
|
else ((16, 4) if dstate <= 32 else
|
||
|
((8, 4) if dstate <= 64 else
|
||
|
((4, 4) if dstate <= 128 else
|
||
|
((4, 8))))))
|
||
|
with torch.cuda.device(x.device.index):
|
||
|
_selective_scan_update_kernel[grid](
|
||
|
state, x, dt, dt_bias, A, B, C, D, z, out,
|
||
|
batch, dim, dstate,
|
||
|
state.stride(0), state.stride(1), state.stride(2),
|
||
|
x.stride(0), x.stride(1),
|
||
|
dt.stride(0), dt.stride(1),
|
||
|
dt_bias.stride(0) if dt_bias is not None else 0,
|
||
|
A.stride(0), A.stride(1),
|
||
|
B.stride(0), B.stride(1),
|
||
|
C.stride(0), C.stride(1),
|
||
|
D.stride(0) if D is not None else 0,
|
||
|
z_strides[0], z_strides[1],
|
||
|
out.stride(0), out.stride(1),
|
||
|
dt_softplus,
|
||
|
BLOCK_SIZE_M,
|
||
|
num_warps=num_warps,
|
||
|
)
|
||
|
return out
|
||
|
|
||
|
|
||
|
def selective_state_update_ref(state, x, dt, A, B, C, D=None, z=None, dt_bias=None, dt_softplus=False):
|
||
|
"""
|
||
|
Argument:
|
||
|
state: (batch, dim, dstate)
|
||
|
x: (batch, dim)
|
||
|
dt: (batch, dim)
|
||
|
A: (dim, dstate)
|
||
|
B: (batch, dstate)
|
||
|
C: (batch, dstate)
|
||
|
D: (dim,)
|
||
|
z: (batch, dim)
|
||
|
dt_bias: (dim,)
|
||
|
Return:
|
||
|
out: (batch, dim)
|
||
|
"""
|
||
|
batch, dim, dstate = state.shape
|
||
|
assert x.shape == (batch, dim)
|
||
|
assert dt.shape == x.shape
|
||
|
assert A.shape == (dim, dstate)
|
||
|
assert B.shape == (batch, dstate)
|
||
|
assert C.shape == B.shape
|
||
|
if D is not None:
|
||
|
assert D.shape == (dim,)
|
||
|
if z is not None:
|
||
|
assert z.shape == x.shape
|
||
|
if dt_bias is not None:
|
||
|
assert dt_bias.shape == (dim,)
|
||
|
dt = dt + dt_bias
|
||
|
dt = F.softplus(dt) if dt_softplus else dt
|
||
|
dA = torch.exp(rearrange(dt, "b d -> b d 1") * A) # (batch, dim, dstate)
|
||
|
dB = rearrange(dt, "b d -> b d 1") * rearrange(B, "b n -> b 1 n") # (batch, dim, dstate)
|
||
|
state.copy_(state * dA + dB * rearrange(x, "b d -> b d 1")) # (batch, dim, dstate
|
||
|
out = torch.einsum("bdn,bn->bd", state.to(C.dtype), C)
|
||
|
if D is not None:
|
||
|
out += (x * D).to(out.dtype)
|
||
|
return (out if z is None else out * F.silu(z)).to(x.dtype)
|