add unique to const, fix longstanding bug (#12965)

* add unique to const, fix longstanding bug

* _force_unique=True

* fix tests

* fix more tests
This commit is contained in:
George Hotz
2025-10-28 15:11:37 +08:00
committed by GitHub
parent e110f4632a
commit b0da173f2f
6 changed files with 54 additions and 13 deletions

View File

@@ -112,7 +112,7 @@ class TestRealWorld(unittest.TestCase):
loss.backward()
optimizer.step()
helper_test("train_mnist", lambda: (Tensor.randn(BS, 1, 28, 28),), train, 0.07, 102)
helper_test("train_mnist", lambda: (Tensor.randn(BS, 1, 28, 28),), train, 0.07, 103)
@unittest.skipIf(CI and Device.DEFAULT in {"CPU", "CL"}, "slow")
def test_forward_cifar(self):
@@ -176,7 +176,7 @@ class TestRealWorld(unittest.TestCase):
for v in data.values(): v.to_(Device.DEFAULT)
helper_test("train_bert", lambda: (data["input_ids"], data["segment_ids"], data["input_mask"], data["masked_lm_positions"], \
data["masked_lm_ids"], data["masked_lm_weights"], data["next_sentence_labels"]), train, 0.31, 358)
data["masked_lm_ids"], data["masked_lm_weights"], data["next_sentence_labels"]), train, 0.31, 427)
if __name__ == '__main__':
unittest.main()

View File

@@ -370,6 +370,7 @@ class TestSchedule(unittest.TestCase):
# NOTE: this is causing "LAZYCACHE=1 incorrectly reuses contiguous const" #4562
# should contiguous dedup?
@unittest.skip("we do the exact opposite now")
def test_dedup_contiguous(self):
a = Tensor.ones(4).contiguous()
b = Tensor.ones(4).contiguous()
@@ -446,7 +447,7 @@ class TestSchedule(unittest.TestCase):
@unittest.skipUnless(is_dtype_supported(dtypes.ulong), "Needs ulong")
def test_fold_conv_batchnorm_optim(self):
# this is too high
for optim, cnt in [(nn.optim.Adam, 21), (nn.optim.SGD, 8)]:
for optim, cnt in [(nn.optim.Adam, 28), (nn.optim.SGD, 8)]:
with self.subTest(optim=optim.__name__):
with Tensor.train():
img = Tensor.ones(1,3,4,4)
@@ -1220,7 +1221,7 @@ class TestSchedule(unittest.TestCase):
_realize_weights(layer)
opt = nn.optim.Adam(nn.state.get_parameters(layer), lr=1e-4)
layer(x).relu().sum().backward()
check_schedule(opt.schedule_step(), 16)
check_schedule(opt.schedule_step(), 19)
def test_adam_conv_fuse(self):
with Tensor.train():
@@ -1230,7 +1231,7 @@ class TestSchedule(unittest.TestCase):
opt = nn.optim.Adam(nn.state.get_parameters(c1), lr=1e-4)
opt.zero_grad()
c1(img).relu().sum().backward()
check_schedule(opt.schedule_step(), 16)
check_schedule(opt.schedule_step(), 19)
def test_adam_2convs_fuse(self):
with Tensor.train():
@@ -1241,7 +1242,7 @@ class TestSchedule(unittest.TestCase):
opt = nn.optim.Adam(nn.state.get_parameters([c1, c2]), lr=1e-4)
opt.zero_grad()
c2(c1(img).relu()).relu().sum().backward()
check_schedule(opt.schedule_step(), 18)
check_schedule(opt.schedule_step(), 21)
def test_sgd_conv_fuse(self):
with Tensor.train():

View File

