mirror of
https://github.com/data61/MP-SPDZ.git
synced 2026-01-09 05:27:56 -05:00
653 lines
22 KiB
Python
653 lines
22 KiB
Python
from Compiler.types import *
|
|
from Compiler.sorting import *
|
|
from Compiler.library import *
|
|
from Compiler import util, oram
|
|
|
|
from itertools import accumulate
|
|
import math
|
|
|
|
debug = False
|
|
debug_split = False
|
|
max_leaves = None
|
|
|
|
def get_type(x):
|
|
if isinstance(x, (Array, SubMultiArray)):
|
|
return x.value_type
|
|
elif isinstance(x, (tuple, list)):
|
|
x = x[0] + x[-1]
|
|
if util.is_constant(x):
|
|
return cint
|
|
else:
|
|
return type(x)
|
|
else:
|
|
return type(x)
|
|
|
|
def PrefixSum(x):
|
|
return x.get_vector().prefix_sum()
|
|
|
|
def PrefixSumR(x):
|
|
tmp = get_type(x).Array(len(x))
|
|
tmp.assign_vector(x)
|
|
break_point()
|
|
tmp[:] = tmp.get_reverse_vector().prefix_sum()
|
|
break_point()
|
|
return tmp.get_reverse_vector()
|
|
|
|
def PrefixSum_inv(x):
|
|
tmp = get_type(x).Array(len(x) + 1)
|
|
tmp.assign_vector(x, base=1)
|
|
tmp[0] = 0
|
|
return tmp.get_vector(size=len(x), base=1) - tmp.get_vector(size=len(x))
|
|
|
|
def PrefixSumR_inv(x):
|
|
tmp = get_type(x).Array(len(x) + 1)
|
|
tmp.assign_vector(x)
|
|
tmp[-1] = 0
|
|
return tmp.get_vector(size=len(x)) - tmp.get_vector(base=1, size=len(x))
|
|
|
|
class SortPerm:
|
|
def __init__(self, x):
|
|
B = sint.Matrix(len(x), 2)
|
|
B.set_column(0, 1 - x.get_vector())
|
|
B.set_column(1, x.get_vector())
|
|
self.perm = Array.create_from(dest_comp(B))
|
|
def apply(self, x):
|
|
res = Array.create_from(x)
|
|
reveal_sort(self.perm, res, False)
|
|
return res
|
|
def unapply(self, x):
|
|
res = Array.create_from(x)
|
|
reveal_sort(self.perm, res, True)
|
|
return res
|
|
|
|
def Sort(keys, *to_sort, n_bits=None, time=False):
|
|
if time:
|
|
start_timer(1)
|
|
for k in keys:
|
|
assert len(k) == len(keys[0])
|
|
n_bits = n_bits or [None] * len(keys)
|
|
bs = Matrix.create_from(
|
|
sum([k.get_vector().bit_decompose(nb)
|
|
for k, nb in reversed(list(zip(keys, n_bits)))], []))
|
|
get_vec = lambda x: x[:] if isinstance(x, Array) else x
|
|
res = Matrix.create_from(get_vec(x).v if isinstance(get_vec(x), sfix) else x
|
|
for x in to_sort)
|
|
res = res.transpose()
|
|
if time:
|
|
start_timer(11)
|
|
radix_sort_from_matrix(bs, res)
|
|
if time:
|
|
stop_timer(11)
|
|
stop_timer(1)
|
|
res = res.transpose()
|
|
return [sfix._new(get_vec(x), k=get_vec(y).k, f=get_vec(y).f)
|
|
if isinstance(get_vec(y), sfix)
|
|
else x for (x, y) in zip(res, to_sort)]
|
|
|
|
def VectMax(key, *data, debug=False):
|
|
def reducer(x, y):
|
|
b = x[0] > y[0]
|
|
if debug:
|
|
print_ln('max b=%s', b.reveal())
|
|
return [b.if_else(xx, yy) for xx, yy in zip(x, y)]
|
|
if debug:
|
|
key = list(key)
|
|
data = [list(x) for x in data]
|
|
print_ln('vect max key=%s data=%s', util.reveal(key), util.reveal(data))
|
|
res = util.tree_reduce(reducer, zip(key, *data))[1:]
|
|
if debug:
|
|
print_ln('vect max res=%s', util.reveal(res))
|
|
return res
|
|
|
|
def GroupSum(g, x):
|
|
assert len(g) == len(x)
|
|
p = PrefixSumR(x) * g
|
|
pi = SortPerm(g.get_vector().bit_not())
|
|
p1 = pi.apply(p)
|
|
s1 = PrefixSumR_inv(p1)
|
|
d1 = PrefixSum_inv(s1)
|
|
d = pi.unapply(d1) * g
|
|
return PrefixSum(d)
|
|
|
|
def GroupPrefixSum(g, x):
|
|
assert len(g) == len(x)
|
|
s = get_type(x).Array(len(x) + 1)
|
|
s[0] = 0
|
|
s.assign_vector(PrefixSum(x), base=1)
|
|
q = get_type(s).Array(len(x))
|
|
q.assign_vector(s.get_vector(size=len(x)) * g)
|
|
return s.get_vector(size=len(x), base=1) - GroupSum(g, q)
|
|
|
|
def GroupMax(g, keys, *x):
|
|
if debug:
|
|
print_ln('group max input g=%s keys=%s x=%s', util.reveal(g),
|
|
util.reveal(keys), util.reveal(x))
|
|
assert len(keys) == len(g)
|
|
for xx in x:
|
|
assert len(xx) == len(g)
|
|
n = len(g)
|
|
m = int(math.ceil(math.log(n, 2)))
|
|
keys = Array.create_from(keys)
|
|
x = [Array.create_from(xx) for xx in x]
|
|
g_new = Array.create_from(g)
|
|
g_old = g_new.same_shape()
|
|
for d in range(m):
|
|
w = 2 ** d
|
|
g_old[:] = g_new[:]
|
|
break_point()
|
|
vsize = n - w
|
|
g_new.assign_vector(g_old.get_vector(size=vsize).bit_or(
|
|
g_old.get_vector(size=vsize, base=w)), base=w)
|
|
b = keys.get_vector(size=vsize) > keys.get_vector(size=vsize, base=w)
|
|
for xx in [keys] + x:
|
|
a = b.if_else(xx.get_vector(size=vsize),
|
|
xx.get_vector(size=vsize, base=w))
|
|
xx.assign_vector(g_old.get_vector(size=vsize, base=w).if_else(
|
|
xx.get_vector(size=vsize, base=w), a), base=w)
|
|
break_point()
|
|
if debug:
|
|
print_ln('group max w=%s b=%s a=%s keys=%s x=%s g=%s', w, b.reveal(),
|
|
util.reveal(a), util.reveal(keys),
|
|
util.reveal(x), g_new.reveal())
|
|
t = sint.Array(len(g))
|
|
t[-1] = 1
|
|
t.assign_vector(g.get_vector(size=n - 1, base=1))
|
|
if debug:
|
|
print_ln('group max end g=%s t=%s keys=%s x=%s', util.reveal(g),
|
|
util.reveal(t), util.reveal(keys), util.reveal(x))
|
|
return [GroupSum(g, t[:] * xx) for xx in [keys] + x]
|
|
|
|
def ModifiedGini(g, y, debug=False):
|
|
assert len(g) == len(y)
|
|
y = [y.get_vector().bit_not(), y]
|
|
u = [GroupPrefixSum(g, yy) for yy in y]
|
|
s = [GroupSum(g, yy) for yy in y]
|
|
w = [ss - uu for ss, uu in zip(s, u)]
|
|
us = sum(u)
|
|
ws = sum(w)
|
|
uqs = u[0] ** 2 + u[1] ** 2
|
|
wqs = w[0] ** 2 + w[1] ** 2
|
|
res = sfix(uqs) / us + sfix(wqs) / ws
|
|
if debug:
|
|
print_ln('g=%s y=%s s=%s',
|
|
util.reveal(g), util.reveal(y),
|
|
util.reveal(s))
|
|
print_ln('u0=%s', util.reveal(u[0]))
|
|
print_ln('u0=%s', util.reveal(u[1]))
|
|
print_ln('us=%s', util.reveal(us))
|
|
print_ln('w0=%s', util.reveal(w[0]))
|
|
print_ln('w1=%s', util.reveal(w[1]))
|
|
print_ln('ws=%s', util.reveal(ws))
|
|
print_ln('uqs=%s', util.reveal(uqs))
|
|
print_ln('wqs=%s', util.reveal(wqs))
|
|
if debug:
|
|
print_ln('gini %s %s', type(res), util.reveal(res))
|
|
return res
|
|
|
|
MIN_VALUE = -10000
|
|
|
|
def FormatLayer(h, g, *a):
|
|
return CropLayer(h, *FormatLayer_without_crop(g, *a))
|
|
|
|
def FormatLayer_without_crop(g, *a, debug=False):
|
|
for x in a:
|
|
assert len(x) == len(g)
|
|
v = [g.if_else(aa, 0) for aa in a]
|
|
if debug:
|
|
print_ln('format in %s', util.reveal(a))
|
|
print_ln('format mux %s', util.reveal(v))
|
|
v = Sort([g.bit_not()], *v, n_bits=[1])
|
|
if debug:
|
|
print_ln('format sort %s', util.reveal(v))
|
|
return v
|
|
|
|
def CropLayer(k, *v):
|
|
if max_leaves:
|
|
n = min(2 ** k, max_leaves)
|
|
else:
|
|
n = 2 ** k
|
|
return [vv[:min(n, len(vv))] for vv in v]
|
|
|
|
def TrainLeafNodes(h, g, y, NID):
|
|
assert len(g) == len(y)
|
|
assert len(g) == len(NID)
|
|
Label = GroupSum(g, y.bit_not()) < GroupSum(g, y)
|
|
return FormatLayer(h, g, NID, Label)
|
|
|
|
def GroupSame(g, y):
|
|
assert len(g) == len(y)
|
|
s = GroupSum(g, [sint(1)] * len(g))
|
|
s0 = GroupSum(g, y.bit_not())
|
|
s1 = GroupSum(g, y)
|
|
if debug_split:
|
|
print_ln('group same g=%s', util.reveal(g))
|
|
print_ln('group same y=%s', util.reveal(y))
|
|
return (s == s0).bit_or(s == s1)
|
|
|
|
def GroupFirstOne(g, b):
|
|
assert len(g) == len(b)
|
|
s = GroupPrefixSum(g, b)
|
|
return s * b == 1
|
|
|
|
class TreeTrainer:
|
|
""" Decision tree training by `Hamada et al.`_
|
|
|
|
:param x: sample data (by attribute, list or
|
|
:py:obj:`~Compiler.types.Matrix`)
|
|
:param y: binary labels (list or sint vector)
|
|
:param h: height (int)
|
|
:param binary: binary attributes instead of continuous
|
|
:param attr_lengths: attribute description for mixed data
|
|
(list of 0/1 for continuous/binary)
|
|
:param n_threads: number of threads (default: single thread)
|
|
|
|
.. _`Hamada et al.`: https://arxiv.org/abs/2112.12906
|
|
|
|
"""
|
|
def ApplyTests(self, x, AID, Threshold):
|
|
m = len(x)
|
|
n = len(AID)
|
|
assert len(AID) == len(Threshold)
|
|
for xx in x:
|
|
assert len(xx) == len(AID)
|
|
e = sint.Matrix(m, n)
|
|
AID = Array.create_from(AID)
|
|
@for_range_multithread(self.n_threads, 1, m)
|
|
def _(j):
|
|
e[j][:] = AID[:] == j
|
|
xx = sum(x[j] * e[j] for j in range(m))
|
|
if self.debug > 1:
|
|
print_ln('apply e=%s xx=%s', util.reveal(e), util.reveal(xx))
|
|
print_ln('threshold %s', util.reveal(Threshold))
|
|
return 2 * xx < Threshold
|
|
|
|
def AttributeWiseTestSelection(self, g, x, y, time=False, debug=False):
|
|
assert len(g) == len(x)
|
|
assert len(g) == len(y)
|
|
if time:
|
|
start_timer(2)
|
|
s = ModifiedGini(g, y, debug=debug or self.debug > 2)
|
|
if time:
|
|
stop_timer(2)
|
|
if debug or self.debug > 1:
|
|
print_ln('gini %s', s.reveal())
|
|
xx = x
|
|
t = get_type(x).Array(len(x))
|
|
t[-1] = MIN_VALUE
|
|
t.assign_vector(xx.get_vector(size=len(x) - 1) + \
|
|
xx.get_vector(size=len(x) - 1, base=1))
|
|
gg = g
|
|
p = sint.Array(len(x))
|
|
p[-1] = 1
|
|
p.assign_vector(gg.get_vector(base=1, size=len(x) - 1).bit_or(
|
|
xx.get_vector(size=len(x) - 1) == \
|
|
xx.get_vector(size=len(x) - 1, base=1)))
|
|
break_point()
|
|
if debug:
|
|
print_ln('attribute t=%s p=%s', util.reveal(t), util.reveal(p))
|
|
s = p[:].if_else(MIN_VALUE, s)
|
|
t = p[:].if_else(MIN_VALUE, t[:])
|
|
if debug:
|
|
print_ln('attribute s=%s t=%s', util.reveal(s), util.reveal(t))
|
|
if time:
|
|
start_timer(3)
|
|
s, t = GroupMax(gg, s, t)
|
|
if time:
|
|
stop_timer(3)
|
|
if debug:
|
|
print_ln('attribute s=%s t=%s', util.reveal(s), util.reveal(t))
|
|
return t, s
|
|
|
|
def GlobalTestSelection(self, x, y, g):
|
|
assert len(y) == len(g)
|
|
for xx in x:
|
|
assert(len(xx) == len(g))
|
|
m = len(x)
|
|
n = len(y)
|
|
u, t = [get_type(x).Matrix(m, n) for i in range(2)]
|
|
v = get_type(y).Matrix(m, n)
|
|
s = sfix.Matrix(m, n)
|
|
@for_range_multithread(self.n_threads, 1, m)
|
|
def _(j):
|
|
single = not self.n_threads or self.n_threads == 1
|
|
time = self.time and single
|
|
if debug:
|
|
print_ln('run %s', j)
|
|
@if_e(self.attr_lengths[j])
|
|
def _():
|
|
u[j][:], v[j][:] = Sort((PrefixSum(g), x[j]), x[j], y,
|
|
n_bits=[util.log2(n), 1], time=time)
|
|
@else_
|
|
def _():
|
|
u[j][:], v[j][:] = Sort((PrefixSum(g), x[j]), x[j], y,
|
|
n_bits=[util.log2(n), None],
|
|
time=time)
|
|
if self.debug_threading:
|
|
print_ln('global sort %s %s %s', j, util.reveal(u[j]),
|
|
util.reveal(v[j]))
|
|
t[j][:], s[j][:] = self.AttributeWiseTestSelection(
|
|
g, u[j], v[j], time=time, debug=self.debug_selection)
|
|
if self.debug_threading:
|
|
print_ln('global attribute %s %s %s', j, util.reveal(t[j]),
|
|
util.reveal(s[j]))
|
|
n = len(g)
|
|
a = sint.Array(n)
|
|
if self.debug_threading:
|
|
print_ln('global s=%s', util.reveal(s))
|
|
if self.debug_gini:
|
|
print_ln('Gini indices ' + ' '.join(str(i) + ':%s' for i in range(m)),
|
|
*(ss[0].reveal() for ss in s))
|
|
if self.time:
|
|
start_timer(4)
|
|
if self.debug > 1:
|
|
print_ln('s=%s', s.reveal_nested())
|
|
print_ln('t=%s', t.reveal_nested())
|
|
a[:], tt = VectMax((s[j][:] for j in range(m)), range(m),
|
|
(t[j][:] for j in range(m)), debug=self.debug > 1)
|
|
tt = Array.create_from(tt)
|
|
if self.time:
|
|
stop_timer(4)
|
|
if self.debug > 1:
|
|
print_ln('a=%s', util.reveal(a))
|
|
print_ln('tt=%s', util.reveal(tt))
|
|
return a[:], tt[:]
|
|
|
|
def TrainInternalNodes(self, k, x, y, g, NID):
|
|
assert len(g) == len(y)
|
|
for xx in x:
|
|
assert len(xx) == len(g)
|
|
AID, Threshold = self.GlobalTestSelection(x, y, g)
|
|
s = GroupSame(g[:], y[:])
|
|
if self.debug > 1 or debug_split:
|
|
print_ln('AID=%s', util.reveal(AID))
|
|
print_ln('Threshold=%s', util.reveal(Threshold))
|
|
print_ln('GroupSame=%s', util.reveal(s))
|
|
AID, Threshold = s.if_else(0, AID), s.if_else(MIN_VALUE, Threshold)
|
|
if self.debug > 1 or debug_split:
|
|
print_ln('AID=%s', util.reveal(AID))
|
|
print_ln('Threshold=%s', util.reveal(Threshold))
|
|
b = self.ApplyTests(x, AID, Threshold)
|
|
layer = FormatLayer_without_crop(g[:], NID, AID, Threshold,
|
|
debug=self.debug > 1)
|
|
return *layer, b
|
|
|
|
@method_block
|
|
def train_layer(self, k):
|
|
x = self.x
|
|
y = self.y
|
|
g = self.g
|
|
NID = self.NID
|
|
if self.debug > 1:
|
|
print_ln('g=%s', g.reveal())
|
|
print_ln('y=%s', y.reveal())
|
|
print_ln('x=%s', x.reveal_nested())
|
|
self.nids[k], self.aids[k], self.thresholds[k], b = \
|
|
self.TrainInternalNodes(k, x, y, g, NID)
|
|
if self.debug > 1:
|
|
print_ln('layer %s:', k)
|
|
for name, data in zip(('NID', 'AID', 'Thr'),
|
|
(self.nids[k], self.aids[k],
|
|
self.thresholds[k])):
|
|
print_ln(' %s: %s', name, data.reveal())
|
|
NID[:] = 2 ** k * b + NID
|
|
b_not = b.bit_not()
|
|
if self.debug > 1:
|
|
print_ln('b_not=%s', b_not.reveal())
|
|
g[:] = GroupFirstOne(g, b_not) + GroupFirstOne(g, b)
|
|
y[:], g[:], NID[:], *xx = Sort([b], y, g, NID, *x, n_bits=[1])
|
|
for i, xxx in enumerate(xx):
|
|
x[i] = xxx
|
|
|
|
def __init__(self, x, y, h, binary=False, attr_lengths=None,
|
|
n_threads=None):
|
|
assert not (binary and attr_lengths)
|
|
if binary:
|
|
attr_lengths = [1] * len(x)
|
|
else:
|
|
attr_lengths = attr_lengths or ([0] * len(x))
|
|
for l in attr_lengths:
|
|
assert l in (0, 1)
|
|
self.attr_lengths = Array.create_from(regint(attr_lengths))
|
|
Array.check_indices = False
|
|
Matrix.disable_index_checks()
|
|
for xx in x:
|
|
assert len(xx) == len(y)
|
|
n = len(y)
|
|
self.g = sint.Array(n)
|
|
self.g.assign_all(0)
|
|
self.g[0] = 1
|
|
self.NID = sint.Array(n)
|
|
self.NID.assign_all(1)
|
|
self.y = Array.create_from(y)
|
|
self.x = Matrix.create_from(x)
|
|
self.nids, self.aids = [sint.Matrix(h, n) for i in range(2)]
|
|
self.thresholds = self.x.value_type.Matrix(h, n)
|
|
self.n_threads = n_threads
|
|
self.debug_selection = False
|
|
self.debug_threading = False
|
|
self.debug_gini = False
|
|
self.debug = False
|
|
self.time = False
|
|
get_program().reading('decision tree learning', 'HIKC23')
|
|
|
|
def train(self):
|
|
""" Train and return decision tree. """
|
|
h = len(self.nids)
|
|
@for_range(h)
|
|
def _(k):
|
|
self.train_layer(k)
|
|
return self.get_tree(h)
|
|
|
|
def train_with_testing(self, *test_set, output=False):
|
|
""" Train decision tree and test against test data.
|
|
|
|
:param y: binary labels (list or sint vector)
|
|
:param x: sample data (by attribute, list or
|
|
:py:obj:`~Compiler.types.Matrix`)
|
|
:param output: output tree after every level
|
|
:returns: tree
|
|
|
|
"""
|
|
for k in range(len(self.nids)):
|
|
self.train_layer(k)
|
|
tree = self.get_tree(k + 1)
|
|
if output:
|
|
output_decision_tree(tree)
|
|
test_decision_tree('train', tree, self.y, self.x,
|
|
n_threads=self.n_threads)
|
|
if test_set:
|
|
test_decision_tree('test', tree, *test_set,
|
|
n_threads=self.n_threads)
|
|
return tree
|
|
|
|
def get_tree(self, h):
|
|
Layer = [None] * (h + 1)
|
|
for k in range(h):
|
|
Layer[k] = CropLayer(k, self.nids[k], self.aids[k],
|
|
self.thresholds[k])
|
|
Layer[h] = TrainLeafNodes(h, self.g[:], self.y[:], self.NID)
|
|
return Layer
|
|
|
|
def DecisionTreeTraining(x, y, h, binary=False):
|
|
return TreeTrainer(x, y, h, binary=binary).train()
|
|
|
|
def output_decision_tree(layers):
|
|
""" Print decision tree output by :py:class:`TreeTrainer`. """
|
|
print_ln('full model %s', util.reveal(layers))
|
|
for i, layer in enumerate(layers[:-1]):
|
|
print_ln('level %s:', i)
|
|
for j, x in enumerate(('NID', 'AID', 'Thr')):
|
|
print_ln(' %s: %s', x, util.reveal(layer[j]))
|
|
print_ln('leaves:')
|
|
for j, x in enumerate(('NID', 'result')):
|
|
print_ln(' %s: %s', x, util.reveal(layers[-1][j]))
|
|
|
|
def pick(bits, x):
|
|
if len(bits) == 1:
|
|
return bits[0] * x[0]
|
|
else:
|
|
try:
|
|
return x[0].dot_product(bits, x)
|
|
except:
|
|
return sum(aa * bb for aa, bb in zip(bits, x))
|
|
|
|
def run_decision_tree(layers, data):
|
|
""" Run decision tree against sample data.
|
|
|
|
:param layers: tree output by :py:class:`TreeTrainer`
|
|
:param data: sample data (:py:class:`~Compiler.types.Array`)
|
|
:returns: binary label
|
|
|
|
"""
|
|
h = len(layers) - 1
|
|
index = 1
|
|
for k, layer in enumerate(layers[:-1]):
|
|
assert len(layer) == 3
|
|
for x in layer:
|
|
assert len(x) <= 2 ** k
|
|
bits = layer[0].equal(index, k)
|
|
threshold = pick(bits, layer[2])
|
|
key_index = pick(bits, layer[1])
|
|
if key_index.is_clear:
|
|
key = data[key_index]
|
|
else:
|
|
key = pick(
|
|
oram.demux(key_index.bit_decompose(util.log2(len(data)))), data)
|
|
child = 2 * key < threshold
|
|
index += child * 2 ** k
|
|
bits = layers[h][0].equal(index, h)
|
|
return pick(bits, layers[h][1])
|
|
|
|
def test_decision_tree(name, layers, y, x, n_threads=None, time=False):
|
|
if time:
|
|
start_timer(100)
|
|
n = len(y)
|
|
x = x.transpose().reveal()
|
|
y = y.reveal()
|
|
guess = regint.Array(n)
|
|
truth = regint.Array(n)
|
|
correct = regint.Array(2)
|
|
parts = regint.Array(2)
|
|
layers = [[Array.create_from(util.reveal(x)) for x in layer]
|
|
for layer in layers]
|
|
@for_range_multithread(n_threads, 1, n)
|
|
def _(i):
|
|
guess[i] = run_decision_tree([[part[:] for part in layer]
|
|
for layer in layers], x[i]).reveal()
|
|
truth[i] = y[i].reveal()
|
|
@for_range(n)
|
|
def _(i):
|
|
parts[truth[i]] += 1
|
|
c = (guess[i].bit_xor(truth[i]).bit_not())
|
|
correct[truth[i]] += c
|
|
print_ln('%s for height %s: %s/%s (%s/%s, %s/%s)', name, len(layers) - 1,
|
|
sum(correct), n, correct[0], parts[0], correct[1], parts[1])
|
|
if time:
|
|
stop_timer(100)
|
|
|
|
class TreeClassifier:
|
|
""" Tree classification with convenient interface. Uses
|
|
:py:class:`TreeTrainer` internally.
|
|
|
|
:param max_depth: the depth of the decision tree
|
|
:param n_threads: number of threads used in training
|
|
|
|
"""
|
|
def __init__(self, max_depth, n_threads=None):
|
|
self.max_depth = max_depth
|
|
self.n_threads = n_threads
|
|
|
|
@staticmethod
|
|
def get_attr_lengths(attr_types):
|
|
if attr_types == None:
|
|
return None
|
|
else:
|
|
return [1 if x == 'b' else 0 for x in attr_types]
|
|
|
|
def fit(self, X, y, attr_types=None):
|
|
""" Train tree.
|
|
|
|
:param X: sample data with row-wise samples (sint/sfix matrix)
|
|
:param y: binary labels (sint list/array)
|
|
|
|
"""
|
|
self.tree = TreeTrainer(
|
|
X.transpose(), y, self.max_depth,
|
|
attr_lengths=self.get_attr_lengths(attr_types),
|
|
n_threads=self.n_threads).train()
|
|
|
|
def fit_with_testing(self, X_train, y_train, X_test, y_test,
|
|
attr_types=None, output_trees=False, debug=False):
|
|
""" Train tree with accuracy output after every level.
|
|
|
|
:param X_train: training data with row-wise samples (sint/sfix matrix)
|
|
:param y_train: training binary labels (sint list/array)
|
|
:param X_test: testing data with row-wise samples (sint/sfix matrix)
|
|
:param y_test: testing binary labels (sint list/array)
|
|
:param attr_types: attributes types (list of 'b'/'c' for
|
|
binary/continuous; default is all continuous)
|
|
:param output_trees: output tree after every level
|
|
:param debug: output debugging information
|
|
|
|
"""
|
|
trainer = TreeTrainer(X_train.transpose(), y_train, self.max_depth,
|
|
attr_lengths=self.get_attr_lengths(attr_types),
|
|
n_threads=self.n_threads)
|
|
trainer.debug = debug
|
|
trainer.debug_gini = debug
|
|
trainer.debug_threading = debug > 1
|
|
self.tree = trainer.train_with_testing(y_test, X_test.transpose(),
|
|
output=output_trees)
|
|
|
|
def predict(self, X):
|
|
""" Use tree for prediction.
|
|
|
|
:param X: sample data with row-wise samples (sint/sfix matrix)
|
|
:returns: sint array
|
|
|
|
"""
|
|
res = sint.Array(len(X))
|
|
@for_range(len(X))
|
|
def _(i):
|
|
res[i] = run_decision_tree(self.tree, X[i])
|
|
return res
|
|
|
|
def output(self):
|
|
""" Output decision tree. """
|
|
output_decision_tree(self.tree)
|
|
|
|
def preprocess_pandas(data):
|
|
""" Preprocess pandas data frame to suit
|
|
:py:class:`TreeClassifier` by expanding non-continuous attributes
|
|
to several binary attributes as a unary encoding.
|
|
|
|
:returns: a tuple of the processed data and a type list for the
|
|
:py:obj:`attr_types` argument.
|
|
|
|
"""
|
|
import pandas
|
|
import numpy
|
|
res = []
|
|
types = []
|
|
for i, t in enumerate(data.dtypes):
|
|
if pandas.api.types.is_int64_dtype(t):
|
|
res.append(data.iloc[:,i].to_numpy())
|
|
types.append('c')
|
|
elif pandas.api.types.is_object_dtype(t):
|
|
values = list(filter(lambda x: isinstance(x, str),
|
|
list(data.iloc[:,i].unique())))
|
|
print('converting the following to unary from %d: %s' %
|
|
(len(res), values))
|
|
if len(values) == 2:
|
|
res.append(data.iloc[:,i].to_numpy() == values[1])
|
|
types.append('b')
|
|
else:
|
|
for value in values:
|
|
res.append(data.iloc[:,i].to_numpy() == value)
|
|
types.append('b')
|
|
else:
|
|
raise CompilerError('unknown pandas type: ' + t)
|
|
res = numpy.array(res)
|
|
res = numpy.swapaxes(res, 0, 1)
|
|
return res, types
|