mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-10 15:38:29 -05:00
HIP cleanups (#2843)
* move everything to code_for_op to reason about it * loop the loopable parts * its not that unreadable * these are loopable too * nitpick * tests p1 - replace these with the actual compiler running alu ops tests * tests p2: compile test_dtype_alu in HIP! +add to CI * nobody liked test_renderer * revert test_dtypes change * isolated mockhip tests * dont need the WHERE hack after #2782 +ruff * bf16 is broken in HIP job failed in: https://github.com/tinygrad/tinygrad/actions/runs/7232101987/job/19705951290?pr=2778#step:8:73 * picking this back up * add compile tests for unary ops and binary ops * MOD is only in ints * CMPLT wont work after the dtypes pr is merged because it will always be bool * test all combinations * Update cstyle.py * don't use vload * no getenv * set seed --------- Co-authored-by: qazal <qazal.software@gmail.com> Co-authored-by: qazal <77887910+Qazalin@users.noreply.github.com>
This commit is contained in:
1
test/external/fuzz_shapetracker_math.py
vendored
1
test/external/fuzz_shapetracker_math.py
vendored
@@ -38,6 +38,7 @@ def fuzz_invert():
|
||||
return start, st_sum
|
||||
|
||||
if __name__ == "__main__":
|
||||
random.seed(42)
|
||||
total = getenv("CNT", 100)
|
||||
for fuzz in [globals()[f'fuzz_{x}'] for x in getenv("FUZZ", "invert,plus").split(",")]:
|
||||
good = 0
|
||||
|
||||
@@ -1,9 +1,18 @@
|
||||
#!/usr/bin/env python
|
||||
import unittest
|
||||
import operator
|
||||
from tinygrad import Tensor, Device, dtypes
|
||||
from tinygrad.helpers import DEBUG, to_function_name
|
||||
from tinygrad.codegen.linearizer import Linearizer
|
||||
from tinygrad.renderer.cstyle import HIPRenderer
|
||||
from examples.beautiful_mnist import Model as MNIST
|
||||
from examples.hlb_cifar10 import SpeedyResNet
|
||||
|
||||
from hypothesis import given, strategies as st, settings
|
||||
settings.register_profile("my_profile", deadline=None)
|
||||
settings.load_profile("my_profile")
|
||||
print(settings.default)
|
||||
|
||||
@unittest.skipIf(Device.DEFAULT != "HIP", reason="testing HIP->rdna3 compilation needs HIP=1")
|
||||
class TestHIPCompilationRDNA(unittest.TestCase):
|
||||
def test_compile_hip_mnist(self):
|
||||
@@ -34,5 +43,34 @@ class TestHIPCompilationRDNA(unittest.TestCase):
|
||||
|
||||
dtypes.default_float = old_default_float
|
||||
|
||||
def compile_ast_to_hip(out: Tensor):
|
||||
from tinygrad.runtime.ops_hip import compile_hip
|
||||
|
||||
lin = Linearizer(out.lazydata.schedule()[-1].ast)
|
||||
lin.hand_coded_optimizations()
|
||||
lin.linearize()
|
||||
code = HIPRenderer(to_function_name(lin.name), lin.uops)[0]
|
||||
if DEBUG >= 4: print(code)
|
||||
compile_hip(code)
|
||||
|
||||
binary_operations = [operator.add, operator.sub, operator.mul]
|
||||
unary_operations = [Tensor.exp, Tensor.log, operator.neg, Tensor.sin, Tensor.sqrt, Tensor.reciprocal]
|
||||
float_dtypes = [dtypes.float16, dtypes.float32]
|
||||
|
||||
@unittest.skipIf(Device.DEFAULT != "HIP", reason="testing HIP->rdna3 compilation needs HIP=1")
|
||||
class TestHIPALUCompilation(unittest.TestCase):
|
||||
@given(st.sampled_from(unary_operations), st.sampled_from(float_dtypes))
|
||||
def test_unary_ops(self, op, dtype):
|
||||
a = Tensor.randn(4,4, dtype=dtype)
|
||||
out = op(a)
|
||||
compile_ast_to_hip(out)
|
||||
|
||||
@given(st.sampled_from(binary_operations), st.sampled_from(float_dtypes))
|
||||
def test_binary_ops(self, op, dtype):
|
||||
a = Tensor.randn(4,4, dtype=dtype)
|
||||
b = Tensor.randn(4,4, dtype=dtype)
|
||||
out = op(a,b)
|
||||
compile_ast_to_hip(out)
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
@@ -243,6 +243,14 @@ class MetalLanguage(CStyleLanguage):
|
||||
return f"as_type<{var_dtype.name}>({x[0]})" if bitcast else super().render_cast(x, var_dtype)
|
||||
MetalRenderer = functools.partial(uops_to_cstyle, MetalLanguage())
|
||||
|
||||
code_for_op_half = {
|
||||
BinaryOps.MAX: lambda a,b,dtype: f"max({a},{b})" if dtype != dtypes.half else f"__hmax({a},{b})",
|
||||
UnaryOps.SQRT: lambda x,dtype: f"sqrt({x})" if dtype != dtypes.half else f"hsqrt({x})",
|
||||
UnaryOps.SIN: lambda x,dtype: f"sin({x})" if dtype != dtypes.half else f"hsin({x})",
|
||||
UnaryOps.LOG2: lambda x,dtype: f"log2({x})" if dtype != dtypes.half else f"hlog2({x})",
|
||||
UnaryOps.EXP2: lambda x,dtype: f"exp2({x})" if dtype != dtypes.half else f"hexp2({x})",
|
||||
}
|
||||
|
||||
class CUDALanguage(CStyleLanguage):
|
||||
kernel_prefix = "#define INFINITY (__int_as_float(0x7f800000))\n#define NAN (__int_as_float(0x7fffffff))\nextern \"C\" __global__ "
|
||||
smem_prefix = "__shared__ "
|
||||
@@ -252,8 +260,7 @@ class CUDALanguage(CStyleLanguage):
|
||||
gid = [f'blockIdx.{chr(120+i)}' for i in range(3)]
|
||||
lid = [f'threadIdx.{chr(120+i)}' for i in range(3)]
|
||||
xid = [f'(blockIdx.{chr(120+i)}*blockDim.{chr(120+i)}+threadIdx.{chr(120+i)})' for i in range(3)]
|
||||
code_for_op = {**CStyleLanguage().code_for_op, BinaryOps.MAX: lambda a,b,dtype: f"max({a},{b})" if dtype != dtypes.half else f"__hmax({a},{b})",
|
||||
UnaryOps.EXP2: lambda x,dtype: f"exp2({x})" if dtype != dtypes.half else f"hexp2({x})"}
|
||||
code_for_op = {**CStyleLanguage().code_for_op, **code_for_op_half}
|
||||
half_prekernel = """
|
||||
#include <cuda_fp16.h>
|
||||
struct half4 { half x, y, z, w; };
|
||||
@@ -263,12 +270,6 @@ CUDARenderer = functools.partial(uops_to_cstyle, CUDALanguage())
|
||||
|
||||
class HIPLanguage(CStyleLanguage):
|
||||
kernel_prefix = "#include <hip/hip_common.h>\n#define INFINITY (__builtin_inff())\n#define NAN (__builtin_nanf(\"\"))" + """
|
||||
__device__ float4 max(float4 x, float4 y) { return float4(max(x.x, y.x), max(x.y, y.y), max(x.z, y.z), max(x.w, y.w)); }
|
||||
__device__ float4 pow(float x, float4 y) { return float4(pow(x, y.x), pow(x, y.y), pow(x, y.z), pow(x, y.w)); }
|
||||
__device__ float4 pow(float4 x, float4 y) { return float4(pow(x.x, y.x), pow(x.y, y.y), pow(x.z, y.z), pow(x.w, y.w)); }
|
||||
__device__ float4 log2(float4 x) { return float4(log2(x.x), log2(x.y), log2(x.z), log2(x.w)); }
|
||||
__device__ float4 exp2(float4 x) { return float4(exp2(x.x), exp2(x.y), exp2(x.z), exp2(x.w)); }
|
||||
__device__ float4 sin(float4 x) { return float4(sin(x.x), sin(x.y), sin(x.z), sin(x.w)); }
|
||||
typedef float float8 __attribute__((ext_vector_type(8)));
|
||||
__device__ float8 make_float8(float x, float y, float z, float w, float a, float b, float c, float d) { return {x, y, z, w, a, b, c, d}; }
|
||||
extern "C" __global__
|
||||
@@ -278,7 +279,6 @@ class HIPLanguage(CStyleLanguage):
|
||||
smem_prefix_for_cast=False
|
||||
barrier = "__syncthreads();"
|
||||
float4 = "make_float4"
|
||||
uses_vload=True
|
||||
uses_ptr_arithmetic=True
|
||||
half_prekernel = "#include <hip/hip_fp16.h>\n" + """
|
||||
typedef union { struct { half x, y, z, w; } __attribute__((aligned(8))); half data[4]; } half4;
|
||||
@@ -289,25 +289,11 @@ __device__ half8 make_half8(half x, half y, half z, half w, half a, half b, half
|
||||
__device__ half16 make_half16(half x, half y, half z, half w, half a, half b, half c, half d,
|
||||
half e, half f, half g, half h, half i, half j, half k, half l) {
|
||||
return {x, y, z, w, a, b, c, d, e, f, g, h, i, j, k, l}; }
|
||||
__device__ float vload_half(size_t offset, const half *p) { return (float)*(p + offset); }
|
||||
__device__ float2 vload_half2(size_t offset, const half *p) { return make_float2((float)*(p + offset*2), (float)*(p + offset*2 + 1)); }
|
||||
__device__ float4 vload_half4(size_t offset, const half *p) {
|
||||
return make_float4((float)*(p + offset*4), (float)*(p + offset*4 + 1), (float)*(p + offset*4 + 2), (float)*(p + offset*4 + 3)); }
|
||||
__device__ void vstore_half(float data, size_t offset, half *p) { *(p + offset) = (half)data; }
|
||||
__device__ void vstore_half2(float2 data, size_t offset, half *p) { *(p + offset*2) = (half)data.x; *(p + offset*2 + 1) = (half)data.y; }
|
||||
__device__ void vstore_half4(float4 data, size_t offset, half *p) {
|
||||
*(p + offset*4) = (half)data.x; *(p + offset*4 + 1) = (half)data.y; *(p + offset*4 + 2) = (half)data.z; *(p + offset*4 + 3) = (half)data.w; }
|
||||
__device__ half exp2(half x) { return hexp2(x); }
|
||||
__device__ half log2(half x) { return hlog2(x); }
|
||||
__device__ half sin(half x) { return hsin(x); }
|
||||
__device__ half sqrt(half x) { return hsqrt(x); }
|
||||
__device__ half hmax(half a, half b) { return __hgt(a, b) ? a : b; }
|
||||
__device__ half operator%(const half &a, const half &b) { return __hsub(a, __hmul(b, __float2half(floorf(__half2float(a) / __half2float(b))))); }
|
||||
"""
|
||||
gid = [f'blockIdx.{chr(120+i)}' for i in range(3)]
|
||||
lid = [f'threadIdx.{chr(120+i)}' for i in range(3)]
|
||||
xid = [f'(blockIdx.{chr(120+i)}*blockDim.{chr(120+i)}+threadIdx.{chr(120+i)})' for i in range(3)]
|
||||
code_for_op = {**CStyleLanguage().code_for_op, BinaryOps.MAX: lambda a,b,dtype: f"max({a},{b})" if dtype != dtypes.half else f"hmax({a},{b})",}
|
||||
code_for_op = {**CStyleLanguage().code_for_op, **code_for_op_half}
|
||||
HIPRenderer = functools.partial(uops_to_cstyle, HIPLanguage())
|
||||
|
||||
# TODO: how much of this can be merged with above?
|
||||
|
||||
Reference in New Issue
Block a user