From 197dbbda0ff63270ff3d2912439bd6ec9d057305 Mon Sep 17 00:00:00 2001 From: George Hotz Date: Thu, 26 Sep 2024 10:36:33 +0800 Subject: [PATCH] add UnaryOps.NEG + BinaryOps.SUB so process replay can work --- tinygrad/ops.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/tinygrad/ops.py b/tinygrad/ops.py index cb8cb9f65f..ca6a90682e 100644 --- a/tinygrad/ops.py +++ b/tinygrad/ops.py @@ -21,11 +21,11 @@ class FastEnum(IntEnum): # NOTE: many GPUs don't have DIV, but UnaryOps.RECIP doesn't work for integer division class UnaryOps(FastEnum): """A -> A (elementwise)""" - EXP2 = auto(); LOG2 = auto(); CAST = auto(); BITCAST = auto(); SIN = auto(); SQRT = auto(); RECIP = auto() # noqa: E702 + EXP2 = auto(); LOG2 = auto(); CAST = auto(); BITCAST = auto(); SIN = auto(); SQRT = auto(); RECIP = auto(); NEG = auto() # noqa: E702 class BinaryOps(FastEnum): """A + A -> A (elementwise)""" ADD = auto(); MUL = auto(); IDIV = auto(); MAX = auto(); MOD = auto(); CMPLT = auto(); CMPNE = auto(); XOR = auto() # noqa: E702 - SHL = auto(); SHR = auto(); OR = auto(); AND = auto(); THREEFRY = auto() # noqa: E702 + SHL = auto(); SHR = auto(); OR = auto(); AND = auto(); THREEFRY = auto(); SUB = auto() # noqa: E702 class TernaryOps(FastEnum): """A + A + A -> A (elementwise)""" WHERE = auto(); MULACC = auto() # noqa: E702 @@ -304,7 +304,8 @@ python_alu: Dict[Op, Callable] = { UnaryOps.LOG2: lambda x: math.log2(x) if x > 0 else -math.inf if x == 0 else math.nan, UnaryOps.EXP2: hook_overflow(math.inf, lambda x: 2**x), UnaryOps.SQRT: lambda x: math.sqrt(x) if x >= 0 else math.nan, UnaryOps.RECIP: lambda x: 1/x if x != 0 else math.copysign(math.inf, x), UnaryOps.SIN: lambda x: math.sin(x) if not math.isinf(x) else math.nan, - BinaryOps.SHR: operator.rshift, BinaryOps.SHL: operator.lshift, BinaryOps.MUL: operator.mul, BinaryOps.ADD: operator.add, + UnaryOps.NEG: operator.neg, BinaryOps.ADD: operator.add, BinaryOps.SUB: operator.sub, + BinaryOps.SHR: operator.rshift, BinaryOps.SHL: operator.lshift, BinaryOps.MUL: operator.mul, BinaryOps.XOR: operator.xor, BinaryOps.MAX: max, BinaryOps.CMPNE: operator.ne, BinaryOps.CMPLT: operator.lt, BinaryOps.OR: operator.or_, BinaryOps.AND: operator.and_, BinaryOps.MOD: lambda x,y: abs(int(x))%abs(int(y))*(1,-1)[x<0], BinaryOps.IDIV: lambda x,y: abs(x)//abs(y)*(1,-1)[x*y<0] if y != 0 else x*math.inf,