mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-25 06:48:22 -05:00
move acc_dt out of lazy (#4382)
move the logic to tensor.py for forward, and function.py for two places in backward (expand and max)
This commit is contained in:
@@ -2,7 +2,7 @@
|
||||
import math
|
||||
from typing import Tuple, Optional
|
||||
from tinygrad.helpers import argsort
|
||||
from tinygrad.dtype import DType
|
||||
from tinygrad.dtype import dtypes, DType, sum_acc_dtype
|
||||
from tinygrad.ops import UnaryOps, BinaryOps, TernaryOps, ReduceOps
|
||||
from tinygrad.tensor import Function
|
||||
from tinygrad.lazy import LazyBuffer
|
||||
@@ -146,22 +146,22 @@ class Where(Function):
|
||||
# ************* reduce ops *************
|
||||
|
||||
class Sum(Function):
|
||||
def forward(self, x:LazyBuffer, axis:Tuple[int, ...], acc_dtype:Optional[DType]=None) -> LazyBuffer:
|
||||
def forward(self, x:LazyBuffer, axis:Tuple[int, ...]) -> LazyBuffer:
|
||||
self.input_shape = x.shape
|
||||
return x.r(ReduceOps.SUM, axis, acc_dtype)
|
||||
return x.r(ReduceOps.SUM, axis)
|
||||
|
||||
def backward(self, grad_output:LazyBuffer) -> LazyBuffer: return grad_output.expand(self.input_shape)
|
||||
|
||||
class Max(Function):
|
||||
def forward(self, x:LazyBuffer, axis:Tuple[int, ...], acc_dtype:Optional[DType]=None) -> LazyBuffer:
|
||||
def forward(self, x:LazyBuffer, axis:Tuple[int, ...]) -> LazyBuffer:
|
||||
self.x, self.ret, self.axis = x, x.r(ReduceOps.MAX, axis), axis
|
||||
return self.ret
|
||||
|
||||
def backward(self, grad_output:LazyBuffer) -> LazyBuffer:
|
||||
# 1s in locations where the max was chosen (can be two locations)
|
||||
max_is_1s = self.x.e(BinaryOps.CMPEQ, self.ret.expand(self.x.shape)).cast(self.x.dtype)
|
||||
max_is_1s = self.x.e(BinaryOps.CMPEQ, self.ret.expand(self.x.shape)).cast(dtypes.float)
|
||||
div = max_is_1s.r(ReduceOps.SUM, self.axis).expand(self.x.shape)
|
||||
return max_is_1s.e(BinaryOps.DIV, div).e(BinaryOps.MUL, grad_output.expand(self.x.shape))
|
||||
return max_is_1s.e(BinaryOps.DIV, div).cast(grad_output.dtype).e(BinaryOps.MUL, grad_output.expand(self.x.shape))
|
||||
|
||||
# ************* movement ops *************
|
||||
|
||||
@@ -171,7 +171,8 @@ class Expand(Function):
|
||||
self.expanded_axis = tuple(i for i, (si, so) in enumerate(zip(x.shape, shape)) if si != so)
|
||||
return x.expand(shape)
|
||||
|
||||
def backward(self, grad_output:LazyBuffer) -> LazyBuffer: return grad_output.r(ReduceOps.SUM, self.expanded_axis)
|
||||
def backward(self, grad_output:LazyBuffer) -> LazyBuffer:
|
||||
return grad_output.cast(sum_acc_dtype(grad_output.dtype)).r(ReduceOps.SUM, self.expanded_axis).cast(grad_output.dtype)
|
||||
|
||||
class Reshape(Function):
|
||||
def forward(self, x:LazyBuffer, shape:Tuple[int, ...]) -> LazyBuffer:
|
||||
|
||||
Reference in New Issue
Block a user