teeny changes (#1589)

* teeny changes

* import order
This commit is contained in:
George Hotz
2023-08-20 13:38:38 -07:00
committed by GitHub
parent 012ee7d162
commit d627349af0
4 changed files with 35 additions and 39 deletions

View File

@@ -4,7 +4,6 @@ import numpy as np
from typing import Dict, Tuple, Union, List, NamedTuple, Final, Iterator, ClassVar, Optional, Callable, Any, Iterable
from math import prod # noqa: F401 # pylint:disable=unused-import
ShapeType = Tuple[int, ...]
# NOTE: helpers is not allowed to import from anything else in tinygrad
OSX = platform.system() == "Darwin"
CI = os.getenv("CI", "") != ""
@@ -93,7 +92,7 @@ class dtypes:
def from_np(x) -> DType: return DTYPES_DICT[np.dtype(x).name]
@staticmethod
def fields() -> Dict[str, DType]: return DTYPES_DICT
bool: Final[DType] = DType(0, 1, "bool", bool)
bool: Final[DType] = DType(0, 1, "bool", np.bool_)
float16: Final[DType] = DType(0, 2, "half", np.float16)
half = float16
float32: Final[DType] = DType(4, 4, "float", np.float32)

View File

@@ -195,20 +195,20 @@ class LazyBuffer:
assert not arg[1] or self.dtype.itemsize == arg[0].itemsize, "can't bitcast mismatched dtype itemsizes"
return elementwise_op(UnaryOps.CAST, self, arg=arg) if self.dtype != arg[0] else self
def unary_op(self:LazyBuffer, op:UnaryOps) -> LazyBuffer: return elementwise_op(op, self)
def binary_op(self:LazyBuffer, op:BinaryOps, y:Union[LazyBuffer, float, int]) -> LazyBuffer: return elementwise_op(op, self, y)
def ternary_op(self:LazyBuffer, op:TernaryOps, y:Union[LazyBuffer, float, int], z:Union[LazyBuffer, float, int]) -> LazyBuffer: return elementwise_op(op, self, y, z)
def binary_op(self:LazyBuffer, op:BinaryOps, y:LazyBuffer) -> LazyBuffer: return elementwise_op(op, self, y)
def ternary_op(self:LazyBuffer, op:TernaryOps, y:LazyBuffer, z:LazyBuffer) -> LazyBuffer: return elementwise_op(op, self, y, z)
def __add__(self, y:Union[LazyBuffer, float, int]) -> LazyBuffer: return elementwise_op(BinaryOps.ADD, self, y)
def __radd__(self, y:Union[LazyBuffer, float, int]) -> LazyBuffer: return elementwise_op(BinaryOps.ADD, y, self)
def __mul__(self, y:Union[LazyBuffer, float, int]) -> LazyBuffer: return elementwise_op(BinaryOps.MUL, self, y)
def __rmul__(self, y:Union[LazyBuffer, float, int]) -> LazyBuffer: return elementwise_op(BinaryOps.MUL, y, self)
def __truediv__(self, y:Union[LazyBuffer, float, int]) -> LazyBuffer: return elementwise_op(BinaryOps.DIV, self, y)
def __rtruediv__(self, y:Union[LazyBuffer, float, int]) -> LazyBuffer: return elementwise_op(BinaryOps.DIV, y, self)
def __sub__(self, y:Union[LazyBuffer, float, int]) -> LazyBuffer: return elementwise_op(BinaryOps.SUB, self, y)
def __rsub__(self, y:Union[LazyBuffer, float, int]) -> LazyBuffer: return elementwise_op(BinaryOps.SUB, y, self)
def __lt__(self, y:Union[LazyBuffer, float, int]) -> LazyBuffer: return elementwise_op(BinaryOps.CMPLT, self, y)
def __gt__(self, y:Union[LazyBuffer, float, int]) -> LazyBuffer: return elementwise_op(BinaryOps.CMPLT, y, self)
def __neg__(self) -> LazyBuffer: return 0.0-self
def __add__(self, y:LazyBuffer) -> LazyBuffer: return elementwise_op(BinaryOps.ADD, self, y)
def __radd__(self, y:LazyBuffer) -> LazyBuffer: return elementwise_op(BinaryOps.ADD, y, self)
def __mul__(self, y:LazyBuffer) -> LazyBuffer: return elementwise_op(BinaryOps.MUL, self, y)
def __rmul__(self, y:LazyBuffer) -> LazyBuffer: return elementwise_op(BinaryOps.MUL, y, self)
def __truediv__(self, y:LazyBuffer) -> LazyBuffer: return elementwise_op(BinaryOps.DIV, self, y)
def __rtruediv__(self, y:LazyBuffer) -> LazyBuffer: return elementwise_op(BinaryOps.DIV, y, self)
def __sub__(self, y:LazyBuffer) -> LazyBuffer: return elementwise_op(BinaryOps.SUB, self, y)
def __rsub__(self, y:LazyBuffer) -> LazyBuffer: return elementwise_op(BinaryOps.SUB, y, self)
def __lt__(self, y:LazyBuffer) -> LazyBuffer: return elementwise_op(BinaryOps.CMPLT, self, y)
def __gt__(self, y:LazyBuffer) -> LazyBuffer: return elementwise_op(BinaryOps.CMPLT, y, self)
def __neg__(self) -> LazyBuffer: return self.const_like(0.0)-self
def contiguous(self:LazyBuffer) -> LazyBuffer:
if not self.realized and self.op.op == LoadOps.CONTIGUOUS: return self # two CONTIGUOUS in a row is one
@@ -316,11 +316,7 @@ def _push_movement_ops(srcs:Tuple[LazyBuffer, ...]) -> Tuple[LazyBuffer, ...]:
new_srcs.append(x)
return tuple(new_srcs)
def elementwise_op(op:Union[UnaryOps, BinaryOps, TernaryOps], *_srcs:Union[LazyBuffer, float, int], arg:Optional[Any]=None) -> LazyBuffer:
# make them all LazyBuffers
first_src = [x for x in _srcs if isinstance(x, LazyBuffer)][0]
srcs:Tuple[LazyBuffer, ...] = tuple(x if isinstance(x, LazyBuffer) else first_src.const_like(x) for x in _srcs)
def elementwise_op(op:Union[UnaryOps, BinaryOps, TernaryOps], *srcs:LazyBuffer, arg:Optional[Any]=None) -> LazyBuffer:
# if we are separated from other binary ops by movement ops, we push those movement ops above those binaryops
if SHUFFLE_MOVEMENT_OPS: srcs = _push_movement_ops(srcs)

View File

@@ -1,5 +1,5 @@
from typing import Tuple, Optional
from tinygrad.helpers import argsort, ShapeType, DType
from tinygrad.helpers import argsort, DType
from tinygrad.ops import UnaryOps, BinaryOps, TernaryOps, ReduceOps
from tinygrad.tensor import Function
from tinygrad.lazy import LazyBuffer
@@ -23,28 +23,28 @@ class Sin(Function):
self.x = x
return x.unary_op(UnaryOps.SIN)
def backward(self, grad:LazyBuffer) -> LazyBuffer:
return ((math.pi / 2) - self.x).unary_op(UnaryOps.SIN) * grad
return (self.x.const_like(math.pi / 2) - self.x).unary_op(UnaryOps.SIN) * grad
# NOTE: maximum(x, 0) behaves differently where x=0
class Relu(Function):
def forward(self, x:LazyBuffer) -> LazyBuffer:
self.ret = x.binary_op(BinaryOps.MAX, 0)
self.ret = x.binary_op(BinaryOps.MAX, x.const_like(0))
return self.ret
def backward(self, grad_output:LazyBuffer) -> LazyBuffer:
return (0 < self.ret) * grad_output
return (self.ret.const_like(0) < self.ret) * grad_output
class Log(Function):
def forward(self, x:LazyBuffer) -> LazyBuffer:
self.x = x
return x.unary_op(UnaryOps.LOG2) * math.log(2)
return x.unary_op(UnaryOps.LOG2) * x.const_like(math.log(2))
def backward(self, grad_output:LazyBuffer) -> LazyBuffer:
return grad_output / self.x
class Exp(Function):
def forward(self, x:LazyBuffer) -> LazyBuffer:
self.ret = (x * (1/math.log(2))).unary_op(UnaryOps.EXP2)
self.ret = (x * x.const_like(1/math.log(2))).unary_op(UnaryOps.EXP2)
return self.ret
def backward(self, grad_output:LazyBuffer) -> LazyBuffer:
@@ -56,23 +56,23 @@ class Sqrt(Function):
return self.ret
def backward(self, grad_output:LazyBuffer) -> LazyBuffer:
return grad_output / (self.ret * 2)
return grad_output / (self.ret * self.ret.const_like(2))
# NOTE: the implicit derivative of sigmoid is not stable
# https://towardsdatascience.com/derivative-of-the-sigmoid-function-536880cf918e
# TODO: have the backend automatically find this
class Sigmoid(Function):
def forward(self, x:LazyBuffer) -> LazyBuffer:
self.ret = 1 / (1 + (x * (-1/math.log(2))).unary_op(UnaryOps.EXP2))
self.ret = x.const_like(1) / (x.const_like(1) + (x * x.const_like(-1/math.log(2))).unary_op(UnaryOps.EXP2))
return self.ret
def backward(self, grad_output:LazyBuffer) -> LazyBuffer:
return (self.ret * (1 - self.ret)) * grad_output
return (self.ret * (self.ret.const_like(1) - self.ret)) * grad_output
# ************* reduce ops *************
class Sum(Function):
def forward(self, x:LazyBuffer, new_shape:ShapeType) -> LazyBuffer:
def forward(self, x:LazyBuffer, new_shape:Tuple[int, ...]) -> LazyBuffer:
self.input_shape = x.shape
return x.reduce_op(ReduceOps.SUM, new_shape)
@@ -80,13 +80,13 @@ class Sum(Function):
return grad_output.expand(self.input_shape)
class Max(Function):
def forward(self, x:LazyBuffer, new_shape:ShapeType) -> LazyBuffer:
def forward(self, x:LazyBuffer, new_shape:Tuple[int, ...]) -> LazyBuffer:
self.x, self.ret = x, x.reduce_op(ReduceOps.MAX, new_shape)
return self.ret
def backward(self, grad_output:LazyBuffer) -> LazyBuffer:
# 1s in locations where the max was chosen (can be two locations)
max_is_1s = 1.0 - (self.x < self.ret.expand(self.x.shape))
max_is_1s = self.x.const_like(1.0) - (self.x < self.ret.expand(self.x.shape))
div = max_is_1s.reduce_op(ReduceOps.SUM, grad_output.shape).expand(self.x.shape)
return (max_is_1s / div) * grad_output.expand(self.x.shape)
@@ -139,14 +139,14 @@ class Where(Function):
def backward(self, grad_output:LazyBuffer):
return None, \
self.x.ternary_op(TernaryOps.WHERE, grad_output, 0) if self.needs_input_grad[1] else None, \
self.x.ternary_op(TernaryOps.WHERE, 0, grad_output) if self.needs_input_grad[2] else None
self.x.ternary_op(TernaryOps.WHERE, grad_output, grad_output.const_like(0)) if self.needs_input_grad[1] else None, \
self.x.ternary_op(TernaryOps.WHERE, grad_output.const_like(0), grad_output) if self.needs_input_grad[2] else None
# ************* movement ops *************
# NOTE: this is sum in reverse
class Expand(Function):
def forward(self, x:LazyBuffer, shape:ShapeType) -> LazyBuffer:
def forward(self, x:LazyBuffer, shape:Tuple[int, ...]) -> LazyBuffer:
self.input_shape = x.shape
return x.expand(shape)
@@ -154,7 +154,7 @@ class Expand(Function):
return grad_output.reduce_op(ReduceOps.SUM, self.input_shape)
class Reshape(Function):
def forward(self, x:LazyBuffer, shape:ShapeType) -> LazyBuffer:
def forward(self, x:LazyBuffer, shape:Tuple[int, ...]) -> LazyBuffer:
self.input_shape = x.shape
return x.reshape(shape)

View File

@@ -1,13 +1,13 @@
# inspired by https://github.com/karpathy/micrograd/blob/master/micrograd/engine.py
from __future__ import annotations
import time
import time, operator
from functools import partialmethod, reduce
from itertools import accumulate, filterfalse
import operator
import numpy as np
from typing import List, Tuple, Callable, Optional, ClassVar, Type, Union, Sequence, cast
from tinygrad.helpers import ImageDType, argfix, make_pair, getenv, IMAGE, DEBUG, flatten, DType, dtypes
from math import ceil, pi, prod, sqrt, log, cos, copysign, isinf
from tinygrad.helpers import ImageDType, argfix, make_pair, getenv, IMAGE, DEBUG, flatten, DType, dtypes
from tinygrad.lazy import Device, LazyBuffer
from tinygrad.ops import LoadOps
@@ -604,6 +604,7 @@ class Tensor:
if x == 2.0: return self*self
if x == 1.0: return self
if x == 0.5: return self.sqrt()
if not isinstance(x, Tensor) and reverse and x > 0: return self.mul(log(x)).exp()
ar = self.abs().log().mul(x).exp() if not reverse or isinstance(x, Tensor) else self.mul(log(abs(x))).exp()
# correct sign of negative numbers raised to a power (cos has a period of 2pi so we use it here to get the oddness of the power)
sign = (x * pi).cos() if isinstance(x, Tensor) else cos(x * pi) if not reverse else (self * pi).cos()