mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-07 03:00:26 -04:00
Convert BinaryOps.DIV to UnaryOps.RECIP and BinaryOps.IDIV (#4887)
* Create UnaryOps.RECIP and BinaryOps.IDIV and changing uses of BinaryOps.DIV * Delete unused import * Add cstyle renderer * Fix formatting text * Fix test error due to bad implementation of renderer * Add PTX support * Add RECIP to LLVMIR * Remove BinaryOps.DIV from symbolic test * Change some test and fix C floor division * Change references to DIV for the RECIP or IDIV * Add mimic idiv for symbolic test * Restore floor * Mimic idiv * cast to int * Fix some test and renderer * Remove DIV for render nodes * Resolve issue with div * Add TestRenderer * Fix test * fix error * Fix PAD test * Fix div implementation * Remove DIV * Add upcast to rshift, due to use of MUL and RECIP on DIV * Fix linter * Remove complete BinaryOps.DIV * Fix lint * Fix some test * Revert mul modification * Fix tests * Fix CLANG for uops * Revert IDIV function * Minor fix * modify pattern matching rule to support nan * Fix UNSAFE_PADS_OPS to add UnaryOps.RECIP * Remove const folding for IDIV and fix PTX * Complete remove IDIV from extra * Remove test_div from TestFloatUOps due to test on recip * Fix linearizer * fix * Fix test_22 * Fix llvm * Apply trunc function for llvmlit * use floor instead of trunc * Use correct type * Generate new fuzz db * Fix rshift, do not cast to float to support idiv * Return upcast=false to rshift * Add to unsafepad BinaryOps.IDIV * Remove RECIP override for CUDA * add atol / rtol for the test * Remove cast to int on IDIV * Regenerate sops * delete sops.gz * regenerate * regenerate * regenerate * Reduce margins * pass atol and rtol as parametersg for _test_metrics * regenerated dataset * Regenerate * Remove duplicated * Revert changes on extra * Remove changes extra and NOQA for test * Remove E501 * Remove and change line * Remove E501 * Fix atan2 * Revert import and E501 * Remove E501 * Add hrcp to halp ops * Remove 1 of hrcp * Remove last DIV and add type check on uops for IDIV * Fix new tests * Fix tests and custom function * Regenerate dataset * Regenerate dataset * Revert dataset * Change generate dataset script * Remove line * Change IDIV, type checker validate if x,y and z are int --------- Co-authored-by: George Hotz <72895+geohot@users.noreply.github.com>
This commit is contained in:
Binary file not shown.
@@ -18,7 +18,10 @@ python3 examples/beautiful_cartpole.py
|
||||
python3 examples/mlperf/model_spec.py
|
||||
python3 examples/yolov8.py ./test/models/efficientnet/Chicken.jpg
|
||||
examples/openpilot/go.sh
|
||||
BIG=1 MPS=1 pytest test/ --ignore=test/test_fusion_op.py --ignore=test/test_linearizer_failures.py
|
||||
JIT=2 BIG=1 MPS=1 pytest test/ --ignore=test/test_fusion_op.py --ignore=test/test_linearizer_failures.py --ignore=test/test_gc.py --ignore=test/test_speed_v_torch.py --ignore=test/test_jit.py
|
||||
JIT=2 BIG=1 MPS=1 python -m pytest test/test_gc.py
|
||||
JIT=2 BIG=1 MPS=1 python -m pytest test/test_jit.py
|
||||
JIT=2 BIG=1 MPS=1 python -m pytest test/test_speed_v_torch.py
|
||||
|
||||
# sort and uniq
|
||||
sort -u /tmp/ops > /tmp/sops
|
||||
|
||||
4
test/external/external_test_metrics.py
vendored
4
test/external/external_test_metrics.py
vendored
@@ -7,10 +7,10 @@ import torch
|
||||
import unittest
|
||||
|
||||
class ExternalTestMetrics(unittest.TestCase):
|
||||
def _test_metrics(self, tinygrad_metrics, orig_metrics, pred, label):
|
||||
def _test_metrics(self, tinygrad_metrics, orig_metrics, pred, label, atol=1e-8, rtol=1e-7):
|
||||
tinygrad_metrics_res = tinygrad_metrics(Tensor(pred), Tensor(label)).squeeze().numpy()
|
||||
orig_metrics_res = orig_metrics(torch.from_numpy(pred), torch.from_numpy(label)).numpy()
|
||||
np.testing.assert_equal(tinygrad_metrics_res, orig_metrics_res)
|
||||
np.testing.assert_allclose(tinygrad_metrics_res, orig_metrics_res, atol=atol, rtol=rtol)
|
||||
|
||||
def test_dice(self):
|
||||
pred, label = np.random.rand(1, 3, 128, 128, 128).astype(np.float32), np.ones((1, 1, 128, 128, 128)).astype(np.uint8)
|
||||
|
||||
@@ -31,7 +31,7 @@ def atan2_cpu(ret:Buffer, a:Buffer, b:Buffer): ret.copyin(np.require(np.arctan2(
|
||||
# NOTE: The derivative of atan2 doesn't need a custom op! https://www.liquisearch.com/atan2/derivative
|
||||
# In general, it is also optional to write a backward function, just your backward pass won't work without it
|
||||
|
||||
from tinygrad.ops import LoadOps, BinaryOps
|
||||
from tinygrad.ops import LoadOps, BinaryOps, UnaryOps
|
||||
from tinygrad.lazy import LazyBuffer
|
||||
from tinygrad.tensor import Function
|
||||
|
||||
@@ -42,9 +42,9 @@ class ATan2(Function):
|
||||
return create_lazybuffer(a.device, ShapeTracker.from_shape(a.shape), max(a.dtype, b.dtype), LoadOps.CUSTOM,
|
||||
arg={"GPU": atan2_gpu, "CPU": atan2_cpu}[a.device], srcs=(a.contiguous(), b.contiguous()))
|
||||
def backward(self, grad_output:LazyBuffer) -> Tuple[Optional[LazyBuffer], Optional[LazyBuffer]]:
|
||||
denom = (self.a.e(BinaryOps.MUL, self.a)).e(BinaryOps.ADD, self.b.e(BinaryOps.MUL, self.b))
|
||||
return grad_output.e(BinaryOps.MUL, self.b.e(BinaryOps.DIV, denom)) if self.needs_input_grad[0] else None, \
|
||||
grad_output.e(BinaryOps.MUL, self.a.const(0).e(BinaryOps.SUB, self.a).e(BinaryOps.DIV, denom)) if self.needs_input_grad[1] else None
|
||||
recip = (self.a.e(BinaryOps.MUL, self.a)).e(BinaryOps.ADD, self.b.e(BinaryOps.MUL, self.b)).e(UnaryOps.RECIP)
|
||||
return grad_output.e(BinaryOps.MUL, self.b.e(BinaryOps.MUL, recip)) if self.needs_input_grad[0] else None, \
|
||||
grad_output.e(BinaryOps.MUL, self.a.const(0).e(BinaryOps.SUB, self.a).e(BinaryOps.MUL, recip)) if self.needs_input_grad[1] else None
|
||||
|
||||
# *** third, we use our lovely new mlop in some tests ***
|
||||
|
||||
|
||||
@@ -283,9 +283,7 @@ class TestLinearizer(unittest.TestCase):
|
||||
|
||||
@unittest.skipIf(CI and Device.DEFAULT in {"AMD"}, "AMD CI is really slow here")
|
||||
def test_multireduce_loop_scope(self):
|
||||
# when rendering multiple reducops, any arithmetic on the result of one reduceop will be rendered within the loop of the next
|
||||
# these ops need to be moved out of the loop so they be accessed in any scope
|
||||
ast = LazyOp(op=BufferOps.STORE, src=(LazyOp(op=ReduceOps.SUM, src=(LazyOp(op=BinaryOps.DIV, src=(LazyOp(op=BinaryOps.SUB, src=(LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=1, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(3, 27, 32), strides=(864, 32, 1), offset=0, mask=None, contiguous=True),)))), LazyOp(op=BinaryOps.MUL, src=(LazyOp(op=ReduceOps.SUM, src=(LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=1, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(3, 27, 32), strides=(864, 32, 1), offset=0, mask=None, contiguous=True),)))),), arg=(2,)), LazyOp(op=BufferOps.CONST, src=(), arg=ConstBuffer(val=0.03125, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(3, 27, 32), strides=(0, 0, 0), offset=0, mask=None, contiguous=False),))))), arg=None))),LazyOp(op=UnaryOps.SQRT, src=(LazyOp(op=BinaryOps.MUL, src=(LazyOp(op=ReduceOps.SUM, src=(LazyOp(op=BinaryOps.MUL, src=(LazyOp(op=BinaryOps.SUB, src=(LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=1, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(3, 27, 32), strides=(864, 32, 1), offset=0, mask=None, contiguous=True),)))), LazyOp(op=BinaryOps.MUL, src=(LazyOp(op=ReduceOps.SUM, src=(LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=1, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(3, 27, 32), strides=(864, 32, 1), offset=0, mask=None, contiguous=True),)))),), arg=(2,)), LazyOp(op=BufferOps.CONST, src=(), arg=ConstBuffer(val=0.03125, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(3, 27, 32), strides=(0, 0, 0), offset=0, mask=None, contiguous=False),))))), arg=None)), arg=None), LazyOp(op=BinaryOps.SUB, src=(LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=1, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(3, 27, 32), strides=(864, 32, 1), offset=0, mask=None, contiguous=True),)))), LazyOp(op=BinaryOps.MUL, src=(LazyOp(op=ReduceOps.SUM, src=(LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=1, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(3, 27, 32), strides=(864, 32, 1), offset=0, mask=None, contiguous=True),)))),), arg=(2,)), LazyOp(op=BufferOps.CONST, src=(), arg=ConstBuffer(val=0.03125, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(3, 27, 32), strides=(0, 0, 0), offset=0, mask=None, contiguous=False),))))), arg=None)), arg=None)), arg=None),), arg=(2,)), LazyOp(op=BufferOps.CONST, src=(), arg=ConstBuffer(val=0.03125, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(3, 27, 1), strides=(27, 1, 0), offset=0, mask=None, contiguous=True),))))), arg=None),), arg=None),),),), arg=(2,)),), arg=MemBuffer(idx=0, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(3, 27, 1), strides=(27, 1, 0), offset=0, mask=None, contiguous=True),),))), # noqa: E501
|
||||
ast = LazyOp(op=BufferOps.STORE, src=(LazyOp(op=ReduceOps.SUM, src=(LazyOp(op=BinaryOps.MUL, src=(LazyOp(op=BinaryOps.SUB, src=(LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=1, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(3, 27, 32), strides=(864, 32, 1), offset=0, mask=None, contiguous=True),)))), LazyOp(op=BinaryOps.MUL, src=(LazyOp(op=ReduceOps.SUM, src=(LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=1, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(3, 27, 32), strides=(864, 32, 1), offset=0, mask=None, contiguous=True),)))),), arg=(2,)), LazyOp(op=BufferOps.CONST, src=(), arg=ConstBuffer(val=0.03125, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(3, 27, 32), strides=(0, 0, 0), offset=0, mask=None, contiguous=False),))))), arg=None))), LazyOp(op=UnaryOps.RECIP, src=(LazyOp(op=UnaryOps.SQRT, src=(LazyOp(op=BinaryOps.MUL, src=(LazyOp(op=ReduceOps.SUM, src=(LazyOp(op=BinaryOps.MUL, src=(LazyOp(op=BinaryOps.SUB, src=(LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=1, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(3, 27, 32), strides=(864, 32, 1), offset=0, mask=None, contiguous=True),)))), LazyOp(op=BinaryOps.MUL, src=(LazyOp(op=ReduceOps.SUM, src=(LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=1, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(3, 27, 32), strides=(864, 32, 1), offset=0, mask=None, contiguous=True),)))),), arg=(2,)), LazyOp(op=BufferOps.CONST, src=(), arg=ConstBuffer(val=0.03125, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(3, 27, 32), strides=(0, 0, 0), offset=0, mask=None, contiguous=False),))))), arg=None)), arg=None), LazyOp(op=BinaryOps.SUB, src=(LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=1, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(3, 27, 32), strides=(864, 32, 1), offset=0, mask=None, contiguous=True),)))), LazyOp(op=BinaryOps.MUL, src=(LazyOp(op=ReduceOps.SUM, src=(LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=1, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(3, 27, 32), strides=(864, 32, 1), offset=0, mask=None, contiguous=True),)))),), arg=(2,)), LazyOp(op=BufferOps.CONST, src=(), arg=ConstBuffer(val=0.03125, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(3, 27, 32), strides=(0, 0, 0), offset=0, mask=None, contiguous=False),))))), arg=None)), arg=None)), arg=None),), arg=(2,)), LazyOp(op=BufferOps.CONST, src=(), arg=ConstBuffer(val=0.03125, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(3, 27, 1), strides=(27, 1, 0), offset=0, mask=None, contiguous=True),))))), arg=None),), arg=None),)),),),), arg=(2,)),), arg=MemBuffer(idx=0, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(3, 27, 1), strides=(27, 1, 0), offset=0, mask=None, contiguous=True),),))), # noqa: E501
|
||||
k = Linearizer(*ast)
|
||||
k.hand_coded_optimizations()
|
||||
k.linearize()
|
||||
@@ -327,7 +325,7 @@ class TestLinearizer(unittest.TestCase):
|
||||
centered_x = LazyOp(op=BinaryOps.SUB, src=(x_ast, max_x))
|
||||
exp_x = LazyOp(op=UnaryOps.EXP2, src=(centered_x,))
|
||||
sum_exp_x = LazyOp(op=ReduceOps.SUM, src=(exp_x,), arg=(1,))
|
||||
y = LazyOp(op=BinaryOps.DIV, src=(exp_x, sum_exp_x))
|
||||
y = LazyOp(op=BinaryOps.MUL, src=(exp_x, LazyOp(op=UnaryOps.RECIP, src=(sum_exp_x,))))
|
||||
y_reduced = LazyOp(op=ReduceOps.SUM, src=(y,), arg=(1,))
|
||||
ast = LazyOp(op=BufferOps.STORE, src=(y_reduced,), arg=MemBuffer(idx=0, dtype=dtypes.float, st=ShapeTracker.from_shape((4,1))))
|
||||
expected = ((np_exp2:=np.exp2(x.numpy() - x.numpy().max(axis=-1, keepdims=True)))/np_exp2.sum(axis=-1, keepdims=True)).sum(axis=-1)
|
||||
@@ -340,7 +338,7 @@ class TestLinearizer(unittest.TestCase):
|
||||
max_x = LazyOp(op=ReduceOps.MAX, src=(x_ast,), arg=(1,))
|
||||
exp_x = LazyOp(op=UnaryOps.EXP2, src=(LazyOp(op=BinaryOps.SUB, src=(x_ast, max_x)),))
|
||||
sum_exp_x = LazyOp(op=ReduceOps.SUM, src=(exp_x,), arg=(1,))
|
||||
ast = LazyOp(op=BufferOps.STORE, src=(LazyOp(op=ReduceOps.SUM, src=(LazyOp(op=BinaryOps.DIV, src=(exp_x, sum_exp_x)),), arg=(1,)),), arg=MemBuffer(idx=0, dtype=dtypes.float, st=ShapeTracker.from_shape((4,1)))) # noqa: E501
|
||||
ast = LazyOp(op=BufferOps.STORE, src=(LazyOp(op=ReduceOps.SUM, src=(LazyOp(op=BinaryOps.MUL, src=(exp_x, LazyOp(op=UnaryOps.RECIP, src=(sum_exp_x,)))),), arg=(1,)),), arg=MemBuffer(idx=0, dtype=dtypes.float, st=ShapeTracker.from_shape((4,1)))) # noqa: E501
|
||||
max_x_ast = LazyOp(op=BufferOps.STORE, src=(max_x,), arg=MemBuffer(idx=1, dtype=dtypes.float, st=ShapeTracker.from_shape((4,1))))
|
||||
sum_exp_x_ast = LazyOp(op=BufferOps.STORE, src=(sum_exp_x,), arg=MemBuffer(idx=2, dtype=dtypes.float, st=ShapeTracker.from_shape((4,1))))
|
||||
expected = [
|
||||
|
||||
File diff suppressed because one or more lines are too long
@@ -23,7 +23,7 @@ def _test_overflow(ast, opts):
|
||||
@unittest.skip("unneeded without launch bounds")
|
||||
class TestLinearizerOverflow(unittest.TestCase):
|
||||
def test_overflow_1(self):
|
||||
ast = LazyOp(op=BufferOps.STORE, src=(LazyOp(op=BinaryOps.MAX, src=(LazyOp(op=BinaryOps.ADD, src=(LazyOp(op=BinaryOps.MUL, src=(LazyOp(op=BinaryOps.MUL, src=(LazyOp(op=BinaryOps.SUB, src=(LazyOp(op=ReduceOps.SUM, src=(LazyOp(op=BinaryOps.MUL, src=(LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=1, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(1, 64, 1, 3, 8, 230, 8, 230), strides=(0, 150528, 0, 50176, 0, 224, 0, 1), offset=-675, mask=((0, 1), (0, 64), (0, 1), (0, 3), (0, 8), (3, 227), (0, 8), (3, 227)), contiguous=False), View(shape=(64, 1, 64, 112, 112, 3, 7, 7), strides=(10156800, 0, 0, 3680, 2, 3385600, 425040, 231), offset=0, mask=None, contiguous=False))))), LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=2, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(64, 1, 64, 112, 112, 3, 7, 7), strides=(0, 0, 147, 0, 0, 49, 7, 1), offset=0, mask=None, contiguous=False),))))), arg=None),), arg=(7, 6, 5)), LazyOp(op=BufferOps.CONST, src=(), arg=ConstBuffer(val=0.0, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(64, 1, 64, 112, 112, 1, 1, 1), strides=(0, 0, 0, 0, 0, 0, 0, 0), offset=0, mask=None, contiguous=False),))))), arg=None), LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=3, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(64, 1, 64, 112, 112, 1, 1, 1), strides=(0, 0, 1, 0, 0, 0, 0, 0), offset=0, mask=None, contiguous=False),))))), arg=None), LazyOp(op=UnaryOps.SQRT, src=(LazyOp(op=BinaryOps.DIV, src=(LazyOp(op=BufferOps.CONST, src=(), arg=ConstBuffer(val=1.0, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(64, 1, 64, 112, 112, 1, 1, 1), strides=(0, 0, 0, 0, 0, 0, 0, 0), offset=0, mask=None, contiguous=False),)))), LazyOp(op=BinaryOps.ADD, src=(LazyOp(op=BufferOps.CONST, src=(), arg=ConstBuffer(val=1.0, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(64, 1, 64, 112, 112, 1, 1, 1), strides=(0, 0, 0, 0, 0, 0, 0, 0), offset=0, mask=None, contiguous=False),)))), LazyOp(op=BufferOps.CONST, src=(), arg=ConstBuffer(val=1e-05, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(64, 1, 64, 112, 112, 1, 1, 1), strides=(0, 0, 0, 0, 0, 0, 0, 0), offset=0, mask=None, contiguous=False),))))), arg=None)), arg=None),), arg=None)), arg=None), LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=4, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(64, 1, 64, 112, 112, 1, 1, 1), strides=(0, 0, 1, 0, 0, 0, 0, 0), offset=0, mask=None, contiguous=False),))))), arg=None), LazyOp(op=BufferOps.CONST, src=(), arg=ConstBuffer(val=0.0, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(64, 1, 64, 112, 112, 1, 1, 1), strides=(0, 0, 0, 0, 0, 0, 0, 0), offset=0, mask=None, contiguous=False),))))), arg=None),), arg=MemBuffer(idx=0, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(64, 1, 64, 112, 112, 1, 1, 1), strides=(802816, 0, 12544, 112, 1, 0, 0, 0), offset=0, mask=None, contiguous=True),))))
|
||||
ast = LazyOp(op=BufferOps.STORE, src=(LazyOp(op=BinaryOps.MAX, src=(LazyOp(op=BinaryOps.ADD, src=(LazyOp(op=BinaryOps.MUL, src=(LazyOp(op=BinaryOps.MUL, src=(LazyOp(op=BinaryOps.SUB, src=(LazyOp(op=ReduceOps.SUM, src=(LazyOp(op=BinaryOps.MUL, src=(LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=1, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(1, 64, 1, 3, 8, 230, 8, 230), strides=(0, 150528, 0, 50176, 0, 224, 0, 1), offset=-675, mask=((0, 1), (0, 64), (0, 1), (0, 3), (0, 8), (3, 227), (0, 8), (3, 227)), contiguous=False), View(shape=(64, 1, 64, 112, 112, 3, 7, 7), strides=(10156800, 0, 0, 3680, 2, 3385600, 425040, 231), offset=0, mask=None, contiguous=False))))), LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=2, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(64, 1, 64, 112, 112, 3, 7, 7), strides=(0, 0, 147, 0, 0, 49, 7, 1), offset=0, mask=None, contiguous=False),))))), arg=None),), arg=(7, 6, 5)), LazyOp(op=BufferOps.CONST, src=(), arg=ConstBuffer(val=0.0, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(64, 1, 64, 112, 112, 1, 1, 1), strides=(0, 0, 0, 0, 0, 0, 0, 0), offset=0, mask=None, contiguous=False),))))), arg=None), LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=3, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(64, 1, 64, 112, 112, 1, 1, 1), strides=(0, 0, 1, 0, 0, 0, 0, 0), offset=0, mask=None, contiguous=False),))))), arg=None), LazyOp(op=UnaryOps.SQRT, src=(LazyOp(op=BinaryOps.MUL, src=(LazyOp(op=BufferOps.CONST, src=(), arg=ConstBuffer(val=1.0, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(64, 1, 64, 112, 112, 1, 1, 1), strides=(0, 0, 0, 0, 0, 0, 0, 0), offset=0, mask=None, contiguous=False),)))), LazyOp(op=UnaryOps.RECIP, src=(LazyOp(op=BinaryOps.ADD, src=(LazyOp(op=BufferOps.CONST, src=(), arg=ConstBuffer(val=1.0, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(64, 1, 64, 112, 112, 1, 1, 1), strides=(0, 0, 0, 0, 0, 0, 0, 0), offset=0, mask=None, contiguous=False),)))), LazyOp(op=BufferOps.CONST, src=(), arg=ConstBuffer(val=1e-05, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(64, 1, 64, 112, 112, 1, 1, 1), strides=(0, 0, 0, 0, 0, 0, 0, 0), offset=0, mask=None, contiguous=False),))))), arg=None),))), arg=None),), arg=None)), arg=None), LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=4, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(64, 1, 64, 112, 112, 1, 1, 1), strides=(0, 0, 1, 0, 0, 0, 0, 0), offset=0, mask=None, contiguous=False),))))), arg=None), LazyOp(op=BufferOps.CONST, src=(), arg=ConstBuffer(val=0.0, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(64, 1, 64, 112, 112, 1, 1, 1), strides=(0, 0, 0, 0, 0, 0, 0, 0), offset=0, mask=None, contiguous=False),))))), arg=None),), arg=MemBuffer(idx=0, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(64, 1, 64, 112, 112, 1, 1, 1), strides=(802816, 0, 12544, 112, 1, 0, 0, 0), offset=0, mask=None, contiguous=True),))))
|
||||
opts = [Opt(op=OptOps.LOCAL, axis=3, amt=16), Opt(op=OptOps.LOCAL, axis=2, amt=16), Opt(op=OptOps.UPCAST, axis=1, amt=4), Opt(op=OptOps.UPCAST, axis=0, amt=4), Opt(op=OptOps.UPCAST, axis=2, amt=0)]
|
||||
_test_overflow(ast, opts)
|
||||
|
||||
|
||||
@@ -104,12 +104,12 @@ class TestFloatUOps(TestUOps):
|
||||
def test_exp2(self): self._test_uop_fxn(UnaryOps.EXP2, lambda a: np.exp2(a))
|
||||
def test_log2(self): self._test_uop_fxn(UnaryOps.LOG2, lambda a: math.log2(a) if a > 0 else float('-inf' if a==0 else 'nan'))
|
||||
def test_sin(self): self._test_uop_fxn(UnaryOps.SIN, lambda a: math.sin(a))
|
||||
def test_recip(self): self._test_uop_fxn(UnaryOps.RECIP, lambda a: 1/a if a != 0 else float('inf'))
|
||||
def test_sqrt(self): self._test_uop_fxn(UnaryOps.SQRT, lambda a: math.sqrt(a) if a >= 0 else float('nan'))
|
||||
|
||||
def test_add(self): self._test_bop_fxn(BinaryOps.ADD, lambda a,b: a+b)
|
||||
def test_sub(self): self._test_bop_fxn(BinaryOps.SUB, lambda a,b: a-b)
|
||||
def test_mul(self): self._test_bop_fxn(BinaryOps.MUL, lambda a,b: a*b)
|
||||
def test_div(self): self._test_bop_fxn(BinaryOps.DIV, lambda a,b: a/b if b != 0 else a*float('inf'))
|
||||
def test_max(self): self._test_bop_fxn(BinaryOps.MAX, lambda a,b: max(a,b))
|
||||
def test_cmplt(self): self._test_bop_fxn(BinaryOps.CMPLT, lambda a,b: a<b)
|
||||
# MOD isn't tested on floats
|
||||
@@ -127,7 +127,7 @@ class TestNonFloatUOps(TestUOps):
|
||||
@unittest.skipUnless(getenv("PTX"), "only ptx uses bitshifts")
|
||||
def test_shl_int32(self): self._test_bop_fxn(BinaryOps.SHL, lambda a,b: int(a)<<int(b), (dtypes.int32, dtypes.int32), no_b_neg=True)
|
||||
def test_div_int32(self):
|
||||
self._test_bop_fxn(BinaryOps.DIV, lambda a,b: int(a/b), (dtypes.int32, dtypes.int32), no_b_zero=True)
|
||||
self._test_bop_fxn(BinaryOps.IDIV, lambda a,b: int(a/b), (dtypes.int32, dtypes.int32), no_b_zero=True)
|
||||
def test_mod_int32(self):
|
||||
self._test_bop_fxn(BinaryOps.MOD,
|
||||
lambda a,b: abs(int(a))%abs(int(b))*(1,-1)[a<0], (dtypes.int32, dtypes.int32), no_b_zero=True)
|
||||
@@ -170,14 +170,23 @@ class TestExecALU(TestUOps):
|
||||
self.assertEqual(exec_alu(UnaryOps.SQRT, dtypes.float, (0.0,)), 0.0)
|
||||
|
||||
def test_div(self):
|
||||
self.assertEqual(exec_alu(BinaryOps.DIV, dtypes.int8, (8, 2)), 4)
|
||||
self.assertEqual(exec_alu(BinaryOps.DIV, dtypes.int8, (7, 3)), 2)
|
||||
self.assertEqual(exec_alu(BinaryOps.DIV, dtypes.int8, (7, -3)), -2)
|
||||
self.assertEqual(exec_alu(BinaryOps.DIV, dtypes.int8, (-50, 6)), -8)
|
||||
self.assertEqual(exec_alu(BinaryOps.IDIV, dtypes.int8, (8, 2)), 4)
|
||||
self.assertEqual(exec_alu(BinaryOps.IDIV, dtypes.int8, (7, 3)), 2)
|
||||
self.assertEqual(exec_alu(BinaryOps.IDIV, dtypes.int8, (7, -3)), -2)
|
||||
self.assertEqual(exec_alu(BinaryOps.IDIV, dtypes.int8, (-50, 6)), -8)
|
||||
|
||||
np.testing.assert_allclose(exec_alu(BinaryOps.DIV, dtypes.float32, (8.0, 2.0)), 4.0)
|
||||
np.testing.assert_allclose(exec_alu(BinaryOps.DIV, dtypes.float32, (7.0, 3.0)), 2+(1.0/3.0))
|
||||
np.testing.assert_allclose(exec_alu(BinaryOps.DIV, dtypes.float32, (7.0, -3.0)), -2-(1.0/3.0))
|
||||
np.testing.assert_allclose(exec_alu(BinaryOps.MUL, dtypes.float32, (7.0, exec_alu(UnaryOps.RECIP, dtypes.float32, (3.0,)))), 2+(1.0/3.0))
|
||||
np.testing.assert_allclose(exec_alu(BinaryOps.MUL, dtypes.float32, (7.0, exec_alu(UnaryOps.RECIP, dtypes.float32, (-3.0,)))), -2-(1.0/3.0))
|
||||
|
||||
def test_recip(self):
|
||||
np.testing.assert_allclose(exec_alu(UnaryOps.RECIP, dtypes.float32, (8,)), 1/8)
|
||||
np.testing.assert_allclose(exec_alu(UnaryOps.RECIP, dtypes.float32, (7,)), 1/7)
|
||||
np.testing.assert_allclose(exec_alu(UnaryOps.RECIP, dtypes.float32, (-3,)), 1/-3)
|
||||
np.testing.assert_allclose(exec_alu(UnaryOps.RECIP, dtypes.float32, (-50,)), 1/-50)
|
||||
|
||||
np.testing.assert_allclose(exec_alu(UnaryOps.RECIP, dtypes.float32, ((32+521+3),)), 1/(32+521+3))
|
||||
np.testing.assert_allclose(exec_alu(UnaryOps.RECIP, dtypes.float32, ((34**2),)), 1/(34**2))
|
||||
np.testing.assert_allclose(exec_alu(UnaryOps.RECIP, dtypes.float32, (10,)), 1/10)
|
||||
|
||||
def test_bool_neg(self):
|
||||
self.assertEqual(exec_alu(UnaryOps.NEG, dtypes.bool, (False,)), True)
|
||||
@@ -306,11 +315,11 @@ class TestAssembly(unittest.TestCase):
|
||||
c1 = uops.add(UOps.CONST, dtypes.int, (), 2)
|
||||
c2 = uops.add(UOps.CONST, dtypes.int, (), 3)
|
||||
l1 = uops.add(UOps.LOAD, dtypes.int, (g1, c1))
|
||||
a1 = uops.add(UOps.ALU, dtypes.int, (l1, c1), BinaryOps.DIV)
|
||||
a2 = uops.add(UOps.ALU, dtypes.int, (l1, c2), BinaryOps.DIV)
|
||||
a1 = uops.add(UOps.ALU, dtypes.int, (l1, c1), BinaryOps.IDIV)
|
||||
a2 = uops.add(UOps.ALU, dtypes.int, (l1, c2), BinaryOps.IDIV)
|
||||
uops.add(UOps.SINK, None, (a1,a2))
|
||||
Device[Device.DEFAULT].renderer.render("test", uops)
|
||||
self.assertEqual(uops.uops[-1].arg, BinaryOps.DIV)
|
||||
self.assertEqual(uops.uops[-1].arg, BinaryOps.IDIV)
|
||||
self.assertEqual(uops.uops[-2].arg, BinaryOps.SHR)
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
@@ -20,7 +20,7 @@ def render(self) -> str:
|
||||
if DEBUG>=5: graph.print()
|
||||
from tinygrad.renderer.cstyle import CStyleLanguage
|
||||
class TestRenderer(CStyleLanguage):
|
||||
code_for_op = {**CStyleLanguage().code_for_op, BinaryOps.DIV: lambda a,b,dtype: f"({a}//{b})"}
|
||||
code_for_op = {**CStyleLanguage().code_for_op, BinaryOps.IDIV: lambda a,b,dtype: f"({a}//{b})"}
|
||||
fxn = TestRenderer().render("", graph)
|
||||
return fxn.split("data0[0] = ")[1].split(";")[0]
|
||||
|
||||
|
||||
@@ -62,7 +62,7 @@ class Linearizer(Kernel):
|
||||
|
||||
render_ops: Any = { Variable: lambda self, ops, ctx: ctx.loop_uops[self.expr], NumNode: lambda self, ops, ctx: ctx.const(self.b),
|
||||
MulNode: lambda self, ops, ctx: ctx.uop_alu_idx(self.a.render(ops, ctx), self.b, ops, ctx, BinaryOps.MUL),
|
||||
DivNode: lambda self, ops, ctx: ctx.uop_alu_idx(self.a.render(ops, ctx), self.b, ops, ctx, BinaryOps.DIV),
|
||||
DivNode: lambda self, ops, ctx: ctx.uop_alu_idx(self.a.render(ops, ctx), self.b, ops, ctx, BinaryOps.IDIV),
|
||||
ModNode: lambda self, ops, ctx: ctx.uop_alu_idx(self.a.render(ops, ctx), self.b, ops, ctx, BinaryOps.MOD),
|
||||
LtNode: lambda self, ops, ctx: ctx.uop_alu_idx(self.a.render(ops, ctx), self.b, ops, ctx, BinaryOps.CMPLT),
|
||||
SumNode: lambda self,ops,ctx:
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
from __future__ import annotations
|
||||
from typing import Iterator, Optional, Tuple, Any, Dict, List, DefaultDict, Set, Callable, Union, cast, TypeVar
|
||||
import functools, itertools, heapq
|
||||
import functools, itertools, heapq, math
|
||||
from collections import defaultdict
|
||||
from enum import Enum, auto
|
||||
from dataclasses import dataclass, field
|
||||
@@ -49,7 +49,7 @@ class UOp:
|
||||
def __sub__(self, x): return UOp.alu(BinaryOps.SUB, self, ufix(self.dtype, x))
|
||||
def __mul__(self, x): return UOp.alu(BinaryOps.MUL, self, ufix(self.dtype, x))
|
||||
def __rmul__(self, x): return UOp.alu(BinaryOps.MUL, ufix(self.dtype, x), self)
|
||||
def __floordiv__(self, x): return UOp.alu(BinaryOps.DIV, self, ufix(self.dtype, x))
|
||||
def __floordiv__(self, x): return UOp.alu(BinaryOps.IDIV, self, ufix(self.dtype, x))
|
||||
def __mod__(self, x): return UOp.alu(BinaryOps.MOD, self, ufix(self.dtype, x))
|
||||
@staticmethod
|
||||
def max(x, y): return UOp.alu(BinaryOps.MAX, x, y)
|
||||
@@ -139,7 +139,7 @@ def loop_collapse(loop_start, loop_end, compval, idx, mval, multconst):
|
||||
# TODO: support and test this with other mvals and loop_starts
|
||||
if DEBUG >= 1: print(f"WARNING, NOT FOLDING: mval:{mval.arg} loop_start:{loop_start.arg}")
|
||||
return None
|
||||
comprange = UOp.min(loop_end, UOp.max(UOp.alu(BinaryOps.DIV, idx-compval-mval, mval) + (loop_end-loop_start), loop_start))
|
||||
comprange = UOp.min(loop_end, UOp.max(UOp.alu(BinaryOps.IDIV, idx-compval-mval, mval) + (loop_end-loop_start), loop_start))
|
||||
return UOp(UOps.UNMUL, multconst.dtype, (comprange.cast(multconst.dtype) * multconst, loop_end-loop_start))
|
||||
|
||||
# this is symbolic 2.0
|
||||
@@ -194,10 +194,13 @@ constant_folder = PatternMatcher([
|
||||
(UPat(UOps.ALU, BinaryOps.ADD, [UPat(name="x"), UPat(UOps.CONST, 0)]), lambda x: x), # x+0 -> x or 0+x -> x
|
||||
(UPat(UOps.ALU, BinaryOps.MUL, [UPat(name="x"), UPat(UOps.CONST, 1)]), lambda x: x), # x*1 -> x or 1*x -> x
|
||||
(UPat(UOps.ALU, BinaryOps.SUB, (UPat(name="x"), UPat(UOps.CONST, 0))), lambda x: x), # x-0 -> x
|
||||
(UPat(UOps.ALU, BinaryOps.DIV, (UPat(name="x"), UPat(UOps.CONST, 1))), lambda x: x), # x/1 -> x
|
||||
(UPat(UOps.ALU, BinaryOps.DIV, (UPat(name="x"), UPat(UOps.CONST, -1))), lambda x: -x), # x/-1 -> -x
|
||||
(UPat(UOps.ALU, BinaryOps.IDIV, (UPat(name="x"), UPat(UOps.CONST, 1))), lambda x: x), # x/1 -> x
|
||||
(UPat(UOps.ALU, BinaryOps.IDIV, (UPat(name="x"), UPat(UOps.CONST, -1))), lambda x: -x), # x/-1 -> -x
|
||||
# ** zero folding **
|
||||
(UPat(UOps.ALU, BinaryOps.MUL, [UPat(), UPat(UOps.CONST, 0, name="c")]), lambda c: c), # x*0 -> 0 or 0*x -> 0
|
||||
#x*0 -> 0 or 0*x -> 0
|
||||
#if x is nan it should render the nan value.
|
||||
(UPat(UOps.ALU, BinaryOps.MUL, [UPat(name="x"), UPat(UOps.CONST, 0, name="c")]),
|
||||
lambda x,c: x if isinstance(x.arg, float) and math.isnan(x.arg) else c),
|
||||
(UPat(UOps.ALU, BinaryOps.SUB, (UPat(name="x"), UPat(name="x"))), lambda x: UOp.const(x.dtype, 0)), # x-x -> 0
|
||||
# ** load/store folding **
|
||||
(UPat(UOps.STORE, vin=(UPat(name="buf"), UPat(name="idx"),
|
||||
@@ -214,9 +217,9 @@ constant_folder = PatternMatcher([
|
||||
(UPat(UOps.ALU, BinaryOps.ADD, (UPat(UOps.ALU, BinaryOps.MUL, [UPat(UOps.CONST, name="c0"), UPat(name="x")]), # (x*c0)+(x*c1) -> x*(c0+c1)
|
||||
UPat(UOps.ALU, BinaryOps.MUL, [UPat(UOps.CONST, name="c1"), UPat(name="x")]),)),
|
||||
lambda x,c0,c1: x*exec_alu(BinaryOps.ADD, x.dtype, [c0.arg, c1.arg])),
|
||||
(UPat(UOps.ALU, BinaryOps.DIV, (UPat(UOps.ALU, BinaryOps.MUL, [UPat(UOps.CONST, name="c0"), UPat(name="x")]),
|
||||
(UPat(UOps.ALU, BinaryOps.IDIV, (UPat(UOps.ALU, BinaryOps.MUL, [UPat(UOps.CONST, name="c0"), UPat(name="x")]),
|
||||
UPat(UOps.CONST, name="c0"))), lambda x,c0: x if c0.arg != 0 else None), # (x*c0)/c0 -> x
|
||||
(UPat(UOps.ALU, BinaryOps.DIV, (UPat(UOps.ALU, BinaryOps.DIV, (UPat(name="x"), UPat(UOps.CONST, name="c0"))), UPat(UOps.CONST, name="c1"))),
|
||||
(UPat(UOps.ALU, BinaryOps.IDIV, (UPat(UOps.ALU, BinaryOps.IDIV, (UPat(name="x"), UPat(UOps.CONST, name="c0"))), UPat(UOps.CONST, name="c1"))),
|
||||
lambda x,c0,c1: x//UOp.const(x.dtype, exec_alu(BinaryOps.MUL, x.dtype, [c0.arg, c1.arg]))), # (x/c0)/c1 -> x/(c0*c1)
|
||||
(UPat(UOps.ALU, BinaryOps.CMPLT, (UPat(UOps.ALU, BinaryOps.ADD, [UPat(UOps.CONST, name="c0"), UPat(name="x")]), UPat(UOps.CONST, name="c1"))),
|
||||
lambda x,c0,c1: UOp.alu(BinaryOps.CMPLT, x, UOp.const(x.dtype, exec_alu(BinaryOps.SUB, x.dtype, [c1.arg, c0.arg])))),
|
||||
@@ -452,6 +455,10 @@ class UOpGraph:
|
||||
elif arg in (BinaryOps.CMPLT, BinaryOps.CMPNE):
|
||||
assert dtype == dtypes.bool, f"{arg} output dtype mismatch {dtype=} != {dtypes.bool}"
|
||||
assert vin[0].dtype == vin[1].dtype, f"{arg} dtype mismatch {dtype=} != {vin[0].dtype=} != {vin[1].dtype=}"
|
||||
elif arg is BinaryOps.IDIV:
|
||||
assert dtypes.is_int(vin[0].dtype) and dtypes.is_int(vin[1].dtype), \
|
||||
f"input dtype mismatch {dtypes.int} != {vin[0].dtype=} != {vin[1].dtype=}"
|
||||
assert dtypes.is_int(dtype), f"{arg} output dtype mismatch {dtype=} != {dtypes.int}"
|
||||
elif arg in BinaryOps:
|
||||
assert dtype == vin[0].dtype == vin[1].dtype, f"{arg} dtype mismatch {dtype=} != {vin[0].dtype=} != {vin[1].dtype=}"
|
||||
elif arg == TernaryOps.WHERE:
|
||||
|
||||
@@ -31,7 +31,7 @@ class Neg(Function):
|
||||
|
||||
class Reciprocal(Function):
|
||||
def forward(self, x:LazyBuffer) -> LazyBuffer:
|
||||
self.ret = x.const(1).e(BinaryOps.DIV, x)
|
||||
self.ret = x.e(UnaryOps.RECIP)
|
||||
return self.ret
|
||||
def backward(self, grad_output:LazyBuffer) -> LazyBuffer:
|
||||
return grad_output.e(UnaryOps.NEG).e(BinaryOps.MUL, self.ret).e(BinaryOps.MUL, self.ret)
|
||||
@@ -58,7 +58,7 @@ class Log(Function):
|
||||
self.x = x
|
||||
return x.e(UnaryOps.LOG2).e(BinaryOps.MUL, x.const(math.log(2)))
|
||||
|
||||
def backward(self, grad_output:LazyBuffer) -> LazyBuffer: return grad_output.e(BinaryOps.DIV, self.x)
|
||||
def backward(self, grad_output:LazyBuffer) -> LazyBuffer: return grad_output.e(BinaryOps.MUL, self.x.e(UnaryOps.RECIP))
|
||||
|
||||
class Exp(Function):
|
||||
def forward(self, x:LazyBuffer) -> LazyBuffer:
|
||||
@@ -73,14 +73,14 @@ class Sqrt(Function):
|
||||
return self.ret
|
||||
|
||||
def backward(self, grad_output:LazyBuffer) -> LazyBuffer:
|
||||
return grad_output.e(BinaryOps.DIV, self.ret.e(BinaryOps.MUL, self.ret.const(2)))
|
||||
return grad_output.e(BinaryOps.MUL, self.ret.e(BinaryOps.MUL, self.ret.const(2)).e(UnaryOps.RECIP))
|
||||
|
||||
# NOTE: the implicit derivative of sigmoid is not stable
|
||||
# https://towardsdatascience.com/derivative-of-the-sigmoid-function-536880cf918e
|
||||
# TODO: have the backend automatically find this
|
||||
class Sigmoid(Function):
|
||||
def forward(self, x:LazyBuffer) -> LazyBuffer:
|
||||
self.ret = x.const(1).e(BinaryOps.DIV, x.const(1).e(BinaryOps.ADD, x.e(BinaryOps.MUL, x.const(-1/math.log(2))).e(UnaryOps.EXP2)))
|
||||
self.ret = x.const(1).e(BinaryOps.ADD, x.e(BinaryOps.MUL, x.const(-1/math.log(2))).e(UnaryOps.EXP2)).e(UnaryOps.RECIP)
|
||||
return self.ret
|
||||
|
||||
def backward(self, grad_output:LazyBuffer) -> LazyBuffer:
|
||||
@@ -132,11 +132,11 @@ class Mul(Function):
|
||||
class Div(Function):
|
||||
def forward(self, x:LazyBuffer, y:LazyBuffer) -> LazyBuffer:
|
||||
self.x, self.y = x, y
|
||||
return x.e(BinaryOps.DIV, y)
|
||||
return x.e(BinaryOps.MUL, y.e(UnaryOps.RECIP)) if not dtypes.is_int(x.dtype) else x.e(BinaryOps.IDIV, y)
|
||||
|
||||
def backward(self, grad_output:LazyBuffer) -> Tuple[Optional[LazyBuffer], Optional[LazyBuffer]]:
|
||||
return grad_output.e(BinaryOps.DIV, self.y) if self.needs_input_grad[0] else None, \
|
||||
grad_output.e(UnaryOps.NEG).e(BinaryOps.MUL, self.x).e(BinaryOps.DIV, self.y.e(BinaryOps.MUL, self.y)) if self.needs_input_grad[1] else None # noqa: E501
|
||||
return grad_output.e(BinaryOps.MUL, self.y.e(UnaryOps.RECIP)) if self.needs_input_grad[0] else None, \
|
||||
grad_output.e(UnaryOps.NEG).e(BinaryOps.MUL, self.x).e(BinaryOps.MUL, self.y.e(BinaryOps.MUL, self.y).e(UnaryOps.RECIP)) if self.needs_input_grad[1] else None # noqa: E501
|
||||
|
||||
# ************* ternary ops *************
|
||||
|
||||
@@ -168,7 +168,7 @@ class Max(Function):
|
||||
# 1s in locations where the max was chosen (can be two locations)
|
||||
max_is_1s = self.x.const(1.0).cast(dtypes.float).e(BinaryOps.SUB, self.x.e(BinaryOps.CMPNE, self.ret.expand(self.x.shape)).cast(dtypes.float))
|
||||
div = max_is_1s.r(ReduceOps.SUM, self.axis).expand(self.x.shape)
|
||||
return max_is_1s.e(BinaryOps.DIV, div).cast(grad_output.dtype).e(BinaryOps.MUL, grad_output.expand(self.x.shape))
|
||||
return max_is_1s.e(BinaryOps.MUL, div.e(UnaryOps.RECIP)).cast(grad_output.dtype).e(BinaryOps.MUL, grad_output.expand(self.x.shape))
|
||||
|
||||
# ************* movement ops *************
|
||||
|
||||
|
||||
@@ -163,8 +163,6 @@ class LazyBuffer:
|
||||
return y if val == 1 else y.const(0) if val == 0 else y.e(UnaryOps.NEG)
|
||||
if y.is_unrealized_unmasked_const() and (val := y.base.arg) in (1, 0, -1):
|
||||
return x if val == 1 else x.const(0) if val == 0 else x.e(UnaryOps.NEG)
|
||||
if op is BinaryOps.DIV and dtypes.is_float(x.dtype) and y.is_unrealized_unmasked_const() and y.base.arg != 0:
|
||||
return x.e(BinaryOps.MUL, x.const(1 / y.base.arg))
|
||||
|
||||
return create_lazybuffer(self.device, ShapeTracker.from_shape(self.shape), out_dtype, op, arg, tuple(srcs))
|
||||
|
||||
|
||||
@@ -14,10 +14,10 @@ from tinygrad.shape.shapetracker import ShapeTracker
|
||||
# NOTE: many GPUs don't have DIV, but UnaryOps.RECIP doesn't work for integer division
|
||||
class UnaryOps(Enum):
|
||||
"""A -> A (elementwise)"""
|
||||
EXP2 = auto(); LOG2 = auto(); CAST = auto(); BITCAST = auto(); SIN = auto(); SQRT = auto(); NEG = auto() # noqa: E702
|
||||
EXP2 = auto(); LOG2 = auto(); CAST = auto(); BITCAST = auto(); SIN = auto(); SQRT = auto(); NEG = auto(); RECIP = auto() # noqa: E702
|
||||
class BinaryOps(Enum):
|
||||
"""A + A -> A (elementwise)"""
|
||||
ADD = auto(); SUB = auto(); MUL = auto(); DIV = auto(); MAX = auto(); MOD = auto(); CMPLT = auto(); CMPNE = auto(); XOR = auto() # noqa: E702
|
||||
ADD = auto(); SUB = auto(); MUL = auto(); IDIV = auto(); MAX = auto(); MOD = auto(); CMPLT = auto(); CMPNE = auto(); XOR = auto() # noqa: E702
|
||||
SHR = auto(); SHL = auto() # noqa: E702
|
||||
class TernaryOps(Enum):
|
||||
"""A + A + A -> A (elementwise)"""
|
||||
@@ -31,7 +31,7 @@ class LoadOps(Enum): EMPTY = auto(); CONST = auto(); COPY = auto(); CONTIGUOUS =
|
||||
Op = Union[UnaryOps, BinaryOps, ReduceOps, LoadOps, TernaryOps, BufferOps]
|
||||
|
||||
# do not preserve f(0) = 0
|
||||
UNSAFE_PAD_OPS = {BinaryOps.DIV, UnaryOps.LOG2, UnaryOps.EXP2}
|
||||
UNSAFE_PAD_OPS = {UnaryOps.RECIP, UnaryOps.LOG2, UnaryOps.EXP2, BinaryOps.IDIV}
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class MemBuffer:
|
||||
@@ -119,12 +119,12 @@ python_alu = {
|
||||
UnaryOps.LOG2: lambda x: math.log2(x) if x > 0 else -math.inf if x == 0 else math.nan,
|
||||
UnaryOps.EXP2: hook_overflow(math.inf, lambda x: math.exp(x*math.log(2))),
|
||||
UnaryOps.SQRT: lambda x: math.sqrt(x) if x >= 0 else math.nan, UnaryOps.SIN: math.sin,
|
||||
UnaryOps.RECIP: lambda x: 1/x if x != 0 else float('inf'),
|
||||
UnaryOps.NEG: lambda x: (not x) if isinstance(x, bool) else -x,
|
||||
BinaryOps.SHR: operator.rshift, BinaryOps.SHL: operator.lshift,
|
||||
BinaryOps.MUL: operator.mul, BinaryOps.ADD: operator.add, BinaryOps.SUB: operator.sub, BinaryOps.XOR: operator.xor,
|
||||
BinaryOps.MAX: max, BinaryOps.CMPNE: operator.ne, BinaryOps.CMPLT: operator.lt,
|
||||
BinaryOps.MOD: lambda x,y: abs(int(x))%abs(int(y))*(1,-1)[x<0],
|
||||
BinaryOps.DIV: lambda x,y: int(x/y) if isinstance(x, int) else (x/y if y != 0 else x*math.inf),
|
||||
BinaryOps.MUL: operator.mul, BinaryOps.ADD: operator.add, BinaryOps.SUB: operator.sub,
|
||||
BinaryOps.XOR: operator.xor, BinaryOps.MAX: max, BinaryOps.CMPNE: operator.ne, BinaryOps.CMPLT: operator.lt,
|
||||
BinaryOps.MOD: lambda x,y: abs(int(x))%abs(int(y))*(1,-1)[x<0], BinaryOps.IDIV: lambda x, y: int(x/y) if y != 0 else x*math.inf,
|
||||
TernaryOps.WHERE: lambda x,y,z: y if x else z}
|
||||
|
||||
truncate: Dict[DType, Callable] = {dtypes.bool: bool,
|
||||
|
||||
@@ -35,6 +35,7 @@ class PTXRenderer(Renderer):
|
||||
lid = [f'%tid.{chr(120+i)}' for i in range(3)]
|
||||
asm_for_op: Dict[Op, Callable] = {
|
||||
UnaryOps.NEG: lambda d,a,dt,name: f"not.pred {d}, {a};" if name == "pred" else f"neg.{name} {d}, {a};",
|
||||
UnaryOps.RECIP: lambda d,a,dt,name: f"rcp{'.approx' if dtypes.is_float(dt) else ''}.{name} {d}, {a};",
|
||||
UnaryOps.EXP2: lambda d,a,dt,name: f"ex2.approx.{name} {d}, {a};", UnaryOps.LOG2: lambda d,a,dt,name: f"lg2.approx.{name} {d}, {a};",
|
||||
UnaryOps.SIN: lambda d,a,dt,name: f"sin.approx.{name} {d}, {a};", UnaryOps.SQRT: lambda d,a,dt,name: f"sqrt.approx.{name} {d}, {a};",
|
||||
BinaryOps.SHR: lambda d,a,b,dt,name: f"shr.{name} {d}, {a}, {b};", BinaryOps.SHL: lambda d,a,b,dt,name: f"shl.b{name[1:]} {d}, {a}, {b};",
|
||||
@@ -42,7 +43,7 @@ class PTXRenderer(Renderer):
|
||||
BinaryOps.SUB: lambda d,a,b,dt,name: f"sub.{name} {d}, {a}, {b};",
|
||||
BinaryOps.MUL: lambda d,a,b,dt,name: ('and' if dt == dtypes.bool else 'mul') + f"{'.lo' if dtypes.is_int(dt) else ''}.{name} {d}, {a}, {b};",
|
||||
BinaryOps.XOR: lambda d,a,b,dt,name: f"xor.pred {d}, {a}, {b};" if name == "pred" else f"xor.b{name[1:]} {d}, {a}, {b};",
|
||||
BinaryOps.DIV: lambda d,a,b,dt,name: f"div{'.approx' if dtypes.is_float(dt) else ''}.{name} {d}, {a}, {b};",
|
||||
BinaryOps.IDIV: lambda d,a,b,dt,name: f"div.{name} {d}, {a}, {b};",
|
||||
BinaryOps.MAX: lambda d,a,b,dt,name: f"max.{name} {d}, {a}, {b};", BinaryOps.MOD: lambda d,a,b,dt,name: f"rem.{name} {d}, {a}, {b};",
|
||||
BinaryOps.CMPLT: lambda d,a,b,dt,name: f"setp.lt.{name} {d}, {a}, {b};",
|
||||
BinaryOps.CMPNE: lambda d,a,b,dt,name: f"setp.ne.{name} {d}, {a}, {b};",
|
||||
@@ -228,7 +229,7 @@ ptx_matcher = PatternMatcher([
|
||||
(UPat(UOps.ALU, BinaryOps.MUL, name="root", dtype=set([dt for dt in dtypes.fields().values() if dtypes.is_int(dt)]),
|
||||
vin=[UPat(UOps.CONST, set([2**i for i in range(64)]), name="const"), UPat(name="mul")]),
|
||||
lambda root, mul, const: UOp(UOps.ALU, root.dtype, (mul, UOp.const(root.dtype, int(math.log2(const.arg)))), BinaryOps.SHL)),
|
||||
(UPat(UOps.ALU, BinaryOps.DIV, name="root", dtype=set([dt for dt in dtypes.fields().values() if dtypes.is_int(dt)]),
|
||||
(UPat(UOps.ALU, BinaryOps.IDIV, name="root", dtype=set([dt for dt in dtypes.fields().values() if dtypes.is_int(dt)]),
|
||||
vin=[UPat(UOps.CONST, set([2**i for i in range(64)]), name="const"), UPat(name="div")]),
|
||||
lambda root, div, const: UOp(UOps.ALU, root.dtype, (div, UOp.const(root.dtype, int(math.log2(const.arg)))), BinaryOps.SHR)),
|
||||
(UPat(UOps.ALU, BinaryOps.CMPNE, (UPat(dtype=dtypes.bool),UPat()), "root"), lambda root: UOp(root.uop, root.dtype, root.vin, BinaryOps.XOR)),
|
||||
|
||||
@@ -25,9 +25,10 @@ class CStyleLanguage(Renderer):
|
||||
type_map: Dict[DType, str] = {}
|
||||
code_for_op: Dict = {
|
||||
UnaryOps.NEG: lambda x,dtype: f"(!{x})" if dtype == dtypes.bool else f"(-{x})", UnaryOps.SQRT: lambda x,dtype: f"sqrt({x})",
|
||||
UnaryOps.RECIP: lambda x,dtype: f"(1/{x})",
|
||||
UnaryOps.EXP2: lambda x,dtype: f"exp2({x})", UnaryOps.LOG2: lambda x,dtype: f"log2({x})", UnaryOps.SIN: lambda x,dtype: f"sin({x})",
|
||||
BinaryOps.ADD: lambda a,b,dtype: f"({a}+{b})", BinaryOps.SUB: lambda a,b,dtype: f"({a}-{b})", BinaryOps.MUL: lambda a,b,dtype: f"({a}*{b})",
|
||||
BinaryOps.DIV: lambda a,b,dtype: f"({a}/{b})", BinaryOps.MAX: lambda a,b,dtype: f"max({a},{b})", BinaryOps.MOD: lambda a,b,dtype: f"({a}%{b})",
|
||||
BinaryOps.ADD: lambda a,b,dtype: f"({a}+{b})", BinaryOps.SUB: lambda a,b,dtype: f"({a}-{b})", BinaryOps.MAX: lambda a,b,dtype: f"max({a},{b})",
|
||||
BinaryOps.IDIV: lambda a,b,dtype: f"({a}/{b})", BinaryOps.MUL: lambda a,b,dtype: f"({a}*{b})", BinaryOps.MOD: lambda a,b,dtype: f"({a}%{b})",
|
||||
BinaryOps.CMPLT: lambda a,b,dtype: f"({a}<{b})", BinaryOps.CMPNE: lambda a,b,dtype: f"({a}!={b})", BinaryOps.XOR: lambda a,b,dtype: f"({a}^{b})",
|
||||
TernaryOps.WHERE: lambda a,b,c,dtype: f"({a}?{b}:{c})"}
|
||||
|
||||
@@ -239,7 +240,8 @@ class MetalRenderer(CStyleLanguage):
|
||||
return {arg[3].name}2(c.thread_elements()[0], c.thread_elements()[1]);\n}}""")
|
||||
return super().render_kernel(function_name, kernel, bufs, uops, prefix)
|
||||
|
||||
code_for_op_half = {BinaryOps.MAX: lambda a,b,dtype: f"__hmax({a},{b})" if dtype in (dtypes.half, dtypes.bfloat16) else f"max({a},{b})",
|
||||
code_for_op_half = {UnaryOps.RECIP: lambda x,dtype: f"hrcp({x})" if dtype in (dtypes.half, dtypes.bfloat16) else f"1/{x}",
|
||||
BinaryOps.MAX: lambda a,b,dtype: f"__hmax({a},{b})" if dtype in (dtypes.half, dtypes.bfloat16) else f"max({a},{b})",
|
||||
UnaryOps.SQRT: lambda x,dtype: f"hsqrt({x})" if dtype in (dtypes.half, dtypes.bfloat16) else f"sqrt({x})",
|
||||
UnaryOps.SIN: lambda x,dtype: f"hsin({x})" if dtype in (dtypes.half, dtypes.bfloat16) else f"sin({x})",
|
||||
UnaryOps.LOG2: lambda x,dtype: f"hlog2({x})" if dtype in (dtypes.half, dtypes.bfloat16) else f"log2({x})",
|
||||
|
||||
@@ -15,12 +15,13 @@ code_for_op: Final[Dict[Op, Callable]] = {
|
||||
(builder.not_(x) if dtype == dtypes.bool else builder.fneg(x, flags=MFLAGS)),
|
||||
UnaryOps.EXP2: lambda builder, x, dtype: builder.call(builder.module.declare_intrinsic('llvm.exp2', [x.type]), [x], fastmath=MFLAGS),
|
||||
UnaryOps.LOG2: lambda builder, x, dtype: builder.call(builder.module.declare_intrinsic('llvm.log2', [x.type]), [x], fastmath=MFLAGS),
|
||||
UnaryOps.RECIP: lambda builder, x, dtype: builder.fdiv(const(1, dtype), x, flags=MFLAGS),
|
||||
UnaryOps.SIN: lambda builder, x, dtype: builder.call(builder.module.declare_intrinsic('llvm.sin', [x.type]), [x], fastmath=MFLAGS),
|
||||
UnaryOps.SQRT: lambda builder, x, dtype: builder.call(builder.module.declare_intrinsic('llvm.sqrt', [x.type]), [x], fastmath=MFLAGS),
|
||||
BinaryOps.ADD: lambda builder, x, y, dtype: builder.or_(x, y) if dtype == dtypes.bool else builder.add(x, y) if dtypes.is_int(dtype) else builder.fadd(x, y, flags=MFLAGS), # noqa: E501
|
||||
BinaryOps.SUB: lambda builder, x, y, dtype: builder.sub(x, y) if dtypes.is_int(dtype) else builder.fsub(x, y, flags=MFLAGS),
|
||||
BinaryOps.MUL: lambda builder, x, y, dtype: builder.mul(x, y) if is_bool_or_unsigned(dtype) or dtypes.is_int(dtype) else builder.fmul(x, y, flags=MFLAGS), # noqa: E501
|
||||
BinaryOps.DIV: lambda builder, x, y, dtype: builder.udiv(x, y) if is_bool_or_unsigned(dtype) else builder.sdiv(x, y) if dtypes.is_int(dtype) else builder.fdiv(x, y, flags=MFLAGS), # noqa: E501
|
||||
BinaryOps.IDIV: lambda builder, x, y, dtype: builder.udiv(x, y) if is_bool_or_unsigned(dtype) else builder.sdiv(x, y),
|
||||
BinaryOps.CMPLT: lambda builder, x, y, dtype: builder.icmp_unsigned("<", x, y) if is_bool_or_unsigned(dtype) else builder.icmp_signed("<", x, y) if dtypes.is_int(dtype) else builder.fcmp_unordered("<", x, y, flags=MFLAGS), # noqa: E501
|
||||
BinaryOps.CMPNE: lambda builder, x, y, dtype: builder.icmp_unsigned("!=", x, y) if is_bool_or_unsigned(dtype) else builder.icmp_signed("!=", x, y) if dtypes.is_int(dtype) else builder.fcmp_unordered("!=", x, y, flags=MFLAGS), # noqa: E501
|
||||
BinaryOps.MAX: lambda builder, x, y, dtype: builder.select(builder.icmp_unsigned(">", x, y) if is_bool_or_unsigned(dtype) else builder.icmp_signed(">", x, y) if dtypes.is_int(dtype) else builder.fcmp_unordered(">", x, y, flags=MFLAGS), x, y), # noqa: E501
|
||||
|
||||
Reference in New Issue
Block a user