use ConstType in various const function type hint (#4074)

This commit is contained in:
chenyu
2024-04-04 20:32:07 -04:00
committed by GitHub
parent c1cffed1df
commit 5e6e6c9a67
4 changed files with 14 additions and 15 deletions

View File

@@ -3,7 +3,7 @@ import functools, struct
from collections import defaultdict
from tinygrad.codegen.linearizer import UOps, UOp
from tinygrad.ops import BinaryOps, UnaryOps, TernaryOps, Op
from tinygrad.dtype import dtypes, DType, PtrDType, INVERSE_DTYPES_DICT
from tinygrad.dtype import dtypes, DType, PtrDType, ConstType, INVERSE_DTYPES_DICT
from tinygrad.codegen.uops import UOpGraph, PatternMatcher
def render_val(x, dtype):
@@ -46,7 +46,7 @@ class AssemblyLanguage(NamedTuple):
types: Dict[DType, str] = INVERSE_DTYPES_DICT
supports_half: List[Op] = []
def render_const(self, x:Union[float,int,bool], dtype, mov=None) -> Union[List[str], str]: raise NotImplementedError()
def render_const(self, x:ConstType, dtype:DType, mov=None) -> Union[List[str], str]: raise NotImplementedError()
def render_local(self, dest, name, size, dtype) -> List[str]: raise NotImplementedError()
def render_loop(self, idx, start, label, acc=None) -> List[str]: raise NotImplementedError()
@@ -112,7 +112,7 @@ def uops_to_asm(lang:AssemblyLanguage, function_name:str, uops:UOpGraph) -> str:
r_label[u] = f"{lang.label_prefix}{prefix}_{c_label[prefix]-1}"
return r_label[u]
def const(x:Union[float,int,bool], dtype, mov=False):
def const(x:ConstType, dtype:DType, mov=False):
if mov or dtype in lang.const_requires_mov:
kk(*lang.render_const(x, dtype, mov=(out:=ssa(None, 'const', lang.types[dtype]))))
return out
@@ -249,7 +249,7 @@ class PTXLanguage(AssemblyLanguage):
const_requires_mov = [dtypes.half, dtypes.bool]
def render_const(self, x:Union[float,int,bool], dtype, mov=None) -> Union[List[str], str]:
def render_const(self, x:ConstType, dtype:DType, mov=None) -> Union[List[str], str]:
val = render_val(x, dtype)
if dtype == dtypes.bool: return [f"setp.ne.s16 {mov}, {val}, 0;"]
return [f"mov.b{self.types[dtype][1:]} {mov}, {val};"] if mov else val

View File

@@ -4,7 +4,7 @@ from collections import defaultdict, Counter
from tinygrad.codegen.linearizer import UOps, UOp
from tinygrad.ops import UnaryOps, BinaryOps, TernaryOps
from tinygrad.helpers import strip_parens, getenv, prod
from tinygrad.dtype import ImageDType, dtypes, DType, PtrDType
from tinygrad.dtype import ImageDType, dtypes, DType, PtrDType, ConstType
from tinygrad.codegen.uops import UOpGraph
class CStyleLanguage(NamedTuple):
@@ -41,13 +41,12 @@ class CStyleLanguage(NamedTuple):
return f"{self.float4.replace('float4', self.render_dtype(var_dtype))}({','.join(x)})"
# returns a str expression of the const with the given type
def render_const(self, x:Union[float,int,bool], var_dtype:DType) -> str:
def render_const(self, x:ConstType, dtype:DType) -> str:
if math.isnan(x): val = "NAN"
elif math.isinf(x): val = ("-" if x < 0 else "") + "INFINITY"
elif var_dtype == dtypes.float64: val = f"{x}"
else: val = f"{x}f" if dtypes.is_float(var_dtype) else f"{x}" if dtypes.is_int(var_dtype) else f"{x}".lower()
return (self.render_cast([val]*var_dtype.count, var_dtype)
if var_dtype.count > 1 or var_dtype not in [dtypes.float, dtypes.int, dtypes.bool] else val)
elif dtype == dtypes.float64: val = f"{x}"
else: val = f"{x}f" if dtypes.is_float(dtype) else f"{x}" if dtypes.is_int(dtype) else f"{x}".lower()
return (self.render_cast([val] * dtype.count, dtype) if dtype.count > 1 or dtype not in [dtypes.float, dtypes.int, dtypes.bool] else val)
# returns a str expression of the loaded value with the output type
def render_load(self, output_dtype, buf_name, buf_dtype, idx, local=False) -> str: