diff --git a/.github/workflows/benchmark.yml b/.github/workflows/benchmark.yml index 05839f173b..59df42d196 100644 --- a/.github/workflows/benchmark.yml +++ b/.github/workflows/benchmark.yml @@ -56,6 +56,10 @@ jobs: run: time PYTHONPATH=. TARGET_EVAL_ACC_PCT=97.5 python3 examples/beautiful_mnist.py | tee beautiful_mnist.txt - name: Run 10 CIFAR training steps run: STEPS=10 python3 examples/hlb_cifar10.py | tee train_cifar.txt + - name: Run 10 CIFAR training steps w HALF + run: STEPS=10 DEFAULT_FLOAT=HALF python3 examples/hlb_cifar10.py | tee train_cifar_half.txt + - name: Run 10 CIFAR training steps w BF16 + run: STEPS=10 DEFAULT_FLOAT=BFLOAT16 python3 examples/hlb_cifar10.py | tee train_cifar_bf16.txt # TODO: this is flaky too # - name: Run 10 CIFAR training steps w winograd # run: WINO=1 STEPS=10 python3 examples/hlb_cifar10.py | tee train_cifar_wino.txt @@ -65,9 +69,6 @@ jobs: path: | onnx_inference_speed.csv torch_speed.txt - beautiful_mnist.txt - train_cifar.txt - train_cifar_wino.txt llama_unjitted.txt llama_jitted.txt llama_beam.txt @@ -78,6 +79,11 @@ jobs: matmul.txt matmul_half.txt sd.txt + beautiful_mnist.txt + train_cifar.txt + train_cifar_half.txt + train_cifar_bf16.txt + train_cifar_wino.txt testnvidiabenchmark: name: NVIDIA Benchmark diff --git a/tinygrad/renderer/cstyle.py b/tinygrad/renderer/cstyle.py index 0642e6a08d..0d209aff74 100644 --- a/tinygrad/renderer/cstyle.py +++ b/tinygrad/renderer/cstyle.py @@ -204,6 +204,14 @@ class MetalLanguage(CStyleLanguage): uses_ptr_arithmetic = True code_for_workitem = {"g": lambda x: f"gid.{chr(120+x)}", "l": lambda x: f"lid.{chr(120+x)}"} extra_args = ['uint3 gid [[threadgroup_position_in_grid]]', 'uint3 lid [[thread_position_in_threadgroup]]'] + type_map = {dtypes.bfloat16: "bfloat"} + code_for_op = {**CStyleLanguage().code_for_op, + BinaryOps.MAX: lambda a,b,dtype: f"(bfloat)max((float){a},(float){b})" if dtype == dtypes.bfloat16 else f"max({a},{b})", + UnaryOps.SQRT: lambda x,dtype: f"(bfloat)sqrt({x})" if dtype == dtypes.bfloat16 else f"sqrt({x})", + UnaryOps.EXP2: lambda x,dtype: f"(bfloat)exp2({x})" if dtype == dtypes.bfloat16 else f"exp2({x})", + UnaryOps.LOG2: lambda x,dtype: f"(bfloat)log2({x})" if dtype == dtypes.bfloat16 else f"log2({x})", + UnaryOps.SIN: lambda x,dtype: f"(bfloat)sin({x})" if dtype == dtypes.bfloat16 else f"sin({x})",} + def render_cast(self, x: List[str], var_dtype: DType, bitcast=False) -> str: return f"as_type<{self.render_dtype(var_dtype)}>({x[0]})" if bitcast else super().render_cast(x, var_dtype) @@ -242,20 +250,18 @@ class CUDALanguage(CStyleLanguage): """__device__ float4 __cuda_mma_m16n8k16_f16_f32(half8 a, half4 b, float4 c) { int *a_pk = (int *) (&a), *b_pk = (int *) (&b); asm( "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 { %0, %1, %2, %3 }, { %4, %5, %6, %7 }, { %8, %9 }, { %0, %1, %2, %3 };" : "+f"(c.x), "+f"(c.y), "+f"(c.z), "+f"(c.w) : "r"(a_pk[0]), "r"(a_pk[1]), "r"(a_pk[2]), "r"(a_pk[3]), "r"(b_pk[0]), "r"(b_pk[1]) ); - return c;}""", - ] + return c;}""",] + if any(uop.dtype == dtypes.bfloat16 for uop in uops): prefix.append("#include ") return super().render_kernel(function_name, kernel, bufs, uops, prefix=prefix) CUDARenderer = functools.partial(uops_to_cstyle, CUDALanguage()) -code_for_op_hip = { - # TODO: MAX with int uses fmax_f32? - BinaryOps.MAX: lambda a,b,dtype: f"__ocml_fmax_f{ {dtypes.half:16, dtypes.double:64}.get(dtype, 32) }({a},{b})", - UnaryOps.SQRT: lambda x,dtype: f"__ocml_sqrt_f{ {dtypes.half:16, dtypes.double:64}.get(dtype, 32)}({x})", - UnaryOps.SIN: lambda x,dtype: f"__ocml_sin_f{ {dtypes.half:16, dtypes.double:64}.get(dtype, 32)}({x})", - UnaryOps.LOG2: lambda x,dtype: f"__ocml_log2_f{ {dtypes.half:16, dtypes.double:64}.get(dtype, 32)}({x})", - UnaryOps.EXP2: lambda x,dtype: f"__ocml_exp2_f{ {dtypes.half:16, dtypes.double:64}.get(dtype, 32)}({x})", -} +code_for_op_hip = { UnaryOps.SQRT: lambda x,dtype: f"__ocml_sqrt_f{ {dtypes.half:16, dtypes.double:64}.get(dtype, 32)}({x})", + UnaryOps.SIN: lambda x,dtype: f"__ocml_sin_f{ {dtypes.half:16, dtypes.double:64}.get(dtype, 32)}({x})", + UnaryOps.LOG2: lambda x,dtype: f"__ocml_log2_f{ {dtypes.half:16, dtypes.double:64}.get(dtype, 32)}({x})", + UnaryOps.EXP2: lambda x,dtype: f"__ocml_exp2_f{ {dtypes.half:16, dtypes.double:64}.get(dtype, 32)}({x})", + # TODO: MAX with int uses fmax_f32? + BinaryOps.MAX: lambda a,b,dtype: f"__ocml_fmax_f{ {dtypes.half:16, dtypes.double:64}.get(dtype, 32) }({a},{b})",} def _make_hip_code_for_op(): def wrapper(key, func):