mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-26 23:38:58 -05:00
add UnaryOps.NEG + BinaryOps.SUB so process replay can work
This commit is contained in:
@@ -21,11 +21,11 @@ class FastEnum(IntEnum):
|
||||
# NOTE: many GPUs don't have DIV, but UnaryOps.RECIP doesn't work for integer division
|
||||
class UnaryOps(FastEnum):
|
||||
"""A -> A (elementwise)"""
|
||||
EXP2 = auto(); LOG2 = auto(); CAST = auto(); BITCAST = auto(); SIN = auto(); SQRT = auto(); RECIP = auto() # noqa: E702
|
||||
EXP2 = auto(); LOG2 = auto(); CAST = auto(); BITCAST = auto(); SIN = auto(); SQRT = auto(); RECIP = auto(); NEG = auto() # noqa: E702
|
||||
class BinaryOps(FastEnum):
|
||||
"""A + A -> A (elementwise)"""
|
||||
ADD = auto(); MUL = auto(); IDIV = auto(); MAX = auto(); MOD = auto(); CMPLT = auto(); CMPNE = auto(); XOR = auto() # noqa: E702
|
||||
SHL = auto(); SHR = auto(); OR = auto(); AND = auto(); THREEFRY = auto() # noqa: E702
|
||||
SHL = auto(); SHR = auto(); OR = auto(); AND = auto(); THREEFRY = auto(); SUB = auto() # noqa: E702
|
||||
class TernaryOps(FastEnum):
|
||||
"""A + A + A -> A (elementwise)"""
|
||||
WHERE = auto(); MULACC = auto() # noqa: E702
|
||||
@@ -304,7 +304,8 @@ python_alu: Dict[Op, Callable] = {
|
||||
UnaryOps.LOG2: lambda x: math.log2(x) if x > 0 else -math.inf if x == 0 else math.nan, UnaryOps.EXP2: hook_overflow(math.inf, lambda x: 2**x),
|
||||
UnaryOps.SQRT: lambda x: math.sqrt(x) if x >= 0 else math.nan, UnaryOps.RECIP: lambda x: 1/x if x != 0 else math.copysign(math.inf, x),
|
||||
UnaryOps.SIN: lambda x: math.sin(x) if not math.isinf(x) else math.nan,
|
||||
BinaryOps.SHR: operator.rshift, BinaryOps.SHL: operator.lshift, BinaryOps.MUL: operator.mul, BinaryOps.ADD: operator.add,
|
||||
UnaryOps.NEG: operator.neg, BinaryOps.ADD: operator.add, BinaryOps.SUB: operator.sub,
|
||||
BinaryOps.SHR: operator.rshift, BinaryOps.SHL: operator.lshift, BinaryOps.MUL: operator.mul,
|
||||
BinaryOps.XOR: operator.xor, BinaryOps.MAX: max, BinaryOps.CMPNE: operator.ne, BinaryOps.CMPLT: operator.lt,
|
||||
BinaryOps.OR: operator.or_, BinaryOps.AND: operator.and_,
|
||||
BinaryOps.MOD: lambda x,y: abs(int(x))%abs(int(y))*(1,-1)[x<0], BinaryOps.IDIV: lambda x,y: abs(x)//abs(y)*(1,-1)[x*y<0] if y != 0 else x*math.inf,
|
||||
|
||||
Reference in New Issue
Block a user