from functools import partial
from typing import Optional
import torch
from torch import nn, einsum, Tensor
import torch.nn.functional as F
from collections import namedtuple
from functools import wraps
from packaging import version
from dataclasses import dataclass
from einops import rearrange
# constants
EfficientAttentionConfig = namedtuple(
"EfficientAttentionConfig",
["enable_flash", "enable_math", "enable_mem_efficient"],
)
@dataclass
class Intermediates:
qk_similarities: Optional[Tensor] = None
pre_softmax_attn: Optional[Tensor] = None
post_softmax_attn: Optional[Tensor] = None
def to_tuple(self):
return (
self.qk_similarities,
self.pre_softmax_attn,
self.post_softmax_attn,
)
# helpers
def exists(val):
return val is not None
def default(val, d):
return val if exists(val) else d
def compact(arr):
return [*filter(exists, arr)]
def once(fn):
called = False
@wraps(fn)
def inner(x):
nonlocal called
if called:
return
called = True
return fn(x)
return inner
print_once = once(print)
# functions for creating causal mask
# need a special one for onnx cpu (no support for .triu)
def create_causal_mask(i, j, device):
return torch.ones((i, j), device=device, dtype=torch.bool).triu(
j - i + 1
)
def onnx_create_causal_mask(i, j, device):
r = torch.arange(i, device=device)
causal_mask = rearrange(r, "i -> i 1") < rearrange(r, "j -> 1 j")
causal_mask = F.pad(causal_mask, (j - i, 0), value=False)
return causal_mask
# main class
class Attend(nn.Module):
def __init__(
self,
*,
dropout=0.0,
causal=False,
heads=None,
talking_heads=False,
sparse_topk=None,
scale=None,
qk_norm=False,
flash=False,
add_zero_kv=False,
onnxable=False,
):
super().__init__()
self.scale = scale
self.qk_norm = qk_norm
self.causal = causal
self.create_causal_mask = (
onnx_create_causal_mask
if onnxable
else create_causal_mask
)
self.attn_fn = (
partial(F.softmax, dtype=torch.float32)
if not qk_norm
else F.softmax
)
self.dropout = dropout
self.attn_dropout = nn.Dropout(dropout)
# talking heads
assert not (
flash and talking_heads
), "talking heads not compatible with flash attention"
self.talking_heads = talking_heads
if talking_heads:
self.pre_softmax_talking_heads = nn.Conv2d(
heads, heads, 1, bias=False
)
self.post_softmax_talking_heads = nn.Conv2d(
heads, heads, 1, bias=False
)
# sparse topk
assert not (
flash and sparse_topk
), "sparse topk not compatible with flash attention"
self.sparse_topk = sparse_topk
# add a key / value token composed of zeros
# in case this helps controlling outliers, proposed by https://www.evanmiller.org/attention-is-off-by-one.html
self.add_zero_kv = add_zero_kv
# flash attention
self.flash = flash
assert not (
flash
and version.parse(torch.__version__)
< version.parse("2.0.0")
), (
"in order to use flash attention, you must be using"
" pytorch 2.0 or above"
)
# determine efficient attention configs for cuda and cpu
self.cpu_config = EfficientAttentionConfig(True, True, True)
self.cuda_config = None
if not torch.cuda.is_available() or not flash:
return
device_properties = torch.cuda.get_device_properties(
torch.device("cuda")
)
if (
device_properties.major == 8
and device_properties.minor == 0
):
print_once(
"A100 GPU detected, using flash attention if input"
" tensor is on cuda"
)
self.cuda_config = EfficientAttentionConfig(
True, False, False
)
else:
print_once(
"Non-A100 GPU detected, using math or mem efficient"
" attention if input tensor is on cuda"
)
self.cuda_config = EfficientAttentionConfig(
False, True, True
)
def flash_attn(self, q, k, v, mask=None, attn_bias=None):
batch, heads, q_len, _, k_len, is_cuda, device = (
*q.shape,
k.shape[-2],
q.is_cuda,
q.device,
)
# Recommended for multi-query single-key-value attention by Tri Dao
# kv shape torch.Size([1, 512, 64]) -> torch.Size([1, 8, 512, 64])
if k.ndim == 3:
k = rearrange(k, "b ... -> b 1 ...").expand_as(q)
if v.ndim == 3:
v = rearrange(v, "b ... -> b 1 ...").expand_as(q)
# handle scale - by default they scale by dim_head ** -0.5, but need to take care if using cosine sim attention
if self.qk_norm:
default_scale = q.shape[-1] ** -0.5
q = q * (default_scale / self.scale)
# Check if mask exists and expand to compatible shape
# The mask is B L, so it would have to be expanded to B H N L
causal = self.causal
if exists(mask):
assert mask.ndim == 4
mask = mask.expand(batch, heads, q_len, k_len)
# manually handle causal mask, if another mask was given
if causal:
causal_mask = self.create_causal_mask(
q_len, k_len, device=device
)
mask = mask & ~causal_mask
causal = False
# handle alibi positional bias
# convert from bool to float
if exists(attn_bias):
attn_bias = rearrange(
attn_bias, "h i j -> 1 h i j"
).expand(batch, heads, -1, -1)
# if mask given, the mask would already contain the causal mask from above logic
# otherwise, if no mask given but still causal, mask out alibi positional bias to a large negative number
mask_value = -torch.finfo(q.dtype).max
if exists(mask):
attn_bias = attn_bias.masked_fill(
~mask, mask_value // 2
)
elif causal:
causal_mask = self.create_causal_mask(
q_len, k_len, device=device
)
attn_bias = attn_bias.masked_fill(
causal_mask, mask_value // 2
)
causal = False
# scaled_dot_product_attention handles attn_mask either as bool or additive bias
# make it an additive bias here
mask = attn_bias
# Check if there is a compatible device for flash attention
config = self.cuda_config if is_cuda else self.cpu_config
# pytorch 2.0 flash attn: q, k, v, mask, dropout, causal, softmax_scale
with torch.backends.cuda.sdp_kernel(**config._asdict()):
out = F.scaled_dot_product_attention(
q,
k,
v,
attn_mask=mask,
dropout_p=self.dropout if self.training else 0.0,
is_causal=causal,
)
return out, Intermediates()
def forward(
self, q, k, v, mask=None, attn_bias=None, prev_attn=None
):
"""
einstein notation
b - batch
h - heads
n, i, j - sequence length (base sequence length, source, target)
d - feature dimension
"""
n, device = q.shape[-2], q.device
scale = default(self.scale, q.shape[-1] ** -0.5)
if self.add_zero_kv:
k, v = map(
lambda t: F.pad(t, (0, 0, 1, 0), value=0.0), (k, v)
)
if exists(mask):
mask = F.pad(mask, (1, 0), value=True)
if exists(attn_bias):
attn_bias = F.pad(attn_bias, (1, 0), value=0.0)
if self.flash:
assert not exists(prev_attn), (
"residual attention not compatible with flash"
" attention"
)
return self.flash_attn(
q, k, v, mask=mask, attn_bias=attn_bias
)
kv_einsum_eq = "b j d" if k.ndim == 3 else "b h j d"
dots = (
einsum(f"b h i d, {kv_einsum_eq} -> b h i j", q, k)
* scale
)
if exists(prev_attn):
dots = dots + prev_attn
qk_similarities = dots.clone()
if self.talking_heads:
dots = self.pre_softmax_talking_heads(dots)
if exists(attn_bias):
dots = dots + attn_bias
i, j, dtype = *dots.shape[-2:], dots.dtype
mask_value = -torch.finfo(dots.dtype).max
if exists(self.sparse_topk) and self.sparse_topk < j:
top_values, _ = dots.topk(self.sparse_topk, dim=-1)
sparse_topk_mask = dots < top_values[..., -1:]
mask = (
(mask & sparse_topk_mask)
if exists(mask)
else sparse_topk_mask
)
if exists(mask):
dots = dots.masked_fill(~mask, mask_value)
if self.causal:
causal_mask = self.create_causal_mask(i, j, device=device)
dots = dots.masked_fill(causal_mask, mask_value)
pre_softmax_attn = dots.clone()
attn = self.attn_fn(dots, dim=-1)
attn = attn.type(dtype)
post_softmax_attn = attn.clone()
attn = self.attn_dropout(attn)
if self.talking_heads:
attn = self.post_softmax_talking_heads(attn)
out = einsum(f"b h i j, {kv_einsum_eq} -> b h i d", attn, v)
intermediates = Intermediates(
qk_similarities=qk_similarities,
pre_softmax_attn=pre_softmax_attn,
post_softmax_attn=post_softmax_attn,
)
return out, intermediates
# cascading heads logic
def to_single_heads(t, dim=1):
heads = t.unbind(dim=dim)
return tuple(head.unsqueeze(dim) for head in heads)
class CascadingHeads(nn.Module):
def __init__(self, attend: Attend):
super().__init__()
self.attend = attend
def forward(
self, q, k, v, mask=None, attn_bias=None, prev_attn=None
):
assert q.shape[-1] == v.shape[-1], (
"cascading heads can only be done if query / key and"
" value head dimensions are the same"
)
# split inputs into per-head inputs
heads = q.shape[1]
queries = to_single_heads(q)
keys = to_single_heads(k) if k.ndim == 4 else ((k,) * heads)
values = to_single_heads(v) if v.ndim == 4 else ((v,) * heads)
mask = (mask,) * heads
attn_bias = (
to_single_heads(attn_bias, dim=0)
if exists(attn_bias)
else ((None,) * heads)
)
prev_attn = (
to_single_heads(prev_attn)
if exists(prev_attn)
else ((None,) * heads)
)
# now loop through each head, without output of previous head summed with the next head
# thus cascading
all_outs = []
all_intermediates = []
prev_head_out = None
for h_q, h_k, h_v, h_mask, h_attn_bias, h_prev_attn in zip(
queries, keys, values, mask, attn_bias, prev_attn
):
if exists(prev_head_out):
h_q = h_q + prev_head_out
out, intermediates = self.attend(
h_q,
h_k,
h_v,
mask=h_mask,
attn_bias=h_attn_bias,
prev_attn=h_prev_attn,
)
prev_head_out = out
all_outs.append(out)
all_intermediates.append(intermediates)
# cat all output heads
all_outs = torch.cat(all_outs, dim=1)
# cat all intermediates, if they exist
qk_similarities, pre_softmax_attn, post_softmax_attn = zip(
*map(lambda i: i.to_tuple(), all_intermediates)
)
qk_similarities, pre_softmax_attn, post_softmax_attn = map(
compact,
(qk_similarities, pre_softmax_attn, post_softmax_attn),
)
aggregated_intermediates = Intermediates(
qk_similarities=(
torch.cat(qk_similarities, dim=1)
if len(qk_similarities) > 0
else None
),
pre_softmax_attn=(
torch.cat(pre_softmax_attn, dim=1)
if len(pre_softmax_attn) > 0
else None
),
post_softmax_attn=(
torch.cat(post_softmax_attn, dim=1)
if len(post_softmax_attn) > 0
else None
),
)
return all_outs, aggregated_intermediates