@@ -919,5 +919,38 @@ class TestIdxUpcast(unittest.TestCase):
a = Tensor.empty(2**11, 2**11, 1, dtype=dtypes.int8).permute((2, 0, 1)).expand((2**9+10, -1, -1)).contiguous()
a.realize()
class TestTensorUnique(unittest.TestCase):
def test_empty_bufs_unique(self):
a = Tensor.empty(10, 10).contiguous()
b = Tensor.empty(10, 10).contiguous()
Tensor.realize(a,b)
self.assertIsNot(a.uop.buffer, b.uop.buffer)
def test_zeros_bufs_unique_sep(self):
a = Tensor.zeros(10, 10).contiguous()
Tensor.realize(a)
b = Tensor.zeros(10, 10).contiguous()
Tensor.realize(b)
self.assertIsNot(a.uop.buffer, b.uop.buffer)
def test_zeros_bufs_unique(self):
a = Tensor.zeros(10, 10).contiguous()
b = Tensor.zeros(10, 10).contiguous()
Tensor.realize(a,b)
self.assertIsNot(a.uop.buffer, b.uop.buffer)
def test_eye_bufs_unique(self):
a = Tensor.eye(10).contiguous()
b = Tensor.eye(10).contiguous()
Tensor.realize(a,b)
self.assertIsNot(a.uop.buffer, b.uop.buffer)
def test_times_2_not_unique(self):
a = Tensor.zeros(10, 10).contiguous()
b = a * 2
c = a * 2
Tensor.realize(b,c)
self.assertIs(b.uop.buffer, c.uop.buffer)
if __name__ == '__main__':
unittest.main()

View File

