from Compiler.types import * from Compiler.sorting import * from Compiler.library import * from Compiler import util, oram from itertools import accumulate import math debug = False debug_split = False max_leaves = None def get_type(x): if isinstance(x, (Array, SubMultiArray)): return x.value_type elif isinstance(x, (tuple, list)): x = x[0] + x[-1] if util.is_constant(x): return cint else: return type(x) else: return type(x) def PrefixSum(x): return x.get_vector().prefix_sum() def PrefixSumR(x): tmp = get_type(x).Array(len(x)) tmp.assign_vector(x) break_point() tmp[:] = tmp.get_reverse_vector().prefix_sum() break_point() return tmp.get_reverse_vector() def PrefixSum_inv(x): tmp = get_type(x).Array(len(x) + 1) tmp.assign_vector(x, base=1) tmp[0] = 0 return tmp.get_vector(size=len(x), base=1) - tmp.get_vector(size=len(x)) def PrefixSumR_inv(x): tmp = get_type(x).Array(len(x) + 1) tmp.assign_vector(x) tmp[-1] = 0 return tmp.get_vector(size=len(x)) - tmp.get_vector(base=1, size=len(x)) class SortPerm: def __init__(self, x): B = sint.Matrix(len(x), 2) B.set_column(0, 1 - x.get_vector()) B.set_column(1, x.get_vector()) self.perm = Array.create_from(dest_comp(B)) def apply(self, x): res = Array.create_from(x) reveal_sort(self.perm, res, False) return res def unapply(self, x): res = Array.create_from(x) reveal_sort(self.perm, res, True) return res def Sort(keys, *to_sort, n_bits=None, time=False): if time: start_timer(1) for k in keys: assert len(k) == len(keys[0]) n_bits = n_bits or [None] * len(keys) bs = Matrix.create_from( sum([k.get_vector().bit_decompose(nb) for k, nb in reversed(list(zip(keys, n_bits)))], [])) get_vec = lambda x: x[:] if isinstance(x, Array) else x res = Matrix.create_from(get_vec(x).v if isinstance(get_vec(x), sfix) else x for x in to_sort) res = res.transpose() if time: start_timer(11) radix_sort_from_matrix(bs, res) if time: stop_timer(11) stop_timer(1) res = res.transpose() return [sfix._new(get_vec(x), k=get_vec(y).k, f=get_vec(y).f) if isinstance(get_vec(y), sfix) else x for (x, y) in zip(res, to_sort)] def VectMax(key, *data, debug=False): def reducer(x, y): b = x[0] > y[0] if debug: print_ln('max b=%s', b.reveal()) return [b.if_else(xx, yy) for xx, yy in zip(x, y)] if debug: key = list(key) data = [list(x) for x in data] print_ln('vect max key=%s data=%s', util.reveal(key), util.reveal(data)) res = util.tree_reduce(reducer, zip(key, *data))[1:] if debug: print_ln('vect max res=%s', util.reveal(res)) return res def GroupSum(g, x): assert len(g) == len(x) p = PrefixSumR(x) * g pi = SortPerm(g.get_vector().bit_not()) p1 = pi.apply(p) s1 = PrefixSumR_inv(p1) d1 = PrefixSum_inv(s1) d = pi.unapply(d1) * g return PrefixSum(d) def GroupPrefixSum(g, x): assert len(g) == len(x) s = get_type(x).Array(len(x) + 1) s[0] = 0 s.assign_vector(PrefixSum(x), base=1) q = get_type(s).Array(len(x)) q.assign_vector(s.get_vector(size=len(x)) * g) return s.get_vector(size=len(x), base=1) - GroupSum(g, q) def GroupMax(g, keys, *x): if debug: print_ln('group max input g=%s keys=%s x=%s', util.reveal(g), util.reveal(keys), util.reveal(x)) assert len(keys) == len(g) for xx in x: assert len(xx) == len(g) n = len(g) m = int(math.ceil(math.log(n, 2))) keys = Array.create_from(keys) x = [Array.create_from(xx) for xx in x] g_new = Array.create_from(g) g_old = g_new.same_shape() for d in range(m): w = 2 ** d g_old[:] = g_new[:] break_point() vsize = n - w g_new.assign_vector(g_old.get_vector(size=vsize).bit_or( g_old.get_vector(size=vsize, base=w)), base=w) b = keys.get_vector(size=vsize) > keys.get_vector(size=vsize, base=w) for xx in [keys] + x: a = b.if_else(xx.get_vector(size=vsize), xx.get_vector(size=vsize, base=w)) xx.assign_vector(g_old.get_vector(size=vsize, base=w).if_else( xx.get_vector(size=vsize, base=w), a), base=w) break_point() if debug: print_ln('group max w=%s b=%s a=%s keys=%s x=%s g=%s', w, b.reveal(), util.reveal(a), util.reveal(keys), util.reveal(x), g_new.reveal()) t = sint.Array(len(g)) t[-1] = 1 t.assign_vector(g.get_vector(size=n - 1, base=1)) if debug: print_ln('group max end g=%s t=%s keys=%s x=%s', util.reveal(g), util.reveal(t), util.reveal(keys), util.reveal(x)) return [GroupSum(g, t[:] * xx) for xx in [keys] + x] def ModifiedGini(g, y, debug=False): assert len(g) == len(y) y = [y.get_vector().bit_not(), y] u = [GroupPrefixSum(g, yy) for yy in y] s = [GroupSum(g, yy) for yy in y] w = [ss - uu for ss, uu in zip(s, u)] us = sum(u) ws = sum(w) uqs = u[0] ** 2 + u[1] ** 2 wqs = w[0] ** 2 + w[1] ** 2 res = sfix(uqs) / us + sfix(wqs) / ws if debug: print_ln('g=%s y=%s s=%s', util.reveal(g), util.reveal(y), util.reveal(s)) print_ln('u0=%s', util.reveal(u[0])) print_ln('u0=%s', util.reveal(u[1])) print_ln('us=%s', util.reveal(us)) print_ln('w0=%s', util.reveal(w[0])) print_ln('w1=%s', util.reveal(w[1])) print_ln('ws=%s', util.reveal(ws)) print_ln('uqs=%s', util.reveal(uqs)) print_ln('wqs=%s', util.reveal(wqs)) if debug: print_ln('gini %s %s', type(res), util.reveal(res)) return res MIN_VALUE = -10000 def FormatLayer(h, g, *a): return CropLayer(h, *FormatLayer_without_crop(g, *a)) def FormatLayer_without_crop(g, *a, debug=False): for x in a: assert len(x) == len(g) v = [g.if_else(aa, 0) for aa in a] if debug: print_ln('format in %s', util.reveal(a)) print_ln('format mux %s', util.reveal(v)) v = Sort([g.bit_not()], *v, n_bits=[1]) if debug: print_ln('format sort %s', util.reveal(v)) return v def CropLayer(k, *v): if max_leaves: n = min(2 ** k, max_leaves) else: n = 2 ** k return [vv[:min(n, len(vv))] for vv in v] def TrainLeafNodes(h, g, y, NID): assert len(g) == len(y) assert len(g) == len(NID) Label = GroupSum(g, y.bit_not()) < GroupSum(g, y) return FormatLayer(h, g, NID, Label) def GroupSame(g, y): assert len(g) == len(y) s = GroupSum(g, [sint(1)] * len(g)) s0 = GroupSum(g, y.bit_not()) s1 = GroupSum(g, y) if debug_split: print_ln('group same g=%s', util.reveal(g)) print_ln('group same y=%s', util.reveal(y)) return (s == s0).bit_or(s == s1) def GroupFirstOne(g, b): assert len(g) == len(b) s = GroupPrefixSum(g, b) return s * b == 1 class TreeTrainer: """ Decision tree training by `Hamada et al.`_ :param x: sample data (by attribute, list or :py:obj:`~Compiler.types.Matrix`) :param y: binary labels (list or sint vector) :param h: height (int) :param binary: binary attributes instead of continuous :param attr_lengths: attribute description for mixed data (list of 0/1 for continuous/binary) :param n_threads: number of threads (default: single thread) .. _`Hamada et al.`: https://arxiv.org/abs/2112.12906 """ def ApplyTests(self, x, AID, Threshold): m = len(x) n = len(AID) assert len(AID) == len(Threshold) for xx in x: assert len(xx) == len(AID) e = sint.Matrix(m, n) AID = Array.create_from(AID) @for_range_multithread(self.n_threads, 1, m) def _(j): e[j][:] = AID[:] == j xx = sum(x[j] * e[j] for j in range(m)) if self.debug > 1: print_ln('apply e=%s xx=%s', util.reveal(e), util.reveal(xx)) print_ln('threshold %s', util.reveal(Threshold)) return 2 * xx < Threshold def AttributeWiseTestSelection(self, g, x, y, time=False, debug=False): assert len(g) == len(x) assert len(g) == len(y) if time: start_timer(2) s = ModifiedGini(g, y, debug=debug or self.debug > 2) if time: stop_timer(2) if debug or self.debug > 1: print_ln('gini %s', s.reveal()) xx = x t = get_type(x).Array(len(x)) t[-1] = MIN_VALUE t.assign_vector(xx.get_vector(size=len(x) - 1) + \ xx.get_vector(size=len(x) - 1, base=1)) gg = g p = sint.Array(len(x)) p[-1] = 1 p.assign_vector(gg.get_vector(base=1, size=len(x) - 1).bit_or( xx.get_vector(size=len(x) - 1) == \ xx.get_vector(size=len(x) - 1, base=1))) break_point() if debug: print_ln('attribute t=%s p=%s', util.reveal(t), util.reveal(p)) s = p[:].if_else(MIN_VALUE, s) t = p[:].if_else(MIN_VALUE, t[:]) if debug: print_ln('attribute s=%s t=%s', util.reveal(s), util.reveal(t)) if time: start_timer(3) s, t = GroupMax(gg, s, t) if time: stop_timer(3) if debug: print_ln('attribute s=%s t=%s', util.reveal(s), util.reveal(t)) return t, s def GlobalTestSelection(self, x, y, g): assert len(y) == len(g) for xx in x: assert(len(xx) == len(g)) m = len(x) n = len(y) u, t = [get_type(x).Matrix(m, n) for i in range(2)] v = get_type(y).Matrix(m, n) s = sfix.Matrix(m, n) @for_range_multithread(self.n_threads, 1, m) def _(j): single = not self.n_threads or self.n_threads == 1 time = self.time and single if debug: print_ln('run %s', j) @if_e(self.attr_lengths[j]) def _(): u[j][:], v[j][:] = Sort((PrefixSum(g), x[j]), x[j], y, n_bits=[util.log2(n), 1], time=time) @else_ def _(): u[j][:], v[j][:] = Sort((PrefixSum(g), x[j]), x[j], y, n_bits=[util.log2(n), None], time=time) if self.debug_threading: print_ln('global sort %s %s %s', j, util.reveal(u[j]), util.reveal(v[j])) t[j][:], s[j][:] = self.AttributeWiseTestSelection( g, u[j], v[j], time=time, debug=self.debug_selection) if self.debug_threading: print_ln('global attribute %s %s %s', j, util.reveal(t[j]), util.reveal(s[j])) n = len(g) a = sint.Array(n) if self.debug_threading: print_ln('global s=%s', util.reveal(s)) if self.debug_gini: print_ln('Gini indices ' + ' '.join(str(i) + ':%s' for i in range(m)), *(ss[0].reveal() for ss in s)) if self.time: start_timer(4) if self.debug > 1: print_ln('s=%s', s.reveal_nested()) print_ln('t=%s', t.reveal_nested()) a[:], tt = VectMax((s[j][:] for j in range(m)), range(m), (t[j][:] for j in range(m)), debug=self.debug > 1) tt = Array.create_from(tt) if self.time: stop_timer(4) if self.debug > 1: print_ln('a=%s', util.reveal(a)) print_ln('tt=%s', util.reveal(tt)) return a[:], tt[:] def TrainInternalNodes(self, k, x, y, g, NID): assert len(g) == len(y) for xx in x: assert len(xx) == len(g) AID, Threshold = self.GlobalTestSelection(x, y, g) s = GroupSame(g[:], y[:]) if self.debug > 1 or debug_split: print_ln('AID=%s', util.reveal(AID)) print_ln('Threshold=%s', util.reveal(Threshold)) print_ln('GroupSame=%s', util.reveal(s)) AID, Threshold = s.if_else(0, AID), s.if_else(MIN_VALUE, Threshold) if self.debug > 1 or debug_split: print_ln('AID=%s', util.reveal(AID)) print_ln('Threshold=%s', util.reveal(Threshold)) b = self.ApplyTests(x, AID, Threshold) layer = FormatLayer_without_crop(g[:], NID, AID, Threshold, debug=self.debug > 1) return *layer, b @method_block def train_layer(self, k): x = self.x y = self.y g = self.g NID = self.NID if self.debug > 1: print_ln('g=%s', g.reveal()) print_ln('y=%s', y.reveal()) print_ln('x=%s', x.reveal_nested()) self.nids[k], self.aids[k], self.thresholds[k], b = \ self.TrainInternalNodes(k, x, y, g, NID) if self.debug > 1: print_ln('layer %s:', k) for name, data in zip(('NID', 'AID', 'Thr'), (self.nids[k], self.aids[k], self.thresholds[k])): print_ln(' %s: %s', name, data.reveal()) NID[:] = 2 ** k * b + NID b_not = b.bit_not() if self.debug > 1: print_ln('b_not=%s', b_not.reveal()) g[:] = GroupFirstOne(g, b_not) + GroupFirstOne(g, b) y[:], g[:], NID[:], *xx = Sort([b], y, g, NID, *x, n_bits=[1]) for i, xxx in enumerate(xx): x[i] = xxx def __init__(self, x, y, h, binary=False, attr_lengths=None, n_threads=None): assert not (binary and attr_lengths) if binary: attr_lengths = [1] * len(x) else: attr_lengths = attr_lengths or ([0] * len(x)) for l in attr_lengths: assert l in (0, 1) self.attr_lengths = Array.create_from(regint(attr_lengths)) Array.check_indices = False Matrix.disable_index_checks() for xx in x: assert len(xx) == len(y) n = len(y) self.g = sint.Array(n) self.g.assign_all(0) self.g[0] = 1 self.NID = sint.Array(n) self.NID.assign_all(1) self.y = Array.create_from(y) self.x = Matrix.create_from(x) self.nids, self.aids = [sint.Matrix(h, n) for i in range(2)] self.thresholds = self.x.value_type.Matrix(h, n) self.n_threads = n_threads self.debug_selection = False self.debug_threading = False self.debug_gini = False self.debug = False self.time = False get_program().reading('decision tree learning', 'HIKC23') def train(self): """ Train and return decision tree. """ h = len(self.nids) @for_range(h) def _(k): self.train_layer(k) return self.get_tree(h) def train_with_testing(self, *test_set, output=False): """ Train decision tree and test against test data. :param y: binary labels (list or sint vector) :param x: sample data (by attribute, list or :py:obj:`~Compiler.types.Matrix`) :param output: output tree after every level :returns: tree """ for k in range(len(self.nids)): self.train_layer(k) tree = self.get_tree(k + 1) if output: output_decision_tree(tree) test_decision_tree('train', tree, self.y, self.x, n_threads=self.n_threads) if test_set: test_decision_tree('test', tree, *test_set, n_threads=self.n_threads) return tree def get_tree(self, h): Layer = [None] * (h + 1) for k in range(h): Layer[k] = CropLayer(k, self.nids[k], self.aids[k], self.thresholds[k]) Layer[h] = TrainLeafNodes(h, self.g[:], self.y[:], self.NID) return Layer def DecisionTreeTraining(x, y, h, binary=False): return TreeTrainer(x, y, h, binary=binary).train() def output_decision_tree(layers): """ Print decision tree output by :py:class:`TreeTrainer`. """ print_ln('full model %s', util.reveal(layers)) for i, layer in enumerate(layers[:-1]): print_ln('level %s:', i) for j, x in enumerate(('NID', 'AID', 'Thr')): print_ln(' %s: %s', x, util.reveal(layer[j])) print_ln('leaves:') for j, x in enumerate(('NID', 'result')): print_ln(' %s: %s', x, util.reveal(layers[-1][j])) def pick(bits, x): if len(bits) == 1: return bits[0] * x[0] else: try: return x[0].dot_product(bits, x) except: return sum(aa * bb for aa, bb in zip(bits, x)) def run_decision_tree(layers, data): """ Run decision tree against sample data. :param layers: tree output by :py:class:`TreeTrainer` :param data: sample data (:py:class:`~Compiler.types.Array`) :returns: binary label """ h = len(layers) - 1 index = 1 for k, layer in enumerate(layers[:-1]): assert len(layer) == 3 for x in layer: assert len(x) <= 2 ** k bits = layer[0].equal(index, k) threshold = pick(bits, layer[2]) key_index = pick(bits, layer[1]) if key_index.is_clear: key = data[key_index] else: key = pick( oram.demux(key_index.bit_decompose(util.log2(len(data)))), data) child = 2 * key < threshold index += child * 2 ** k bits = layers[h][0].equal(index, h) return pick(bits, layers[h][1]) def test_decision_tree(name, layers, y, x, n_threads=None, time=False): if time: start_timer(100) n = len(y) x = x.transpose().reveal() y = y.reveal() guess = regint.Array(n) truth = regint.Array(n) correct = regint.Array(2) parts = regint.Array(2) layers = [[Array.create_from(util.reveal(x)) for x in layer] for layer in layers] @for_range_multithread(n_threads, 1, n) def _(i): guess[i] = run_decision_tree([[part[:] for part in layer] for layer in layers], x[i]).reveal() truth[i] = y[i].reveal() @for_range(n) def _(i): parts[truth[i]] += 1 c = (guess[i].bit_xor(truth[i]).bit_not()) correct[truth[i]] += c print_ln('%s for height %s: %s/%s (%s/%s, %s/%s)', name, len(layers) - 1, sum(correct), n, correct[0], parts[0], correct[1], parts[1]) if time: stop_timer(100) class TreeClassifier: """ Tree classification with convenient interface. Uses :py:class:`TreeTrainer` internally. :param max_depth: the depth of the decision tree :param n_threads: number of threads used in training """ def __init__(self, max_depth, n_threads=None): self.max_depth = max_depth self.n_threads = n_threads @staticmethod def get_attr_lengths(attr_types): if attr_types == None: return None else: return [1 if x == 'b' else 0 for x in attr_types] def fit(self, X, y, attr_types=None): """ Train tree. :param X: sample data with row-wise samples (sint/sfix matrix) :param y: binary labels (sint list/array) """ self.tree = TreeTrainer( X.transpose(), y, self.max_depth, attr_lengths=self.get_attr_lengths(attr_types), n_threads=self.n_threads).train() def fit_with_testing(self, X_train, y_train, X_test, y_test, attr_types=None, output_trees=False, debug=False): """ Train tree with accuracy output after every level. :param X_train: training data with row-wise samples (sint/sfix matrix) :param y_train: training binary labels (sint list/array) :param X_test: testing data with row-wise samples (sint/sfix matrix) :param y_test: testing binary labels (sint list/array) :param attr_types: attributes types (list of 'b'/'c' for binary/continuous; default is all continuous) :param output_trees: output tree after every level :param debug: output debugging information """ trainer = TreeTrainer(X_train.transpose(), y_train, self.max_depth, attr_lengths=self.get_attr_lengths(attr_types), n_threads=self.n_threads) trainer.debug = debug trainer.debug_gini = debug trainer.debug_threading = debug > 1 self.tree = trainer.train_with_testing(y_test, X_test.transpose(), output=output_trees) def predict(self, X): """ Use tree for prediction. :param X: sample data with row-wise samples (sint/sfix matrix) :returns: sint array """ res = sint.Array(len(X)) @for_range(len(X)) def _(i): res[i] = run_decision_tree(self.tree, X[i]) return res def output(self): """ Output decision tree. """ output_decision_tree(self.tree) def preprocess_pandas(data): """ Preprocess pandas data frame to suit :py:class:`TreeClassifier` by expanding non-continuous attributes to several binary attributes as a unary encoding. :returns: a tuple of the processed data and a type list for the :py:obj:`attr_types` argument. """ import pandas import numpy res = [] types = [] for i, t in enumerate(data.dtypes): if pandas.api.types.is_int64_dtype(t): res.append(data.iloc[:,i].to_numpy()) types.append('c') elif pandas.api.types.is_object_dtype(t): values = list(filter(lambda x: isinstance(x, str), list(data.iloc[:,i].unique()))) print('converting the following to unary:', values) if len(values) == 2: res.append(data.iloc[:,i].to_numpy() == values[1]) types.append('b') else: for value in values: res.append(data.iloc[:,i].to_numpy() == value) types.append('b') else: raise CompilerError('unknown pandas type: ' + t) res = numpy.array(res) res = numpy.swapaxes(res, 0, 1) return res, types