diff --git a/test/test_rangeify.py b/test/test_rangeify.py index 84d2a1dd00..72317f0984 100644 --- a/test/test_rangeify.py +++ b/test/test_rangeify.py @@ -42,29 +42,35 @@ elif getenv("BIG") > 0: else: BS, HEADS, SEQLEN, EMB = 4, 2, 16, 8 +def fa(): + Tensor.manual_seed(1337) + with Context(DEBUG=0): q,k,v = [Tensor.rand(BS, HEADS, SEQLEN, EMB).contiguous().realize() for _ in range(3)] + GlobalCounters.reset() + return q.scaled_dot_product_attention(k, v) + +def fa_bw(): + Tensor.manual_seed(1337) + with Context(DEBUG=0): + q,k,v = [Tensor.rand(BS, HEADS, SEQLEN, EMB).contiguous().realize().requires_grad_() for _ in range(3)] + attn_output = nn.Linear(HEADS*EMB, HEADS*EMB, bias=False) + attn_output.weight.requires_grad_().realize() + target = Tensor.rand(BS, SEQLEN, HEADS*EMB).contiguous().realize() + + GlobalCounters.reset() + attn = q.scaled_dot_product_attention(k, v).contiguous().contiguous_backward() + attn = attn.transpose(1, 2).reshape(BS, SEQLEN, -1) + out = attn_output(attn) + loss = (out - target).square().mean() + loss.backward() + #ret = [out, Tensor.stack(q.grad, k.grad, v.grad, dim=-1)] + #ret = [out, Tensor.stack(q.grad, k.grad, dim=-1), v.grad] + ret = [out, q.grad, k.grad, v.grad] + Tensor.realize(*ret) + return ret + @unittest.skipIf(isinstance(Device[Device.DEFAULT].renderer, (NIRRenderer, PTXRenderer)), "broken in LVP and PTX") class TestPcontig(unittest.TestCase): def test_flash_attention_bw(self): - def fa_bw(): - Tensor.manual_seed(1337) - with Context(DEBUG=0): - q,k,v = [Tensor.rand(BS, HEADS, SEQLEN, EMB).contiguous().realize().requires_grad_() for _ in range(3)] - attn_output = nn.Linear(HEADS*EMB, HEADS*EMB, bias=False) - attn_output.weight.requires_grad_().realize() - target = Tensor.rand(BS, SEQLEN, HEADS*EMB).contiguous().realize() - - GlobalCounters.reset() - attn = q.scaled_dot_product_attention(k, v).contiguous().contiguous_backward() - attn = attn.transpose(1, 2).reshape(BS, SEQLEN, -1) - out = attn_output(attn) - loss = (out - target).square().mean() - loss.backward() - #ret = [out, Tensor.stack(q.grad, k.grad, v.grad, dim=-1)] - #ret = [out, Tensor.stack(q.grad, k.grad, dim=-1), v.grad] - ret = [out, q.grad, k.grad, v.grad] - Tensor.realize(*ret) - return ret - with Context(PCONTIG=max(2, PCONTIG.value), DEBUG=2): grads = fa_bw() print(f"{GlobalCounters.global_ops/1e9:.2f} GFLOPS") @@ -80,17 +86,11 @@ class TestPcontig(unittest.TestCase): self.assertLessEqual(mse, 1e-6) def test_flash_attention(self): - def fa(): - Tensor.manual_seed(1337) - with Context(DEBUG=0): q,k,v = [Tensor.rand(BS, HEADS, SEQLEN, EMB).contiguous().realize() for _ in range(3)] - GlobalCounters.reset() - return q.scaled_dot_product_attention(k, v).realize() - with Context(PCONTIG=2, DEBUG=2): - ret = fa() + ret = fa().realize() print(f"{GlobalCounters.global_ops/1e9:.2f} GFLOPS") with Context(DEBUG=2): - cmp = fa() + cmp = fa().realize() print(f"{GlobalCounters.global_ops/1e9:.2f} GFLOPS") with Context(DEBUG=0): mse = ((cmp-ret)**2).sum().item() diff --git a/tinygrad/codegen/opt/postrange.py b/tinygrad/codegen/opt/postrange.py index f720712b44..269103134c 100644 --- a/tinygrad/codegen/opt/postrange.py +++ b/tinygrad/codegen/opt/postrange.py @@ -87,7 +87,7 @@ class Scheduler: self.ast = self.ast.substitute(dict(zip(self.rngs, rng))) def colors(self) -> list[str]: - output_rngs = flatten([s.src[2:] for s in self.ast.src]) + output_rngs = flatten([list(UOp.sink(*s.src[2:]).ranges) for s in self.ast.src]) ret = [] for x,r in zip(self.axis_types, self.rngs): if self.dont_use_locals and x == AxisType.GLOBAL: ret.append("BLUE") diff --git a/tinygrad/schedule/rangeify.py b/tinygrad/schedule/rangeify.py index df8a49baf3..d28cfdc4af 100644 --- a/tinygrad/schedule/rangeify.py +++ b/tinygrad/schedule/rangeify.py @@ -247,8 +247,8 @@ pm_remove_bufferize = PatternMatcher([ def late_buffer_view(t:UOp, b:UOp): if isinstance(b.device, str) and (b.device.startswith("DISK") or b.device.startswith("TINYFS")): - rngs = b.src[1:] - size = prod(shape := [int(r.vmax+1) for r in rngs]) + shape = b.shape + size = prod(shape) # walk up for the INDEX x = t @@ -301,9 +301,9 @@ pm_limit_bufs = PatternMatcher([(UPat(set.union(GroupOp.Binary, GroupOp.Ternary) def bufferize_to_store(x:UOp, allow_locals=True): rngs = x.src[1:] - shape = tuple([int(r.vmax+1) for r in rngs]) + shape = x.shape size = prod(shape) - assert size > 0, f"no zero sized buffers {shape}" + assert size > 0 and isinstance(size, int), f"no zero sized or symbolic sized buffers {shape}" sdtype = x.dtype.ptr(size=size, addrspace=x.arg.addrspace) if x.src[0].op is Ops.ASSIGN: diff --git a/tinygrad/uop/symbolic.py b/tinygrad/uop/symbolic.py index c0303a85ac..0852a12cb1 100644 --- a/tinygrad/uop/symbolic.py +++ b/tinygrad/uop/symbolic.py @@ -379,7 +379,7 @@ symbolic = symbolic_simple+commutative+PatternMatcher([ ((UPat.var("x", dtypes.index) + UPat.cvar("c")).cast(dtypes.sints, name="cast"), lambda x,c,cast:x.cast(cast.dtype)+c.cast(cast.dtype)), # only RANGE/IF/STORE/KERNEL have side effects (UPat(Ops.AFTER, name="x"), lambda x: x.replace(src=(x.src[0],)+ - tuple(flatten([(y,) if y.op in {Ops.RANGE, Ops.IF, Ops.STORE, Ops.KERNEL, Ops.BARRIER, Ops.END} else y.src for y in x.src[1:]])))), + tuple(flatten([(y,) if y.op in {Ops.RANGE, Ops.IF, Ops.STORE, Ops.KERNEL, Ops.BARRIER, Ops.END, Ops.UNROLL} else y.src for y in x.src[1:]])))), # after with 1 src is just src[0] (UPat(Ops.AFTER, src=(UPat.var("s"),)), lambda s: s), ])+gep_pushing diff --git a/tinygrad/viz/serve.py b/tinygrad/viz/serve.py index 2708e47332..00975704fb 100755 --- a/tinygrad/viz/serve.py +++ b/tinygrad/viz/serve.py @@ -84,8 +84,8 @@ def uop_to_json(x:UOp, ignore_indexing=False) -> dict[int, dict]: label += f"\n{shape_to_str(u.shape)}" if u.op in {Ops.INDEX, Ops.BUFFERIZE}: label += f"\n{u.render()}" - if u.op in {Ops.END, Ops.STORE, Ops.REDUCE}: - label += "\n"+' '.join([f"{colored(s.arg[0], axis_colors[s.arg[-1]])}({s.vmax+1})" for s in u.src[range_start[u.op]:]]) + if u.op in {Ops.END, Ops.STORE, Ops.REDUCE} and len(trngs:=list(UOp.sink(*u.src[range_start[u.op]:]).ranges)): + label += "\n"+' '.join([f"{colored(s.arg[0], axis_colors[s.arg[-1]])}({s.vmax+1})" for s in trngs]) except Exception: label += "\n" if (ref:=ref_map.get(u.arg.ast) if u.op is Ops.KERNEL else None) is not None: label += f"\ncodegen@{ctxs[ref]['name']}"