mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-07 22:23:55 -05:00
shared_codegen_spec and fix index spec (#12967)
* split shared_codegen_spec and fix index * add VCONST to program_spec and move index to shared_codegen_spec * working ignore_oob=0 * cleanup * fix spec * undo that * move barrier and special earlier * fix more spec issues * more updates * remove special from program_spec * cleanup and fixes * move more to shared * special is not in shared_spec * some comments * dont do bounds check there
This commit is contained in:
@@ -1,5 +1,5 @@
|
|||||||
import unittest, itertools, math
|
import unittest, itertools, math
|
||||||
from tinygrad import Tensor, Device, dtypes
|
from tinygrad import Tensor, Device, dtypes, Context
|
||||||
from tinygrad.dtype import DType, ConstType
|
from tinygrad.dtype import DType, ConstType
|
||||||
from tinygrad.uop.ops import Ops, UOp
|
from tinygrad.uop.ops import Ops, UOp
|
||||||
from tinygrad.codegen import full_rewrite_to_sink
|
from tinygrad.codegen import full_rewrite_to_sink
|
||||||
@@ -126,7 +126,8 @@ class TestBitcastConstFolding(unittest.TestCase):
|
|||||||
t({dtypes.int64: 4598983288165178391, dtypes.uint64: 4598983288165178391, dtypes.float64: 0.29485681936461233})
|
t({dtypes.int64: 4598983288165178391, dtypes.uint64: 4598983288165178391, dtypes.float64: 0.29485681936461233})
|
||||||
|
|
||||||
def test_vec_bitcast(self):
|
def test_vec_bitcast(self):
|
||||||
r = full_rewrite_to_sink(UOp.const(dtypes.int32.vec(3), (-1, -2**31, 75)).bitcast(dtypes.uint32.vec(3)).sink()).src[0]
|
with Context(SPEC=0):
|
||||||
|
r = full_rewrite_to_sink(UOp.const(dtypes.int32.vec(3), (-1, -2**31, 75)).bitcast(dtypes.uint32.vec(3)).sink()).src[0]
|
||||||
self.assertEqual(r.op, Ops.VECTORIZE)
|
self.assertEqual(r.op, Ops.VECTORIZE)
|
||||||
self.assertEqual(r.dtype, dtypes.uint32.vec(3))
|
self.assertEqual(r.dtype, dtypes.uint32.vec(3))
|
||||||
self.assertEqual(tuple(x.arg for x in r.src), (2**32-1, 2**31, 75))
|
self.assertEqual(tuple(x.arg for x in r.src), (2**32-1, 2**31, 75))
|
||||||
|
|||||||
@@ -46,7 +46,7 @@ class TestRendererFailures(unittest.TestCase):
|
|||||||
def test_gated_store_with_alu(self):
|
def test_gated_store_with_alu(self):
|
||||||
a = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), (), 0)
|
a = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), (), 0)
|
||||||
gate_alu = (lidx0:=UOp(Ops.SPECIAL, dtypes.int, (UOp.const(dtypes.int, 4),), 'lidx0')).ne(0)
|
gate_alu = (lidx0:=UOp(Ops.SPECIAL, dtypes.int, (UOp.const(dtypes.int, 4),), 'lidx0')).ne(0)
|
||||||
gated_alu_store = UOp(Ops.STORE, dtypes.void, (a.index(lidx0, gate_alu), UOp.const(dtypes.int, 1)))
|
gated_alu_store = UOp(Ops.STORE, dtypes.void, (a.index(lidx0.valid(gate_alu)), UOp.const(dtypes.int, 1)))
|
||||||
sink = UOp(Ops.SINK, dtypes.void, (gated_alu_store,))
|
sink = UOp(Ops.SINK, dtypes.void, (gated_alu_store,))
|
||||||
uops = full_rewrite(sink, Device[Device.DEFAULT].renderer)
|
uops = full_rewrite(sink, Device[Device.DEFAULT].renderer)
|
||||||
ret = _test_uop_result([], uops, local_size=[4, 1, 1])[0]
|
ret = _test_uop_result([], uops, local_size=[4, 1, 1])[0]
|
||||||
@@ -57,7 +57,7 @@ class TestRendererFailures(unittest.TestCase):
|
|||||||
a = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), (), 0)
|
a = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), (), 0)
|
||||||
gate_alu_0 = (lidx0:=UOp(Ops.SPECIAL, dtypes.int, (UOp.const(dtypes.int, 4),), 'lidx0')).ne(0)
|
gate_alu_0 = (lidx0:=UOp(Ops.SPECIAL, dtypes.int, (UOp.const(dtypes.int, 4),), 'lidx0')).ne(0)
|
||||||
gate_alu_1 = (lidx1:=UOp(Ops.SPECIAL, dtypes.int, (UOp.const(dtypes.int, 2),), 'lidx1')).ne(0)
|
gate_alu_1 = (lidx1:=UOp(Ops.SPECIAL, dtypes.int, (UOp.const(dtypes.int, 2),), 'lidx1')).ne(0)
|
||||||
gated_alu_store = UOp(Ops.STORE, dtypes.void, (a.index(lidx0+lidx1*4, gate_alu_0&gate_alu_1), UOp.const(dtypes.int, 1)))
|
gated_alu_store = UOp(Ops.STORE, dtypes.void, (a.index((lidx0+lidx1*4).valid(gate_alu_0&gate_alu_1)), UOp.const(dtypes.int, 1)))
|
||||||
sink = UOp(Ops.SINK, dtypes.void, (gated_alu_store,))
|
sink = UOp(Ops.SINK, dtypes.void, (gated_alu_store,))
|
||||||
uops = full_rewrite(sink, Device[Device.DEFAULT].renderer)
|
uops = full_rewrite(sink, Device[Device.DEFAULT].renderer)
|
||||||
ret = _test_uop_result([], uops, local_size=[4, 2, 1])[0]
|
ret = _test_uop_result([], uops, local_size=[4, 2, 1])[0]
|
||||||
|
|||||||
@@ -307,9 +307,10 @@ class TestUOpGraph(unittest.TestCase):
|
|||||||
for vec_size in [2, 4, 8]:
|
for vec_size in [2, 4, 8]:
|
||||||
consts = [UOp.const(dtypes.float, float(i)) for i in range(vec_size)]
|
consts = [UOp.const(dtypes.float, float(i)) for i in range(vec_size)]
|
||||||
vec = UOp(Ops.VECTORIZE, dtypes.float.vec(vec_size), tuple(consts))
|
vec = UOp(Ops.VECTORIZE, dtypes.float.vec(vec_size), tuple(consts))
|
||||||
uops = to_uops_list([UOp(Ops.GEP, dtypes.float, (vec,), (i,)) for i in range(vec_size)])
|
with Context(SPEC=0):
|
||||||
for uop, const in zip(uops, consts):
|
uops = to_uops_list([UOp(Ops.GEP, dtypes.float, (vec,), (i,)) for i in range(vec_size)])
|
||||||
self.assertEqual(uop, const)
|
for uop, const in zip(uops, consts):
|
||||||
|
self.assertEqual(uop, const)
|
||||||
|
|
||||||
@unittest.skip("no longer testable standalone")
|
@unittest.skip("no longer testable standalone")
|
||||||
def test_wmma_vectorize_fold(self):
|
def test_wmma_vectorize_fold(self):
|
||||||
@@ -505,10 +506,10 @@ class TestUOpGraph(unittest.TestCase):
|
|||||||
with Context(IGNORE_OOB=0):
|
with Context(IGNORE_OOB=0):
|
||||||
glbl0 = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(16), src=(), arg=0)
|
glbl0 = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(16), src=(), arg=0)
|
||||||
v = Variable("v", 0, 20)
|
v = Variable("v", 0, 20)
|
||||||
st0 = UOp(Ops.STORE, dtypes.void, src=(glbl0.index(v, v<16), UOp.const(dtypes.int, 0)))
|
st0 = UOp(Ops.STORE, dtypes.void, src=(glbl0.index(v.valid(v<16)), UOp.const(dtypes.int, 0)))
|
||||||
to_uops_list([st0])
|
to_uops_list([st0])
|
||||||
|
|
||||||
st1 = UOp(Ops.STORE, dtypes.void, (glbl0.index(v), v, v<20))
|
st1 = UOp(Ops.STORE, dtypes.void, (glbl0.index(v.valid(v<20)), v))
|
||||||
with self.assertRaises(RuntimeError): to_uops_list([st1])
|
with self.assertRaises(RuntimeError): to_uops_list([st1])
|
||||||
|
|
||||||
@unittest.skip("if not allowed in graph")
|
@unittest.skip("if not allowed in graph")
|
||||||
@@ -541,7 +542,7 @@ class TestUOpGraph(unittest.TestCase):
|
|||||||
ridx = UOp.range(20, 0)
|
ridx = UOp.range(20, 0)
|
||||||
glbl0 = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(16), (), 0)
|
glbl0 = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(16), (), 0)
|
||||||
i = (ridx.cast(dtypes.float)*0.68).trunc().cast(dtypes.int)
|
i = (ridx.cast(dtypes.float)*0.68).trunc().cast(dtypes.int)
|
||||||
ld0 = UOp(Ops.LOAD, dtypes.int, (glbl0.index(i, ((0<=i)&(i<16))),))
|
ld0 = UOp(Ops.LOAD, dtypes.int, (glbl0.index(i.valid((0<=i)&(i<16))),))
|
||||||
to_uops_list([ld0])
|
to_uops_list([ld0])
|
||||||
glblfloat = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(20), (), 0)
|
glblfloat = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(20), (), 0)
|
||||||
ldfloat = UOp(Ops.LOAD, dtypes.float, (glblfloat.index(ridx),))
|
ldfloat = UOp(Ops.LOAD, dtypes.float, (glblfloat.index(ridx),))
|
||||||
@@ -552,7 +553,7 @@ class TestUOpGraph(unittest.TestCase):
|
|||||||
with Context(IGNORE_OOB=0):
|
with Context(IGNORE_OOB=0):
|
||||||
glbl0 = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(1), (), 0)
|
glbl0 = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(1), (), 0)
|
||||||
ridx = UOp.range(20, 0)
|
ridx = UOp.range(20, 0)
|
||||||
ld0 = UOp(Ops.LOAD, dtypes.int, (glbl0.index(ridx, ridx.cast(dtypes.bool).logical_not()),))
|
ld0 = UOp(Ops.LOAD, dtypes.int, (glbl0.index(ridx.valid(ridx.cast(dtypes.bool).logical_not())),))
|
||||||
to_uops_list([ld0])
|
to_uops_list([ld0])
|
||||||
|
|
||||||
@unittest.skip("Bool load is not supported yet")
|
@unittest.skip("Bool load is not supported yet")
|
||||||
@@ -574,23 +575,23 @@ class TestUOpGraph(unittest.TestCase):
|
|||||||
with Context(IGNORE_OOB=0):
|
with Context(IGNORE_OOB=0):
|
||||||
glbl0 = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(16), (), 0)
|
glbl0 = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(16), (), 0)
|
||||||
gidx0 = UOp.range(42, 0, AxisType.GLOBAL)
|
gidx0 = UOp.range(42, 0, AxisType.GLOBAL)
|
||||||
ld0 = UOp(Ops.LOAD, dtypes.int, (glbl0.index(gidx0, (5<gidx0)&(gidx0<16)),))
|
ld0 = UOp(Ops.LOAD, dtypes.int, (glbl0.index(gidx0.valid((5<gidx0)&(gidx0<16))),))
|
||||||
ld1 = UOp(Ops.LOAD, dtypes.int, (glbl0.index(gidx0, gidx0<16),))
|
ld1 = UOp(Ops.LOAD, dtypes.int, (glbl0.index(gidx0.valid(gidx0<16)),))
|
||||||
to_uops_list([ld0, ld1])
|
to_uops_list([ld0, ld1])
|
||||||
|
|
||||||
ld0 = UOp(Ops.LOAD, dtypes.int, (glbl0.index(gidx0, gidx0<17),))
|
ld0 = UOp(Ops.LOAD, dtypes.int, (glbl0.index(gidx0.valid(gidx0<17)),))
|
||||||
with self.assertRaises(RuntimeError): to_uops_list([ld0])
|
with self.assertRaises(RuntimeError): to_uops_list([ld0])
|
||||||
|
|
||||||
def test_in_out_of_bounds_access_symbolic_mask(self):
|
def test_in_out_of_bounds_access_symbolic_mask(self):
|
||||||
with Context(IGNORE_OOB=0):
|
with Context(IGNORE_OOB=0):
|
||||||
glbl0 = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(16), (), 0)
|
glbl0 = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(16), (), 0)
|
||||||
i = Variable("i", 1, 80)
|
i = Variable("i", 1, 80)
|
||||||
ld0 = UOp(Ops.LOAD, dtypes.int, (glbl0.index(i, i<10),))
|
ld0 = UOp(Ops.LOAD, dtypes.int, (glbl0.index(i.valid(i<10)),))
|
||||||
to_uops_list([ld0])
|
to_uops_list([ld0])
|
||||||
ld0 = UOp(Ops.LOAD, dtypes.int, (glbl0.index(i, i<15),))
|
ld0 = UOp(Ops.LOAD, dtypes.int, (glbl0.index(i.valid(i<15)),))
|
||||||
to_uops_list([ld0])
|
to_uops_list([ld0])
|
||||||
|
|
||||||
ld0 = UOp(Ops.LOAD, dtypes.int, (glbl0.index(i, i<20),))
|
ld0 = UOp(Ops.LOAD, dtypes.int, (glbl0.index(i.valid(i<20)),))
|
||||||
with self.assertRaises(RuntimeError): to_uops_list([ld0])
|
with self.assertRaises(RuntimeError): to_uops_list([ld0])
|
||||||
|
|
||||||
def test_in_out_of_bounds_access_index_load(self):
|
def test_in_out_of_bounds_access_index_load(self):
|
||||||
@@ -598,11 +599,11 @@ class TestUOpGraph(unittest.TestCase):
|
|||||||
glbl0 = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(16), (), 0)
|
glbl0 = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(16), (), 0)
|
||||||
glbl1 = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(64), (), 0)
|
glbl1 = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(64), (), 0)
|
||||||
gidx0 = UOp.range(42, 0, AxisType.GLOBAL)
|
gidx0 = UOp.range(42, 0, AxisType.GLOBAL)
|
||||||
ld0 = UOp(Ops.LOAD, dtypes.int, (glbl0.index(gidx0, gidx0<8),)).cast(dtypes.index)
|
ld0 = UOp(Ops.LOAD, dtypes.int, (glbl0.index(gidx0.valid(gidx0<8)),)).cast(dtypes.index)
|
||||||
ld1 = UOp(Ops.LOAD, dtypes.int, (glbl1.index(ld0*2, (ld0>=0)&(ld0<32)),))
|
ld1 = UOp(Ops.LOAD, dtypes.int, (glbl1.index((ld0*2).valid((ld0>=0)&(ld0<32))),))
|
||||||
to_uops_list([ld1])
|
to_uops_list([ld1])
|
||||||
|
|
||||||
ld1 = UOp(Ops.LOAD, dtypes.int, (glbl1.index(ld0*2, (ld0>=0)&(ld0<64)),))
|
ld1 = UOp(Ops.LOAD, dtypes.int, (glbl1.index((ld0*2).valid((ld0>=0)&(ld0<64))),))
|
||||||
with self.assertRaises(RuntimeError): to_uops_list([ld1])
|
with self.assertRaises(RuntimeError): to_uops_list([ld1])
|
||||||
|
|
||||||
def test_bounds_with_loaded_bool(self):
|
def test_bounds_with_loaded_bool(self):
|
||||||
@@ -620,7 +621,7 @@ class TestUOpGraph(unittest.TestCase):
|
|||||||
glbl2 = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), (), 2)
|
glbl2 = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), (), 2)
|
||||||
idx = UOp.const(dtypes.int, 0)
|
idx = UOp.const(dtypes.int, 0)
|
||||||
ld0 = UOp(Ops.LOAD, dtypes.int, (glbl1.index(UOp.invalid()),))
|
ld0 = UOp(Ops.LOAD, dtypes.int, (glbl1.index(UOp.invalid()),))
|
||||||
ld1 = UOp(Ops.LOAD, dtypes.int, (glbl2.index(idx, UOp.const(dtypes.bool, True)),))
|
ld1 = UOp(Ops.LOAD, dtypes.int, (glbl2.index(idx.valid(UOp.const(dtypes.bool, True))),))
|
||||||
uops = to_uops_list([UOp(Ops.STORE, dtypes.void, (glbl0.index(idx), ld1+ld0))])
|
uops = to_uops_list([UOp(Ops.STORE, dtypes.void, (glbl0.index(idx), ld1+ld0))])
|
||||||
ld0 = uops[-1].src[-1]
|
ld0 = uops[-1].src[-1]
|
||||||
# the gate and invalid value are deleted from ld1
|
# the gate and invalid value are deleted from ld1
|
||||||
@@ -633,7 +634,7 @@ class TestUOpGraph(unittest.TestCase):
|
|||||||
st = UOp(Ops.STORE, dtypes.void, (smem.index(lidx), UOp.load(glbl0.index(lidx), dtype=dtypes.int)))
|
st = UOp(Ops.STORE, dtypes.void, (smem.index(lidx), UOp.load(glbl0.index(lidx), dtype=dtypes.int)))
|
||||||
barrier = UOp(Ops.BARRIER, dtypes.void, (st, ))
|
barrier = UOp(Ops.BARRIER, dtypes.void, (st, ))
|
||||||
ld0 = UOp(Ops.LOAD, dtypes.int, (smem.after(barrier).index(UOp.invalid()),))
|
ld0 = UOp(Ops.LOAD, dtypes.int, (smem.after(barrier).index(UOp.invalid()),))
|
||||||
ld1 = UOp(Ops.LOAD, dtypes.int, (smem.after(barrier).index(lidx+2, UOp.const(dtypes.bool, True)),))
|
ld1 = UOp(Ops.LOAD, dtypes.int, (smem.after(barrier).index((lidx+2).valid(UOp.const(dtypes.bool, True))),))
|
||||||
uops = to_uops_list([UOp(Ops.STORE, dtypes.void, (glbl0.index(lidx), ld1+ld0))])
|
uops = to_uops_list([UOp(Ops.STORE, dtypes.void, (glbl0.index(lidx), ld1+ld0))])
|
||||||
|
|
||||||
ld0 = uops[-1].src[-1]
|
ld0 = uops[-1].src[-1]
|
||||||
@@ -646,7 +647,7 @@ class TestUOpGraph(unittest.TestCase):
|
|||||||
idx1 = UOp.const(dtypes.int, 0)
|
idx1 = UOp.const(dtypes.int, 0)
|
||||||
val = UOp.const(dtypes.int, 42)
|
val = UOp.const(dtypes.int, 42)
|
||||||
st0 = glbl.index(UOp.invalid()).store(val)
|
st0 = glbl.index(UOp.invalid()).store(val)
|
||||||
st1 = glbl.index(idx0, UOp.const(dtypes.bool, True)).store(val)
|
st1 = glbl.index(idx0.valid(UOp.const(dtypes.bool, True))).store(val)
|
||||||
uops = to_uops_list([st0, st1])
|
uops = to_uops_list([st0, st1])
|
||||||
# only the second store happens
|
# only the second store happens
|
||||||
self.assertEqual(len(uops), 5)
|
self.assertEqual(len(uops), 5)
|
||||||
|
|||||||
@@ -277,7 +277,7 @@ class TestGatedStoreRewrite(unittest.TestCase):
|
|||||||
gmem = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), (), 0)
|
gmem = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), (), 0)
|
||||||
gidx0 = UOp(Ops.SPECIAL, dtypes.int, (UOp.const(dtypes.int, 4),), 'gidx0')
|
gidx0 = UOp(Ops.SPECIAL, dtypes.int, (UOp.const(dtypes.int, 4),), 'gidx0')
|
||||||
gate = gidx0<UOp.const(dtypes.int, 1)
|
gate = gidx0<UOp.const(dtypes.int, 1)
|
||||||
idx = UOp(Ops.INDEX, dtypes.float.ptr(), (gmem, gidx0 * UOp.const(dtypes.int, 2), gate))
|
idx = UOp(Ops.INDEX, dtypes.float.ptr(), (gmem, (gidx0 * UOp.const(dtypes.int, 2)).valid(gate)))
|
||||||
val = UOp.const(dtypes.float, 42.0)
|
val = UOp.const(dtypes.float, 42.0)
|
||||||
store = UOp(Ops.STORE, dtypes.void, (idx, val))
|
store = UOp(Ops.STORE, dtypes.void, (idx, val))
|
||||||
uops = to_uops_list([store])
|
uops = to_uops_list([store])
|
||||||
@@ -294,7 +294,7 @@ class TestGatedStoreRewrite(unittest.TestCase):
|
|||||||
gmem1 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), (), 1)
|
gmem1 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), (), 1)
|
||||||
gidx0 = UOp(Ops.SPECIAL, dtypes.int, (UOp.const(dtypes.int, 4),), 'gidx0')
|
gidx0 = UOp(Ops.SPECIAL, dtypes.int, (UOp.const(dtypes.int, 4),), 'gidx0')
|
||||||
idx = gidx0 * UOp.const(dtypes.int, 2)
|
idx = gidx0 * UOp.const(dtypes.int, 2)
|
||||||
idx0 = UOp(Ops.INDEX, dtypes.float.ptr(), (gmem0, idx, gidx0<UOp.const(dtypes.int, 1)))
|
idx0 = UOp(Ops.INDEX, dtypes.float.ptr(), (gmem0, idx.valid(gidx0<UOp.const(dtypes.int, 1))))
|
||||||
idx1 = UOp(Ops.INDEX, dtypes.float.ptr(), (gmem1, idx))
|
idx1 = UOp(Ops.INDEX, dtypes.float.ptr(), (gmem1, idx))
|
||||||
val = UOp.const(dtypes.float, 42.0)
|
val = UOp.const(dtypes.float, 42.0)
|
||||||
stores = [UOp.store(idx0, val), UOp.store(idx1, val)]
|
stores = [UOp.store(idx0, val), UOp.store(idx1, val)]
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
import unittest, functools
|
import unittest, functools
|
||||||
from tinygrad import Tensor
|
from tinygrad import Tensor, Context
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
def orthogonality_helper(A:Tensor, tolerance=1e-5):
|
def orthogonality_helper(A:Tensor, tolerance=1e-5):
|
||||||
@@ -27,15 +27,16 @@ class TestLinAlg(unittest.TestCase):
|
|||||||
reconstruction_helper([U,s_diag,V],a)
|
reconstruction_helper([U,s_diag,V],a)
|
||||||
|
|
||||||
def _test_svd_nonfull(self, size):
|
def _test_svd_nonfull(self, size):
|
||||||
a = Tensor.randn(size).realize()
|
with Context(IGNORE_OOB=1): # sometimes this is slow in CI
|
||||||
U,S,V = a.svd(full_matrices=False)
|
a = Tensor.randn(size).realize()
|
||||||
b_shape,m,n = size[0:-2],size[-2],size[-1]
|
U,S,V = a.svd(full_matrices=False)
|
||||||
k = min(m,n)
|
b_shape,m,n = size[0:-2],size[-2],size[-1]
|
||||||
s_diag = (S.unsqueeze(-2) * Tensor.eye(k).reshape((1,) * len(b_shape) + (k,k)).expand(b_shape + (k,k)))
|
k = min(m,n)
|
||||||
#reduced U,V is only orthogonal along smaller dim
|
s_diag = (S.unsqueeze(-2) * Tensor.eye(k).reshape((1,) * len(b_shape) + (k,k)).expand(b_shape + (k,k)))
|
||||||
if (m < n): orthogonality_helper(U),orthogonality_helper(V)
|
#reduced U,V is only orthogonal along smaller dim
|
||||||
else: orthogonality_helper(U.transpose(-2,-1)),orthogonality_helper(V.transpose(-2,-1))
|
if (m < n): orthogonality_helper(U),orthogonality_helper(V)
|
||||||
reconstruction_helper([U,s_diag,V],a)
|
else: orthogonality_helper(U.transpose(-2,-1)),orthogonality_helper(V.transpose(-2,-1))
|
||||||
|
reconstruction_helper([U,s_diag,V],a)
|
||||||
|
|
||||||
# faster for parallel pytest
|
# faster for parallel pytest
|
||||||
def test_svd_nonfull_2_2(self): self._test_svd_nonfull((2,2))
|
def test_svd_nonfull_2_2(self): self._test_svd_nonfull((2,2))
|
||||||
@@ -75,4 +76,4 @@ class TestLinAlg(unittest.TestCase):
|
|||||||
orthogonality_helper(b if size[-1] > size[-2] else b.transpose(-2, -1), tolerance=1e-3)
|
orthogonality_helper(b if size[-1] > size[-2] else b.transpose(-2, -1), tolerance=1e-3)
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
|||||||
@@ -39,6 +39,7 @@ shared_spec = PatternMatcher([
|
|||||||
(UPat(Ops.RANGE, src=(UPat.var("x"),), allow_any_len=True, name="rng"), lambda rng,x:
|
(UPat(Ops.RANGE, src=(UPat.var("x"),), allow_any_len=True, name="rng"), lambda rng,x:
|
||||||
rng.dtype == x.dtype and isinstance(rng.arg, tuple) and len(rng.arg) >= 2 and \
|
rng.dtype == x.dtype and isinstance(rng.arg, tuple) and len(rng.arg) >= 2 and \
|
||||||
all(isinstance(ra, int) for ra in rng.arg[0:-1]) and isinstance(rng.arg[-1], AxisType)),
|
all(isinstance(ra, int) for ra in rng.arg[0:-1]) and isinstance(rng.arg[-1], AxisType)),
|
||||||
|
(UPat(Ops.INDEX, src=(UPat(),), allow_any_len=True, name="x"), lambda x: all(y.dtype == dtypes.index for y in x.src[1:]) or None),
|
||||||
])
|
])
|
||||||
|
|
||||||
# ***** UOp spec in the Tensor graph *****
|
# ***** UOp spec in the Tensor graph *****
|
||||||
@@ -105,9 +106,9 @@ tensor_spec = PatternMatcher([
|
|||||||
(UPat(Ops.AFTER, src=(UPat((Ops.BUFFER, Ops.AFTER)),), allow_any_len=True), lambda: True),
|
(UPat(Ops.AFTER, src=(UPat((Ops.BUFFER, Ops.AFTER)),), allow_any_len=True), lambda: True),
|
||||||
])+shared_spec
|
])+shared_spec
|
||||||
|
|
||||||
# ***** UOp spec in linearized programs *****
|
# ***** UOp spec in codegen shared between kernel and program *****
|
||||||
|
|
||||||
program_spec = PatternMatcher([
|
shared_codegen_spec = PatternMatcher([
|
||||||
# DEFINEs
|
# DEFINEs
|
||||||
(UPat(Ops.DEFINE_GLOBAL, name="x"), lambda x: isinstance(x.dtype, (PtrDType, ImageDType)) and x.dtype.addrspace == AddrSpace.GLOBAL),
|
(UPat(Ops.DEFINE_GLOBAL, name="x"), lambda x: isinstance(x.dtype, (PtrDType, ImageDType)) and x.dtype.addrspace == AddrSpace.GLOBAL),
|
||||||
(UPat(Ops.DEFINE_LOCAL, name="x"), lambda x: isinstance(x.dtype, PtrDType) and x.dtype.addrspace == AddrSpace.LOCAL),
|
(UPat(Ops.DEFINE_LOCAL, name="x"), lambda x: isinstance(x.dtype, PtrDType) and x.dtype.addrspace == AddrSpace.LOCAL),
|
||||||
@@ -117,42 +118,59 @@ program_spec = PatternMatcher([
|
|||||||
(UPat(Ops.AFTER, src=(UPat(GroupOp.Defines),), allow_any_len=True), lambda: True),
|
(UPat(Ops.AFTER, src=(UPat(GroupOp.Defines),), allow_any_len=True), lambda: True),
|
||||||
(UPat(Ops.GROUP, dtypes.void), lambda: True),
|
(UPat(Ops.GROUP, dtypes.void), lambda: True),
|
||||||
|
|
||||||
# INDEX is used in new style load/store
|
|
||||||
(UPat(Ops.INDEX, src=(UPat(GroupOp.Defines).or_after(), UPat(), UPat(dtype=dtypes.bool))), lambda: True),
|
|
||||||
(UPat(Ops.INDEX, src=(UPat(GroupOp.Defines).or_after(), UPat())), lambda: True),
|
|
||||||
|
|
||||||
# LOAD (idx, alt_value) / STORE(if gated) / LOAD(idx) / STORE(idx, val)
|
|
||||||
(UPat().index(UPat(), UPat(dtype=dtypes.bool, name="gate"), name="idx").or_casted().load(UPat()), validate_index),
|
|
||||||
(UPat().index(UPat(), UPat(dtype=dtypes.bool, name="gate"), name="idx").or_casted().store(UPat()), validate_index),
|
|
||||||
(UPat().index(UPat(), name="idx").or_casted().load(), validate_index),
|
|
||||||
(UPat().index(UPat(), name="idx").or_casted().store(UPat()), validate_index),
|
|
||||||
|
|
||||||
# RANGE/SPECIAL define loops, END closes them
|
# RANGE/SPECIAL define loops, END closes them
|
||||||
(UPat(Ops.SPECIAL, src=(UPat.var("x"),), name="s"), lambda s,x: s.dtype == x.dtype == dtypes.int32 and isinstance(s.arg, str)),
|
|
||||||
(UPat(Ops.END, src=(UPat(), UPat(Ops.RANGE)), dtype=dtypes.void), lambda: True),
|
(UPat(Ops.END, src=(UPat(), UPat(Ops.RANGE)), dtype=dtypes.void), lambda: True),
|
||||||
|
|
||||||
# make sure all index dtypes have been lowered
|
|
||||||
(UPat(GroupOp.All, dtype=dtypes.index), lambda: False),
|
|
||||||
(UPat(Ops.CONST, arg=Invalid), lambda: False),
|
|
||||||
(UPat(Ops.VCONST, name="x"), lambda x: all(v is not Invalid for v in x.src)),
|
|
||||||
|
|
||||||
# WMMA has a <a, b, acc>
|
# WMMA has a <a, b, acc>
|
||||||
(UPat(Ops.WMMA, src=(UPat(), UPat(), UPat()), name="x"), lambda x: isinstance(x.arg, tuple) and len(x.arg) == 8),
|
(UPat(Ops.WMMA, src=(UPat(), UPat(), UPat()), name="x"), lambda x: isinstance(x.arg, tuple) and len(x.arg) == 8),
|
||||||
|
|
||||||
# if has a <gate, index_for_dedup>
|
# UNROLL/CONTRACT is used here for WMMA
|
||||||
(UPat(Ops.IF, dtype=dtypes.void, src=(UPat(dtype=dtypes.bool), UPat((Ops.CAST, Ops.INDEX)))), lambda: True),
|
(UPat(Ops.CONTRACT, name="x"), lambda x: x.dtype.count == prod(y[1] for y in x.arg)),
|
||||||
(UPat(Ops.ENDIF, dtype=dtypes.void, src=(UPat(Ops.IF),)), lambda: True),
|
(UPat(Ops.UNROLL, name="x"), lambda x: x.src[0].dtype.count == prod(y[1] for y in x.arg)),
|
||||||
|
|
||||||
# VECTORIZE/GEP
|
# VECTORIZE/GEP
|
||||||
(UPat(Ops.VECTORIZE, name="x"), lambda x: len(x.src)>1 and len(x.src) == x.dtype.vcount and all(x.dtype == y.dtype.vec(len(x.src)) for y in x.src)),
|
(UPat(Ops.VECTORIZE, name="x"), lambda x: len(x.src)>1 and len(x.src) == x.dtype.vcount and all(x.dtype == y.dtype.vec(len(x.src)) for y in x.src)),
|
||||||
(UPat(Ops.GEP, src=(UPat.var("src"),), name="gep"), lambda gep,src: gep.dtype == src.dtype.scalar()),
|
(UPat(Ops.GEP, src=(UPat.var("src"),), name="gep"), lambda gep,src: gep.dtype == src.dtype.scalar()),
|
||||||
|
|
||||||
# BARRIER
|
# LOAD(idx) / STORE(idx, val) / LOAD with alt value only exists in program_spec
|
||||||
(UPat(Ops.BARRIER, dtypes.void, src=(UPat(),)), lambda: True),
|
(UPat().index(UPat()).or_casted().load(), lambda: True),
|
||||||
|
(UPat(Ops.INDEX).or_casted().store(UPat()), lambda: True),
|
||||||
|
|
||||||
# all CUSTOM + PRECAST
|
# all CUSTOM + PRECAST
|
||||||
(UPat((Ops.CUSTOMI, Ops.CUSTOM, Ops.PRECAST)), lambda: True),
|
(UPat((Ops.CUSTOMI, Ops.CUSTOM, Ops.PRECAST)), lambda: True),
|
||||||
])+shared_spec
|
|
||||||
|
# INDEX
|
||||||
|
(UPat(GroupOp.Defines, name="buf").or_after().index(UPat.var("idx")), validate_index),
|
||||||
|
|
||||||
|
# SPECIAL
|
||||||
|
(UPat(Ops.SPECIAL, src=(UPat.var("x", (dtypes.index, dtypes.int32)),), name="s"), lambda s,x: s.dtype == x.dtype and isinstance(s.arg, str)),
|
||||||
|
|
||||||
|
# BARRIER
|
||||||
|
(UPat(Ops.BARRIER, dtypes.void, src=(UPat(),)), lambda: True),
|
||||||
|
])
|
||||||
|
|
||||||
|
# ***** UOp spec in linearized programs *****
|
||||||
|
|
||||||
|
program_spec = PatternMatcher([
|
||||||
|
# INDEX with a gate as third src
|
||||||
|
(UPat(Ops.INDEX, src=(UPat(GroupOp.Defines, name="buf").or_after(), UPat.var("idx"), UPat.var("gate", dtype=dtypes.bool))), validate_index),
|
||||||
|
|
||||||
|
# LOAD (idx, alt_value), LOAD can have an alt value, but only if the index has a gate
|
||||||
|
(UPat().index(UPat(), UPat(dtype=dtypes.bool)).or_casted().load(UPat()), lambda: True),
|
||||||
|
|
||||||
|
# END closes ranges
|
||||||
|
(UPat(Ops.END, src=(UPat(), UPat(Ops.RANGE)), dtype=dtypes.void), lambda: True),
|
||||||
|
|
||||||
|
# make sure all index dtypes have been lowered
|
||||||
|
(UPat(GroupOp.All, dtype=dtypes.index), lambda: False),
|
||||||
|
(UPat(Ops.CONST, arg=Invalid), lambda: False),
|
||||||
|
(UPat(Ops.VCONST, name="x"), lambda x: all(v is not Invalid for v in x.arg) and len(x.arg)==x.dtype.vcount>1 and
|
||||||
|
type(x.arg) is type(dtypes.as_const(x.arg, x.dtype))),
|
||||||
|
|
||||||
|
# if has a <gate, index_for_dedup>
|
||||||
|
(UPat(Ops.IF, dtype=dtypes.void, src=(UPat(dtype=dtypes.bool), UPat((Ops.CAST, Ops.INDEX)))), lambda: True),
|
||||||
|
(UPat(Ops.ENDIF, dtype=dtypes.void, src=(UPat(Ops.IF),)), lambda: True),
|
||||||
|
])+shared_codegen_spec+shared_spec
|
||||||
|
|
||||||
# ***** UOp spec in kernel graph *****
|
# ***** UOp spec in kernel graph *****
|
||||||
|
|
||||||
@@ -160,14 +178,6 @@ kernel_spec = PatternMatcher([
|
|||||||
# index is allowed here
|
# index is allowed here
|
||||||
(UPat(GroupOp.Elementwise|{Ops.CONST, Ops.RANGE, Ops.DEFINE_VAR}, dtype=dtypes.index), lambda: True),
|
(UPat(GroupOp.Elementwise|{Ops.CONST, Ops.RANGE, Ops.DEFINE_VAR}, dtype=dtypes.index), lambda: True),
|
||||||
|
|
||||||
# LOAD(idx) / STORE(idx, val) -- NOTE: we do this here to not run validate_index since z3 doesn't support Invalid
|
|
||||||
(UPat(Ops.INDEX).or_casted().load(), lambda: True),
|
|
||||||
(UPat(Ops.INDEX).or_casted().store(UPat()), lambda: True),
|
|
||||||
|
|
||||||
# UNROLL/CONTRACT is used here for WMMA
|
|
||||||
(UPat(Ops.CONTRACT, name="x"), lambda x: x.dtype.count == prod(y[1] for y in x.arg)),
|
|
||||||
(UPat(Ops.UNROLL, name="x"), lambda x: x.src[0].dtype.count == prod(y[1] for y in x.arg)),
|
|
||||||
|
|
||||||
# END can end multiple axes here
|
# END can end multiple axes here
|
||||||
(UPat(Ops.END, src=(UPat(), UPat()), allow_any_len=True, dtype=dtypes.void), lambda: True),
|
(UPat(Ops.END, src=(UPat(), UPat()), allow_any_len=True, dtype=dtypes.void), lambda: True),
|
||||||
|
|
||||||
@@ -176,10 +186,7 @@ kernel_spec = PatternMatcher([
|
|||||||
|
|
||||||
# reduce must be on ranges
|
# reduce must be on ranges
|
||||||
(UPat(Ops.REDUCE, src=(UPat(),), allow_any_len=True, name="x"), lambda x: all(y.dtype == dtypes.index for y in x.src[1:])),
|
(UPat(Ops.REDUCE, src=(UPat(),), allow_any_len=True, name="x"), lambda x: all(y.dtype == dtypes.index for y in x.src[1:])),
|
||||||
|
])+shared_codegen_spec+shared_spec
|
||||||
# intermediate index
|
|
||||||
(UPat(Ops.INDEX, src=(UPat(),), allow_any_len=True, name="x"), lambda x: all(y.dtype == dtypes.index for y in x.src[1:]) or None),
|
|
||||||
])+program_spec+shared_spec
|
|
||||||
|
|
||||||
# *** this spec should match all UOps ever created ***
|
# *** this spec should match all UOps ever created ***
|
||||||
|
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
from typing import Callable
|
from typing import Callable
|
||||||
from tinygrad.uop.ops import PatternMatcher, UPat, GroupOp, Ops, UOp, python_alu, graph_rewrite
|
from tinygrad.uop.ops import PatternMatcher, UPat, GroupOp, Ops, UOp, python_alu, graph_rewrite
|
||||||
from tinygrad.dtype import ImageDType, dtypes
|
from tinygrad.dtype import ImageDType, dtypes, Invalid
|
||||||
from tinygrad.helpers import IGNORE_OOB, Context, cpu_profile
|
from tinygrad.helpers import IGNORE_OOB, Context, cpu_profile
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@@ -25,15 +25,19 @@ try:
|
|||||||
# ctx is (solver, load_number_dict)
|
# ctx is (solver, load_number_dict)
|
||||||
# each uop gets rewritten to NOOP(arg=(solver, z3_object)), the arg has the solver first due to UOpMetaClass caching. z3 objects from different
|
# each uop gets rewritten to NOOP(arg=(solver, z3_object)), the arg has the solver first due to UOpMetaClass caching. z3 objects from different
|
||||||
# contexts can have the same hash but error on comparison
|
# contexts can have the same hash but error on comparison
|
||||||
|
def add_valid(ctx, cond, x):
|
||||||
|
ctx[0].add(cond.arg[1])
|
||||||
|
return x
|
||||||
z3_renderer = PatternMatcher([
|
z3_renderer = PatternMatcher([
|
||||||
|
(UPat(Ops.NOOP, name="cond").where(UPat(Ops.NOOP, name="x"), UPat(Ops.CONST, arg=Invalid)), add_valid),
|
||||||
(UPat(Ops.SPECIAL, src=UPat(Ops.NOOP), name="x"), lambda x,ctx: UOp(Ops.NOOP, arg=(ctx[0],create_bounded(x.arg, 0, x.src[0].arg[1]-1, ctx[0])))),
|
(UPat(Ops.SPECIAL, src=UPat(Ops.NOOP), name="x"), lambda x,ctx: UOp(Ops.NOOP, arg=(ctx[0],create_bounded(x.arg, 0, x.src[0].arg[1]-1, ctx[0])))),
|
||||||
(UPat(Ops.DEFINE_VAR, name="x"), lambda x,ctx: UOp(Ops.NOOP, arg=(ctx[0],create_bounded(x.arg[0], x.arg[1], x.arg[2], ctx[0])))),
|
(UPat(Ops.DEFINE_VAR, name="x"), lambda x,ctx: UOp(Ops.NOOP, arg=(ctx[0],create_bounded(x.arg[0], x.arg[1], x.arg[2], ctx[0])))),
|
||||||
(UPat(Ops.RANGE, name="x"), lambda x,ctx: UOp(Ops.NOOP, arg=(ctx[0],create_bounded(f"ridx{x.arg}", 0, x.src[0].arg[1]-1, ctx[0])))),
|
(UPat(Ops.RANGE, name="x"), lambda x,ctx: UOp(Ops.NOOP, arg=(ctx[0],create_bounded(f"ridx{x.arg}", 0, x.src[0].arg[1]-1, ctx[0])))),
|
||||||
# loaded bools become a z3 int with min max of 0-1
|
# loaded bools become a z3 int with min max of 0-1
|
||||||
(UPat(Ops.LOAD, dtypes.ints+(dtypes.bool,), name="x"), lambda x,ctx:
|
(UPat(Ops.LOAD, dtypes.ints+(dtypes.bool,), name="x"), lambda x,ctx:
|
||||||
UOp(Ops.NOOP, arg=(ctx[0],create_bounded(f"load{ctx[1].setdefault(x, len(ctx[1]))}", x.dtype.min, x.dtype.max, ctx[0]))).cast(x.dtype)),
|
UOp(Ops.NOOP, arg=(ctx[0],create_bounded(f"load{ctx[1].setdefault(x, len(ctx[1]))}", x.dtype.min, x.dtype.max, ctx[0]))).cast(x.dtype)),
|
||||||
(UPat(Ops.CONST, dtype=dtypes.ints+(dtypes.bool,dtypes.index), name="x"),
|
(UPat(Ops.CONST, dtype=dtypes.ints+(dtypes.bool,dtypes.index), name="x"), lambda x,ctx:
|
||||||
lambda x,ctx: UOp(Ops.NOOP, arg=(ctx[0],(z3.BoolVal if dtypes.is_bool(x.dtype) else z3.IntVal)(x.arg, ctx=ctx[0].ctx)))),
|
UOp(Ops.NOOP, arg=(ctx[0],(z3.BoolVal if dtypes.is_bool(x.dtype) else z3.IntVal)(x.arg, ctx=ctx[0].ctx))) if x.arg is not Invalid else None),
|
||||||
# z3 can cast from bool to int automatically
|
# z3 can cast from bool to int automatically
|
||||||
(UPat(Ops.CAST, dtype=dtypes.ints+(dtypes.index,), src=UPat(Ops.NOOP), name="x"), lambda x: x.src[0]),
|
(UPat(Ops.CAST, dtype=dtypes.ints+(dtypes.index,), src=UPat(Ops.NOOP), name="x"), lambda x: x.src[0]),
|
||||||
(UPat(Ops.CAST, dtype=dtypes.bool, src=UPat(Ops.NOOP), name="x"), lambda x,ctx: UOp(Ops.NOOP, arg=(ctx[0], x.src[0].arg[1]!=0))),
|
(UPat(Ops.CAST, dtype=dtypes.bool, src=UPat(Ops.NOOP), name="x"), lambda x,ctx: UOp(Ops.NOOP, arg=(ctx[0], x.src[0].arg[1]!=0))),
|
||||||
@@ -55,24 +59,26 @@ try:
|
|||||||
z3_imported = True
|
z3_imported = True
|
||||||
except (ImportError, AttributeError): z3_imported = False
|
except (ImportError, AttributeError): z3_imported = False
|
||||||
|
|
||||||
def validate_index(idx:UOp, gate:UOp|None=None):
|
def validate_index(buf:UOp, idx:UOp, gate:UOp|None=None):
|
||||||
|
if idx.op is Ops.CONST and idx.arg is Invalid: return True
|
||||||
if gate is None: gate = UOp.const(dtypes.bool, True)
|
if gate is None: gate = UOp.const(dtypes.bool, True)
|
||||||
# TODO: check for overflow
|
# TODO: check for overflow
|
||||||
if IGNORE_OOB or isinstance(idx.dtype, ImageDType) or (sz := idx.src[0].ptrdtype.size) == -1: return True
|
if IGNORE_OOB or isinstance(buf.dtype, ImageDType) or (sz := buf.ptrdtype.size) == -1: return True
|
||||||
# We can use UOp min/max to do a faster check, but it can give false positive since its not an exact bound and doesn't consider the mask
|
# We can use UOp min/max to do a faster check, but it can give false positive since its not an exact bound and doesn't consider the mask
|
||||||
if 0<=idx.src[1].vmin and idx.src[1].vmax<sz: return True
|
if 0<=idx.vmin and idx.vmax<sz: return True
|
||||||
|
|
||||||
# WEBGPU has a BITCAST in the index. TODO: fix
|
# WEBGPU has a BITCAST in the index. TODO: fix
|
||||||
if any(x.op is Ops.BITCAST for x in idx.toposort()): return True
|
if any(x.op is Ops.BITCAST for x in idx.toposort()): return True
|
||||||
|
|
||||||
if not z3_imported: raise ImportError("z3 >= 4.12.4 is required for bounds checking, try IGNORE_OOB=0 or \"pip install 'z3-solver>=4.12.4\"")
|
if not z3_imported: raise ImportError("z3 >= 4.12.4 is required for bounds checking, try IGNORE_OOB=0 or \"pip install 'z3-solver>=4.12.4\"")
|
||||||
solver = z3.Solver(ctx=z3.Context())
|
solver = z3.Solver(ctx=z3.Context())
|
||||||
z3_idx, z3_mask = uops_to_z3(solver, idx.src[1], gate)
|
z3_idx, z3_mask = uops_to_z3(solver, idx, gate)
|
||||||
solver.add(z3_mask)
|
solver.add(z3_mask)
|
||||||
with cpu_profile("validate index with z3", "TINY"):
|
with cpu_profile("validate index with z3", "TINY"):
|
||||||
if solver.check((z3_idx<0)|(sz<=z3_idx)) == z3.sat:
|
match solver.check((z3_idx<0)|(sz<=z3_idx)):
|
||||||
print(f"idx={idx.src[1].render(simplify=False)}")
|
case z3.unsat: return True
|
||||||
print(f"gate={gate.render(simplify=False)}")
|
case z3.sat: print(f"# OUT OF BOUNDS ACCESS: at {solver.model()} INDEX not in 0 - {sz}\nconstraints = {solver}")
|
||||||
print(f"# OUT OF BOUNDS ACCESS: at {solver.model()} INDEX not in 0 - {sz}\nconstraints = {solver}")
|
case z3.unknown: print(f"# UNKNOWN RESULT FROM Z3: {solver.reason_unknown()}\nconstraints = {solver}")
|
||||||
return False
|
print(f"idx={idx.render(simplify=False)}")
|
||||||
return True
|
print(f"mask={gate.render(simplify=False)}")
|
||||||
|
return False
|
||||||
|
|||||||
Reference in New Issue
Block a user