mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-10 07:28:15 -05:00
test flash attention backward (#12762)
* test flash attention backward * TODO: fix pcontig * end ranges * render colors * very big * multiout at every level * reset ending ranges * fix tests * ugh
This commit is contained in:
@@ -28,29 +28,67 @@ class TestRangeifyEdgeCase(unittest.TestCase):
|
||||
res = Tensor.cat(a, c, dim=0)
|
||||
self.assertEqual(res.numpy()[-1, :16].tolist(), [512] * 16)
|
||||
|
||||
if getenv("BIG") > 2:
|
||||
# llama 8B (8192)
|
||||
BS, HEADS, SEQLEN, EMB = 4, 32, 8192, 128
|
||||
elif getenv("BIG") > 1:
|
||||
# llama 8B
|
||||
BS, HEADS, SEQLEN, EMB = 4, 32, 2048, 128
|
||||
elif getenv("BIG") > 0:
|
||||
# bigger
|
||||
BS, HEADS, SEQLEN, EMB = 4, 32, 1024, 64
|
||||
else:
|
||||
BS, HEADS, SEQLEN, EMB = 4, 2, 16, 8
|
||||
|
||||
@unittest.skipIf(CPU_LVP, "broken in LVP")
|
||||
class TestPcontig(unittest.TestCase):
|
||||
def test_flash_attention(self):
|
||||
if getenv("BIG") > 1:
|
||||
# llama 8B
|
||||
BS, HEADS, SEQLEN, EMB = 4, 32, 2048, 128
|
||||
elif getenv("BIG") > 0:
|
||||
# bigger
|
||||
BS, HEADS, SEQLEN, EMB = 4, 32, 1024, 64
|
||||
else:
|
||||
BS, HEADS, SEQLEN, EMB = 4, 2, 16, 8
|
||||
def test_flash_attention_bw(self):
|
||||
def fa_bw():
|
||||
Tensor.manual_seed(1337)
|
||||
with Context(DEBUG=0):
|
||||
q,k,v = [Tensor.rand(BS, HEADS, SEQLEN, EMB).contiguous().realize().requires_grad_() for _ in range(3)]
|
||||
attn_output = nn.Linear(HEADS*EMB, HEADS*EMB, bias=False)
|
||||
attn_output.weight.requires_grad_().realize()
|
||||
target = Tensor.rand(BS, SEQLEN, HEADS*EMB).contiguous().realize()
|
||||
|
||||
GlobalCounters.reset()
|
||||
attn = q.scaled_dot_product_attention(k, v).contiguous().contiguous_backward()
|
||||
attn = attn.transpose(1, 2).reshape(BS, SEQLEN, -1)
|
||||
out = attn_output(attn)
|
||||
loss = (out - target).square().mean()
|
||||
loss.backward()
|
||||
#ret = [out, Tensor.stack(q.grad, k.grad, v.grad)]
|
||||
ret = [out, q.grad, k.grad, v.grad]
|
||||
Tensor.realize(*ret)
|
||||
return ret
|
||||
|
||||
with Context(PCONTIG=2, REAL_SUBSTITUTE=1, DEBUG=2):
|
||||
grads = fa_bw()
|
||||
print(f"{GlobalCounters.global_ops/1e9:.2f} GFLOPS")
|
||||
|
||||
with Context(DEBUG=2):
|
||||
cmp_grads = fa_bw()
|
||||
print(f"{GlobalCounters.global_ops/1e9:.2f} GFLOPS")
|
||||
|
||||
with Context(DEBUG=0):
|
||||
mses = [((x-y)**2).sum().item() for x,y in zip(grads, cmp_grads)]
|
||||
mse = sum(mses)
|
||||
print(f"mse: {mse}")
|
||||
self.assertLessEqual(mse, 1e-6)
|
||||
|
||||
def test_flash_attention(self):
|
||||
def fa():
|
||||
Tensor.manual_seed(1337)
|
||||
with Context(DEBUG=0): q,k,v = [Tensor.rand(BS, HEADS, SEQLEN, EMB).contiguous().realize() for _ in range(3)]
|
||||
GlobalCounters.reset()
|
||||
return q.scaled_dot_product_attention(k, v).realize()
|
||||
|
||||
with Context(PCONTIG=2, DEBUG=2):
|
||||
GlobalCounters.reset()
|
||||
ret = fa()
|
||||
print(f"{GlobalCounters.global_ops/1e9:.2f} GFLOPS")
|
||||
with Context(DEBUG=2):
|
||||
GlobalCounters.reset()
|
||||
cmp = fa()
|
||||
print(f"{GlobalCounters.global_ops/1e9:.2f} GFLOPS")
|
||||
with Context(DEBUG=0):
|
||||
mse = ((cmp-ret)**2).sum().item()
|
||||
print(f"mse: {mse}")
|
||||
|
||||
@@ -333,7 +333,7 @@ class TestSchedule(unittest.TestCase):
|
||||
r1 = (x - r0).sum(axis=0).div(2)
|
||||
out0 = r0 + y
|
||||
out1 = r1 + y
|
||||
schedule = check_schedule([out0, out1], 4)
|
||||
schedule = check_schedule([out0, out1], 3)
|
||||
reduceops = [x for si in schedule for x in si.ast.toposort() if x.op in {Ops.REDUCE_AXIS, Ops.REDUCE}]
|
||||
self.assertEqual(len(reduceops), 2) # why is RANGEIFY different?
|
||||
|
||||
|
||||
@@ -170,6 +170,7 @@ SPEC = ContextVar("SPEC", 0)
|
||||
# TODO: disable by default due to speed
|
||||
IGNORE_OOB = ContextVar("IGNORE_OOB", 1)
|
||||
PCONTIG = ContextVar("PCONTIG", 0) # partial contiguous in rangeify
|
||||
REAL_SUBSTITUTE = ContextVar("REAL_SUBSTITUTE", 0)
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class Metadata:
|
||||
|
||||
@@ -151,15 +151,11 @@ def run_rangeify(tsink:UOp, debug:bool=False) -> tuple[UOp, IndexingContext]:
|
||||
tsink_reverse_toposort = tsink.reverse_toposort(consumer_map:=tsink.get_consumer_map())
|
||||
|
||||
# explicit rangeify
|
||||
ending_ranges: dict[UOp, bool] = {}
|
||||
ending_ranges: dict[UOp, list[UOp]] = {}
|
||||
for x in tsink_reverse_toposort:
|
||||
if x.op in {Ops.DEVICE, Ops.UNIQUE}: continue
|
||||
if x.dtype.scalar() == dtypes.index: continue # TODO: why do I need this?
|
||||
ending_ranges[x] = any(ending_ranges[u] for u in consumer_map[x])
|
||||
|
||||
# if this element has weight and it's ending a range, we (force) realize it
|
||||
if ending_ranges[x] and x.op in GroupOp.Elementwise.union({Ops.REDUCE_AXIS}) and not (PCONTIG>1):
|
||||
rctx.realize_map[x] = None
|
||||
ending_ranges[x] = sum([ending_ranges.get(u, []) for u in consumer_map[x]], [])
|
||||
|
||||
# *** the ranges on the output are
|
||||
# 1. new if this op is realized
|
||||
@@ -169,9 +165,9 @@ def run_rangeify(tsink:UOp, debug:bool=False) -> tuple[UOp, IndexingContext]:
|
||||
consumer_rngs = [rctx.range_map[c][0] for c in consumer_map[x] if c in rctx.range_map]
|
||||
if x in rctx.realize_map:
|
||||
# if this is in the realize_map, we create new ranges (at the output)
|
||||
out_rngs = tuple(rctx.new_range(s) if not isinstance(s, UOp) or s.op is not Ops.RANGE else s for s in x.shape)
|
||||
out_rngs = tuple(rctx.new_range(s) for s in x.shape)
|
||||
# all ranges are ended now
|
||||
ending_ranges[x] = False
|
||||
ending_ranges[x] = []
|
||||
# mark all ranges as ended
|
||||
assert rctx.realize_map[x] is None
|
||||
rctx.realize_map[x] = list(range(len(x.shape)))
|
||||
@@ -195,7 +191,7 @@ def run_rangeify(tsink:UOp, debug:bool=False) -> tuple[UOp, IndexingContext]:
|
||||
# TODO: in RANGEIFY > 1 all_all_same isn't required
|
||||
all_all_same = all(all_same(local_rngs) for local_rngs,_ in rngs_valids)
|
||||
_out_rngs = []
|
||||
_new_rngs = []
|
||||
_realize_axis = []
|
||||
for i,(local_rngs,valids) in enumerate(rngs_valids):
|
||||
# we compare the ranges without their valids
|
||||
if all_all_same or (PCONTIG and all_same(local_rngs)):
|
||||
@@ -204,11 +200,23 @@ def run_rangeify(tsink:UOp, debug:bool=False) -> tuple[UOp, IndexingContext]:
|
||||
_out_rngs.append(graph_rewrite(minimum_valid.where(local_rngs[0], UOp.invalid()), symbolic, name="minimum_valid"))
|
||||
else:
|
||||
_out_rngs.append(rctx.new_range(x.shape[i]))
|
||||
_new_rngs.append(i)
|
||||
_realize_axis.append(i)
|
||||
out_rngs = tuple(_out_rngs)
|
||||
|
||||
# we have to (partially) realize here if there's new ranges
|
||||
if len(_new_rngs): rctx.realize_map[x] = _new_rngs
|
||||
if len(_realize_axis): rctx.realize_map[x] = _realize_axis
|
||||
|
||||
# if this element is a reduce and there's ended ranges, we might have to end some other ranges
|
||||
if len(ending_ranges[x]) and x.op in GroupOp.Elementwise.union({Ops.REDUCE_AXIS}):
|
||||
_realize_axis = rctx.realize_map.get(x, []) or []
|
||||
for i,r in enumerate(out_rngs):
|
||||
if i in _realize_axis: continue
|
||||
if not (PCONTIG > 1) or any(any(rr.arg > e.arg for e in ending_ranges[x]) for rr in r.ranges):
|
||||
_realize_axis.append(i)
|
||||
ending_ranges[x] = []
|
||||
if len(_realize_axis):
|
||||
rctx.realize_map[x] = _realize_axis
|
||||
out_rngs = tuple([(rctx.new_range(x.shape[i]) if i in _realize_axis else r) for i,r in enumerate(out_rngs)])
|
||||
|
||||
# TODO: some ops don't have shape, enable this after the `.st` property is removed
|
||||
#assert len(out_rngs) == len(x.shape), \
|
||||
@@ -225,7 +233,8 @@ def run_rangeify(tsink:UOp, debug:bool=False) -> tuple[UOp, IndexingContext]:
|
||||
if x.op in GroupOp.Movement: rngs = apply_movement_op(x.op, x.src[0].shape, x.marg, rngs)
|
||||
# if the EXPAND is used to inject a range, we don't mark it as ending_ranges. otherwise we do.
|
||||
# NOTE: this doesn't actually always end a range, but this is why convs are realized, so for now we need it
|
||||
if x.op is Ops.EXPAND and all(isinstance(y, int) or y.op is not Ops.RANGE for y in x.shape): ending_ranges[x] = True
|
||||
if x.op is Ops.EXPAND and all(isinstance(y, int) or y.op is not Ops.RANGE for y in x.shape):
|
||||
ending_ranges[x] = list(UOp.sink(*[ro for ri, ro in zip(rngs, out_rngs) if ri is not ro]).ranges.keys())
|
||||
|
||||
# REDUCE_AXIS creates ranges for the axes it is reducing
|
||||
if x.op is Ops.REDUCE_AXIS:
|
||||
|
||||
@@ -4,7 +4,7 @@ from tinygrad.dtype import dtypes, PtrDType, ImageDType, AddrSpace
|
||||
from tinygrad.uop.ops import PatternMatcher, UPat, Ops, UOp, resolve, GroupOp, _substitute, ssimplify, KernelInfo
|
||||
from tinygrad.uop.ops import track_rewrites, graph_rewrite, identity_element, sint, AxisType
|
||||
from tinygrad.uop.symbolic import symbolic_simple
|
||||
from tinygrad.helpers import argsort, prod, all_same, pluralize, getenv, flatten, dedup, all_int, DEBUG, SPLIT_REDUCEOP, Metadata
|
||||
from tinygrad.helpers import argsort, prod, all_same, pluralize, getenv, flatten, dedup, all_int, DEBUG, SPLIT_REDUCEOP, Metadata, REAL_SUBSTITUTE
|
||||
from tinygrad.codegen.simplify import pm_flatten_range, pm_reduce_unparented
|
||||
from tinygrad.codegen.opt import Opt
|
||||
from tinygrad.schedule.indexing import run_rangeify, BufferizeOpts, ALWAYS_CONTIGUOUS, IndexingContext, apply_movement_op
|
||||
@@ -178,7 +178,7 @@ def remove_bufferize(src:UOp, buf:UOp, idx:UOp):
|
||||
# if it makes it here, the bufferize is removed
|
||||
# this is the ranges replaced
|
||||
# NOTE: if buf src is a const, we don't replace it
|
||||
if getenv("REAL_SUBSTITUTE"):
|
||||
if REAL_SUBSTITUTE:
|
||||
return src.substitute({k:v for k,v in zip(buf.src[1:], idx.src[1:]) if k.op is not Ops.CONST})
|
||||
else:
|
||||
replaces = flatten([(k,v) for k,v in zip(buf.src[1:], idx.src[1:]) if k.op is not Ops.CONST])
|
||||
|
||||
Reference in New Issue
Block a user