mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-10 15:38:29 -05:00
add unique to const, fix longstanding bug (#12965)
* add unique to const, fix longstanding bug * _force_unique=True * fix tests * fix more tests
This commit is contained in:
@@ -112,7 +112,7 @@ class TestRealWorld(unittest.TestCase):
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
helper_test("train_mnist", lambda: (Tensor.randn(BS, 1, 28, 28),), train, 0.07, 102)
|
||||
helper_test("train_mnist", lambda: (Tensor.randn(BS, 1, 28, 28),), train, 0.07, 103)
|
||||
|
||||
@unittest.skipIf(CI and Device.DEFAULT in {"CPU", "CL"}, "slow")
|
||||
def test_forward_cifar(self):
|
||||
@@ -176,7 +176,7 @@ class TestRealWorld(unittest.TestCase):
|
||||
for v in data.values(): v.to_(Device.DEFAULT)
|
||||
|
||||
helper_test("train_bert", lambda: (data["input_ids"], data["segment_ids"], data["input_mask"], data["masked_lm_positions"], \
|
||||
data["masked_lm_ids"], data["masked_lm_weights"], data["next_sentence_labels"]), train, 0.31, 358)
|
||||
data["masked_lm_ids"], data["masked_lm_weights"], data["next_sentence_labels"]), train, 0.31, 427)
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
||||
@@ -370,6 +370,7 @@ class TestSchedule(unittest.TestCase):
|
||||
|
||||
# NOTE: this is causing "LAZYCACHE=1 incorrectly reuses contiguous const" #4562
|
||||
# should contiguous dedup?
|
||||
@unittest.skip("we do the exact opposite now")
|
||||
def test_dedup_contiguous(self):
|
||||
a = Tensor.ones(4).contiguous()
|
||||
b = Tensor.ones(4).contiguous()
|
||||
@@ -446,7 +447,7 @@ class TestSchedule(unittest.TestCase):
|
||||
@unittest.skipUnless(is_dtype_supported(dtypes.ulong), "Needs ulong")
|
||||
def test_fold_conv_batchnorm_optim(self):
|
||||
# this is too high
|
||||
for optim, cnt in [(nn.optim.Adam, 21), (nn.optim.SGD, 8)]:
|
||||
for optim, cnt in [(nn.optim.Adam, 28), (nn.optim.SGD, 8)]:
|
||||
with self.subTest(optim=optim.__name__):
|
||||
with Tensor.train():
|
||||
img = Tensor.ones(1,3,4,4)
|
||||
@@ -1220,7 +1221,7 @@ class TestSchedule(unittest.TestCase):
|
||||
_realize_weights(layer)
|
||||
opt = nn.optim.Adam(nn.state.get_parameters(layer), lr=1e-4)
|
||||
layer(x).relu().sum().backward()
|
||||
check_schedule(opt.schedule_step(), 16)
|
||||
check_schedule(opt.schedule_step(), 19)
|
||||
|
||||
def test_adam_conv_fuse(self):
|
||||
with Tensor.train():
|
||||
@@ -1230,7 +1231,7 @@ class TestSchedule(unittest.TestCase):
|
||||
opt = nn.optim.Adam(nn.state.get_parameters(c1), lr=1e-4)
|
||||
opt.zero_grad()
|
||||
c1(img).relu().sum().backward()
|
||||
check_schedule(opt.schedule_step(), 16)
|
||||
check_schedule(opt.schedule_step(), 19)
|
||||
|
||||
def test_adam_2convs_fuse(self):
|
||||
with Tensor.train():
|
||||
@@ -1241,7 +1242,7 @@ class TestSchedule(unittest.TestCase):
|
||||
opt = nn.optim.Adam(nn.state.get_parameters([c1, c2]), lr=1e-4)
|
||||
opt.zero_grad()
|
||||
c2(c1(img).relu()).relu().sum().backward()
|
||||
check_schedule(opt.schedule_step(), 18)
|
||||
check_schedule(opt.schedule_step(), 21)
|
||||
|
||||
def test_sgd_conv_fuse(self):
|
||||
with Tensor.train():
|
||||
|
||||
@@ -919,5 +919,38 @@ class TestIdxUpcast(unittest.TestCase):
|
||||
a = Tensor.empty(2**11, 2**11, 1, dtype=dtypes.int8).permute((2, 0, 1)).expand((2**9+10, -1, -1)).contiguous()
|
||||
a.realize()
|
||||
|
||||
class TestTensorUnique(unittest.TestCase):
|
||||
def test_empty_bufs_unique(self):
|
||||
a = Tensor.empty(10, 10).contiguous()
|
||||
b = Tensor.empty(10, 10).contiguous()
|
||||
Tensor.realize(a,b)
|
||||
self.assertIsNot(a.uop.buffer, b.uop.buffer)
|
||||
|
||||
def test_zeros_bufs_unique_sep(self):
|
||||
a = Tensor.zeros(10, 10).contiguous()
|
||||
Tensor.realize(a)
|
||||
b = Tensor.zeros(10, 10).contiguous()
|
||||
Tensor.realize(b)
|
||||
self.assertIsNot(a.uop.buffer, b.uop.buffer)
|
||||
|
||||
def test_zeros_bufs_unique(self):
|
||||
a = Tensor.zeros(10, 10).contiguous()
|
||||
b = Tensor.zeros(10, 10).contiguous()
|
||||
Tensor.realize(a,b)
|
||||
self.assertIsNot(a.uop.buffer, b.uop.buffer)
|
||||
|
||||
def test_eye_bufs_unique(self):
|
||||
a = Tensor.eye(10).contiguous()
|
||||
b = Tensor.eye(10).contiguous()
|
||||
Tensor.realize(a,b)
|
||||
self.assertIsNot(a.uop.buffer, b.uop.buffer)
|
||||
|
||||
def test_times_2_not_unique(self):
|
||||
a = Tensor.zeros(10, 10).contiguous()
|
||||
b = a * 2
|
||||
c = a * 2
|
||||
Tensor.realize(b,c)
|
||||
self.assertIs(b.uop.buffer, c.uop.buffer)
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
||||
@@ -115,7 +115,7 @@ class Tensor(MathTrait):
|
||||
training: ClassVar[bool] = False
|
||||
|
||||
def __init__(self, data:ConstType|bytes|list|tuple|UOp|'np.ndarray'|pathlib.Path|None, # type: ignore [name-defined] # noqa: F821
|
||||
device:str|tuple|list|None=None, dtype:DTypeLike|None=None, requires_grad:bool|None=None):
|
||||
device:str|tuple|list|None=None, dtype:DTypeLike|None=None, requires_grad:bool|None=None, _force_unique:bool=False):
|
||||
if device is None and isinstance(data, pathlib.Path): device = f"DISK:{data.resolve()}" # keep it on the disk if device is None
|
||||
_dtype:DType|None = to_dtype(dtype) if dtype is not None else None
|
||||
_device:str|tuple[str, ...] = tuple(canonicalize_device(x) for x in device) if isinstance(device, (tuple, list)) else canonicalize_device(device)
|
||||
@@ -138,8 +138,8 @@ class Tensor(MathTrait):
|
||||
# give the bound constant a device
|
||||
const = UOp.const(var.dtype, val, _device, ())
|
||||
data = data.replace(src=(var.replace(src=const.src), const)) # type: ignore
|
||||
elif data is None: data = UOp.const(_dtype or dtypes.default_float, 0, _device, ())
|
||||
elif isinstance(data, get_args(ConstType)): data = UOp.const(_dtype or dtypes.from_py(data), data, _device, ())
|
||||
elif data is None: data = UOp.const(_dtype or dtypes.default_float, 0, _device, (), unique=_force_unique)
|
||||
elif isinstance(data, get_args(ConstType)): data = UOp.const(_dtype or dtypes.from_py(data), data, _device, (), unique=_force_unique)
|
||||
elif isinstance(data, bytes): data = _frompy(data, dtypes.uint8 if _dtype is None else _dtype)
|
||||
elif isinstance(data, (list, tuple)):
|
||||
if _dtype is None:
|
||||
@@ -150,7 +150,7 @@ class Tensor(MathTrait):
|
||||
elif is_numpy_ndarray(data):
|
||||
import numpy as np
|
||||
assert isinstance(data, np.ndarray), f"expected np.ndarray, got {data}"
|
||||
if data.shape == (): data = UOp.const(_dtype or _from_np_dtype(data.dtype), data.item(), _device, ())
|
||||
if data.shape == (): data = UOp.const(_dtype or _from_np_dtype(data.dtype), data.item(), _device, (), unique=_force_unique)
|
||||
else: data = _fromnp(data.astype(npdtype) if _dtype is not None and (npdtype:=_to_np_dtype(_dtype)) is not None else data) # type: ignore [name-defined]
|
||||
elif isinstance(data, pathlib.Path):
|
||||
_dtype = _dtype or dtypes.uint8
|
||||
@@ -625,7 +625,7 @@ class Tensor(MathTrait):
|
||||
print(Tensor.full((2, 3), False).numpy())
|
||||
```
|
||||
"""
|
||||
return Tensor(fill_value, **kwargs).reshape((1, )*len(new_shape := argfix(shape))).expand(new_shape)
|
||||
return Tensor(fill_value, _force_unique=True, **kwargs).reshape((1, )*len(new_shape := argfix(shape))).expand(new_shape)
|
||||
|
||||
@staticmethod
|
||||
def zeros(*shape, **kwargs) -> Tensor:
|
||||
|
||||
@@ -371,13 +371,16 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
|
||||
if op in {Ops.CMPLT, Ops.CMPNE, Ops.CMPEQ}: out_dtype = dtypes.bool.vec(out_dtype.count) if out_dtype.count > 1 else dtypes.bool
|
||||
return UOp(op, out_dtype, (self,)+src, **kwargs)
|
||||
@staticmethod
|
||||
def const(dtype:DType, b:ConstLike, device:str|tuple[str, ...]|None=None, shape:tuple[sint, ...]|None=None, src=None):
|
||||
def const(dtype:DType, b:ConstLike, device:str|tuple[str, ...]|None=None, shape:tuple[sint, ...]|None=None, src=None, unique:bool|int=False):
|
||||
if isinstance(b, UOp): return b.unbind()[0] if b.op is Ops.BIND else b
|
||||
if isinstance(b, tuple) and all_same(b): b = b[0] # doesn't have to be a VCONST if they are all the same
|
||||
# NOTE: float('nan') != float('nan'), so we canonicalize here
|
||||
if isinstance(b, float) and math.isnan(b): b = math.nan
|
||||
ret = UOp(Ops.VCONST if isinstance(b, tuple) else Ops.CONST, dtype, arg=dtypes.as_const(b, dtype), src=() if src is None else (src,))
|
||||
if device is not None: ret = ret.replace(src=(UOp(Ops.DEVICE, arg=device),))
|
||||
if device is not None:
|
||||
if unique or not isinstance(unique, bool): ret = ret.replace(src=(UOp(Ops.DEVICE, arg=device), UOp.unique(None if unique is True else unique)))
|
||||
else: ret = ret.replace(src=(UOp(Ops.DEVICE, arg=device),))
|
||||
elif unique or not isinstance(unique, bool): raise RuntimeError("unique consts only with DEVICE")
|
||||
if shape is not None: ret = ret.reshape((1,)*len(shape)).expand(shape)
|
||||
return ret
|
||||
@staticmethod
|
||||
@@ -1252,6 +1255,8 @@ def render_marg(ctx,x:UOp):
|
||||
sugar = {Ops.SINK, Ops.END, Ops.STORE, Ops.LOAD, Ops.UNIQUE, Ops.SQRT, Ops.INDEX, Ops.REDUCE, Ops.AFTER,
|
||||
Ops.WHERE, Ops.RECIPROCAL, Ops.EXP2, Ops.LOG2, Ops.SIN, Ops.CONTIGUOUS, Ops.BARRIER}
|
||||
pm_pyrender_extra = PatternMatcher([
|
||||
(UPat(Ops.CONST, src=(UPat(Ops.DEVICE, name="d"), UPat(Ops.UNIQUE, name="u")), name="x"),
|
||||
lambda x,d,u: f"UOp.const({x.dtype}, {x.arg}, device={repr(d.arg)}, unique={u.arg})"),
|
||||
(UPat(Ops.CONST, src=(UPat(Ops.DEVICE, name="d"),), name="x"), lambda x,d: f"UOp.const({x.dtype}, {x.arg}, device={repr(d.arg)})"),
|
||||
(UPat(Ops.CONST, name="x"), lambda x: f"UOp.const({x.dtype}, {x.arg})"),
|
||||
(UPat(Ops.DEFINE_VAR, src=(), name="x"), lambda x:
|
||||
|
||||
@@ -77,7 +77,9 @@ tensor_spec = PatternMatcher([
|
||||
# Tensor variable bindings
|
||||
(UPat(Ops.BIND, (dtypes.int,dtypes.index,), (UPat(Ops.DEFINE_VAR), UPat.cvar(dtype=(dtypes.int,dtypes.index,))), arg=None), lambda: True),
|
||||
|
||||
# device or unique
|
||||
(UPat(Ops.CONST, src=(UPat(Ops.DEVICE),)), lambda: True),
|
||||
(UPat(Ops.CONST, src=(UPat(Ops.DEVICE), UPat(Ops.UNIQUE))), lambda: True),
|
||||
|
||||
# DETACH and CONTIGUOUS change how we interpret the source UOp
|
||||
# CONTIGUOUS ensures the source UOp realizes
|
||||
|
||||
Reference in New Issue
Block a user