This commit is contained in:
Marcel Keller
2019-02-14 15:53:09 +11:00
parent b6a18675e8
commit b59f9ca8cf
4 changed files with 746 additions and 0 deletions

View File

@@ -0,0 +1,593 @@
# Syntax is (longest vector, number of sops, name of layer)
v1_025_128 = [
(27, 32768, 'CONV2D'),
(9, 32768 , 'DWCONV2D'),
(8, 65536, 'CONV2D'),
(9, 16384 , 'DWCONV2D'),
(16, 32768, 'CONV2D'),
(9, 32768 , 'DWCONV2D'),
(32, 32768, 'CONV2D'),
(9, 8192 , 'DWCONV2D'),
(32, 16384, 'CONV2D'),
(9, 16384 , 'DWCONV2D'),
(64, 16384, 'CONV2D'),
(9, 4096 , 'DWCONV2D'),
(64, 8192, 'CONV2D'),
(9, 8192 , 'DWCONV2D'),
(128, 8192, 'CONV2D'),
(9, 8192 , 'DWCONV2D'),
(128, 8192, 'CONV2D'),
(9, 8192 , 'DWCONV2D'),
(128, 8192, 'CONV2D'),
(9, 8192 , 'DWCONV2D'),
(128, 8192, 'CONV2D'),
(9, 8192 , 'DWCONV2D'),
(128, 8192, 'CONV2D'),
(9, 2048 , 'DWCONV2D'),
(128, 4096, 'CONV2D'),
(9, 4096 , 'DWCONV2D'),
(256, 4096, 'CONV2D'),
(256, 1001, 'CONV2D')
]
v1_025_160 = [
(27, 51200, 'CONV2D'),
(9, 51200 , 'DWCONV2D'),
(8, 102400, 'CONV2D'),
(9, 25600 , 'DWCONV2D'),
(16, 51200, 'CONV2D'),
(9, 51200 , 'DWCONV2D'),
(32, 51200, 'CONV2D'),
(9, 12800 , 'DWCONV2D'),
(32, 25600, 'CONV2D'),
(9, 25600 , 'DWCONV2D'),
(64, 25600, 'CONV2D'),
(9, 6400 , 'DWCONV2D'),
(64, 12800, 'CONV2D'),
(9, 12800 , 'DWCONV2D'),
(128, 12800, 'CONV2D'),
(9, 12800 , 'DWCONV2D'),
(128, 12800, 'CONV2D'),
(9, 12800 , 'DWCONV2D'),
(128, 12800, 'CONV2D'),
(9, 12800 , 'DWCONV2D'),
(128, 12800, 'CONV2D'),
(9, 12800 , 'DWCONV2D'),
(128, 12800, 'CONV2D'),
(9, 3200 , 'DWCONV2D'),
(128, 6400, 'CONV2D'),
(9, 6400 , 'DWCONV2D'),
(256, 6400, 'CONV2D'),
(256, 1001, 'CONV2D')
]
v1_025_192 = [
(27, 73728, 'CONV2D'),
(9, 73728 , 'DWCONV2D'),
(8, 147456, 'CONV2D'),
(9, 36864 , 'DWCONV2D'),
(16, 73728, 'CONV2D'),
(9, 73728 , 'DWCONV2D'),
(32, 73728, 'CONV2D'),
(9, 18432 , 'DWCONV2D'),
(32, 36864, 'CONV2D'),
(9, 36864 , 'DWCONV2D'),
(64, 36864, 'CONV2D'),
(9, 9216 , 'DWCONV2D'),
(64, 18432, 'CONV2D'),
(9, 18432 , 'DWCONV2D'),
(128, 18432, 'CONV2D'),
(9, 18432 , 'DWCONV2D'),
(128, 18432, 'CONV2D'),
(9, 18432 , 'DWCONV2D'),
(128, 18432, 'CONV2D'),
(9, 18432 , 'DWCONV2D'),
(128, 18432, 'CONV2D'),
(9, 18432 , 'DWCONV2D'),
(128, 18432, 'CONV2D'),
(9, 4608 , 'DWCONV2D'),
(128, 9216, 'CONV2D'),
(9, 9216 , 'DWCONV2D'),
(256, 9216, 'CONV2D'),
(256, 1001, 'CONV2D')
]
v1_025_224 = [
(27, 100352, 'CONV2D'),
(9, 100352 , 'DWCONV2D'),
(8, 200704, 'CONV2D'),
(9, 50176 , 'DWCONV2D'),
(16, 100352, 'CONV2D'),
(9, 100352 , 'DWCONV2D'),
(32, 100352, 'CONV2D'),
(9, 25088 , 'DWCONV2D'),
(32, 50176, 'CONV2D'),
(9, 50176 , 'DWCONV2D'),
(64, 50176, 'CONV2D'),
(9, 12544 , 'DWCONV2D'),
(64, 25088, 'CONV2D'),
(9, 25088 , 'DWCONV2D'),
(128, 25088, 'CONV2D'),
(9, 25088 , 'DWCONV2D'),
(128, 25088, 'CONV2D'),
(9, 25088 , 'DWCONV2D'),
(128, 25088, 'CONV2D'),
(9, 25088 , 'DWCONV2D'),
(128, 25088, 'CONV2D'),
(9, 25088 , 'DWCONV2D'),
(128, 25088, 'CONV2D'),
(9, 6272 , 'DWCONV2D'),
(128, 12544, 'CONV2D'),
(9, 12544 , 'DWCONV2D'),
(256, 12544, 'CONV2D'),
(256, 1001, 'CONV2D')
]
v1_05_128 = [
(27, 65536, 'CONV2D'),
(9, 65536 , 'DWCONV2D'),
(16, 131072, 'CONV2D'),
(9, 32768 , 'DWCONV2D'),
(32, 65536, 'CONV2D'),
(9, 65536 , 'DWCONV2D'),
(64, 65536, 'CONV2D'),
(9, 16384 , 'DWCONV2D'),
(64, 32768, 'CONV2D'),
(9, 32768 , 'DWCONV2D'),
(128, 32768, 'CONV2D'),
(9, 8192 , 'DWCONV2D'),
(128, 16384, 'CONV2D'),
(9, 16384 , 'DWCONV2D'),
(256, 16384, 'CONV2D'),
(9, 16384 , 'DWCONV2D'),
(256, 16384, 'CONV2D'),
(9, 16384 , 'DWCONV2D'),
(256, 16384, 'CONV2D'),
(9, 16384 , 'DWCONV2D'),
(256, 16384, 'CONV2D'),
(9, 16384 , 'DWCONV2D'),
(256, 16384, 'CONV2D'),
(9, 4096 , 'DWCONV2D'),
(256, 8192, 'CONV2D'),
(9, 8192 , 'DWCONV2D'),
(512, 8192, 'CONV2D'),
(512, 1001, 'CONV2D')
]
v1_05_160 = [
(27, 102400, 'CONV2D'),
(9, 102400 , 'DWCONV2D'),
(16, 204800, 'CONV2D'),
(9, 51200 , 'DWCONV2D'),
(32, 102400, 'CONV2D'),
(9, 102400 , 'DWCONV2D'),
(64, 102400, 'CONV2D'),
(9, 25600 , 'DWCONV2D'),
(64, 51200, 'CONV2D'),
(9, 51200 , 'DWCONV2D'),
(128, 51200, 'CONV2D'),
(9, 12800 , 'DWCONV2D'),
(128, 25600, 'CONV2D'),
(9, 25600 , 'DWCONV2D'),
(256, 25600, 'CONV2D'),
(9, 25600 , 'DWCONV2D'),
(256, 25600, 'CONV2D'),
(9, 25600 , 'DWCONV2D'),
(256, 25600, 'CONV2D'),
(9, 25600 , 'DWCONV2D'),
(256, 25600, 'CONV2D'),
(9, 25600 , 'DWCONV2D'),
(256, 25600, 'CONV2D'),
(9, 6400 , 'DWCONV2D'),
(256, 12800, 'CONV2D'),
(9, 12800 , 'DWCONV2D'),
(512, 12800, 'CONV2D'),
(512, 1001, 'CONV2D')
]
v1_05_192 = [
(27, 147456, 'CONV2D'),
(9, 147456 , 'DWCONV2D'),
(16, 294912, 'CONV2D'),
(9, 73728 , 'DWCONV2D'),
(32, 147456, 'CONV2D'),
(9, 147456 , 'DWCONV2D'),
(64, 147456, 'CONV2D'),
(9, 36864 , 'DWCONV2D'),
(64, 73728, 'CONV2D'),
(9, 73728 , 'DWCONV2D'),
(128, 73728, 'CONV2D'),
(9, 18432 , 'DWCONV2D'),
(128, 36864, 'CONV2D'),
(9, 36864 , 'DWCONV2D'),
(256, 36864, 'CONV2D'),
(9, 36864 , 'DWCONV2D'),
(256, 36864, 'CONV2D'),
(9, 36864 , 'DWCONV2D'),
(256, 36864, 'CONV2D'),
(9, 36864 , 'DWCONV2D'),
(256, 36864, 'CONV2D'),
(9, 36864 , 'DWCONV2D'),
(256, 36864, 'CONV2D'),
(9, 9216 , 'DWCONV2D'),
(256, 18432, 'CONV2D'),
(9, 18432 , 'DWCONV2D'),
(512, 18432, 'CONV2D'),
(512, 1001, 'CONV2D')
]
v1_05_224 = [
(27, 200704, 'CONV2D'),
(9, 200704 , 'DWCONV2D'),
(16, 401408, 'CONV2D'),
(9, 100352 , 'DWCONV2D'),
(32, 200704, 'CONV2D'),
(9, 200704 , 'DWCONV2D'),
(64, 200704, 'CONV2D'),
(9, 50176 , 'DWCONV2D'),
(64, 100352, 'CONV2D'),
(9, 100352 , 'DWCONV2D'),
(128, 100352, 'CONV2D'),
(9, 25088 , 'DWCONV2D'),
(128, 50176, 'CONV2D'),
(9, 50176 , 'DWCONV2D'),
(256, 50176, 'CONV2D'),
(9, 50176 , 'DWCONV2D'),
(256, 50176, 'CONV2D'),
(9, 50176 , 'DWCONV2D'),
(256, 50176, 'CONV2D'),
(9, 50176 , 'DWCONV2D'),
(256, 50176, 'CONV2D'),
(9, 50176 , 'DWCONV2D'),
(256, 50176, 'CONV2D'),
(9, 12544 , 'DWCONV2D'),
(256, 25088, 'CONV2D'),
(9, 25088 , 'DWCONV2D'),
(512, 25088, 'CONV2D'),
(512, 1001, 'CONV2D')
]
v1_075_128 = [
(27, 98304, 'CONV2D'),
(9, 98304 , 'DWCONV2D'),
(24, 196608, 'CONV2D'),
(9, 49152 , 'DWCONV2D'),
(48, 98304, 'CONV2D'),
(9, 98304 , 'DWCONV2D'),
(96, 98304, 'CONV2D'),
(9, 24576 , 'DWCONV2D'),
(96, 49152, 'CONV2D'),
(9, 49152 , 'DWCONV2D'),
(192, 49152, 'CONV2D'),
(9, 12288 , 'DWCONV2D'),
(192, 24576, 'CONV2D'),
(9, 24576 , 'DWCONV2D'),
(384, 24576, 'CONV2D'),
(9, 24576 , 'DWCONV2D'),
(384, 24576, 'CONV2D'),
(9, 24576 , 'DWCONV2D'),
(384, 24576, 'CONV2D'),
(9, 24576 , 'DWCONV2D'),
(384, 24576, 'CONV2D'),
(9, 24576 , 'DWCONV2D'),
(384, 24576, 'CONV2D'),
(9, 6144 , 'DWCONV2D'),
(384, 12288, 'CONV2D'),
(9, 12288 , 'DWCONV2D'),
(768, 12288, 'CONV2D'),
(768, 1001, 'CONV2D')
]
v1_075_160 = [
(27, 153600, 'CONV2D'),
(9, 153600 , 'DWCONV2D'),
(24, 307200, 'CONV2D'),
(9, 76800 , 'DWCONV2D'),
(48, 153600, 'CONV2D'),
(9, 153600 , 'DWCONV2D'),
(96, 153600, 'CONV2D'),
(9, 38400 , 'DWCONV2D'),
(96, 76800, 'CONV2D'),
(9, 76800 , 'DWCONV2D'),
(192, 76800, 'CONV2D'),
(9, 19200 , 'DWCONV2D'),
(192, 38400, 'CONV2D'),
(9, 38400 , 'DWCONV2D'),
(384, 38400, 'CONV2D'),
(9, 38400 , 'DWCONV2D'),
(384, 38400, 'CONV2D'),
(9, 38400 , 'DWCONV2D'),
(384, 38400, 'CONV2D'),
(9, 38400 , 'DWCONV2D'),
(384, 38400, 'CONV2D'),
(9, 38400 , 'DWCONV2D'),
(384, 38400, 'CONV2D'),
(9, 9600 , 'DWCONV2D'),
(384, 19200, 'CONV2D'),
(9, 19200 , 'DWCONV2D'),
(768, 19200, 'CONV2D'),
(768, 1001, 'CONV2D')
]
v1_075_192 = [
(27, 221184, 'CONV2D'),
(9, 221184 , 'DWCONV2D'),
(24, 442368, 'CONV2D'),
(9, 110592 , 'DWCONV2D'),
(48, 221184, 'CONV2D'),
(9, 221184 , 'DWCONV2D'),
(96, 221184, 'CONV2D'),
(9, 55296 , 'DWCONV2D'),
(96, 110592, 'CONV2D'),
(9, 110592 , 'DWCONV2D'),
(192, 110592, 'CONV2D'),
(9, 27648 , 'DWCONV2D'),
(192, 55296, 'CONV2D'),
(9, 55296 , 'DWCONV2D'),
(384, 55296, 'CONV2D'),
(9, 55296 , 'DWCONV2D'),
(384, 55296, 'CONV2D'),
(9, 55296 , 'DWCONV2D'),
(384, 55296, 'CONV2D'),
(9, 55296 , 'DWCONV2D'),
(384, 55296, 'CONV2D'),
(9, 55296 , 'DWCONV2D'),
(384, 55296, 'CONV2D'),
(9, 13824 , 'DWCONV2D'),
(384, 27648, 'CONV2D'),
(9, 27648 , 'DWCONV2D'),
(768, 27648, 'CONV2D'),
(768, 1001, 'CONV2D')
]
v1_075_224 = [
(27, 301056, 'CONV2D'),
(9, 301056 , 'DWCONV2D'),
(24, 602112, 'CONV2D'),
(9, 150528 , 'DWCONV2D'),
(48, 301056, 'CONV2D'),
(9, 301056 , 'DWCONV2D'),
(96, 301056, 'CONV2D'),
(9, 75264 , 'DWCONV2D'),
(96, 150528, 'CONV2D'),
(9, 150528 , 'DWCONV2D'),
(192, 150528, 'CONV2D'),
(9, 37632 , 'DWCONV2D'),
(192, 75264, 'CONV2D'),
(9, 75264 , 'DWCONV2D'),
(384, 75264, 'CONV2D'),
(9, 75264 , 'DWCONV2D'),
(384, 75264, 'CONV2D'),
(9, 75264 , 'DWCONV2D'),
(384, 75264, 'CONV2D'),
(9, 75264 , 'DWCONV2D'),
(384, 75264, 'CONV2D'),
(9, 75264 , 'DWCONV2D'),
(384, 75264, 'CONV2D'),
(9, 18816 , 'DWCONV2D'),
(384, 37632, 'CONV2D'),
(9, 37632 , 'DWCONV2D'),
(768, 37632, 'CONV2D'),
(768, 1001, 'CONV2D')
]
v1_1_128 = [
(27, 131072, 'CONV2D'),
(9, 131072 , 'DWCONV2D'),
(32, 262144, 'CONV2D'),
(9, 65536 , 'DWCONV2D'),
(64, 131072, 'CONV2D'),
(9, 131072 , 'DWCONV2D'),
(128, 131072, 'CONV2D'),
(9, 32768 , 'DWCONV2D'),
(128, 65536, 'CONV2D'),
(9, 65536 , 'DWCONV2D'),
(256, 65536, 'CONV2D'),
(9, 16384 , 'DWCONV2D'),
(256, 32768, 'CONV2D'),
(9, 32768 , 'DWCONV2D'),
(512, 32768, 'CONV2D'),
(9, 32768 , 'DWCONV2D'),
(512, 32768, 'CONV2D'),
(9, 32768 , 'DWCONV2D'),
(512, 32768, 'CONV2D'),
(9, 32768 , 'DWCONV2D'),
(512, 32768, 'CONV2D'),
(9, 32768 , 'DWCONV2D'),
(512, 32768, 'CONV2D'),
(9, 8192 , 'DWCONV2D'),
(512, 16384, 'CONV2D'),
(9, 16384 , 'DWCONV2D'),
(1024, 16384, 'CONV2D'),
(1024, 1001, 'CONV2D')
]
v1_1_160 = [
(27, 204800, 'CONV2D'),
(9, 204800 , 'DWCONV2D'),
(32, 409600, 'CONV2D'),
(9, 102400 , 'DWCONV2D'),
(64, 204800, 'CONV2D'),
(9, 204800 , 'DWCONV2D'),
(128, 204800, 'CONV2D'),
(9, 51200 , 'DWCONV2D'),
(128, 102400, 'CONV2D'),
(9, 102400 , 'DWCONV2D'),
(256, 102400, 'CONV2D'),
(9, 25600 , 'DWCONV2D'),
(256, 51200, 'CONV2D'),
(9, 51200 , 'DWCONV2D'),
(512, 51200, 'CONV2D'),
(9, 51200 , 'DWCONV2D'),
(512, 51200, 'CONV2D'),
(9, 51200 , 'DWCONV2D'),
(512, 51200, 'CONV2D'),
(9, 51200 , 'DWCONV2D'),
(512, 51200, 'CONV2D'),
(9, 51200 , 'DWCONV2D'),
(512, 51200, 'CONV2D'),
(9, 12800 , 'DWCONV2D'),
(512, 25600, 'CONV2D'),
(9, 25600 , 'DWCONV2D'),
(1024, 25600, 'CONV2D'),
(1024, 1001, 'CONV2D')
]
v1_1_192 = [
(27, 294912, 'CONV2D'),
(9, 294912 , 'DWCONV2D'),
(32, 589824, 'CONV2D'),
(9, 147456 , 'DWCONV2D'),
(64, 294912, 'CONV2D'),
(9, 294912 , 'DWCONV2D'),
(128, 294912, 'CONV2D'),
(9, 73728 , 'DWCONV2D'),
(128, 147456, 'CONV2D'),
(9, 147456 , 'DWCONV2D'),
(256, 147456, 'CONV2D'),
(9, 36864 , 'DWCONV2D'),
(256, 73728, 'CONV2D'),
(9, 73728 , 'DWCONV2D'),
(512, 73728, 'CONV2D'),
(9, 73728 , 'DWCONV2D'),
(512, 73728, 'CONV2D'),
(9, 73728 , 'DWCONV2D'),
(512, 73728, 'CONV2D'),
(9, 73728 , 'DWCONV2D'),
(512, 73728, 'CONV2D'),
(9, 73728 , 'DWCONV2D'),
(512, 73728, 'CONV2D'),
(9, 18432 , 'DWCONV2D'),
(512, 36864, 'CONV2D'),
(9, 36864 , 'DWCONV2D'),
(1024, 36864, 'CONV2D'),
(1024, 1001, 'CONV2D')
]
v1_1_224 = [
(27, 401408, 'CONV2D'),
(9, 401408 , 'DWCONV2D'),
(32, 802816, 'CONV2D'),
(9, 200704 , 'DWCONV2D'),
(64, 401408, 'CONV2D'),
(9, 401408 , 'DWCONV2D'),
(128, 401408, 'CONV2D'),
(9, 100352 , 'DWCONV2D'),
(128, 200704, 'CONV2D'),
(9, 200704 , 'DWCONV2D'),
(256, 200704, 'CONV2D'),
(9, 50176 , 'DWCONV2D'),
(256, 100352, 'CONV2D'),
(9, 100352 , 'DWCONV2D'),
(512, 100352, 'CONV2D'),
(9, 100352 , 'DWCONV2D'),
(512, 100352, 'CONV2D'),
(9, 100352 , 'DWCONV2D'),
(512, 100352, 'CONV2D'),
(9, 100352 , 'DWCONV2D'),
(512, 100352, 'CONV2D'),
(9, 100352 , 'DWCONV2D'),
(512, 100352, 'CONV2D'),
(9, 25088 , 'DWCONV2D'),
(512, 50176, 'CONV2D'),
(9, 50176 , 'DWCONV2D'),
(1024, 50176, 'CONV2D'),
(1024, 1001, 'CONV2D')
]
v2_1_224 = [
(27, 401408, 'CONV2D'),
(9, 401408, 'DWCONV2D' ),
(32, 200704, 'CONV2D'),
(16, 1204224, 'CONV2D'),
(9, 301056, 'DWCONV2D' ),
(96, 75264, 'CONV2D'),
(24, 451584, 'CONV2D'),
(9, 451584, 'DWCONV2D' ),
(144, 75264, 'CONV2D'),
(24, 451584, 'CONV2D'),
(9, 112896, 'DWCONV2D' ),
(144, 25088, 'CONV2D'),
(32, 150528, 'CONV2D'),
(9, 150528, 'DWCONV2D' ),
(192, 25088, 'CONV2D'),
(32, 150528, 'CONV2D'),
(9, 150528, 'DWCONV2D' ),
(192, 25088, 'CONV2D'),
(32, 150528, 'CONV2D'),
(9, 37632, 'DWCONV2D' ),
(192, 12544, 'CONV2D'),
(64, 75264, 'CONV2D'),
(9, 75264, 'DWCONV2D' ),
(384, 12544, 'CONV2D'),
(64, 75264, 'CONV2D'),
(9, 75264, 'DWCONV2D' ),
(384, 12544, 'CONV2D'),
(64, 75264, 'CONV2D'),
(9, 75264, 'DWCONV2D' ),
(384, 12544, 'CONV2D'),
(64, 75264, 'CONV2D'),
(9, 75264, 'DWCONV2D' ),
(384, 18816, 'CONV2D'),
(96, 112896, 'CONV2D'),
(9, 112896, 'DWCONV2D' ),
(576, 18816, 'CONV2D'),
(96, 112896, 'CONV2D'),
(9, 112896, 'DWCONV2D' ),
(576, 18816, 'CONV2D'),
(96, 112896, 'CONV2D'),
(9, 28224, 'DWCONV2D' ),
(576, 7840, 'CONV2D'),
(160, 47040, 'CONV2D'),
(9, 47040, 'DWCONV2D' ),
(960, 7840, 'CONV2D'),
(160, 47040, 'CONV2D'),
(9, 47040, 'DWCONV2D' ),
(960, 7840, 'CONV2D'),
(160, 47040, 'CONV2D'),
(9, 47040, 'DWCONV2D' ),
(960, 15680, 'CONV2D'),
(320, 62720, 'CONV2D'),
(1280, 1001, 'CONV2D')
]
network = program.args[1]
layers = globals()[network]
# c5.9xlarge has 36 cores
n_threads = 36
# S, Z, bit length
# using sfloat for secret floats and sint for secret int
p1 = squant_params(sfloat(.001), sint(1), 8)
p2 = squant_params(sfloat(.002), sint(2), 8)
p3 = squant_params(sfloat(.003), sint(3), 8)
# precompute multiplication of p1 and p2 to p3
p3.precompute(p1, p2)
# need to this to have arrays with specific parameters
class squant1(squant):
params = p1
class squant2(squant):
params = p2
zero1 = squant1(0)
zero2 = squant2(0)
for l, n, _ in layers:
a = Array(l, squant1)
b = Array(l, squant2)
a.assign_all(zero1)
b.assign_all(zero2)
# parallelization for optimization
@for_range_multithread(n_threads, 100, n)
def _(i):
# only for optimization
aa = a.get_vector()
bb = b.get_vector()
# store in memory to prevent dead code elimination
squant.dot_product(aa, bb, res_params=p3).store_in_mem(regint(0))

