add returns between views

This commit is contained in:
George Hotz
2023-01-15 08:58:10 -08:00
parent 287699c32c
commit 7ea89779fa

View File

@@ -113,11 +113,11 @@ class CLASTKernel(ASTKernel):
key = f"{buf_index}_{offset}"
# add the index if we don't have it
if key not in self.seen_idx:
idx_pieces = [str(st.offset + offset)] + [(f"idx{i}*{st}" if st != 1 else f"idx{i}") for i,(sh,st) in enumerate(zip(self.shapes[buf_index][0:self.last_reduce], self.strides[buf_index][0:self.last_reduce])) if sh != 1 and st != 0]
idx_pieces = [str(self.offsets[buf_index] + offset)] + [(f"idx{i}*{st}" if st != 1 else f"idx{i}") for i,(sh,st) in enumerate(zip(self.shapes[buf_index][0:self.last_reduce], self.strides[buf_index][0:self.last_reduce])) if sh != 1 and st != 0]
if st.needs_valid(): self.kernel.append(f"bool bufvalid{key} = true;")
self.kernel.append(f"int bufi{key} = " + '('+' + '.join(idx_pieces)+');\n')
if len(st.views) > 1:
extra_idx = ';'.join([v.expr for v in st.views[0:-1][::-1] if v.expr not in ['', 'idx=idx', 'valid=valid']])
extra_idx = ';\n '.join([v.expr for v in st.views[0:-1][::-1] if v.expr not in ['', 'idx=idx', 'valid=valid']])
self.kernel.append(extra_idx.replace("//", "/").replace("idx", f"bufi{key}").replace("valid", f"bufvalid{key}") + ";\n")
self.seen_idx.add(key)
return key