root cause fix for UOps.CONST bad args (#4638)

* delete that

* real fix
This commit is contained in:
qazal
2024-05-18 14:15:25 +08:00
committed by GitHub
parent 9b464e34ea
commit d0a2d40df3
2 changed files with 1 additions and 4 deletions

View File

@@ -52,7 +52,7 @@ class Linearizer(Kernel):
# NOTE: the consts have to be cached for deduping of downstream uops to work
def const(self, b:ConstType, dtype:DType=dtypes.int32) -> UOp:
if isinstance(b, Variable): return self.uops.add(UOps.DEFINE_VAR, dtype, tuple(), b.unbind()[0])
else: return self.uops.add(UOps.CONST, dtype, tuple(), b)
else: return self.uops.add(UOps.CONST, dtype, tuple(), dtypes.as_const(b, dtype))
def cast(self, val: UOp, dtype) -> UOp: return self.uops.add(UOps.CAST, dtype, (val,)) if val.dtype != dtype else val

View File

@@ -360,9 +360,6 @@ class UOpGraph:
if type_verify: self.type_verify()
def add(self, uop:UOps, dtype:Optional[DType]=None, vin:Tuple[UOp, ...]=tuple(), arg:Any=None) -> UOp:
if uop is UOps.CONST:
assert dtype is not None
arg = dtypes.as_const(arg, dtype) # TODO: this doesn't belong here
if found:=self.nodes.get(key:=(uop, dtype, vin, arg)): return found
self.nodes[key] = ret = UOp(*key)
return ret