View File

@@ -0,0 +1,80 @@
NetworkA = [
(784, 128, 'FC'),
(128, 128, 'FC'),
(128, 10, 'FC')
]
NetworkD = [
(25, 980, 'CONV2D'),
(980, 100, 'FC'),
(100, 10, 'FC')
]
NetworkB = [
(25, 9216, 'CONV2D'),
(4, 2304, 'MAXP'),
(400, 1024, 'CONV2D'),
(4, 256, 'MAXP'),
(256, 100, 'FC'),
(100, 10, 'FC')
]
NetworkC = [
(25, 11520, 'CONV2D'),
(4, 2880, 'MAXP'),
(500, 3200, 'CONV2D'),
(4, 800, 'MAXP'),
(800, 500, 'FC'),
(500, 10, 'FC')
]
network = globals()['Network' + program.args[1]]
# c5.9xlarge has 36 cores
n_threads = 8
# S, Z, bit length
# using sfloat for secret floats and sint for secret int
p1 = squant_params(sfloat(.001), sint(1), 8)
p2 = squant_params(sfloat(.002), sint(2), 8)
p3 = squant_params(sfloat(.003), sint(3), 8)
# precompute multiplication of p1 and p2 to p3
p3.precompute(p1, p2)
# need to this to have arrays with specific parameters
class squant1(squant):
params = p1
class squant2(squant):
params = p2
program.set_bit_length(8)
import util
def maxpool(ln, num):
items = [sint(0, size=num)] * ln
util.tree_reduce(min, items).store_in_mem(0)
zero1 = squant1(0)
zero2 = squant2(0)
for l, n, typ in network:
if typ == 'MAXP':
maxpool(l, n)
else:
a = Array(l, squant1)
b = Array(l, squant2)
a.assign_all(zero1)
b.assign_all(zero2)
# parallelization for optimization
@for_range_multithread(n_threads, 1000, n)
def _(i):
# only for optimization
aa = a.get_vector()
bb = b.get_vector()
# store in memory to prevent dead code elimination
squant.dot_product(aa, bb, res_params=p3).store_in_mem(regint(0))

