floats for nvidia

This commit is contained in:
George Hotz
2023-01-23 16:36:10 -08:00
parent 6fe9edf30f
commit c22554f44a

View File

@@ -162,7 +162,7 @@ class CLASTKernel(ASTKernel):
if self.bufs[buf_index].st.needs_valid():
for o in self.buftokens[buf_index].offsets():
_, valid = self.compute_buf_index_symbolic(self.bufs[buf_index].st, buf_index, o)
tokens.append(Token(f"({valid.cl} ? {const.tok} : 0.0)", const.typ) if str(valid) != "1" else const)
tokens.append(Token(f"({valid.cl} ? {const.tok} : 0.0f)", const.typ) if str(valid) != "1" else const)
return tokens
else:
return [const]*self.buftokens[buf_index].size()
@@ -185,7 +185,7 @@ class CLASTKernel(ASTKernel):
ldr = Token(f"({valid.cl} ? \\ \n {ldrt} : (float4)(0.0, 0.0, 0.0, 0.0))" if str(valid) != "1" and valid is not None else ldrt, Types.FLOAT4)
else:
ldr = Token(f"{self.buftokens[buf_index].tok}[{(idxy//(4 if self.buftokens[buf_index].typ == Types.FLOAT4 else 1)).cl}]", self.buftokens[buf_index].typ)
ldr = Token(f"({valid.cl} ? {ldr.tok} : 0.0)", ldr.typ) if str(valid) != "1" else ldr
ldr = Token(f"({valid.cl} ? {ldr.tok} : 0.0f)", ldr.typ) if str(valid) != "1" else ldr
self.kernel.append(f"{ldr.decltype()} val{buf_index}_{o} = {ldr.tok};\n")
self.loaded_keys[(buf_index,o)] = Token(f"val{buf_index}_{o}", ldr.typ)
tokens.append(self.loaded_keys[(buf_index,o)])