Files
tinygrad/extra/triton/triton.py
qazal 4380ccb169 Non fp32 math (#2264)
* `global_load` and `global_store` using buffer dtype

* `UOps.PHI` in all dtypes

* `UOps.ALU` in all dtypes

* `UOps.CONST` & `UOps.DEFINE_ACC` in all dtypes

* -- endof implementation --
+tiny lint changes

* these tests require the fp16 extention

you can run them locally to confirm they're green: (GPT2 test is broken in master for mac, see [this](https://discord.com/channels/1068976834382925865/1069001075828469790/1177993277958533261)

`GPU=1 python3 -m pytest test/external/external_test_onnx_backend.py::OnnxBackendNodeModelTest::test_dequantizelinear_e4m3fn_float16_cpu test/external/external_test_onnx_backend.py::OnnxBackendNodeModelTest::test_max_float16_cpu test/external/external_test_onnx_backend.py::OnnxBackendNodeModelTest::test_min_float16_cpu test/models/test_real_world.py::TestRealWorld::test_llama test/models/test_real_world.py::TestRealWorld::test_gpt2 test/models/test_whisper.py test/test_specific_conv.py::TestSpecific::test_big_vec_mul`

skip the new test_linearizer_failures in CI GPU because of the fp16 extention

This passes on a real GPU since the extention is available:
`GPU=1 python3 -m pytest test/test_linearizer_failures.py::TestLinearizerFailures::test_failure_8`

see CI logs [here](https://github.com/tinygrad/tinygrad/actions/runs/6996590597/job/19032641427#step:14:644)

* these tests fail in CI due to segfaults and CPU crashes

To confirm they're green locally, you can run the following commands:

1. For the tests skipped in test_ops.py (note: CLANG is very slow)

`for var in GPU CUDA CLANG; do export $var=1; for test in test/test_ops.py::TestOps::test_slice_fancy_indexing_no_dim_collapse test/test_ops.py::TestOps::test_slice_fancy_indexing_dim_collapse_int test/test_ops.py::TestOps::test_slice_fancy_indexing_dim_inject_none test/test_ops.py::TestOps::test_slice_fancy_indexing_dim_inject_and_collapse; do python3 -m pytest $test; done; unset $var; done`

2. For the ONNX tests skipped in CLANG:

```
CLANG=1 python3 -m pytest test/external/external_test_onnx_backend.py::OnnxBackendNodeModelTest::test_ai_onnx_ml_array_feature_extractor_cpu \
test/external/external_test_onnx_backend.py::OnnxBackendNodeModelTest::test_gather_elements_0_cpu \
test/external/external_test_onnx_backend.py::OnnxBackendNodeModelTest::test_nllloss_NCd1_expanded_cpu \
test/external/external_test_onnx_backend.py::OnnxBackendNodeModelTest::test_sce_mean_weight_ii_3d_cpu \
test/external/external_test_onnx_backend.py::OnnxBackendNodeModelTest::test_gather_elements_1_cpu \
test/external/external_test_onnx_backend.py::OnnxBackendNodeModelTest::test_sce_NCd1_mean_weight_negative_ii_cpu \
test/external/external_test_onnx_backend.py::OnnxBackendNodeModelTest::test_nllloss_NCd1_weight_expanded_cpu \
test/external/external_test_onnx_backend.py::OnnxBackendNodeModelTest::test_nllloss_NCd1d2d3_none_no_weight_negative_ii_expanded_cpu \
test/external/external_test_onnx_backend.py::OnnxBackendNodeModelTest::test_nllloss_NCd1_ii_expanded_cpu \
test/external/external_test_onnx_backend.py::OnnxBackendNodeModelTest::test_sce_mean_weight_ii_4d_cpu \
test/external/external_test_onnx_backend.py::OnnxBackendNodeModelTest::test_sce_mean_weight_ii_3d_log_prob_cpu \
test/external/external_test_onnx_backend.py::OnnxBackendNodeModelTest::test_gather_elements_negative_indices_cpu \
test/external/external_test_onnx_backend.py::OnnxBackendNodeModelTest::test_sce_NCd1d2d3d4d5_mean_weight_log_prob_cpu \
test/external/external_test_onnx_backend.py::OnnxBackendNodeModelTest::test_sce_NCd1_mean_weight_negative_ii_log_prob_cpu \
test/external/external_test_onnx_backend.py::OnnxBackendNodeModelTest::test_nllloss_NCd1d2_no_weight_reduction_mean_ii_expanded_cpu \
test/external/external_test_onnx_backend.py::OnnxBackendNodeModelTest::test_sce_NCd1d2d3d4d5_mean_weight_cpu \
test/external/external_test_onnx_backend.py::OnnxBackendNodeModelTest::test_nllloss_NCd1d2d3d4d5_mean_weight_expanded_cpu \
test/external/external_test_onnx_backend.py::OnnxBackendNodeModelTest::test_nllloss_NCd1_mean_weight_negative_ii_expanded_cpu \
test/external/external_test_onnx_backend.py::OnnxBackendNodeModelTest::test_sce_mean_weight_ii_4d_log_prob_cpu \
test/external/external_test_onnx_backend.py::OnnxBackendNodeModelTest::test_nllloss_NCd1d2_with_weight_reduction_mean_expanded_cpu \
test/external/external_test_onnx_backend.py::OnnxBackendNodeModelTest::test_nllloss_NCd1_weight_ii_expanded_cpu \
test/external/external_test_onnx_backend.py::OnnxBackendNodeModelTest::test_nllloss_NCd1d2_with_weight_reduction_sum_ii_expanded_cpu \
test/external/external_test_onnx_backend.py::OnnxBackendNodeModelTest::test_nllloss_NCd1d2_with_weight_reduction_sum_expanded_cpu \
test/external/external_test_onnx_backend.py::OnnxBackendNodeModelTest::test_nllloss_NCd1d2_expanded_cpu \
test/external/external_test_onnx_backend.py::OnnxBackendNodeModelTest::test_nllloss_NCd1d2_reduction_sum_expanded_cpu \
test/external/external_test_onnx_backend.py::OnnxBackendNodeModelTest::test_nllloss_NCd1d2d3d4d5_none_no_weight_expanded_cpu \
test/external/external_test_onnx_backend.py::OnnxBackendNodeModelTest::test_nllloss_NCd1d2d3_sum_weight_high_ii_expanded_cpu \
test/external/external_test_onnx_backend.py::OnnxBackendNodeModelTest::test_nllloss_NCd1d2_reduction_mean_expanded_cpu \
test/external/external_test_onnx_backend.py::OnnxBackendNodeModelTest::test_nllloss_NCd1d2_with_weight_expanded_cpu
```

3. The LLVM test I skipped here is already [skipped in master for all backends](https://github.com/tinygrad/tinygrad/blob/master/test/external/external_test_onnx_backend.py#L186), I just made it more specific

`LLVM=1 python3 -m pytest test/external/external_test_onnx_backend.py::OnnxBackendNodeModelTest::test_dequantizelinear_e4m3fn_float16_cpu`

* Revert "these tests fail in CI due to segfaults and CPU crashes"

This reverts commit 15db570143.

* merge with cleanup-vectorized-hip-renders

* barely working HIP P1, ALU ops need a refactor?

* manage the fact that in HIP [half2 is actually an unsigned int vec](f921880387/hip/include/hip/amd_detail/amd_hip_fp16.h (L59)) and half is a totally different __half that [has an unsigned int element in it](f921880387/hip/include/hip/amd_detail/amd_hip_fp16.h (L50)) but can't be accessed [because it's private](f921880387/hip/include/hip/amd_detail/amd_hip_fp16.h (L86)). If you just do this:

```
half2 val0 = // ...
half val1 = // ...
```
then you can't do:
```
val0.x + val1 // error: use of overloaded operator '+' is ambiguous (with operand types 'unsigned short' and 'half' (aka '__half'))
```

* update the sign definition to avoid division by zero in all dtypes

* diff cleanup p1: why were these in the diff anyways

* less hacky HIP, enable CIFAR fp16 benchmark, test ops for HIP in CI!

add ALU ops overloads for HIP

this will make HIP max work

handle mod

Revert "handle mod"

This reverts commit 370fd4b3fbe99b6ae8cc293d005b106628205933.

update max to use hmax

add HIP GEP render logic

enable CIFAR fp16 benchmark

test ops for HIP

back to store as float because this only works for float4 grouping right now

test_ops for hip!!

always sign

* back to the sign we had before because we cant do a backward pass on a Less node

* remove old hacks

HIP compiling test_ops in CI takes ~9 mins, not doing it for now

new HIP ALUs

* reduce accs done right

* refactor to function

* no device hacks

hacks p2

the other way

* LLVM ALU ops

half, float and double are all float

update max

* update test_uops, cmplt is always a bool in the real linearizer. assertAlmostEqual is wrong when ret is bool

* cleanup LLVM wrong code

* dummy change for the CUDA install glitch

---------

Co-authored-by: George Hotz <72895+geohot@users.noreply.github.com>
2023-12-03 13:45:49 -08:00

132 lines
7.3 KiB
Python

from typing import Dict, List, Final, Callable, DefaultDict
from collections import defaultdict
from tinygrad.ops import UnaryOps, BinaryOps, TernaryOps, Op
from tinygrad.helpers import DType, dtypes, ImageDType, DEBUG, getenv
from tinygrad.codegen.linearizer import UOp, UOps
from triton.compiler import compile as triton_compile
import linecache
import math
import re
triton_dtypes = {dtypes.double: "tl.float64", dtypes.float32: "tl.float32", dtypes.float16: "tl.float16", dtypes.bool: "tl.int1", dtypes.int8: "tl.int8", dtypes.uint8: "tl.uint8", dtypes.int32: "tl.int32", dtypes.int64: "tl.int64", dtypes.uint32: "tl.uint32", dtypes.uint64: "tl.uint64", dtypes.int16: "tl.int16", dtypes.uint16: "tl.uint16"}
signature_dtypes = {dtypes.double: "*fp64",dtypes.float32: "*fp32", dtypes.float16: "*fp16", dtypes.bool: "*i8", dtypes.int8: "*i1", dtypes.uint8: "*u8", dtypes._arg_int32: "i32", dtypes.int32: "*i32", dtypes.int64: "*i64", dtypes.uint32: "*u32", dtypes.uint64: "*u64", dtypes.int16: "*i16", dtypes.uint16: "*u16"}
def next_power_of_2(x):
return 1 << (x - 1).bit_length()
def render_valid(valid):
return '(' * (len(valid) -1) + ') and '.join(valid) if len(valid) else 'True'
#NOTE Triton requires matching dimensions for load/store, disable this and see TestOps::test_output_padded_conv_transpose2d fail to compile
def fill_dims_for_idx(idx, dims):
return "(" + idx + "+ (" + (f"0*({'+'.join(d for d in dims)})))") if len(dims) else idx
def get_max(var):
if isinstance(var, int): return var
return re.sub(r'\[(.*?)\]', '', str(var))[1:-1]
#NOTE can be removed after https://github.com/gpuocelot/gpuocelot/issues/8 gets resolved
def remove_single_scalar_curly_braces(ptx_code):
return '\n'.join([re.sub(r'\{\s*(%\w+)\s*\}', r'\1', line) for line in ptx_code.split('\n')])
def render_const(args,dtype:DType):
return (('-' if args<0 else '') + 'tl.where(1,float("inf"),0)') if math.isinf(args) else ('tl.where(1,float("nan"),0)' if math.isnan(args) else f"{int(args)}" if dtypes.is_int(dtype) else str(args))
def render_cast(x:str, dtype:DType):
return f"{x}.to({triton_dtypes[dtype]})"
def define_scalar(local_size, dtype, args):
if len(local_size) > 0: return f"tl.full(({','.join([str(next_power_of_2(x)) for x in local_size])},),{render_const(args,dtype)}, dtype={triton_dtypes[dtype]})"
return render_const(args,dtype)
def uops_to_triton(function_name:str, uops:List[UOp]):
local_size: List[int] = []
depth = 1
signatures, dims, bufs, kernel, valid = [], [], [], [], [] #type: ignore
c: DefaultDict[str, int] = defaultdict(int)
r: Dict[UOp, str] = {}
def ssa(u, prefix="t"):
nonlocal c, r
c[prefix] += 1
r[u]=f"{prefix}{c[prefix]-1}"
return r[u]
child_count: DefaultDict[UOp, int] = defaultdict(int)
for ru in uops:
for v in ru.vin:
child_count[v] += 1
def kk(s): kernel.append(" "*depth+s)
code_for_op: Final[Dict[Op, Callable]] = {
UnaryOps.EXP2: lambda x,dtype,: f"tl.math.exp2({x})",
UnaryOps.LOG2: lambda x,dtype,: f"tl.math.log2({x})",
UnaryOps.SIN: lambda x,dtype: f"tl.sin({x})",
UnaryOps.SQRT: lambda x,dtype: f"tl.sqrt({x})",
UnaryOps.NEG: lambda x,dtype: f"-{x}" if dtype != dtypes.bool else f"tl.where({x}, 0, 1)",
BinaryOps.ADD: lambda x,y,dtype: f"({x}+{y})", BinaryOps.SUB: lambda x,y,: f"({x}-{y})",
BinaryOps.MUL: lambda x,y,dtype: f"({x}*{y})", BinaryOps.DIV: lambda x,y,: f"({x}/{y})" if y != '0.0' else f"{x}*tl.where({x}==0.0, float('nan'), float('inf'))",
BinaryOps.MAX: lambda x,y,dtype: f"tl.maximum({x},{y})",
BinaryOps.CMPLT: lambda x,y,dtype: f"({x}<{y})",
BinaryOps.MOD: lambda x,y,dtype: f"tl.abs({x})%tl.abs({y})*tl.where({x}<0,-1,1)",
TernaryOps.MULACC: lambda x,y,z,dtype: f"(({x}*{y})+{z})",
TernaryOps.WHERE: lambda x,y,z,dtype: f"tl.where({x},{y},{z})",
}
def int_div(x,y): return f"({x}//{y})" if y != '0' else f"{x}*tl.where({x}==0, float('nan'), float('inf'))"
for u in uops:
uop,dtype,vin,args = u.uop,u.dtype,u.vin,u.arg
if uop == UOps.LOOP:
kk(f"for {ssa(u, 'ridx')} in range({vin[0].arg}, {r[vin[1]]}):")
depth += 1
elif uop == UOps.END: depth -= 1
elif uop == UOps.ALU:
assert dtype is not None
val = code_for_op[args](*[r[x] for x in vin])
if child_count[u] <=1 or dtypes.is_int(dtype): r[u] = int_div(*[r[x] for x in vin]) if args == BinaryOps.DIV and dtypes.is_int(dtype) else val
else: kk(f"{ssa(u, 'alu')} = ({val})")
elif uop == UOps.LOAD:
assert dtype is not None
if len(vin) == 2: kk(f"{ssa(u, 'val')} = {render_cast(f'tl.load({r[vin[0]]} + { fill_dims_for_idx(r[vin[1]], dims)}, mask = {render_valid(valid)})', dtype)}")
else: kk(f"{ssa(u, 'val')} = {render_cast(f'tl.where({r[vin[2]]}, tl.load({r[vin[0]]}+{fill_dims_for_idx(r[vin[1]],dims)} , mask={render_valid(valid+[r[vin[2]]])}), 0.0)', dtype)}")
elif uop == UOps.DEFINE_ACC: kk(f"{ssa(u, 'acc')} = {define_scalar(local_size, dtype, args).replace('//', '/')}")
elif uop == UOps.CONST: r[u] = define_scalar([], dtype, args)
elif uop == UOps.PHI:
kk(f"{r[vin[0]]} = {r[vin[1]].replace('//', '/')}")
r[u] = r[vin[0]]
elif uop == UOps.STORE:
assert not isinstance(dtype, ImageDType), "unimplemented: image store"
kk(f"{'if '+r[vin[3]]+': ' if len(vin)>3 else ''}tl.store({r[vin[0]]} + {r[vin[1]]}, {r[vin[2]].replace('//', '/')}, mask = {render_valid(valid)}) ")
elif uop == UOps.DEFINE_GLOBAL:
bufs.append(args)
signatures.append(signature_dtypes[args[1]])
r[u] = args[0]
elif uop == UOps.SPECIAL:
dims.append(args[1])
valid.append(f"{args[1]}<{get_max(args[2])}")
if args[1].startswith("g"): kk(f"{args[1]} = tl.program_id({args[0]}) # {args[2]}")
elif args[1].startswith("l"):
kk(f"{args[1]} = tl.arange({0}, {next_power_of_2(args[2])})")
local_size.append(args[2])
r[u] = args[1]
elif uop == UOps.CAST and dtype is not None: r[u] = render_cast(r[vin[0]], dtype)
else: raise NotImplementedError(f"unimplemented: {uop}")
prg = f"import triton\nimport triton.language as tl\ntl.core.TRITON_MAX_TENSOR_NUMEL = float('inf')\n@triton.jit\ndef {function_name}("+','.join(f"{buf[0]}" for buf in bufs)+"):\n"
for i, line in enumerate(list(filter(lambda line: "tl.arange" in line, kernel))): kernel[kernel.index(line)] += f"[{', '.join([':' if i == j else 'None' for j in range(len(local_size))])}]"
prg += "\n".join(kernel)
acc_local_size = 1
for x in local_size: acc_local_size *= next_power_of_2(x)
local_size = [acc_local_size] + [1] * (len(local_size) - 1)
if DEBUG >= 4: print(prg)
getlines = linecache.getlines
linecache.getlines = lambda filename, module_globals=None: prg.splitlines(keepends=True) if "<triton>" == filename else getlines(filename, module_globals)
exec(compile(prg, "<triton>", "exec"), globals()) # pylint: disable=W0122\
compiled = triton_compile(globals()[function_name], signature=",".join(signatures), device_type="cuda", debug=False, cc=(35 if getenv("CUDACPU", 0) else None))
prg = remove_single_scalar_curly_braces(compiled.asm["ptx"].split(".file")[0].split(".visible .func")[0])
max_local_size = [int(x) for x in prg.split(".maxntid ")[1].split("\n")[0].split(", ")]
for i in range(len(local_size)): local_size[i] = min(local_size[i], max_local_size[i])
return prg, {"shared":compiled.metadata["shared"], "local_size":local_size + [1]*(3-len(local_size))}