render_locals takes a dtype (#2873)

Co-authored-by: chenyu <chenyu@fastmail.com>
This commit is contained in:
qazal
2023-12-21 21:15:28 +02:00
committed by GitHub
parent 6116039f7b
commit d2e9245de8

View File

@@ -68,8 +68,8 @@ class CStyleLanguage(NamedTuple):
return self.render_cast([out_val], output_dtype) if output_dtype != buf_dtype else out_val
def render_local(self, name:str, size:int):
return self.smem_align + self.smem_prefix + f"float {name}[{size}];"
def render_local(self, name:str, dtype:DType, size:int):
return self.smem_align + self.smem_prefix + f"{dtype.name} {name}[{size}];"
def render_for(self, expr: str, _min:Union[int,str], _max:Union[int,str]) -> str:
return f"for ({self.generic_var_prefix if self.generic_var_prefix else 'int'} {expr} = {_min}; {expr} < {_max}; {expr}++) {{"
@@ -190,10 +190,11 @@ def uops_to_cstyle(lang:CStyleLanguage, function_name:str, uops:List[UOp]) -> Tu
if child_count[u] <= 1: r[u] = val
else: kk(f"{lang.generic_var_prefix if lang.generic_var_prefix else dtype.name} {ssa(u,'cast')} = {val};")
elif uop == UOps.DEFINE_LOCAL:
assert dtype is not None
if lang.external_local_bufs:
prekernel.append(lang.render_local(args[0], args[1]))
prekernel.append(lang.render_local(args[0], dtype, args[1]))
else:
kk(lang.render_local(args[0], args[1]))
kk(lang.render_local(args[0], dtype, args[1]))
r[u] = args[0]
elif uop == UOps.DEFINE_GLOBAL:
assert dtype is not None
@@ -309,8 +310,8 @@ class WGSLLanguage(CStyleLanguage):
# HACK: write bool as f32. remove after elementwise op cast inputs properly
type_map = {dtypes.float: "f32", dtypes.half: "f16", dtypes.int32: "i32", dtypes.uint32: "u32", dtypes.bool: "f32"}
def render_local(self, name: str, size: int):
return f"var<workgroup> {name}: array<f32,{size}>;"
def render_local(self, name: str, dtype:DType, size: int):
return f"var<workgroup> {name}: array<{self.type_map[dtype]},{size}>;"
def render_const(self, x:Union[float,int], var_dtype) -> str:
if math.isnan(x): return "nan()"