diff --git a/Compiler/__init__.py b/Compiler/__init__.py index 9a22da46..6a0d6b1d 100644 --- a/Compiler/__init__.py +++ b/Compiler/__init__.py @@ -2,30 +2,3 @@ from . import compilerLib, program, instructions, types, library, floatingpoint from .GC import types as GC_types import inspect from .config import * -from .compilerLib import run - - -# add all instructions to the program VARS dictionary -compilerLib.VARS = {} -instr_classes = [t[1] for t in inspect.getmembers(instructions, inspect.isclass)] - -for mod in (types, GC_types): - instr_classes += [t[1] for t in inspect.getmembers(mod, inspect.isclass)\ - if t[1].__module__ == mod.__name__] - -instr_classes += [t[1] for t in inspect.getmembers(library, inspect.isfunction)\ - if t[1].__module__ == library.__name__] - -for op in instr_classes: - compilerLib.VARS[op.__name__] = op - -# add open and input separately due to name conflict -compilerLib.VARS['open'] = instructions.asm_open -compilerLib.VARS['vopen'] = instructions.vasm_open -compilerLib.VARS['gopen'] = instructions.gasm_open -compilerLib.VARS['vgopen'] = instructions.vgasm_open -compilerLib.VARS['input'] = instructions.asm_input -compilerLib.VARS['ginput'] = instructions.gasm_input - -compilerLib.VARS['comparison'] = comparison -compilerLib.VARS['floatingpoint'] = floatingpoint diff --git a/Compiler/compilerLib.py b/Compiler/compilerLib.py index b2898e21..591700c1 100644 --- a/Compiler/compilerLib.py +++ b/Compiler/compilerLib.py @@ -1,94 +1,361 @@ -from Compiler.program import Program -from .GC import types as GC_types - +import inspect +import os +import re import sys -import re, tempfile, os +import tempfile +from optparse import OptionParser + +from .GC import types as GC_types +from .program import Program, defaults -def run(args, options): - """ Compile a file and output a Program object. - - If options.merge_opens is set to True, will attempt to merge any - parallelisable open instructions. """ - - prog = Program(args, options) - VARS['program'] = prog - if options.binary: - VARS['sint'] = GC_types.sbitintvec.get_type(int(options.binary)) - VARS['sfix'] = GC_types.sbitfixvec - for i in 'cint', 'cfix', 'cgf2n', 'sintbit', 'sgf2n', 'sgf2nint', \ - 'sgf2nuint', 'sgf2nuint32', 'sgf2nfloat', 'sfloat', 'cfloat', \ - 'squant': - del VARS[i] - - print('Compiling file', prog.infile) - f = open(prog.infile, 'rb') +class Compiler: + def __init__(self): + self.usage = "usage: %prog [options] filename [args]" + self.build_option_parser() + self.VARS = {} - changed = False - if options.flow_optimization: - output = [] - if_stack = [] - for line in open(prog.infile): - if if_stack and not re.match(if_stack[-1][0], line): - if_stack.pop() - m = re.match( - '(\s*)for +([a-zA-Z_]+) +in +range\(([0-9a-zA-Z_]+)\):', - line) - if m: - output.append('%s@for_range_opt(%s)\n' % (m.group(1), - m.group(3))) - output.append('%sdef _(%s):\n' % (m.group(1), m.group(2))) - changed = True - continue - m = re.match('(\s*)if(\W.*):', line) - if m: - if_stack.append((m.group(1), len(output))) - output.append('%s@if_(%s)\n' % (m.group(1), m.group(2))) - output.append('%sdef _():\n' % (m.group(1))) - changed = True - continue - m = re.match('(\s*)elif\s+', line) - if m: - raise CompilerError('elif not supported') - if if_stack: - m = re.match('%selse:' % if_stack[-1][0], line) - if m: - start = if_stack[-1][1] - ws = if_stack[-1][0] - output[start] = re.sub(r'^%s@if_\(' % ws, r'%s@if_e(' % ws, - output[start]) - output.append('%s@else_\n' % ws) - output.append('%sdef _():\n' % ws) - continue - output.append(line) - if changed: - infile = tempfile.NamedTemporaryFile('w+', delete=False) - for line in output: - infile.write(line) - infile.seek(0) - else: - infile = open(prog.infile) - else: - infile = open(prog.infile) + def build_option_parser(self): + parser = OptionParser(usage=self.usage) + parser.add_option( + "-n", + "--nomerge", + action="store_false", + dest="merge_opens", + default=defaults.merge_opens, + help="don't attempt to merge open instructions", + ) + parser.add_option("-o", "--output", dest="outfile", help="specify output file") + parser.add_option( + "-a", + "--asm-output", + dest="asmoutfile", + help="asm output file for debugging", + ) + parser.add_option( + "-g", + "--galoissize", + dest="galois", + default=defaults.galois, + help="bit length of Galois field", + ) + parser.add_option( + "-d", + "--debug", + action="store_true", + dest="debug", + help="keep track of trace for debugging", + ) + parser.add_option( + "-c", + "--comparison", + dest="comparison", + default="log", + help="comparison variant: log|plain|inv|sinv", + ) + parser.add_option( + "-M", + "--preserve-mem-order", + action="store_true", + dest="preserve_mem_order", + default=defaults.preserve_mem_order, + help="preserve order of memory instructions; possible efficiency loss", + ) + parser.add_option( + "-O", + "--optimize-hard", + action="store_true", + dest="optimize_hard", + help="currently not in use", + ) + parser.add_option( + "-u", + "--noreallocate", + action="store_true", + dest="noreallocate", + default=defaults.noreallocate, + help="don't reallocate", + ) + parser.add_option( + "-m", + "--max-parallel-open", + dest="max_parallel_open", + default=defaults.max_parallel_open, + help="restrict number of parallel opens", + ) + parser.add_option( + "-D", + "--dead-code-elimination", + action="store_true", + dest="dead_code_elimination", + default=defaults.dead_code_elimination, + help="eliminate instructions with unused result", + ) + parser.add_option( + "-p", + "--profile", + action="store_true", + dest="profile", + help="profile compilation", + ) + parser.add_option( + "-s", + "--stop", + action="store_true", + dest="stop", + help="stop on register errors", + ) + parser.add_option( + "-R", + "--ring", + dest="ring", + default=defaults.ring, + help="bit length of ring (default: 0 for field)", + ) + parser.add_option( + "-B", + "--binary", + dest="binary", + default=defaults.binary, + help="bit length of sint in binary circuit (default: 0 for arithmetic)", + ) + parser.add_option( + "-F", + "--field", + dest="field", + default=defaults.field, + help="bit length of sint modulo prime (default: 64)", + ) + parser.add_option( + "-P", + "--prime", + dest="prime", + default=defaults.prime, + help="prime modulus (default: not specified)", + ) + parser.add_option( + "-I", + "--insecure", + action="store_true", + dest="insecure", + help="activate insecure functionality for benchmarking", + ) + parser.add_option( + "-b", + "--budget", + dest="budget", + default=defaults.budget, + help="set budget for optimized loop unrolling " "(default: 100000)", + ) + parser.add_option( + "-X", + "--mixed", + action="store_true", + dest="mixed", + help="mixing arithmetic and binary computation", + ) + parser.add_option( + "-Y", + "--edabit", + action="store_true", + dest="edabit", + help="mixing arithmetic and binary computation using edaBits", + ) + parser.add_option( + "-Z", + "--split", + default=defaults.split, + dest="split", + help="mixing arithmetic and binary computation " + "using direct conversion if supported " + "(number of parties as argument)", + ) + parser.add_option( + "-C", + "--CISC", + action="store_true", + dest="cisc", + help="faster CISC compilation mode", + ) + parser.add_option( + "-K", + "--keep-cisc", + dest="keep_cisc", + help="don't translate CISC instructions", + ) + parser.add_option( + "-l", + "--flow-optimization", + action="store_true", + dest="flow_optimization", + help="optimize control flow", + ) + parser.add_option( + "-v", + "--verbose", + action="store_true", + dest="verbose", + help="more verbose output", + ) + self.parser = parser - # make compiler modules directly accessible - sys.path.insert(0, 'Compiler') - # create the tapes - exec(compile(infile.read(), infile.name, 'exec'), VARS) + def parse_args(self): + self.options, self.args = self.parser.parse_args() + if len(self.args) < 1: + self.parser.print_help() + return - if changed and not options.debug: - os.unlink(infile.name) + if self.options.optimize_hard: + print("Note that -O/--optimize-hard currently has no effect") - prog.finalize() + def build_program(self): + self.prog = Program(self.args, self.options) - if prog.req_num: - print('Program requires at most:') - for x in prog.req_num.pretty(): - print(x) + def build_vars(self): + from . import comparison, floatingpoint, instructions, library, types - if prog.verbose: - print('Program requires:', repr(prog.req_num)) - print('Cost:', 0 if prog.req_num is None else prog.req_num.cost()) - print('Memory size:', dict(prog.allocated_mem)) + # add all instructions to the program VARS dictionary + instr_classes = [ + t[1] for t in inspect.getmembers(instructions, inspect.isclass) + ] - return prog + for mod in (types, GC_types): + instr_classes += [ + t[1] + for t in inspect.getmembers(mod, inspect.isclass) + if t[1].__module__ == mod.__name__ + ] + + instr_classes += [ + t[1] + for t in inspect.getmembers(library, inspect.isfunction) + if t[1].__module__ == library.__name__ + ] + + for op in instr_classes: + self.VARS[op.__name__] = op + + # add open and input separately due to name conflict + self.VARS["open"] = instructions.asm_open + self.VARS["vopen"] = instructions.vasm_open + self.VARS["gopen"] = instructions.gasm_open + self.VARS["vgopen"] = instructions.vgasm_open + self.VARS["input"] = instructions.asm_input + self.VARS["ginput"] = instructions.gasm_input + + self.VARS["comparison"] = comparison + self.VARS["floatingpoint"] = floatingpoint + + self.VARS["program"] = self.prog + if self.options.binary: + self.VARS["sint"] = GC_types.sbitintvec.get_type(int(self.options.binary)) + self.VARS["sfix"] = GC_types.sbitfixvec + for i in [ + "cint", + "cfix", + "cgf2n", + "sintbit", + "sgf2n", + "sgf2nint", + "sgf2nuint", + "sgf2nuint32", + "sgf2nfloat", + "sfloat", + "cfloat", + "squant", + ]: + del self.VARS[i] + + def prep_compile(self): + self.parse_args() + self.build_program() + self.build_vars() + + def compile_file(self): + """Compile a file and output a Program object. + + If options.merge_opens is set to True, will attempt to merge any + parallelisable open instructions.""" + print("Compiling file", self.prog.infile) + + with open(self.prog.infile, "rb") as f: + changed = False + if self.options.flow_optimization: + output = [] + if_stack = [] + for line in f: + if if_stack and not re.match(if_stack[-1][0], line): + if_stack.pop() + m = re.match( + r"(\s*)for +([a-zA-Z_]+) +in " r"+range\(([0-9a-zA-Z_]+)\):", + line, + ) + if m: + output.append( + "%s@for_range_opt(%s)\n" % (m.group(1), m.group(3)) + ) + output.append("%sdef _(%s):\n" % (m.group(1), m.group(2))) + changed = True + continue + m = re.match(r"(\s*)if(\W.*):", line) + if m: + if_stack.append((m.group(1), len(output))) + output.append("%s@if_(%s)\n" % (m.group(1), m.group(2))) + output.append("%sdef _():\n" % (m.group(1))) + changed = True + continue + m = re.match(r"(\s*)elif\s+", line) + if m: + raise CompilerError("elif not supported") + if if_stack: + m = re.match("%selse:" % if_stack[-1][0], line) + if m: + start = if_stack[-1][1] + ws = if_stack[-1][0] + output[start] = re.sub( + r"^%s@if_\(" % ws, r"%s@if_e(" % ws, output[start] + ) + output.append("%s@else_\n" % ws) + output.append("%sdef _():\n" % ws) + continue + output.append(line) + if changed: + infile = tempfile.NamedTemporaryFile("w+", delete=False) + for line in output: + infile.write(line) + infile.seek(0) + else: + infile = open(self.prog.infile) + else: + infile = open(self.prog.infile) + + # make compiler modules directly accessible + sys.path.insert(0, "Compiler") + # create the tapes + exec(compile(infile.read(), infile.name, "exec"), self.VARS) + + if changed and not self.options.debug: + os.unlink(infile.name) + + return self.finalize_compile() + + def compile_func(self, f): + self.prep_compile() + print(f"Compiling function: {f.__name__}") + f(self.VARS) + self.finalize_compile() + + def finalize_compile(self): + self.prog.finalize() + + if self.prog.req_num: + print("Program requires at most:") + for x in self.prog.req_num.pretty(): + print(x) + + if self.prog.verbose: + print("Program requires:", repr(self.prog.req_num)) + print("Cost:", 0 if self.prog.req_num is None else self.prog.req_num.cost()) + print("Memory size:", dict(self.prog.allocated_mem)) + + return self.prog diff --git a/compile.py b/compile.py index da1b69ee..50671a04 100755 --- a/compile.py +++ b/compile.py @@ -12,100 +12,30 @@ # # See the compiler documentation at https://mp-spdz.readthedocs.io # for details on the Compiler package +from Compiler.compilerLib import Compiler -from optparse import OptionParser -from Compiler.program import defaults -import Compiler +def compilation(compiler): + prog = compiler.compile_file() -def main(): - usage = "usage: %prog [options] filename [args]" - parser = OptionParser(usage=usage) - parser.add_option("-n", "--nomerge", - action="store_false", dest="merge_opens", - default=defaults.merge_opens, - help="don't attempt to merge open instructions") - parser.add_option("-o", "--output", dest="outfile", - help="specify output file") - parser.add_option("-a", "--asm-output", dest="asmoutfile", - help="asm output file for debugging") - parser.add_option("-g", "--galoissize", dest="galois", - default=defaults.galois, - help="bit length of Galois field") - parser.add_option("-d", "--debug", action="store_true", dest="debug", - help="keep track of trace for debugging") - parser.add_option("-c", "--comparison", dest="comparison", default="log", - help="comparison variant: log|plain|inv|sinv") - parser.add_option("-M", "--preserve-mem-order", action="store_true", - dest="preserve_mem_order", - default=defaults.preserve_mem_order, - help="preserve order of memory instructions; possible efficiency loss") - parser.add_option("-O", "--optimize-hard", action="store_true", - dest="optimize_hard", help="currently not in use") - parser.add_option("-u", "--noreallocate", action="store_true", dest="noreallocate", - default=defaults.noreallocate, help="don't reallocate") - parser.add_option("-m", "--max-parallel-open", dest="max_parallel_open", - default=defaults.max_parallel_open, - help="restrict number of parallel opens") - parser.add_option("-D", "--dead-code-elimination", action="store_true", - dest="dead_code_elimination", - default=defaults.dead_code_elimination, - help="eliminate instructions with unused result") - parser.add_option("-p", "--profile", action="store_true", dest="profile", - help="profile compilation") - parser.add_option("-s", "--stop", action="store_true", dest="stop", - help="stop on register errors") - parser.add_option("-R", "--ring", dest="ring", default=defaults.ring, - help="bit length of ring (default: 0 for field)") - parser.add_option("-B", "--binary", dest="binary", default=defaults.binary, - help="bit length of sint in binary circuit (default: 0 for arithmetic)") - parser.add_option("-F", "--field", dest="field", default=defaults.field, - help="bit length of sint modulo prime (default: 64)") - parser.add_option("-P", "--prime", dest="prime", default=defaults.prime, - help="prime modulus (default: not specified)") - parser.add_option("-I", "--insecure", action="store_true", dest="insecure", - help="activate insecure functionality for benchmarking") - parser.add_option("-b", "--budget", dest="budget", default=defaults.budget, - help="set budget for optimized loop unrolling " - "(default: 100000)") - parser.add_option("-X", "--mixed", action="store_true", dest="mixed", - help="mixing arithmetic and binary computation") - parser.add_option("-Y", "--edabit", action="store_true", dest="edabit", - help="mixing arithmetic and binary computation using edaBits") - parser.add_option("-Z", "--split", default=defaults.split, dest="split", - help="mixing arithmetic and binary computation " - "using direct conversion if supported " - "(number of parties as argument)") - parser.add_option("-C", "--CISC", action="store_true", dest="cisc", - help="faster CISC compilation mode") - parser.add_option("-K", "--keep-cisc", dest="keep_cisc", - help="don't translate CISC instructions") - parser.add_option("-l", "--flow-optimization", action="store_true", - dest="flow_optimization", help="optimize control flow") - parser.add_option("-v", "--verbose", action="store_true", dest="verbose", - help="more verbose output") - options,args = parser.parse_args() - if len(args) < 1: - parser.print_help() - return + if prog.public_input_file is not None: + print( + "WARNING: %s is required to run the program" % prog.public_input_file.name + ) - if options.optimize_hard: - print('Note that -O/--optimize-hard currently has no effect') - def compilation(): - prog = Compiler.run(args, options) - - if prog.public_input_file is not None: - print('WARNING: %s is required to run the program' % \ - prog.public_input_file.name) - - if options.profile: +def main(compiler): + compiler.prep_compile() + if compiler.options.profile: import cProfile - p = cProfile.Profile().runctx('compilation()', globals(), locals()) - p.dump_stats(args[0] + '.prof') + + p = cProfile.Profile().runctx("compilation(compiler)", globals(), locals()) + p.dump_stats(compiler.args[0] + ".prof") p.print_stats(2) else: - compilation() + compilation(compiler) -if __name__ == '__main__': - main() + +if __name__ == "__main__": + compiler = Compiler() + main(compiler)