This commit is contained in:
George Hotz
2025-05-16 14:06:48 -07:00
parent 0f52dab7d2
commit bb2e430ac3
7 changed files with 44 additions and 31 deletions

View File

@@ -7,7 +7,7 @@ from tinygrad.codegen.lowerer import get_contraction_with_reduce, get_contractio
from tinygrad.codegen.symbolic import symbolic_simple
from tinygrad.helpers import Metadata, all_int, all_same, colored, prod, dedup, unwrap, getenv, pluralize, ContextVar, Context, diskcache_put
from tinygrad.helpers import FUSE_CONV_BW, FUSE_ARANGE, DEBUG, DONT_REALIZE_EXPAND, DONT_GROUP_REDUCES, SPLIT_REDUCEOP, CAPTURE_PROCESS_REPLAY, RING
from tinygrad.dtype import ImageDType
from tinygrad.dtype import ImageDType, dtypes
from tinygrad.shape.shapetracker import ShapeTracker
from tinygrad.shape.view import View, strides_for_shape
from tinygrad.spec import type_verify, sched_spec
@@ -86,8 +86,7 @@ sym = symbolic_simple+PatternMatcher([
(UPat(Ops.COPY, name="root", src=(UPat.cvar("x"), UPat(Ops.DEVICE))), lambda root,x: root.const_like(x.arg)),
# store a shrink before COPY, otherwise view after the COPY
(UPat(Ops.COPY, src=(UPat(Ops.VIEW, name="v"), UPat(Ops.DEVICE)), name="copy"), lambda copy,v:
v.contiguous().copy_to_device(copy.device, arg=copy.arg) if prod(v.shape) < prod(v.base.shape) else \
v.base.copy_to_device(copy.device, arg=copy.arg).view(v.st)),
v.contiguous().copy_to_device(copy.device) if prod(v.shape) < prod(v.base.shape) else v.base.copy_to_device(copy.device).view(v.st)),
# remove cast to image when it's already a contiguous image
(UPat(Ops.CAST, name="cast", src=(UPat(Ops.VIEW, name="vm", src=(UPat(Ops.CONTIGUOUS, name="base"),)),)),
lambda cast,base,vm: base.view(vm.st) if isinstance(cast.dtype, ImageDType) and isinstance(base.dtype, ImageDType) else None),
@@ -244,7 +243,7 @@ def create_kernel(x:UOp, b:UOp|None=None):
buffer = b.base if b.size == b.base.size else UOp(Ops.BUFFER_VIEW, b.dtype, (b.base,), (b.size, b.arg.views[0].offset))
return buffer.assign(kernel).reshape(x.shape)
DONT_PLACE_IN_KERNEL = {Ops.KERNEL, Ops.ASSIGN, Ops.BUFFER}
DONT_PLACE_IN_KERNEL = {Ops.KERNEL, Ops.ASSIGN, Ops.BUFFER, Ops.MSELECT}
def append_to_kernel(ctx:dict[UOp, None], x:UOp):
new_srcs: list[UOp] = []
metadata = dict.fromkeys(x.arg.metadata)
@@ -393,6 +392,9 @@ def fix_kernel_ast(k:UOp) -> UOp|None:
if s.op is Ops.ASSIGN:
for out in s.src[1].arg.ast.src: parents_rep[out] = s.buf_uop.view(unwrap(out.st))
parents_rep[s] = s.buf_uop
if s.op is Ops.MSELECT:
for out in s.src[0].src[1].arg.ast.src: parents_rep[out] = s.src[0].buf_uop.view(unwrap(out.st))
parents_rep[s] = s.buf_uop
ast = k.arg.ast.substitute(parents_rep, name="replace realized")
# push views to edges
ast = graph_rewrite(graph_rewrite(ast, view_left, name="Main View Left"), view_right, name="Main View Right")
@@ -496,8 +498,9 @@ def handle_allreduce(buf:UOp, red:UOp) -> UOp|None:
buf = buf.contiguous()
# copy to all devices. if you shrink later, that'll be handled
if not use_ring: return functools.reduce(lambda x,y: x.alu(red.arg, y),
[UOp(Ops.COPY, buf.dtype, (buf, red.src[1]), arg=i) for i in range(len(buf.device))])
use_ring = False
if not use_ring:
return functools.reduce(lambda x,y: x.alu(red.arg, y), [buf.mselect(i).copy_to_device(buf.device) for i in range(len(buf.device))])
# new ring reduce
factor = next((f for f in [32, 16, 8, 4, 2] if numel % f == 0), 1)
@@ -521,10 +524,12 @@ def handle_allreduce(buf:UOp, red:UOp) -> UOp|None:
return functools.reduce(operator.add, [c.copy_to_device(buf.device).pad(pad) for pad,c in zip(pads, reduced_chunks)]).reshape(shape)
replace_allreduce = PatternMatcher([
(UPat(Ops.ALLREDUCE, src=(UPat.var("buf"), UPat()), name="red"), handle_allreduce),
# copy on specific copy
(UPat(Ops.COPY, src=(UPat(Ops.COPY, name="c1"), UPat(Ops.DEVICE, name="out_device")), name="c2"),
lambda c1,c2,out_device: c1.src[0].copy_to_device(out_device, arg=c1.arg) if c2.arg is None else None),
(UPat(Ops.ALLREDUCE, src=(UPat.var("buf"),), name="red"), handle_allreduce),
])
view_mselect = PatternMatcher([
(UPat(Ops.MSELECT, src=(UPat(Ops.VIEW, name="v")), name="m"), lambda m,v:
v.src[0].mselect(m.arg).view(v.arg.substitute({v.device_num():UOp.const(dtypes.int, m.arg)}))),
])
@track_rewrites(name_fxn=get_name)
@@ -533,7 +538,7 @@ def get_becomes_map(big_sink:UOp) -> dict[UOp, UOp]:
tensor_map = graph_rewrite_map(big_sink, replace_allreduce, name="replace_allreduce")
# merge_views + simplify
tensor_map = graph_rewrite_map(tensor_map[big_sink], insert_fuse+do_fuse+merge_views+sym+replace_contiguous, ctx={},
tensor_map = graph_rewrite_map(tensor_map[big_sink], insert_fuse+do_fuse+merge_views+sym+replace_contiguous+view_mselect, ctx={},
input_map=tensor_map, name="merge_views")
# display the cleaned up tensor graph

View File

@@ -118,7 +118,7 @@ def shrink_multi(root:UOp, multi:UOp):
"cannot shrink sharded and non-sharded axis at the same time"
# NOTE: shrink on the shard axis is only allowed when result is a single partition, denoted by the new real
# we just copy it to all the devices, no real. this will be optimized out later
return multi.src[0].copy_to_device(multi.device, arg=multi.bounds.index(root.arg[multi.axis]))
return multi.src[0].mselect(multi.bounds.index(root.arg[multi.axis])).copy_to_device(multi.device)
return UOp.multi(*[x.shrink(tuple((0, x.shape[multi.axis]) if a == multi.axis else s for a,s in enumerate(root.arg))) for x in multi.src],
axis=multi.axis)
@@ -139,7 +139,7 @@ def copy_multi(multi:UOp, device:UOp):
bsz, dcount = multi.shape[multi.axis]//len(multi.device), len(multi.device)
dnum = UOp.variable("_device_num", 0, len(multi.device)-1)
padded = multi.src[0].pad(tuple((0,0) if a != multi.axis else (bsz*dnum, bsz*(dcount-1) - bsz*dnum) for a in range(len(multi.shape))))
ret = padded.allreduce(Ops.ADD, device)
ret = padded.allreduce(Ops.ADD)
return ret if isinstance(device.arg, str) else ret.multi(axis=None)
def assign_multi(dest:UOp, src:UOp):
@@ -160,8 +160,7 @@ multi_pm = PatternMatcher([
(UPat(Ops.FLIP, src=(UPat(Ops.MULTI, name="multi"), ), name="root"), flip_multi),
(UPat(Ops.ASSIGN, src=(UPat(Ops.MULTI, name="dest"), UPat(Ops.MULTI, name="src"))), assign_multi),
(UPat(Ops.COPY, src=(UPat(Ops.MULTI, name="multi"), UPat(Ops.DEVICE, name="device"))), copy_multi),
(UPat(Ops.ALLREDUCE, src=(UPat(Ops.MULTI, name="multi"), UPat(Ops.DEVICE, name="device")), name="red"),
lambda multi,device,red: multi.src[0].allreduce(red.arg, device).multi(axis=multi.axis)),
(UPat(Ops.ALLREDUCE, src=(UPat(Ops.MULTI, name="multi"),), name="red"), lambda multi,red: multi.src[0].allreduce(red.arg).multi(axis=multi.axis)),
(UPat((Ops.CAST, Ops.BITCAST, Ops.CONTIGUOUS, Ops.DETACH, Ops.CONTIGUOUS_BACKWARD, Ops.FUSE),
src=(UPat(Ops.MULTI, name="multi"), ), name="root"), passthrough_multi),
])

View File

@@ -116,6 +116,9 @@ class Ops(FastEnum):
# reduce
REDUCE_AXIS = auto(); REDUCE = auto(); ALLREDUCE = auto() # noqa: E702
# multi
MCAT = auto(); MSELECT = auto()
# helper ops
GEP = auto(); VECTORIZE = auto(); CAT = auto(); PTRCAT = auto() # noqa: E702
@@ -430,9 +433,7 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
def contiguous(self): return self.alu(Ops.CONTIGUOUS)
def contiguous_backward(self): return self.alu(Ops.CONTIGUOUS_BACKWARD)
def fuse(self): return self.alu(Ops.FUSE)
def allreduce(self, op, device:str|tuple[str, ...]|UOp):
assert isinstance(self.device, tuple), f"allreduce must be on tuple {self.device} isn't"
return UOp(Ops.ALLREDUCE, self.dtype, (self, UOp(Ops.DEVICE, arg=device) if not isinstance(device, UOp) else device), op)
def allreduce(self, op:Ops): return UOp(Ops.ALLREDUCE, self.dtype, (self,), op)
# *** from MultiLazyBuffer ***
@@ -488,8 +489,10 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
assert op is Ops.BIND, f"unknown op {op}"
var, val = arg.unbind()
return var.replace(src=(UOp(Ops.VIEW, dtypes.void, (UOp(Ops.DEVICE, arg=device),), ShapeTracker.from_shape(shape)),)).bind(val)
def copy_to_device(self, device:str|tuple[str, ...]|UOp, arg=None):
return UOp(Ops.COPY, self.dtype, (self, UOp(Ops.DEVICE, arg=device) if not isinstance(device, UOp) else device), arg)
def copy_to_device(self, device:str|tuple[str, ...]|UOp):
return UOp(Ops.COPY, self.dtype, (self, UOp(Ops.DEVICE, arg=device) if not isinstance(device, UOp) else device))
def mselect(self, idx:int) -> UOp: return UOp(Ops.MSELECT, self.dtype, (self,), idx)
def device_num(self): return UOp.variable("_device_num", 0, len(self.device)-1)
def clone(self) -> UOp: return self.copy_to_device(self.device)
@property
def metadata(self) -> Metadata|None: return all_metadata.get(self, None)
@@ -530,11 +533,13 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
@functools.cached_property
def _device(self) -> Optional[str|tuple[str, ...]]:
if self.op is Ops.DEVICE: return self.arg
if self.op in {Ops.COPY, Ops.BUFFER, Ops.ALLREDUCE}: return self.src[1].device
if self.op is Ops.MSELECT: return self.src[0].device[self.arg]
if self.op in {Ops.COPY, Ops.BUFFER}: return self.src[1].device
return dsrcs[0]._device if len(dsrcs:=[x for x in self.src if x._device is not None]) != 0 else None
@property
def buf_uop(self) -> UOp:
if self.op is Ops.BUFFER: return self
if self.op is Ops.MSELECT: return self.src[0].buf_uop.mselect(self.arg)
assert self.op is Ops.ASSIGN, f"must be ASSIGN {self.op}"
return self.src[0].base
@property
@@ -570,8 +575,7 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
assert self.op is Ops.DEFINE_VAR, f"op is {self.op}, need DEFINE_VAR"
assert self.arg[1] <= val and val <= self.arg[2], f"bind {val} not in range [{self.arg[1]}, {self.arg[2]}]"
return UOp(Ops.BIND, self.dtype, (self, self.const_like(val)))
def unbind(self, optional=False) -> tuple[Variable, int]:
if optional and self.op is Ops.DEFINE_VAR: return (self, -1)
def unbind(self) -> tuple[Variable, int]:
assert self.op is Ops.BIND and self.src[0].op is Ops.DEFINE_VAR and self.src[1].op is Ops.CONST, f"can't unbind {self}"
return self.src[0], self.src[1].arg
@property

View File

@@ -106,6 +106,7 @@ class ShapeTracker:
unbound_views, var_vals = zip(*[v.unbind() for v in self.views])
if all(len(x) == 0 for x in var_vals): return self, {}
return ShapeTracker(tuple(unbound_views)), merge_dicts(var_vals)
def substitute(self, dvars:dict[Variable, int]): return ShapeTracker(tuple(x.substitute(dvars) for x in self.views))
def real_strides(self, ignore_valid=False) -> tuple[Optional[sint], ...]:
with Context(TRACK_MATCH_STATS=0): return views_to_real_strides(self.views, ignore_valid)

View File

@@ -141,12 +141,15 @@ class View:
def unbind(self) -> tuple[View, dict[Variable, int]]:
var_unboundvar_val = [(v, v.unbind()) for v in self.vars() if v.op is Ops.BIND]
unbound_vars = {v:uv for v,(uv,_) in var_unboundvar_val}
def substitute(x:sint): return x if isinstance(x, int) else x.substitute(unbound_vars)
new_shape = tuple(map(substitute, self.shape))
new_strides = tuple(map(substitute, self.strides))
new_offset = substitute(self.offset)
new_mask = tuple((substitute(x[0]), substitute(x[1])) for x in self.mask) if self.mask is not None else None
return View.create(new_shape, new_strides, new_offset, new_mask), dict(x[1] for x in var_unboundvar_val)
return self.substitute(unbound_vars), dict(x[1] for x in var_unboundvar_val)
def substitute(self, dvars:dict[UOp, UOp]):
def _substitute(x:sint): return x if isinstance(x, int) else x.substitute(dvars)
new_shape = tuple(map(_substitute, self.shape))
new_strides = tuple(map(_substitute, self.strides))
new_offset = _substitute(self.offset)
new_mask = tuple((_substitute(x[0]), _substitute(x[1])) for x in self.mask) if self.mask is not None else None
return View.create(new_shape, new_strides, new_offset, new_mask)
@functools.cache # pylint: disable=method-cache-max-size-none
def __add__(self, vm1:View) -> Optional[View]:

View File

@@ -78,7 +78,8 @@ tensor_uop_spec = buffer_spec+assign_spec+PatternMatcher([
# COPY/ALLREDUCE
(UPat(Ops.COPY, name="copy", src=(UPat.var("x"), UPat(Ops.DEVICE))), lambda copy,x: copy.dtype == x.dtype),
(UPat(Ops.ALLREDUCE, name="red", src=(UPat.var("x"), UPat(Ops.DEVICE))), lambda red,x: red.dtype == x.dtype and isinstance(red.arg, Ops)),
(UPat(Ops.ALLREDUCE, name="red", src=(UPat.var("x"),)), lambda red,x: red.dtype == x.dtype and isinstance(red.arg, Ops)),
(UPat(Ops.MSELECT, name="m"), lambda m: m.arg >= 0 and m.arg < len(m.src[0].device))
])
# ***** uop type spec *****

View File

@@ -15,7 +15,7 @@ uops_colors = {Ops.LOAD: "#ffc0c0", Ops.STORE: "#87CEEB", Ops.CONST: "#e0e0e0",
Ops.INDEX: "#e8ffa0", Ops.WMMA: "#efefc0", Ops.VIEW: "#C8F9D4", Ops.MULTI: "#f6ccff", Ops.KERNEL: "#3e7f55", Ops.IGNORE: "#00C000",
**{x:"#D8F9E4" for x in GroupOp.Movement}, **{x:"#ffffc0" for x in GroupOp.ALU}, Ops.THREEFRY:"#ffff80", Ops.BUFFER_VIEW: "#E5EAFF",
Ops.BLOCK: "#C4A484", Ops.BLOCKEND: "#C4A4A4", Ops.BUFFER: "#B0BDFF", Ops.COPY: "#a040a0", Ops.FUSE: "#FFa500",
Ops.ALLREDUCE: "#ff40a0"}
Ops.ALLREDUCE: "#ff40a0", Ops.MSELECT: "#ffc0e0", Ops.MCAT: "#ffc0e0"}
# VIZ API