mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-07 03:00:26 -04:00
from_number_like to fix div issue
This commit is contained in:
@@ -66,7 +66,7 @@ class CLProgram:
|
||||
|
||||
class GPUBuffer:
|
||||
code_for_op = {
|
||||
UnaryOps.NOOP: "(A)", UnaryOps.NEG: "(-(A))", UnaryOps.RELU: "max(A, (float)0.)", UnaryOps.EXP: "exp(A)", UnaryOps.LOG: "log(A)", UnaryOps.SIGN: "sign(A)", UnaryOps.RECIPROCAL: "(1.0/A)",
|
||||
UnaryOps.NOOP: "(A)", UnaryOps.NEG: "(-(A))", UnaryOps.RELU: "max(A, (float)0.)", UnaryOps.EXP: "exp(A)", UnaryOps.LOG: "log(A)", UnaryOps.SIGN: "sign(A)", UnaryOps.RECIPROCAL: "((float)1.0/A)",
|
||||
BinaryOps.ADD: "(A+B)", BinaryOps.SUB: "(A-B)", BinaryOps.MUL: "(A*B)", BinaryOps.DIV: "(A/B)", BinaryOps.POW: "pow(A,B)", BinaryOps.CMPEQ: "(A==B)",
|
||||
ReduceOps.SUM: "(acc + A)", ReduceOps.MAX: "max(A, acc)"
|
||||
}
|
||||
|
||||
@@ -242,10 +242,13 @@ class Tensor:
|
||||
|
||||
# ***** broadcasted binary ops *****
|
||||
|
||||
# TODO: cache common number Tensors?
|
||||
def from_number_like(self, t): return Tensor([t], device=self.device, requires_grad=False) if not isinstance(t, Tensor) else t
|
||||
|
||||
@staticmethod
|
||||
def broadcasted(fxn, x, y):
|
||||
tt = [arg for arg in [x,y] if isinstance(arg, Tensor)][0] # this is the prototype tensor
|
||||
x,y = [Tensor([t], device=tt.device, requires_grad=False) if not isinstance(t, Tensor) else t for t in [x,y]]
|
||||
x,y = [tt.from_number_like(t) for t in [x,y]]
|
||||
x,y = [t.reshape(list(t.shape) + [1]*(max(len(x.shape), len(y.shape))-len(t.shape))) for t in [x,y]]
|
||||
shape_ret = tuple(max(sx, sy) for sx,sy in zip(x.shape, y.shape))
|
||||
return fxn(x.expand(shape_ret), y.expand(shape_ret))
|
||||
@@ -255,7 +258,7 @@ class Tensor:
|
||||
def sub(self, x): return Tensor.broadcasted(Tensor._sub, self, x)
|
||||
def mul(self, x): return Tensor.broadcasted(Tensor._mul, self, x)
|
||||
def pow(self, x): return Tensor.broadcasted(Tensor._pow, self, x)
|
||||
def div(self, y): return self * y.reciprocal()
|
||||
def div(self, y): return self * self.from_number_like(y).reciprocal()
|
||||
__truediv__ = div
|
||||
|
||||
# ***** functional nn ops *****
|
||||
|
||||
Reference in New Issue
Block a user