Revert "bring reciprocal back (#3687)" (#3692)

This reverts commit bcf6fbd3b2.
This commit is contained in:
George Hotz
2024-03-11 15:55:14 -07:00
committed by GitHub
parent ef44c8959b
commit 3af1c1051a
7 changed files with 5 additions and 15 deletions

View File

@@ -36,7 +36,6 @@ python_alu = {
UnaryOps.EXP2: hook_overflow(math.inf, lambda x: math.exp(x*math.log(2))),
UnaryOps.SQRT: lambda x: math.sqrt(x) if x >= 0 else math.nan, UnaryOps.SIN: math.sin,
UnaryOps.NEG: lambda x: (not x) if isinstance(x, bool) else -x,
UnaryOps.RECIP: lambda x: 1/x if x != 0 else math.nan,
BinaryOps.MUL: operator.mul, BinaryOps.ADD: operator.add, BinaryOps.SUB: operator.sub, BinaryOps.XOR: operator.xor,
BinaryOps.MAX: max, BinaryOps.CMPEQ: operator.eq, BinaryOps.CMPLT: operator.lt, BinaryOps.MOD: operator.mod,
BinaryOps.DIV: lambda x,y: int(x/y) if isinstance(x, int) else (x/y if y != 0 else math.nan),

View File

@@ -32,13 +32,6 @@ class Neg(Function):
def forward(self, x:LazyBuffer) -> LazyBuffer: return x.e(UnaryOps.NEG)
def backward(self, grad_output:LazyBuffer) -> LazyBuffer: return grad_output.e(UnaryOps.NEG)
class Reciprocal(Function):
def forward(self, x:LazyBuffer) -> LazyBuffer:
self.ret = x.e(UnaryOps.RECIP)
return self.ret
def backward(self, grad_output:LazyBuffer) -> LazyBuffer:
return grad_output.e(UnaryOps.NEG).e(BinaryOps.MUL, self.ret).e(BinaryOps.MUL, self.ret)
class Sin(Function):
def forward(self, x:LazyBuffer) -> LazyBuffer:
self.x = x
@@ -83,7 +76,7 @@ class Sqrt(Function):
# TODO: have the backend automatically find this
class Sigmoid(Function):
def forward(self, x:LazyBuffer) -> LazyBuffer:
self.ret = x.const(1).e(BinaryOps.ADD, x.e(BinaryOps.MUL, x.const(-1/math.log(2))).e(UnaryOps.EXP2)).e(UnaryOps.RECIP)
self.ret = x.const(1).e(BinaryOps.DIV, x.const(1).e(BinaryOps.ADD, x.e(BinaryOps.MUL, x.const(-1/math.log(2))).e(UnaryOps.EXP2)))
return self.ret
def backward(self, grad_output:LazyBuffer) -> LazyBuffer:

View File

@@ -11,7 +11,7 @@ from dataclasses import dataclass
# the Enum class doesn't work with mypy, this is static. sorry it's ugly
# NOTE: MOD, CMPLT don't have to be implemented on vectors, just scalars
# NOTE: many GPUs don't have DIV, but UnaryOps.RECIP doesn't work for integer division
class UnaryOps(Enum): EXP2 = auto(); LOG2 = auto(); CAST = auto(); SIN = auto(); SQRT = auto(); NEG = auto(); RECIP = auto() # noqa: E702
class UnaryOps(Enum): EXP2 = auto(); LOG2 = auto(); CAST = auto(); SIN = auto(); SQRT = auto(); NEG = auto() # noqa: E702
class BinaryOps(Enum):
ADD = auto(); SUB = auto(); MUL = auto(); DIV = auto(); MAX = auto(); MOD = auto(); CMPLT = auto(); CMPEQ = auto(); XOR = auto() # noqa: E702
class TernaryOps(Enum): WHERE = auto() # noqa: E702

View File

@@ -188,7 +188,6 @@ class PTXLanguage(AssemblyLanguage):
UnaryOps.EXP2: lambda d,a,dt,name: f"ex2.approx.{name} {d}, {a};", UnaryOps.LOG2: lambda d,a,dt,name: f"lg2.approx.{name} {d}, {a};",
UnaryOps.SIN: lambda d,a,dt,name: f"sin.approx.{name} {d}, {a};",
UnaryOps.SQRT: lambda d,a,dt,name: f"sqrt.approx.{name} {d}, {a};",
UnaryOps.RECIP: lambda d,a,dt,name: f"rcp.rn.{name} {d}, {a};",
BinaryOps.ADD: lambda d,a,b,dt,name: f"{'or' if name == 'pred' else 'add'}.{name} {d}, {a}, {b};",
BinaryOps.SUB: lambda d,a,b,dt,name: f"sub.{name} {d}, {a}, {b};",
BinaryOps.MUL: lambda d,a,b,dt,name: ('and' if dt == dtypes.bool else 'mul') + f"{'.lo' if dtypes.is_int(dt) else ''}.{name} {d}, {a}, {b};",

View File

@@ -30,7 +30,7 @@ class CStyleLanguage(NamedTuple):
BinaryOps.ADD: lambda a,b,dtype: f"({a}+{b})", BinaryOps.SUB: lambda a,b,dtype: f"({a}-{b})", BinaryOps.MUL: lambda a,b,dtype: f"({a}*{b})",
BinaryOps.DIV: lambda a,b,dtype: f"({a}/{b})", BinaryOps.MAX: lambda a,b,dtype: f"max({a},{b})", BinaryOps.MOD: lambda a,b,dtype: f"({a}%{b})",
BinaryOps.CMPLT: lambda a,b,dtype: f"({a}<{b})", BinaryOps.CMPEQ: lambda a,b,dtype: f"({a}=={b})", BinaryOps.XOR: lambda a,b,dtype: f"({a}^{b})",
UnaryOps.RECIP: lambda x,dtype: f"(1.0/{x})", TernaryOps.WHERE: lambda a,b,c,dtype: f"({a}?{b}:{c})"}
TernaryOps.WHERE: lambda a,b,c,dtype: f"({a}?{b}:{c})"}
# returns a str expression of the casted xs with the given type
def render_cast(self, x:List[str], var_dtype:DType, bitcast=False) -> str:
@@ -133,7 +133,7 @@ def uops_to_cstyle(lang:CStyleLanguage, function_name:str, uops:UOpGraph) -> str
val = lang.code_for_op[args](*operands, dtype)
assert child_count[u] != 0, f"childless ALU op found {u}"
# TODO: fix index rendering issue. fix clang nested max macro issue
if child_count[u] <= 1 and args not in {UnaryOps.RECIP, BinaryOps.MAX} and not getenv("EXPAND_SSA"): r[u] = val
if child_count[u] <= 1 and args != BinaryOps.MAX and not getenv("EXPAND_SSA"): r[u] = val
else: kk(f"{dtype.name} {ssa(u,'alu')} = {val};")
elif uop is UOps.SPECIAL:
kk(f"int {args[1]} = {lang.code_for_workitem[args[1][0]](args[0])}; /* {args[2]} */")

