mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-25 23:08:06 -05:00
root cause fix for UOps.CONST bad args (#4638)
* delete that * real fix
This commit is contained in:
@@ -52,7 +52,7 @@ class Linearizer(Kernel):
|
||||
# NOTE: the consts have to be cached for deduping of downstream uops to work
|
||||
def const(self, b:ConstType, dtype:DType=dtypes.int32) -> UOp:
|
||||
if isinstance(b, Variable): return self.uops.add(UOps.DEFINE_VAR, dtype, tuple(), b.unbind()[0])
|
||||
else: return self.uops.add(UOps.CONST, dtype, tuple(), b)
|
||||
else: return self.uops.add(UOps.CONST, dtype, tuple(), dtypes.as_const(b, dtype))
|
||||
|
||||
def cast(self, val: UOp, dtype) -> UOp: return self.uops.add(UOps.CAST, dtype, (val,)) if val.dtype != dtype else val
|
||||
|
||||
|
||||
@@ -360,9 +360,6 @@ class UOpGraph:
|
||||
if type_verify: self.type_verify()
|
||||
|
||||
def add(self, uop:UOps, dtype:Optional[DType]=None, vin:Tuple[UOp, ...]=tuple(), arg:Any=None) -> UOp:
|
||||
if uop is UOps.CONST:
|
||||
assert dtype is not None
|
||||
arg = dtypes.as_const(arg, dtype) # TODO: this doesn't belong here
|
||||
if found:=self.nodes.get(key:=(uop, dtype, vin, arg)): return found
|
||||
self.nodes[key] = ret = UOp(*key)
|
||||
return ret
|
||||
|
||||
Reference in New Issue
Block a user