From b0da173f2f9753198a1474024fab5b0225ff8dff Mon Sep 17 00:00:00 2001 From: George Hotz <72895+geohot@users.noreply.github.com> Date: Tue, 28 Oct 2025 15:11:37 +0800 Subject: [PATCH] add unique to const, fix longstanding bug (#12965) * add unique to const, fix longstanding bug * _force_unique=True * fix tests * fix more tests --- test/models/test_real_world.py | 4 ++-- test/test_schedule.py | 9 +++++---- test/test_tensor.py | 33 +++++++++++++++++++++++++++++++++ tinygrad/tensor.py | 10 +++++----- tinygrad/uop/ops.py | 9 +++++++-- tinygrad/uop/spec.py | 2 ++ 6 files changed, 54 insertions(+), 13 deletions(-) diff --git a/test/models/test_real_world.py b/test/models/test_real_world.py index c96a1d7846..0d8389b76c 100644 --- a/test/models/test_real_world.py +++ b/test/models/test_real_world.py @@ -112,7 +112,7 @@ class TestRealWorld(unittest.TestCase): loss.backward() optimizer.step() - helper_test("train_mnist", lambda: (Tensor.randn(BS, 1, 28, 28),), train, 0.07, 102) + helper_test("train_mnist", lambda: (Tensor.randn(BS, 1, 28, 28),), train, 0.07, 103) @unittest.skipIf(CI and Device.DEFAULT in {"CPU", "CL"}, "slow") def test_forward_cifar(self): @@ -176,7 +176,7 @@ class TestRealWorld(unittest.TestCase): for v in data.values(): v.to_(Device.DEFAULT) helper_test("train_bert", lambda: (data["input_ids"], data["segment_ids"], data["input_mask"], data["masked_lm_positions"], \ - data["masked_lm_ids"], data["masked_lm_weights"], data["next_sentence_labels"]), train, 0.31, 358) + data["masked_lm_ids"], data["masked_lm_weights"], data["next_sentence_labels"]), train, 0.31, 427) if __name__ == '__main__': unittest.main() diff --git a/test/test_schedule.py b/test/test_schedule.py index 6b090d911e..c49459e6dc 100644 --- a/test/test_schedule.py +++ b/test/test_schedule.py @@ -370,6 +370,7 @@ class TestSchedule(unittest.TestCase): # NOTE: this is causing "LAZYCACHE=1 incorrectly reuses contiguous const" #4562 # should contiguous dedup? + @unittest.skip("we do the exact opposite now") def test_dedup_contiguous(self): a = Tensor.ones(4).contiguous() b = Tensor.ones(4).contiguous() @@ -446,7 +447,7 @@ class TestSchedule(unittest.TestCase): @unittest.skipUnless(is_dtype_supported(dtypes.ulong), "Needs ulong") def test_fold_conv_batchnorm_optim(self): # this is too high - for optim, cnt in [(nn.optim.Adam, 21), (nn.optim.SGD, 8)]: + for optim, cnt in [(nn.optim.Adam, 28), (nn.optim.SGD, 8)]: with self.subTest(optim=optim.__name__): with Tensor.train(): img = Tensor.ones(1,3,4,4) @@ -1220,7 +1221,7 @@ class TestSchedule(unittest.TestCase): _realize_weights(layer) opt = nn.optim.Adam(nn.state.get_parameters(layer), lr=1e-4) layer(x).relu().sum().backward() - check_schedule(opt.schedule_step(), 16) + check_schedule(opt.schedule_step(), 19) def test_adam_conv_fuse(self): with Tensor.train(): @@ -1230,7 +1231,7 @@ class TestSchedule(unittest.TestCase): opt = nn.optim.Adam(nn.state.get_parameters(c1), lr=1e-4) opt.zero_grad() c1(img).relu().sum().backward() - check_schedule(opt.schedule_step(), 16) + check_schedule(opt.schedule_step(), 19) def test_adam_2convs_fuse(self): with Tensor.train(): @@ -1241,7 +1242,7 @@ class TestSchedule(unittest.TestCase): opt = nn.optim.Adam(nn.state.get_parameters([c1, c2]), lr=1e-4) opt.zero_grad() c2(c1(img).relu()).relu().sum().backward() - check_schedule(opt.schedule_step(), 18) + check_schedule(opt.schedule_step(), 21) def test_sgd_conv_fuse(self): with Tensor.train(): diff --git a/test/test_tensor.py b/test/test_tensor.py index 468633e560..5dce85e0e2 100644 --- a/test/test_tensor.py +++ b/test/test_tensor.py @@ -919,5 +919,38 @@ class TestIdxUpcast(unittest.TestCase): a = Tensor.empty(2**11, 2**11, 1, dtype=dtypes.int8).permute((2, 0, 1)).expand((2**9+10, -1, -1)).contiguous() a.realize() +class TestTensorUnique(unittest.TestCase): + def test_empty_bufs_unique(self): + a = Tensor.empty(10, 10).contiguous() + b = Tensor.empty(10, 10).contiguous() + Tensor.realize(a,b) + self.assertIsNot(a.uop.buffer, b.uop.buffer) + + def test_zeros_bufs_unique_sep(self): + a = Tensor.zeros(10, 10).contiguous() + Tensor.realize(a) + b = Tensor.zeros(10, 10).contiguous() + Tensor.realize(b) + self.assertIsNot(a.uop.buffer, b.uop.buffer) + + def test_zeros_bufs_unique(self): + a = Tensor.zeros(10, 10).contiguous() + b = Tensor.zeros(10, 10).contiguous() + Tensor.realize(a,b) + self.assertIsNot(a.uop.buffer, b.uop.buffer) + + def test_eye_bufs_unique(self): + a = Tensor.eye(10).contiguous() + b = Tensor.eye(10).contiguous() + Tensor.realize(a,b) + self.assertIsNot(a.uop.buffer, b.uop.buffer) + + def test_times_2_not_unique(self): + a = Tensor.zeros(10, 10).contiguous() + b = a * 2 + c = a * 2 + Tensor.realize(b,c) + self.assertIs(b.uop.buffer, c.uop.buffer) + if __name__ == '__main__': unittest.main() diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index 06b0dc3feb..6569b821dc 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -115,7 +115,7 @@ class Tensor(MathTrait): training: ClassVar[bool] = False def __init__(self, data:ConstType|bytes|list|tuple|UOp|'np.ndarray'|pathlib.Path|None, # type: ignore [name-defined] # noqa: F821 - device:str|tuple|list|None=None, dtype:DTypeLike|None=None, requires_grad:bool|None=None): + device:str|tuple|list|None=None, dtype:DTypeLike|None=None, requires_grad:bool|None=None, _force_unique:bool=False): if device is None and isinstance(data, pathlib.Path): device = f"DISK:{data.resolve()}" # keep it on the disk if device is None _dtype:DType|None = to_dtype(dtype) if dtype is not None else None _device:str|tuple[str, ...] = tuple(canonicalize_device(x) for x in device) if isinstance(device, (tuple, list)) else canonicalize_device(device) @@ -138,8 +138,8 @@ class Tensor(MathTrait): # give the bound constant a device const = UOp.const(var.dtype, val, _device, ()) data = data.replace(src=(var.replace(src=const.src), const)) # type: ignore - elif data is None: data = UOp.const(_dtype or dtypes.default_float, 0, _device, ()) - elif isinstance(data, get_args(ConstType)): data = UOp.const(_dtype or dtypes.from_py(data), data, _device, ()) + elif data is None: data = UOp.const(_dtype or dtypes.default_float, 0, _device, (), unique=_force_unique) + elif isinstance(data, get_args(ConstType)): data = UOp.const(_dtype or dtypes.from_py(data), data, _device, (), unique=_force_unique) elif isinstance(data, bytes): data = _frompy(data, dtypes.uint8 if _dtype is None else _dtype) elif isinstance(data, (list, tuple)): if _dtype is None: @@ -150,7 +150,7 @@ class Tensor(MathTrait): elif is_numpy_ndarray(data): import numpy as np assert isinstance(data, np.ndarray), f"expected np.ndarray, got {data}" - if data.shape == (): data = UOp.const(_dtype or _from_np_dtype(data.dtype), data.item(), _device, ()) + if data.shape == (): data = UOp.const(_dtype or _from_np_dtype(data.dtype), data.item(), _device, (), unique=_force_unique) else: data = _fromnp(data.astype(npdtype) if _dtype is not None and (npdtype:=_to_np_dtype(_dtype)) is not None else data) # type: ignore [name-defined] elif isinstance(data, pathlib.Path): _dtype = _dtype or dtypes.uint8 @@ -625,7 +625,7 @@ class Tensor(MathTrait): print(Tensor.full((2, 3), False).numpy()) ``` """ - return Tensor(fill_value, **kwargs).reshape((1, )*len(new_shape := argfix(shape))).expand(new_shape) + return Tensor(fill_value, _force_unique=True, **kwargs).reshape((1, )*len(new_shape := argfix(shape))).expand(new_shape) @staticmethod def zeros(*shape, **kwargs) -> Tensor: diff --git a/tinygrad/uop/ops.py b/tinygrad/uop/ops.py index a90069452b..2edb54b5e6 100644 --- a/tinygrad/uop/ops.py +++ b/tinygrad/uop/ops.py @@ -371,13 +371,16 @@ class UOp(MathTrait, metaclass=UOpMetaClass): if op in {Ops.CMPLT, Ops.CMPNE, Ops.CMPEQ}: out_dtype = dtypes.bool.vec(out_dtype.count) if out_dtype.count > 1 else dtypes.bool return UOp(op, out_dtype, (self,)+src, **kwargs) @staticmethod - def const(dtype:DType, b:ConstLike, device:str|tuple[str, ...]|None=None, shape:tuple[sint, ...]|None=None, src=None): + def const(dtype:DType, b:ConstLike, device:str|tuple[str, ...]|None=None, shape:tuple[sint, ...]|None=None, src=None, unique:bool|int=False): if isinstance(b, UOp): return b.unbind()[0] if b.op is Ops.BIND else b if isinstance(b, tuple) and all_same(b): b = b[0] # doesn't have to be a VCONST if they are all the same # NOTE: float('nan') != float('nan'), so we canonicalize here if isinstance(b, float) and math.isnan(b): b = math.nan ret = UOp(Ops.VCONST if isinstance(b, tuple) else Ops.CONST, dtype, arg=dtypes.as_const(b, dtype), src=() if src is None else (src,)) - if device is not None: ret = ret.replace(src=(UOp(Ops.DEVICE, arg=device),)) + if device is not None: + if unique or not isinstance(unique, bool): ret = ret.replace(src=(UOp(Ops.DEVICE, arg=device), UOp.unique(None if unique is True else unique))) + else: ret = ret.replace(src=(UOp(Ops.DEVICE, arg=device),)) + elif unique or not isinstance(unique, bool): raise RuntimeError("unique consts only with DEVICE") if shape is not None: ret = ret.reshape((1,)*len(shape)).expand(shape) return ret @staticmethod @@ -1252,6 +1255,8 @@ def render_marg(ctx,x:UOp): sugar = {Ops.SINK, Ops.END, Ops.STORE, Ops.LOAD, Ops.UNIQUE, Ops.SQRT, Ops.INDEX, Ops.REDUCE, Ops.AFTER, Ops.WHERE, Ops.RECIPROCAL, Ops.EXP2, Ops.LOG2, Ops.SIN, Ops.CONTIGUOUS, Ops.BARRIER} pm_pyrender_extra = PatternMatcher([ + (UPat(Ops.CONST, src=(UPat(Ops.DEVICE, name="d"), UPat(Ops.UNIQUE, name="u")), name="x"), + lambda x,d,u: f"UOp.const({x.dtype}, {x.arg}, device={repr(d.arg)}, unique={u.arg})"), (UPat(Ops.CONST, src=(UPat(Ops.DEVICE, name="d"),), name="x"), lambda x,d: f"UOp.const({x.dtype}, {x.arg}, device={repr(d.arg)})"), (UPat(Ops.CONST, name="x"), lambda x: f"UOp.const({x.dtype}, {x.arg})"), (UPat(Ops.DEFINE_VAR, src=(), name="x"), lambda x: diff --git a/tinygrad/uop/spec.py b/tinygrad/uop/spec.py index ced3cea7ee..0bb0e66796 100644 --- a/tinygrad/uop/spec.py +++ b/tinygrad/uop/spec.py @@ -77,7 +77,9 @@ tensor_spec = PatternMatcher([ # Tensor variable bindings (UPat(Ops.BIND, (dtypes.int,dtypes.index,), (UPat(Ops.DEFINE_VAR), UPat.cvar(dtype=(dtypes.int,dtypes.index,))), arg=None), lambda: True), + # device or unique (UPat(Ops.CONST, src=(UPat(Ops.DEVICE),)), lambda: True), + (UPat(Ops.CONST, src=(UPat(Ops.DEVICE), UPat(Ops.UNIQUE))), lambda: True), # DETACH and CONTIGUOUS change how we interpret the source UOp # CONTIGUOUS ensures the source UOp realizes