mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 15:08:02 -05:00
cleanups from flash attention branch (#12897)
This commit is contained in:
@@ -42,29 +42,35 @@ elif getenv("BIG") > 0:
|
||||
else:
|
||||
BS, HEADS, SEQLEN, EMB = 4, 2, 16, 8
|
||||
|
||||
def fa():
|
||||
Tensor.manual_seed(1337)
|
||||
with Context(DEBUG=0): q,k,v = [Tensor.rand(BS, HEADS, SEQLEN, EMB).contiguous().realize() for _ in range(3)]
|
||||
GlobalCounters.reset()
|
||||
return q.scaled_dot_product_attention(k, v)
|
||||
|
||||
def fa_bw():
|
||||
Tensor.manual_seed(1337)
|
||||
with Context(DEBUG=0):
|
||||
q,k,v = [Tensor.rand(BS, HEADS, SEQLEN, EMB).contiguous().realize().requires_grad_() for _ in range(3)]
|
||||
attn_output = nn.Linear(HEADS*EMB, HEADS*EMB, bias=False)
|
||||
attn_output.weight.requires_grad_().realize()
|
||||
target = Tensor.rand(BS, SEQLEN, HEADS*EMB).contiguous().realize()
|
||||
|
||||
GlobalCounters.reset()
|
||||
attn = q.scaled_dot_product_attention(k, v).contiguous().contiguous_backward()
|
||||
attn = attn.transpose(1, 2).reshape(BS, SEQLEN, -1)
|
||||
out = attn_output(attn)
|
||||
loss = (out - target).square().mean()
|
||||
loss.backward()
|
||||
#ret = [out, Tensor.stack(q.grad, k.grad, v.grad, dim=-1)]
|
||||
#ret = [out, Tensor.stack(q.grad, k.grad, dim=-1), v.grad]
|
||||
ret = [out, q.grad, k.grad, v.grad]
|
||||
Tensor.realize(*ret)
|
||||
return ret
|
||||
|
||||
@unittest.skipIf(isinstance(Device[Device.DEFAULT].renderer, (NIRRenderer, PTXRenderer)), "broken in LVP and PTX")
|
||||
class TestPcontig(unittest.TestCase):
|
||||
def test_flash_attention_bw(self):
|
||||
def fa_bw():
|
||||
Tensor.manual_seed(1337)
|
||||
with Context(DEBUG=0):
|
||||
q,k,v = [Tensor.rand(BS, HEADS, SEQLEN, EMB).contiguous().realize().requires_grad_() for _ in range(3)]
|
||||
attn_output = nn.Linear(HEADS*EMB, HEADS*EMB, bias=False)
|
||||
attn_output.weight.requires_grad_().realize()
|
||||
target = Tensor.rand(BS, SEQLEN, HEADS*EMB).contiguous().realize()
|
||||
|
||||
GlobalCounters.reset()
|
||||
attn = q.scaled_dot_product_attention(k, v).contiguous().contiguous_backward()
|
||||
attn = attn.transpose(1, 2).reshape(BS, SEQLEN, -1)
|
||||
out = attn_output(attn)
|
||||
loss = (out - target).square().mean()
|
||||
loss.backward()
|
||||
#ret = [out, Tensor.stack(q.grad, k.grad, v.grad, dim=-1)]
|
||||
#ret = [out, Tensor.stack(q.grad, k.grad, dim=-1), v.grad]
|
||||
ret = [out, q.grad, k.grad, v.grad]
|
||||
Tensor.realize(*ret)
|
||||
return ret
|
||||
|
||||
with Context(PCONTIG=max(2, PCONTIG.value), DEBUG=2):
|
||||
grads = fa_bw()
|
||||
print(f"{GlobalCounters.global_ops/1e9:.2f} GFLOPS")
|
||||
@@ -80,17 +86,11 @@ class TestPcontig(unittest.TestCase):
|
||||
self.assertLessEqual(mse, 1e-6)
|
||||
|
||||
def test_flash_attention(self):
|
||||
def fa():
|
||||
Tensor.manual_seed(1337)
|
||||
with Context(DEBUG=0): q,k,v = [Tensor.rand(BS, HEADS, SEQLEN, EMB).contiguous().realize() for _ in range(3)]
|
||||
GlobalCounters.reset()
|
||||
return q.scaled_dot_product_attention(k, v).realize()
|
||||
|
||||
with Context(PCONTIG=2, DEBUG=2):
|
||||
ret = fa()
|
||||
ret = fa().realize()
|
||||
print(f"{GlobalCounters.global_ops/1e9:.2f} GFLOPS")
|
||||
with Context(DEBUG=2):
|
||||
cmp = fa()
|
||||
cmp = fa().realize()
|
||||
print(f"{GlobalCounters.global_ops/1e9:.2f} GFLOPS")
|
||||
with Context(DEBUG=0):
|
||||
mse = ((cmp-ret)**2).sum().item()
|
||||
|
||||
@@ -87,7 +87,7 @@ class Scheduler:
|
||||
self.ast = self.ast.substitute(dict(zip(self.rngs, rng)))
|
||||
|
||||
def colors(self) -> list[str]:
|
||||
output_rngs = flatten([s.src[2:] for s in self.ast.src])
|
||||
output_rngs = flatten([list(UOp.sink(*s.src[2:]).ranges) for s in self.ast.src])
|
||||
ret = []
|
||||
for x,r in zip(self.axis_types, self.rngs):
|
||||
if self.dont_use_locals and x == AxisType.GLOBAL: ret.append("BLUE")
|
||||
|
||||
@@ -247,8 +247,8 @@ pm_remove_bufferize = PatternMatcher([
|
||||
|
||||
def late_buffer_view(t:UOp, b:UOp):
|
||||
if isinstance(b.device, str) and (b.device.startswith("DISK") or b.device.startswith("TINYFS")):
|
||||
rngs = b.src[1:]
|
||||
size = prod(shape := [int(r.vmax+1) for r in rngs])
|
||||
shape = b.shape
|
||||
size = prod(shape)
|
||||
|
||||
# walk up for the INDEX
|
||||
x = t
|
||||
@@ -301,9 +301,9 @@ pm_limit_bufs = PatternMatcher([(UPat(set.union(GroupOp.Binary, GroupOp.Ternary)
|
||||
|
||||
def bufferize_to_store(x:UOp, allow_locals=True):
|
||||
rngs = x.src[1:]
|
||||
shape = tuple([int(r.vmax+1) for r in rngs])
|
||||
shape = x.shape
|
||||
size = prod(shape)
|
||||
assert size > 0, f"no zero sized buffers {shape}"
|
||||
assert size > 0 and isinstance(size, int), f"no zero sized or symbolic sized buffers {shape}"
|
||||
|
||||
sdtype = x.dtype.ptr(size=size, addrspace=x.arg.addrspace)
|
||||
if x.src[0].op is Ops.ASSIGN:
|
||||
|
||||
@@ -379,7 +379,7 @@ symbolic = symbolic_simple+commutative+PatternMatcher([
|
||||
((UPat.var("x", dtypes.index) + UPat.cvar("c")).cast(dtypes.sints, name="cast"), lambda x,c,cast:x.cast(cast.dtype)+c.cast(cast.dtype)),
|
||||
# only RANGE/IF/STORE/KERNEL have side effects
|
||||
(UPat(Ops.AFTER, name="x"), lambda x: x.replace(src=(x.src[0],)+
|
||||
tuple(flatten([(y,) if y.op in {Ops.RANGE, Ops.IF, Ops.STORE, Ops.KERNEL, Ops.BARRIER, Ops.END} else y.src for y in x.src[1:]])))),
|
||||
tuple(flatten([(y,) if y.op in {Ops.RANGE, Ops.IF, Ops.STORE, Ops.KERNEL, Ops.BARRIER, Ops.END, Ops.UNROLL} else y.src for y in x.src[1:]])))),
|
||||
# after with 1 src is just src[0]
|
||||
(UPat(Ops.AFTER, src=(UPat.var("s"),)), lambda s: s),
|
||||
])+gep_pushing
|
||||
|
||||
@@ -84,8 +84,8 @@ def uop_to_json(x:UOp, ignore_indexing=False) -> dict[int, dict]:
|
||||
label += f"\n{shape_to_str(u.shape)}"
|
||||
if u.op in {Ops.INDEX, Ops.BUFFERIZE}:
|
||||
label += f"\n{u.render()}"
|
||||
if u.op in {Ops.END, Ops.STORE, Ops.REDUCE}:
|
||||
label += "\n"+' '.join([f"{colored(s.arg[0], axis_colors[s.arg[-1]])}({s.vmax+1})" for s in u.src[range_start[u.op]:]])
|
||||
if u.op in {Ops.END, Ops.STORE, Ops.REDUCE} and len(trngs:=list(UOp.sink(*u.src[range_start[u.op]:]).ranges)):
|
||||
label += "\n"+' '.join([f"{colored(s.arg[0], axis_colors[s.arg[-1]])}({s.vmax+1})" for s in trngs])
|
||||
except Exception:
|
||||
label += "\n<ISSUE GETTING LABEL>"
|
||||
if (ref:=ref_map.get(u.arg.ast) if u.op is Ops.KERNEL else None) is not None: label += f"\ncodegen@{ctxs[ref]['name']}"
|
||||
|
||||
Reference in New Issue
Block a user