View File

@@ -16,7 +16,6 @@ code_for_op: Final[Dict[Op, Callable]] = {
UnaryOps.LOG2: lambda builder, x, dtype: builder.call(builder.module.declare_intrinsic('llvm.log2', [x.type]), [x], fastmath=MFLAGS),
UnaryOps.SIN: lambda builder, x, dtype: builder.call(builder.module.declare_intrinsic('llvm.sin', [x.type]), [x], fastmath=MFLAGS),
UnaryOps.SQRT: lambda builder, x, dtype: builder.call(builder.module.declare_intrinsic('llvm.sqrt', [x.type]), [x], fastmath=MFLAGS),
UnaryOps.RECIP: lambda builder, x, dtype: builder.fdiv(ir.Constant(x.type, 1), x, flags=MFLAGS),
BinaryOps.ADD: lambda builder, x, y, dtype: builder.or_(x, y) if dtype == dtypes.bool else builder.add(x, y) if dtypes.is_int(dtype) else builder.fadd(x, y, flags=MFLAGS), # noqa: E501
BinaryOps.SUB: lambda builder, x, y, dtype: builder.sub(x, y) if dtypes.is_int(dtype) else builder.fsub(x, y, flags=MFLAGS),
BinaryOps.MUL: lambda builder, x, y, dtype: builder.mul(x, y) if is_bool_or_unsigned(dtype) or dtypes.is_int(dtype) else builder.fmul(x, y, flags=MFLAGS), # noqa: E501

View File

@@ -789,7 +789,7 @@ class Tensor:
def clip(self, min_, max_): return self.maximum(min_).minimum(max_)
def abs(self): return self.relu() + (-self).relu()
def sign(self): return ((self.float()) / (self.float().abs() + 1e-12)).cast(self.dtype)
def reciprocal(self): return mlops.Reciprocal.apply(self.cast(least_upper_float(self.dtype)))
def reciprocal(self): return 1.0/self
# ***** activation functions (unary) *****