mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-24 06:18:01 -05:00
render_locals takes a dtype (#2873)
Co-authored-by: chenyu <chenyu@fastmail.com>
This commit is contained in:
@@ -68,8 +68,8 @@ class CStyleLanguage(NamedTuple):
|
||||
|
||||
return self.render_cast([out_val], output_dtype) if output_dtype != buf_dtype else out_val
|
||||
|
||||
def render_local(self, name:str, size:int):
|
||||
return self.smem_align + self.smem_prefix + f"float {name}[{size}];"
|
||||
def render_local(self, name:str, dtype:DType, size:int):
|
||||
return self.smem_align + self.smem_prefix + f"{dtype.name} {name}[{size}];"
|
||||
|
||||
def render_for(self, expr: str, _min:Union[int,str], _max:Union[int,str]) -> str:
|
||||
return f"for ({self.generic_var_prefix if self.generic_var_prefix else 'int'} {expr} = {_min}; {expr} < {_max}; {expr}++) {{"
|
||||
@@ -190,10 +190,11 @@ def uops_to_cstyle(lang:CStyleLanguage, function_name:str, uops:List[UOp]) -> Tu
|
||||
if child_count[u] <= 1: r[u] = val
|
||||
else: kk(f"{lang.generic_var_prefix if lang.generic_var_prefix else dtype.name} {ssa(u,'cast')} = {val};")
|
||||
elif uop == UOps.DEFINE_LOCAL:
|
||||
assert dtype is not None
|
||||
if lang.external_local_bufs:
|
||||
prekernel.append(lang.render_local(args[0], args[1]))
|
||||
prekernel.append(lang.render_local(args[0], dtype, args[1]))
|
||||
else:
|
||||
kk(lang.render_local(args[0], args[1]))
|
||||
kk(lang.render_local(args[0], dtype, args[1]))
|
||||
r[u] = args[0]
|
||||
elif uop == UOps.DEFINE_GLOBAL:
|
||||
assert dtype is not None
|
||||
@@ -309,8 +310,8 @@ class WGSLLanguage(CStyleLanguage):
|
||||
# HACK: write bool as f32. remove after elementwise op cast inputs properly
|
||||
type_map = {dtypes.float: "f32", dtypes.half: "f16", dtypes.int32: "i32", dtypes.uint32: "u32", dtypes.bool: "f32"}
|
||||
|
||||
def render_local(self, name: str, size: int):
|
||||
return f"var<workgroup> {name}: array<f32,{size}>;"
|
||||
def render_local(self, name: str, dtype:DType, size: int):
|
||||
return f"var<workgroup> {name}: array<{self.type_map[dtype]},{size}>;"
|
||||
|
||||
def render_const(self, x:Union[float,int], var_dtype) -> str:
|
||||
if math.isnan(x): return "nan()"
|
||||
|
||||
Reference in New Issue
Block a user