From a4c4e483856dea5a5a0e01b4800bd60c2f6123ec Mon Sep 17 00:00:00 2001 From: George Hotz <72895+geohot@users.noreply.github.com> Date: Wed, 3 Dec 2025 14:34:34 -0800 Subject: [PATCH] add LUNIQUE op (#13554) --- examples/gradaccum_mnist.py | 15 ++++++++++++++- test/test_schedule.py | 7 ++----- tinygrad/tensor.py | 4 +++- tinygrad/uop/__init__.py | 3 +++ tinygrad/uop/ops.py | 3 ++- tinygrad/uop/spec.py | 5 +++-- tinygrad/viz/serve.py | 2 +- 7 files changed, 28 insertions(+), 11 deletions(-) diff --git a/examples/gradaccum_mnist.py b/examples/gradaccum_mnist.py index c46cd0ced3..376fc5785a 100644 --- a/examples/gradaccum_mnist.py +++ b/examples/gradaccum_mnist.py @@ -1,8 +1,21 @@ import itertools -from examples.beautiful_mnist import Model +from typing import Callable from tinygrad import nn, Tensor, dtypes, Device from tinygrad.helpers import getenv, trange, partition +class Model: + def __init__(self): + self.layers: list[Callable[[Tensor], Tensor]] = [ + nn.Conv2d(1, 32, 5), Tensor.relu, + nn.Conv2d(32, 32, 5), Tensor.relu, + nn.BatchNorm(32), Tensor.max_pool2d, + nn.Conv2d(32, 64, 3), Tensor.relu, + nn.Conv2d(64, 64, 3), Tensor.relu, + nn.BatchNorm(64), Tensor.max_pool2d, + lambda x: x.flatten(1), nn.Linear(576, 10)] + + def __call__(self, x:Tensor) -> Tensor: return x.sequential(self.layers) + # TODO: refactor this into optim/onnx def functional_adam(g:Tensor, m:Tensor, v:Tensor, b1_t:Tensor, b2_t:Tensor, lr=0.001, b1=0.9, b2=0.999, eps=1e-6) -> Tensor: b1_t *= b1 diff --git a/test/test_schedule.py b/test/test_schedule.py index 993bee08f1..5d513808e2 100644 --- a/test/test_schedule.py +++ b/test/test_schedule.py @@ -12,8 +12,7 @@ from tinygrad.device import is_dtype_supported from tinygrad.dtype import DType, ImageDType from tinygrad.uop.ops import UOp, Ops, GroupOp, UPat from tinygrad.helpers import CI, DEBUG, SPLIT_REDUCEOP, GlobalCounters, Context, getenv, all_same, temp -from tinygrad.schedule.rangeify import get_rangeify_map, Kernel -from tinygrad.engine.schedule import create_schedule_with_vars +from tinygrad.schedule.rangeify import Kernel from tinygrad.engine.realize import CompiledRunner, run_schedule, lower_schedule class KernelCountException(Exception): pass @@ -24,9 +23,7 @@ def check_schedule(t:Tensor|list[Tensor]|UOp, allowed:int, to_prerealize:list[Te elif isinstance(t, list) and isinstance(t[0], Tensor): sched = Tensor.schedule(*t) else: assert isinstance(t, UOp), f"can't schedule {t}" - sink = UOp.sink(t) if t.op is not Ops.SINK else t - becomes_map = get_rangeify_map(sink) - sched, _ = create_schedule_with_vars(sink.substitute(becomes_map)) + sched = Tensor(t).schedule() # test lowering all the ScheduleItems to ExecItems kernel_cnt = len([si for si,ei in lower_schedule(sched.copy()) if isinstance(ei.prg, CompiledRunner) or not filter_sink]) if kernel_cnt != allowed: diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index 77801f5fe4..64cc898975 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -27,7 +27,7 @@ def _apply_map_to_tensors(applied_map:dict[UOp, UOp], name:str) -> None: # get tensors in scope in_scope: dict[UOp, bool] = {} def visitor(node: UOp) -> bool: return True if node in applied_map else any(in_scope.get(s, False) for s in node.src) - scope_tensors = [t for tref in list(all_tensors) if (t:=tref()) is not None and t.uop.topovisit(visitor, in_scope)] + scope_tensors: list[Tensor] = [t for tref in list(all_tensors) if (t:=tref()) is not None and t.uop.topovisit(visitor, in_scope)] # get all Tensors and apply the map sink = UOp.sink(*[t.uop for t in scope_tensors]) @@ -242,6 +242,8 @@ class Tensor(OpMixin): NOTE: A Tensor can only be scheduled once. """ big_sink = UOp.sink(*[x.uop for x in (self,)+lst]) + + # this is where the schedule cache should go becomes_map, schedule, var_vals = complete_create_schedule_with_vars(big_sink) _apply_map_to_tensors(becomes_map, name="Apply Schedule Map") return schedule, var_vals diff --git a/tinygrad/uop/__init__.py b/tinygrad/uop/__init__.py index 11eb13617a..201117b13f 100644 --- a/tinygrad/uop/__init__.py +++ b/tinygrad/uop/__init__.py @@ -76,6 +76,9 @@ class Ops(FastEnum): # tensor graph ops UNIQUE = auto(); DEVICE = auto(); KERNEL = auto(); ASSIGN = auto() + # local unique + LUNIQUE = auto() + # ops that adjust the behavior of the scheduler CONTIGUOUS = auto(); CONTIGUOUS_BACKWARD = auto(); DETACH = auto() diff --git a/tinygrad/uop/ops.py b/tinygrad/uop/ops.py index 5ed65c93e4..83130beef0 100644 --- a/tinygrad/uop/ops.py +++ b/tinygrad/uop/ops.py @@ -216,7 +216,7 @@ class UOp(OpMixin, metaclass=UOpMetaClass): def _shape(self) -> tuple[sint, ...]|None: match self.op: # late ops don't have shape - case Ops.UNIQUE | Ops.DEVICE | Ops.RANGE | Ops.LOAD | Ops.IF | Ops.BARRIER | Ops.CUSTOM | Ops.CUSTOMI | \ + case Ops.UNIQUE | Ops.LUNIQUE | Ops.DEVICE | Ops.RANGE | Ops.LOAD | Ops.IF | Ops.BARRIER | Ops.CUSTOM | Ops.CUSTOMI | \ Ops.VECTORIZE | Ops.VCONST | Ops.GEP | Ops.SPECIAL | Ops.UNROLL | Ops.CONTRACT: return None @@ -663,6 +663,7 @@ class UOp(OpMixin, metaclass=UOpMetaClass): assert all_same([x.size for x in ret.bufs]) and all_same([x.dtype for x in ret.bufs]), "multibuffers mismatch buffers" return ret assert self.op is Ops.BUFFER, f"must be BUFFER {self.op}" + assert self.src[0].op is Ops.UNIQUE, "buffer src[0] must be UNIQUE" if (cret:=buffers.get(self)) is not None: return cret rdtype = self.dtype if isinstance(self.dtype, ImageDType) else self.dtype.base if isinstance(self.device, tuple): ret = MultiBuffer(self.device, self.size, rdtype).ref(1) diff --git a/tinygrad/uop/spec.py b/tinygrad/uop/spec.py index 6035936acc..d3ef6ba84e 100644 --- a/tinygrad/uop/spec.py +++ b/tinygrad/uop/spec.py @@ -60,9 +60,10 @@ movement_ops = PatternMatcher([ _tensor_spec = PatternMatcher([ # buffer spec (UPat(Ops.UNIQUE, dtypes.void, ()), lambda: True), + (UPat(Ops.LUNIQUE, dtypes.void, ()), lambda: True), (UPat(Ops.DEVICE, dtypes.void, (), name="d"), lambda d: isinstance(d.arg, str) or (isinstance(d.arg, tuple) and all(isinstance(s, str) for s in d.arg))), - (UPat(Ops.BUFFER, src=(UPat(Ops.UNIQUE), UPat(Ops.DEVICE)), allow_any_len=True, name="buf"), + (UPat(Ops.BUFFER, src=(UPat((Ops.LUNIQUE, Ops.UNIQUE)), UPat(Ops.DEVICE)), allow_any_len=True, name="buf"), lambda buf: isinstance(buf.arg, int) and isinstance(buf.dtype, (DType, ImageDType))), (UPat(Ops.BUFFER_VIEW, src=(UPat(Ops.BUFFER),), name="buf_view"), lambda buf_view: isinstance(buf_view.arg, tuple) and len(buf_view.arg) == 2 and all(isinstance(arg, (int, UOp)) for arg in buf_view.arg)), @@ -85,7 +86,7 @@ _tensor_spec = PatternMatcher([ # device or unique (UPat(Ops.CONST, src=(UPat(Ops.DEVICE),)), lambda: True), - (UPat(Ops.CONST, src=(UPat(Ops.DEVICE), UPat(Ops.UNIQUE))), lambda: True), + (UPat(Ops.CONST, src=(UPat(Ops.DEVICE), UPat((Ops.LUNIQUE, Ops.UNIQUE)))), lambda: True), # DETACH and CONTIGUOUS change how we interpret the source UOp # CONTIGUOUS ensures the source UOp realizes diff --git a/tinygrad/viz/serve.py b/tinygrad/viz/serve.py index 52db1b6b67..473b421e73 100755 --- a/tinygrad/viz/serve.py +++ b/tinygrad/viz/serve.py @@ -69,7 +69,7 @@ def uop_to_json(x:UOp) -> dict[int, dict]: excluded: set[UOp] = set() for u in (toposort:=x.toposort()): # always exclude DEVICE/CONST/UNIQUE - if u.op in {Ops.DEVICE, Ops.CONST, Ops.UNIQUE} and u is not x: excluded.add(u) + if u.op in {Ops.DEVICE, Ops.CONST, Ops.UNIQUE, Ops.LUNIQUE} and u is not x: excluded.add(u) if u.op is Ops.VCONST and u.dtype.scalar() == dtypes.index and u is not x: excluded.add(u) if u.op is Ops.VECTORIZE and len(u.src) == 0: excluded.add(u) for u in toposort: