mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-26 23:38:58 -05:00
add optional NEG and SUB (#6750)
* add optional NEG and SUB * describe that compute + optional mulacc * ptx cleanup * lil cleanups
This commit is contained in:
@@ -4,7 +4,7 @@ import functools, itertools, heapq, math, operator
|
||||
from collections import defaultdict
|
||||
from tinygrad.dtype import dtypes, PtrDType, ImageDType, ConstType
|
||||
from tinygrad.ops import UnaryOps, BinaryOps, exec_alu, UOp, UOps, END_FOR_UOP, type_verify, print_uops, identity_element
|
||||
from tinygrad.ops import UPat, PatternMatcher, graph_rewrite
|
||||
from tinygrad.ops import UPat, PatternMatcher, graph_rewrite, TernaryOps
|
||||
from tinygrad.helpers import DEBUG, getenv, flatten, dedup, TRANSCENDENTAL, AMX, prod, CI, partition, all_same
|
||||
from tinygrad.codegen.transcendental import xexp2, xlog2, xsin, TRANSCENDENTAL_SUPPORTED_DTYPES
|
||||
if TYPE_CHECKING: from tinygrad.renderer import Renderer
|
||||
@@ -284,6 +284,11 @@ def get_extra_patterns(ops):
|
||||
UOp(UOps.ALU, root.dtype, (mul, UOp.const(dtypes.int, int(math.log2(const.arg)))), BinaryOps.SHL) if const.arg in shiftable_consts else None),
|
||||
(UPat(UOps.ALU, arg=BinaryOps.IDIV, name="root", src=(UPat.var("div"), UPat.cvar("const"))), lambda root, div, const:
|
||||
UOp(UOps.ALU, root.dtype, (div, UOp.const(dtypes.int, int(math.log2(const.arg)))), BinaryOps.SHR) if const.arg in shiftable_consts else None)]
|
||||
if UnaryOps.NEG in ops:
|
||||
pat += [(UPat.var('x')*-1, lambda x: x.alu(UnaryOps.NEG))]
|
||||
if BinaryOps.SUB in ops: pat += [(UPat.var('x')+UPat.var('y').alu(UnaryOps.NEG), lambda x,y: x.alu(BinaryOps.SUB, y))]
|
||||
if TernaryOps.MULACC in ops:
|
||||
pat += [(UPat.var('a')*UPat.var('b')+UPat.var('c'), lambda a,b,c: a.alu(TernaryOps.MULACC, b, c))]
|
||||
return PatternMatcher(pat)
|
||||
|
||||
# ***** threefry *****
|
||||
@@ -361,9 +366,10 @@ def no_vectorized_wmma(wmma:UOp):
|
||||
|
||||
# this is symbolic 2.0
|
||||
constant_folder = PatternMatcher([
|
||||
# bool ADD is OR, MUL is AND. prevents other rules to rewrite bool ADD/MUL incorrectly
|
||||
(UPat(UOps.ALU, dtypes.bool, arg=BinaryOps.ADD, name="x"), lambda x: UOp(x.op, x.dtype, x.src, BinaryOps.OR)),
|
||||
(UPat(UOps.ALU, dtypes.bool, arg=BinaryOps.MUL, name="x"), lambda x: UOp(x.op, x.dtype, x.src, BinaryOps.AND)),
|
||||
# bool MUL is AND, ADD/MAX is OR. prevents other rules to rewrite bool ADD/MUL incorrectly
|
||||
(UPat.var('x', dtype=dtypes.bool) * UPat.var('y'), lambda x,y: x&y),
|
||||
(UPat.var('x', dtype=dtypes.bool) + UPat.var('y'), lambda x,y: x|y),
|
||||
(UPat.var('x', dtype=dtypes.bool).max(UPat.var('y')), lambda x,y: x|y),
|
||||
# self ASSIGN is just self
|
||||
(UPat(UOps.ASSIGN, src=(UPat.var('x'), UPat.var('x'))), lambda x: x),
|
||||
# ASSIGN to global is just self
|
||||
|
||||
@@ -33,19 +33,15 @@ asm_for_op: Dict[Op, Callable] = {
|
||||
}
|
||||
|
||||
supports_half: List[Op] = [UnaryOps.EXP2, BinaryOps.ADD, BinaryOps.MUL, BinaryOps.MAX, BinaryOps.CMPLT, TernaryOps.WHERE]
|
||||
shiftable_consts = set([2**i for i in range(64)])
|
||||
ptx_matcher = constant_folder+PatternMatcher([
|
||||
(UPat(UOps.ALU, arg=BinaryOps.CMPNE, src=(UPat(dtype=dtypes.bool),UPat()), name="root"),
|
||||
lambda root: UOp(root.op, root.dtype, root.src, BinaryOps.XOR)),
|
||||
(UPat(UOps.ALU, arg=BinaryOps.CMPLT, src=(UPat.var("x", dtype=dtypes.bool),UPat.var("y")), name="root"),
|
||||
lambda root,x,y: UOp(root.op, root.dtype, (UOp(UOps.ALU, dtypes.bool, (x, UOp.const(dtypes.bool, True)), BinaryOps.CMPNE), y), BinaryOps.MUL)),
|
||||
(UPat(UOps.ALU, arg=BinaryOps.ADD, src=[UPat.var("non_muls"), UPat(UOps.ALU, arg=BinaryOps.MUL, name="muls")], name="root"),
|
||||
lambda root, muls, non_muls: UOp(UOps.ALU, root.dtype, muls.src + (non_muls,), TernaryOps.MULACC)),
|
||||
# bool CMPNE is XOR, bool CMPLT is XOR+AND (universal makes this slow, this is for renderer only)
|
||||
(UPat.var('x', dtype=dtypes.bool).ne(UPat.var('y')), lambda x,y: x^y),
|
||||
(UPat.var('x', dtype=dtypes.bool).lt(UPat.var('y')), lambda x,y: (x^True)&y),
|
||||
# upcast to float32 all the ops that don't support half
|
||||
*[(UPat(UOps.ALU, arg=op, dtype=dtypes.half, name="x"),
|
||||
lambda x: (UOp(x.op, dtypes.float32, tuple([vv.cast(dtypes.float32) for vv in x.src]), x.arg).cast(dtypes.half)))
|
||||
lambda x: (UOp(x.op, dtypes.float32, tuple(vv.cast(dtypes.float32) for vv in x.src), x.arg).cast(dtypes.half)))
|
||||
for op in asm_for_op.keys() if op not in supports_half],
|
||||
(UPat(UOps.ALU, name="x", dtype=dtypes.bool, arg=BinaryOps.MAX),
|
||||
lambda x: UOp(UOps.ALU, dtypes.uint8, tuple(s.cast(dtypes.uint8) for s in x.src), x.arg).cast(dtypes.bool)),
|
||||
# fix the gates for load/store (low quality!)
|
||||
(UPat(UOps.LOAD, name="root", dtype=dtypes.bool, src=(UPat.var("x"),UPat.var("y"),UPat.var("z"),UPat.var("k"))),
|
||||
lambda root,x,y,z,k: UOp(root.op, dtypes.uint8, (x,y,z.cast(dtypes.uint8),k)).cast(dtypes.bool)),
|
||||
(UPat(UOps.LOAD, name="root", dtype=dtypes.bool, src=(UPat(),UPat())),
|
||||
@@ -66,8 +62,7 @@ ptx_matcher = constant_folder+PatternMatcher([
|
||||
lambda root, const: UOp(root.op, root.dtype,
|
||||
(root.src[0].cast(dtypes.int64),
|
||||
UOp.const(dtypes.int64, const.arg*root.src[0].dtype.itemsize),)+root.src[2:])),
|
||||
(UPat((UOps.LOAD, UOps.STORE), name="root", allow_any_len=True, src=(UPat((UOps.DEFINE_LOCAL,UOps.DEFINE_GLOBAL)),
|
||||
UPat.var("alu"))), # no const here
|
||||
(UPat((UOps.LOAD, UOps.STORE), name="root", allow_any_len=True, src=(UPat((UOps.DEFINE_LOCAL,UOps.DEFINE_GLOBAL)), UPat.var("alu"))),
|
||||
lambda root, alu: UOp(root.op, root.dtype,
|
||||
(alu.cast(dtypes.int64)*UOp.const(dtypes.int64, root.src[0].dtype.itemsize)+root.src[0].cast(dtypes.int64),
|
||||
UOp.const(dtypes.int64, 0))+root.src[2:])),
|
||||
|
||||
@@ -112,9 +112,10 @@ class CStyleLanguage(Renderer):
|
||||
code_for_op: Dict = {
|
||||
UnaryOps.SQRT: lambda x,dtype: f"sqrt({x})",
|
||||
UnaryOps.RECIP: lambda x,dtype: f"(1/{x})",
|
||||
UnaryOps.NEG: lambda x,dtype: f"-{x}",
|
||||
UnaryOps.EXP2: lambda x,dtype: f"exp2({x})", UnaryOps.LOG2: lambda x,dtype: f"log2({x})", UnaryOps.SIN: lambda x,dtype: f"sin({x})",
|
||||
BinaryOps.SHL: lambda a,b,dtype: f"({a}<<{b})", BinaryOps.SHR: lambda a,b,dtype: f"({a}>>{b})",
|
||||
BinaryOps.ADD: lambda a,b,dtype: f"({a}+{b})", BinaryOps.MAX: lambda a,b,dtype: f"max({a},{b})",
|
||||
BinaryOps.ADD: lambda a,b,dtype: f"({a}+{b})", BinaryOps.SUB: lambda a,b,dtype: f"({a}-{b})", BinaryOps.MAX: lambda a,b,dtype: f"max({a},{b})",
|
||||
BinaryOps.IDIV: lambda a,b,dtype: f"({a}/{b})", BinaryOps.MUL: lambda a,b,dtype: f"({a}*{b})", BinaryOps.MOD: lambda a,b,dtype: f"({a}%{b})",
|
||||
BinaryOps.CMPLT: lambda a,b,dtype: f"({a}<{b})", BinaryOps.CMPNE: lambda a,b,dtype: f"({a}!={b})", BinaryOps.XOR: lambda a,b,dtype: f"({a}^{b})",
|
||||
BinaryOps.AND: lambda a,b,dtype: f"({a}&{b})", BinaryOps.OR: lambda a,b,dtype: f"({a}|{b})",
|
||||
|
||||
Reference in New Issue
Block a user