import itertools from Compiler import types, library, instructions from Compiler import comparison, util def dest_comp(B): Bt = B.transpose() St_flat = Bt.get_vector().prefix_sum() Tt_flat = Bt.get_vector() * St_flat.get_vector() Tt = types.Matrix(*Bt.sizes, B.value_type) Tt.assign_vector(Tt_flat) return sum(Tt) - 1 def reveal_sort(k, D, reverse=False): """ Sort in place according to "perfect" key. The name hints at the fact that a random order of the keys is revealed. :param k: vector or Array of sint containing exactly :math:`0,\dots,n-1` in any order :param D: Array or MultiArray to sort :param reverse: wether :py:obj:`key` is a permutation in forward or backward order """ comparison.require_ring_size(util.log2(len(k)) + 1, 'sorting') assert len(k) == len(D) library.break_point() shuffle = types.sint.get_secure_shuffle(len(k)) k_prime = k.get_vector().secure_permute(shuffle).reveal() idx = types.Array.create_from(k_prime) if reverse: D.assign_vector(D.get_slice_vector(idx)) library.break_point() D.secure_permute(shuffle, reverse=True) else: D.secure_permute(shuffle) library.break_point() v = D.get_vector() D.assign_slice_vector(idx, v) library.break_point() instructions.delshuffle(shuffle) def radix_sort(k, D, n_bits=None, signed=True): """ Sort in place according to key. :param k: keys (vector or Array of sint or sfix) :param D: Array or MultiArray to sort :param n_bits: number of bits in keys (int) :param signed: whether keys are signed (bool) """ assert len(k) == len(D) bs = types.Matrix.create_from(k.get_vector().bit_decompose(n_bits)) if signed and len(bs) > 1: bs[-1][:] = bs[-1][:].bit_not() radix_sort_from_matrix(bs, D) def radix_sort_from_matrix(bs, D): n = len(D) for b in bs: assert(len(b) == n) B = types.sint.Matrix(n, 2) h = types.Array.create_from(types.sint(types.regint.inc(n))) @library.for_range(len(bs)) def _(i): b = bs[i] B.set_column(0, 1 - b.get_vector()) B.set_column(1, b.get_vector()) c = types.Array.create_from(dest_comp(B)) reveal_sort(c, h, reverse=False) @library.if_e(i < len(bs) - 1) def _(): reveal_sort(h, bs[i + 1], reverse=True) @library.else_ def _(): reveal_sort(h, D, reverse=True)