mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
push Tensor(symbolic) logic to Tensor.from_uop (#15420)
This commit is contained in:
@@ -134,13 +134,7 @@ class Tensor(OpMixin):
|
||||
if isinstance(data, UOp):
|
||||
assert _dtype is None or _dtype==data.dtype or data.dtype==dtypes.weakint, f"dtype mismatch: {_dtype} vs {data.dtype}"
|
||||
# if data is dtype.weakint that means that this is a symbolic int and we need to lower it to something we can make a Tensor out of
|
||||
# TODO: remove this and stay in weakint
|
||||
if data.dtype==dtypes.weakint: data = _index_to_concrete_int(data)
|
||||
if data.op is Ops.BIND:
|
||||
var, val = data.unbind()
|
||||
# give the bound constant a device
|
||||
const = UOp.const(var.dtype, val, _device, ())
|
||||
data = data.replace(src=(var.replace(src=const.src), const))
|
||||
if data.dtype == dtypes.weakint: data = Tensor.from_uop(data).uop
|
||||
elif data is None:
|
||||
data = UOp.const(_dtype or dtypes.default_float, 0, _device)
|
||||
elif isinstance(data, get_args(ConstType)):
|
||||
@@ -503,7 +497,13 @@ class Tensor(OpMixin):
|
||||
|
||||
@staticmethod
|
||||
def from_uop(y:UOp, **kwargs) -> Tensor:
|
||||
if y.op is Ops.BIND: return Tensor(y, **kwargs, requires_grad=False)
|
||||
# TODO: remove this and stay in weakint
|
||||
if y.dtype == dtypes.weakint: y = _index_to_concrete_int(y)
|
||||
if y.op is Ops.BIND:
|
||||
var, val = y.unbind()
|
||||
_device = canonicalize_device(kwargs.get("device"))
|
||||
const = UOp.const(var.dtype, val, _device, ())
|
||||
return Tensor(y.replace(src=(var.replace(src=const.src), const)), **kwargs, requires_grad=False)
|
||||
if y.op is Ops.CONST: return Tensor(y.arg, **kwargs, requires_grad=False)
|
||||
if y.op is Ops.MUL: return Tensor.from_uop(y.src[0]) * Tensor.from_uop(y.src[1])
|
||||
if y.op is Ops.ADD: return Tensor.from_uop(y.src[0]) + Tensor.from_uop(y.src[1])
|
||||
|
||||
Reference in New Issue
Block a user