Files
MP-SPDZ/Programs/Source/test_gc.mpc
2022-11-09 11:22:18 +11:00

140 lines
3.5 KiB
Plaintext

def test(actual, expected):
actual = actual.reveal()
if expected >= 2 ** (actual.n - 1) and actual.n != 1:
expected -= 2 ** actual.n
print_ln('expected %s, got %s', expected, actual)
test(sbits(3) + sbits(5), 3 ^ 5)
test(cbits(3) + cbits(5), 3 + 5)
test(cbits(3) + (5), 3 + 5)
test(cbits(5) - cbits(3), 5 - 3)
test(cbits(5) - (3), 5 - 3)
test((5) - cbits(3), 5 - 3)
test(-cbits(-3), 3)
test(cbits(3) ^ cbits(5), 3 ^ 5)
test(cbits(3) ^ (5), 3 ^ 5)
test(sbits(3) + 5, 3 ^ 5)
test(sbits(3) - sbits(5), 3 ^ 5)
test(sbit(1) * sbits(3), 3)
#test(cbits(1) * cbits(3), 3)
test(sbit(1) * 3, 3)
test(~sbits.new(1, n=64), 2**64 - 2)
test(sbits(5) & sbits(3), 5 & 3)
test(sbits(3).equal(sbits(3)), 1)
test(sbits(3).equal(sbits(2)), 0)
test(sbit(1).if_else(sbits(3), sbits(5)), 3)
test(sbits(7) << 1, 14)
test(cbits(5) >> 1, 2)
test(sbit.bit_compose((sbit(1), sbit(0), sbit(1))), 5)
test(sbit(0).if_else(1, 2), 2)
test(sbit(1).if_else(1, 2), 1)
test(sbit(0).if_else(2, 1), 1)
test(sbit(1).if_else(2, 1), 2)
test(sbits.compose((sbits.new(2, n=2), sbits.new(1, n=2)), 2), 6)
x = MemValue(sbits(1234))
program.curr_tape.start_new_basicblock()
test(x, 1234)
x = MemValue(cbits(123))
program.curr_tape.start_new_basicblock()
test(x, 123)
x = memorize(cbits(234))
program.curr_tape.start_new_basicblock()
test(unmemorize(x), 234)
cbits(456).store_in_mem(1234)
program.curr_tape.start_new_basicblock()
test(cbits.load_mem(1234), 456)
test(sbits.new(1 << 63, n=64), 1 << 63)
bits = sbits(0x1234, n=40).bit_decompose(40)
test(sbits.bit_compose(bits), 0x1234)
test(sbits.new(5, n=4) ^ sbits.new(3, n=3), 6)
test(sbits.new(5, n=3) ^ sbits.new(3, n=4), 6)
test(sbits.new(13, n=4) ^ sbits.new(3, n=3), 14)
test(sbits.new(5, n=3) ^ sbits.new(11, n=4), 14)
b = sbits.get_random_bit()
test(b * (1 - b), 0)
bits = [sbits.get_random_bit() for i in range(40)]
print_ln('random: %s', sbits.bit_compose(bits).reveal())
r = sbits.bit_compose(bits)
test(r * sbit(1) + sbit(1) * r, 0)
test(sbits.get_type(64)(2**64 - 1).popcnt(), 64)
a = [sbits.new(x, 2) for x in range(4)]
x, y, *z = sbits.trans(a)
test(x, 0xa)
test(y, 0xc)
aa = [1, 2**63, 2**64 - 1]
a = sbitvec(sbits.new(x, n=64) for x in aa).elements()
test(a[0], aa[0])
test(a[1], aa[1])
test(a[2], aa[2])
a = sbitvec(sbits.new(x, n=64) for x in [1, 2**63, 2**64 - 1]).popcnt().elements()
test(a[0], 1)
test(a[1], 1)
test(a[2], 64)
a = sbits.new(-1, n=64)
test(a & a, 2**64 - 1)
sbits.n = 64
a = sbitvec(64 * [sbits.new(2**64 - 1, n=64)]).popcnt().elements()
test(a[0], 64)
test(a[63], 64)
a = sbitintvec(sbits.new(x, n=64) for x in [2**63 - 1, 1])
b = sbitintvec(sbits.new(x, n=64) for x in [1, -1])
c = (a + b).elements()
test(c[0], 2**63)
test(c[1], 0)
a = sbitintvec(sbits.new(x, n=64) for x in [1, 1, 2**63 - 1, 2**63])
b = sbitintvec(sbits.new(x, n=64) for x in [1, 2, 2**63, 2**63 - 1])
c = a.less_than(b).elements()
test(c[0], 0)
test(c[1], 1)
test(c[2], 0)
test(c[3], 1)
test(sbit(sbits(3)), 1)
test(sbits(sbit(1)), 1)
si32 = sbitint.get_type(32)
test(si32(5) + si32(3), 8)
test(si32(5) - si32(3), 2)
test(si32(5) * si32(3), 15)
test(sbit(si32(5) < si32(3)), 0)
sb32 = sbits.get_type(32)
siv32 = sbitintvec.get_type(32)
a = siv32(sbitvec([sb32(3), sb32(5)]))
b = siv32(sbitvec([sb32(4), sb32(6)]))
c = (a + b).elements()
test(c[0], 7)
test(c[1], 11)
test(c[0] + c[1], 18)
c = (a * b).elements()
test(c[0], 12)
test(c[1], 30)
c = (a - b).elements()
test(c[0], 2 ** 32 - 1)
test(c[1], 2 ** 32 - 1)
c = (a < b).elements()
test(c[0], 1)
test(c[1], 1)
test(~cbits.get_type(2)(0), 3)
test(~sbits.get_type(64)(0).reveal(), 2 ** 64 - 1)