From b99798f08ef994b07aad1261a7df84cd9664b360 Mon Sep 17 00:00:00 2001 From: George Hotz Date: Wed, 29 Mar 2023 08:02:11 +0400 Subject: [PATCH] acc function not needed --- tinygrad/codegen/linearizer.py | 29 ++++++++--------------------- 1 file changed, 8 insertions(+), 21 deletions(-) diff --git a/tinygrad/codegen/linearizer.py b/tinygrad/codegen/linearizer.py index 423e3e3f4c..3b39feee47 100644 --- a/tinygrad/codegen/linearizer.py +++ b/tinygrad/codegen/linearizer.py @@ -124,30 +124,17 @@ class Linearizer: if DEBUG >= 5: print(f"fuse buf {i} {ret} :", check_no_mul(idxy_test, float4_index), idxy_test, idxy_test//4, valid_test//4) return ret - # TODO: this is very similar to load - def acc(self, ssa, i, idxs:List[Variable], name='acc') -> List[Token]: - should_upcast = self.supports_float4 and self.can_float4(i) - cache: Dict[int, Token] = {} - def op(offset): - if offset in cache: return cache[offset] - will_merge = should_upcast and self.can_merge_float4(i, idxs, offset) - assert self.reduceop is not None - reg = self.uop(UOps.CONST, ssa(name, LocalTypes.float4 if will_merge else LocalTypes.float), [], {ReduceOps.SUM: 0.0, ReduceOps.MAX: -math.inf}[cast(ReduceOps, self.reduceop.op)]) - if will_merge: - for j in range(0, 4): cache[offset+j] = Token(reg.name, LocalTypes.float4, j) - else: - cache[offset] = reg - return cache[offset] - return [op(o) for o in self.offsets(i)] - - def global_load(self, i, idxs:List[Variable]) -> List[Token]: + def global_load(self, i, idxs:List[Variable], const=None) -> List[Token]: should_upcast = self.supports_float4 and self.can_float4(i) and self.bufs[i].dtype != dtypes.float16 cache: Dict[int, Token] = {} def op(offset): if offset in cache: return cache[offset] will_merge = should_upcast and self.can_merge_float4(i, idxs, offset) - assert will_merge or not isinstance(self.bufs[i].dtype, ImageDType), "image must merge float4" - reg = self.uop(UOps.LOAD, Token(f"val{mnum(i)}_{mnum(offset)}", LocalTypes.float4 if will_merge else LocalTypes.float), [], MemOp(i, *self.sts[i].expr_idxs(offset, idxs))) + if const is not None: + reg = self.uop(UOps.CONST, Token(f"acc{mnum(i)}_{mnum(offset)}", LocalTypes.float4 if will_merge else LocalTypes.float), [], const) + else: + assert will_merge or not isinstance(self.bufs[i].dtype, ImageDType), "image must merge float4" + reg = self.uop(UOps.LOAD, Token(f"val{mnum(i)}_{mnum(offset)}", LocalTypes.float4 if will_merge else LocalTypes.float), [], MemOp(i, *self.sts[i].expr_idxs(offset, idxs))) if will_merge: for j in range(0, 4): cache[offset+j] = Token(reg.name, LocalTypes.float4, j) else: @@ -219,7 +206,7 @@ class Linearizer: # reduce op if self.reduceop is not None: # define accumulator - acc = self.acc(ssa, 0, gl_idxs) + acc = self.global_load(0, gl_idxs, {ReduceOps.SUM: 0.0, ReduceOps.MAX: -math.inf}[cast(ReduceOps, self.reduceop.op)]) # reduce loop reduce_idxs = [Variable(f"ridx{i}", 0, self.full_shape[i]-1) for i in range(self.first_reduce+len(self.group_for_reduce), self.shape_len-self.upcasted)] @@ -248,7 +235,7 @@ class Linearizer: # NOTE: this structure is the same as the reduce op above # define late accumulator - acc = self.acc(ssa, -1, local_idxs, 'lacc') + acc = self.global_load(-1, local_idxs, {ReduceOps.SUM: 0.0, ReduceOps.MAX: -math.inf}[cast(ReduceOps, self.reduceop.op)]) # late reduce loop end_local_idxs = [Variable(f"tidx{i}", 0, self.full_shape[i]-1 if i >= self.first_reduce else 0) for i in range(0, self.first_reduce+len(self.group_for_reduce))]