don't const rewrite in cstyle (#7442)

* don't const rewrite in cstyle

* Update cstyle.py

* simple_symbolic

* fix bfloat16 const on AMD
This commit is contained in:
George Hotz
2024-10-31 18:16:49 +07:00
committed by GitHub
parent bdde795239
commit 5dd1ffd5d0
6 changed files with 26 additions and 13 deletions

View File

@@ -34,7 +34,7 @@ def assert_jit_cache_len(fxn, expected_len):
def is_dtype_supported(dtype: DType, device: str = Device.DEFAULT):
if dtype == dtypes.bfloat16:
# NOTE: this requires bf16 buffer support
return device in {"AMD"} or (device in {"CUDA", "NV"} and not CI and not getenv("PTX"))
return device in {"AMD"} or (device in {"CUDA", "NV", "METAL"} and not CI and not getenv("PTX"))
if device in ["WEBGPU", "WEBGL"]: return dtype in [dtypes.float, dtypes.int32, dtypes.uint32]
# for CI GPU and OSX, cl_khr_fp16 isn't supported
# for CI LLVM, it segfaults because it can't link to the casting function

8
test_driven_development.sh Executable file
View File

@@ -0,0 +1,8 @@
#!/bin/bash
python3 test/external/process_replay/reset.py
RUN_PROCESS_REPLAY=1 pytest -n auto test/test_tiny.py test/test_uop_graph.py test/test_ops.py test/test_linearizer.py
while true; do
if python3 test/test_tiny.py; then
PYTHONPATH="." python3 test/external/process_replay/process_replay.py
fi
done

View File

@@ -3,7 +3,7 @@ from typing import Optional, Tuple, Dict, List, cast, TYPE_CHECKING, Any, Defaul
import functools, itertools, operator
from collections import defaultdict
from tinygrad.dtype import dtypes, ImageDType, PtrDType
from tinygrad.ops import UnaryOps, BinaryOps, TernaryOps, UOp, UOps, UPat, PatternMatcher, symbolic_flat
from tinygrad.ops import UnaryOps, BinaryOps, TernaryOps, UOp, UOps, UPat, PatternMatcher, symbolic_flat, symbolic_simple
from tinygrad.ops import graph_rewrite, is_irreducible, split_uop, identity_element, uop_given_valid, parse_valid, is_increasing, simplify_valid
from tinygrad.helpers import DEBUG, getenv, flatten, dedup, TRANSCENDENTAL, AMX, prod, partition, all_same
from tinygrad.codegen.transcendental import xexp2, xlog2, xsin, TRANSCENDENTAL_SUPPORTED_DTYPES
@@ -525,5 +525,5 @@ def full_graph_rewrite(sink:UOp, opts:Optional[Renderer]=None) -> UOp:
sink = graph_rewrite(sink, sym+get_extra_patterns(tuple(opts.code_for_op.keys()) if opts is not None else (), TRANSCENDENTAL>=2))
# for rendering without sym (including the rules from the renderer)
sink = graph_rewrite(sink, pm_render+opts.extra_matcher if opts is not None and opts.extra_matcher is not None else pm_render)
sink = graph_rewrite(sink, symbolic_simple+(pm_render+opts.extra_matcher if opts is not None and opts.extra_matcher is not None else pm_render))
return sink

View File

