Add WHERE ternary (or trinary?) op (#1196)

* Rename FusedOps to TernaryOps

* Support ternary broadcast

* Add where llop and mlop

* Make where op work in cstyle codegen

* Don't skip test_inf_where

* Add backward path to where op

* Use bool in cstyle codegen

* Add LLVM where op

* Add numpy where op

* Add torch where op

* Simplify where mlop

* Update documentation

* Forgot a rename

* Merged relevant changes from PR #1195 onto PR #1196

* Add test to cover changes to linearizer.ast_parse for WHERE op

Without this METAL will try to use ternary op on float4 and fail

* Make where op work in wgsl backend

* Allow ternary ops to be merged

* Make mypy happy

---------

Co-authored-by: Francis Lam <flam@alum.mit.edu>
This commit is contained in:
Adrian Kretz
2023-07-16 09:31:55 +02:00
committed by GitHub
parent 91f797cd52
commit 5a8ad57163
16 changed files with 86 additions and 44 deletions

View File

@@ -1,6 +1,6 @@
# Adding a new accelerator to tinygrad
It's pretty easy to add a new accelerator to tinygrad. All you need to do is implement a total of 26 (optionally 27) low level ops. Then tinygrad takes care of the rest, handling derivatives and syntactic sugar.
It's pretty easy to add a new accelerator to tinygrad. All you need to do is implement a total of 27 (optionally 28) low level ops. Then tinygrad takes care of the rest, handling derivatives and syntactic sugar.
## llops
@@ -12,7 +12,8 @@ reduce_op (SUM, MAX) # A -> B (smaller s
binary_op (ADD, SUB, MUL, DIV, CMPEQ, MAX) # A + A -> A (all the same size)
movement_op (EXPAND, RESHAPE, PERMUTE, PAD, SHRINK, STRIDE) # A -> B (different size)
load_op (EMPTY, RAND, CONST, FROM, CONTIGUOUS, CUSTOM) # -> A (initialize data on device)
fused_op [[optional]] (MULACC) # A * A -> B
ternary_op (WHERE) # A, A, A -> A
ternary_op [[optional]] (MULACC) # A * A -> B
```
## mlops
@@ -23,6 +24,7 @@ Relu, Log, Exp, Sin # unary ops
Sum, Max # reduce ops (with axis argument)
Maximum, Add, Sub, Mul, Pow, Div, Equal # binary ops (no broadcasting, use expand)
Expand, Reshape, Permute, Pad, Shrink, Flip # movement ops
Where # ternary ops
```
These are implemented in [mlops.py](/tinygrad/mlops.py).