Files
MP-SPDZ/Compiler/compilerLib.py
Marcel Keller 2813c0ef0f Maintenance.
2023-08-14 18:29:46 +10:00

570 lines
19 KiB
Python

import inspect
import os
import re
import sys
import tempfile
import subprocess
from optparse import OptionParser
from Compiler.exceptions import CompilerError
from .GC import types as GC_types
from .program import Program, defaults
class Compiler:
def __init__(self, custom_args=None, usage=None, execute=False):
if usage:
self.usage = usage
else:
self.usage = "usage: %prog [options] filename [args]"
self.execute = execute
self.custom_args = custom_args
self.build_option_parser()
self.VARS = {}
self.root = os.path.dirname(__file__) + '/..'
def build_option_parser(self):
parser = OptionParser(usage=self.usage)
parser.add_option(
"-n",
"--nomerge",
action="store_false",
dest="merge_opens",
default=defaults.merge_opens,
help="don't attempt to merge open instructions",
)
parser.add_option("-o", "--output", dest="outfile", help="specify output file")
parser.add_option(
"-a",
"--asm-output",
dest="asmoutfile",
help="asm output file for debugging",
)
parser.add_option(
"-g",
"--galoissize",
dest="galois",
default=defaults.galois,
help="bit length of Galois field",
)
parser.add_option(
"-d",
"--debug",
action="store_true",
dest="debug",
help="keep track of trace for debugging",
)
parser.add_option(
"-c",
"--comparison",
dest="comparison",
default="log",
help="comparison variant: log|plain|inv|sinv",
)
parser.add_option(
"-M",
"--preserve-mem-order",
action="store_true",
dest="preserve_mem_order",
default=defaults.preserve_mem_order,
help="preserve order of memory instructions; possible efficiency loss",
)
parser.add_option(
"-O",
"--optimize-hard",
action="store_true",
dest="optimize_hard",
help="lower number of rounds at higher compilation cost "
"(disables -C and increases the budget to 100000)",
)
parser.add_option(
"-u",
"--noreallocate",
action="store_true",
dest="noreallocate",
default=defaults.noreallocate,
help="don't reallocate",
)
parser.add_option(
"-m",
"--max-parallel-open",
dest="max_parallel_open",
default=defaults.max_parallel_open,
help="restrict number of parallel opens",
)
parser.add_option(
"-D",
"--dead-code-elimination",
action="store_true",
dest="dead_code_elimination",
default=defaults.dead_code_elimination,
help="eliminate instructions with unused result",
)
parser.add_option(
"-p",
"--profile",
action="store_true",
dest="profile",
help="profile compilation",
)
parser.add_option(
"-s",
"--stop",
action="store_true",
dest="stop",
help="stop on register errors",
)
parser.add_option(
"-R",
"--ring",
dest="ring",
default=defaults.ring,
help="bit length of ring (default: 0 for field)",
)
parser.add_option(
"-B",
"--binary",
dest="binary",
default=defaults.binary,
help="bit length of sint in binary circuit (default: 0 for arithmetic)",
)
parser.add_option(
"-G",
"--garbled-circuit",
dest="garbled",
action="store_true",
help="compile for binary circuits only (default: false)",
)
parser.add_option(
"-F",
"--field",
dest="field",
default=defaults.field,
help="bit length of sint modulo prime (default: 64)",
)
parser.add_option(
"-P",
"--prime",
dest="prime",
default=defaults.prime,
help="prime modulus (default: not specified)",
)
parser.add_option(
"-I",
"--insecure",
action="store_true",
dest="insecure",
help="activate insecure functionality for benchmarking",
)
parser.add_option(
"-b",
"--budget",
dest="budget",
help="set budget for optimized loop unrolling (default: %d)" % \
defaults.budget,
)
parser.add_option(
"-X",
"--mixed",
action="store_true",
dest="mixed",
help="mixing arithmetic and binary computation",
)
parser.add_option(
"-Y",
"--edabit",
action="store_true",
dest="edabit",
help="mixing arithmetic and binary computation using edaBits",
)
parser.add_option(
"-Z",
"--split",
default=defaults.split,
dest="split",
help="mixing arithmetic and binary computation "
"using direct conversion if supported "
"(number of parties as argument)",
)
parser.add_option(
"--invperm",
action="store_true",
dest="invperm",
help="speedup inverse permutation (only use in two-party, "
"semi-honest environment)"
)
parser.add_option(
"-C",
"--CISC",
action="store_true",
dest="cisc",
help="faster CISC compilation mode "
"(used by default unless -O is given)",
)
parser.add_option(
"-K",
"--keep-cisc",
dest="keep_cisc",
help="don't translate CISC instructions",
)
parser.add_option(
"-l",
"--flow-optimization",
action="store_true",
dest="flow_optimization",
help="optimize control flow",
)
parser.add_option(
"-v",
"--verbose",
action="store_true",
dest="verbose",
help="more verbose output",
)
if self.execute:
parser.add_option(
"-E",
"--execute",
dest="execute",
help="protocol to execute with",
)
parser.add_option(
"-H",
"--hostfile",
dest="hostfile",
help="hosts to execute with",
)
self.parser = parser
def parse_args(self):
self.options, self.args = self.parser.parse_args(self.custom_args)
if self.execute:
if not self.options.execute:
raise CompilerError("must give name of protocol with '-E'")
protocol = self.options.execute
if protocol.find("ring") >= 0 or protocol.find("2k") >= 0 or \
protocol.find("brain") >= 0 or protocol == "emulate":
if not (self.options.ring or self.options.binary):
self.options.ring = "64"
if self.options.field:
raise CompilerError(
"field option not compatible with %s" % protocol)
else:
if protocol.find("bin") >= 0 or protocol.find("ccd") >= 0 or \
protocol.find("bmr") >= 0 or \
protocol in ("replicated", "tinier", "tiny", "yao"):
if not self.options.binary:
self.options.binary = "32"
if self.options.ring or self.options.field:
raise CompilerError(
"ring/field options not compatible with %s" %
protocol)
if self.options.ring:
raise CompilerError(
"ring option not compatible with %s" % protocol)
if protocol == "emulate":
self.options.keep_cisc = ''
def build_program(self, name=None):
self.prog = Program(self.args, self.options, name=name)
if self.execute:
if self.options.execute in \
("emulate", "ring", "rep-field", "rep4-ring"):
self.prog.use_trunc_pr = True
if self.options.execute in ("ring", "ps-rep-ring", "sy-rep-ring"):
self.prog.use_split(3)
if self.options.execute in ("semi2k",):
self.prog.use_split(2)
if self.options.execute in ("rep4-ring",):
self.prog.use_split(4)
def build_vars(self):
from . import comparison, floatingpoint, instructions, library, types
# add all instructions to the program VARS dictionary
instr_classes = [
t[1] for t in inspect.getmembers(instructions, inspect.isclass)
]
for mod in (types, GC_types):
instr_classes += [
t[1]
for t in inspect.getmembers(mod, inspect.isclass)
if t[1].__module__ == mod.__name__
]
instr_classes += [
t[1]
for t in inspect.getmembers(library, inspect.isfunction)
if t[1].__module__ == library.__name__
]
for op in instr_classes:
self.VARS[op.__name__] = op
# backward compatibility for deprecated classes
self.VARS["sbitint"] = GC_types.sbitintvec
self.VARS["sbitfix"] = GC_types.sbitfixvec
# add open and input separately due to name conflict
self.VARS["vopen"] = instructions.vasm_open
self.VARS["gopen"] = instructions.gasm_open
self.VARS["vgopen"] = instructions.vgasm_open
self.VARS["ginput"] = instructions.gasm_input
self.VARS["comparison"] = comparison
self.VARS["floatingpoint"] = floatingpoint
self.VARS["program"] = self.prog
if self.options.binary:
self.VARS["sint"] = GC_types.sbitintvec.get_type(int(self.options.binary))
self.VARS["sfix"] = GC_types.sbitfixvec
for i in [
"cint",
"cfix",
"cgf2n",
"sintbit",
"sgf2n",
"sgf2nint",
"sgf2nuint",
"sgf2nuint32",
"sgf2nfloat",
"cfloat",
"squant",
]:
del self.VARS[i]
def prep_compile(self, name=None, build=True):
self.parse_args()
if len(self.args) < 1 and name is None:
self.parser.print_help()
exit(1)
if build:
self.build(name=name)
def build(self, name=None):
self.build_program(name=name)
self.build_vars()
def compile_file(self):
"""Compile a file and output a Program object.
If options.merge_opens is set to True, will attempt to merge any
parallelisable open instructions."""
print("Compiling file", self.prog.infile)
with open(self.prog.infile, "r") as f:
changed = False
if self.options.flow_optimization:
output = []
if_stack = []
for line in f:
if if_stack and not re.match(if_stack[-1][0], line):
if_stack.pop()
m = re.match(
r"(\s*)for +([a-zA-Z_]+) +in " r"+range\(([0-9a-zA-Z_.]+)\):",
line,
)
if m:
output.append(
"%s@for_range_opt(%s)\n" % (m.group(1), m.group(3))
)
output.append("%sdef _(%s):\n" % (m.group(1), m.group(2)))
changed = True
continue
m = re.match(r"(\s*)if(\W.*):", line)
if m:
if_stack.append((m.group(1), len(output)))
output.append("%s@if_(%s)\n" % (m.group(1), m.group(2)))
output.append("%sdef _():\n" % (m.group(1)))
changed = True
continue
m = re.match(r"(\s*)elif\s+", line)
if m:
raise CompilerError("elif not supported")
if if_stack:
m = re.match("%selse:" % if_stack[-1][0], line)
if m:
start = if_stack[-1][1]
ws = if_stack[-1][0]
output[start] = re.sub(
r"^%s@if_\(" % ws, r"%s@if_e(" % ws, output[start]
)
output.append("%s@else_\n" % ws)
output.append("%sdef _():\n" % ws)
continue
output.append(line)
if changed:
infile = tempfile.NamedTemporaryFile("w+", delete=False)
for line in output:
infile.write(line)
infile.seek(0)
else:
infile = open(self.prog.infile)
else:
infile = open(self.prog.infile)
# make compiler modules directly accessible
sys.path.insert(0, "%s/Compiler" % self.root)
# create the tapes
exec(compile(infile.read(), infile.name, "exec"), self.VARS)
if changed and not self.options.debug:
os.unlink(infile.name)
return self.finalize_compile()
def register_function(self, name=None):
"""
To register a function to be compiled, use this as a decorator.
Example:
@compiler.register_function('test-mpc')
def test_mpc(compiler):
...
"""
def inner(func):
self.compile_name = name or func.__name__
self.compile_function = func
return func
return inner
def compile_func(self):
if not (hasattr(self, "compile_name") and hasattr(self, "compile_func")):
raise CompilerError(
"No function to compile. "
"Did you decorate a function with @register_fuction(name)?"
)
self.prep_compile(self.compile_name)
print(
"Compiling: {} from {}".format(self.compile_name, self.compile_func.__name__)
)
self.compile_function()
self.finalize_compile()
def finalize_compile(self):
self.prog.finalize()
if self.prog.req_num:
print("Program requires at most:")
for x in self.prog.req_num.pretty():
print(x)
if self.prog.verbose:
print("Program requires:", repr(self.prog.req_num))
print("Cost:", 0 if self.prog.req_num is None else self.prog.req_num.cost())
print("Memory size:", dict(self.prog.allocated_mem))
return self.prog
@staticmethod
def executable_from_protocol(protocol):
match = {
"ring": "replicated-ring",
"rep-field": "replicated-field",
"replicated": "replicated-bin"
}
if protocol in match:
protocol = match[protocol]
if protocol.find("bmr") == -1:
protocol = re.sub("^mal-", "malicious-", protocol)
if protocol == "emulate":
return protocol + ".x"
else:
return protocol + "-party.x"
def local_execution(self, args=[]):
executable = self.executable_from_protocol(self.options.execute)
if not os.path.exists("%s/%s" % (self.root, executable)):
print("Creating binary for virtual machine...")
try:
subprocess.run(["make", executable], check=True, cwd=self.root)
except:
raise CompilerError(
"Cannot produce %s. " % executable + \
"Note that compilation requires a few GB of RAM.")
vm = "%s/Scripts/%s.sh" % (self.root, self.options.execute)
sys.stdout.flush()
os.execl(vm, vm, self.prog.name, *args)
def remote_execution(self, args=[]):
vm = self.executable_from_protocol(self.options.execute)
hosts = list(x.strip()
for x in filter(None, open(self.options.hostfile)))
# test availability before compilation
from fabric import Connection
import subprocess
print("Creating static binary for virtual machine...")
subprocess.run(["make", "static/%s" % vm], check=True, cwd=self.root)
# transfer files
import glob
hostnames = []
destinations = []
for host in hosts:
split = host.split('/', maxsplit=1)
hostnames.append(split[0])
if len(split) > 1:
destinations.append(split[1])
else:
destinations.append('.')
connections = [Connection(hostname) for hostname in hostnames]
print("Setting up players...")
def run(i):
dest = destinations[i]
connection = connections[i]
connection.run(
"mkdir -p %s/{Player-Data,Programs/{Bytecode,Schedules}} " % \
dest)
# executable
connection.put("%s/static/%s" % (self.root, vm), dest)
# program
dest += "/"
connection.put("Programs/Schedules/%s.sch" % self.prog.name,
dest + "Programs/Schedules")
for filename in glob.glob(
"Programs/Bytecode/%s-*.bc" % self.prog.name):
connection.put(filename, dest + "Programs/Bytecode")
# inputs
for filename in glob.glob("Player-Data/Input*-P%d-*" % i):
connection.put(filename, dest + "Player-Data")
# key and certificates
for suffix in ('key', 'pem'):
connection.put("Player-Data/P%d.%s" % (i, suffix),
dest + "Player-Data")
for filename in glob.glob("Player-Data/*.0"):
connection.put(filename, dest + "Player-Data")
import threading
import random
threads = []
for i in range(len(hosts)):
threads.append(threading.Thread(target=run, args=(i,)))
for thread in threads:
thread.start()
for thread in threads:
thread.join()
# execution
threads = []
# random port numbers to avoid conflict
port = 10000 + random.randrange(40000)
if '@' in hostnames[0]:
party0 = hostnames[0].split('@')[1]
else:
party0 = hostnames[0]
for i in range(len(connections)):
run = lambda i: connections[i].run(
"cd %s; ./%s -p %d %s -h %s -pn %d %s" % \
(destinations[i], vm, i, self.prog.name, party0, port,
' '.join(args)))
threads.append(threading.Thread(target=run, args=(i,)))
for thread in threads:
thread.start()
for thread in threads:
thread.join()