mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-10 07:28:15 -05:00
Add WHERE ternary (or trinary?) op (#1196)
* Rename FusedOps to TernaryOps * Support ternary broadcast * Add where llop and mlop * Make where op work in cstyle codegen * Don't skip test_inf_where * Add backward path to where op * Use bool in cstyle codegen * Add LLVM where op * Add numpy where op * Add torch where op * Simplify where mlop * Update documentation * Forgot a rename * Merged relevant changes from PR #1195 onto PR #1196 * Add test to cover changes to linearizer.ast_parse for WHERE op Without this METAL will try to use ternary op on float4 and fail * Make where op work in wgsl backend * Allow ternary ops to be merged * Make mypy happy --------- Co-authored-by: Francis Lam <flam@alum.mit.edu>
This commit is contained in:
@@ -2,7 +2,7 @@ import numpy as np
|
||||
import operator
|
||||
from typing import Callable, Dict, Tuple, Optional
|
||||
from tinygrad.helpers import dtypes, DType
|
||||
from tinygrad.ops import UnaryOps, BinaryOps, MovementOps, ReduceOps, FusedOps, Op, Interpreted
|
||||
from tinygrad.ops import UnaryOps, BinaryOps, MovementOps, ReduceOps, TernaryOps, Op, Interpreted
|
||||
from tinygrad.runtime.lib import RawBuffer
|
||||
|
||||
def shape_to_axis(old_shape:Tuple[int, ...], new_shape:Tuple[int, ...]) -> Tuple[int, ...]:
|
||||
@@ -37,7 +37,8 @@ numpy_fxn_for_op: Dict[Op, Callable] = {**base_fxn_for_op, **{
|
||||
BinaryOps.DIV: lambda x, y: np.divide(*match_types(x, y)), UnaryOps.SQRT: np.sqrt,
|
||||
MovementOps.PERMUTE: lambda x, order: x.transpose(order), MovementOps.PAD: np.pad, MovementOps.EXPAND: np.broadcast_to,
|
||||
MovementOps.STRIDE: lambda x, arg: x[tuple(slice(None, None, i) for i in arg)],
|
||||
FusedOps.MULACC: einsum_mulacc(lambda s,a,b: np.einsum(s, *match_types(a.copy(), b.copy()), optimize=True), lambda x: x.strides, np.broadcast_to),
|
||||
TernaryOps.MULACC: einsum_mulacc(lambda s,a,b: np.einsum(s, *match_types(a.copy(), b.copy()), optimize=True), lambda x: x.strides, np.broadcast_to),
|
||||
TernaryOps.WHERE: np.where,
|
||||
}}
|
||||
|
||||
class RawNumpyBuffer(RawBuffer):
|
||||
|
||||
Reference in New Issue
Block a user