mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 15:08:02 -05:00
* `global_load` and `global_store` using buffer dtype * `UOps.PHI` in all dtypes * `UOps.ALU` in all dtypes * `UOps.CONST` & `UOps.DEFINE_ACC` in all dtypes * -- endof implementation -- +tiny lint changes * these tests require the fp16 extention you can run them locally to confirm they're green: (GPT2 test is broken in master for mac, see [this](https://discord.com/channels/1068976834382925865/1069001075828469790/1177993277958533261) `GPU=1 python3 -m pytest test/external/external_test_onnx_backend.py::OnnxBackendNodeModelTest::test_dequantizelinear_e4m3fn_float16_cpu test/external/external_test_onnx_backend.py::OnnxBackendNodeModelTest::test_max_float16_cpu test/external/external_test_onnx_backend.py::OnnxBackendNodeModelTest::test_min_float16_cpu test/models/test_real_world.py::TestRealWorld::test_llama test/models/test_real_world.py::TestRealWorld::test_gpt2 test/models/test_whisper.py test/test_specific_conv.py::TestSpecific::test_big_vec_mul` skip the new test_linearizer_failures in CI GPU because of the fp16 extention This passes on a real GPU since the extention is available: `GPU=1 python3 -m pytest test/test_linearizer_failures.py::TestLinearizerFailures::test_failure_8` see CI logs [here](https://github.com/tinygrad/tinygrad/actions/runs/6996590597/job/19032641427#step:14:644) * these tests fail in CI due to segfaults and CPU crashes To confirm they're green locally, you can run the following commands: 1. For the tests skipped in test_ops.py (note: CLANG is very slow) `for var in GPU CUDA CLANG; do export $var=1; for test in test/test_ops.py::TestOps::test_slice_fancy_indexing_no_dim_collapse test/test_ops.py::TestOps::test_slice_fancy_indexing_dim_collapse_int test/test_ops.py::TestOps::test_slice_fancy_indexing_dim_inject_none test/test_ops.py::TestOps::test_slice_fancy_indexing_dim_inject_and_collapse; do python3 -m pytest $test; done; unset $var; done` 2. For the ONNX tests skipped in CLANG: ``` CLANG=1 python3 -m pytest test/external/external_test_onnx_backend.py::OnnxBackendNodeModelTest::test_ai_onnx_ml_array_feature_extractor_cpu \ test/external/external_test_onnx_backend.py::OnnxBackendNodeModelTest::test_gather_elements_0_cpu \ test/external/external_test_onnx_backend.py::OnnxBackendNodeModelTest::test_nllloss_NCd1_expanded_cpu \ test/external/external_test_onnx_backend.py::OnnxBackendNodeModelTest::test_sce_mean_weight_ii_3d_cpu \ test/external/external_test_onnx_backend.py::OnnxBackendNodeModelTest::test_gather_elements_1_cpu \ test/external/external_test_onnx_backend.py::OnnxBackendNodeModelTest::test_sce_NCd1_mean_weight_negative_ii_cpu \ test/external/external_test_onnx_backend.py::OnnxBackendNodeModelTest::test_nllloss_NCd1_weight_expanded_cpu \ test/external/external_test_onnx_backend.py::OnnxBackendNodeModelTest::test_nllloss_NCd1d2d3_none_no_weight_negative_ii_expanded_cpu \ test/external/external_test_onnx_backend.py::OnnxBackendNodeModelTest::test_nllloss_NCd1_ii_expanded_cpu \ test/external/external_test_onnx_backend.py::OnnxBackendNodeModelTest::test_sce_mean_weight_ii_4d_cpu \ test/external/external_test_onnx_backend.py::OnnxBackendNodeModelTest::test_sce_mean_weight_ii_3d_log_prob_cpu \ test/external/external_test_onnx_backend.py::OnnxBackendNodeModelTest::test_gather_elements_negative_indices_cpu \ test/external/external_test_onnx_backend.py::OnnxBackendNodeModelTest::test_sce_NCd1d2d3d4d5_mean_weight_log_prob_cpu \ test/external/external_test_onnx_backend.py::OnnxBackendNodeModelTest::test_sce_NCd1_mean_weight_negative_ii_log_prob_cpu \ test/external/external_test_onnx_backend.py::OnnxBackendNodeModelTest::test_nllloss_NCd1d2_no_weight_reduction_mean_ii_expanded_cpu \ test/external/external_test_onnx_backend.py::OnnxBackendNodeModelTest::test_sce_NCd1d2d3d4d5_mean_weight_cpu \ test/external/external_test_onnx_backend.py::OnnxBackendNodeModelTest::test_nllloss_NCd1d2d3d4d5_mean_weight_expanded_cpu \ test/external/external_test_onnx_backend.py::OnnxBackendNodeModelTest::test_nllloss_NCd1_mean_weight_negative_ii_expanded_cpu \ test/external/external_test_onnx_backend.py::OnnxBackendNodeModelTest::test_sce_mean_weight_ii_4d_log_prob_cpu \ test/external/external_test_onnx_backend.py::OnnxBackendNodeModelTest::test_nllloss_NCd1d2_with_weight_reduction_mean_expanded_cpu \ test/external/external_test_onnx_backend.py::OnnxBackendNodeModelTest::test_nllloss_NCd1_weight_ii_expanded_cpu \ test/external/external_test_onnx_backend.py::OnnxBackendNodeModelTest::test_nllloss_NCd1d2_with_weight_reduction_sum_ii_expanded_cpu \ test/external/external_test_onnx_backend.py::OnnxBackendNodeModelTest::test_nllloss_NCd1d2_with_weight_reduction_sum_expanded_cpu \ test/external/external_test_onnx_backend.py::OnnxBackendNodeModelTest::test_nllloss_NCd1d2_expanded_cpu \ test/external/external_test_onnx_backend.py::OnnxBackendNodeModelTest::test_nllloss_NCd1d2_reduction_sum_expanded_cpu \ test/external/external_test_onnx_backend.py::OnnxBackendNodeModelTest::test_nllloss_NCd1d2d3d4d5_none_no_weight_expanded_cpu \ test/external/external_test_onnx_backend.py::OnnxBackendNodeModelTest::test_nllloss_NCd1d2d3_sum_weight_high_ii_expanded_cpu \ test/external/external_test_onnx_backend.py::OnnxBackendNodeModelTest::test_nllloss_NCd1d2_reduction_mean_expanded_cpu \ test/external/external_test_onnx_backend.py::OnnxBackendNodeModelTest::test_nllloss_NCd1d2_with_weight_expanded_cpu ``` 3. The LLVM test I skipped here is already [skipped in master for all backends](https://github.com/tinygrad/tinygrad/blob/master/test/external/external_test_onnx_backend.py#L186), I just made it more specific `LLVM=1 python3 -m pytest test/external/external_test_onnx_backend.py::OnnxBackendNodeModelTest::test_dequantizelinear_e4m3fn_float16_cpu` * Revert "these tests fail in CI due to segfaults and CPU crashes" This reverts commit15db570143. * merge with cleanup-vectorized-hip-renders * barely working HIP P1, ALU ops need a refactor? * manage the fact that in HIP [half2 is actually an unsigned int vec](f921880387/hip/include/hip/amd_detail/amd_hip_fp16.h (L59)) and half is a totally different __half that [has an unsigned int element in it](f921880387/hip/include/hip/amd_detail/amd_hip_fp16.h (L50)) but can't be accessed [because it's private](f921880387/hip/include/hip/amd_detail/amd_hip_fp16.h (L86)). If you just do this: ``` half2 val0 = // ... half val1 = // ... ``` then you can't do: ``` val0.x + val1 // error: use of overloaded operator '+' is ambiguous (with operand types 'unsigned short' and 'half' (aka '__half')) ``` * update the sign definition to avoid division by zero in all dtypes * diff cleanup p1: why were these in the diff anyways * less hacky HIP, enable CIFAR fp16 benchmark, test ops for HIP in CI! add ALU ops overloads for HIP this will make HIP max work handle mod Revert "handle mod" This reverts commit 370fd4b3fbe99b6ae8cc293d005b106628205933. update max to use hmax add HIP GEP render logic enable CIFAR fp16 benchmark test ops for HIP back to store as float because this only works for float4 grouping right now test_ops for hip!! always sign * back to the sign we had before because we cant do a backward pass on a Less node * remove old hacks HIP compiling test_ops in CI takes ~9 mins, not doing it for now new HIP ALUs * reduce accs done right * refactor to function * no device hacks hacks p2 the other way * LLVM ALU ops half, float and double are all float update max * update test_uops, cmplt is always a bool in the real linearizer. assertAlmostEqual is wrong when ret is bool * cleanup LLVM wrong code * dummy change for the CUDA install glitch --------- Co-authored-by: George Hotz <72895+geohot@users.noreply.github.com>
132 lines
7.3 KiB
Python
132 lines
7.3 KiB
Python
from typing import Dict, List, Final, Callable, DefaultDict
|
|
from collections import defaultdict
|
|
from tinygrad.ops import UnaryOps, BinaryOps, TernaryOps, Op
|
|
from tinygrad.helpers import DType, dtypes, ImageDType, DEBUG, getenv
|
|
from tinygrad.codegen.linearizer import UOp, UOps
|
|
from triton.compiler import compile as triton_compile
|
|
import linecache
|
|
import math
|
|
import re
|
|
|
|
triton_dtypes = {dtypes.double: "tl.float64", dtypes.float32: "tl.float32", dtypes.float16: "tl.float16", dtypes.bool: "tl.int1", dtypes.int8: "tl.int8", dtypes.uint8: "tl.uint8", dtypes.int32: "tl.int32", dtypes.int64: "tl.int64", dtypes.uint32: "tl.uint32", dtypes.uint64: "tl.uint64", dtypes.int16: "tl.int16", dtypes.uint16: "tl.uint16"}
|
|
signature_dtypes = {dtypes.double: "*fp64",dtypes.float32: "*fp32", dtypes.float16: "*fp16", dtypes.bool: "*i8", dtypes.int8: "*i1", dtypes.uint8: "*u8", dtypes._arg_int32: "i32", dtypes.int32: "*i32", dtypes.int64: "*i64", dtypes.uint32: "*u32", dtypes.uint64: "*u64", dtypes.int16: "*i16", dtypes.uint16: "*u16"}
|
|
|
|
def next_power_of_2(x):
|
|
return 1 << (x - 1).bit_length()
|
|
|
|
def render_valid(valid):
|
|
return '(' * (len(valid) -1) + ') and '.join(valid) if len(valid) else 'True'
|
|
|
|
#NOTE Triton requires matching dimensions for load/store, disable this and see TestOps::test_output_padded_conv_transpose2d fail to compile
|
|
def fill_dims_for_idx(idx, dims):
|
|
return "(" + idx + "+ (" + (f"0*({'+'.join(d for d in dims)})))") if len(dims) else idx
|
|
|
|
def get_max(var):
|
|
if isinstance(var, int): return var
|
|
return re.sub(r'\[(.*?)\]', '', str(var))[1:-1]
|
|
|
|
#NOTE can be removed after https://github.com/gpuocelot/gpuocelot/issues/8 gets resolved
|
|
def remove_single_scalar_curly_braces(ptx_code):
|
|
return '\n'.join([re.sub(r'\{\s*(%\w+)\s*\}', r'\1', line) for line in ptx_code.split('\n')])
|
|
|
|
def render_const(args,dtype:DType):
|
|
return (('-' if args<0 else '') + 'tl.where(1,float("inf"),0)') if math.isinf(args) else ('tl.where(1,float("nan"),0)' if math.isnan(args) else f"{int(args)}" if dtypes.is_int(dtype) else str(args))
|
|
|
|
def render_cast(x:str, dtype:DType):
|
|
return f"{x}.to({triton_dtypes[dtype]})"
|
|
|
|
def define_scalar(local_size, dtype, args):
|
|
if len(local_size) > 0: return f"tl.full(({','.join([str(next_power_of_2(x)) for x in local_size])},),{render_const(args,dtype)}, dtype={triton_dtypes[dtype]})"
|
|
return render_const(args,dtype)
|
|
|
|
def uops_to_triton(function_name:str, uops:List[UOp]):
|
|
local_size: List[int] = []
|
|
depth = 1
|
|
signatures, dims, bufs, kernel, valid = [], [], [], [], [] #type: ignore
|
|
|
|
c: DefaultDict[str, int] = defaultdict(int)
|
|
r: Dict[UOp, str] = {}
|
|
def ssa(u, prefix="t"):
|
|
nonlocal c, r
|
|
c[prefix] += 1
|
|
r[u]=f"{prefix}{c[prefix]-1}"
|
|
return r[u]
|
|
|
|
child_count: DefaultDict[UOp, int] = defaultdict(int)
|
|
for ru in uops:
|
|
for v in ru.vin:
|
|
child_count[v] += 1
|
|
|
|
def kk(s): kernel.append(" "*depth+s)
|
|
code_for_op: Final[Dict[Op, Callable]] = {
|
|
UnaryOps.EXP2: lambda x,dtype,: f"tl.math.exp2({x})",
|
|
UnaryOps.LOG2: lambda x,dtype,: f"tl.math.log2({x})",
|
|
UnaryOps.SIN: lambda x,dtype: f"tl.sin({x})",
|
|
UnaryOps.SQRT: lambda x,dtype: f"tl.sqrt({x})",
|
|
UnaryOps.NEG: lambda x,dtype: f"-{x}" if dtype != dtypes.bool else f"tl.where({x}, 0, 1)",
|
|
BinaryOps.ADD: lambda x,y,dtype: f"({x}+{y})", BinaryOps.SUB: lambda x,y,: f"({x}-{y})",
|
|
BinaryOps.MUL: lambda x,y,dtype: f"({x}*{y})", BinaryOps.DIV: lambda x,y,: f"({x}/{y})" if y != '0.0' else f"{x}*tl.where({x}==0.0, float('nan'), float('inf'))",
|
|
BinaryOps.MAX: lambda x,y,dtype: f"tl.maximum({x},{y})",
|
|
BinaryOps.CMPLT: lambda x,y,dtype: f"({x}<{y})",
|
|
BinaryOps.MOD: lambda x,y,dtype: f"tl.abs({x})%tl.abs({y})*tl.where({x}<0,-1,1)",
|
|
TernaryOps.MULACC: lambda x,y,z,dtype: f"(({x}*{y})+{z})",
|
|
TernaryOps.WHERE: lambda x,y,z,dtype: f"tl.where({x},{y},{z})",
|
|
}
|
|
def int_div(x,y): return f"({x}//{y})" if y != '0' else f"{x}*tl.where({x}==0, float('nan'), float('inf'))"
|
|
for u in uops:
|
|
uop,dtype,vin,args = u.uop,u.dtype,u.vin,u.arg
|
|
if uop == UOps.LOOP:
|
|
kk(f"for {ssa(u, 'ridx')} in range({vin[0].arg}, {r[vin[1]]}):")
|
|
depth += 1
|
|
elif uop == UOps.END: depth -= 1
|
|
elif uop == UOps.ALU:
|
|
assert dtype is not None
|
|
val = code_for_op[args](*[r[x] for x in vin])
|
|
if child_count[u] <=1 or dtypes.is_int(dtype): r[u] = int_div(*[r[x] for x in vin]) if args == BinaryOps.DIV and dtypes.is_int(dtype) else val
|
|
else: kk(f"{ssa(u, 'alu')} = ({val})")
|
|
elif uop == UOps.LOAD:
|
|
assert dtype is not None
|
|
if len(vin) == 2: kk(f"{ssa(u, 'val')} = {render_cast(f'tl.load({r[vin[0]]} + { fill_dims_for_idx(r[vin[1]], dims)}, mask = {render_valid(valid)})', dtype)}")
|
|
else: kk(f"{ssa(u, 'val')} = {render_cast(f'tl.where({r[vin[2]]}, tl.load({r[vin[0]]}+{fill_dims_for_idx(r[vin[1]],dims)} , mask={render_valid(valid+[r[vin[2]]])}), 0.0)', dtype)}")
|
|
elif uop == UOps.DEFINE_ACC: kk(f"{ssa(u, 'acc')} = {define_scalar(local_size, dtype, args).replace('//', '/')}")
|
|
elif uop == UOps.CONST: r[u] = define_scalar([], dtype, args)
|
|
elif uop == UOps.PHI:
|
|
kk(f"{r[vin[0]]} = {r[vin[1]].replace('//', '/')}")
|
|
r[u] = r[vin[0]]
|
|
elif uop == UOps.STORE:
|
|
assert not isinstance(dtype, ImageDType), "unimplemented: image store"
|
|
kk(f"{'if '+r[vin[3]]+': ' if len(vin)>3 else ''}tl.store({r[vin[0]]} + {r[vin[1]]}, {r[vin[2]].replace('//', '/')}, mask = {render_valid(valid)}) ")
|
|
elif uop == UOps.DEFINE_GLOBAL:
|
|
bufs.append(args)
|
|
signatures.append(signature_dtypes[args[1]])
|
|
r[u] = args[0]
|
|
elif uop == UOps.SPECIAL:
|
|
dims.append(args[1])
|
|
valid.append(f"{args[1]}<{get_max(args[2])}")
|
|
if args[1].startswith("g"): kk(f"{args[1]} = tl.program_id({args[0]}) # {args[2]}")
|
|
elif args[1].startswith("l"):
|
|
kk(f"{args[1]} = tl.arange({0}, {next_power_of_2(args[2])})")
|
|
local_size.append(args[2])
|
|
r[u] = args[1]
|
|
elif uop == UOps.CAST and dtype is not None: r[u] = render_cast(r[vin[0]], dtype)
|
|
else: raise NotImplementedError(f"unimplemented: {uop}")
|
|
|
|
prg = f"import triton\nimport triton.language as tl\ntl.core.TRITON_MAX_TENSOR_NUMEL = float('inf')\n@triton.jit\ndef {function_name}("+','.join(f"{buf[0]}" for buf in bufs)+"):\n"
|
|
for i, line in enumerate(list(filter(lambda line: "tl.arange" in line, kernel))): kernel[kernel.index(line)] += f"[{', '.join([':' if i == j else 'None' for j in range(len(local_size))])}]"
|
|
prg += "\n".join(kernel)
|
|
|
|
acc_local_size = 1
|
|
for x in local_size: acc_local_size *= next_power_of_2(x)
|
|
local_size = [acc_local_size] + [1] * (len(local_size) - 1)
|
|
|
|
if DEBUG >= 4: print(prg)
|
|
getlines = linecache.getlines
|
|
linecache.getlines = lambda filename, module_globals=None: prg.splitlines(keepends=True) if "<triton>" == filename else getlines(filename, module_globals)
|
|
exec(compile(prg, "<triton>", "exec"), globals()) # pylint: disable=W0122\
|
|
compiled = triton_compile(globals()[function_name], signature=",".join(signatures), device_type="cuda", debug=False, cc=(35 if getenv("CUDACPU", 0) else None))
|
|
prg = remove_single_scalar_curly_braces(compiled.asm["ptx"].split(".file")[0].split(".visible .func")[0])
|
|
max_local_size = [int(x) for x in prg.split(".maxntid ")[1].split("\n")[0].split(", ")]
|
|
for i in range(len(local_size)): local_size[i] = min(local_size[i], max_local_size[i])
|
|
|
|
return prg, {"shared":compiled.metadata["shared"], "local_size":local_size + [1]*(3-len(local_size))}
|