Files
MP-SPDZ/Programs/Source/test_dot.mpc
2024-06-25 09:15:35 +02:00

166 lines
3.8 KiB
Plaintext

a = Array.create_from([sint(1), sint(2), sint(3), sint(4)])
b = Array.create_from([sint(3), sint(2), sint(1)])
c = Matrix.create_from([
[sint(1), sint(2), sint(3)],
[sint(4), sint(5), sint(6)],
[sint(7), sint(8), sint(9)],
[sint(10), sint(11), sint(12)]
])
d = Matrix.create_from([
[sint(12), sint(11), sint(10), sint(9)],
[sint(8), sint(7), sint(6), sint(5)],
[sint(4), sint(3), sint(2), sint(1)]
])
def test_array(expected, actual):
actual = actual.reveal()
expected = Array.create_from([cint(x) for x in expected])
@for_range(len(expected))
def _(i):
@if_(actual[i] != expected[i])
def fail():
print_ln("Unexpected entry at index %s", i)
print_ln("Expected:")
expected.print_reveal_nested()
print_ln("Actual:")
actual.print_reveal_nested()
crash()
def test_matrix(expected, actual):
actual = actual.reveal()
expected = Matrix.create_from([[cint(x) for x in row] for row in expected])
@for_range(len(expected))
def outer(i):
@for_range(len(expected[0]))
def inner(j):
@if_(actual[i][j] != expected[i][j])
def fail():
print_ln("Unexpected entry at index %s,%s", i, j)
print_ln("Expected:")
expected.print_reveal_nested()
print_ln("Actual:")
actual.print_reveal_nested()
crash()
break_point()
def hacky_array_dot_matrix(arr, mat):
# Arrays sadly do not have a dot function, therefore the array is converted into a 1 times n Matrix by copying memory addresses.
tmp = sint.Matrix(rows=1, columns=len(arr), address=arr.address)
result = tmp.dot(mat)
return sint.Array(mat.shape[1], result.address)
start_timer(3)
e3 = hacky_array_dot_matrix(a, c)
# b[0] = e3[0]
f3 = hacky_array_dot_matrix(b, d)
stop_timer(3)
e3 = e3.reveal()
f3 = f3.reveal()
e3.print_reveal_nested()
f3.print_reveal_nested()
test_array([70, 80, 90], e3)
test_array([56, 50, 44, 38], f3)
start_timer(4)
e4 = hacky_array_dot_matrix(a, c)
b[-1] = e4[0]
f4 = hacky_array_dot_matrix(b, d)
stop_timer(4)
test_array([70, 80, 90], e4)
test_array([332, 257, 182, 107], f4)
f4.print_reveal_nested()
# TODO: Crashes
start_timer(5)
g = c.dot(d)
stop_timer(5)
test_matrix([
[ 40, 34, 28, 22],
[112, 97, 82, 67],
[184, 160, 136, 112],
[256, 223, 190, 157]
], g)
g.print_reveal_nested()
# Big matrix tests.
# These are intended to test matrix multiplications that require multiple batches.
def identity(size):
result = sint.Matrix(rows=size, columns=size)
result.assign_all(0)
for i in range(size):
result[i][i] = 1
return result
def counting_matrix(rows, columns):
result = sint.Matrix(rows, columns)
@for_range(rows)
def outer(i):
@for_range(columns)
def inner(j):
result[i][j] = i * columns + j
return result
def clear_counting_matrix(rows, columns):
return [list(range(i * columns, (i + 1) * columns)) for i in range(rows)]
# Single matrix multiplication requiring multiple batches.
a = counting_matrix(20, 20)
b = identity(20)
start_timer(6)
c = a * b
stop_timer(6)
test_matrix(clear_counting_matrix(20, 20), c)
# Multiple matrix multiplications requiring multiple batches.
start_timer(7)
d = a * b
e = c * b
stop_timer(7)
test_matrix(clear_counting_matrix(20, 20), d)
test_matrix(clear_counting_matrix(20, 20), e)
start_timer(8)
d = a.dot(b, n_threads=2)
stop_timer(8)
test_matrix(clear_counting_matrix(20, 20), d)
start_timer(9)
M = sint.Matrix(10, 10)
M.direct_mul(M, indices=[regint(0), regint.inc(10), regint.inc(10),
regint(0)])
stop_timer(9)
start_timer(10)
sint.Matrix(1000, 1000) * sint.Matrix(1000, 1000)
stop_timer(10)