diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index 57094b030f..68c814de23 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -134,13 +134,7 @@ class Tensor(OpMixin): if isinstance(data, UOp): assert _dtype is None or _dtype==data.dtype or data.dtype==dtypes.weakint, f"dtype mismatch: {_dtype} vs {data.dtype}" # if data is dtype.weakint that means that this is a symbolic int and we need to lower it to something we can make a Tensor out of - # TODO: remove this and stay in weakint - if data.dtype==dtypes.weakint: data = _index_to_concrete_int(data) - if data.op is Ops.BIND: - var, val = data.unbind() - # give the bound constant a device - const = UOp.const(var.dtype, val, _device, ()) - data = data.replace(src=(var.replace(src=const.src), const)) + if data.dtype == dtypes.weakint: data = Tensor.from_uop(data).uop elif data is None: data = UOp.const(_dtype or dtypes.default_float, 0, _device) elif isinstance(data, get_args(ConstType)): @@ -503,7 +497,13 @@ class Tensor(OpMixin): @staticmethod def from_uop(y:UOp, **kwargs) -> Tensor: - if y.op is Ops.BIND: return Tensor(y, **kwargs, requires_grad=False) + # TODO: remove this and stay in weakint + if y.dtype == dtypes.weakint: y = _index_to_concrete_int(y) + if y.op is Ops.BIND: + var, val = y.unbind() + _device = canonicalize_device(kwargs.get("device")) + const = UOp.const(var.dtype, val, _device, ()) + return Tensor(y.replace(src=(var.replace(src=const.src), const)), **kwargs, requires_grad=False) if y.op is Ops.CONST: return Tensor(y.arg, **kwargs, requires_grad=False) if y.op is Ops.MUL: return Tensor.from_uop(y.src[0]) * Tensor.from_uop(y.src[1]) if y.op is Ops.ADD: return Tensor.from_uop(y.src[0]) + Tensor.from_uop(y.src[1])