Functionality to call high-level code from C++.

This commit is contained in:
Marcel Keller
2024-11-21 12:18:10 +11:00
parent e7554ccbfd
commit 91321ff8cd
245 changed files with 3875 additions and 1139 deletions

View File

@@ -22,16 +22,17 @@ elif 'vertical' in program.args:
b = sfix.input_tensor_via(1, X_train[:,X_train.shape[1] // 2:])
X_train = a.concat_columns(b)
y_train = sint.input_tensor_via(0, y_train)
elif 'party0' in program.args:
a = sfix.input_tensor_via(0, X_train[:,:X_train.shape[1] // 2])
b = sfix.input_tensor_via(1, shape=X_train[:,X_train.shape[1] // 2:].shape)
elif 'party0' in program.args or 'party1' in program.args:
party = int('party1' in program.args)
a = sfix.input_tensor_via(
0, X_train[:,:X_train.shape[1] // 2] if party == 0 else None,
shape=X_train[:,:X_train.shape[1] // 2].shape)
b = sfix.input_tensor_via(
1, X_train[:,X_train.shape[1] // 2:] if party == 1 else None,
shape=X_train[:,X_train.shape[1] // 2:].shape)
X_train = a.concat_columns(b)
y_train = sint.input_tensor_via(0, y_train)
elif 'party1' in program.args:
a = sfix.input_tensor_via(0, shape=X_train[:,:X_train.shape[1] // 2].shape)
b = sfix.input_tensor_via(1, X_train[:,X_train.shape[1] // 2:])
X_train = a.concat_columns(b)
y_train = sint.input_tensor_via(0, shape=y_train.shape)
y_train = sint.input_tensor_via(0, y_train if party == 0 else None,
shape=y_train.shape)
else:
X_train = sfix.input_tensor_via(0, X_train)
y_train = sint.input_tensor_via(0, y_train)

View File

@@ -0,0 +1,7 @@
@export
def a2b(x, res):
print_ln('x=%s', x.reveal())
res[:] = sbitvec(x, length=16)
print_ln('res=%s', x.reveal())
a2b(sint(size=10), sbitvec.get_type(16).Array(10))

View File

@@ -0,0 +1,7 @@
@export
def b2a(res, x):
print_ln('x=%s', x.reveal())
res[:] = sint(x[:])
print_ln('res=%s', x.reveal())
b2a(sint.Array(size=10), sbitvec.get_type(16).Array(10))

View File

@@ -0,0 +1,7 @@
@export
def sort(x):
print_ln('x=%s', x.reveal())
res = x.sort()
print_ln('res=%s', x.reveal())
sort(sint.Array(1000))

View File

@@ -0,0 +1,8 @@
@export
def trunc_pr(x):
print_ln('x=%s', x.reveal())
res = x.round(32, 2)
print_ln('res=%s', res.reveal())
return res
trunc_pr(sint(0, size=1000))

View File

@@ -63,9 +63,6 @@ else:
if 'nearest' in program.args:
sfix.round_nearest = True
if program.options.ring:
assert sfix.f * 4 == int(program.options.ring)
debug_ml = ('debug_ml' in program.args) * 2 ** (sfix.f / 2)
if '1dense' in program.args:

View File

@@ -21,3 +21,13 @@ test(a, 10000, 10000)
test(b, 10000, 20000)
test(a, 1000000, 1000000)
test(b, 1000000, 2000000)
a = 1
if True:
if True:
a = 2
if True:
a = 3
else:
a = 4
crash()

View File

@@ -41,7 +41,9 @@ secret_input = sfix.input_tensor_via(
layers = ml.layers_from_torch(model, secret_input.shape, 1, input_via=0)
optimizer = ml.Optimizer(layers)
optimizer = ml.Optimizer(layers, time_layers='time_layers' in program.args)
start_timer(1)
print_ln('Secure computation says %s',
optimizer.eval(secret_input, top=True)[0].reveal())
stop_timer(1)

View File

@@ -0,0 +1,42 @@
# this tests the pretrained VGG in secure computation
program.options_from_args()
from Compiler import ml
try:
ml.set_n_threads(int(program.args[2]))
except:
pass
import torchvision
import torch
import numpy
import requests
import io
import PIL
from torchvision import transforms
name = 'vgg' + program.args[1]
model = getattr(torchvision.models, name)(weights='DEFAULT')
r = requests.get('https://github.com/pytorch/hub/raw/master/images/dog.jpg')
input_image = PIL.Image.open(io.BytesIO(r.content))
input_tensor = transforms._presets.ImageClassification(crop_size=32)(input_image)
input_batch = input_tensor.unsqueeze(0) # create a mini-batch as expected by the model
with torch.no_grad():
output = int(model(input_batch).argmax())
print('Model says %d' % output)
secret_input = sfix.input_tensor_via(
0, numpy.moveaxis(input_batch.numpy(), 1, -1))
layers = ml.layers_from_torch(model, secret_input.shape, 1, input_via=0)
optimizer = ml.Optimizer(layers)
optimizer.time_layers = True
print_ln('Secure computation says %s',
optimizer.eval(secret_input, top=True)[0].reveal())