mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-07 03:00:26 -04:00
move renderer into options (#4514)
* move renderer into options * fix tests * renders are functions
This commit is contained in:
@@ -1,5 +1,5 @@
|
||||
from typing import Dict, List, Optional, NamedTuple, Tuple, Union, DefaultDict, cast, Literal, Callable
|
||||
import math, functools
|
||||
import math
|
||||
from collections import defaultdict, Counter
|
||||
from tinygrad.codegen.linearizer import UOps, UOp
|
||||
from tinygrad.ops import UnaryOps, BinaryOps, TernaryOps
|
||||
@@ -179,7 +179,7 @@ class ClangLanguage(CStyleLanguage):
|
||||
buffer_suffix = " restrict"
|
||||
type_map = {dtypes.bool:"_Bool", dtypes.half:"__fp16"}
|
||||
code_for_op = {**CStyleLanguage().code_for_op, BinaryOps.MAX: lambda a,b,dtype: f"(({a}>{b})?{a}:{b})"}
|
||||
ClangRenderer = functools.partial(uops_to_cstyle, ClangLanguage())
|
||||
def ClangRenderer(name:str, uops:UOpGraph) -> str: return uops_to_cstyle(ClangLanguage(), name, uops)
|
||||
|
||||
class OpenCLLanguage(CStyleLanguage):
|
||||
kernel_prefix = "__kernel "
|
||||
@@ -197,7 +197,7 @@ class OpenCLLanguage(CStyleLanguage):
|
||||
def render_kernel(self, function_name, kernel, bufs, uops, prefix=None) -> str:
|
||||
if any(uop.dtype == dtypes.half for uop in uops): prefix = ["#pragma OPENCL EXTENSION cl_khr_fp16 : enable"]
|
||||
return super().render_kernel(function_name, kernel, bufs, uops, prefix)
|
||||
OpenCLRenderer = functools.partial(uops_to_cstyle, OpenCLLanguage())
|
||||
def OpenCLRenderer(name:str, uops:UOpGraph) -> str: return uops_to_cstyle(OpenCLLanguage(), name, uops)
|
||||
|
||||
class MetalLanguage(CStyleLanguage):
|
||||
kernel_prefix = "kernel "
|
||||
@@ -227,7 +227,7 @@ class MetalLanguage(CStyleLanguage):
|
||||
b.thread_elements()[1] = n.y; c.thread_elements()[0] = o.x; c.thread_elements()[1] = o.y; simdgroup_multiply_accumulate(c, a, b, c);
|
||||
return {arg[3].name}2(c.thread_elements()[0], c.thread_elements()[1]);\n}}""")
|
||||
return super().render_kernel(function_name, kernel, bufs, uops, prefix)
|
||||
MetalRenderer = functools.partial(uops_to_cstyle, MetalLanguage())
|
||||
def MetalRenderer(name:str, uops:UOpGraph) -> str: return uops_to_cstyle(MetalLanguage(), name, uops)
|
||||
|
||||
code_for_op_half = {BinaryOps.MAX: lambda a,b,dtype: f"__hmax({a},{b})" if dtype in (dtypes.half, dtypes.bfloat16) else f"max({a},{b})",
|
||||
UnaryOps.SQRT: lambda x,dtype: f"hsqrt({x})" if dtype in (dtypes.half, dtypes.bfloat16) else f"sqrt({x})",
|
||||
@@ -271,7 +271,7 @@ asm( "mma.sync.aligned.m16n8k16.row.col.{co}.{ci}.{ci}.{co} {{ %0, %1, %2, %3 }}
|
||||
return c;}}""")
|
||||
|
||||
return super().render_kernel(function_name, kernel, bufs, uops, prefix=prefix)
|
||||
CUDARenderer = functools.partial(uops_to_cstyle, CUDALanguage())
|
||||
def CUDARenderer(name:str, uops:UOpGraph) -> str: return uops_to_cstyle(CUDALanguage(), name, uops)
|
||||
|
||||
code_for_op_hip = { UnaryOps.SQRT: lambda x,dtype: f"__ocml_sqrt_f{ {dtypes.half:16, dtypes.double:64}.get(dtype, 32)}({x})",
|
||||
UnaryOps.SIN: lambda x,dtype: f"__ocml_sin_f{ {dtypes.half:16, dtypes.double:64}.get(dtype, 32)}({x})",
|
||||
@@ -358,4 +358,4 @@ static __attribute__((device)) bool operator==(hip_bfloat16 a, hip_bfloat16 b) {
|
||||
# NOTE: this makes hlb_cifar10 twice as fast, there may be more gains in tweaking these parameters
|
||||
return f"__attribute__((amdgpu_flat_work_group_size(1, {requiredMaxThreadsPerBlock})))"
|
||||
|
||||
HIPRenderer = functools.partial(uops_to_cstyle, HIPLanguage())
|
||||
def HIPRenderer(name:str, uops:UOpGraph) -> str: return uops_to_cstyle(HIPLanguage(), name, uops)
|
||||
|
||||
Reference in New Issue
Block a user