mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
don't const rewrite in cstyle (#7442)
* don't const rewrite in cstyle * Update cstyle.py * simple_symbolic * fix bfloat16 const on AMD
This commit is contained in:
@@ -34,7 +34,7 @@ def assert_jit_cache_len(fxn, expected_len):
|
||||
def is_dtype_supported(dtype: DType, device: str = Device.DEFAULT):
|
||||
if dtype == dtypes.bfloat16:
|
||||
# NOTE: this requires bf16 buffer support
|
||||
return device in {"AMD"} or (device in {"CUDA", "NV"} and not CI and not getenv("PTX"))
|
||||
return device in {"AMD"} or (device in {"CUDA", "NV", "METAL"} and not CI and not getenv("PTX"))
|
||||
if device in ["WEBGPU", "WEBGL"]: return dtype in [dtypes.float, dtypes.int32, dtypes.uint32]
|
||||
# for CI GPU and OSX, cl_khr_fp16 isn't supported
|
||||
# for CI LLVM, it segfaults because it can't link to the casting function
|
||||
|
||||
8
test_driven_development.sh
Executable file
8
test_driven_development.sh
Executable file
@@ -0,0 +1,8 @@
|
||||
#!/bin/bash
|
||||
python3 test/external/process_replay/reset.py
|
||||
RUN_PROCESS_REPLAY=1 pytest -n auto test/test_tiny.py test/test_uop_graph.py test/test_ops.py test/test_linearizer.py
|
||||
while true; do
|
||||
if python3 test/test_tiny.py; then
|
||||
PYTHONPATH="." python3 test/external/process_replay/process_replay.py
|
||||
fi
|
||||
done
|
||||
@@ -3,7 +3,7 @@ from typing import Optional, Tuple, Dict, List, cast, TYPE_CHECKING, Any, Defaul
|
||||
import functools, itertools, operator
|
||||
from collections import defaultdict
|
||||
from tinygrad.dtype import dtypes, ImageDType, PtrDType
|
||||
from tinygrad.ops import UnaryOps, BinaryOps, TernaryOps, UOp, UOps, UPat, PatternMatcher, symbolic_flat
|
||||
from tinygrad.ops import UnaryOps, BinaryOps, TernaryOps, UOp, UOps, UPat, PatternMatcher, symbolic_flat, symbolic_simple
|
||||
from tinygrad.ops import graph_rewrite, is_irreducible, split_uop, identity_element, uop_given_valid, parse_valid, is_increasing, simplify_valid
|
||||
from tinygrad.helpers import DEBUG, getenv, flatten, dedup, TRANSCENDENTAL, AMX, prod, partition, all_same
|
||||
from tinygrad.codegen.transcendental import xexp2, xlog2, xsin, TRANSCENDENTAL_SUPPORTED_DTYPES
|
||||
@@ -525,5 +525,5 @@ def full_graph_rewrite(sink:UOp, opts:Optional[Renderer]=None) -> UOp:
|
||||
sink = graph_rewrite(sink, sym+get_extra_patterns(tuple(opts.code_for_op.keys()) if opts is not None else (), TRANSCENDENTAL>=2))
|
||||
|
||||
# for rendering without sym (including the rules from the renderer)
|
||||
sink = graph_rewrite(sink, pm_render+opts.extra_matcher if opts is not None and opts.extra_matcher is not None else pm_render)
|
||||
sink = graph_rewrite(sink, symbolic_simple+(pm_render+opts.extra_matcher if opts is not None and opts.extra_matcher is not None else pm_render))
|
||||
return sink
|
||||
|
||||
@@ -1003,7 +1003,7 @@ def max_var_const(x:UOp, c1:UOp, c2:UOp):
|
||||
if x.vmin >= 0: return x*c1 if c1.arg >= c2.arg else x*c2
|
||||
if x.vmax <= 0: return x*c2 if c1.arg >= c2.arg else x*c1
|
||||
|
||||
symbolic = PatternMatcher([
|
||||
symbolic_simple = PatternMatcher([
|
||||
# ** self folding **
|
||||
(UPat.var("x") + 0, lambda x: x), # x+0 -> x
|
||||
(UPat.var("x") * 1, lambda x: x), # x*1 -> x
|
||||
@@ -1036,6 +1036,12 @@ symbolic = PatternMatcher([
|
||||
(UPat.var('x', dtype=dtypes.bool) * UPat.var('y'), lambda x,y: x&y),
|
||||
(UPat.var('x', dtype=dtypes.bool) + UPat.var('y'), lambda x,y: x|y),
|
||||
(UPat.var('x', dtype=dtypes.bool).maximum(UPat.var('y')), lambda x,y: x|y),
|
||||
# *** cast ***
|
||||
(UPat(UOps.CAST, name="root", src=UPat.cvar("c")), lambda root, c: root.const_like(c.arg)),
|
||||
(UPat(UOps.CAST, name="root"), lambda root: root.src[0] if root.dtype == root.src[0].dtype else None),
|
||||
])
|
||||
|
||||
symbolic = symbolic_simple+PatternMatcher([
|
||||
# group like
|
||||
((UPat.var("x") + UPat.var("y")) + UPat.var("x") * UPat.cvar("c"), lambda x,y,c: (x+x*c)+y),
|
||||
# ** combine terms **
|
||||
@@ -1091,9 +1097,6 @@ symbolic = PatternMatcher([
|
||||
# ** mod **
|
||||
# mod folding
|
||||
(UPat.var("x") % UPat.cvar("c", vec=False), lambda x,c: newx if 0 < c.arg and (newx:=mod_folding(x,c.arg)) is not None else None),
|
||||
# *** cast ***
|
||||
(UPat(UOps.CAST, name="root", src=UPat.cvar("c")), lambda root, c: root.const_like(c.arg)),
|
||||
(UPat(UOps.CAST, name="root"), lambda root: root.src[0] if root.dtype == root.src[0].dtype else None),
|
||||
])
|
||||
|
||||
symbolic_flat = symbolic+PatternMatcher([
|
||||
|
||||
@@ -34,6 +34,11 @@ base_rewrite = PatternMatcher([
|
||||
(UPat(UOps.CONST, dtype=dtypes.uint64, name="x"), lambda ctx,x: f"{x.arg}ull"),
|
||||
(UPat(UOps.CONST, dtype=dtypes.uint32, name="x"), lambda ctx,x: f"{x.arg}u"),
|
||||
(UPat(UOps.CONST, dtype=dtypes.bool, name="x"), lambda ctx,x: "1" if x.arg else "0"),
|
||||
# consts are rendered to larger type and casted
|
||||
(UPat(UOps.CONST, (dtypes.bfloat16, dtypes.half), name="x"), lambda ctx,x: f"(({ctx.render_dtype(x.dtype)}){x.arg}f)"),
|
||||
(UPat(UOps.CONST, (dtypes.uint8, dtypes.uint16), name="x"), lambda ctx,x: f"(({ctx.render_dtype(x.dtype)}){x.arg}u)"),
|
||||
(UPat(UOps.CONST, (dtypes.int8, dtypes.int16), name="x"), lambda ctx,x: f"(({ctx.render_dtype(x.dtype)}){x.arg})"),
|
||||
# default const render
|
||||
(UPat(UOps.CONST, name="x"), lambda ctx,x: str(x.arg)),
|
||||
# new load/store
|
||||
(UPat(UOps.INDEX, src=(UPat.var("buf"), UPat.var('idx'))),
|
||||
@@ -49,10 +54,6 @@ base_rewrite = PatternMatcher([
|
||||
])
|
||||
|
||||
extra_pm = PatternMatcher([
|
||||
# consts are rendered to larger type and casted
|
||||
(UPat(UOps.CONST, (dtypes.bfloat16, dtypes.half), name="c"), lambda c: UOp.const(dtypes.float, c.arg).cast(c.dtype)),
|
||||
(UPat(UOps.CONST, (dtypes.uint8, dtypes.uint16), name="c"), lambda c: UOp.const(dtypes.uint32, c.arg).cast(c.dtype)),
|
||||
(UPat(UOps.CONST, (dtypes.int8, dtypes.int16), name="c"), lambda c: UOp.const(dtypes.int32, c.arg).cast(c.dtype)),
|
||||
# insert a NOOP before BITCAST to force it to be rendered. not needed on all backends?
|
||||
(UPat(UOps.BITCAST, name="x"),
|
||||
lambda x: UOp(UOps.BITCAST, x.dtype, (UOp(UOps.NOOP, x.src[0].dtype, x.src),)) if x.src[0].op is not UOps.NOOP else None),
|
||||
@@ -396,6 +397,7 @@ class AMDRenderer(CStyleLanguage):
|
||||
(UPat(UOps.CAST, name="x", src=UPat.var("y", dtypes.bfloat16)),lambda x,y: y.cast(dtypes.float).cast(x.dtype) if x.dtype!=dtypes.float else None),
|
||||
(UPat(UOps.CAST, dtypes.bfloat16, UPat.var("x")),lambda x: x.cast(dtypes.float).cast(dtypes.bfloat16) if x.dtype!=dtypes.float else None),
|
||||
# bfloat16 casting
|
||||
(UPat.cvar('x', dtypes.bfloat16), lambda x: cast_float_to_bf16(UOp.const(dtypes.float, x.arg))),
|
||||
(UPat(UOps.CAST, dtype=dtypes.float, src=UPat.var("x", dtype=dtypes.bfloat16)),
|
||||
lambda x: (x.bitcast(dtypes.ushort).cast(dtypes.uint)<<16).bitcast(dtypes.float)),
|
||||
(UPat(UOps.CAST, dtype=dtypes.bfloat16, src=UPat.var("x", dtype=dtypes.float)), cast_float_to_bf16)]) + extra_pm
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
from typing import DefaultDict, Dict, List, Union, Optional, cast, Callable
|
||||
import struct
|
||||
from collections import defaultdict
|
||||
from tinygrad.ops import BinaryOps, UnaryOps, TernaryOps, Op, UOps, UOp, PatternMatcher, UPat, symbolic
|
||||
from tinygrad.ops import BinaryOps, UnaryOps, TernaryOps, Op, UOps, UOp, PatternMatcher, UPat
|
||||
from tinygrad.dtype import dtypes, DType, PtrDType, ConstType
|
||||
from tinygrad.renderer import Renderer
|
||||
from tinygrad.renderer.cstyle import CUDARenderer
|
||||
@@ -33,7 +33,7 @@ asm_for_op: Dict[Op, Callable] = {
|
||||
}
|
||||
|
||||
supports_half: List[Op] = [UnaryOps.EXP2, BinaryOps.ADD, BinaryOps.MUL, BinaryOps.MAX, BinaryOps.CMPLT, TernaryOps.WHERE]
|
||||
ptx_matcher = symbolic+PatternMatcher([
|
||||
ptx_matcher = PatternMatcher([
|
||||
# bool CMPNE is XOR, bool CMPLT is XOR+AND (universal makes this slow, this is for renderer only)
|
||||
(UPat.var('x', dtype=dtypes.bool).ne(UPat.var('y')), lambda x,y: x^y),
|
||||
(UPat.var('x', dtype=dtypes.bool).lt(UPat.var('y')), lambda x,y: (x^True)&y),
|
||||
|
||||
Reference in New Issue
Block a user