From 935a60db723f1501b7d1e65acf7aed5183016575 Mon Sep 17 00:00:00 2001 From: George Hotz <72895+geohot@users.noreply.github.com> Date: Fri, 17 Oct 2025 16:19:05 +0800 Subject: [PATCH] bring back partial contig and flash attention (#12756) * bring back partial contig and flash attention * why not 2 * work * that * fix pcontig --- test/test_rangeify.py | 57 ++++++++++++++++++----------------- tinygrad/engine/realize.py | 5 +-- tinygrad/helpers.py | 1 + tinygrad/runtime/ops_null.py | 2 +- tinygrad/schedule/indexing.py | 34 +++++++++++++-------- tinygrad/schedule/multi.py | 8 +++-- 6 files changed, 62 insertions(+), 45 deletions(-) diff --git a/test/test_rangeify.py b/test/test_rangeify.py index ca06d353e4..a0f5b3351b 100644 --- a/test/test_rangeify.py +++ b/test/test_rangeify.py @@ -1,6 +1,6 @@ import unittest from tinygrad import Tensor, nn -from tinygrad.helpers import Context, GlobalCounters, CI +from tinygrad.helpers import Context, GlobalCounters, CI, CPU_LVP from tinygrad.uop.ops import graph_rewrite, PatternMatcher, UPat, Ops class TestRangeifyAssign(unittest.TestCase): @@ -28,6 +28,34 @@ class TestRangeifyEdgeCase(unittest.TestCase): res = Tensor.cat(a, c, dim=0) self.assertEqual(res.numpy()[-1, :16].tolist(), [512] * 16) +@unittest.skipIf(CPU_LVP, "broken in LVP") +class TestPcontig(unittest.TestCase): + def test_flash_attention(self): + BS, HEADS, SEQLEN, EMB = 4, 2, 16, 8 + + # bigger + #BS, HEADS, SEQLEN, EMB = 4, 32, 1024, 64 + + # llama 8B + #BS, HEADS, SEQLEN, EMB = 4, 32, 2048, 128 + + def fa(): + Tensor.manual_seed(1337) + with Context(DEBUG=0): q,k,v = [Tensor.rand(BS, HEADS, SEQLEN, EMB).contiguous().realize() for _ in range(3)] + return q.scaled_dot_product_attention(k, v).realize() + + with Context(PCONTIG=2, DEBUG=2): + GlobalCounters.reset() + ret = fa() + with Context(DEBUG=2): + GlobalCounters.reset() + cmp = fa() + with Context(DEBUG=0): + mse = ((cmp-ret)**2).sum().item() + print(f"mse: {mse}") + self.assertLessEqual(mse, 1e-6) + + # *** non CI rangeify tests below this line *** N = 256 @@ -215,33 +243,6 @@ class TestRangeify(unittest.TestCase): out = blk._feed_forward(x) out.realize() - @unittest.skip("RANGEIFY=0 does nothing") - def test_flash_attention(self): - BS, HEADS, SEQLEN, EMB = 4, 2, 16, 8 - - # bigger - #BS, HEADS, SEQLEN, EMB = 4, 32, 1024, 64 - - # llama 8B - #BS, HEADS, SEQLEN, EMB = 4, 32, 2048, 128 - - def fa(): - Tensor.manual_seed(1337) - with Context(DEBUG=0): q,k,v = [Tensor.rand(BS, HEADS, SEQLEN, EMB).contiguous().realize() for _ in range(3)] - return q.scaled_dot_product_attention(k, v).realize() - - with Context(DEBUG=4): - GlobalCounters.reset() - ret = fa() - with Context(RANGEIFY=0): - with Context(DEBUG=2): - GlobalCounters.reset() - cmp = fa() - with Context(DEBUG=0): - mse = ((cmp-ret)**2).sum().item() - print(f"mse: {mse}") - self.assertLessEqual(mse, 1e-6) - # contiguous + reduce can support ranges? @unittest.skip("pm_rangeify no longer exists. test this in a different way") diff --git a/tinygrad/engine/realize.py b/tinygrad/engine/realize.py index b889908cd7..26eec0d604 100644 --- a/tinygrad/engine/realize.py +++ b/tinygrad/engine/realize.py @@ -182,9 +182,10 @@ class ExecItem: ptm = colored(time_to_str(et, w=9), "yellow" if et > 0.01 else None) if et is not None else "" flops, membw, ldsbw = op_est/(et or 1e-20), mem_est/(et or 1e-20), lds_est/(et or 1e-20) flops_str = f"{flops*1e-9:9.2f} GFLOPS" if flops < 1e14 else colored(f"{flops*1e-12:9.2f} TFLOPS", 'green') - mem_str = f"{membw*1e-9:6.1f}|{ldsbw*1e-9:<7.1f} GB/s" if membw < 1e13 else colored(f"{membw*1e-12:6.1f}|{ldsbw*1e-12:<7.1f} TB/s", 'green') + mem_str = f"{membw*1e-9:6.1f}|{ldsbw*1e-9:<8.1f} GB/s" if membw < 1e13 and ldsbw < 1e15 else \ + colored(f"{membw*1e-12:6.1f}|{ldsbw*1e-12:<8.1f} TB/s", 'green') print(f"{colored(f'*** {self.prg.device[:7]:7s} {GlobalCounters.kernel_count:4d}', header_color)}"+ - f" {self.prg.display_name+' '*(44-ansilen(self.prg.display_name))} arg {len(bufs):2d} mem {GlobalCounters.mem_used/1e9:5.2f} GB"+ + f" {self.prg.display_name+' '*(44-ansilen(self.prg.display_name))} arg {len(bufs):2d} mem {GlobalCounters.mem_used/1e9:6.2f} GB"+ ("" if et is None else f" tm {ptm}/{GlobalCounters.time_sum_s*1e3:9.2f}ms ({flops_str} {mem_str})")+ f" {[repr(m) if TRACEMETA >= 2 else str(m) for m in self.metadata] if self.metadata else ''}") self.prg.first_run = False diff --git a/tinygrad/helpers.py b/tinygrad/helpers.py index e5fffe2330..aeb1fc5d8a 100644 --- a/tinygrad/helpers.py +++ b/tinygrad/helpers.py @@ -169,6 +169,7 @@ VIZ = PROFILE = ContextVar("VIZ", 0) SPEC = ContextVar("SPEC", 0) # TODO: disable by default due to speed IGNORE_OOB = ContextVar("IGNORE_OOB", 1) +PCONTIG = ContextVar("PCONTIG", 0) # partial contiguous in rangeify @dataclass(frozen=True) class Metadata: diff --git a/tinygrad/runtime/ops_null.py b/tinygrad/runtime/ops_null.py index 7d64fee1c0..6377369292 100644 --- a/tinygrad/runtime/ops_null.py +++ b/tinygrad/runtime/ops_null.py @@ -17,7 +17,7 @@ class NullRenderer(CStyleLanguage): class NullProgram: def __init__(self, device:str, name:str, lib:bytes): self.device, self.name = device, name def __call__(self, *bufs, global_size:tuple[int,int,int]=(1,1,1), local_size:tuple[int,int,int]=(1,1,1), vals:tuple[int, ...]=(), wait=False): - with cpu_profile(self.name, self.device): return 1e-4 + with cpu_profile(self.name, self.device): return 1e-3 class NullAllocator(Allocator['NullDevice']): def _alloc(self, size, options): pass diff --git a/tinygrad/schedule/indexing.py b/tinygrad/schedule/indexing.py index baa0c4bb5b..e3533555c3 100644 --- a/tinygrad/schedule/indexing.py +++ b/tinygrad/schedule/indexing.py @@ -4,7 +4,7 @@ from dataclasses import dataclass, field from tinygrad.dtype import dtypes, AddrSpace from tinygrad.uop.ops import PatternMatcher, UPat, Ops, UOp, resolve, GroupOp, graph_rewrite, sint, AxisType from tinygrad.uop.symbolic import symbolic, pm_simplify_valid, pm_drop_and_clauses -from tinygrad.helpers import argsort, all_same, cpu_profile, TracingKey +from tinygrad.helpers import argsort, all_same, cpu_profile, TracingKey, PCONTIG ALWAYS_CONTIGUOUS: set[Ops] = {Ops.CONTIGUOUS, Ops.ASSIGN, Ops.COPY, Ops.BUFFER, Ops.BUFFER_VIEW, Ops.CONST, Ops.BIND, Ops.DEVICE, Ops.MSELECT, Ops.MSTACK, Ops.DEFINE_GLOBAL, @@ -40,12 +40,13 @@ class BufferizeOpts: @dataclass class IndexingContext: - realize_map: dict[UOp, None] = field(default_factory=dict) + realize_map: dict[UOp, None|list[int]] = field(default_factory=dict) range_map: dict[UOp, tuple[tuple[UOp, ...], tuple[UOp, ...]]] = field(default_factory=dict) # create ranges range_idx: Iterator[int] = field(default_factory=itertools.count) def new_range(self, s:sint, axistype:AxisType=AxisType.LOOP): + # if a range has a 1 src, it's the same as UOp.const(dtypes.index, 0) return UOp.range(s, next(self.range_idx), axistype) if resolve(s!=1) else UOp.const(dtypes.index, 0) def create_bufferize_and_index_based_on_ranges(ctx:IndexingContext, x:UOp): @@ -57,8 +58,12 @@ def create_bufferize_and_index_based_on_ranges(ctx:IndexingContext, x:UOp): if s.op in {Ops.BUFFER, Ops.BUFFER_VIEW, Ops.MSTACK, Ops.MSELECT} or (s.op is Ops.ASSIGN and s.src[1].op is Ops.KERNEL): if x in ctx.range_map: new_src = new_src.index(*ctx.range_map[x][0]) elif s in ctx.realize_map: - new_src = UOp(Ops.BUFFERIZE, s.dtype, src=(new_src,)+tuple(ctx.range_map[s][1]), arg=BufferizeOpts(device=s.device), tag=s.tag) - if x in ctx.range_map: new_src = new_src.index(*ctx.range_map[x][0]) + realized_ranges = ctx.realize_map[s] + assert isinstance(realized_ranges, list), "realize map must contain range list" + closed_ranges = tuple([r for i,r in enumerate(ctx.range_map[s][1]) if i in realized_ranges]) + opts = BufferizeOpts(device=s.device) if len(ctx.range_map[s][1]) == len(realized_ranges) else BufferizeOpts(None, AddrSpace.LOCAL) + new_src = UOp(Ops.BUFFERIZE, s.dtype, src=(new_src,)+closed_ranges, arg=opts, tag=s.tag if opts.addrspace == AddrSpace.GLOBAL else None) + if x in ctx.range_map: new_src = new_src.index(*[r for i,r in enumerate(ctx.range_map[x][0]) if i in realized_ranges]) new_srcs.append(new_src) # NOTE: do we need this? return x.replace(src=tns) if x.src != (tns:=tuple(new_srcs)) else None @@ -151,7 +156,8 @@ def run_rangeify(tsink:UOp, debug:bool=False) -> tuple[UOp, IndexingContext]: ending_ranges[x] = any(ending_ranges[u] for u in consumer_map[x]) # if this element has weight and it's ending a range, we (force) realize it - if ending_ranges[x] and x.op in GroupOp.Elementwise.union({Ops.REDUCE_AXIS}): rctx.realize_map[x] = None + if ending_ranges[x] and x.op in GroupOp.Elementwise.union({Ops.REDUCE_AXIS}) and not (PCONTIG>1): + rctx.realize_map[x] = None # *** the ranges on the output are # 1. new if this op is realized @@ -164,6 +170,9 @@ def run_rangeify(tsink:UOp, debug:bool=False) -> tuple[UOp, IndexingContext]: out_rngs = tuple(rctx.new_range(s) if not isinstance(s, UOp) or s.op is not Ops.RANGE else s for s in x.shape) # all ranges are ended now ending_ranges[x] = False + # mark all ranges as ended + assert rctx.realize_map[x] is None + rctx.realize_map[x] = list(range(len(out_rngs))) elif x.op in {Ops.MSTACK, Ops.MSELECT}: # treat MSTACK/MSELECT like SINK continue @@ -175,29 +184,29 @@ def run_rangeify(tsink:UOp, debug:bool=False) -> tuple[UOp, IndexingContext]: out_rngs = consumer_rngs[0] elif len(consumer_rngs) > 1: # if this has two consumers, we have to merge the ranges and might create new ones - all_rngs = list(zip(*consumer_rngs)) + all_rngs: list[tuple[UOp, ...]] = list(zip(*consumer_rngs)) rngs_valids = [] for valid_rngs in all_rngs: local_rngs, valids = zip(*[(r.get_idx(), r.get_valid()) for r in valid_rngs]) - # if a range has a 1 src, it's the same as UOp.const(dtypes.index, 0) - same_rngs = [x if x.op is not Ops.RANGE or resolve(x.src[0] != 1) else UOp.const(dtypes.index, 0) for x in local_rngs] - rngs_valids.append((local_rngs, valids, all_same(same_rngs))) + rngs_valids.append((local_rngs, valids, all_same(local_rngs))) # TODO: in RANGEIFY > 1 all_all_same isn't required all_all_same = all(same_rngs for _,_,same_rngs in rngs_valids) _out_rngs = [] + _new_rngs = [] for i,(local_rngs,valids,same_rngs) in enumerate(rngs_valids): # we compare the ranges without their valids - if all_all_same: + if all_all_same or (PCONTIG and same_rngs): # the new valid is the OR of all the children valids minimum_valid = functools.reduce(operator.or_, valids, UOp.const(dtypes.bool, False)) _out_rngs.append(graph_rewrite(minimum_valid.where(local_rngs[0], UOp.invalid()), symbolic, name="minimum_valid")) else: _out_rngs.append(rctx.new_range(x.shape[i])) + _new_rngs.append(i) out_rngs = tuple(_out_rngs) - # we have to realize here if there's new ranges - if not all_all_same: rctx.realize_map[x] = None + # we have to (partially) realize here if there's new ranges + if len(_new_rngs): rctx.realize_map[x] = _new_rngs # TODO: some ops don't have shape, enable this after the `.st` property is removed #assert len(out_rngs) == len(x.shape), \ @@ -213,6 +222,7 @@ def run_rangeify(tsink:UOp, debug:bool=False) -> tuple[UOp, IndexingContext]: # apply movement ops if x.op in GroupOp.Movement: rngs = apply_movement_op(x.op, x.src[0].shape, x.marg, rngs) # if the EXPAND is used to inject a range, we don't mark it as ending_ranges. otherwise we do. + # NOTE: this doesn't actually always end a range, but this is why convs are realized, so for now we need it if x.op is Ops.EXPAND and all(isinstance(y, int) or y.op is not Ops.RANGE for y in x.shape): ending_ranges[x] = True # REDUCE_AXIS creates ranges for the axes it is reducing diff --git a/tinygrad/schedule/multi.py b/tinygrad/schedule/multi.py index 6db4c24f25..2fc58f46b0 100644 --- a/tinygrad/schedule/multi.py +++ b/tinygrad/schedule/multi.py @@ -1,7 +1,7 @@ from typing import cast import functools, itertools, operator from tinygrad.helpers import all_same, all_int, prod, DEBUG, RING, getenv -from tinygrad.uop.ops import Ops, UOp, sint, PatternMatcher, UPat, GroupOp, track_rewrites, graph_rewrite_map +from tinygrad.uop.ops import Ops, UOp, sint, PatternMatcher, UPat, GroupOp, track_rewrites, graph_rewrite_map, graph_rewrite from tinygrad.device import Device # *** allreduce implementation *** @@ -219,4 +219,8 @@ multi_pm = PatternMatcher([ ])+replace_allreduce @track_rewrites() -def get_multi_map(big_sink:UOp) -> dict[UOp, UOp]: return graph_rewrite_map(big_sink, multi_pm, name="multi_pm") +def get_multi_map(big_sink:UOp) -> dict[UOp, UOp]: + if getenv("VIZ"): graph_rewrite(big_sink, PatternMatcher([]), name="View Multi AST") + ret = graph_rewrite_map(big_sink, multi_pm, name="multi_pm") + if getenv("VIZ"): graph_rewrite(ret[big_sink], PatternMatcher([]), name="View Post Multi AST") + return ret