mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-10 23:48:01 -05:00
allow keyword args in UOp.store [run_process_replay] (#5008)
* allow keyword args in UOp.store [run_process_replay] * same for load * typing can stay
This commit is contained in:
@@ -67,9 +67,9 @@ class UOp:
|
||||
@staticmethod
|
||||
def alu(arg, *vin:UOp): return UOp(UOps.ALU, dtypes.bool if arg in {BinaryOps.CMPLT, BinaryOps.CMPNE} else vin[-1].dtype, vin, arg)
|
||||
@staticmethod
|
||||
def load(*vin: UOp, dtype:Optional[DType]=None): return UOp(UOps.LOAD, dtype, tuple(vin))
|
||||
def load(*vin:UOp, dtype:Optional[DType]=None, **kwargs): return UOp(UOps.LOAD, dtype, tuple(vin)+tuple(kwargs.values()))
|
||||
@staticmethod
|
||||
def store(*vin: UOp, dtype:Optional[DType]=None): return UOp(UOps.STORE, dtype, tuple(vin))
|
||||
def store(*vin:UOp, dtype:Optional[DType]=None, **kwargs): return UOp(UOps.STORE, dtype, tuple(vin)+tuple(kwargs.values()))
|
||||
@staticmethod
|
||||
def var(name: Optional[str]=None, dtype: Optional[DType]=None): return UOp(UOps.VAR, dtype=dtype, arg=name)
|
||||
@staticmethod
|
||||
@@ -249,10 +249,8 @@ constant_folder = PatternMatcher([
|
||||
(UOp.store(UOp.var("buf"), UOp.var("idx"), UOp.alu(TernaryOps.WHERE, UOp.var("gate"), UOp.var("alt"), UOp.load(UOp.var("buf"), UOp.var("idx")))),
|
||||
lambda buf, idx, gate, alt: UOp.store(buf, idx, alt, gate)),
|
||||
# store float4/float2 directly (remove CAST/GEP)
|
||||
(UOp.store(UOp.var("buf"), UOp.var("idx"), UOp(UOps.CAST, vin=tuple(UOp(UOps.GEP, arg=i, vin=(UOp.var("val"),)) for i in range(4)))),
|
||||
lambda buf, idx, val: UOp.store(buf, idx, val)), # pylint: disable=unnecessary-lambda
|
||||
(UOp.store(UOp.var("buf"), UOp.var("idx"), UOp(UOps.CAST, vin=tuple(UOp(UOps.GEP, arg=i, vin=(UOp.var("val"),)) for i in range(2)))),
|
||||
lambda buf, idx, val: UOp.store(buf, idx, val)), # pylint: disable=unnecessary-lambda
|
||||
(UOp.store(UOp.var("buf"), UOp.var("idx"), UOp(UOps.CAST, vin=tuple(UOp(UOps.GEP, arg=i, vin=(UOp.var("val"),)) for i in range(4)))), UOp.store),
|
||||
(UOp.store(UOp.var("buf"), UOp.var("idx"), UOp(UOps.CAST, vin=tuple(UOp(UOps.GEP, arg=i, vin=(UOp.var("val"),)) for i in range(2)))), UOp.store),
|
||||
# CAST-PHI-GEP -> PHI-CAST
|
||||
(UPat(UOps.CAST, name="root", vin=tuple(UPat(UOps.PHI, vin=(UPat(UOps.GEP, i, vin=(UPat(name="val"),)), UPat(name=f"v{i}"))) for i in range(4))),
|
||||
lambda root, val, v0, v1, v2, v3: UOp(UOps.PHI, root.dtype, (val, UOp(UOps.CAST, val.dtype, (v0, v1, v2, v3))))),
|
||||
|
||||
Reference in New Issue
Block a user