mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
Invalid in Tensor (#15154)
This commit is contained in:
committed by
GitHub
parent
68c7c3ca84
commit
ffaafd391a
@@ -1189,5 +1189,11 @@ class TestBufferView(unittest.TestCase):
|
||||
b = a.shrink(((200, 800),)).shrink(((0, 300),)).reshape((30, 10)).shrink(((20, 25), (0, 10))).contiguous()
|
||||
run_schedule(check_schedule(b, 0))
|
||||
|
||||
class TestInvalidTensor(unittest.TestCase):
|
||||
def test_full_invalid_is_zero_kernels(self):
|
||||
from tinygrad.dtype import Invalid
|
||||
t = Tensor.full((4,), Invalid, dtype=dtypes.float)
|
||||
check_schedule(t, 0)
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main(verbosity=2)
|
||||
|
||||
167
test/unit/test_invalid_tensor.py
Normal file
167
test/unit/test_invalid_tensor.py
Normal file
@@ -0,0 +1,167 @@
|
||||
import unittest
|
||||
from unittest.mock import patch
|
||||
from tinygrad import Tensor
|
||||
from tinygrad.device import Device
|
||||
from tinygrad.dtype import Invalid, dtypes
|
||||
from tinygrad.helpers import unwrap_class_type
|
||||
|
||||
class TestInvalidTensor(unittest.TestCase):
|
||||
def _realize_and_capture(self, out):
|
||||
before = None
|
||||
original_call = (runtime_cls:=unwrap_class_type(Device[Device.DEFAULT].runtime)).__call__
|
||||
|
||||
def patched_call(self_prg, *bufs, **kwargs):
|
||||
nonlocal before
|
||||
before = Device[Device.DEFAULT].allocator._as_buffer(bufs[0]).cast(out.dtype.fmt).tolist()
|
||||
return original_call(self_prg, *bufs, **kwargs)
|
||||
|
||||
with patch.object(runtime_cls, '__call__', patched_call): ret = out.tolist()
|
||||
|
||||
return before, ret
|
||||
|
||||
def test_where_x_invalid(self):
|
||||
mask = Tensor.arange(4) < 2
|
||||
out = mask.where(Tensor([1.0, 2.0, 3.0, 4.0]), Invalid)
|
||||
before, ret = self._realize_and_capture(out)
|
||||
assert ret[0] == 1.0 and ret[1] == 2.0
|
||||
assert before[2] == ret[2] and before[3] == ret[3]
|
||||
|
||||
def test_where_invalid_x(self):
|
||||
mask = Tensor.arange(4) < 2
|
||||
out = mask.where(Invalid, Tensor([1.0, 2.0, 3.0, 4.0]))
|
||||
before, ret = self._realize_and_capture(out)
|
||||
assert ret[2] == 3.0 and ret[3] == 4.0
|
||||
assert before[0] == ret[0] and before[1] == ret[1]
|
||||
|
||||
def test_where_invalid_2d(self):
|
||||
mask = Tensor.arange(6).reshape(2, 3) < 3
|
||||
vals = Tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])
|
||||
out = mask.where(vals, Invalid)
|
||||
before, ret = self._realize_and_capture(out)
|
||||
assert ret[0] == [1.0, 2.0, 3.0]
|
||||
assert before[3] == ret[1][0] and before[4] == ret[1][1] and before[5] == ret[1][2]
|
||||
|
||||
def test_where_invalid_int(self):
|
||||
mask = Tensor.arange(3) < 2
|
||||
out = mask.where(Tensor([10, 20, 30]), Invalid)
|
||||
before, ret = self._realize_and_capture(out)
|
||||
assert ret[0] == 10 and ret[1] == 20
|
||||
assert before[2] == ret[2]
|
||||
|
||||
def test_where_invalid_add(self):
|
||||
mask = Tensor.arange(3) < 2
|
||||
mixed = mask.where(Tensor([10.0, 20.0, 30.0]), Invalid)
|
||||
out = mixed + Tensor([1.0, 2.0, 3.0])
|
||||
before, ret = self._realize_and_capture(out)
|
||||
assert ret[0] == 11.0 and ret[1] == 22.0
|
||||
assert before[2] == ret[2]
|
||||
|
||||
def test_where_invalid_add_left(self):
|
||||
mask = Tensor.arange(3) < 2
|
||||
mixed = mask.where(Tensor([10.0, 20.0, 30.0]), Invalid)
|
||||
out = Tensor([1.0, 2.0, 3.0]) + mixed
|
||||
before, ret = self._realize_and_capture(out)
|
||||
assert ret[0] == 11.0 and ret[1] == 22.0
|
||||
assert before[2] == ret[2]
|
||||
|
||||
def test_where_always_true(self):
|
||||
mask = Tensor.arange(3) < 10
|
||||
out = mask.where(Tensor([10.0, 20.0, 30.0]), Invalid)
|
||||
before, ret = self._realize_and_capture(out)
|
||||
assert ret == [10.0, 20.0, 30.0]
|
||||
|
||||
def test_where_cast(self):
|
||||
mask = Tensor.arange(4) < 2
|
||||
out = mask.where(Tensor([1.0, 2.0, 3.0, 4.0]), Invalid).cast(dtypes.int)
|
||||
before, ret = self._realize_and_capture(out)
|
||||
assert ret[0] == 1 and ret[1] == 2
|
||||
assert before[2] == ret[2] and before[3] == ret[3]
|
||||
|
||||
def test_where_compare(self):
|
||||
mask = Tensor.arange(4) < 2
|
||||
out = mask.where(Tensor([1.0, 2.0, 3.0, 4.0]), Invalid) > 1
|
||||
before, ret = self._realize_and_capture(out)
|
||||
assert not ret[0] and ret[1]
|
||||
assert before[2] == ret[2] and before[3] == ret[3]
|
||||
|
||||
def test_where_unary(self):
|
||||
mask = Tensor.arange(4) < 2
|
||||
out = mask.where(Tensor([1.0, 4.0, 9.0, 16.0]), Invalid).sqrt()
|
||||
before, ret = self._realize_and_capture(out)
|
||||
assert ret[0] == 1.0 and ret[1] == 2.0
|
||||
assert before[2] == ret[2] and before[3] == ret[3]
|
||||
|
||||
def test_where_where(self):
|
||||
mask1 = Tensor.arange(4) < 2
|
||||
mask2 = Tensor.arange(4) > 0
|
||||
out = mask2.where(mask1.where(Tensor([1.0, 2.0, 3.0, 4.0]), Invalid), Invalid)
|
||||
before, ret = self._realize_and_capture(out)
|
||||
assert ret[1] == 2.0
|
||||
assert before[0] == ret[0] and before[2] == ret[2] and before[3] == ret[3]
|
||||
|
||||
def test_where_reduce_always_true(self):
|
||||
mask = Tensor.arange(4) < 9
|
||||
out = mask.where(Tensor([1.0, 2.0, 3.0, 4.0]), Invalid).sum()
|
||||
before, ret = self._realize_and_capture(out)
|
||||
assert ret == 10.0
|
||||
|
||||
def test_invalid_unary(self):
|
||||
mask = Tensor.arange(4) < 2
|
||||
out = mask.where(Tensor([1.0, 2.0, 3.0, 4.0]), Tensor.full((4,), Invalid, dtype=dtypes.float).sqrt())
|
||||
before, ret = self._realize_and_capture(out)
|
||||
assert ret[0] == 1.0 and ret[1] == 2.0
|
||||
assert before[2] == ret[2] and before[3] == ret[3]
|
||||
|
||||
def test_invalid_binary(self):
|
||||
mask = Tensor.arange(4) < 2
|
||||
out = mask.where(Tensor([1.0, 2.0, 3.0, 4.0]), Tensor.full((4,), Invalid, dtype=dtypes.float) + 2)
|
||||
before, ret = self._realize_and_capture(out)
|
||||
assert ret[0] == 1.0 and ret[1] == 2.0
|
||||
assert before[2] == ret[2] and before[3] == ret[3]
|
||||
|
||||
def test_invalid_binary_left(self):
|
||||
mask = Tensor.arange(4) < 2
|
||||
out = mask.where(Tensor([1.0, 2.0, 3.0, 4.0]), 2 + Tensor.full((4,), Invalid, dtype=dtypes.float))
|
||||
before, ret = self._realize_and_capture(out)
|
||||
assert ret[0] == 1.0 and ret[1] == 2.0
|
||||
assert before[2] == ret[2] and before[3] == ret[3]
|
||||
|
||||
def test_invalid_reshape(self):
|
||||
mask = Tensor.arange(4) < 2
|
||||
out = mask.where(Tensor([1.0, 2.0, 3.0, 4.0]), Invalid).reshape(2,2)
|
||||
before, ret = self._realize_and_capture(out)
|
||||
assert ret[0] == [1.0, 2.0]
|
||||
assert ret[1] == [before[2], before[3]]
|
||||
|
||||
def test_invalid_cast(self):
|
||||
mask = Tensor.arange(4) < 2
|
||||
out = mask.where(Tensor([1.0, 2.0, 3.0, 4.0]), Tensor.full((4,), Invalid, dtype=dtypes.int).cast(dtypes.float))
|
||||
before, ret = self._realize_and_capture(out)
|
||||
assert ret[0] == 1.0 and ret[1] == 2.0
|
||||
assert before[2] == ret[2] and before[3] == ret[3]
|
||||
|
||||
def test_invalid_bitcast(self):
|
||||
mask = Tensor.arange(4) < 2
|
||||
out = mask.where(Tensor([1.0, 2.0, 3.0, 4.0]), Tensor.full((4,), Invalid, dtype=dtypes.int).bitcast(dtypes.float))
|
||||
before, ret = self._realize_and_capture(out)
|
||||
assert ret[0] == 1.0 and ret[1] == 2.0
|
||||
assert before[2] == ret[2] and before[3] == ret[3]
|
||||
|
||||
def test_where_bitcast(self):
|
||||
mask = Tensor.arange(4) < 2
|
||||
out = mask.where(Tensor([1.0, 2.0, 3.0, 4.0]), Tensor.full((4,), Invalid, dtype=dtypes.int)).bitcast(dtypes.int)
|
||||
before, ret = self._realize_and_capture(out)
|
||||
assert ret[0] == 0x3f800000 and ret[1] == 0x40000000
|
||||
assert before[2] == ret[2] and before[3] == ret[3]
|
||||
|
||||
# tensor indexing uses reduce, so the entire result becomes invalid
|
||||
@unittest.expectedFailure
|
||||
def test_tensor_index(self):
|
||||
idx = (Tensor.arange(4) < 2).where(Tensor([0, 1, 2, 3]), Invalid)
|
||||
out = Tensor([1.0, 2.0, 3.0, 4.0])[idx]
|
||||
before, ret = self._realize_and_capture(out)
|
||||
assert ret[0] == 1.0 and ret[1] == 2.0
|
||||
assert before[2] == ret[2] and before[3] == ret[3]
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
@@ -30,6 +30,7 @@ class InvalidType:
|
||||
def __hash__(self): return id(self)
|
||||
def __repr__(self): return "Invalid"
|
||||
def __reduce__(self): return (InvalidType, ()) # unpickle returns the singleton
|
||||
def __format__(self, spec): return "Invalid"
|
||||
|
||||
Invalid = InvalidType()
|
||||
|
||||
|
||||
@@ -5,7 +5,7 @@ from contextlib import ContextDecorator
|
||||
from typing import Any, Callable, ClassVar, Sequence, cast, get_args, Literal, SupportsIndex, ParamSpec, TypeVar, Generic, TYPE_CHECKING
|
||||
if TYPE_CHECKING: import numpy
|
||||
from tinygrad.dtype import DType, DTypeLike, dtypes, ImageDType, ConstType, least_upper_float, least_upper_dtype, sum_acc_dtype, to_dtype, truncate
|
||||
from tinygrad.dtype import _from_np_dtype, _to_np_dtype, PyConst
|
||||
from tinygrad.dtype import _from_np_dtype, _to_np_dtype, PyConst, Invalid, InvalidType
|
||||
from tinygrad.helpers import argfix, make_tuple, flatten, prod, all_int, round_up, merge_dicts, argsort, getenv, all_same, fully_flatten
|
||||
from tinygrad.helpers import IMAGE, FLOAT16, WINO, Metadata, TRACEMETA, ASM_GEMM, ceildiv, fetch, is_numpy_ndarray, TracingKey, cpu_profile
|
||||
from tinygrad.helpers import suppress_finalizing, disable_gc
|
||||
@@ -113,7 +113,7 @@ class Tensor(OpMixin):
|
||||
__slots__ = "uop", "requires_grad", "grad"
|
||||
training: ClassVar[bool] = False
|
||||
|
||||
def __init__(self, data:PyConst|bytes|list|tuple|UOp|'numpy.ndarray'|pathlib.Path|None,
|
||||
def __init__(self, data:ConstType|bytes|list|tuple|UOp|'numpy.ndarray'|pathlib.Path|None,
|
||||
device:str|tuple|list|None=None, dtype:DTypeLike|None=None, requires_grad:bool|None=None, _force_unique:bool=False):
|
||||
if device is None and isinstance(data, pathlib.Path): device = f"DISK:{data.resolve()}" # keep it on the disk if device is None
|
||||
_dtype:DType|None = to_dtype(dtype) if dtype is not None else None
|
||||
@@ -141,6 +141,9 @@ class Tensor(OpMixin):
|
||||
data = Tensor(0, device=_device, dtype=_dtype or dtypes.default_float, requires_grad=requires_grad).uop
|
||||
elif isinstance(data, get_args(PyConst)):
|
||||
data = (UOp.unique_const if _force_unique or requires_grad else UOp.const)(_dtype or dtypes.from_py(data), data, _device)
|
||||
elif isinstance(data, InvalidType):
|
||||
assert _dtype is not None
|
||||
data = UOp.const(_dtype, data, _device)
|
||||
elif isinstance(data, bytes): data = _frompy(data, dtypes.uint8 if _dtype is None else _dtype)
|
||||
elif isinstance(data, (list, tuple)):
|
||||
if _dtype is None:
|
||||
@@ -2944,7 +2947,8 @@ class Tensor(OpMixin):
|
||||
if not isinstance(y, Tensor):
|
||||
# make y a Tensor
|
||||
assert isinstance(y, (*get_args(ConstType), UOp)), f"{type(y)=}, {y=}"
|
||||
if isinstance(x.dtype, ImageDType) or dtypes.is_float(x.dtype) or (dtypes.is_int(x.dtype) and isinstance(y, int)): y_dtype = x.dtype
|
||||
if y is Invalid or isinstance(x.dtype, ImageDType) or dtypes.is_float(x.dtype) or (dtypes.is_int(x.dtype) and isinstance(y, int)):
|
||||
y_dtype = x.dtype
|
||||
elif not isinstance(y, UOp): y_dtype = dtypes.from_py(y)
|
||||
if isinstance(y, UOp): y = Tensor.from_uop(y, device=x.device)
|
||||
else: y = Tensor(dtypes.as_const(y, y_dtype), x.device, y_dtype, requires_grad=False)
|
||||
|
||||
@@ -267,8 +267,6 @@ full_spec = PatternMatcher([
|
||||
# linearizer: outputs + intermediate KERNELs
|
||||
(UPat(Ops.CALL, dtype=dtypes.void), lambda: True),
|
||||
|
||||
# Invalid must have type Index
|
||||
(UPat(Ops.CONST, arg=Invalid, name="x"), lambda x: x.dtype.scalar() == dtypes.index),
|
||||
# where on index in rhs position is fine
|
||||
(UPat(Ops.WHERE, dtype=dtypes.index, src=(UPat(dtype=dtypes.bool), UPat(), UPat(dtype=dtypes.index))), lambda: True),
|
||||
# allow index dtype on a restricted set of UOps
|
||||
|
||||
@@ -28,13 +28,22 @@ invalid_gate = UPat.var("cond").where(UPat.var("x"), invalid_pat)
|
||||
# this needs to be before symbolic so that 0*something_that_might_be_invalid doesnt become 0
|
||||
propagate_invalid = PatternMatcher([
|
||||
# propagate invalid, push it past children
|
||||
(invalid_gate.cast(name="cast"), lambda i,x,cond,cast: x.cast(cast.dtype)),
|
||||
*((invalid_gate.alu(op, UPat.var("y")).named("alu"), lambda cond,x,y,alu,i: cond.where(x.alu(alu.op,y), i))
|
||||
for op in GroupOp.Binary-GroupOp.Comparison),
|
||||
(invalid_gate.cast(name="cast"), lambda i,x,cond,cast: x.cast(cast.dtype) if i.dtype is dtypes.index else None),
|
||||
(UPat(GroupOp.Unary, src=(invalid_gate,), name="alu"), lambda cond,x,alu,i: cond.where(x.alu(alu.op), i)),
|
||||
(UPat(GroupOp.Binary-GroupOp.Comparison, src=(invalid_gate, UPat.var("y")), name="alu"), lambda cond,x,y,alu,i: cond.where(x.alu(alu.op,y), i)),
|
||||
(UPat(GroupOp.Binary-GroupOp.Comparison, src=(UPat.var("y"), invalid_gate), name="alu"), lambda cond,x,y,alu,i: cond.where(y.alu(alu.op,x), i)),
|
||||
# TODO: when can this happen? and is it always safe to just drop invalid?
|
||||
*((invalid_gate.alu(op, UPat.var("y")).named("alu"), lambda cond,x,y,alu,i: x.alu(alu.op,y)) for op in GroupOp.Comparison),
|
||||
(UPat(GroupOp.Comparison, src=(invalid_gate, UPat.var("y")), name="alu"), lambda cond,x,y,alu,i:
|
||||
x.alu(alu.op,y) if i.dtype is dtypes.index else cond.where(x.alu(alu.op,y), i.cast(dtypes.bool))),
|
||||
(UPat(GroupOp.Comparison, src=(UPat.var("y"), invalid_gate), name="alu"), lambda cond,x,y,alu,i:
|
||||
y.alu(alu.op,x) if i.dtype is dtypes.index else cond.where(y.alu(alu.op,x), i.cast(dtypes.bool))),
|
||||
# alu with invalid -> invalid
|
||||
*((invalid_pat.alu(op, UPat(dtype=dtypes.index)), lambda i: i) for op in GroupOp.Binary-GroupOp.Comparison),
|
||||
(UPat(GroupOp.Unary, src=(invalid_pat,)), lambda i: i),
|
||||
(UPat(GroupOp.Binary-GroupOp.Comparison, src=[invalid_pat, UPat()]), lambda i: i),
|
||||
# normalize where(cond, Invalid, val) -> where(~cond, val, Invalid)
|
||||
(UPat.var("cond").where(invalid_pat, UPat.var("val")), lambda cond, i, val: cond.logical_not().where(val, i) if val.arg != Invalid else i),
|
||||
(UPat(Ops.BITCAST, src=(invalid_pat,), name="bc"), lambda bc,i: i.cast(bc.dtype)),
|
||||
(UPat(Ops.BITCAST, src=(invalid_gate,), name="bc"), lambda bc,cond,x,i: cond.where(x.bitcast(bc.dtype), i.bitcast(bc.dtype))),
|
||||
])
|
||||
|
||||
symbolic_simple = propagate_invalid + PatternMatcher([
|
||||
@@ -122,6 +131,12 @@ symbolic_simple = propagate_invalid + PatternMatcher([
|
||||
(UPat.cvar("gate", vec=False).where(UPat.var("c0"), UPat.var("c1")), lambda gate, c0, c1: c0 if gate.arg else c1),
|
||||
# a.where(b.where(c, d), d) -> (a & b).where(c, d)
|
||||
(UPat.var("a").where(UPat.var("b").where(UPat.var("c"), UPat.var("d")), UPat.var("d")), lambda a,b,c,d: (a&b).where(c,d)),
|
||||
# where over invalid -> invalid
|
||||
(invalid_gate.where(UPat.var("a"), UPat.var("b")), lambda a,b,cond,x,i: i.cast(a.dtype)),
|
||||
(invalid_pat.where(UPat.var("a"), UPat.var("b")), lambda a,b,i: i.cast(a.dtype)),
|
||||
# reduce with invalid -> invalid
|
||||
(UPat(Ops.REDUCE, src=(invalid_gate,), allow_any_len=True, name="r"), lambda r,cond,x,i: i.cast(r.dtype)),
|
||||
(UPat(Ops.REDUCE, src=(invalid_pat,), allow_any_len=True, name="r"), lambda r,i: i.cast(r.dtype)),
|
||||
])
|
||||
|
||||
# ******** phase 2 builds on phase 1, it includes the old "symbolic", rules that match deeper ********
|
||||
@@ -412,6 +427,10 @@ sym = symbolic+pm_simplify_valid+PatternMatcher([
|
||||
# fold gated LOAD/STORE
|
||||
(UPat((Ops.LOAD, Ops.STORE), src=(UPat().index(UPat.const(dtypes.index, Invalid)).or_casted(),), allow_any_len=True, name="x"),
|
||||
lambda x: UOp(Ops.NOOP) if x.op is Ops.STORE else x.const_like(0)), # invalid store does nothing. invalid load produces 0
|
||||
(UPat(Ops.STORE, src=(UPat(), invalid_pat), allow_any_len=True), lambda i: UOp(Ops.NOOP)),
|
||||
# store of where with invalid -> gated store
|
||||
(UPat(Ops.STORE, src=(UPat(Ops.INDEX, name="index"), UPat.var("cond").where(UPat.var("val"), invalid_pat)), allow_any_len=True, name="store"),
|
||||
lambda index, cond, val, store, i: UOp.store(index.src[0].index(cond.where(index.src[1], UOp.invalid())), val, *store.src[2:])),
|
||||
((UPat.var("x") * UPat.var("x")).reciprocal(), lambda x: x.reciprocal()*x.reciprocal()), # 1/(x^c) -> (1/x)^c
|
||||
((UPat.var("x") * UPat.var("x") * UPat.var("x")).reciprocal(), lambda x: x.reciprocal()*x.reciprocal()*x.reciprocal()),
|
||||
((UPat.var("x") * UPat.cvar("c")).reciprocal(), lambda x,c: x.reciprocal()*c.reciprocal()), # 1/(x*c) -> (1/c)*(1/x)
|
||||
|
||||
Reference in New Issue
Block a user