From a59439d0135e96b0f56bfa296b691bc0e9597e21 Mon Sep 17 00:00:00 2001 From: George Hotz <72895+geohot@users.noreply.github.com> Date: Wed, 15 Oct 2025 10:01:34 +0800 Subject: [PATCH] use UOp.shape property instead of UOp.st (#12664) * work on shape property * reshape causing issues * more mops * all mops * need to cache it * _shape is like _device * mostly works * shape is good * const uses _shape * fix tests * size doesn't use st * close * test is broken * one less st * hack for 3 op assign * oops, i didn't mean to change that * support emulate in the NullDevice * reproed failure in emulation * fix wmma --- .github/workflows/test.yml | 2 + test/test_multitensor.py | 2 +- tinygrad/runtime/ops_null.py | 16 +++-- tinygrad/schedule/rangeify.py | 2 +- tinygrad/uop/ops.py | 116 ++++++++++++++++++++++++++++++---- 5 files changed, 119 insertions(+), 19 deletions(-) diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 72d7f1a458..9ea57b843b 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -272,6 +272,8 @@ jobs: # run: NULL=1 DEBUG=1 python3 examples/sdxl.py --seed 0 --noshow --timing --fakeweights - name: Run Clip tests for SD MLPerf on NULL backend run: NULL=1 python -m pytest -n=auto test/external/mlperf_stable_diffusion/external_test_models.py::TestOpenClip --durations=20 + - name: Run AMD emulated BERT training on NULL backend + run: EMULATE=AMD_RDNA4 NULL=1 CAPTURE_PROCESS_REPLAY=0 DEFAULT_FLOAT=HALF BENCHMARK=10 BS=66 GPUS=1 BERT_LAYERS=2 MODEL=bert python3 examples/mlperf/model_train.py # TODO: support fake weights #- name: Run LLaMA 7B on 4 fake devices # run: NULL=1 python3 examples/llama.py --gen 1 --size 7B --shard 4 --prompt "Hello." --count 3 --temperature 0 --timing diff --git a/test/test_multitensor.py b/test/test_multitensor.py index 2fa3a614b8..f987676dbc 100644 --- a/test/test_multitensor.py +++ b/test/test_multitensor.py @@ -658,7 +658,7 @@ class TestMultiTensor(unittest.TestCase): # it doesn't work like this anymore # NOTE: this never failed in assign_multi, it failed tensor spec because MULTI was never pushed in the graph - @unittest.expectedFailure + @unittest.skip("this test is broken") def test_mlb_assign_change_axis(self): t_none = Tensor.zeros((16, 16)).shard(devices_2).contiguous().realize() t_zero = Tensor.ones((16, 16)).shard(devices_2, axis=0) diff --git a/tinygrad/runtime/ops_null.py b/tinygrad/runtime/ops_null.py index c8f5a6b59f..7d64fee1c0 100644 --- a/tinygrad/runtime/ops_null.py +++ b/tinygrad/runtime/ops_null.py @@ -1,9 +1,11 @@ import functools +from typing import cast from tinygrad.device import Compiled, Compiler, Allocator from tinygrad.engine.jit import MultiGraphRunner -from tinygrad.renderer.cstyle import CStyleLanguage +from tinygrad.renderer.cstyle import Renderer, CStyleLanguage +from tinygrad.renderer.llvmir import AMDLLVMRenderer from tinygrad.uop.ops import Ops -from tinygrad.helpers import cpu_profile +from tinygrad.helpers import cpu_profile, EMULATE class NullRenderer(CStyleLanguage): device = "NULL" @@ -29,5 +31,11 @@ class NullGraph(MultiGraphRunner): def __call__(self, input_rawbuffers, var_vals, wait=False) -> float|None: return 1e-3 class NullDevice(Compiled): - def __init__(self, device:str): super().__init__(device, NullAllocator(self), [(NullRenderer, Compiler)], functools.partial(NullProgram, device), - NullGraph) + def __init__(self, device:str): + renderer:functools.partial|type[Renderer] + match cast(str, EMULATE.value): + case "AMD": renderer = functools.partial(AMDLLVMRenderer, "gfx1100") + case "AMD_RDNA4": renderer = functools.partial(AMDLLVMRenderer, "gfx1201") + case "": renderer = NullRenderer + case _: raise RuntimeError(f"can't EMULATE device: {EMULATE.value}") + super().__init__(device, NullAllocator(self), [(renderer, Compiler)], functools.partial(NullProgram, device), NullGraph) diff --git a/tinygrad/schedule/rangeify.py b/tinygrad/schedule/rangeify.py index 6e4ad755f7..56eb8d24a1 100644 --- a/tinygrad/schedule/rangeify.py +++ b/tinygrad/schedule/rangeify.py @@ -60,7 +60,7 @@ earliest_rewrites = PatternMatcher([ lambda reduce,x: reduce.const_like(identity_element(reduce.arg[0], reduce.dtype)) if x.size == 0 and reduce.size != 0 else None), # handle size 0 - (UPat(GroupOp.All-{Ops.SINK}, name="x"), lambda x: x.const_like(0).rtag(x.tag) if x.st is not None and x.size == 0 else None), + (UPat(GroupOp.All-{Ops.SINK}, name="x"), lambda x: x.const_like(0).rtag(x.tag) if x._shape is not None and x.size == 0 else None), # remove contiguous on movement ops before a copy on disk (UPat(GroupOp.Movement-{Ops.SHRINK, Ops.RESHAPE}, name="x").f(Ops.CONTIGUOUS).f(Ops.COPY, allow_any_len=True, name="copy"), diff --git a/tinygrad/uop/ops.py b/tinygrad/uop/ops.py index 01a83de380..517554c7bb 100644 --- a/tinygrad/uop/ops.py +++ b/tinygrad/uop/ops.py @@ -175,6 +175,7 @@ class UOp(MathTrait, metaclass=UOpMetaClass): # *** uop shape stuff *** + # TODO: remove this. it's used by the jit and split_reduceop @recursive_property def st(self) -> ShapeTracker|None: if self.op is Ops.INDEX and self.src[0].op in {Ops.DEFINE_GLOBAL, Ops.DEFINE_LOCAL, Ops.DEFINE_REG, Ops.MSTACK, @@ -223,12 +224,98 @@ class UOp(MathTrait, metaclass=UOpMetaClass): shape = tuple(1 if i in axis_arg else s for i,s in enumerate(shape)) return ShapeTracker.from_shape(shape) + @recursive_property + def _shape(self) -> tuple[sint, ...]|None: + match self.op: + # late ops don't have shape + case Ops.UNIQUE | Ops.DEVICE | Ops.RANGE | Ops.INDEX | Ops.LOAD | Ops.IF | Ops.BARRIER | \ + Ops.VECTORIZE | Ops.VCONST | Ops.SUBSTITUTE | Ops.GEP | Ops.SPECIAL | Ops.UNROLL | Ops.PRECAST: + return None + + # some ops init the shape + case Ops.CONST | Ops.DEFINE_VAR | Ops.BIND: return () if self._device is not None else None + case Ops.BUFFER: return (self.arg,) + case Ops.BUFFER_VIEW: return (self.arg[0],) + case Ops.BUFFERIZE: return tuple([int(r.vmax+1) for r in self.src[1:]]) + case Ops.DEFINE_GLOBAL | Ops.DEFINE_LOCAL | Ops.DEFINE_REG: return (self.ptrdtype.size,) + + # passthrough ops + case Ops.REDUCE | Ops.MSTACK | Ops.MSELECT | Ops.DETACH | Ops.CONTIGUOUS | Ops.CONTIGUOUS_BACKWARD | Ops.FUSE: return self.src[0]._shape + + # ops with custom handling + case Ops.KERNEL: return self.arg.ast._shape + case Ops.STORE: + if isinstance(self.dtype, PtrDType): return (self.ptrdtype.size,) + if self.dtype is not dtypes.void: return self.src[0].src[0].shape + return None + + # TODO: disallow shape changing bitcast + case Ops.BITCAST: + ps = self.src[0]._shape + if ps is None: return None + if (output_sz:=self.dtype.itemsize) != (input_sz:=self.src[0].dtype.itemsize): return ps[:-1]+(ssimplify((ps[-1]*input_sz) // output_sz),) + return ps + + # TODO: disallow reshape from nothing. tested by TestOpenClip.test_multigpu_clip_score + case Ops.RESHAPE: + if self.src[0]._shape is None: return tuple(ssimplify(s) for s in self.arg) + + # movement ops change the shape. this is the logic from the old ShapeTracker + # NOTE: ssimplify is required because the shape needs to be canonical for broadcasting and same shape checking + if self.op in GroupOp.Movement.union({Ops.MULTI, Ops.REDUCE_AXIS, Ops.WMMA}): + ps = self.src[0]._shape + # TODO: WMMA is used for both axis WMMA and op WMMA. fix this and remove this hack. tested by BERT on AMD LLVM + if ps is None and self.op is Ops.WMMA: return None + if ps is None: raise RuntimeError(f"movement op {self.op} requires shape") + match self.op: + case Ops.RESHAPE: + if not all(x >= 0 for x in self.arg): raise ValueError(f"shape can't contain negative numbers {self.arg}") + if prod(ps) != prod(self.arg): raise ValueError(f"bad reshape: {ps} -> {self.arg}") + return tuple(ssimplify(s) for s in self.arg) + case Ops.EXPAND: + if len(ps) != len(self.arg) or not all(s==ns or (s==1 and ns>=0) for s,ns in zip(ps, self.arg)): + raise ValueError(f"bad expand: {ps} -> {self.arg}") + return tuple(ssimplify(s) for s in self.arg) + case Ops.PERMUTE: + if sorted(self.arg) != list(range(len(ps))): raise ValueError(f"invalid permutation {self.arg} of len {len(ps)}") + return tuple(ps[i] for i in self.arg) + case Ops.PAD: + # TODO: why do i need resolve here? + if len(ps) != len(self.arg) or not all(resolve(b>=0) and resolve(e>=0) for b,e in self.arg): raise ValueError(f"invalid pad {self.arg}") + return tuple(ssimplify(s+b+e) for s,(b,e) in zip(ps, self.arg)) + case Ops.SHRINK: + # TODO: why do i need resolve here? + if len(ps) != len(self.arg) or not all(resolve(0<=b) and resolve(b<=e) and resolve(e<=s) for s,(b,e) in zip(ps, self.arg)): + raise ValueError(f"invalid shrink {self.arg} for {ps}") + return tuple(ssimplify(e-s) for s,e in self.arg) + case Ops.FLIP: + if len(ps) != len(self.arg) or not all(isinstance(x, bool) for x in self.arg): raise ValueError(f"bad flip on {ps}, {self.arg}") + return ps + case Ops.MULTI: return tuple(s*len(self.device) if a == self.axis else s for a,s in enumerate(ps)) + case Ops.REDUCE_AXIS | Ops.WMMA: + axis_arg = self.arg[1] if self.op is Ops.REDUCE_AXIS else self.arg[7] + if not isinstance(axis_arg, tuple) or not all(isinstance(x, int) and x>=0 and x tuple[sint, ...]: - assert self.st is not None, f"{self.op} doesn't have a shape" - return unwrap(self.st).shape + if (ret:=self._shape) is None: raise RuntimeError(f"shape requested, but {self.op} doesn't have a shape") + return ret + @property - def size(self) -> int: return self.arg[0] if self.op is Ops.BUFFER_VIEW else self.arg if self.op is Ops.BUFFER else unwrap(self.st).size + def size(self) -> int: return prod([int(x.vmax) if isinstance(x, UOp) else x for x in self.shape]) # determine what ranges this is in @recursive_property @@ -290,7 +377,7 @@ class UOp(MathTrait, metaclass=UOpMetaClass): def __getitem__(self, idx): return self.index(idx) def const_like(self, b:ConstLike): # constants can optionally have a DEVICE source - return UOp.const(self.dtype, b, device=self._device, shape=self.shape if self.st is not None else None) + return UOp.const(self.dtype, b, device=self._device, shape=self._shape) def broadcast(self, count:int): assert self.dtype.count == 1 if count == 1: return self @@ -428,19 +515,22 @@ class UOp(MathTrait, metaclass=UOpMetaClass): if self.op is Ops.MULTI: return self.src[0].base # MULTI is really a VIEW return self - def _mop(self, op:Ops, arg) -> UOp: + def _mop(self, op:Ops, arg, no_reshape_is_no_op:bool=False) -> UOp: ret = UOp(op, self.dtype, (self,), arg) - if self.st == ret.st: return self # ignore NOOPs, also check ret.st + # for all movement ops, we check shape property + if ret.shape == self.shape and no_reshape_is_no_op: return self return ret - def forced_reshape(self, arg:tuple[sint, ...], **kwargs): return UOp(Ops.RESHAPE, kwargs.pop("dtype", self.dtype), src=(self,), arg=arg) - def reshape(self, arg:tuple[sint, ...]): return self._mop(Ops.RESHAPE, arg) - def expand(self, arg:tuple[sint, ...]): return self._mop(Ops.EXPAND, arg) - def shrink(self, arg:tuple[tuple[sint, sint], ...]): return self._mop(Ops.SHRINK, arg) - def pad(self, arg:tuple[tuple[sint, sint], ...]): return self._mop(Ops.PAD, arg) - def permute(self, arg:tuple[int, ...]): return self._mop(Ops.PERMUTE, arg) - def flip(self, arg:tuple[bool, ...]): return self._mop(Ops.FLIP, arg) + # in these four, if the shape doesn't change we can return self + def reshape(self, arg:tuple[sint, ...]): return self._mop(Ops.RESHAPE, arg, no_reshape_is_no_op=True) + def expand(self, arg:tuple[sint, ...]): return self._mop(Ops.EXPAND, arg, no_reshape_is_no_op=True) + def shrink(self, arg:tuple[tuple[sint, sint], ...]): return self._mop(Ops.SHRINK, arg, no_reshape_is_no_op=True) + def pad(self, arg:tuple[tuple[sint, sint], ...]): return self._mop(Ops.PAD, arg, no_reshape_is_no_op=True) + + # in these two, we have custom logic to check if they are a no-op + def permute(self, arg:tuple[int, ...]): return self._mop(Ops.PERMUTE, arg) if arg != tuple(range(len(self.shape))) else self + def flip(self, arg:tuple[bool, ...]): return self._mop(Ops.FLIP, arg) if any(arg) and len(arg) == len(self.shape) else self # *** uop UNIQUE ***