mirror of
https://github.com/data61/MP-SPDZ.git
synced 2026-04-20 03:01:31 -04:00
81 lines
1.7 KiB
Plaintext
81 lines
1.7 KiB
Plaintext
NetworkA = [
|
|
(784, 128, 'FC'),
|
|
(128, 128, 'FC'),
|
|
(128, 10, 'FC')
|
|
]
|
|
|
|
NetworkD = [
|
|
(25, 980, 'CONV2D'),
|
|
(980, 100, 'FC'),
|
|
(100, 10, 'FC')
|
|
]
|
|
|
|
NetworkB = [
|
|
(25, 9216, 'CONV2D'),
|
|
(4, 2304, 'MAXP'),
|
|
(400, 1024, 'CONV2D'),
|
|
(4, 256, 'MAXP'),
|
|
(256, 100, 'FC'),
|
|
(100, 10, 'FC')
|
|
]
|
|
|
|
NetworkC = [
|
|
(25, 11520, 'CONV2D'),
|
|
(4, 2880, 'MAXP'),
|
|
(500, 3200, 'CONV2D'),
|
|
(4, 800, 'MAXP'),
|
|
(800, 500, 'FC'),
|
|
(500, 10, 'FC')
|
|
]
|
|
|
|
network = globals()['Network' + program.args[1]]
|
|
|
|
# c5.9xlarge has 36 cores
|
|
n_threads = 8
|
|
|
|
# S, Z, bit length
|
|
# using sfloat for secret floats and sint for secret int
|
|
p1 = squant_params(sfloat(.001), sint(1), 8)
|
|
p2 = squant_params(sfloat(.002), sint(2), 8)
|
|
p3 = squant_params(sfloat(.003), sint(3), 8)
|
|
|
|
# precompute multiplication of p1 and p2 to p3
|
|
p3.precompute(p1, p2)
|
|
|
|
# need to this to have arrays with specific parameters
|
|
class squant1(squant):
|
|
params = p1
|
|
class squant2(squant):
|
|
params = p2
|
|
|
|
program.set_bit_length(8)
|
|
|
|
import util
|
|
|
|
|
|
def maxpool(ln, num):
|
|
items = [sint(0, size=num)] * ln
|
|
util.tree_reduce(min, items).store_in_mem(0)
|
|
|
|
zero1 = squant1(0)
|
|
zero2 = squant2(0)
|
|
|
|
for l, n, typ in network:
|
|
|
|
if typ == 'MAXP':
|
|
maxpool(l, n)
|
|
else:
|
|
a = Array(l, squant1)
|
|
b = Array(l, squant2)
|
|
a.assign_all(zero1)
|
|
b.assign_all(zero2)
|
|
|
|
# parallelization for optimization
|
|
@for_range_multithread(n_threads, 1000, n)
|
|
def _(i):
|
|
# only for optimization
|
|
aa = a.get_vector()
|
|
bb = b.get_vector()
|
|
# store in memory to prevent dead code elimination
|
|
squant.dot_product(aa, bb, res_params=p3).store_in_mem(regint(0))
|