ugh, that's getting ugle

This commit is contained in:
George Hotz
2023-03-10 17:41:19 -08:00
parent 4780f9a6df
commit c7d17c25d9

View File

@@ -73,7 +73,7 @@ class GPUCodegen(ASTKernel):
assert len(self.sts[buf_index].views) == 1, "store has more than one view"
# all stores can merge, since they have one view and are valid
should_upcast = self.lang.float4 and self.buftokens[buf_index].can_float4() and (self.bufs[buf_index].dtype != dtypes.float16 and not hasattr(self.bufs[buf_index]._buf, "IMAGE"))
should_upcast = self.lang.float4 and self.buftokens[buf_index].can_float4() and (self.bufs[buf_index].dtype != dtypes.float16 or hasattr(self.bufs[buf_index]._buf, "IMAGE"))
to_store = {o:v for o,v in zip(self.buftokens[buf_index].offsets(), value)}
did_store = set()
@@ -101,7 +101,7 @@ class GPUCodegen(ASTKernel):
val = self.bufs[buf_index]._backing[0]
assert not math.isnan(val)
const = Token(f"({val}f)", Types.FLOAT)
should_upcast = self.lang.float4 and const is None and self.buftokens[buf_index].can_float4() and (self.bufs[buf_index].dtype != dtypes.float16 and not hasattr(self.bufs[buf_index]._buf, "IMAGE"))
should_upcast = self.lang.float4 and const is None and self.buftokens[buf_index].can_float4() and (self.bufs[buf_index] is None or self.bufs[buf_index].dtype != dtypes.float16 or hasattr(self.bufs[buf_index]._buf, "IMAGE"))
tokens = []
test_idy = []
for o in self.buftokens[buf_index].offsets():