mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-24 06:18:01 -05:00
add returns between views
This commit is contained in:
@@ -113,11 +113,11 @@ class CLASTKernel(ASTKernel):
|
||||
key = f"{buf_index}_{offset}"
|
||||
# add the index if we don't have it
|
||||
if key not in self.seen_idx:
|
||||
idx_pieces = [str(st.offset + offset)] + [(f"idx{i}*{st}" if st != 1 else f"idx{i}") for i,(sh,st) in enumerate(zip(self.shapes[buf_index][0:self.last_reduce], self.strides[buf_index][0:self.last_reduce])) if sh != 1 and st != 0]
|
||||
idx_pieces = [str(self.offsets[buf_index] + offset)] + [(f"idx{i}*{st}" if st != 1 else f"idx{i}") for i,(sh,st) in enumerate(zip(self.shapes[buf_index][0:self.last_reduce], self.strides[buf_index][0:self.last_reduce])) if sh != 1 and st != 0]
|
||||
if st.needs_valid(): self.kernel.append(f"bool bufvalid{key} = true;")
|
||||
self.kernel.append(f"int bufi{key} = " + '('+' + '.join(idx_pieces)+');\n')
|
||||
if len(st.views) > 1:
|
||||
extra_idx = ';'.join([v.expr for v in st.views[0:-1][::-1] if v.expr not in ['', 'idx=idx', 'valid=valid']])
|
||||
extra_idx = ';\n '.join([v.expr for v in st.views[0:-1][::-1] if v.expr not in ['', 'idx=idx', 'valid=valid']])
|
||||
self.kernel.append(extra_idx.replace("//", "/").replace("idx", f"bufi{key}").replace("valid", f"bufvalid{key}") + ";\n")
|
||||
self.seen_idx.add(key)
|
||||
return key
|
||||
|
||||
Reference in New Issue
Block a user