out_tokens -> grouped in linearizer (#2989)

no more token now
This commit is contained in:
chenyu
2024-01-03 11:45:28 -05:00
committed by GitHub
parent f494b9d463
commit 590268fa03

View File

@@ -158,11 +158,11 @@ class Linearizer(Kernel):
_idx = k[:upcast_dim[0]] + (float4_expand[0],) + k[upcast_dim[0]+1:]
grouped_store_offset[_idx].append(store_offset[k])
store_offset_new = {}
for k,out_tokens in grouped_store_offset.items():
amt = len(out_tokens)
for k,grouped in grouped_store_offset.items():
amt = len(grouped)
idx, valid = self.sts[i].expr_idxs(k)
assert idx == ((idx//amt)*amt), "float4 stores are always aligned"
store_offset_new[k] = self.uop(UOps.CAST, buf.dtype.vec(amt), tuple(out_tokens))
store_offset_new[k] = self.uop(UOps.CAST, buf.dtype.vec(amt), tuple(grouped))
store_offset = store_offset_new
stores = []