move acc_dt out of lazy (#4382)

move the logic to tensor.py for forward, and function.py for two places in backward (expand and max)
This commit is contained in:
chenyu
2024-05-06 10:41:25 -04:00
committed by GitHub
parent 113c2f00b9
commit 292ce64ad7
5 changed files with 29 additions and 29 deletions

View File

@@ -2,7 +2,7 @@
import math
from typing import Tuple, Optional
from tinygrad.helpers import argsort
from tinygrad.dtype import DType
from tinygrad.dtype import dtypes, DType, sum_acc_dtype
from tinygrad.ops import UnaryOps, BinaryOps, TernaryOps, ReduceOps
from tinygrad.tensor import Function
from tinygrad.lazy import LazyBuffer
@@ -146,22 +146,22 @@ class Where(Function):
# ************* reduce ops *************
class Sum(Function):
def forward(self, x:LazyBuffer, axis:Tuple[int, ...], acc_dtype:Optional[DType]=None) -> LazyBuffer:
def forward(self, x:LazyBuffer, axis:Tuple[int, ...]) -> LazyBuffer:
self.input_shape = x.shape
return x.r(ReduceOps.SUM, axis, acc_dtype)
return x.r(ReduceOps.SUM, axis)
def backward(self, grad_output:LazyBuffer) -> LazyBuffer: return grad_output.expand(self.input_shape)
class Max(Function):
def forward(self, x:LazyBuffer, axis:Tuple[int, ...], acc_dtype:Optional[DType]=None) -> LazyBuffer:
def forward(self, x:LazyBuffer, axis:Tuple[int, ...]) -> LazyBuffer:
self.x, self.ret, self.axis = x, x.r(ReduceOps.MAX, axis), axis
return self.ret
def backward(self, grad_output:LazyBuffer) -> LazyBuffer:
# 1s in locations where the max was chosen (can be two locations)
max_is_1s = self.x.e(BinaryOps.CMPEQ, self.ret.expand(self.x.shape)).cast(self.x.dtype)
max_is_1s = self.x.e(BinaryOps.CMPEQ, self.ret.expand(self.x.shape)).cast(dtypes.float)
div = max_is_1s.r(ReduceOps.SUM, self.axis).expand(self.x.shape)
return max_is_1s.e(BinaryOps.DIV, div).e(BinaryOps.MUL, grad_output.expand(self.x.shape))
return max_is_1s.e(BinaryOps.DIV, div).cast(grad_output.dtype).e(BinaryOps.MUL, grad_output.expand(self.x.shape))
# ************* movement ops *************
@@ -171,7 +171,8 @@ class Expand(Function):
self.expanded_axis = tuple(i for i, (si, so) in enumerate(zip(x.shape, shape)) if si != so)
return x.expand(shape)
def backward(self, grad_output:LazyBuffer) -> LazyBuffer: return grad_output.r(ReduceOps.SUM, self.expanded_axis)
def backward(self, grad_output:LazyBuffer) -> LazyBuffer:
return grad_output.cast(sum_acc_dtype(grad_output.dtype)).r(ReduceOps.SUM, self.expanded_axis).cast(grad_output.dtype)
class Reshape(Function):
def forward(self, x:LazyBuffer, shape:Tuple[int, ...]) -> LazyBuffer: