from_number_like to fix div issue

This commit is contained in:
George Hotz
2022-09-03 16:19:16 -07:00
parent c2a030fe55
commit 39e1d23c88
2 changed files with 6 additions and 3 deletions

View File

@@ -66,7 +66,7 @@ class CLProgram:
class GPUBuffer:
code_for_op = {
UnaryOps.NOOP: "(A)", UnaryOps.NEG: "(-(A))", UnaryOps.RELU: "max(A, (float)0.)", UnaryOps.EXP: "exp(A)", UnaryOps.LOG: "log(A)", UnaryOps.SIGN: "sign(A)", UnaryOps.RECIPROCAL: "(1.0/A)",
UnaryOps.NOOP: "(A)", UnaryOps.NEG: "(-(A))", UnaryOps.RELU: "max(A, (float)0.)", UnaryOps.EXP: "exp(A)", UnaryOps.LOG: "log(A)", UnaryOps.SIGN: "sign(A)", UnaryOps.RECIPROCAL: "((float)1.0/A)",
BinaryOps.ADD: "(A+B)", BinaryOps.SUB: "(A-B)", BinaryOps.MUL: "(A*B)", BinaryOps.DIV: "(A/B)", BinaryOps.POW: "pow(A,B)", BinaryOps.CMPEQ: "(A==B)",
ReduceOps.SUM: "(acc + A)", ReduceOps.MAX: "max(A, acc)"
}

View File

@@ -242,10 +242,13 @@ class Tensor:
# ***** broadcasted binary ops *****
# TODO: cache common number Tensors?
def from_number_like(self, t): return Tensor([t], device=self.device, requires_grad=False) if not isinstance(t, Tensor) else t
@staticmethod
def broadcasted(fxn, x, y):
tt = [arg for arg in [x,y] if isinstance(arg, Tensor)][0] # this is the prototype tensor
x,y = [Tensor([t], device=tt.device, requires_grad=False) if not isinstance(t, Tensor) else t for t in [x,y]]
x,y = [tt.from_number_like(t) for t in [x,y]]
x,y = [t.reshape(list(t.shape) + [1]*(max(len(x.shape), len(y.shape))-len(t.shape))) for t in [x,y]]
shape_ret = tuple(max(sx, sy) for sx,sy in zip(x.shape, y.shape))
return fxn(x.expand(shape_ret), y.expand(shape_ret))
@@ -255,7 +258,7 @@ class Tensor:
def sub(self, x): return Tensor.broadcasted(Tensor._sub, self, x)
def mul(self, x): return Tensor.broadcasted(Tensor._mul, self, x)
def pow(self, x): return Tensor.broadcasted(Tensor._pow, self, x)
def div(self, y): return self * y.reciprocal()
def div(self, y): return self * self.from_number_like(y).reciprocal()
__truediv__ = div
# ***** functional nn ops *****