mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
Move transcendental to UOp level (#5367)
* move uopgraph to file [run_process_replay] * transcendental uops * tests pass * no skip --------- Co-authored-by: chenyu <chenyu@fastmail.com>
This commit is contained in:
@@ -1,25 +1,12 @@
|
||||
"""This is where the forwards and backwards passes live."""
|
||||
import math
|
||||
from typing import Tuple, Optional
|
||||
from tinygrad.helpers import argsort, TRANSCENDENTAL
|
||||
from tinygrad.helpers import argsort
|
||||
from tinygrad.dtype import dtypes, DType, sum_acc_dtype
|
||||
from tinygrad.ops import UnaryOps, BinaryOps, TernaryOps, ReduceOps
|
||||
from tinygrad.tensor import Function
|
||||
from tinygrad.lazy import LazyBuffer
|
||||
from tinygrad.shape.symbolic import sint
|
||||
from tinygrad.transcendental import xsin, xlog2, xexp2, is_dtype_transcendental_supported
|
||||
|
||||
transcendental_supported_devices = ["CLANG", "LLVM"]
|
||||
def use_transcendental(d:LazyBuffer) -> bool:
|
||||
# TRANSCENDENTAL=0 to always ignore.
|
||||
# TRANSCENDENTAL=1 to run only in CLANG/LLVM (default).
|
||||
# TRANSCENDENTAL=2 to always run it.
|
||||
if TRANSCENDENTAL >= 2:
|
||||
return is_dtype_transcendental_supported(d.dtype)
|
||||
if TRANSCENDENTAL >= 1:
|
||||
return (is_dtype_transcendental_supported(d.dtype) and
|
||||
d.device in transcendental_supported_devices)
|
||||
return False
|
||||
|
||||
class Contiguous(Function):
|
||||
def forward(self, x:LazyBuffer) -> LazyBuffer: return x.contiguous()
|
||||
@@ -52,11 +39,10 @@ class Reciprocal(Function):
|
||||
class Sin(Function):
|
||||
def forward(self, x:LazyBuffer) -> LazyBuffer:
|
||||
self.x = x
|
||||
return xsin(x) if use_transcendental(x) else x.e(UnaryOps.SIN)
|
||||
return x.e(UnaryOps.SIN)
|
||||
|
||||
def backward(self, grad_output:LazyBuffer) -> LazyBuffer:
|
||||
def _xsin(x): return xsin(x) if use_transcendental(x) else x.e(UnaryOps.SIN)
|
||||
return _xsin(self.x.const(math.pi / 2).e(BinaryOps.ADD, self.x.e(UnaryOps.NEG))).e(BinaryOps.MUL, grad_output)
|
||||
return self.x.const(math.pi / 2).e(BinaryOps.ADD, self.x.e(UnaryOps.NEG)).e(UnaryOps.SIN).e(BinaryOps.MUL, grad_output)
|
||||
|
||||
# NOTE: maximum(x, 0) behaves differently where x=0
|
||||
class Relu(Function):
|
||||
@@ -70,15 +56,13 @@ class Relu(Function):
|
||||
class Log(Function):
|
||||
def forward(self, x:LazyBuffer) -> LazyBuffer:
|
||||
self.x = x
|
||||
def _xlog2(x): return xlog2(x) if use_transcendental(x) else x.e(UnaryOps.LOG2)
|
||||
return _xlog2(x).e(BinaryOps.MUL, x.const(math.log(2)))
|
||||
return x.e(UnaryOps.LOG2).e(BinaryOps.MUL, x.const(math.log(2)))
|
||||
|
||||
def backward(self, grad_output:LazyBuffer) -> LazyBuffer: return grad_output.e(BinaryOps.MUL, self.x.e(UnaryOps.RECIP))
|
||||
|
||||
class Exp(Function):
|
||||
def forward(self, x:LazyBuffer) -> LazyBuffer:
|
||||
def _xexp2(x): return xexp2(x) if use_transcendental(x) else x.e(UnaryOps.EXP2)
|
||||
self.ret = _xexp2(x.e(BinaryOps.MUL, x.const(1/math.log(2))))
|
||||
self.ret = x.e(BinaryOps.MUL, x.const(1/math.log(2))).e(UnaryOps.EXP2)
|
||||
return self.ret
|
||||
|
||||
def backward(self, grad_output:LazyBuffer) -> LazyBuffer: return self.ret.e(BinaryOps.MUL, grad_output)
|
||||
@@ -96,8 +80,7 @@ class Sqrt(Function):
|
||||
# TODO: have the backend automatically find this
|
||||
class Sigmoid(Function):
|
||||
def forward(self, x:LazyBuffer) -> LazyBuffer:
|
||||
def _xexp2(x): return xexp2(x) if use_transcendental(x) else x.e(UnaryOps.EXP2)
|
||||
self.ret = x.const(1).e(BinaryOps.ADD, _xexp2(x.e(BinaryOps.MUL, x.const(-1/math.log(2))))).e(UnaryOps.RECIP)
|
||||
self.ret = x.const(1).e(BinaryOps.ADD, x.e(BinaryOps.MUL, x.const(-1/math.log(2))).e(UnaryOps.EXP2)).e(UnaryOps.RECIP)
|
||||
return self.ret
|
||||
|
||||
def backward(self, grad_output:LazyBuffer) -> LazyBuffer:
|
||||
|
||||
Reference in New Issue
Block a user