HIP cleanups (#2843)

* move everything to code_for_op to reason about it

* loop the loopable parts

* its not that unreadable

* these are loopable too

* nitpick

* tests p1 - replace these with the actual compiler running alu ops tests

* tests p2: compile test_dtype_alu in HIP!

+add to CI

* nobody liked test_renderer

* revert test_dtypes change

* isolated mockhip tests

* dont need the WHERE hack after #2782

+ruff

* bf16 is broken in HIP

job failed in: https://github.com/tinygrad/tinygrad/actions/runs/7232101987/job/19705951290?pr=2778#step:8:73

* picking this back up

* add compile tests for unary ops and binary ops

* MOD is only in ints

* CMPLT wont work after the dtypes pr is merged because it will always be bool

* test all combinations

* Update cstyle.py

* don't use vload

* no getenv

* set seed

---------

Co-authored-by: qazal <qazal.software@gmail.com>
Co-authored-by: qazal <77887910+Qazalin@users.noreply.github.com>
This commit is contained in:
George Hotz
2023-12-18 21:09:32 -08:00
committed by GitHub
parent b6d71b131e
commit 07df14aa0e
3 changed files with 49 additions and 24 deletions

View File

@@ -38,6 +38,7 @@ def fuzz_invert():
return start, st_sum
if __name__ == "__main__":
random.seed(42)
total = getenv("CNT", 100)
for fuzz in [globals()[f'fuzz_{x}'] for x in getenv("FUZZ", "invert,plus").split(",")]:
good = 0

View File

@@ -1,9 +1,18 @@
#!/usr/bin/env python
import unittest
import operator
from tinygrad import Tensor, Device, dtypes
from tinygrad.helpers import DEBUG, to_function_name
from tinygrad.codegen.linearizer import Linearizer
from tinygrad.renderer.cstyle import HIPRenderer
from examples.beautiful_mnist import Model as MNIST
from examples.hlb_cifar10 import SpeedyResNet
from hypothesis import given, strategies as st, settings
settings.register_profile("my_profile", deadline=None)
settings.load_profile("my_profile")
print(settings.default)
@unittest.skipIf(Device.DEFAULT != "HIP", reason="testing HIP->rdna3 compilation needs HIP=1")
class TestHIPCompilationRDNA(unittest.TestCase):
def test_compile_hip_mnist(self):
@@ -34,5 +43,34 @@ class TestHIPCompilationRDNA(unittest.TestCase):
dtypes.default_float = old_default_float
def compile_ast_to_hip(out: Tensor):
from tinygrad.runtime.ops_hip import compile_hip
lin = Linearizer(out.lazydata.schedule()[-1].ast)
lin.hand_coded_optimizations()
lin.linearize()
code = HIPRenderer(to_function_name(lin.name), lin.uops)[0]
if DEBUG >= 4: print(code)
compile_hip(code)
binary_operations = [operator.add, operator.sub, operator.mul]
unary_operations = [Tensor.exp, Tensor.log, operator.neg, Tensor.sin, Tensor.sqrt, Tensor.reciprocal]
float_dtypes = [dtypes.float16, dtypes.float32]
@unittest.skipIf(Device.DEFAULT != "HIP", reason="testing HIP->rdna3 compilation needs HIP=1")
class TestHIPALUCompilation(unittest.TestCase):
@given(st.sampled_from(unary_operations), st.sampled_from(float_dtypes))
def test_unary_ops(self, op, dtype):
a = Tensor.randn(4,4, dtype=dtype)
out = op(a)
compile_ast_to_hip(out)
@given(st.sampled_from(binary_operations), st.sampled_from(float_dtypes))
def test_binary_ops(self, op, dtype):
a = Tensor.randn(4,4, dtype=dtype)
b = Tensor.randn(4,4, dtype=dtype)
out = op(a,b)
compile_ast_to_hip(out)
if __name__ == "__main__":
unittest.main()

View File

@@ -243,6 +243,14 @@ class MetalLanguage(CStyleLanguage):
return f"as_type<{var_dtype.name}>({x[0]})" if bitcast else super().render_cast(x, var_dtype)
MetalRenderer = functools.partial(uops_to_cstyle, MetalLanguage())
code_for_op_half = {
BinaryOps.MAX: lambda a,b,dtype: f"max({a},{b})" if dtype != dtypes.half else f"__hmax({a},{b})",
UnaryOps.SQRT: lambda x,dtype: f"sqrt({x})" if dtype != dtypes.half else f"hsqrt({x})",
UnaryOps.SIN: lambda x,dtype: f"sin({x})" if dtype != dtypes.half else f"hsin({x})",
UnaryOps.LOG2: lambda x,dtype: f"log2({x})" if dtype != dtypes.half else f"hlog2({x})",
UnaryOps.EXP2: lambda x,dtype: f"exp2({x})" if dtype != dtypes.half else f"hexp2({x})",
}
class CUDALanguage(CStyleLanguage):
kernel_prefix = "#define INFINITY (__int_as_float(0x7f800000))\n#define NAN (__int_as_float(0x7fffffff))\nextern \"C\" __global__ "
smem_prefix = "__shared__ "
@@ -252,8 +260,7 @@ class CUDALanguage(CStyleLanguage):
gid = [f'blockIdx.{chr(120+i)}' for i in range(3)]
lid = [f'threadIdx.{chr(120+i)}' for i in range(3)]
xid = [f'(blockIdx.{chr(120+i)}*blockDim.{chr(120+i)}+threadIdx.{chr(120+i)})' for i in range(3)]
code_for_op = {**CStyleLanguage().code_for_op, BinaryOps.MAX: lambda a,b,dtype: f"max({a},{b})" if dtype != dtypes.half else f"__hmax({a},{b})",
UnaryOps.EXP2: lambda x,dtype: f"exp2({x})" if dtype != dtypes.half else f"hexp2({x})"}
code_for_op = {**CStyleLanguage().code_for_op, **code_for_op_half}
half_prekernel = """
#include <cuda_fp16.h>
struct half4 { half x, y, z, w; };
@@ -263,12 +270,6 @@ CUDARenderer = functools.partial(uops_to_cstyle, CUDALanguage())
class HIPLanguage(CStyleLanguage):
kernel_prefix = "#include <hip/hip_common.h>\n#define INFINITY (__builtin_inff())\n#define NAN (__builtin_nanf(\"\"))" + """
__device__ float4 max(float4 x, float4 y) { return float4(max(x.x, y.x), max(x.y, y.y), max(x.z, y.z), max(x.w, y.w)); }
__device__ float4 pow(float x, float4 y) { return float4(pow(x, y.x), pow(x, y.y), pow(x, y.z), pow(x, y.w)); }
__device__ float4 pow(float4 x, float4 y) { return float4(pow(x.x, y.x), pow(x.y, y.y), pow(x.z, y.z), pow(x.w, y.w)); }
__device__ float4 log2(float4 x) { return float4(log2(x.x), log2(x.y), log2(x.z), log2(x.w)); }
__device__ float4 exp2(float4 x) { return float4(exp2(x.x), exp2(x.y), exp2(x.z), exp2(x.w)); }
__device__ float4 sin(float4 x) { return float4(sin(x.x), sin(x.y), sin(x.z), sin(x.w)); }
typedef float float8 __attribute__((ext_vector_type(8)));
__device__ float8 make_float8(float x, float y, float z, float w, float a, float b, float c, float d) { return {x, y, z, w, a, b, c, d}; }
extern "C" __global__
@@ -278,7 +279,6 @@ class HIPLanguage(CStyleLanguage):
smem_prefix_for_cast=False
barrier = "__syncthreads();"
float4 = "make_float4"
uses_vload=True
uses_ptr_arithmetic=True
half_prekernel = "#include <hip/hip_fp16.h>\n" + """
typedef union { struct { half x, y, z, w; } __attribute__((aligned(8))); half data[4]; } half4;
@@ -289,25 +289,11 @@ __device__ half8 make_half8(half x, half y, half z, half w, half a, half b, half
__device__ half16 make_half16(half x, half y, half z, half w, half a, half b, half c, half d,
half e, half f, half g, half h, half i, half j, half k, half l) {
return {x, y, z, w, a, b, c, d, e, f, g, h, i, j, k, l}; }
__device__ float vload_half(size_t offset, const half *p) { return (float)*(p + offset); }
__device__ float2 vload_half2(size_t offset, const half *p) { return make_float2((float)*(p + offset*2), (float)*(p + offset*2 + 1)); }
__device__ float4 vload_half4(size_t offset, const half *p) {
return make_float4((float)*(p + offset*4), (float)*(p + offset*4 + 1), (float)*(p + offset*4 + 2), (float)*(p + offset*4 + 3)); }
__device__ void vstore_half(float data, size_t offset, half *p) { *(p + offset) = (half)data; }
__device__ void vstore_half2(float2 data, size_t offset, half *p) { *(p + offset*2) = (half)data.x; *(p + offset*2 + 1) = (half)data.y; }
__device__ void vstore_half4(float4 data, size_t offset, half *p) {
*(p + offset*4) = (half)data.x; *(p + offset*4 + 1) = (half)data.y; *(p + offset*4 + 2) = (half)data.z; *(p + offset*4 + 3) = (half)data.w; }
__device__ half exp2(half x) { return hexp2(x); }
__device__ half log2(half x) { return hlog2(x); }
__device__ half sin(half x) { return hsin(x); }
__device__ half sqrt(half x) { return hsqrt(x); }
__device__ half hmax(half a, half b) { return __hgt(a, b) ? a : b; }
__device__ half operator%(const half &a, const half &b) { return __hsub(a, __hmul(b, __float2half(floorf(__half2float(a) / __half2float(b))))); }
"""
gid = [f'blockIdx.{chr(120+i)}' for i in range(3)]
lid = [f'threadIdx.{chr(120+i)}' for i in range(3)]
xid = [f'(blockIdx.{chr(120+i)}*blockDim.{chr(120+i)}+threadIdx.{chr(120+i)})' for i in range(3)]
code_for_op = {**CStyleLanguage().code_for_op, BinaryOps.MAX: lambda a,b,dtype: f"max({a},{b})" if dtype != dtypes.half else f"hmax({a},{b})",}
code_for_op = {**CStyleLanguage().code_for_op, **code_for_op_half}
HIPRenderer = functools.partial(uops_to_cstyle, HIPLanguage())
# TODO: how much of this can be merged with above?