push Tensor(symbolic) logic to Tensor.from_uop (#15420)

This commit is contained in:
chenyu
2026-03-22 23:49:35 -04:00
committed by GitHub
parent 2087df814f
commit 67dcc79fdd

View File

@@ -134,13 +134,7 @@ class Tensor(OpMixin):
if isinstance(data, UOp):
assert _dtype is None or _dtype==data.dtype or data.dtype==dtypes.weakint, f"dtype mismatch: {_dtype} vs {data.dtype}"
# if data is dtype.weakint that means that this is a symbolic int and we need to lower it to something we can make a Tensor out of
# TODO: remove this and stay in weakint
if data.dtype==dtypes.weakint: data = _index_to_concrete_int(data)
if data.op is Ops.BIND:
var, val = data.unbind()
# give the bound constant a device
const = UOp.const(var.dtype, val, _device, ())
data = data.replace(src=(var.replace(src=const.src), const))
if data.dtype == dtypes.weakint: data = Tensor.from_uop(data).uop
elif data is None:
data = UOp.const(_dtype or dtypes.default_float, 0, _device)
elif isinstance(data, get_args(ConstType)):
@@ -503,7 +497,13 @@ class Tensor(OpMixin):
@staticmethod
def from_uop(y:UOp, **kwargs) -> Tensor:
if y.op is Ops.BIND: return Tensor(y, **kwargs, requires_grad=False)
# TODO: remove this and stay in weakint
if y.dtype == dtypes.weakint: y = _index_to_concrete_int(y)
if y.op is Ops.BIND:
var, val = y.unbind()
_device = canonicalize_device(kwargs.get("device"))
const = UOp.const(var.dtype, val, _device, ())
return Tensor(y.replace(src=(var.replace(src=const.src), const)), **kwargs, requires_grad=False)
if y.op is Ops.CONST: return Tensor(y.arg, **kwargs, requires_grad=False)
if y.op is Ops.MUL: return Tensor.from_uop(y.src[0]) * Tensor.from_uop(y.src[1])
if y.op is Ops.ADD: return Tensor.from_uop(y.src[0]) + Tensor.from_uop(y.src[1])