@@ -115,7 +115,7 @@ class Tensor(MathTrait):
training: ClassVar[bool] = False
def __init__(self, data:ConstType|bytes|list|tuple|UOp|'np.ndarray'|pathlib.Path|None, # type: ignore [name-defined] # noqa: F821
device:str|tuple|list|None=None, dtype:DTypeLike|None=None, requires_grad:bool|None=None):
device:str|tuple|list|None=None, dtype:DTypeLike|None=None, requires_grad:bool|None=None, _force_unique:bool=False):
if device is None and isinstance(data, pathlib.Path): device = f"DISK:{data.resolve()}" # keep it on the disk if device is None
_dtype:DType|None = to_dtype(dtype) if dtype is not None else None
_device:str|tuple[str, ...] = tuple(canonicalize_device(x) for x in device) if isinstance(device, (tuple, list)) else canonicalize_device(device)
@@ -138,8 +138,8 @@ class Tensor(MathTrait):
# give the bound constant a device
const = UOp.const(var.dtype, val, _device, ())
data = data.replace(src=(var.replace(src=const.src), const)) # type: ignore
elif data is None: data = UOp.const(_dtype or dtypes.default_float, 0, _device, ())
elif isinstance(data, get_args(ConstType)): data = UOp.const(_dtype or dtypes.from_py(data), data, _device, ())
elif data is None: data = UOp.const(_dtype or dtypes.default_float, 0, _device, (), unique=_force_unique)
elif isinstance(data, get_args(ConstType)): data = UOp.const(_dtype or dtypes.from_py(data), data, _device, (), unique=_force_unique)
elif isinstance(data, bytes): data = _frompy(data, dtypes.uint8 if _dtype is None else _dtype)
elif isinstance(data, (list, tuple)):
if _dtype is None:
@@ -150,7 +150,7 @@ class Tensor(MathTrait):
elif is_numpy_ndarray(data):
import numpy as np
assert isinstance(data, np.ndarray), f"expected np.ndarray, got {data}"
if data.shape == (): data = UOp.const(_dtype or _from_np_dtype(data.dtype), data.item(), _device, ())
if data.shape == (): data = UOp.const(_dtype or _from_np_dtype(data.dtype), data.item(), _device, (), unique=_force_unique)
else: data = _fromnp(data.astype(npdtype) if _dtype is not None and (npdtype:=_to_np_dtype(_dtype)) is not None else data) # type: ignore [name-defined]
elif isinstance(data, pathlib.Path):
_dtype = _dtype or dtypes.uint8
@@ -625,7 +625,7 @@ class Tensor(MathTrait):
print(Tensor.full((2, 3), False).numpy())
```
"""
return Tensor(fill_value, **kwargs).reshape((1, )*len(new_shape := argfix(shape))).expand(new_shape)
return Tensor(fill_value, _force_unique=True, **kwargs).reshape((1, )*len(new_shape := argfix(shape))).expand(new_shape)
@staticmethod
def zeros(*shape, **kwargs) -> Tensor:

View File

@@ -371,13 +371,16 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
if op in {Ops.CMPLT, Ops.CMPNE, Ops.CMPEQ}: out_dtype = dtypes.bool.vec(out_dtype.count) if out_dtype.count > 1 else dtypes.bool
return UOp(op, out_dtype, (self,)+src, **kwargs)
@staticmethod
def const(dtype:DType, b:ConstLike, device:str|tuple[str, ...]|None=None, shape:tuple[sint, ...]|None=None, src=None):
def const(dtype:DType, b:ConstLike, device:str|tuple[str, ...]|None=None, shape:tuple[sint, ...]|None=None, src=None, unique:bool|int=False):
if isinstance(b, UOp): return b.unbind()[0] if b.op is Ops.BIND else b
if isinstance(b, tuple) and all_same(b): b = b[0] # doesn't have to be a VCONST if they are all the same
# NOTE: float('nan') != float('nan'), so we canonicalize here
if isinstance(b, float) and math.isnan(b): b = math.nan
ret = UOp(Ops.VCONST if isinstance(b, tuple) else Ops.CONST, dtype, arg=dtypes.as_const(b, dtype), src=() if src is None else (src,))
if device is not None: ret = ret.replace(src=(UOp(Ops.DEVICE, arg=device),))
if device is not None:
if unique or not isinstance(unique, bool): ret = ret.replace(src=(UOp(Ops.DEVICE, arg=device), UOp.unique(None if unique is True else unique)))
else: ret = ret.replace(src=(UOp(Ops.DEVICE, arg=device),))
elif unique or not isinstance(unique, bool): raise RuntimeError("unique consts only with DEVICE")
if shape is not None: ret = ret.reshape((1,)*len(shape)).expand(shape)
return ret
@staticmethod
@@ -1252,6 +1255,8 @@ def render_marg(ctx,x:UOp):
sugar = {Ops.SINK, Ops.END, Ops.STORE, Ops.LOAD, Ops.UNIQUE, Ops.SQRT, Ops.INDEX, Ops.REDUCE, Ops.AFTER,
Ops.WHERE, Ops.RECIPROCAL, Ops.EXP2, Ops.LOG2, Ops.SIN, Ops.CONTIGUOUS, Ops.BARRIER}
pm_pyrender_extra = PatternMatcher([
(UPat(Ops.CONST, src=(UPat(Ops.DEVICE, name="d"), UPat(Ops.UNIQUE, name="u")), name="x"),
lambda x,d,u: f"UOp.const({x.dtype}, {x.arg}, device={repr(d.arg)}, unique={u.arg})"),
(UPat(Ops.CONST, src=(UPat(Ops.DEVICE, name="d"),), name="x"), lambda x,d: f"UOp.const({x.dtype}, {x.arg}, device={repr(d.arg)})"),
(UPat(Ops.CONST, name="x"), lambda x: f"UOp.const({x.dtype}, {x.arg})"),
(UPat(Ops.DEFINE_VAR, src=(), name="x"), lambda x:

View File

@@ -77,7 +77,9 @@ tensor_spec = PatternMatcher([
# Tensor variable bindings
(UPat(Ops.BIND, (dtypes.int,dtypes.index,), (UPat(Ops.DEFINE_VAR), UPat.cvar(dtype=(dtypes.int,dtypes.index,))), arg=None), lambda: True),
# device or unique
(UPat(Ops.CONST, src=(UPat(Ops.DEVICE),)), lambda: True),
(UPat(Ops.CONST, src=(UPat(Ops.DEVICE), UPat(Ops.UNIQUE))), lambda: True),
# DETACH and CONTIGUOUS change how we interpret the source UOp
# CONTIGUOUS ensures the source UOp realizes