Files
MP-SPDZ/Compiler/compilerLib.py
2024-11-21 13:14:54 +11:00

690 lines
24 KiB
Python

import inspect
import os
import re
import sys
import tempfile
import subprocess
from optparse import OptionParser
from Compiler.exceptions import CompilerError
from .GC import types as GC_types
from .program import Program, defaults
class Compiler:
singleton = None
def __init__(self, custom_args=None, usage=None, execute=False,
split_args=False):
if Compiler.singleton:
raise CompilerError(
"Cannot have more than one compiler instance. "
"It's not possible to run direct compilation programs with "
"compile.py or compile-run.py.")
else:
Compiler.singleton = self
if usage:
self.usage = usage
else:
self.usage = "usage: %prog [options] filename [args]"
self.execute = execute
self.runtime_args = []
if split_args:
if custom_args is None:
args = sys.argv
else:
args = custom_args
try:
split = args.index('--')
except ValueError:
split = len(args)
custom_args = args[1:split]
self.runtime_args = args[split + 1:]
self.custom_args = custom_args
self.build_option_parser()
self.VARS = {}
self.root = os.path.dirname(__file__) + '/..'
def build_option_parser(self):
parser = OptionParser(usage=self.usage)
parser.add_option(
"-n",
"--nomerge",
action="store_false",
dest="merge_opens",
default=defaults.merge_opens,
help="don't attempt to merge open instructions",
)
parser.add_option("-o", "--output", dest="outfile", help="specify output file")
parser.add_option(
"-a",
"--asm-output",
dest="asmoutfile",
help="asm output file for debugging",
)
parser.add_option(
"-g",
"--galoissize",
dest="galois",
default=defaults.galois,
help="bit length of Galois field",
)
parser.add_option(
"-d",
"--debug",
action="store_true",
dest="debug",
help="keep track of trace for debugging",
)
parser.add_option(
"-c",
"--comparison",
dest="comparison",
default="log",
help="comparison variant: log|plain|inv|sinv",
)
parser.add_option(
"-M",
"--preserve-mem-order",
action="store_true",
dest="preserve_mem_order",
default=defaults.preserve_mem_order,
help="preserve order of memory instructions; possible efficiency loss",
)
parser.add_option(
"-O",
"--optimize-hard",
action="store_true",
dest="optimize_hard",
help="lower number of rounds at higher compilation cost "
"(disables -C and increases the budget to 100000)",
)
parser.add_option(
"-u",
"--noreallocate",
action="store_true",
dest="noreallocate",
default=defaults.noreallocate,
help="don't reallocate",
)
parser.add_option(
"-m",
"--max-parallel-open",
dest="max_parallel_open",
default=defaults.max_parallel_open,
help="restrict number of parallel opens",
)
parser.add_option(
"-D",
"--dead-code-elimination",
action="store_true",
dest="dead_code_elimination",
default=defaults.dead_code_elimination,
help="eliminate instructions with unused result",
)
parser.add_option(
"-p",
"--profile",
action="store_true",
dest="profile",
help="profile compilation",
)
parser.add_option(
"-s",
"--stop",
action="store_true",
dest="stop",
help="stop on register errors",
)
parser.add_option(
"-R",
"--ring",
dest="ring",
default=defaults.ring,
help="bit length of ring (default: 0 for field)",
)
parser.add_option(
"-B",
"--binary",
dest="binary",
default=defaults.binary,
help="bit length of sint in binary circuit (default: 0 for arithmetic)",
)
parser.add_option(
"-G",
"--garbled-circuit",
dest="garbled",
action="store_true",
help="compile for binary circuits only (default: false)",
)
parser.add_option(
"-F",
"--field",
dest="field",
default=defaults.field,
help="bit length of sint modulo prime (default: 64)",
)
parser.add_option(
"-P",
"--prime",
dest="prime",
default=defaults.prime,
help="use bit decomposition with a specifed prime modulus "
"for non-linear computation (default: use the masking approach). "
"Don't use this unless you're certain that you need it.",
)
parser.add_option(
"-I",
"--insecure",
action="store_true",
dest="insecure",
help="activate insecure functionality for benchmarking",
)
parser.add_option(
"-b",
"--budget",
dest="budget",
help="set budget for optimized loop unrolling (default: %d)" % \
defaults.budget,
)
parser.add_option(
"-X",
"--mixed",
action="store_true",
dest="mixed",
help="mixing arithmetic and binary computation",
)
parser.add_option(
"-Y",
"--edabit",
action="store_true",
dest="edabit",
help="mixing arithmetic and binary computation using edaBits",
)
parser.add_option(
"-Z",
"--split",
default=defaults.split,
dest="split",
help="mixing arithmetic and binary computation "
"using direct conversion if supported "
"(number of parties as argument)",
)
parser.add_option(
"--invperm",
action="store_true",
dest="invperm",
help="speedup inverse permutation (only use in two-party, "
"semi-honest environment)"
)
parser.add_option(
"-C",
"--CISC",
action="store_true",
dest="cisc",
help="faster CISC compilation mode "
"(used by default unless -O is given)",
)
parser.add_option(
"-K",
"--keep-cisc",
dest="keep_cisc",
help="don't translate CISC instructions",
)
parser.add_option(
"-l",
"--flow-optimization",
action="store_true",
dest="flow_optimization",
help="optimize control flow",
)
parser.add_option(
"-v",
"--verbose",
action="store_true",
dest="verbose",
help="more verbose output",
)
if self.execute:
parser.add_option(
"-E",
"--execute",
dest="execute",
help="protocol to execute with",
)
parser.add_option(
"-H",
"--hostfile",
dest="hostfile",
help="hosts to execute with",
)
parser.add_option(
"-t",
"--tidy_output",
action="store_true",
dest="tidy_output",
help="make output prints tidy and grouped by party (note: delays the prints)",
)
else:
parser.add_option(
"-E",
"--execute",
dest="execute",
help="protocol to optimize for",
)
self.parser = parser
def parse_args(self):
self.options, self.args = self.parser.parse_args(self.custom_args)
if self.options.verbose:
self.runtime_args += ["--verbose"]
if self.options.execute:
self.options.execute = re.sub("-party.x$", "", self.options.execute)
for s, l in self.match.items():
if self.options.execute == l:
self.options.execute = s
break
if self.execute:
if not self.options.execute:
if len(self.args) > 1:
self.options.execute = self.args.pop(0)
else:
self.parser.error("missing protocol name")
if self.options.hostfile:
try:
open(self.options.hostfile)
except:
print('hostfile %s not found' % self.options.hostfile,
file=sys.stderr)
exit(1)
if self.options.execute:
protocol = self.options.execute
if protocol.find("ring") >= 0 or protocol.find("2k") >= 0 or \
protocol.find("brain") >= 0 or protocol == "emulate":
if not (self.options.ring or self.options.binary):
self.options.ring = "64"
if self.options.field:
raise CompilerError(
"field option not compatible with %s" % protocol)
else:
if protocol.find("bin") >= 0 or protocol.find("ccd") >= 0 or \
protocol.find("bmr") >= 0 or \
protocol in ("replicated", "tinier", "tiny", "yao"):
if not self.options.binary:
self.options.binary = "32"
if self.options.ring or self.options.field:
raise CompilerError(
"ring/field options not compatible with %s" %
protocol)
if self.options.ring:
raise CompilerError(
"ring option not compatible with %s" % protocol)
if protocol == "emulate":
self.options.keep_cisc = ''
def build_program(self, name=None):
self.prog = Program(self.args, self.options, name=name)
if self.options.execute:
if self.options.execute in \
("emulate", "ring", "rep-field", "rep4-ring"):
self.prog.use_trunc_pr = True
if self.options.execute in ("ring", "ps-rep-ring", "sy-rep-ring"):
self.prog.use_split(3)
if self.options.execute in ("semi2k",):
self.prog.use_split(int(os.getenv("PLAYERS", 2)))
if self.options.execute in ("rep4-ring",):
self.prog.use_split(4)
if self.options.execute.find("dealer") >= 0:
self.prog.use_edabit(True)
def build_vars(self):
from . import comparison, floatingpoint, instructions, library, types
# add all instructions to the program VARS dictionary
instr_classes = [
t[1] for t in inspect.getmembers(instructions, inspect.isclass)
]
for mod in (types, GC_types):
instr_classes += [
t[1]
for t in inspect.getmembers(mod, inspect.isclass)
if t[1].__module__ == mod.__name__
]
instr_classes += [
t[1]
for t in inspect.getmembers(library, inspect.isfunction)
if t[1].__module__ == library.__name__
]
for op in instr_classes:
self.VARS[op.__name__] = op
# backward compatibility for deprecated classes
self.VARS["sbitint"] = GC_types.sbitintvec
self.VARS["sbitfix"] = GC_types.sbitfixvec
# add open and input separately due to name conflict
self.VARS["vopen"] = instructions.vasm_open
self.VARS["gopen"] = instructions.gasm_open
self.VARS["vgopen"] = instructions.vgasm_open
self.VARS["ginput"] = instructions.gasm_input
self.VARS["comparison"] = comparison
self.VARS["floatingpoint"] = floatingpoint
self.VARS["program"] = self.prog
if self.options.binary:
self.sint = GC_types.sbitintvec.get_type(int(self.options.binary))
self.sfix = GC_types.sbitfixvec
for i in [
"cint",
"cfix",
"cgf2n",
"sintbit",
"sgf2n",
"sgf2nint",
"sgf2nuint",
"sgf2nuint32",
"sgf2nfloat",
"cfloat",
"squant",
]:
class dummy:
def __init__(self, *args):
raise CompilerError(self.error)
dummy.error = i + " not availabe with binary circuits"
if i in ("cint", "cfix"):
dummy.error += ". See https://mp-spdz.readthedocs.io/en/" \
"latest/Compiler.html#Compiler.types." + i
self.VARS[i] = dummy
else:
self.sint = types.sint
self.sfix = types.sfix
self.VARS["sint"] = self.sint
self.VARS["sfix"] = self.sfix
def prep_compile(self, name=None, build=True):
self.parse_args()
if len(self.args) < 1 and name is None:
self.parser.print_help()
exit(1)
if build:
self.build(name=name)
def build(self, name=None):
self.build_program(name=name)
self.build_vars()
def compile_file(self):
"""Compile a file and output a Program object.
If options.merge_opens is set to True, will attempt to merge any
parallelisable open instructions."""
print("Compiling file", self.prog.infile)
self.prog.sint = self.sint
self.prog.sfix = self.sfix
with open(self.prog.infile, "r") as f:
changed = False
if self.options.flow_optimization:
output = []
if_stack = []
for line in f:
if if_stack and not re.match(if_stack[-1][0], line):
if_stack.pop()
m = re.match(
r"(\s*)for +([a-zA-Z_]+) +in " r"+range\(([0-9a-zA-Z_.]+)\):",
line,
)
if m:
output.append(
"%s@for_range_opt(%s)\n" % (m.group(1), m.group(3))
)
output.append("%sdef _(%s):\n" % (m.group(1), m.group(2)))
changed = True
continue
m = re.match(r"(\s*)if(\W.*):", line)
if m:
while if_stack and if_stack[-1][0] == m.group(1):
if_stack.pop()
if_stack.append((m.group(1), len(output)))
output.append("%s@if_(%s)\n" % (m.group(1), m.group(2)))
output.append("%sdef _():\n" % (m.group(1)))
changed = True
continue
m = re.match(r"(\s*)elif\s+", line)
if m:
raise CompilerError("elif not supported")
if if_stack:
m = re.match("%selse:" % if_stack[-1][0], line)
if m:
start = if_stack[-1][1]
ws = if_stack[-1][0]
output[start] = re.sub(
r"^%s@if_\(" % ws, r"%s@if_e(" % ws, output[start]
)
output.append("%s@else_\n" % ws)
output.append("%sdef _():\n" % ws)
continue
output.append(line)
if changed:
infile = tempfile.NamedTemporaryFile("w+", delete=False)
for line in output:
infile.write(line)
infile.seek(0)
else:
infile = open(self.prog.infile)
else:
infile = open(self.prog.infile)
# make compiler modules directly accessible
sys.path.insert(0, "%s/Compiler" % self.root)
# create the tapes
exec(compile(infile.read(), infile.name, "exec"), self.VARS)
if changed and not self.options.debug:
os.unlink(infile.name)
return self.finalize_compile()
def register_function(self, name=None):
"""
To register a function to be compiled, use this as a decorator.
Example:
@compiler.register_function('test-mpc')
def test_mpc(compiler):
...
"""
def inner(func):
self.compile_name = name or func.__name__
self.compile_function = func
return func
return inner
def compile_func(self):
if not (hasattr(self, "compile_name") and hasattr(self, "compile_func")):
raise CompilerError(
"No function to compile. "
"Did you decorate a function with @register_fuction(name)?"
)
self.prep_compile(self.compile_name)
print(
"Compiling: {} from {}".format(self.compile_name, self.compile_func.__name__)
)
self.compile_function()
self.finalize_compile()
def finalize_compile(self):
self.prog.finalize()
if self.prog.req_num:
print("Program requires at most:")
for x in self.prog.req_num.pretty():
print(x)
if self.prog.verbose:
print("Program requires:", repr(self.prog.req_num))
print("Cost:", 0 if self.prog.req_num is None else self.prog.req_num.cost())
print("Memory size:", dict(self.prog.allocated_mem))
return self.prog
match = {
"ring": "replicated-ring",
"rep-field": "replicated-field",
"replicated": "replicated-bin"
}
@classmethod
def executable_from_protocol(cls, protocol):
match = cls.match
if protocol in match:
protocol = match[protocol]
if protocol.find("bmr") == -1:
protocol = re.sub("^mal-", "malicious-", protocol)
if protocol == "emulate":
return protocol + ".x"
else:
return protocol + "-party.x"
def local_execution(self, args=None):
if args is None:
args = self.runtime_args
executable = self.executable_from_protocol(self.options.execute)
if not os.path.exists("%s/%s" % (self.root, executable)):
print("Creating binary for virtual machine...")
try:
subprocess.run(["make", executable], check=True, cwd=self.root)
except:
raise CompilerError(
"Cannot produce %s. " % executable + \
"Note that compilation requires a few GB of RAM.")
vm = "%s/Scripts/%s.sh" % (self.root, self.options.execute)
sys.stdout.flush()
print("Compilation finished, running program...", file=sys.stderr)
sys.stderr.flush()
os.execl(vm, vm, self.prog.name, *args)
def remote_execution(self, args=None):
if args is None:
args = self.runtime_args
vm = self.executable_from_protocol(self.options.execute)
hosts = list(x.strip()
for x in filter(None, open(self.options.hostfile)))
# test availability before compilation
from fabric import Connection
import subprocess
print("Creating static binary for virtual machine...")
subprocess.run(["make", "static/%s" % vm], check=True, cwd=self.root)
# transfer files
import glob
hostnames = []
destinations = []
for host in hosts:
split = host.split('/', maxsplit=1)
hostnames.append(split[0])
if len(split) > 1:
destinations.append(split[1])
else:
destinations.append('.')
connections = [Connection(hostname) for hostname in hostnames]
print("Setting up players...")
def run(i):
dest = destinations[i]
connection = connections[i]
connection.run(
"mkdir -p %s/{Player-Data,Programs/{Bytecode,Schedules}} " % \
dest)
# executable
connection.put("%s/static/%s" % (self.root, vm), dest)
# program
dest += "/"
connection.put("Programs/Schedules/%s.sch" % self.prog.name,
dest + "Programs/Schedules")
for filename in glob.glob(
"Programs/Bytecode/%s-*.bc" % self.prog.name):
connection.put(filename, dest + "Programs/Bytecode")
# inputs
for filename in glob.glob("Player-Data/Input*-P%d-*" % i):
connection.put(filename, dest + "Player-Data")
# key and certificates
for suffix in ('key', 'pem'):
connection.put("Player-Data/P%d.%s" % (i, suffix),
dest + "Player-Data")
for filename in glob.glob("Player-Data/*.0"):
connection.put(filename, dest + "Player-Data")
def run_with_error(i):
try:
run(i)
except IOError:
print('IO error when copying files, does %s have enough space?' %
hostnames[i])
raise
import threading
import random
import io
def run_and_capture_outputs(outputs, fn, i):
out = fn(i)
outputs[i] = out
threads = []
for i in range(len(hosts)):
threads.append(threading.Thread(target=run_with_error, args=(i,)))
for thread in threads:
thread.start()
for thread in threads:
thread.join()
# execution
threads = []
# tidy up output prints
hide_option = False
if self.options.tidy_output:
outputs = []
for i in range(len(connections)):
outputs += [""]
hide_option = True
# random port numbers to avoid conflict
port = 10000 + random.randrange(40000)
if '@' in hostnames[0]:
party0 = hostnames[0].split('@')[1]
else:
party0 = hostnames[0]
if 'rep' not in vm and 'yao' not in vm:
N = ['-N', str(len(connections))]
else:
N = []
for i in range(len(connections)):
run = lambda i: connections[i].run(
"cd %s; ./%s -p %d %s -h %s -pn %d %s" % \
(destinations[i], vm, i, self.prog.name, party0, port,
' '.join(args + N)), hide=hide_option)
if self.options.tidy_output:
threads.append(threading.Thread(target=run_and_capture_outputs, args=(outputs, run, i,)))
else:
threads.append(threading.Thread(target=run, args=(i,)))
for thread in threads:
thread.start()
for thread in threads:
thread.join()
if self.options.tidy_output:
for out in outputs:
print(out)