diff --git a/tinygrad/mlops.py b/tinygrad/mlops.py index 5dc95d28e2..001fca6f28 100644 --- a/tinygrad/mlops.py +++ b/tinygrad/mlops.py @@ -78,8 +78,8 @@ class Sqrt(Function): # TODO: have the backend automatically find this class Sigmoid(Function): def forward(self, x:LazyBuffer) -> LazyBuffer: - self.ret = x.cast(ftype:=least_upper_float(x.dtype)).const(1).e( - BinaryOps.DIV, x.cast(ftype).const(1).e(BinaryOps.ADD, x.e(BinaryOps.MUL, x.cast(ftype).const(-1/math.log(2))).e(UnaryOps.EXP2))) + x = x.cast(least_upper_float(x.dtype)) + self.ret = x.const(1).e(BinaryOps.DIV, x.const(1).e(BinaryOps.ADD, x.e(BinaryOps.MUL, x.const(-1/math.log(2))).e(UnaryOps.EXP2))) return self.ret def backward(self, grad_output:LazyBuffer) -> LazyBuffer: @@ -89,11 +89,13 @@ class Sigmoid(Function): class Less(Function): def forward(self, x:LazyBuffer, y:LazyBuffer) -> LazyBuffer: - return x.e(BinaryOps.CMPLT, y) + output_dtype = least_upper_dtype(x.dtype, y.dtype) + return x.cast(output_dtype).e(BinaryOps.CMPLT, y.cast(output_dtype)) class Xor(Function): def forward(self, x:LazyBuffer, y:LazyBuffer) -> LazyBuffer: - return x.e(BinaryOps.XOR, y) + output_dtype = least_upper_dtype(x.dtype, y.dtype) + return x.cast(output_dtype).e(BinaryOps.XOR, y.cast(output_dtype)) class Add(Function): def forward(self, x:LazyBuffer, y:LazyBuffer) -> LazyBuffer: diff --git a/tinygrad/runtime/ops_cpu.py b/tinygrad/runtime/ops_cpu.py index 6e74a1272e..5ebf15f03f 100644 --- a/tinygrad/runtime/ops_cpu.py +++ b/tinygrad/runtime/ops_cpu.py @@ -10,9 +10,6 @@ def shape_to_axis(old_shape:Tuple[int, ...], new_shape:Tuple[int, ...]) -> Tuple # TODO: this should be global infrastructure def output_type(x, y): return x.dtype if dtypes.from_np(x.dtype).priority > dtypes.from_np(y.dtype).priority else y.dtype -def match_types(x, y): - up = output_type(x, y) - return x.astype(up, copy=False), y.astype(up, copy=False) def einsum_mulacc(einsum, get_strides, expand): def einscripts(x): return ''.join(["abcdefghijklmnopqrstuvwxyz"[i] for i in x]) @@ -24,20 +21,23 @@ def einsum_mulacc(einsum, get_strides, expand): return expand(ret.reshape([(1 if i not in a_axes and i not in b_axes else s) for i,s in enumerate(new_shape)]), new_shape) return mulacc +def as_strided(x, arg): + return np.ndarray(shape=arg[0], dtype=x.dtype, buffer=np.require(x, requirements='C'), + offset=arg[2]*x.dtype.itemsize, strides=tuple(y*x.dtype.itemsize for y in arg[1])) + numpy_fxn_for_op: Dict[Op, Callable] = { BufferOps.CONST: lambda val, dtype: np.array(val, dtype=dtype.np), UnaryOps.EXP2: np.exp2, UnaryOps.LOG2: np.log2, UnaryOps.SIN: np.sin, UnaryOps.CAST: lambda x,y: x.view(y[0].np) if y[1] else x.astype(y[0].np, copy=False), UnaryOps.NEG: lambda x: np.logical_not(x) if x.dtype == np.bool_ else np.negative(x), - BinaryOps.MAX: np.maximum, BinaryOps.CMPLT: lambda x,y: (x type_map[y.dtype].priority else y.dtype -def match_types(x, y, disallow_bool=False): - up = output_type(x, y) - if disallow_bool and up == torch.bool: up = torch.float - return x.type(up), y.type(up) def as_strided(x, arg): if any(i < 0 for i in arg[1]): @@ -33,11 +30,11 @@ torch_fxn_for_op: Dict[Op, Callable] = { UnaryOps.CAST: lambda x,y: (x.view if y[1] else x.type)(inverse_type_map[y[0]]), UnaryOps.NEG: lambda x: torch.logical_not(x) if x.dtype is torch.bool else torch.neg(x), BinaryOps.MAX: torch.maximum, BinaryOps.CMPLT: lambda x,y: (x