Files
MP-SPDZ/Scripts/memory-usage.py
Marcel Keller 78fe3d8bad Maintenance.
2024-07-09 12:19:52 +10:00

74 lines
2.0 KiB
Python
Executable File

#!/usr/bin/env python3
import sys, os
import collections
sys.path.append('.')
from Compiler.program import *
from Compiler.instructions_base import *
if len(sys.argv) <= 1:
print('Usage: %s <program>' % sys.argv[0])
res = collections.defaultdict(lambda: 0)
regs = collections.defaultdict(lambda: 0)
thread_regs = collections.defaultdict(lambda: 0)
def process(tapename, res, regs):
for inst in Tape.read_instructions(tapename):
t = inst.type
if issubclass(t, DirectMemoryInstruction):
res[type(inst.args[0])] = max(inst.args[1].i + inst.size,
res[type(inst.args[0])]) + 1
for arg in inst.args:
if isinstance(arg, RegisterArgFormat):
regs[type(arg)] = max(regs[type(arg)], arg.i + inst.size)
tapes = Program.read_tapes(sys.argv[1])
n_threads = Program.read_n_threads(sys.argv[1])
domain_size = Program.read_domain_size(sys.argv[1]) or 8
process(next(tapes), res, regs)
for tapename in tapes:
process(tapename, res, thread_regs)
reverse_formats = dict((v, k) for k, v in ArgFormats.items())
regout = lambda regs: dict((reverse_formats[t], n) for t, n in regs.items())
def output(data):
for t, n in data.items():
if n:
try:
print('%10d %s' % (n, ArgFormats[t.removesuffix('w')].name))
except:
pass
total = 0
for x in res, regs:
total += sum(x.values())
thread_total = sum(thread_regs.values())
print ('Memory:')
output(regout(res))
print ('Registers in main thread:')
output(regout(regs))
if thread_regs:
print ('Registers in other threads:')
output(regout(thread_regs))
if len(sys.argv) > 2:
min = max = int(sys.argv[2]) * domain_size
else:
min = 1 * domain_size
max = 3 * domain_size
print ('The program requires at least an estimated %f-%f GB of RAM per party.'
% (min * (total + thread_total) * 1e-9,
max * ((total + (n_threads - 1) * thread_total) * 1e-9)))