mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-15 01:48:23 -05:00
more fa multi fix (#14152)
This commit is contained in:
@@ -413,7 +413,7 @@ def flash_attention(xq, xk, xv, attn_mask:Tensor|None=None, is_causal:bool=False
|
||||
if attn_mask.dtype == dtypes.bool: attn_mask = attn_mask.where(0, -float("inf"))
|
||||
else:
|
||||
attn_mask = Tensor.zeros((B, 1, N, N), requires_grad=False, device=single_device, dtype=dtypes.float32)
|
||||
if isinstance(xq.device, tuple):
|
||||
if isinstance(xq.device, tuple) and not isinstance(attn_mask.device, tuple):
|
||||
attn_mask = attn_mask.shard(xq.device, axis=0)
|
||||
|
||||
attn = _sharded_empty_like(xq, axis=0)
|
||||
|
||||
@@ -9,6 +9,10 @@ from tinygrad.engine.realize import ExecItem
|
||||
|
||||
# **** schedule linearizer
|
||||
|
||||
def _unwrap_src(s: UOp) -> UOp:
|
||||
while len(s.src) and s.op not in {Ops.AFTER, Ops.BUFFER, Ops.MSELECT, Ops.MSTACK, Ops.BIND}: s = s.src[0]
|
||||
return s
|
||||
|
||||
def create_schedule(sched_sink:UOp) -> tuple[list[ExecItem], UOp]:
|
||||
with cpu_profile(TracingKey("toposort sched_sink")):
|
||||
# construct the KERNEL children graph based on assigns
|
||||
@@ -22,7 +26,7 @@ def create_schedule(sched_sink:UOp) -> tuple[list[ExecItem], UOp]:
|
||||
k = u.src[1]
|
||||
in_degree.setdefault(k, 0)
|
||||
for s in k.src[0].src if k.op is Ops.END else k.src:
|
||||
while len(s.src) and s.op not in {Ops.AFTER, Ops.BUFFER, Ops.MSELECT, Ops.MSTACK, Ops.BIND}: s = s.src[0]
|
||||
s = _unwrap_src(s)
|
||||
if s.op is Ops.AFTER:
|
||||
children.setdefault(s.src[1], []).append(k)
|
||||
in_degree[k] += 1
|
||||
@@ -50,7 +54,7 @@ def create_schedule(sched_sink:UOp) -> tuple[list[ExecItem], UOp]:
|
||||
if k.op is Ops.RANGE: schedule.append(k)
|
||||
elif k.op is Ops.KERNEL:
|
||||
ast = k.arg.ast
|
||||
buf_uops = tuple(s.buf_uop for s in k.src if s.op is not Ops.BIND)
|
||||
buf_uops = tuple(_unwrap_src(s).buf_uop for s in k.src if s.op is not Ops.BIND)
|
||||
bound_ranges = tuple(s for s in k.src if s.op is Ops.BIND and len(s.src) > 1 and s.src[1].op is Ops.RANGE)
|
||||
schedule.append((ast, buf_uops, k.arg.metadata, {}, bound_ranges))
|
||||
if rk.op is Ops.END: schedule.append(rk)
|
||||
|
||||
Reference in New Issue
Block a user