uops cleanup (#4634)

* def add cleanup

* minor speedup

* add back ptx speed

* a little faster

* merge that

* only linearize once for ptx

* two graph rewrites for ptx, bug?
This commit is contained in:
George Hotz
2024-05-17 20:02:38 -07:00
committed by GitHub
parent 07b350a8f4
commit b74cc1d01a
4 changed files with 50 additions and 45 deletions

View File

@@ -259,6 +259,11 @@ ptx_matcher = PatternMatcher([
({"__name__": "root", "uop": UOps.STORE, "vin": ({},{},{},{"__name__": "g", "dtype": dtypes.int})},
lambda root,g: UOp(root.uop, root.dtype, root.vin[:3] + (UOp(UOps.CAST, dtypes.bool, (g,)),), root.arg)),
# ptr_ar (load/store)
({"__name__": "root", "uop": {UOps.LOAD, UOps.STORE}, "__allow_len__":[2,3,4,5], "vin": ({"uop":{UOps.DEFINE_LOCAL,UOps.DEFINE_GLOBAL}},
{"uop": UOps.ALU, "arg": BinaryOps.ADD,"vin":[{"__name__": "alu"}, {"__name__": "const", "uop":UOps.CONST}]})},
lambda root, alu, const: UOp(root.uop, root.dtype,
(alu.cast(dtypes.int64)*UOp.const(dtypes.int64, root.vin[0].dtype.itemsize)+root.vin[0].cast(dtypes.int64),
UOp.const(const.dtype, root.vin[0].dtype.itemsize)*const)+root.vin[2:])),
({"__name__": "root", "uop": {UOps.LOAD, UOps.STORE}, "__allow_len__":[2,3,4,5], "vin": ({"uop":{UOps.DEFINE_LOCAL,UOps.DEFINE_GLOBAL}},
{"__name__": "const", "uop":UOps.CONST})},
lambda root, const: UOp(root.uop, root.dtype, (root.vin[0].cast(dtypes.int64),