View File

@@ -0,0 +1,35 @@
# c5.9xlarge has 36 cores
n_threads = 36
# S, Z, bit length
# using sfloat for secret floats and sint for secret int
p1 = squant_params(sfloat(.001), sint(1), 8)
p2 = squant_params(sfloat(.002), sint(2), 8)
p3 = squant_params(sfloat(.003), sint(3), 8)
# precompute multiplication of p1 and p2 to p3
p3.precompute(p1, p2)
# need to this to have arrays with specific parameters
class squant1(squant):
params = p1
class squant2(squant):
params = p2
# fixed number of terms to be 512
a = Array(512, squant1)
b = Array(512, squant2)
a.assign_all(0)
b.assign_all(0)
# 50000, 100000, 150000, 200000
n = int(program.args[1])
# parallelization for optimization
@for_range_multithread(n_threads, 1000, n)
def _(i):
# only for optimization
aa = a.get_vector()
bb = b.get_vector()
# store in memory to prevent dead code elimination
squant.dot_product(aa, bb, res_params=p3).store_in_mem(regint(0))

View File

@@ -0,0 +1,38 @@
# c5.9xlarge has 36 cores
n_threads = 36
# S, Z, bit length
# using sfloat for secret floats and sint for secret int
p1 = squant_params(sfloat(.001), sint(1), 8)
p2 = squant_params(sfloat(.002), sint(2), 8)
p3 = squant_params(sfloat(.003), sint(3), 8)
# precompute multiplication of p1 and p2 to p3
p3.precompute(p1, p2)
# need to this to have arrays with specific parameters
class squant1(squant):
params = p1
class squant2(squant):
params = p2
# 256, 512, 768, 1024
ln = int(program.args[1])
a = Array(ln, squant1)
b = Array(ln, squant2)
a.assign_all(0)
b.assign_all(0)
# fix number of sops
n = 100000
# parallelization for optimization
@for_range_multithread(n_threads, 100, n)
def _(i):
# only for optimization
aa = a.get_vector()
bb = b.get_vector()
# store in memory to prevent dead code elimination
squant.dot_product(aa, bb, res_params=p3).store_in_mem(regint(0))