mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-24 22:38:16 -05:00
acc function not needed
This commit is contained in:
@@ -124,30 +124,17 @@ class Linearizer:
|
||||
if DEBUG >= 5: print(f"fuse buf {i} {ret} :", check_no_mul(idxy_test, float4_index), idxy_test, idxy_test//4, valid_test//4)
|
||||
return ret
|
||||
|
||||
# TODO: this is very similar to load
|
||||
def acc(self, ssa, i, idxs:List[Variable], name='acc') -> List[Token]:
|
||||
should_upcast = self.supports_float4 and self.can_float4(i)
|
||||
cache: Dict[int, Token] = {}
|
||||
def op(offset):
|
||||
if offset in cache: return cache[offset]
|
||||
will_merge = should_upcast and self.can_merge_float4(i, idxs, offset)
|
||||
assert self.reduceop is not None
|
||||
reg = self.uop(UOps.CONST, ssa(name, LocalTypes.float4 if will_merge else LocalTypes.float), [], {ReduceOps.SUM: 0.0, ReduceOps.MAX: -math.inf}[cast(ReduceOps, self.reduceop.op)])
|
||||
if will_merge:
|
||||
for j in range(0, 4): cache[offset+j] = Token(reg.name, LocalTypes.float4, j)
|
||||
else:
|
||||
cache[offset] = reg
|
||||
return cache[offset]
|
||||
return [op(o) for o in self.offsets(i)]
|
||||
|
||||
def global_load(self, i, idxs:List[Variable]) -> List[Token]:
|
||||
def global_load(self, i, idxs:List[Variable], const=None) -> List[Token]:
|
||||
should_upcast = self.supports_float4 and self.can_float4(i) and self.bufs[i].dtype != dtypes.float16
|
||||
cache: Dict[int, Token] = {}
|
||||
def op(offset):
|
||||
if offset in cache: return cache[offset]
|
||||
will_merge = should_upcast and self.can_merge_float4(i, idxs, offset)
|
||||
assert will_merge or not isinstance(self.bufs[i].dtype, ImageDType), "image must merge float4"
|
||||
reg = self.uop(UOps.LOAD, Token(f"val{mnum(i)}_{mnum(offset)}", LocalTypes.float4 if will_merge else LocalTypes.float), [], MemOp(i, *self.sts[i].expr_idxs(offset, idxs)))
|
||||
if const is not None:
|
||||
reg = self.uop(UOps.CONST, Token(f"acc{mnum(i)}_{mnum(offset)}", LocalTypes.float4 if will_merge else LocalTypes.float), [], const)
|
||||
else:
|
||||
assert will_merge or not isinstance(self.bufs[i].dtype, ImageDType), "image must merge float4"
|
||||
reg = self.uop(UOps.LOAD, Token(f"val{mnum(i)}_{mnum(offset)}", LocalTypes.float4 if will_merge else LocalTypes.float), [], MemOp(i, *self.sts[i].expr_idxs(offset, idxs)))
|
||||
if will_merge:
|
||||
for j in range(0, 4): cache[offset+j] = Token(reg.name, LocalTypes.float4, j)
|
||||
else:
|
||||
@@ -219,7 +206,7 @@ class Linearizer:
|
||||
# reduce op
|
||||
if self.reduceop is not None:
|
||||
# define accumulator
|
||||
acc = self.acc(ssa, 0, gl_idxs)
|
||||
acc = self.global_load(0, gl_idxs, {ReduceOps.SUM: 0.0, ReduceOps.MAX: -math.inf}[cast(ReduceOps, self.reduceop.op)])
|
||||
|
||||
# reduce loop
|
||||
reduce_idxs = [Variable(f"ridx{i}", 0, self.full_shape[i]-1) for i in range(self.first_reduce+len(self.group_for_reduce), self.shape_len-self.upcasted)]
|
||||
@@ -248,7 +235,7 @@ class Linearizer:
|
||||
# NOTE: this structure is the same as the reduce op above
|
||||
|
||||
# define late accumulator
|
||||
acc = self.acc(ssa, -1, local_idxs, 'lacc')
|
||||
acc = self.global_load(-1, local_idxs, {ReduceOps.SUM: 0.0, ReduceOps.MAX: -math.inf}[cast(ReduceOps, self.reduceop.op)])
|
||||
|
||||
# late reduce loop
|
||||
end_local_idxs = [Variable(f"tidx{i}", 0, self.full_shape[i]-1 if i >= self.first_reduce else 0) for i in range(0, self.first_reduce+len(self.group_for_reduce))]
|
||||
|
||||
Reference in New Issue
Block a user