import math import util n_threads = 64 xor_op = lambda x, y: x ^ y n_bits = 64 full_t = sbits.get_type(64) sbits.n = n_bits if len(program.args) > 1: n_batches = int(program.args[1]) else: n_batches = 78 if len(program.args) > 2: m_batches = int(program.args[2]) else: m_batches = n_batches if len(program.args) > 3: n_threads = int(program.args[3]) else: n_threads = n_batches batch_size = 64 n = n_batches * batch_size m = m_batches * batch_size l = 16 a = Matrix(n, l, full_t) b = Matrix(m, l, full_t) t = sbitint.get_type(int(math.ceil(math.log(batch_size * l, 2))) + 1) matches = Matrix(n, m, t.bit_type) mismatches = Matrix(n, m, t) threshold = MemValue(t(10)) for i in range(n): for j in range(l): a[i][j] = full_t.get_input_from(0) for i in range(m): for j in range(l): b[i][j] = full_t.get_input_from(1) # test, create match between a[0] and b[1] but no match for a[1] a.assign_all(0) b.assign_all(0) a[0][0] = -1 b[1][0] = -1 a[1][1] = -1 @for_range_multithread(n_threads, 1, n) def _(i): print_ln('%s', i) @for_range_parallel(m_batches, m_batches) def _(j): j = j * batch_size av = sbitintvec.from_matrix((a[i][kk] for _ in range(batch_size)) \ for kk in range(l)) bv = sbitintvec.from_matrix((b[j + k][kk] for k in range(batch_size)) \ for kk in range(l)) res = xor_op(av, bv).popcnt() mismatches[i].set_range(j, (t(x) for x in res.elements())) @for_range_multithread(n_threads, 8, n) def _(i): print_ln('%s', i) @for_range_parallel(m_batches, m_batches) def _(j): j = j * batch_size v = sbitintvec(mismatches[i].get_range(j, batch_size)) vv = sbitintvec([threshold.read()] * batch_size) matches[i].set_range(j, v.less_than(vv, 10).elements()) mg = MultiArray([n_batches, m, t.n], full_t) ag = Matrix(n_batches, m, full_t) @for_range_multithread(n_threads, 1, n_batches) def _(i): mgi = mg[i] a = ag[i] i = i * batch_size print_ln('best %s', i) @for_range(m) def _(j): mgi[j].assign(sbitintvec(mismatches[i + k][j] for k in range(batch_size)).v) mgi = [sbitintvec.from_vec(mgi[j]) for j in range(m)] def reducer(a, b): c = a[0].less_than(b[0]) return util.if_else(c, (a[0], a[1] + [0] * len(b[1])), (b[0], [0] * len(a[1]) + b[1])) mm = util.tree_reduce(reducer, ((x, [2**batch_size - 1]) for x in mgi)) a.assign(mm[1]) @for_range_parallel(100, len(a)) def _(j): x = a[j] pm = sbitintvec(matches[i + k][j] for k in range(batch_size)) x = sbitintvec.from_vec([x]) for k, y in enumerate((pm & x).elements()): matches[i + k][j] = y def test(result, expected): print_ln('%s ?= %s', result.reveal(), expected) test(matches[0][1], 1) test(matches[0][0], 0) test(matches[1][0], 0) test(matches[1][1], 0) test(sum(matches[2]), 1) test(mismatches[0][1], 0) test(mismatches[0][0], 64) test(mismatches[1][0], 64) test(mismatches[1][1], 128) print_ln('%sx%s linkage of %s bits', n, n, l * batch_size)