acc function not needed

This commit is contained in:
George Hotz
2023-03-29 08:02:11 +04:00
parent 20894991ed
commit b99798f08e

View File

@@ -124,30 +124,17 @@ class Linearizer:
if DEBUG >= 5: print(f"fuse buf {i} {ret} :", check_no_mul(idxy_test, float4_index), idxy_test, idxy_test//4, valid_test//4)
return ret
# TODO: this is very similar to load
def acc(self, ssa, i, idxs:List[Variable], name='acc') -> List[Token]:
should_upcast = self.supports_float4 and self.can_float4(i)
cache: Dict[int, Token] = {}
def op(offset):
if offset in cache: return cache[offset]
will_merge = should_upcast and self.can_merge_float4(i, idxs, offset)
assert self.reduceop is not None
reg = self.uop(UOps.CONST, ssa(name, LocalTypes.float4 if will_merge else LocalTypes.float), [], {ReduceOps.SUM: 0.0, ReduceOps.MAX: -math.inf}[cast(ReduceOps, self.reduceop.op)])
if will_merge:
for j in range(0, 4): cache[offset+j] = Token(reg.name, LocalTypes.float4, j)
else:
cache[offset] = reg
return cache[offset]
return [op(o) for o in self.offsets(i)]
def global_load(self, i, idxs:List[Variable]) -> List[Token]:
def global_load(self, i, idxs:List[Variable], const=None) -> List[Token]:
should_upcast = self.supports_float4 and self.can_float4(i) and self.bufs[i].dtype != dtypes.float16
cache: Dict[int, Token] = {}
def op(offset):
if offset in cache: return cache[offset]
will_merge = should_upcast and self.can_merge_float4(i, idxs, offset)
assert will_merge or not isinstance(self.bufs[i].dtype, ImageDType), "image must merge float4"
reg = self.uop(UOps.LOAD, Token(f"val{mnum(i)}_{mnum(offset)}", LocalTypes.float4 if will_merge else LocalTypes.float), [], MemOp(i, *self.sts[i].expr_idxs(offset, idxs)))
if const is not None:
reg = self.uop(UOps.CONST, Token(f"acc{mnum(i)}_{mnum(offset)}", LocalTypes.float4 if will_merge else LocalTypes.float), [], const)
else:
assert will_merge or not isinstance(self.bufs[i].dtype, ImageDType), "image must merge float4"
reg = self.uop(UOps.LOAD, Token(f"val{mnum(i)}_{mnum(offset)}", LocalTypes.float4 if will_merge else LocalTypes.float), [], MemOp(i, *self.sts[i].expr_idxs(offset, idxs)))
if will_merge:
for j in range(0, 4): cache[offset+j] = Token(reg.name, LocalTypes.float4, j)
else:
@@ -219,7 +206,7 @@ class Linearizer:
# reduce op
if self.reduceop is not None:
# define accumulator
acc = self.acc(ssa, 0, gl_idxs)
acc = self.global_load(0, gl_idxs, {ReduceOps.SUM: 0.0, ReduceOps.MAX: -math.inf}[cast(ReduceOps, self.reduceop.op)])
# reduce loop
reduce_idxs = [Variable(f"ridx{i}", 0, self.full_shape[i]-1) for i in range(self.first_reduce+len(self.group_for_reduce), self.shape_len-self.upcasted)]
@@ -248,7 +235,7 @@ class Linearizer:
# NOTE: this structure is the same as the reduce op above
# define late accumulator
acc = self.acc(ssa, -1, local_idxs, 'lacc')
acc = self.global_load(-1, local_idxs, {ReduceOps.SUM: 0.0, ReduceOps.MAX: -math.inf}[cast(ReduceOps, self.reduceop.op)])
# late reduce loop
end_local_idxs = [Variable(f"tidx{i}", 0, self.full_shape[i]-1 if i >= self.first_reduce else 0) for i in range(0, self.first_reduce+len(self.group_for_reduce))]