mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-23 22:08:08 -05:00
idx idy
This commit is contained in:
@@ -109,13 +109,14 @@ class CLASTKernel(ASTKernel):
|
||||
def __init__(self, ast:LazyOp):
|
||||
super().__init__(ast)
|
||||
|
||||
def compute_buf_index(self, st, buf_index, offset=0):
|
||||
key = f"{buf_index}_{offset}"
|
||||
def compute_buf_index(self, st, buf_index, offset=0, div=1, mod=None):
|
||||
key = f"{buf_index}_{offset}" + (f"_d{div}" if div != 1 else "") + (f"_m{mod}" if mod is not None else "")
|
||||
# add the index if we don't have it
|
||||
if key not in self.seen_idx:
|
||||
# TODO: do the div and mod in a smarter way
|
||||
idx_pieces = [str(self.offsets[buf_index] + offset)] + [(f"idx{i}*{st}" if st != 1 else f"idx{i}") for i,(sh,st) in enumerate(zip(self.shapes[buf_index][0:self.last_reduce], self.strides[buf_index][0:self.last_reduce])) if sh != 1 and st != 0]
|
||||
self.kernel.append(f"int bufi{key} = " + '(('+' + '.join(idx_pieces)+f')/{div})' + (f'%{mod};\n' if mod is not None else ';\n'))
|
||||
if st.needs_valid(): self.kernel.append(f"bool bufvalid{key} = true;")
|
||||
self.kernel.append(f"int bufi{key} = " + '('+' + '.join(idx_pieces)+');\n')
|
||||
if len(st.views) > 1:
|
||||
extra_idx = ';\n '.join([v.expr for v in st.views[0:-1][::-1] if v.expr not in ['', 'idx=idx', 'valid=valid']])
|
||||
self.kernel.append(extra_idx.replace("//", "/").replace("idx", f"bufi{key}").replace("valid", f"bufvalid{key}") + ";\n")
|
||||
@@ -145,7 +146,6 @@ class CLASTKernel(ASTKernel):
|
||||
if key not in self.loaded_keys:
|
||||
st = self.bufs[buf_index].st
|
||||
if offset > 0: assert len(st.views) == 1
|
||||
key = self.compute_buf_index(st, buf_index, offset)
|
||||
|
||||
# constant folding
|
||||
constant_fold = None
|
||||
@@ -155,9 +155,17 @@ class CLASTKernel(ASTKernel):
|
||||
|
||||
if isinstance(self.bufs[buf_index]._buf, CLImage):
|
||||
W = self.bufs[buf_index]._base_shape[1]
|
||||
ldrt = f"read_imagef(data{buf_index}, smp, (int2)(((bufi{key})/4)%{W}, (bufi{key})/{W*4})) /* {self.bufs[buf_index]._base_shape} */"
|
||||
if len(st.views) == 1:
|
||||
idx = self.compute_buf_index(st, buf_index, offset, 4, W)
|
||||
idy = self.compute_buf_index(st, buf_index, offset, W*4, self.bufs[buf_index]._base_shape[0])
|
||||
ldrt = f"read_imagef(data{buf_index}, smp, (int2)(bufi{idx}, bufi{idy})) /* {self.bufs[buf_index]._base_shape} */"
|
||||
else:
|
||||
self.kernel.append(f"/* computing {st} */\n")
|
||||
key = self.compute_buf_index(st, buf_index, offset)
|
||||
ldrt = f"read_imagef(data{buf_index}, smp, (int2)(((bufi{key})/4)%{W}, (bufi{key})/{W*4})) /* {self.bufs[buf_index]._base_shape} */"
|
||||
ldr = Token(f"(bufvalid{key} ? {ldrt} : 0.0)" if st.needs_valid() else ldrt, Types.FLOAT4)
|
||||
else:
|
||||
key = self.compute_buf_index(st, buf_index, offset)
|
||||
if self.late_are_float4 or (self.early_loads_are_float4 and self.bufs[buf_index] in self.earlybufs):
|
||||
if self.strides[buf_index][-1] == 1 and len(st.views) == 1 and not st.needs_valid():
|
||||
ldr = Token(f"((__global float4*)data{buf_index})[bufi{key}/4]", Types.FLOAT4)
|
||||
|
||||
Reference in New Issue
Block a user