From b199b699ed8b91e295b6cd65b7fafbb365daf2a8 Mon Sep 17 00:00:00 2001 From: George Hotz <72895+geohot@users.noreply.github.com> Date: Thu, 26 Sep 2024 09:59:36 +0800 Subject: [PATCH] use shl everywhere (#6744) * use shl everywhere * fix parens * late patterns * works as an extra pass * ptx --- test/test_uops_stats.py | 1 + tinygrad/codegen/uopgraph.py | 33 +++++++++++++++++++++++---------- tinygrad/dtype.py | 14 ++++++++++---- tinygrad/renderer/assembly.py | 10 +--------- tinygrad/renderer/cstyle.py | 5 +++-- 5 files changed, 38 insertions(+), 25 deletions(-) diff --git a/test/test_uops_stats.py b/test/test_uops_stats.py index 2c129e7448..2e38925e05 100644 --- a/test/test_uops_stats.py +++ b/test/test_uops_stats.py @@ -94,6 +94,7 @@ class TestUOpsStats(unittest.TestCase): # NOTE; ops also include indexing ops assert expected_ops <= ops and ops <= expected_ops * 2 + @unittest.skipIf(getenv("PTX"), "wrong in PTX") def test_simple_add_sq(self): a = Tensor.empty(100,100) b = Tensor.empty(100,100) diff --git a/tinygrad/codegen/uopgraph.py b/tinygrad/codegen/uopgraph.py index 4d0df88e89..d9602988a9 100644 --- a/tinygrad/codegen/uopgraph.py +++ b/tinygrad/codegen/uopgraph.py @@ -266,12 +266,25 @@ def simplify_valid_image_load(load:UOp, buf:UOp): new_valid = functools.reduce(operator.and_, ss) if (ss:=[s for s in _get_chain(valid, BinaryOps.AND) if s not in drop_stmt]) else None return load.replace(src=((buf, idx, invalid_val, new_valid) if new_valid else (buf, idx))) -# ***** transcendental ***** +# ***** optional patterns ***** + +transcendental_patterns = [ + (UPat(UOps.ALU, dtype=TRANSCENDENTAL_SUPPORTED_DTYPES, src=(UPat.var("d"),), arg=UnaryOps.EXP2), xexp2), + (UPat(UOps.ALU, dtype=TRANSCENDENTAL_SUPPORTED_DTYPES, src=(UPat.var("d"),), arg=UnaryOps.LOG2), xlog2), + (UPat(UOps.ALU, dtype=TRANSCENDENTAL_SUPPORTED_DTYPES, src=(UPat.var("d"),), arg=UnaryOps.SIN), xsin), +] @functools.lru_cache(None) -def transcendental_folding(ops): - return PatternMatcher([(UPat(UOps.ALU, dtype=TRANSCENDENTAL_SUPPORTED_DTYPES, src=(UPat.var("d"),), arg=k), cast(Callable, v)) - for k,v in ((UnaryOps.EXP2, xexp2), (UnaryOps.LOG2, xlog2), (UnaryOps.SIN, xsin)) if k not in ops]) +def get_extra_patterns(ops): + pat = [(p[0], cast(Callable, p[1])) for p in transcendental_patterns if p[0].arg not in ops or TRANSCENDENTAL >= 2] + if BinaryOps.SHL in ops and BinaryOps.SHR in ops: + shiftable_consts = set([2**i for i in range(64)]) + pat += [ + (UPat(UOps.ALU, arg=BinaryOps.MUL, name="root", dtype=dtypes.ints, src=[UPat.cvar("const"), UPat.var("mul")]), lambda root, mul, const: + UOp(UOps.ALU, root.dtype, (mul, UOp.const(dtypes.int, int(math.log2(const.arg)))), BinaryOps.SHL) if const.arg in shiftable_consts else None), + (UPat(UOps.ALU, arg=BinaryOps.IDIV, name="root", src=(UPat.var("div"), UPat.cvar("const"))), lambda root, div, const: + UOp(UOps.ALU, root.dtype, (div, UOp.const(dtypes.int, int(math.log2(const.arg)))), BinaryOps.SHR) if const.arg in shiftable_consts else None)] + return PatternMatcher(pat) # ***** threefry ***** @@ -715,11 +728,10 @@ linearize_cnt = 0 def full_graph_rewrite(sink:UOp, opts:Optional[Renderer]=None) -> UOp: global linearize_cnt, acc_number assert sink.op is UOps.SINK, f"sink isn't sink, it's {sink.op}" - folder = constant_folder + transcendental_folding(tuple() if TRANSCENDENTAL >= 2 or opts is None else tuple(opts.code_for_op.keys())) # do graph rewrite acc_number = 0 - sink = graph_rewrite(sink, folder) + sink = graph_rewrite(sink, constant_folder) # rewrite pyint to int32 sink = graph_rewrite(sink, no_pyint) @@ -727,11 +739,12 @@ def full_graph_rewrite(sink:UOp, opts:Optional[Renderer]=None) -> UOp: # expand linearize_cnt += 1 if linearize_cnt != (de:=getenv("DEBUG_EXPAND", 0)) and de != -1: - sink = graph_rewrite(sink, folder+expander) + sink = graph_rewrite(sink, constant_folder+expander) if getenv("DO_REDUCE", 1): - sink = graph_rewrite(sink, folder+just_reduce) - sink = graph_rewrite(sink, folder+(devectorize+float4_folding if opts is not None and opts.supports_float4 else devectorize)) - sink = graph_rewrite(sink, folder+reducer) + sink = graph_rewrite(sink, constant_folder+just_reduce) + sink = graph_rewrite(sink, constant_folder+(devectorize+float4_folding if opts is not None and opts.supports_float4 else devectorize)) + sink = graph_rewrite(sink, constant_folder+reducer) + sink = graph_rewrite(sink, constant_folder+get_extra_patterns(tuple(opts.code_for_op.keys()) if opts is not None else ())) if opts is not None and opts.extra_matcher is not None: sink = graph_rewrite(sink, opts.extra_matcher) return sink diff --git a/tinygrad/dtype.py b/tinygrad/dtype.py index 2e475dbab2..f4e9d32b41 100644 --- a/tinygrad/dtype.py +++ b/tinygrad/dtype.py @@ -44,13 +44,13 @@ class PtrDType(DType): class dtypes: @staticmethod @functools.lru_cache(None) - def is_float(x: DType) -> bool: return x.scalar() in {dtypes.float16, dtypes.bfloat16, dtypes.float32, dtypes.float64} + def is_float(x: DType) -> bool: return x.scalar() in dtypes.floats @staticmethod # static methds on top, or bool in the type info will refer to dtypes.bool @functools.lru_cache(None) - def is_int(x: DType) -> bool: return x.scalar() in {dtypes.int8, dtypes.int16, dtypes.int32, dtypes.int64, dtypes.pyint} or dtypes.is_unsigned(x) + def is_int(x: DType) -> bool: return x.scalar() in dtypes.ints @staticmethod @functools.lru_cache(None) - def is_unsigned(x: DType) -> bool: return x.scalar() in {dtypes.uint8, dtypes.uint16, dtypes.uint32, dtypes.uint64} + def is_unsigned(x: DType) -> bool: return x.scalar() in dtypes.uints @staticmethod def from_py(x) -> DType: if x.__class__ is float: return dtypes.default_float @@ -114,6 +114,11 @@ class dtypes: default_float: ClassVar[DType] = float32 default_int: ClassVar[DType] = int32 + floats = (float16, bfloat16, float32, float64) + uints = (uint8, uint16, uint32, uint64) + sints = (int8, int16, int32, int64, pyint) + ints = uints + sints + if (env_default_float := getenv("DEFAULT_FLOAT", "")): dtypes.default_float = getattr(dtypes, env_default_float.lower()) assert dtypes.is_float(dtypes.default_float), f"{env_default_float} is not a float dtype" @@ -137,7 +142,8 @@ def least_upper_dtype(*ds:DType) -> DType: def least_upper_float(dt:DType) -> DType: return dt if dtypes.is_float(dt) else least_upper_dtype(dt, dtypes.float32) # HACK: staticmethods are not callable in 3.8 so we have to compare the class -DTYPES_DICT = {k: v for k, v in dtypes.__dict__.items() if not (k.startswith(('__', 'default', 'void')) or v.__class__ is staticmethod)} +DTYPES_DICT = {k: v for k, v in dtypes.__dict__.items() if not (k.startswith(('__', 'default', 'void')) + or v.__class__ is staticmethod or isinstance(v, tuple))} INVERSE_DTYPES_DICT = {v.name:k for k,v in DTYPES_DICT.items()} INVERSE_DTYPES_DICT['void'] = 'void' diff --git a/tinygrad/renderer/assembly.py b/tinygrad/renderer/assembly.py index 3ff56a47c3..7e491514a5 100644 --- a/tinygrad/renderer/assembly.py +++ b/tinygrad/renderer/assembly.py @@ -1,5 +1,5 @@ from typing import DefaultDict, Dict, List, Union, Optional, cast, Callable -import struct, math +import struct from collections import defaultdict from tinygrad.ops import BinaryOps, UnaryOps, TernaryOps, Op, UOps, UOp, PatternMatcher, UPat from tinygrad.codegen.uopgraph import constant_folder @@ -35,14 +35,6 @@ asm_for_op: Dict[Op, Callable] = { supports_half: List[Op] = [UnaryOps.EXP2, BinaryOps.ADD, BinaryOps.MUL, BinaryOps.MAX, BinaryOps.CMPLT, TernaryOps.WHERE] shiftable_consts = set([2**i for i in range(64)]) ptx_matcher = constant_folder+PatternMatcher([ - (UPat(UOps.ALU, arg=BinaryOps.MUL, name="root", dtype=tuple([dt for dt in dtypes.fields().values() if dtypes.is_int(dt)]), - src=[UPat.cvar("const"), UPat.var("mul")]), - lambda root, mul, const: UOp(UOps.ALU, root.dtype, - (mul, UOp.const(dtypes.int, int(math.log2(const.arg)))), BinaryOps.SHL) if const.arg in shiftable_consts else None), - (UPat(UOps.ALU, arg=BinaryOps.IDIV, name="root", dtype=tuple([dt for dt in dtypes.fields().values() if dtypes.is_int(dt)]), - src=[UPat.cvar("const"), UPat.var("div")]), - lambda root, div, const: UOp(UOps.ALU, root.dtype, - (div, UOp.const(dtypes.int, int(math.log2(const.arg)))), BinaryOps.SHR) if const.arg in shiftable_consts else None), (UPat(UOps.ALU, arg=BinaryOps.CMPNE, src=(UPat(dtype=dtypes.bool),UPat()), name="root"), lambda root: UOp(root.op, root.dtype, root.src, BinaryOps.XOR)), (UPat(UOps.ALU, arg=BinaryOps.CMPLT, src=(UPat.var("x", dtype=dtypes.bool),UPat.var("y")), name="root"), diff --git a/tinygrad/renderer/cstyle.py b/tinygrad/renderer/cstyle.py index bf013a50f4..980caaba3f 100644 --- a/tinygrad/renderer/cstyle.py +++ b/tinygrad/renderer/cstyle.py @@ -8,7 +8,7 @@ from tinygrad.dtype import ImageDType, dtypes, DType, PtrDType from tinygrad.renderer import Renderer, TensorCore def render_load(r:CStyleLanguage, load:UOp, buf:UOp) -> str: - sidx = strip_parens(r[load.src[1]]) + sidx = strip_parens(r[load.src[1]]) if load.src[1].arg == BinaryOps.ADD else r[load.src[1]] if isinstance(buf.dtype, ImageDType): assert load.dtype == dtypes.float.vec(4), f"images must be float4, getting {load.dtype}" val = f"read_imagef({r[buf]}, smp, {sidx})" @@ -24,7 +24,7 @@ def render_load(r:CStyleLanguage, load:UOp, buf:UOp) -> str: return val def render_store(r:CStyleLanguage, store:UOp, buf:UOp, var:UOp) -> str: - sidx = strip_parens(r[store.src[1]]) + sidx = strip_parens(r[store.src[1]]) if store.src[1].arg == BinaryOps.ADD else r[store.src[1]] if isinstance(buf.dtype, ImageDType): assert var.dtype == dtypes.float.vec(4), f"images must be float4, getting {var.dtype}" val = f"write_imagef({r[buf]}, {sidx}, {r[var]});" @@ -113,6 +113,7 @@ class CStyleLanguage(Renderer): UnaryOps.SQRT: lambda x,dtype: f"sqrt({x})", UnaryOps.RECIP: lambda x,dtype: f"(1/{x})", UnaryOps.EXP2: lambda x,dtype: f"exp2({x})", UnaryOps.LOG2: lambda x,dtype: f"log2({x})", UnaryOps.SIN: lambda x,dtype: f"sin({x})", + BinaryOps.SHL: lambda a,b,dtype: f"({a}<<{b})", BinaryOps.SHR: lambda a,b,dtype: f"({a}>>{b})", BinaryOps.ADD: lambda a,b,dtype: f"({a}+{b})", BinaryOps.MAX: lambda a,b,dtype: f"max({a},{b})", BinaryOps.IDIV: lambda a,b,dtype: f"({a}/{b})", BinaryOps.MUL: lambda a,b,dtype: f"({a}*{b})", BinaryOps.MOD: lambda a,b,dtype: f"({a}%{b})", BinaryOps.CMPLT: lambda a,b,dtype: f"({a}<{b})", BinaryOps.CMPNE: lambda a,b,dtype: f"({a}!={b})", BinaryOps.XOR: lambda a,b,dtype: f"({a}^{b})",