Files
MP-SPDZ/Programs/Source/breast_logistic.mpc
2024-11-21 13:14:54 +11:00

56 lines
2.1 KiB
Plaintext

from sklearn.datasets import load_breast_cancer
from sklearn.model_selection import train_test_split
X, y = load_breast_cancer(return_X_y=True)
# normalize column-wise
X /= X.max(axis=0)
X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=0)
if 'horizontal' in program.args:
# split by sample
a = sfix.input_tensor_via(0, X_train[len(X_train) // 2:])
b = sfix.input_tensor_via(1, X_train[:len(X_train) // 2])
X_train = a.concat(b)
a = sint.input_tensor_via(0, y_train[len(y_train) // 2:])
b = sint.input_tensor_via(1, y_train[:len(y_train) // 2])
y_train = a.concat(b)
elif 'vertical' in program.args:
print (X_train.shape, X_train.shape[1])
a = sfix.input_tensor_via(0, X_train[:,:X_train.shape[1] // 2])
b = sfix.input_tensor_via(1, X_train[:,X_train.shape[1] // 2:])
X_train = a.concat_columns(b)
y_train = sint.input_tensor_via(0, y_train)
elif 'party0' in program.args or 'party1' in program.args:
party = int('party1' in program.args)
a = sfix.input_tensor_via(
0, X_train[:,:X_train.shape[1] // 2] if party == 0 else None,
shape=X_train[:,:X_train.shape[1] // 2].shape)
b = sfix.input_tensor_via(
1, X_train[:,X_train.shape[1] // 2:] if party == 1 else None,
shape=X_train[:,X_train.shape[1] // 2:].shape)
X_train = a.concat_columns(b)
y_train = sint.input_tensor_via(0, y_train if party == 0 else None,
shape=y_train.shape)
else:
X_train = sfix.input_tensor_via(0, X_train)
y_train = sint.input_tensor_via(0, y_train)
if 'party1' in program.args:
X_test = sfix.input_tensor_via(0, shape=X_test.shape)
y_test = sint.input_tensor_via(0, shape=y_test.shape)
else:
X_test = sfix.input_tensor_via(0, X_test)
y_test = sint.input_tensor_via(0, y_test)
from Compiler import ml
log = ml.SGDLogistic(20, 2, program)
log.fit(X_train, y_train)
print_ln('%s', (log.predict(X_test) - y_test.get_vector()).reveal())
log.fit_with_testing(X_train, y_train, X_test, y_test)
print_ln('%s', (log.predict_proba(X_test) - y_test.get_vector()).reveal())