@@ -1003,7 +1003,7 @@ def max_var_const(x:UOp, c1:UOp, c2:UOp):
if x.vmin >= 0: return x*c1 if c1.arg >= c2.arg else x*c2
if x.vmax <= 0: return x*c2 if c1.arg >= c2.arg else x*c1
symbolic = PatternMatcher([
symbolic_simple = PatternMatcher([
# ** self folding **
(UPat.var("x") + 0, lambda x: x), # x+0 -> x
(UPat.var("x") * 1, lambda x: x), # x*1 -> x
@@ -1036,6 +1036,12 @@ symbolic = PatternMatcher([
(UPat.var('x', dtype=dtypes.bool) * UPat.var('y'), lambda x,y: x&y),
(UPat.var('x', dtype=dtypes.bool) + UPat.var('y'), lambda x,y: x|y),
(UPat.var('x', dtype=dtypes.bool).maximum(UPat.var('y')), lambda x,y: x|y),
# *** cast ***
(UPat(UOps.CAST, name="root", src=UPat.cvar("c")), lambda root, c: root.const_like(c.arg)),
(UPat(UOps.CAST, name="root"), lambda root: root.src[0] if root.dtype == root.src[0].dtype else None),
])
symbolic = symbolic_simple+PatternMatcher([
# group like
((UPat.var("x") + UPat.var("y")) + UPat.var("x") * UPat.cvar("c"), lambda x,y,c: (x+x*c)+y),
# ** combine terms **
@@ -1091,9 +1097,6 @@ symbolic = PatternMatcher([
# ** mod **
# mod folding
(UPat.var("x") % UPat.cvar("c", vec=False), lambda x,c: newx if 0 < c.arg and (newx:=mod_folding(x,c.arg)) is not None else None),
# *** cast ***
(UPat(UOps.CAST, name="root", src=UPat.cvar("c")), lambda root, c: root.const_like(c.arg)),
(UPat(UOps.CAST, name="root"), lambda root: root.src[0] if root.dtype == root.src[0].dtype else None),
])
symbolic_flat = symbolic+PatternMatcher([

View File

@@ -34,6 +34,11 @@ base_rewrite = PatternMatcher([
(UPat(UOps.CONST, dtype=dtypes.uint64, name="x"), lambda ctx,x: f"{x.arg}ull"),
(UPat(UOps.CONST, dtype=dtypes.uint32, name="x"), lambda ctx,x: f"{x.arg}u"),
(UPat(UOps.CONST, dtype=dtypes.bool, name="x"), lambda ctx,x: "1" if x.arg else "0"),
# consts are rendered to larger type and casted
(UPat(UOps.CONST, (dtypes.bfloat16, dtypes.half), name="x"), lambda ctx,x: f"(({ctx.render_dtype(x.dtype)}){x.arg}f)"),
(UPat(UOps.CONST, (dtypes.uint8, dtypes.uint16), name="x"), lambda ctx,x: f"(({ctx.render_dtype(x.dtype)}){x.arg}u)"),
(UPat(UOps.CONST, (dtypes.int8, dtypes.int16), name="x"), lambda ctx,x: f"(({ctx.render_dtype(x.dtype)}){x.arg})"),
# default const render
(UPat(UOps.CONST, name="x"), lambda ctx,x: str(x.arg)),
# new load/store
(UPat(UOps.INDEX, src=(UPat.var("buf"), UPat.var('idx'))),
@@ -49,10 +54,6 @@ base_rewrite = PatternMatcher([
])
extra_pm = PatternMatcher([
# consts are rendered to larger type and casted
(UPat(UOps.CONST, (dtypes.bfloat16, dtypes.half), name="c"), lambda c: UOp.const(dtypes.float, c.arg).cast(c.dtype)),
(UPat(UOps.CONST, (dtypes.uint8, dtypes.uint16), name="c"), lambda c: UOp.const(dtypes.uint32, c.arg).cast(c.dtype)),
(UPat(UOps.CONST, (dtypes.int8, dtypes.int16), name="c"), lambda c: UOp.const(dtypes.int32, c.arg).cast(c.dtype)),
# insert a NOOP before BITCAST to force it to be rendered. not needed on all backends?
(UPat(UOps.BITCAST, name="x"),
lambda x: UOp(UOps.BITCAST, x.dtype, (UOp(UOps.NOOP, x.src[0].dtype, x.src),)) if x.src[0].op is not UOps.NOOP else None),
@@ -396,6 +397,7 @@ class AMDRenderer(CStyleLanguage):
(UPat(UOps.CAST, name="x", src=UPat.var("y", dtypes.bfloat16)),lambda x,y: y.cast(dtypes.float).cast(x.dtype) if x.dtype!=dtypes.float else None),
(UPat(UOps.CAST, dtypes.bfloat16, UPat.var("x")),lambda x: x.cast(dtypes.float).cast(dtypes.bfloat16) if x.dtype!=dtypes.float else None),
# bfloat16 casting
(UPat.cvar('x', dtypes.bfloat16), lambda x: cast_float_to_bf16(UOp.const(dtypes.float, x.arg))),
(UPat(UOps.CAST, dtype=dtypes.float, src=UPat.var("x", dtype=dtypes.bfloat16)),
lambda x: (x.bitcast(dtypes.ushort).cast(dtypes.uint)<<16).bitcast(dtypes.float)),
(UPat(UOps.CAST, dtype=dtypes.bfloat16, src=UPat.var("x", dtype=dtypes.float)), cast_float_to_bf16)]) + extra_pm

View File

@@ -1,7 +1,7 @@
from typing import DefaultDict, Dict, List, Union, Optional, cast, Callable
import struct
from collections import defaultdict
from tinygrad.ops import BinaryOps, UnaryOps, TernaryOps, Op, UOps, UOp, PatternMatcher, UPat, symbolic
from tinygrad.ops import BinaryOps, UnaryOps, TernaryOps, Op, UOps, UOp, PatternMatcher, UPat
from tinygrad.dtype import dtypes, DType, PtrDType, ConstType
from tinygrad.renderer import Renderer
from tinygrad.renderer.cstyle import CUDARenderer
@@ -33,7 +33,7 @@ asm_for_op: Dict[Op, Callable] = {
}
supports_half: List[Op] = [UnaryOps.EXP2, BinaryOps.ADD, BinaryOps.MUL, BinaryOps.MAX, BinaryOps.CMPLT, TernaryOps.WHERE]
ptx_matcher = symbolic+PatternMatcher([
ptx_matcher = PatternMatcher([
# bool CMPNE is XOR, bool CMPLT is XOR+AND (universal makes this slow, this is for renderer only)
(UPat.var('x', dtype=dtypes.bool).ne(UPat.var('y')), lambda x,y: x^y),
(UPat.var('x', dtype=dtypes.bool).lt(UPat.var('y')), lambda x,y: (x^True)&y),