mirror of
https://github.com/data61/MP-SPDZ.git
synced 2026-01-10 14:08:09 -05:00
148 lines
3.6 KiB
Plaintext
148 lines
3.6 KiB
Plaintext
a = Array.create_from([sint(1), sint(2), sint(3), sint(4)])
|
|
b = Array.create_from([sint(3), sint(2), sint(1)])
|
|
|
|
c = Matrix.create_from([
|
|
[sint(1), sint(2), sint(3)],
|
|
[sint(4), sint(5), sint(6)],
|
|
[sint(7), sint(8), sint(9)],
|
|
[sint(10), sint(11), sint(12)]
|
|
])
|
|
|
|
d = Matrix.create_from([
|
|
[sint(12), sint(11), sint(10), sint(9)],
|
|
[sint(8), sint(7), sint(6), sint(5)],
|
|
[sint(4), sint(3), sint(2), sint(1)]
|
|
])
|
|
|
|
|
|
def test_array(expected: list[int], actual: Array) -> None:
|
|
actual = actual.reveal()
|
|
expected = Array.create_from([cint(x) for x in expected])
|
|
@for_range(len(expected))
|
|
def _(i: cint) -> None:
|
|
@if_(actual[i] != expected[i])
|
|
def fail():
|
|
print_ln("Unexpected entry at index %s", i)
|
|
print_ln("Expected:")
|
|
expected.print_reveal_nested()
|
|
print_ln("Actual:")
|
|
actual.print_reveal_nested()
|
|
|
|
crash()
|
|
|
|
|
|
def test_matrix(expected: list[list[int]], actual: Matrix) -> None:
|
|
actual = actual.reveal()
|
|
expected = Matrix.create_from([[cint(x) for x in row] for row in expected])
|
|
@for_range(len(expected))
|
|
def outer(i: cint) -> None:
|
|
|
|
@for_range(len(expected[0]))
|
|
def inner(j: cint) -> None:
|
|
@if_(actual[i][j] != expected[i][j])
|
|
def fail():
|
|
print_ln("Unexpected entry at index %s,%s", i, j)
|
|
print_ln("Expected:")
|
|
expected.print_reveal_nested()
|
|
print_ln("Actual:")
|
|
actual.print_reveal_nested()
|
|
|
|
crash()
|
|
|
|
break_point()
|
|
def hacky_array_dot_matrix(arr: Array, mat: Matrix) -> Array:
|
|
# Arrays sadly do not have a dot function, therefore the array is converted into a 1 times n Matrix by copying memory addresses.
|
|
tmp = sint.Matrix(rows=1, columns=len(arr), address=arr.address)
|
|
result = tmp.dot(mat)
|
|
return sint.Array(mat.shape[1], result.address)
|
|
|
|
start_timer(3)
|
|
|
|
e3 = hacky_array_dot_matrix(a, c)
|
|
# b[0] = e3[0]
|
|
f3 = hacky_array_dot_matrix(b, d)
|
|
|
|
stop_timer(3)
|
|
|
|
e3 = e3.reveal()
|
|
f3 = f3.reveal()
|
|
|
|
e3.print_reveal_nested()
|
|
f3.print_reveal_nested()
|
|
|
|
test_array([70, 80, 90], e3)
|
|
test_array([56, 50, 44, 38], f3)
|
|
|
|
start_timer(4)
|
|
|
|
e4 = hacky_array_dot_matrix(a, c)
|
|
b[-1] = e4[0]
|
|
f4 = hacky_array_dot_matrix(b, d)
|
|
|
|
stop_timer(4)
|
|
|
|
test_array([70, 80, 90], e4)
|
|
test_array([332, 257, 182, 107], f4)
|
|
|
|
f4.print_reveal_nested()
|
|
|
|
# TODO: Crashes
|
|
|
|
|
|
start_timer(5)
|
|
g = c.dot(d)
|
|
stop_timer(5)
|
|
|
|
test_matrix([
|
|
[ 40, 34, 28, 22],
|
|
[112, 97, 82, 67],
|
|
[184, 160, 136, 112],
|
|
[256, 223, 190, 157]
|
|
], g)
|
|
g.print_reveal_nested()
|
|
|
|
|
|
# Big matrix tests.
|
|
# These are intended to test matrix multiplications that require multiple batches.
|
|
|
|
def identity(size: int) -> Matrix:
|
|
result = sint.Matrix(rows=size, columns=size)
|
|
result.assign_all(0)
|
|
for i in range(size):
|
|
result[i][i] = 1
|
|
return result
|
|
|
|
|
|
def counting_matrix(rows: int, columns: int) -> Matrix:
|
|
result = sint.Matrix(rows, columns)
|
|
@for_range(rows)
|
|
def outer(i: cint) -> None:
|
|
@for_range(columns)
|
|
def inner(j: cint) -> None:
|
|
result[i][j] = i * columns + j
|
|
return result
|
|
|
|
|
|
def clear_counting_matrix(rows: int, columns: int) -> list[list[int]]:
|
|
return [list(range(i * columns, (i + 1) * columns)) for i in range(rows)]
|
|
|
|
|
|
# Single matrix multiplication requiring multiple batches.
|
|
a = counting_matrix(20, 20)
|
|
b = identity(20)
|
|
|
|
start_timer(6)
|
|
c = a * b
|
|
stop_timer(6)
|
|
|
|
test_matrix(clear_counting_matrix(20, 20), c)
|
|
|
|
# Multiple matrix multiplications requiring multiple batches.
|
|
start_timer(7)
|
|
d = a * b
|
|
e = c * b
|
|
stop_timer(7)
|
|
|
|
test_matrix(clear_counting_matrix(20, 20), d)
|
|
test_matrix(clear_counting_matrix(20, 20), e)
|