mirror of
https://github.com/data61/MP-SPDZ.git
synced 2026-01-09 13:37:58 -05:00
Make MATMULSM mergeable
This commit is contained in:
147
Programs/Source/test_dot.mpc
Normal file
147
Programs/Source/test_dot.mpc
Normal file
@@ -0,0 +1,147 @@
|
||||
a = Array.create_from([sint(1), sint(2), sint(3), sint(4)])
|
||||
b = Array.create_from([sint(3), sint(2), sint(1)])
|
||||
|
||||
c = Matrix.create_from([
|
||||
[sint(1), sint(2), sint(3)],
|
||||
[sint(4), sint(5), sint(6)],
|
||||
[sint(7), sint(8), sint(9)],
|
||||
[sint(10), sint(11), sint(12)]
|
||||
])
|
||||
|
||||
d = Matrix.create_from([
|
||||
[sint(12), sint(11), sint(10), sint(9)],
|
||||
[sint(8), sint(7), sint(6), sint(5)],
|
||||
[sint(4), sint(3), sint(2), sint(1)]
|
||||
])
|
||||
|
||||
|
||||
def test_array(expected: list[int], actual: Array) -> None:
|
||||
actual = actual.reveal()
|
||||
expected = Array.create_from([cint(x) for x in expected])
|
||||
@for_range(len(expected))
|
||||
def _(i: cint) -> None:
|
||||
@if_(actual[i] != expected[i])
|
||||
def fail():
|
||||
print_ln("Unexpected entry at index %s", i)
|
||||
print_ln("Expected:")
|
||||
expected.print_reveal_nested()
|
||||
print_ln("Actual:")
|
||||
actual.print_reveal_nested()
|
||||
|
||||
crash()
|
||||
|
||||
|
||||
def test_matrix(expected: list[list[int]], actual: Matrix) -> None:
|
||||
actual = actual.reveal()
|
||||
expected = Matrix.create_from([[cint(x) for x in row] for row in expected])
|
||||
@for_range(len(expected))
|
||||
def outer(i: cint) -> None:
|
||||
|
||||
@for_range(len(expected[0]))
|
||||
def inner(j: cint) -> None:
|
||||
@if_(actual[i][j] != expected[i][j])
|
||||
def fail():
|
||||
print_ln("Unexpected entry at index %s,%s", i, j)
|
||||
print_ln("Expected:")
|
||||
expected.print_reveal_nested()
|
||||
print_ln("Actual:")
|
||||
actual.print_reveal_nested()
|
||||
|
||||
crash()
|
||||
|
||||
break_point()
|
||||
def hacky_array_dot_matrix(arr: Array, mat: Matrix) -> Array:
|
||||
# Arrays sadly do not have a dot function, therefore the array is converted into a 1 times n Matrix by copying memory addresses.
|
||||
tmp = sint.Matrix(rows=1, columns=len(arr), address=arr.address)
|
||||
result = tmp.dot(mat)
|
||||
return sint.Array(mat.shape[1], result.address)
|
||||
|
||||
start_timer(3)
|
||||
|
||||
e3 = hacky_array_dot_matrix(a, c)
|
||||
# b[0] = e3[0]
|
||||
f3 = hacky_array_dot_matrix(b, d)
|
||||
|
||||
stop_timer(3)
|
||||
|
||||
e3 = e3.reveal()
|
||||
f3 = f3.reveal()
|
||||
|
||||
e3.print_reveal_nested()
|
||||
f3.print_reveal_nested()
|
||||
|
||||
test_array([70, 80, 90], e3)
|
||||
test_array([56, 50, 44, 38], f3)
|
||||
|
||||
start_timer(4)
|
||||
|
||||
e4 = hacky_array_dot_matrix(a, c)
|
||||
b[-1] = e4[0]
|
||||
f4 = hacky_array_dot_matrix(b, d)
|
||||
|
||||
stop_timer(4)
|
||||
|
||||
test_array([70, 80, 90], e4)
|
||||
test_array([332, 257, 182, 107], f4)
|
||||
|
||||
f4.print_reveal_nested()
|
||||
|
||||
# TODO: Crashes
|
||||
|
||||
|
||||
start_timer(5)
|
||||
g = c.dot(d)
|
||||
stop_timer(5)
|
||||
|
||||
test_matrix([
|
||||
[ 40, 34, 28, 22],
|
||||
[112, 97, 82, 67],
|
||||
[184, 160, 136, 112],
|
||||
[256, 223, 190, 157]
|
||||
], g)
|
||||
g.print_reveal_nested()
|
||||
|
||||
|
||||
# Big matrix tests.
|
||||
# These are intended to test matrix multiplications that require multiple batches.
|
||||
|
||||
def identity(size: int) -> Matrix:
|
||||
result = sint.Matrix(rows=size, columns=size)
|
||||
result.assign_all(0)
|
||||
for i in range(size):
|
||||
result[i][i] = 1
|
||||
return result
|
||||
|
||||
|
||||
def counting_matrix(rows: int, columns: int) -> Matrix:
|
||||
result = sint.Matrix(rows, columns)
|
||||
@for_range(rows)
|
||||
def outer(i: cint) -> None:
|
||||
@for_range(columns)
|
||||
def inner(j: cint) -> None:
|
||||
result[i][j] = i * columns + j
|
||||
return result
|
||||
|
||||
|
||||
def clear_counting_matrix(rows: int, columns: int) -> list[list[int]]:
|
||||
return [list(range(i * columns, (i + 1) * columns)) for i in range(rows)]
|
||||
|
||||
|
||||
# Single matrix multiplication requiring multiple batches.
|
||||
a = counting_matrix(20, 20)
|
||||
b = identity(20)
|
||||
|
||||
start_timer(6)
|
||||
c = a * b
|
||||
stop_timer(6)
|
||||
|
||||
test_matrix(clear_counting_matrix(20, 20), c)
|
||||
|
||||
# Multiple matrix multiplications requiring multiple batches.
|
||||
start_timer(7)
|
||||
d = a * b
|
||||
e = c * b
|
||||
stop_timer(7)
|
||||
|
||||
test_matrix(clear_counting_matrix(20, 20), d)
|
||||
test_matrix(clear_counting_matrix(20, 20), e)
|
||||
Reference in New Issue
Block a user