mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
@@ -4,7 +4,6 @@ import numpy as np
|
||||
from typing import Dict, Tuple, Union, List, NamedTuple, Final, Iterator, ClassVar, Optional, Callable, Any, Iterable
|
||||
from math import prod # noqa: F401 # pylint:disable=unused-import
|
||||
|
||||
ShapeType = Tuple[int, ...]
|
||||
# NOTE: helpers is not allowed to import from anything else in tinygrad
|
||||
OSX = platform.system() == "Darwin"
|
||||
CI = os.getenv("CI", "") != ""
|
||||
@@ -93,7 +92,7 @@ class dtypes:
|
||||
def from_np(x) -> DType: return DTYPES_DICT[np.dtype(x).name]
|
||||
@staticmethod
|
||||
def fields() -> Dict[str, DType]: return DTYPES_DICT
|
||||
bool: Final[DType] = DType(0, 1, "bool", bool)
|
||||
bool: Final[DType] = DType(0, 1, "bool", np.bool_)
|
||||
float16: Final[DType] = DType(0, 2, "half", np.float16)
|
||||
half = float16
|
||||
float32: Final[DType] = DType(4, 4, "float", np.float32)
|
||||
|
||||
@@ -195,20 +195,20 @@ class LazyBuffer:
|
||||
assert not arg[1] or self.dtype.itemsize == arg[0].itemsize, "can't bitcast mismatched dtype itemsizes"
|
||||
return elementwise_op(UnaryOps.CAST, self, arg=arg) if self.dtype != arg[0] else self
|
||||
def unary_op(self:LazyBuffer, op:UnaryOps) -> LazyBuffer: return elementwise_op(op, self)
|
||||
def binary_op(self:LazyBuffer, op:BinaryOps, y:Union[LazyBuffer, float, int]) -> LazyBuffer: return elementwise_op(op, self, y)
|
||||
def ternary_op(self:LazyBuffer, op:TernaryOps, y:Union[LazyBuffer, float, int], z:Union[LazyBuffer, float, int]) -> LazyBuffer: return elementwise_op(op, self, y, z)
|
||||
def binary_op(self:LazyBuffer, op:BinaryOps, y:LazyBuffer) -> LazyBuffer: return elementwise_op(op, self, y)
|
||||
def ternary_op(self:LazyBuffer, op:TernaryOps, y:LazyBuffer, z:LazyBuffer) -> LazyBuffer: return elementwise_op(op, self, y, z)
|
||||
|
||||
def __add__(self, y:Union[LazyBuffer, float, int]) -> LazyBuffer: return elementwise_op(BinaryOps.ADD, self, y)
|
||||
def __radd__(self, y:Union[LazyBuffer, float, int]) -> LazyBuffer: return elementwise_op(BinaryOps.ADD, y, self)
|
||||
def __mul__(self, y:Union[LazyBuffer, float, int]) -> LazyBuffer: return elementwise_op(BinaryOps.MUL, self, y)
|
||||
def __rmul__(self, y:Union[LazyBuffer, float, int]) -> LazyBuffer: return elementwise_op(BinaryOps.MUL, y, self)
|
||||
def __truediv__(self, y:Union[LazyBuffer, float, int]) -> LazyBuffer: return elementwise_op(BinaryOps.DIV, self, y)
|
||||
def __rtruediv__(self, y:Union[LazyBuffer, float, int]) -> LazyBuffer: return elementwise_op(BinaryOps.DIV, y, self)
|
||||
def __sub__(self, y:Union[LazyBuffer, float, int]) -> LazyBuffer: return elementwise_op(BinaryOps.SUB, self, y)
|
||||
def __rsub__(self, y:Union[LazyBuffer, float, int]) -> LazyBuffer: return elementwise_op(BinaryOps.SUB, y, self)
|
||||
def __lt__(self, y:Union[LazyBuffer, float, int]) -> LazyBuffer: return elementwise_op(BinaryOps.CMPLT, self, y)
|
||||
def __gt__(self, y:Union[LazyBuffer, float, int]) -> LazyBuffer: return elementwise_op(BinaryOps.CMPLT, y, self)
|
||||
def __neg__(self) -> LazyBuffer: return 0.0-self
|
||||
def __add__(self, y:LazyBuffer) -> LazyBuffer: return elementwise_op(BinaryOps.ADD, self, y)
|
||||
def __radd__(self, y:LazyBuffer) -> LazyBuffer: return elementwise_op(BinaryOps.ADD, y, self)
|
||||
def __mul__(self, y:LazyBuffer) -> LazyBuffer: return elementwise_op(BinaryOps.MUL, self, y)
|
||||
def __rmul__(self, y:LazyBuffer) -> LazyBuffer: return elementwise_op(BinaryOps.MUL, y, self)
|
||||
def __truediv__(self, y:LazyBuffer) -> LazyBuffer: return elementwise_op(BinaryOps.DIV, self, y)
|
||||
def __rtruediv__(self, y:LazyBuffer) -> LazyBuffer: return elementwise_op(BinaryOps.DIV, y, self)
|
||||
def __sub__(self, y:LazyBuffer) -> LazyBuffer: return elementwise_op(BinaryOps.SUB, self, y)
|
||||
def __rsub__(self, y:LazyBuffer) -> LazyBuffer: return elementwise_op(BinaryOps.SUB, y, self)
|
||||
def __lt__(self, y:LazyBuffer) -> LazyBuffer: return elementwise_op(BinaryOps.CMPLT, self, y)
|
||||
def __gt__(self, y:LazyBuffer) -> LazyBuffer: return elementwise_op(BinaryOps.CMPLT, y, self)
|
||||
def __neg__(self) -> LazyBuffer: return self.const_like(0.0)-self
|
||||
|
||||
def contiguous(self:LazyBuffer) -> LazyBuffer:
|
||||
if not self.realized and self.op.op == LoadOps.CONTIGUOUS: return self # two CONTIGUOUS in a row is one
|
||||
@@ -316,11 +316,7 @@ def _push_movement_ops(srcs:Tuple[LazyBuffer, ...]) -> Tuple[LazyBuffer, ...]:
|
||||
new_srcs.append(x)
|
||||
return tuple(new_srcs)
|
||||
|
||||
def elementwise_op(op:Union[UnaryOps, BinaryOps, TernaryOps], *_srcs:Union[LazyBuffer, float, int], arg:Optional[Any]=None) -> LazyBuffer:
|
||||
# make them all LazyBuffers
|
||||
first_src = [x for x in _srcs if isinstance(x, LazyBuffer)][0]
|
||||
srcs:Tuple[LazyBuffer, ...] = tuple(x if isinstance(x, LazyBuffer) else first_src.const_like(x) for x in _srcs)
|
||||
|
||||
def elementwise_op(op:Union[UnaryOps, BinaryOps, TernaryOps], *srcs:LazyBuffer, arg:Optional[Any]=None) -> LazyBuffer:
|
||||
# if we are separated from other binary ops by movement ops, we push those movement ops above those binaryops
|
||||
if SHUFFLE_MOVEMENT_OPS: srcs = _push_movement_ops(srcs)
|
||||
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
from typing import Tuple, Optional
|
||||
from tinygrad.helpers import argsort, ShapeType, DType
|
||||
from tinygrad.helpers import argsort, DType
|
||||
from tinygrad.ops import UnaryOps, BinaryOps, TernaryOps, ReduceOps
|
||||
from tinygrad.tensor import Function
|
||||
from tinygrad.lazy import LazyBuffer
|
||||
@@ -23,28 +23,28 @@ class Sin(Function):
|
||||
self.x = x
|
||||
return x.unary_op(UnaryOps.SIN)
|
||||
def backward(self, grad:LazyBuffer) -> LazyBuffer:
|
||||
return ((math.pi / 2) - self.x).unary_op(UnaryOps.SIN) * grad
|
||||
return (self.x.const_like(math.pi / 2) - self.x).unary_op(UnaryOps.SIN) * grad
|
||||
|
||||
# NOTE: maximum(x, 0) behaves differently where x=0
|
||||
class Relu(Function):
|
||||
def forward(self, x:LazyBuffer) -> LazyBuffer:
|
||||
self.ret = x.binary_op(BinaryOps.MAX, 0)
|
||||
self.ret = x.binary_op(BinaryOps.MAX, x.const_like(0))
|
||||
return self.ret
|
||||
|
||||
def backward(self, grad_output:LazyBuffer) -> LazyBuffer:
|
||||
return (0 < self.ret) * grad_output
|
||||
return (self.ret.const_like(0) < self.ret) * grad_output
|
||||
|
||||
class Log(Function):
|
||||
def forward(self, x:LazyBuffer) -> LazyBuffer:
|
||||
self.x = x
|
||||
return x.unary_op(UnaryOps.LOG2) * math.log(2)
|
||||
return x.unary_op(UnaryOps.LOG2) * x.const_like(math.log(2))
|
||||
|
||||
def backward(self, grad_output:LazyBuffer) -> LazyBuffer:
|
||||
return grad_output / self.x
|
||||
|
||||
class Exp(Function):
|
||||
def forward(self, x:LazyBuffer) -> LazyBuffer:
|
||||
self.ret = (x * (1/math.log(2))).unary_op(UnaryOps.EXP2)
|
||||
self.ret = (x * x.const_like(1/math.log(2))).unary_op(UnaryOps.EXP2)
|
||||
return self.ret
|
||||
|
||||
def backward(self, grad_output:LazyBuffer) -> LazyBuffer:
|
||||
@@ -56,23 +56,23 @@ class Sqrt(Function):
|
||||
return self.ret
|
||||
|
||||
def backward(self, grad_output:LazyBuffer) -> LazyBuffer:
|
||||
return grad_output / (self.ret * 2)
|
||||
return grad_output / (self.ret * self.ret.const_like(2))
|
||||
|
||||
# NOTE: the implicit derivative of sigmoid is not stable
|
||||
# https://towardsdatascience.com/derivative-of-the-sigmoid-function-536880cf918e
|
||||
# TODO: have the backend automatically find this
|
||||
class Sigmoid(Function):
|
||||
def forward(self, x:LazyBuffer) -> LazyBuffer:
|
||||
self.ret = 1 / (1 + (x * (-1/math.log(2))).unary_op(UnaryOps.EXP2))
|
||||
self.ret = x.const_like(1) / (x.const_like(1) + (x * x.const_like(-1/math.log(2))).unary_op(UnaryOps.EXP2))
|
||||
return self.ret
|
||||
|
||||
def backward(self, grad_output:LazyBuffer) -> LazyBuffer:
|
||||
return (self.ret * (1 - self.ret)) * grad_output
|
||||
return (self.ret * (self.ret.const_like(1) - self.ret)) * grad_output
|
||||
|
||||
# ************* reduce ops *************
|
||||
|
||||
class Sum(Function):
|
||||
def forward(self, x:LazyBuffer, new_shape:ShapeType) -> LazyBuffer:
|
||||
def forward(self, x:LazyBuffer, new_shape:Tuple[int, ...]) -> LazyBuffer:
|
||||
self.input_shape = x.shape
|
||||
return x.reduce_op(ReduceOps.SUM, new_shape)
|
||||
|
||||
@@ -80,13 +80,13 @@ class Sum(Function):
|
||||
return grad_output.expand(self.input_shape)
|
||||
|
||||
class Max(Function):
|
||||
def forward(self, x:LazyBuffer, new_shape:ShapeType) -> LazyBuffer:
|
||||
def forward(self, x:LazyBuffer, new_shape:Tuple[int, ...]) -> LazyBuffer:
|
||||
self.x, self.ret = x, x.reduce_op(ReduceOps.MAX, new_shape)
|
||||
return self.ret
|
||||
|
||||
def backward(self, grad_output:LazyBuffer) -> LazyBuffer:
|
||||
# 1s in locations where the max was chosen (can be two locations)
|
||||
max_is_1s = 1.0 - (self.x < self.ret.expand(self.x.shape))
|
||||
max_is_1s = self.x.const_like(1.0) - (self.x < self.ret.expand(self.x.shape))
|
||||
div = max_is_1s.reduce_op(ReduceOps.SUM, grad_output.shape).expand(self.x.shape)
|
||||
return (max_is_1s / div) * grad_output.expand(self.x.shape)
|
||||
|
||||
@@ -139,14 +139,14 @@ class Where(Function):
|
||||
|
||||
def backward(self, grad_output:LazyBuffer):
|
||||
return None, \
|
||||
self.x.ternary_op(TernaryOps.WHERE, grad_output, 0) if self.needs_input_grad[1] else None, \
|
||||
self.x.ternary_op(TernaryOps.WHERE, 0, grad_output) if self.needs_input_grad[2] else None
|
||||
self.x.ternary_op(TernaryOps.WHERE, grad_output, grad_output.const_like(0)) if self.needs_input_grad[1] else None, \
|
||||
self.x.ternary_op(TernaryOps.WHERE, grad_output.const_like(0), grad_output) if self.needs_input_grad[2] else None
|
||||
|
||||
# ************* movement ops *************
|
||||
|
||||
# NOTE: this is sum in reverse
|
||||
class Expand(Function):
|
||||
def forward(self, x:LazyBuffer, shape:ShapeType) -> LazyBuffer:
|
||||
def forward(self, x:LazyBuffer, shape:Tuple[int, ...]) -> LazyBuffer:
|
||||
self.input_shape = x.shape
|
||||
return x.expand(shape)
|
||||
|
||||
@@ -154,7 +154,7 @@ class Expand(Function):
|
||||
return grad_output.reduce_op(ReduceOps.SUM, self.input_shape)
|
||||
|
||||
class Reshape(Function):
|
||||
def forward(self, x:LazyBuffer, shape:ShapeType) -> LazyBuffer:
|
||||
def forward(self, x:LazyBuffer, shape:Tuple[int, ...]) -> LazyBuffer:
|
||||
self.input_shape = x.shape
|
||||
return x.reshape(shape)
|
||||
|
||||
|
||||
@@ -1,13 +1,13 @@
|
||||
# inspired by https://github.com/karpathy/micrograd/blob/master/micrograd/engine.py
|
||||
from __future__ import annotations
|
||||
import time
|
||||
import time, operator
|
||||
from functools import partialmethod, reduce
|
||||
from itertools import accumulate, filterfalse
|
||||
import operator
|
||||
import numpy as np
|
||||
from typing import List, Tuple, Callable, Optional, ClassVar, Type, Union, Sequence, cast
|
||||
from tinygrad.helpers import ImageDType, argfix, make_pair, getenv, IMAGE, DEBUG, flatten, DType, dtypes
|
||||
from math import ceil, pi, prod, sqrt, log, cos, copysign, isinf
|
||||
|
||||
from tinygrad.helpers import ImageDType, argfix, make_pair, getenv, IMAGE, DEBUG, flatten, DType, dtypes
|
||||
from tinygrad.lazy import Device, LazyBuffer
|
||||
from tinygrad.ops import LoadOps
|
||||
|
||||
@@ -604,6 +604,7 @@ class Tensor:
|
||||
if x == 2.0: return self*self
|
||||
if x == 1.0: return self
|
||||
if x == 0.5: return self.sqrt()
|
||||
if not isinstance(x, Tensor) and reverse and x > 0: return self.mul(log(x)).exp()
|
||||
ar = self.abs().log().mul(x).exp() if not reverse or isinstance(x, Tensor) else self.mul(log(abs(x))).exp()
|
||||
# correct sign of negative numbers raised to a power (cos has a period of 2pi so we use it here to get the oddness of the power)
|
||||
sign = (x * pi).cos() if isinstance(x, Tensor) else cos(x * pi) if not reverse else (self * pi).cos()
|
||||
|
||||
Reference in New Issue
Block a user