mirror of
https://github.com/data61/MP-SPDZ.git
synced 2026-01-08 21:18:03 -05:00
refactor to add Compiler class
This commit is contained in:
@@ -2,30 +2,3 @@ from . import compilerLib, program, instructions, types, library, floatingpoint
|
||||
from .GC import types as GC_types
|
||||
import inspect
|
||||
from .config import *
|
||||
from .compilerLib import run
|
||||
|
||||
|
||||
# add all instructions to the program VARS dictionary
|
||||
compilerLib.VARS = {}
|
||||
instr_classes = [t[1] for t in inspect.getmembers(instructions, inspect.isclass)]
|
||||
|
||||
for mod in (types, GC_types):
|
||||
instr_classes += [t[1] for t in inspect.getmembers(mod, inspect.isclass)\
|
||||
if t[1].__module__ == mod.__name__]
|
||||
|
||||
instr_classes += [t[1] for t in inspect.getmembers(library, inspect.isfunction)\
|
||||
if t[1].__module__ == library.__name__]
|
||||
|
||||
for op in instr_classes:
|
||||
compilerLib.VARS[op.__name__] = op
|
||||
|
||||
# add open and input separately due to name conflict
|
||||
compilerLib.VARS['open'] = instructions.asm_open
|
||||
compilerLib.VARS['vopen'] = instructions.vasm_open
|
||||
compilerLib.VARS['gopen'] = instructions.gasm_open
|
||||
compilerLib.VARS['vgopen'] = instructions.vgasm_open
|
||||
compilerLib.VARS['input'] = instructions.asm_input
|
||||
compilerLib.VARS['ginput'] = instructions.gasm_input
|
||||
|
||||
compilerLib.VARS['comparison'] = comparison
|
||||
compilerLib.VARS['floatingpoint'] = floatingpoint
|
||||
|
||||
@@ -1,94 +1,361 @@
|
||||
from Compiler.program import Program
|
||||
from .GC import types as GC_types
|
||||
|
||||
import inspect
|
||||
import os
|
||||
import re
|
||||
import sys
|
||||
import re, tempfile, os
|
||||
import tempfile
|
||||
from optparse import OptionParser
|
||||
|
||||
from .GC import types as GC_types
|
||||
from .program import Program, defaults
|
||||
|
||||
|
||||
def run(args, options):
|
||||
""" Compile a file and output a Program object.
|
||||
|
||||
If options.merge_opens is set to True, will attempt to merge any
|
||||
parallelisable open instructions. """
|
||||
|
||||
prog = Program(args, options)
|
||||
VARS['program'] = prog
|
||||
if options.binary:
|
||||
VARS['sint'] = GC_types.sbitintvec.get_type(int(options.binary))
|
||||
VARS['sfix'] = GC_types.sbitfixvec
|
||||
for i in 'cint', 'cfix', 'cgf2n', 'sintbit', 'sgf2n', 'sgf2nint', \
|
||||
'sgf2nuint', 'sgf2nuint32', 'sgf2nfloat', 'sfloat', 'cfloat', \
|
||||
'squant':
|
||||
del VARS[i]
|
||||
|
||||
print('Compiling file', prog.infile)
|
||||
f = open(prog.infile, 'rb')
|
||||
class Compiler:
|
||||
def __init__(self):
|
||||
self.usage = "usage: %prog [options] filename [args]"
|
||||
self.build_option_parser()
|
||||
self.VARS = {}
|
||||
|
||||
changed = False
|
||||
if options.flow_optimization:
|
||||
output = []
|
||||
if_stack = []
|
||||
for line in open(prog.infile):
|
||||
if if_stack and not re.match(if_stack[-1][0], line):
|
||||
if_stack.pop()
|
||||
m = re.match(
|
||||
'(\s*)for +([a-zA-Z_]+) +in +range\(([0-9a-zA-Z_]+)\):',
|
||||
line)
|
||||
if m:
|
||||
output.append('%s@for_range_opt(%s)\n' % (m.group(1),
|
||||
m.group(3)))
|
||||
output.append('%sdef _(%s):\n' % (m.group(1), m.group(2)))
|
||||
changed = True
|
||||
continue
|
||||
m = re.match('(\s*)if(\W.*):', line)
|
||||
if m:
|
||||
if_stack.append((m.group(1), len(output)))
|
||||
output.append('%s@if_(%s)\n' % (m.group(1), m.group(2)))
|
||||
output.append('%sdef _():\n' % (m.group(1)))
|
||||
changed = True
|
||||
continue
|
||||
m = re.match('(\s*)elif\s+', line)
|
||||
if m:
|
||||
raise CompilerError('elif not supported')
|
||||
if if_stack:
|
||||
m = re.match('%selse:' % if_stack[-1][0], line)
|
||||
if m:
|
||||
start = if_stack[-1][1]
|
||||
ws = if_stack[-1][0]
|
||||
output[start] = re.sub(r'^%s@if_\(' % ws, r'%s@if_e(' % ws,
|
||||
output[start])
|
||||
output.append('%s@else_\n' % ws)
|
||||
output.append('%sdef _():\n' % ws)
|
||||
continue
|
||||
output.append(line)
|
||||
if changed:
|
||||
infile = tempfile.NamedTemporaryFile('w+', delete=False)
|
||||
for line in output:
|
||||
infile.write(line)
|
||||
infile.seek(0)
|
||||
else:
|
||||
infile = open(prog.infile)
|
||||
else:
|
||||
infile = open(prog.infile)
|
||||
def build_option_parser(self):
|
||||
parser = OptionParser(usage=self.usage)
|
||||
parser.add_option(
|
||||
"-n",
|
||||
"--nomerge",
|
||||
action="store_false",
|
||||
dest="merge_opens",
|
||||
default=defaults.merge_opens,
|
||||
help="don't attempt to merge open instructions",
|
||||
)
|
||||
parser.add_option("-o", "--output", dest="outfile", help="specify output file")
|
||||
parser.add_option(
|
||||
"-a",
|
||||
"--asm-output",
|
||||
dest="asmoutfile",
|
||||
help="asm output file for debugging",
|
||||
)
|
||||
parser.add_option(
|
||||
"-g",
|
||||
"--galoissize",
|
||||
dest="galois",
|
||||
default=defaults.galois,
|
||||
help="bit length of Galois field",
|
||||
)
|
||||
parser.add_option(
|
||||
"-d",
|
||||
"--debug",
|
||||
action="store_true",
|
||||
dest="debug",
|
||||
help="keep track of trace for debugging",
|
||||
)
|
||||
parser.add_option(
|
||||
"-c",
|
||||
"--comparison",
|
||||
dest="comparison",
|
||||
default="log",
|
||||
help="comparison variant: log|plain|inv|sinv",
|
||||
)
|
||||
parser.add_option(
|
||||
"-M",
|
||||
"--preserve-mem-order",
|
||||
action="store_true",
|
||||
dest="preserve_mem_order",
|
||||
default=defaults.preserve_mem_order,
|
||||
help="preserve order of memory instructions; possible efficiency loss",
|
||||
)
|
||||
parser.add_option(
|
||||
"-O",
|
||||
"--optimize-hard",
|
||||
action="store_true",
|
||||
dest="optimize_hard",
|
||||
help="currently not in use",
|
||||
)
|
||||
parser.add_option(
|
||||
"-u",
|
||||
"--noreallocate",
|
||||
action="store_true",
|
||||
dest="noreallocate",
|
||||
default=defaults.noreallocate,
|
||||
help="don't reallocate",
|
||||
)
|
||||
parser.add_option(
|
||||
"-m",
|
||||
"--max-parallel-open",
|
||||
dest="max_parallel_open",
|
||||
default=defaults.max_parallel_open,
|
||||
help="restrict number of parallel opens",
|
||||
)
|
||||
parser.add_option(
|
||||
"-D",
|
||||
"--dead-code-elimination",
|
||||
action="store_true",
|
||||
dest="dead_code_elimination",
|
||||
default=defaults.dead_code_elimination,
|
||||
help="eliminate instructions with unused result",
|
||||
)
|
||||
parser.add_option(
|
||||
"-p",
|
||||
"--profile",
|
||||
action="store_true",
|
||||
dest="profile",
|
||||
help="profile compilation",
|
||||
)
|
||||
parser.add_option(
|
||||
"-s",
|
||||
"--stop",
|
||||
action="store_true",
|
||||
dest="stop",
|
||||
help="stop on register errors",
|
||||
)
|
||||
parser.add_option(
|
||||
"-R",
|
||||
"--ring",
|
||||
dest="ring",
|
||||
default=defaults.ring,
|
||||
help="bit length of ring (default: 0 for field)",
|
||||
)
|
||||
parser.add_option(
|
||||
"-B",
|
||||
"--binary",
|
||||
dest="binary",
|
||||
default=defaults.binary,
|
||||
help="bit length of sint in binary circuit (default: 0 for arithmetic)",
|
||||
)
|
||||
parser.add_option(
|
||||
"-F",
|
||||
"--field",
|
||||
dest="field",
|
||||
default=defaults.field,
|
||||
help="bit length of sint modulo prime (default: 64)",
|
||||
)
|
||||
parser.add_option(
|
||||
"-P",
|
||||
"--prime",
|
||||
dest="prime",
|
||||
default=defaults.prime,
|
||||
help="prime modulus (default: not specified)",
|
||||
)
|
||||
parser.add_option(
|
||||
"-I",
|
||||
"--insecure",
|
||||
action="store_true",
|
||||
dest="insecure",
|
||||
help="activate insecure functionality for benchmarking",
|
||||
)
|
||||
parser.add_option(
|
||||
"-b",
|
||||
"--budget",
|
||||
dest="budget",
|
||||
default=defaults.budget,
|
||||
help="set budget for optimized loop unrolling " "(default: 100000)",
|
||||
)
|
||||
parser.add_option(
|
||||
"-X",
|
||||
"--mixed",
|
||||
action="store_true",
|
||||
dest="mixed",
|
||||
help="mixing arithmetic and binary computation",
|
||||
)
|
||||
parser.add_option(
|
||||
"-Y",
|
||||
"--edabit",
|
||||
action="store_true",
|
||||
dest="edabit",
|
||||
help="mixing arithmetic and binary computation using edaBits",
|
||||
)
|
||||
parser.add_option(
|
||||
"-Z",
|
||||
"--split",
|
||||
default=defaults.split,
|
||||
dest="split",
|
||||
help="mixing arithmetic and binary computation "
|
||||
"using direct conversion if supported "
|
||||
"(number of parties as argument)",
|
||||
)
|
||||
parser.add_option(
|
||||
"-C",
|
||||
"--CISC",
|
||||
action="store_true",
|
||||
dest="cisc",
|
||||
help="faster CISC compilation mode",
|
||||
)
|
||||
parser.add_option(
|
||||
"-K",
|
||||
"--keep-cisc",
|
||||
dest="keep_cisc",
|
||||
help="don't translate CISC instructions",
|
||||
)
|
||||
parser.add_option(
|
||||
"-l",
|
||||
"--flow-optimization",
|
||||
action="store_true",
|
||||
dest="flow_optimization",
|
||||
help="optimize control flow",
|
||||
)
|
||||
parser.add_option(
|
||||
"-v",
|
||||
"--verbose",
|
||||
action="store_true",
|
||||
dest="verbose",
|
||||
help="more verbose output",
|
||||
)
|
||||
self.parser = parser
|
||||
|
||||
# make compiler modules directly accessible
|
||||
sys.path.insert(0, 'Compiler')
|
||||
# create the tapes
|
||||
exec(compile(infile.read(), infile.name, 'exec'), VARS)
|
||||
def parse_args(self):
|
||||
self.options, self.args = self.parser.parse_args()
|
||||
if len(self.args) < 1:
|
||||
self.parser.print_help()
|
||||
return
|
||||
|
||||
if changed and not options.debug:
|
||||
os.unlink(infile.name)
|
||||
if self.options.optimize_hard:
|
||||
print("Note that -O/--optimize-hard currently has no effect")
|
||||
|
||||
prog.finalize()
|
||||
def build_program(self):
|
||||
self.prog = Program(self.args, self.options)
|
||||
|
||||
if prog.req_num:
|
||||
print('Program requires at most:')
|
||||
for x in prog.req_num.pretty():
|
||||
print(x)
|
||||
def build_vars(self):
|
||||
from . import comparison, floatingpoint, instructions, library, types
|
||||
|
||||
if prog.verbose:
|
||||
print('Program requires:', repr(prog.req_num))
|
||||
print('Cost:', 0 if prog.req_num is None else prog.req_num.cost())
|
||||
print('Memory size:', dict(prog.allocated_mem))
|
||||
# add all instructions to the program VARS dictionary
|
||||
instr_classes = [
|
||||
t[1] for t in inspect.getmembers(instructions, inspect.isclass)
|
||||
]
|
||||
|
||||
return prog
|
||||
for mod in (types, GC_types):
|
||||
instr_classes += [
|
||||
t[1]
|
||||
for t in inspect.getmembers(mod, inspect.isclass)
|
||||
if t[1].__module__ == mod.__name__
|
||||
]
|
||||
|
||||
instr_classes += [
|
||||
t[1]
|
||||
for t in inspect.getmembers(library, inspect.isfunction)
|
||||
if t[1].__module__ == library.__name__
|
||||
]
|
||||
|
||||
for op in instr_classes:
|
||||
self.VARS[op.__name__] = op
|
||||
|
||||
# add open and input separately due to name conflict
|
||||
self.VARS["open"] = instructions.asm_open
|
||||
self.VARS["vopen"] = instructions.vasm_open
|
||||
self.VARS["gopen"] = instructions.gasm_open
|
||||
self.VARS["vgopen"] = instructions.vgasm_open
|
||||
self.VARS["input"] = instructions.asm_input
|
||||
self.VARS["ginput"] = instructions.gasm_input
|
||||
|
||||
self.VARS["comparison"] = comparison
|
||||
self.VARS["floatingpoint"] = floatingpoint
|
||||
|
||||
self.VARS["program"] = self.prog
|
||||
if self.options.binary:
|
||||
self.VARS["sint"] = GC_types.sbitintvec.get_type(int(self.options.binary))
|
||||
self.VARS["sfix"] = GC_types.sbitfixvec
|
||||
for i in [
|
||||
"cint",
|
||||
"cfix",
|
||||
"cgf2n",
|
||||
"sintbit",
|
||||
"sgf2n",
|
||||
"sgf2nint",
|
||||
"sgf2nuint",
|
||||
"sgf2nuint32",
|
||||
"sgf2nfloat",
|
||||
"sfloat",
|
||||
"cfloat",
|
||||
"squant",
|
||||
]:
|
||||
del self.VARS[i]
|
||||
|
||||
def prep_compile(self):
|
||||
self.parse_args()
|
||||
self.build_program()
|
||||
self.build_vars()
|
||||
|
||||
def compile_file(self):
|
||||
"""Compile a file and output a Program object.
|
||||
|
||||
If options.merge_opens is set to True, will attempt to merge any
|
||||
parallelisable open instructions."""
|
||||
print("Compiling file", self.prog.infile)
|
||||
|
||||
with open(self.prog.infile, "rb") as f:
|
||||
changed = False
|
||||
if self.options.flow_optimization:
|
||||
output = []
|
||||
if_stack = []
|
||||
for line in f:
|
||||
if if_stack and not re.match(if_stack[-1][0], line):
|
||||
if_stack.pop()
|
||||
m = re.match(
|
||||
r"(\s*)for +([a-zA-Z_]+) +in " r"+range\(([0-9a-zA-Z_]+)\):",
|
||||
line,
|
||||
)
|
||||
if m:
|
||||
output.append(
|
||||
"%s@for_range_opt(%s)\n" % (m.group(1), m.group(3))
|
||||
)
|
||||
output.append("%sdef _(%s):\n" % (m.group(1), m.group(2)))
|
||||
changed = True
|
||||
continue
|
||||
m = re.match(r"(\s*)if(\W.*):", line)
|
||||
if m:
|
||||
if_stack.append((m.group(1), len(output)))
|
||||
output.append("%s@if_(%s)\n" % (m.group(1), m.group(2)))
|
||||
output.append("%sdef _():\n" % (m.group(1)))
|
||||
changed = True
|
||||
continue
|
||||
m = re.match(r"(\s*)elif\s+", line)
|
||||
if m:
|
||||
raise CompilerError("elif not supported")
|
||||
if if_stack:
|
||||
m = re.match("%selse:" % if_stack[-1][0], line)
|
||||
if m:
|
||||
start = if_stack[-1][1]
|
||||
ws = if_stack[-1][0]
|
||||
output[start] = re.sub(
|
||||
r"^%s@if_\(" % ws, r"%s@if_e(" % ws, output[start]
|
||||
)
|
||||
output.append("%s@else_\n" % ws)
|
||||
output.append("%sdef _():\n" % ws)
|
||||
continue
|
||||
output.append(line)
|
||||
if changed:
|
||||
infile = tempfile.NamedTemporaryFile("w+", delete=False)
|
||||
for line in output:
|
||||
infile.write(line)
|
||||
infile.seek(0)
|
||||
else:
|
||||
infile = open(self.prog.infile)
|
||||
else:
|
||||
infile = open(self.prog.infile)
|
||||
|
||||
# make compiler modules directly accessible
|
||||
sys.path.insert(0, "Compiler")
|
||||
# create the tapes
|
||||
exec(compile(infile.read(), infile.name, "exec"), self.VARS)
|
||||
|
||||
if changed and not self.options.debug:
|
||||
os.unlink(infile.name)
|
||||
|
||||
return self.finalize_compile()
|
||||
|
||||
def compile_func(self, f):
|
||||
self.prep_compile()
|
||||
print(f"Compiling function: {f.__name__}")
|
||||
f(self.VARS)
|
||||
self.finalize_compile()
|
||||
|
||||
def finalize_compile(self):
|
||||
self.prog.finalize()
|
||||
|
||||
if self.prog.req_num:
|
||||
print("Program requires at most:")
|
||||
for x in self.prog.req_num.pretty():
|
||||
print(x)
|
||||
|
||||
if self.prog.verbose:
|
||||
print("Program requires:", repr(self.prog.req_num))
|
||||
print("Cost:", 0 if self.prog.req_num is None else self.prog.req_num.cost())
|
||||
print("Memory size:", dict(self.prog.allocated_mem))
|
||||
|
||||
return self.prog
|
||||
|
||||
106
compile.py
106
compile.py
@@ -12,100 +12,30 @@
|
||||
#
|
||||
# See the compiler documentation at https://mp-spdz.readthedocs.io
|
||||
# for details on the Compiler package
|
||||
from Compiler.compilerLib import Compiler
|
||||
|
||||
|
||||
from optparse import OptionParser
|
||||
from Compiler.program import defaults
|
||||
import Compiler
|
||||
def compilation(compiler):
|
||||
prog = compiler.compile_file()
|
||||
|
||||
def main():
|
||||
usage = "usage: %prog [options] filename [args]"
|
||||
parser = OptionParser(usage=usage)
|
||||
parser.add_option("-n", "--nomerge",
|
||||
action="store_false", dest="merge_opens",
|
||||
default=defaults.merge_opens,
|
||||
help="don't attempt to merge open instructions")
|
||||
parser.add_option("-o", "--output", dest="outfile",
|
||||
help="specify output file")
|
||||
parser.add_option("-a", "--asm-output", dest="asmoutfile",
|
||||
help="asm output file for debugging")
|
||||
parser.add_option("-g", "--galoissize", dest="galois",
|
||||
default=defaults.galois,
|
||||
help="bit length of Galois field")
|
||||
parser.add_option("-d", "--debug", action="store_true", dest="debug",
|
||||
help="keep track of trace for debugging")
|
||||
parser.add_option("-c", "--comparison", dest="comparison", default="log",
|
||||
help="comparison variant: log|plain|inv|sinv")
|
||||
parser.add_option("-M", "--preserve-mem-order", action="store_true",
|
||||
dest="preserve_mem_order",
|
||||
default=defaults.preserve_mem_order,
|
||||
help="preserve order of memory instructions; possible efficiency loss")
|
||||
parser.add_option("-O", "--optimize-hard", action="store_true",
|
||||
dest="optimize_hard", help="currently not in use")
|
||||
parser.add_option("-u", "--noreallocate", action="store_true", dest="noreallocate",
|
||||
default=defaults.noreallocate, help="don't reallocate")
|
||||
parser.add_option("-m", "--max-parallel-open", dest="max_parallel_open",
|
||||
default=defaults.max_parallel_open,
|
||||
help="restrict number of parallel opens")
|
||||
parser.add_option("-D", "--dead-code-elimination", action="store_true",
|
||||
dest="dead_code_elimination",
|
||||
default=defaults.dead_code_elimination,
|
||||
help="eliminate instructions with unused result")
|
||||
parser.add_option("-p", "--profile", action="store_true", dest="profile",
|
||||
help="profile compilation")
|
||||
parser.add_option("-s", "--stop", action="store_true", dest="stop",
|
||||
help="stop on register errors")
|
||||
parser.add_option("-R", "--ring", dest="ring", default=defaults.ring,
|
||||
help="bit length of ring (default: 0 for field)")
|
||||
parser.add_option("-B", "--binary", dest="binary", default=defaults.binary,
|
||||
help="bit length of sint in binary circuit (default: 0 for arithmetic)")
|
||||
parser.add_option("-F", "--field", dest="field", default=defaults.field,
|
||||
help="bit length of sint modulo prime (default: 64)")
|
||||
parser.add_option("-P", "--prime", dest="prime", default=defaults.prime,
|
||||
help="prime modulus (default: not specified)")
|
||||
parser.add_option("-I", "--insecure", action="store_true", dest="insecure",
|
||||
help="activate insecure functionality for benchmarking")
|
||||
parser.add_option("-b", "--budget", dest="budget", default=defaults.budget,
|
||||
help="set budget for optimized loop unrolling "
|
||||
"(default: 100000)")
|
||||
parser.add_option("-X", "--mixed", action="store_true", dest="mixed",
|
||||
help="mixing arithmetic and binary computation")
|
||||
parser.add_option("-Y", "--edabit", action="store_true", dest="edabit",
|
||||
help="mixing arithmetic and binary computation using edaBits")
|
||||
parser.add_option("-Z", "--split", default=defaults.split, dest="split",
|
||||
help="mixing arithmetic and binary computation "
|
||||
"using direct conversion if supported "
|
||||
"(number of parties as argument)")
|
||||
parser.add_option("-C", "--CISC", action="store_true", dest="cisc",
|
||||
help="faster CISC compilation mode")
|
||||
parser.add_option("-K", "--keep-cisc", dest="keep_cisc",
|
||||
help="don't translate CISC instructions")
|
||||
parser.add_option("-l", "--flow-optimization", action="store_true",
|
||||
dest="flow_optimization", help="optimize control flow")
|
||||
parser.add_option("-v", "--verbose", action="store_true", dest="verbose",
|
||||
help="more verbose output")
|
||||
options,args = parser.parse_args()
|
||||
if len(args) < 1:
|
||||
parser.print_help()
|
||||
return
|
||||
if prog.public_input_file is not None:
|
||||
print(
|
||||
"WARNING: %s is required to run the program" % prog.public_input_file.name
|
||||
)
|
||||
|
||||
if options.optimize_hard:
|
||||
print('Note that -O/--optimize-hard currently has no effect')
|
||||
|
||||
def compilation():
|
||||
prog = Compiler.run(args, options)
|
||||
|
||||
if prog.public_input_file is not None:
|
||||
print('WARNING: %s is required to run the program' % \
|
||||
prog.public_input_file.name)
|
||||
|
||||
if options.profile:
|
||||
def main(compiler):
|
||||
compiler.prep_compile()
|
||||
if compiler.options.profile:
|
||||
import cProfile
|
||||
p = cProfile.Profile().runctx('compilation()', globals(), locals())
|
||||
p.dump_stats(args[0] + '.prof')
|
||||
|
||||
p = cProfile.Profile().runctx("compilation(compiler)", globals(), locals())
|
||||
p.dump_stats(compiler.args[0] + ".prof")
|
||||
p.print_stats(2)
|
||||
else:
|
||||
compilation()
|
||||
compilation(compiler)
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
|
||||
if __name__ == "__main__":
|
||||
compiler = Compiler()
|
||||
main(compiler)
|
||||
|
||||
Reference in New Issue
Block a user