From 70b771a175fd1d141492fbdc55ff528349805aaa Mon Sep 17 00:00:00 2001 From: George Hotz Date: Sun, 15 Jan 2023 09:39:22 -0800 Subject: [PATCH] idx idy --- tinygrad/llops/ops_gpu.py | 18 +++++++++++++----- 1 file changed, 13 insertions(+), 5 deletions(-) diff --git a/tinygrad/llops/ops_gpu.py b/tinygrad/llops/ops_gpu.py index 330bbc52d9..198ef61083 100644 --- a/tinygrad/llops/ops_gpu.py +++ b/tinygrad/llops/ops_gpu.py @@ -109,13 +109,14 @@ class CLASTKernel(ASTKernel): def __init__(self, ast:LazyOp): super().__init__(ast) - def compute_buf_index(self, st, buf_index, offset=0): - key = f"{buf_index}_{offset}" + def compute_buf_index(self, st, buf_index, offset=0, div=1, mod=None): + key = f"{buf_index}_{offset}" + (f"_d{div}" if div != 1 else "") + (f"_m{mod}" if mod is not None else "") # add the index if we don't have it if key not in self.seen_idx: + # TODO: do the div and mod in a smarter way idx_pieces = [str(self.offsets[buf_index] + offset)] + [(f"idx{i}*{st}" if st != 1 else f"idx{i}") for i,(sh,st) in enumerate(zip(self.shapes[buf_index][0:self.last_reduce], self.strides[buf_index][0:self.last_reduce])) if sh != 1 and st != 0] + self.kernel.append(f"int bufi{key} = " + '(('+' + '.join(idx_pieces)+f')/{div})' + (f'%{mod};\n' if mod is not None else ';\n')) if st.needs_valid(): self.kernel.append(f"bool bufvalid{key} = true;") - self.kernel.append(f"int bufi{key} = " + '('+' + '.join(idx_pieces)+');\n') if len(st.views) > 1: extra_idx = ';\n '.join([v.expr for v in st.views[0:-1][::-1] if v.expr not in ['', 'idx=idx', 'valid=valid']]) self.kernel.append(extra_idx.replace("//", "/").replace("idx", f"bufi{key}").replace("valid", f"bufvalid{key}") + ";\n") @@ -145,7 +146,6 @@ class CLASTKernel(ASTKernel): if key not in self.loaded_keys: st = self.bufs[buf_index].st if offset > 0: assert len(st.views) == 1 - key = self.compute_buf_index(st, buf_index, offset) # constant folding constant_fold = None @@ -155,9 +155,17 @@ class CLASTKernel(ASTKernel): if isinstance(self.bufs[buf_index]._buf, CLImage): W = self.bufs[buf_index]._base_shape[1] - ldrt = f"read_imagef(data{buf_index}, smp, (int2)(((bufi{key})/4)%{W}, (bufi{key})/{W*4})) /* {self.bufs[buf_index]._base_shape} */" + if len(st.views) == 1: + idx = self.compute_buf_index(st, buf_index, offset, 4, W) + idy = self.compute_buf_index(st, buf_index, offset, W*4, self.bufs[buf_index]._base_shape[0]) + ldrt = f"read_imagef(data{buf_index}, smp, (int2)(bufi{idx}, bufi{idy})) /* {self.bufs[buf_index]._base_shape} */" + else: + self.kernel.append(f"/* computing {st} */\n") + key = self.compute_buf_index(st, buf_index, offset) + ldrt = f"read_imagef(data{buf_index}, smp, (int2)(((bufi{key})/4)%{W}, (bufi{key})/{W*4})) /* {self.bufs[buf_index]._base_shape} */" ldr = Token(f"(bufvalid{key} ? {ldrt} : 0.0)" if st.needs_valid() else ldrt, Types.FLOAT4) else: + key = self.compute_buf_index(st, buf_index, offset) if self.late_are_float4 or (self.early_loads_are_float4 and self.bufs[buf_index] in self.earlybufs): if self.strides[buf_index][-1] == 1 and len(st.views) == 1 and not st.needs_valid(): ldr = Token(f"((__global float4*)data{buf_index})[bufi{key}/4]", Types.FLOAT4)