mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
rewrite bool ADD to OR and MUL to AND (#6084)
* rewrite bool ADD to OR and MUL to AND fixed running `tinyphysics.onnx`, which contains a getitem from a boolean tensor. only can repro through BEAM_COMPARE, which i think is a different bug in test_linearizer_failure * fold those, and fix tests * only for bool * move dtypes.bool
This commit is contained in:
@@ -474,5 +474,22 @@ class TestLinearizerFailures(unittest.TestCase):
|
||||
opts = [Opt(op=OptOps.TC, axis=0, amt=2), Opt(op=OptOps.UPCAST, axis=0, amt=2)]
|
||||
helper_test_lin(Kernel(ast, opts=Device[Device.DEFAULT].renderer), opts=opts, failed_platforms=[])
|
||||
|
||||
def test_failure_50(self):
|
||||
# from BEAM_COMPARE=2 running tinyphysics.onnx model
|
||||
ast = LazyOp(MetaOps.KERNEL, arg=None, src=(
|
||||
LazyOp(BufferOps.STORE, arg=MemBuffer(idx=0, dtype=dtypes.bool, st=ShapeTracker(views=(View(shape=(1, 1, 20, 1, 20), strides=(0, 0, 20, 0, 1), offset=0, mask=None, contiguous=True),))), src=(
|
||||
LazyOp(BinaryOps.CMPNE, arg=None, src=(
|
||||
LazyOp(ReduceOps.SUM, arg=(3,), src=(
|
||||
LazyOp(BinaryOps.MUL, arg=None, src=(
|
||||
LazyOp(BufferOps.LOAD, arg=MemBuffer(idx=1, dtype=dtypes.bool, st=ShapeTracker(views=(View(shape=(1, 1, 20, 20, 20), strides=(0, 0, 0, 20, 1), offset=0, mask=None, contiguous=False),))), src=()),
|
||||
LazyOp(BinaryOps.CMPNE, arg=None, src=(
|
||||
LazyOp(BinaryOps.CMPNE, arg=None, src=(
|
||||
LazyOp(BufferOps.LOAD, arg=MemBuffer(idx=2, dtype=dtypes.int, st=ShapeTracker(views=(View(shape=(1, 1, 20, 20, 20), strides=(0, 0, 1, 0, 0), offset=0, mask=None, contiguous=False),))), src=()),
|
||||
LazyOp(BufferOps.LOAD, arg=MemBuffer(idx=3, dtype=dtypes.int, st=ShapeTracker(views=(View(shape=(1, 1, 20, 20, 20), strides=(0, 0, 0, 1, 0), offset=0, mask=None, contiguous=False),))), src=()),)),
|
||||
LazyOp(BufferOps.CONST, arg=ConstBuffer(val=True, dtype=dtypes.bool, st=ShapeTracker(views=(View(shape=(1, 1, 20, 20, 20), strides=(0, 0, 0, 0, 0), offset=0, mask=None, contiguous=False),))), src=()),)),)),)),
|
||||
LazyOp(BufferOps.CONST, arg=ConstBuffer(val=True, dtype=dtypes.bool, st=ShapeTracker(views=(View(shape=(1, 1, 20, 1, 20), strides=(0, 0, 0, 0, 0), offset=0, mask=None, contiguous=False),))), src=()),)),)),))
|
||||
opts = [Opt(op=OptOps.UPCAST, axis=1, amt=2)]
|
||||
helper_test_lin(Kernel(ast, opts=Device[Device.DEFAULT].renderer), opts=opts, failed_platforms=[])
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
||||
@@ -566,7 +566,7 @@ class TestIFUOps(TestUOps):
|
||||
sbuf = UOp(UOps.DEFINE_LOCAL, PtrDType(dtypes.float), (), ("smem", 4))
|
||||
valid = UOp(UOps.SPECIAL, dtypes.int, (), ("gidx0", 10)).lt(5)
|
||||
lidx = UOp(UOps.SPECIAL, dtypes.int, (), ("lidx0", 4))
|
||||
gate = valid*(lidx.ne(2))
|
||||
gate = valid&(lidx.ne(2))
|
||||
idx = UOp.const(dtypes.int, 0)
|
||||
st = UOp(UOps.STORE, None, (sbuf, idx, UOp.const(dtypes.float, 42)))
|
||||
barrier = UOp(UOps.BARRIER, None, (st,))
|
||||
@@ -585,7 +585,7 @@ class TestIFUOps(TestUOps):
|
||||
sbuf = UOp(UOps.DEFINE_LOCAL, PtrDType(dtypes.float), (), ("smem", 16))
|
||||
valid = UOp(UOps.SPECIAL, dtypes.int, (), ("gidx0", 4)).lt(1)
|
||||
lidx = UOp(UOps.SPECIAL, dtypes.int, (), ("lidx0", 16))
|
||||
gate = valid*(lidx.ne(2))
|
||||
gate = valid&(lidx.ne(2))
|
||||
st = UOp(UOps.STORE, None, (sbuf, lidx, UOp.const(dtypes.float, 42)))
|
||||
barrier = UOp(UOps.BARRIER, None, (st,))
|
||||
lbufs = [UOp(UOps.LOAD, dtypes.float, (sbuf, UOp.const(dtypes.int, i), barrier)) for i in range(4)]
|
||||
@@ -604,7 +604,7 @@ class TestIFUOps(TestUOps):
|
||||
buf = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), (), 0)
|
||||
valid = UOp(UOps.SPECIAL, dtypes.int, (), ("gidx0", 10)).lt(5)
|
||||
lidx = UOp(UOps.SPECIAL, dtypes.int, (), ("lidx0", 4))
|
||||
gate = valid*(lidx.ne(2))
|
||||
gate = valid&(lidx.ne(2))
|
||||
stores = [UOp(UOps.STORE, None, (buf, UOp.const(dtypes.int, i), UOp.const(dtypes.float, i), gate)) for i in range(4)]
|
||||
sink = UOp(UOps.SINK, None, tuple(stores))
|
||||
sink = gate_rewrite(sink)
|
||||
|
||||
@@ -89,7 +89,7 @@ class TestSymbolic(unittest.TestCase):
|
||||
def test_ge_divides_and(self):
|
||||
expr = Node.ands([create_lt_node(Variable("idx1", 0, 511)*4 + Variable("FLOAT4_INDEX", 0, 3), 512),
|
||||
create_lt_node(Variable("idx2", 0, 511)*4 + Variable("FLOAT4_INDEX", 0, 3), 512)])
|
||||
self.helper_test_variable(expr, 0, 1, {"((idx1<128) and (idx2<128))", "((idx1<128)*(idx2<128))"})
|
||||
self.helper_test_variable(expr, 0, 1, {"((idx1<128) and (idx2<128))", "((idx1<128)&(idx2<128))"})
|
||||
# # bool divided by int is not allowed
|
||||
# expr = Node.ands([create_lt_node(Variable("idx1", 0, 511)*4 + Variable("FLOAT4_INDEX", 0, 3), 512),
|
||||
# create_lt_node(Variable("idx2", 0, 511)*4 + Variable("FLOAT8_INDEX", 0, 7), 512)])
|
||||
|
||||
Reference in New Issue
Block a user