Add WHERE ternary (or trinary?) op (#1196)

* Rename FusedOps to TernaryOps

* Support ternary broadcast

* Add where llop and mlop

* Make where op work in cstyle codegen

* Don't skip test_inf_where

* Add backward path to where op

* Use bool in cstyle codegen

* Add LLVM where op

* Add numpy where op

* Add torch where op

* Simplify where mlop

* Update documentation

* Forgot a rename

* Merged relevant changes from PR #1195 onto PR #1196

* Add test to cover changes to linearizer.ast_parse for WHERE op

Without this METAL will try to use ternary op on float4 and fail

* Make where op work in wgsl backend

* Allow ternary ops to be merged

* Make mypy happy

---------

Co-authored-by: Francis Lam <flam@alum.mit.edu>
This commit is contained in:
Adrian Kretz
2023-07-16 09:31:55 +02:00
committed by GitHub
parent 91f797cd52
commit 5a8ad57163
16 changed files with 86 additions and 44 deletions

View File

@@ -2,7 +2,7 @@ import numpy as np
import operator
from typing import Callable, Dict, Tuple, Optional
from tinygrad.helpers import dtypes, DType
from tinygrad.ops import UnaryOps, BinaryOps, MovementOps, ReduceOps, FusedOps, Op, Interpreted
from tinygrad.ops import UnaryOps, BinaryOps, MovementOps, ReduceOps, TernaryOps, Op, Interpreted
from tinygrad.runtime.lib import RawBuffer
def shape_to_axis(old_shape:Tuple[int, ...], new_shape:Tuple[int, ...]) -> Tuple[int, ...]:
@@ -37,7 +37,8 @@ numpy_fxn_for_op: Dict[Op, Callable] = {**base_fxn_for_op, **{
BinaryOps.DIV: lambda x, y: np.divide(*match_types(x, y)), UnaryOps.SQRT: np.sqrt,
MovementOps.PERMUTE: lambda x, order: x.transpose(order), MovementOps.PAD: np.pad, MovementOps.EXPAND: np.broadcast_to,
MovementOps.STRIDE: lambda x, arg: x[tuple(slice(None, None, i) for i in arg)],
FusedOps.MULACC: einsum_mulacc(lambda s,a,b: np.einsum(s, *match_types(a.copy(), b.copy()), optimize=True), lambda x: x.strides, np.broadcast_to),
TernaryOps.MULACC: einsum_mulacc(lambda s,a,b: np.einsum(s, *match_types(a.copy(), b.copy()), optimize=True), lambda x: x.strides, np.broadcast_to),
TernaryOps.WHERE: np.where,
}}
class RawNumpyBuffer(RawBuffer):