cleanups from flash attention branch (#12897)

This commit is contained in:
George Hotz
2025-10-24 14:14:56 +08:00
committed by GitHub
parent 9dac505565
commit 0bde87d8d7
5 changed files with 36 additions and 36 deletions

View File

@@ -42,29 +42,35 @@ elif getenv("BIG") > 0:
else:
BS, HEADS, SEQLEN, EMB = 4, 2, 16, 8
def fa():
Tensor.manual_seed(1337)
with Context(DEBUG=0): q,k,v = [Tensor.rand(BS, HEADS, SEQLEN, EMB).contiguous().realize() for _ in range(3)]
GlobalCounters.reset()
return q.scaled_dot_product_attention(k, v)
def fa_bw():
Tensor.manual_seed(1337)
with Context(DEBUG=0):
q,k,v = [Tensor.rand(BS, HEADS, SEQLEN, EMB).contiguous().realize().requires_grad_() for _ in range(3)]
attn_output = nn.Linear(HEADS*EMB, HEADS*EMB, bias=False)
attn_output.weight.requires_grad_().realize()
target = Tensor.rand(BS, SEQLEN, HEADS*EMB).contiguous().realize()
GlobalCounters.reset()
attn = q.scaled_dot_product_attention(k, v).contiguous().contiguous_backward()
attn = attn.transpose(1, 2).reshape(BS, SEQLEN, -1)
out = attn_output(attn)
loss = (out - target).square().mean()
loss.backward()
#ret = [out, Tensor.stack(q.grad, k.grad, v.grad, dim=-1)]
#ret = [out, Tensor.stack(q.grad, k.grad, dim=-1), v.grad]
ret = [out, q.grad, k.grad, v.grad]
Tensor.realize(*ret)
return ret
@unittest.skipIf(isinstance(Device[Device.DEFAULT].renderer, (NIRRenderer, PTXRenderer)), "broken in LVP and PTX")
class TestPcontig(unittest.TestCase):
def test_flash_attention_bw(self):
def fa_bw():
Tensor.manual_seed(1337)
with Context(DEBUG=0):
q,k,v = [Tensor.rand(BS, HEADS, SEQLEN, EMB).contiguous().realize().requires_grad_() for _ in range(3)]
attn_output = nn.Linear(HEADS*EMB, HEADS*EMB, bias=False)
attn_output.weight.requires_grad_().realize()
target = Tensor.rand(BS, SEQLEN, HEADS*EMB).contiguous().realize()
GlobalCounters.reset()
attn = q.scaled_dot_product_attention(k, v).contiguous().contiguous_backward()
attn = attn.transpose(1, 2).reshape(BS, SEQLEN, -1)
out = attn_output(attn)
loss = (out - target).square().mean()
loss.backward()
#ret = [out, Tensor.stack(q.grad, k.grad, v.grad, dim=-1)]
#ret = [out, Tensor.stack(q.grad, k.grad, dim=-1), v.grad]
ret = [out, q.grad, k.grad, v.grad]
Tensor.realize(*ret)
return ret
with Context(PCONTIG=max(2, PCONTIG.value), DEBUG=2):
grads = fa_bw()
print(f"{GlobalCounters.global_ops/1e9:.2f} GFLOPS")
@@ -80,17 +86,11 @@ class TestPcontig(unittest.TestCase):
self.assertLessEqual(mse, 1e-6)
def test_flash_attention(self):
def fa():
Tensor.manual_seed(1337)
with Context(DEBUG=0): q,k,v = [Tensor.rand(BS, HEADS, SEQLEN, EMB).contiguous().realize() for _ in range(3)]
GlobalCounters.reset()
return q.scaled_dot_product_attention(k, v).realize()
with Context(PCONTIG=2, DEBUG=2):
ret = fa()
ret = fa().realize()
print(f"{GlobalCounters.global_ops/1e9:.2f} GFLOPS")
with Context(DEBUG=2):
cmp = fa()
cmp = fa().realize()
print(f"{GlobalCounters.global_ops/1e9:.2f} GFLOPS")
with Context(DEBUG=0):
mse = ((cmp-ret)**2).sum().item()

View File

@@ -87,7 +87,7 @@ class Scheduler:
self.ast = self.ast.substitute(dict(zip(self.rngs, rng)))
def colors(self) -> list[str]:
output_rngs = flatten([s.src[2:] for s in self.ast.src])
output_rngs = flatten([list(UOp.sink(*s.src[2:]).ranges) for s in self.ast.src])
ret = []
for x,r in zip(self.axis_types, self.rngs):
if self.dont_use_locals and x == AxisType.GLOBAL: ret.append("BLUE")

View File

@@ -247,8 +247,8 @@ pm_remove_bufferize = PatternMatcher([
def late_buffer_view(t:UOp, b:UOp):
if isinstance(b.device, str) and (b.device.startswith("DISK") or b.device.startswith("TINYFS")):
rngs = b.src[1:]
size = prod(shape := [int(r.vmax+1) for r in rngs])
shape = b.shape
size = prod(shape)
# walk up for the INDEX
x = t
@@ -301,9 +301,9 @@ pm_limit_bufs = PatternMatcher([(UPat(set.union(GroupOp.Binary, GroupOp.Ternary)
def bufferize_to_store(x:UOp, allow_locals=True):
rngs = x.src[1:]
shape = tuple([int(r.vmax+1) for r in rngs])
shape = x.shape
size = prod(shape)
assert size > 0, f"no zero sized buffers {shape}"
assert size > 0 and isinstance(size, int), f"no zero sized or symbolic sized buffers {shape}"
sdtype = x.dtype.ptr(size=size, addrspace=x.arg.addrspace)
if x.src[0].op is Ops.ASSIGN:

View File

@@ -379,7 +379,7 @@ symbolic = symbolic_simple+commutative+PatternMatcher([
((UPat.var("x", dtypes.index) + UPat.cvar("c")).cast(dtypes.sints, name="cast"), lambda x,c,cast:x.cast(cast.dtype)+c.cast(cast.dtype)),
# only RANGE/IF/STORE/KERNEL have side effects
(UPat(Ops.AFTER, name="x"), lambda x: x.replace(src=(x.src[0],)+
tuple(flatten([(y,) if y.op in {Ops.RANGE, Ops.IF, Ops.STORE, Ops.KERNEL, Ops.BARRIER, Ops.END} else y.src for y in x.src[1:]])))),
tuple(flatten([(y,) if y.op in {Ops.RANGE, Ops.IF, Ops.STORE, Ops.KERNEL, Ops.BARRIER, Ops.END, Ops.UNROLL} else y.src for y in x.src[1:]])))),
# after with 1 src is just src[0]
(UPat(Ops.AFTER, src=(UPat.var("s"),)), lambda s: s),
])+gep_pushing

View File

@@ -84,8 +84,8 @@ def uop_to_json(x:UOp, ignore_indexing=False) -> dict[int, dict]:
label += f"\n{shape_to_str(u.shape)}"
if u.op in {Ops.INDEX, Ops.BUFFERIZE}:
label += f"\n{u.render()}"
if u.op in {Ops.END, Ops.STORE, Ops.REDUCE}:
label += "\n"+' '.join([f"{colored(s.arg[0], axis_colors[s.arg[-1]])}({s.vmax+1})" for s in u.src[range_start[u.op]:]])
if u.op in {Ops.END, Ops.STORE, Ops.REDUCE} and len(trngs:=list(UOp.sink(*u.src[range_start[u.op]:]).ranges)):
label += "\n"+' '.join([f"{colored(s.arg[0], axis_colors[s.arg[-1]])}({s.vmax+1})" for s in trngs])
except Exception:
label += "\n<ISSUE GETTING LABEL>"
if (ref:=ref_map.get(u.arg.ast) if u.op is Ops.KERNEL else None) is not None: label += f"\ncodegen@{ctxs[ref]['name']}"