Ptx mulacc (#3937)

* mulacc

* Move more stuff to pattern matcher

* disable callable from the == check

* disable function passing in pattern matcher

* Add set of dtypes pattern matching + refactor mulacc pattern
This commit is contained in:
Szymon Ożóg
2024-04-04 09:15:25 +02:00
committed by GitHub
parent 3e72d745ea
commit 92378fb5b6
3 changed files with 8 additions and 1 deletions

View File

@@ -50,6 +50,8 @@ def _match(uop:UOp, pattern:Dict[str, Any], store:Dict[str, UOp]) -> bool:
for k,v in new_store.items(): store[k] = v
return True
return False
elif k == "dtype":
if uop.__getattribute__(k) not in (v if isinstance(v, set) else set([v])): return False
else:
if uop.__getattribute__(k) != v: return False
return True

View File

@@ -16,7 +16,7 @@ from tinygrad.buffer import Buffer
class UnaryOps(Enum): EXP2 = auto(); LOG2 = auto(); CAST = auto(); SIN = auto(); SQRT = auto(); NEG = auto() # noqa: E702
class BinaryOps(Enum):
ADD = auto(); SUB = auto(); MUL = auto(); DIV = auto(); MAX = auto(); MOD = auto(); CMPLT = auto(); CMPEQ = auto(); XOR = auto() # noqa: E702
class TernaryOps(Enum): WHERE = auto() # noqa: E702
class TernaryOps(Enum): WHERE = auto(); MULACC = auto() # noqa: E702
class ReduceOps(Enum): SUM = auto(); MAX = auto() # noqa: E702
class BufferOps(Enum): LOAD = auto(); CONST = auto(); STORE = auto() # noqa: E702
class LoadOps(Enum): EMPTY = auto(); CONST = auto(); COPY = auto(); CONTIGUOUS = auto(); CUSTOM = auto(); ASSIGN = auto() # noqa: E702

View File

@@ -60,6 +60,9 @@ def uops_to_asm(lang:AssemblyLanguage, function_name:str, uops:UOpGraph) -> str:
lambda root: UOp(UOps.ALU, dtypes.bool, (UOp(root.uop, root.dtype, root.vin, BinaryOps.XOR),), UnaryOps.NEG)),
({"__name__": "root", "uop": UOps.ALU, "arg": BinaryOps.CMPLT, "vin": ({"__name__": "x", "dtype": dtypes.bool},{"__name__": "y"})},
lambda root,x,y: UOp(root.uop, root.dtype, (UOp(UOps.ALU, dtypes.bool, (x,), UnaryOps.NEG), y), BinaryOps.MUL)),
({"__name__": "root", "uop": UOps.ALU, "arg": BinaryOps.ADD, "dtype": set([dtypes.float16, dtypes.bfloat16, dtypes.float32, dtypes.float64]),
"vin": [{"__name__": "non_muls"}, {"__name__": "muls", "uop": UOps.ALU, "arg": BinaryOps.MUL}]},
lambda root, muls, non_muls : UOp(UOps.ALU, root.dtype, muls.vin + (non_muls,), TernaryOps.MULACC)),
*[({"__name__": "x", "uop": UOps.ALU, "dtype": dtypes.half, "arg": op},
lambda x: UOp(UOps.CAST, dtypes.half, (UOp(x.uop, dtypes.float32, tuple([UOp(UOps.CAST, dtypes.float32, (vv,)) for vv in x.vin]), x.arg),)))
for op in lang.asm_for_op.keys() if op not in lang.supports_half],
@@ -72,6 +75,7 @@ def uops_to_asm(lang:AssemblyLanguage, function_name:str, uops:UOpGraph) -> str:
lambda root,z: UOp(root.uop, root.dtype, root.vin[:2] + (UOp(UOps.CAST, dtypes.uint8, (z,), None),), root.arg)),
({"__name__": "root", "uop": UOps.STORE, "vin": ({},{},{"__name__": "z","dtype": dtypes.bool})},
lambda root,z: UOp(root.uop, root.dtype, root.vin[:2] + (UOp(UOps.CAST, dtypes.uint8, (z,), None),), root.arg)),
])
# here we do a pretransform on UOps to fix some shortcomings of PTX
@@ -216,6 +220,7 @@ class PTXLanguage(AssemblyLanguage):
BinaryOps.MAX: lambda d,a,b,dt,name: f"max.{name} {d}, {a}, {b};", BinaryOps.MOD: lambda d,a,b,dt,name: f"rem.{name} {d}, {a}, {b};",
BinaryOps.CMPLT: lambda d,a,b,dt,name: f"setp.lt.{name} {d}, {a}, {b};",
BinaryOps.CMPEQ: lambda d,a,b,dt,name: f"setp.eq.{name} {d}, {a}, {b};",
TernaryOps.MULACC: lambda d,a,b,c,dt,name: f"fma.rn.{name} {d}, {a}, {b}, {c};",
TernaryOps.WHERE: lambda d,a,b,c,dt,name:
f"@{a} mov.{name} {d}, {b};\n@!{a} mov.{name} {d}, {c};" if name == "pred" else f"selp.{'b16' if name == 'f16' else name} {d}, {b}, {c}, {a};"
}