Move transcendental to UOp level (#5367)

* move uopgraph to file [run_process_replay]

* transcendental uops

* tests pass

* no skip

---------

Co-authored-by: chenyu <chenyu@fastmail.com>
This commit is contained in:
George Hotz
2024-07-10 19:06:25 -07:00
committed by GitHub
parent 64986f949c
commit 0215c952c5
5 changed files with 80 additions and 90 deletions

View File

@@ -1,25 +1,12 @@
"""This is where the forwards and backwards passes live."""
import math
from typing import Tuple, Optional
from tinygrad.helpers import argsort, TRANSCENDENTAL
from tinygrad.helpers import argsort
from tinygrad.dtype import dtypes, DType, sum_acc_dtype
from tinygrad.ops import UnaryOps, BinaryOps, TernaryOps, ReduceOps
from tinygrad.tensor import Function
from tinygrad.lazy import LazyBuffer
from tinygrad.shape.symbolic import sint
from tinygrad.transcendental import xsin, xlog2, xexp2, is_dtype_transcendental_supported
transcendental_supported_devices = ["CLANG", "LLVM"]
def use_transcendental(d:LazyBuffer) -> bool:
# TRANSCENDENTAL=0 to always ignore.
# TRANSCENDENTAL=1 to run only in CLANG/LLVM (default).
# TRANSCENDENTAL=2 to always run it.
if TRANSCENDENTAL >= 2:
return is_dtype_transcendental_supported(d.dtype)
if TRANSCENDENTAL >= 1:
return (is_dtype_transcendental_supported(d.dtype) and
d.device in transcendental_supported_devices)
return False
class Contiguous(Function):
def forward(self, x:LazyBuffer) -> LazyBuffer: return x.contiguous()
@@ -52,11 +39,10 @@ class Reciprocal(Function):
class Sin(Function):
def forward(self, x:LazyBuffer) -> LazyBuffer:
self.x = x
return xsin(x) if use_transcendental(x) else x.e(UnaryOps.SIN)
return x.e(UnaryOps.SIN)
def backward(self, grad_output:LazyBuffer) -> LazyBuffer:
def _xsin(x): return xsin(x) if use_transcendental(x) else x.e(UnaryOps.SIN)
return _xsin(self.x.const(math.pi / 2).e(BinaryOps.ADD, self.x.e(UnaryOps.NEG))).e(BinaryOps.MUL, grad_output)
return self.x.const(math.pi / 2).e(BinaryOps.ADD, self.x.e(UnaryOps.NEG)).e(UnaryOps.SIN).e(BinaryOps.MUL, grad_output)
# NOTE: maximum(x, 0) behaves differently where x=0
class Relu(Function):
@@ -70,15 +56,13 @@ class Relu(Function):
class Log(Function):
def forward(self, x:LazyBuffer) -> LazyBuffer:
self.x = x
def _xlog2(x): return xlog2(x) if use_transcendental(x) else x.e(UnaryOps.LOG2)
return _xlog2(x).e(BinaryOps.MUL, x.const(math.log(2)))
return x.e(UnaryOps.LOG2).e(BinaryOps.MUL, x.const(math.log(2)))
def backward(self, grad_output:LazyBuffer) -> LazyBuffer: return grad_output.e(BinaryOps.MUL, self.x.e(UnaryOps.RECIP))
class Exp(Function):
def forward(self, x:LazyBuffer) -> LazyBuffer:
def _xexp2(x): return xexp2(x) if use_transcendental(x) else x.e(UnaryOps.EXP2)
self.ret = _xexp2(x.e(BinaryOps.MUL, x.const(1/math.log(2))))
self.ret = x.e(BinaryOps.MUL, x.const(1/math.log(2))).e(UnaryOps.EXP2)
return self.ret
def backward(self, grad_output:LazyBuffer) -> LazyBuffer: return self.ret.e(BinaryOps.MUL, grad_output)
@@ -96,8 +80,7 @@ class Sqrt(Function):
# TODO: have the backend automatically find this
class Sigmoid(Function):
def forward(self, x:LazyBuffer) -> LazyBuffer:
def _xexp2(x): return xexp2(x) if use_transcendental(x) else x.e(UnaryOps.EXP2)
self.ret = x.const(1).e(BinaryOps.ADD, _xexp2(x.e(BinaryOps.MUL, x.const(-1/math.log(2))))).e(UnaryOps.RECIP)
self.ret = x.const(1).e(BinaryOps.ADD, x.e(BinaryOps.MUL, x.const(-1/math.log(2))).e(UnaryOps.EXP2)).e(UnaryOps.RECIP)
return self.ret
def backward(self, grad_output:LazyBuffer) -> LazyBuffer: