mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-02-19 02:44:40 -05:00
Add WHERE ternary (or trinary?) op (#1196)
* Rename FusedOps to TernaryOps * Support ternary broadcast * Add where llop and mlop * Make where op work in cstyle codegen * Don't skip test_inf_where * Add backward path to where op * Use bool in cstyle codegen * Add LLVM where op * Add numpy where op * Add torch where op * Simplify where mlop * Update documentation * Forgot a rename * Merged relevant changes from PR #1195 onto PR #1196 * Add test to cover changes to linearizer.ast_parse for WHERE op Without this METAL will try to use ternary op on float4 and fail * Make where op work in wgsl backend * Allow ternary ops to be merged * Make mypy happy --------- Co-authored-by: Francis Lam <flam@alum.mit.edu>
This commit is contained in:
@@ -170,10 +170,6 @@ class Tensor:
|
||||
@staticmethod
|
||||
def eye(dim, **kwargs): return Tensor([1], **kwargs).slice(((0,dim+1),)).reshape(1, dim+1).expand(dim, dim+1).reshape(dim*(dim+1)).slice(((0,dim*dim),)).reshape(dim, dim)
|
||||
|
||||
def where(self:Tensor, input_:Union[Tensor, float], other:Union[Tensor, float]):
|
||||
cond = (self != 0.0)
|
||||
return cond * input_ + (1.0 - cond) * other
|
||||
|
||||
# ***** rng hlops *****
|
||||
|
||||
@staticmethod
|
||||
@@ -582,6 +578,31 @@ class Tensor:
|
||||
def minimum(self, x:Union[Tensor, float]) -> Tensor: return -((-self).maximum(-x))
|
||||
def eq(self, x) -> Tensor: return self._broadcasted(mlops.Equal, x, False)
|
||||
|
||||
# ***** broadcasted trinary mlops *****
|
||||
|
||||
def where(self:Tensor, input_:Union[Tensor, float], other:Union[Tensor, float]):
|
||||
# TODO: ensure self is non-differentiable, could mess with ceil/float though
|
||||
dtype = self.dtype if self.dtype != dtypes.bool and self.dtype.__class__ is not ImageDType else dtypes.float32
|
||||
x: Tensor = self
|
||||
y: Tensor = Tensor(cast(float, input_), device=self.device, requires_grad=False, dtype=dtype) if input_.__class__ is not Tensor else cast(Tensor, input_)
|
||||
z: Tensor = Tensor(cast(float, other), device=self.device, requires_grad=False, dtype=dtype) if other.__class__ is not Tensor else cast(Tensor, other)
|
||||
if x.shape == y.shape and y.shape == z.shape: return mlops.Where.apply(x, y, z)
|
||||
|
||||
# TODO refactor this code along with the binary version above
|
||||
len_x_shape, len_y_shape, len_z_shape = len(x.shape), len(y.shape), len(z.shape)
|
||||
max_shape = max(len_x_shape, len_y_shape, len_z_shape)
|
||||
|
||||
if len_x_shape != max_shape: x = x.reshape((1,) * (max_shape - len_x_shape) + x.shape)
|
||||
if len_y_shape != max_shape: y = y.reshape((1,) * (max_shape - len_y_shape) + y.shape)
|
||||
if len_z_shape != max_shape: z = z.reshape((1,) * (max_shape - len_z_shape) + z.shape)
|
||||
|
||||
shape_ret = tuple([max(x, y, z) for x, y, z in zip(x.shape, y.shape, z.shape)])
|
||||
if x.shape != shape_ret: x = x.expand(shape_ret)
|
||||
if y.shape != shape_ret: y = y.expand(shape_ret)
|
||||
if z.shape != shape_ret: z = z.expand(shape_ret)
|
||||
|
||||
return mlops.Where.apply(x, y, z)
|
||||
|
||||
# ***** binary op wrappers (18 wasted lines to make the typechecker happy) *****
|
||||
|
||||
# NOTE: __pow__ and friends are broken in mypyc with the ** operator
|
||||
|
||||
Reference in New Issue
Block a user