From 5dd1ffd5d012aa83d0ad073aae35deb9cebb8fd1 Mon Sep 17 00:00:00 2001 From: George Hotz <72895+geohot@users.noreply.github.com> Date: Thu, 31 Oct 2024 18:16:49 +0700 Subject: [PATCH] don't const rewrite in cstyle (#7442) * don't const rewrite in cstyle * Update cstyle.py * simple_symbolic * fix bfloat16 const on AMD --- test/helpers.py | 2 +- test_driven_development.sh | 8 ++++++++ tinygrad/codegen/uopgraph.py | 4 ++-- tinygrad/ops.py | 11 +++++++---- tinygrad/renderer/cstyle.py | 10 ++++++---- tinygrad/renderer/ptx.py | 4 ++-- 6 files changed, 26 insertions(+), 13 deletions(-) create mode 100755 test_driven_development.sh diff --git a/test/helpers.py b/test/helpers.py index d66849a6a2..5aa62f3513 100644 --- a/test/helpers.py +++ b/test/helpers.py @@ -34,7 +34,7 @@ def assert_jit_cache_len(fxn, expected_len): def is_dtype_supported(dtype: DType, device: str = Device.DEFAULT): if dtype == dtypes.bfloat16: # NOTE: this requires bf16 buffer support - return device in {"AMD"} or (device in {"CUDA", "NV"} and not CI and not getenv("PTX")) + return device in {"AMD"} or (device in {"CUDA", "NV", "METAL"} and not CI and not getenv("PTX")) if device in ["WEBGPU", "WEBGL"]: return dtype in [dtypes.float, dtypes.int32, dtypes.uint32] # for CI GPU and OSX, cl_khr_fp16 isn't supported # for CI LLVM, it segfaults because it can't link to the casting function diff --git a/test_driven_development.sh b/test_driven_development.sh new file mode 100755 index 0000000000..e6334e6aa9 --- /dev/null +++ b/test_driven_development.sh @@ -0,0 +1,8 @@ +#!/bin/bash +python3 test/external/process_replay/reset.py +RUN_PROCESS_REPLAY=1 pytest -n auto test/test_tiny.py test/test_uop_graph.py test/test_ops.py test/test_linearizer.py +while true; do + if python3 test/test_tiny.py; then + PYTHONPATH="." python3 test/external/process_replay/process_replay.py + fi +done \ No newline at end of file diff --git a/tinygrad/codegen/uopgraph.py b/tinygrad/codegen/uopgraph.py index 8d611b5c0d..b48c153165 100644 --- a/tinygrad/codegen/uopgraph.py +++ b/tinygrad/codegen/uopgraph.py @@ -3,7 +3,7 @@ from typing import Optional, Tuple, Dict, List, cast, TYPE_CHECKING, Any, Defaul import functools, itertools, operator from collections import defaultdict from tinygrad.dtype import dtypes, ImageDType, PtrDType -from tinygrad.ops import UnaryOps, BinaryOps, TernaryOps, UOp, UOps, UPat, PatternMatcher, symbolic_flat +from tinygrad.ops import UnaryOps, BinaryOps, TernaryOps, UOp, UOps, UPat, PatternMatcher, symbolic_flat, symbolic_simple from tinygrad.ops import graph_rewrite, is_irreducible, split_uop, identity_element, uop_given_valid, parse_valid, is_increasing, simplify_valid from tinygrad.helpers import DEBUG, getenv, flatten, dedup, TRANSCENDENTAL, AMX, prod, partition, all_same from tinygrad.codegen.transcendental import xexp2, xlog2, xsin, TRANSCENDENTAL_SUPPORTED_DTYPES @@ -525,5 +525,5 @@ def full_graph_rewrite(sink:UOp, opts:Optional[Renderer]=None) -> UOp: sink = graph_rewrite(sink, sym+get_extra_patterns(tuple(opts.code_for_op.keys()) if opts is not None else (), TRANSCENDENTAL>=2)) # for rendering without sym (including the rules from the renderer) - sink = graph_rewrite(sink, pm_render+opts.extra_matcher if opts is not None and opts.extra_matcher is not None else pm_render) + sink = graph_rewrite(sink, symbolic_simple+(pm_render+opts.extra_matcher if opts is not None and opts.extra_matcher is not None else pm_render)) return sink diff --git a/tinygrad/ops.py b/tinygrad/ops.py index 5a576675d6..93c77aafcc 100644 --- a/tinygrad/ops.py +++ b/tinygrad/ops.py @@ -1003,7 +1003,7 @@ def max_var_const(x:UOp, c1:UOp, c2:UOp): if x.vmin >= 0: return x*c1 if c1.arg >= c2.arg else x*c2 if x.vmax <= 0: return x*c2 if c1.arg >= c2.arg else x*c1 -symbolic = PatternMatcher([ +symbolic_simple = PatternMatcher([ # ** self folding ** (UPat.var("x") + 0, lambda x: x), # x+0 -> x (UPat.var("x") * 1, lambda x: x), # x*1 -> x @@ -1036,6 +1036,12 @@ symbolic = PatternMatcher([ (UPat.var('x', dtype=dtypes.bool) * UPat.var('y'), lambda x,y: x&y), (UPat.var('x', dtype=dtypes.bool) + UPat.var('y'), lambda x,y: x|y), (UPat.var('x', dtype=dtypes.bool).maximum(UPat.var('y')), lambda x,y: x|y), + # *** cast *** + (UPat(UOps.CAST, name="root", src=UPat.cvar("c")), lambda root, c: root.const_like(c.arg)), + (UPat(UOps.CAST, name="root"), lambda root: root.src[0] if root.dtype == root.src[0].dtype else None), +]) + +symbolic = symbolic_simple+PatternMatcher([ # group like ((UPat.var("x") + UPat.var("y")) + UPat.var("x") * UPat.cvar("c"), lambda x,y,c: (x+x*c)+y), # ** combine terms ** @@ -1091,9 +1097,6 @@ symbolic = PatternMatcher([ # ** mod ** # mod folding (UPat.var("x") % UPat.cvar("c", vec=False), lambda x,c: newx if 0 < c.arg and (newx:=mod_folding(x,c.arg)) is not None else None), - # *** cast *** - (UPat(UOps.CAST, name="root", src=UPat.cvar("c")), lambda root, c: root.const_like(c.arg)), - (UPat(UOps.CAST, name="root"), lambda root: root.src[0] if root.dtype == root.src[0].dtype else None), ]) symbolic_flat = symbolic+PatternMatcher([ diff --git a/tinygrad/renderer/cstyle.py b/tinygrad/renderer/cstyle.py index 24af804c79..d571b99178 100644 --- a/tinygrad/renderer/cstyle.py +++ b/tinygrad/renderer/cstyle.py @@ -34,6 +34,11 @@ base_rewrite = PatternMatcher([ (UPat(UOps.CONST, dtype=dtypes.uint64, name="x"), lambda ctx,x: f"{x.arg}ull"), (UPat(UOps.CONST, dtype=dtypes.uint32, name="x"), lambda ctx,x: f"{x.arg}u"), (UPat(UOps.CONST, dtype=dtypes.bool, name="x"), lambda ctx,x: "1" if x.arg else "0"), + # consts are rendered to larger type and casted + (UPat(UOps.CONST, (dtypes.bfloat16, dtypes.half), name="x"), lambda ctx,x: f"(({ctx.render_dtype(x.dtype)}){x.arg}f)"), + (UPat(UOps.CONST, (dtypes.uint8, dtypes.uint16), name="x"), lambda ctx,x: f"(({ctx.render_dtype(x.dtype)}){x.arg}u)"), + (UPat(UOps.CONST, (dtypes.int8, dtypes.int16), name="x"), lambda ctx,x: f"(({ctx.render_dtype(x.dtype)}){x.arg})"), + # default const render (UPat(UOps.CONST, name="x"), lambda ctx,x: str(x.arg)), # new load/store (UPat(UOps.INDEX, src=(UPat.var("buf"), UPat.var('idx'))), @@ -49,10 +54,6 @@ base_rewrite = PatternMatcher([ ]) extra_pm = PatternMatcher([ - # consts are rendered to larger type and casted - (UPat(UOps.CONST, (dtypes.bfloat16, dtypes.half), name="c"), lambda c: UOp.const(dtypes.float, c.arg).cast(c.dtype)), - (UPat(UOps.CONST, (dtypes.uint8, dtypes.uint16), name="c"), lambda c: UOp.const(dtypes.uint32, c.arg).cast(c.dtype)), - (UPat(UOps.CONST, (dtypes.int8, dtypes.int16), name="c"), lambda c: UOp.const(dtypes.int32, c.arg).cast(c.dtype)), # insert a NOOP before BITCAST to force it to be rendered. not needed on all backends? (UPat(UOps.BITCAST, name="x"), lambda x: UOp(UOps.BITCAST, x.dtype, (UOp(UOps.NOOP, x.src[0].dtype, x.src),)) if x.src[0].op is not UOps.NOOP else None), @@ -396,6 +397,7 @@ class AMDRenderer(CStyleLanguage): (UPat(UOps.CAST, name="x", src=UPat.var("y", dtypes.bfloat16)),lambda x,y: y.cast(dtypes.float).cast(x.dtype) if x.dtype!=dtypes.float else None), (UPat(UOps.CAST, dtypes.bfloat16, UPat.var("x")),lambda x: x.cast(dtypes.float).cast(dtypes.bfloat16) if x.dtype!=dtypes.float else None), # bfloat16 casting + (UPat.cvar('x', dtypes.bfloat16), lambda x: cast_float_to_bf16(UOp.const(dtypes.float, x.arg))), (UPat(UOps.CAST, dtype=dtypes.float, src=UPat.var("x", dtype=dtypes.bfloat16)), lambda x: (x.bitcast(dtypes.ushort).cast(dtypes.uint)<<16).bitcast(dtypes.float)), (UPat(UOps.CAST, dtype=dtypes.bfloat16, src=UPat.var("x", dtype=dtypes.float)), cast_float_to_bf16)]) + extra_pm diff --git a/tinygrad/renderer/ptx.py b/tinygrad/renderer/ptx.py index 9f7fe9998a..19d00b3d8e 100644 --- a/tinygrad/renderer/ptx.py +++ b/tinygrad/renderer/ptx.py @@ -1,7 +1,7 @@ from typing import DefaultDict, Dict, List, Union, Optional, cast, Callable import struct from collections import defaultdict -from tinygrad.ops import BinaryOps, UnaryOps, TernaryOps, Op, UOps, UOp, PatternMatcher, UPat, symbolic +from tinygrad.ops import BinaryOps, UnaryOps, TernaryOps, Op, UOps, UOp, PatternMatcher, UPat from tinygrad.dtype import dtypes, DType, PtrDType, ConstType from tinygrad.renderer import Renderer from tinygrad.renderer.cstyle import CUDARenderer @@ -33,7 +33,7 @@ asm_for_op: Dict[Op, Callable] = { } supports_half: List[Op] = [UnaryOps.EXP2, BinaryOps.ADD, BinaryOps.MUL, BinaryOps.MAX, BinaryOps.CMPLT, TernaryOps.WHERE] -ptx_matcher = symbolic+PatternMatcher([ +ptx_matcher = PatternMatcher([ # bool CMPNE is XOR, bool CMPLT is XOR+AND (universal makes this slow, this is for renderer only) (UPat.var('x', dtype=dtypes.bool).ne(UPat.var('y')), lambda x,y: x^y), (UPat.var('x', dtype=dtypes.bool).lt(UPat.var('y')), lambda x,y: (x^True)&y),