Files
circ/aws_benchmark/aws_controller.py
Edward Chen 6805789405 update
2022-07-31 01:02:05 -04:00

271 lines
9.7 KiB
Python

import boto3
import json
import os
import multiprocessing
import paramiko
import subprocess
import sys
import time
# instance_type = "t2.micro"
instance_type = "c5.large"
ec2_resource = boto3.resource("ec2",
aws_access_key_id=os.environ["AWS_ACCESS_KEY_ID"],
aws_secret_access_key=os.environ["AWS_SECRET_ACCESS_KEY"],
region_name="us-east-2")
def create_instances(num):
instances = ec2_resource.create_instances(ImageId="ami-05b63781e32145c7f",
InstanceType=instance_type,
KeyName="the-key-to-her-heart",
MinCount=1,
MaxCount=num,
Monitoring={"Enabled": False},
SecurityGroups=[
"circ4mpc"]
)
print("Created {} instances".format(num))
def start_instances(num):
stopped_instances = list(ec2_resource.instances.filter(
Filters=[{"Name": "instance-state-name", "Values": ["stopped"]}]))
count = 0
num = min(num, len(stopped_instances))
for i in range(num):
instance = stopped_instances[i]
ec2_resource.instances.filter(InstanceIds=[instance.id]).start()
count += 1
print("Started {} instances".format(count))
def stop_instances(num):
running_instances = list(ec2_resource.instances.filter(
Filters=[{"Name": "instance-state-name", "Values": ["running"]}]))
count = 0
num = min(num, len(running_instances))
for i in range(num):
instance = running_instances[i]
print("Stopping", instance.public_dns_name)
ec2_resource.instances.filter(InstanceIds=[instance.id]).stop()
count += 1
print("Stopped {} instances".format(count))
def terminate_instances(num):
instances = list(ec2_resource.instances.filter(
Filters=[{"Name": "instance-state-name", "Values": ["running", "stopped"]}]))
count = 0
num = min(num, len(instances))
for i in range(num):
instance = instances[i]
count += 1
ec2_resource.instances.filter(InstanceIds=[instance.id]).terminate()
print("Terminated {} instances".format(count))
def stats():
stats = {}
stats["total"] = len(list(ec2_resource.instances.filter(Filters=[
{"Name": "instance-state-name", "Values": ["running", "stopped", "pending", "stopping"]}])))
stats["running"] = len(list(ec2_resource.instances.filter(
Filters=[{"Name": "instance-state-name", "Values": ["running"]}])))
stats["stopped"] = len(list(ec2_resource.instances.filter(
Filters=[{"Name": "instance-state-name", "Values": ["stopped"]}])))
stats["pending"] = len(list(ec2_resource.instances.filter(
Filters=[{"Name": "instance-state-name", "Values": ["pending"]}])))
stats["stopping"] = len(list(ec2_resource.instances.filter(
Filters=[{"Name": "instance-state-name", "Values": ["stopping"]}])))
print(json.dumps(stats, indent=4))
def hosts():
running_instances = list(ec2_resource.instances.filter(
Filters=[{"Name": "instance-state-name", "Values": ["running"]}]))
running_instance_ips = [
instance.public_dns_name for instance in running_instances]
for ip in running_instance_ips:
print("ssh -i \"the-key-to-her-heart.pem\" ubuntu@{}".format(ip))
def setup_instances(num):
running_instances = list(ec2_resource.instances.filter(
Filters=[{"Name": "instance-state-name", "Values": ["running"]}]))
if len(running_instances) < num:
print("Not all instances are up yet!")
return
running_instance_ips = [
instance.public_dns_name for instance in running_instances]
pool = multiprocessing.Pool(len(running_instance_ips))
pool.map(setup_worker, running_instance_ips)
def setup_worker(ip):
print("Setting up", ip)
key = paramiko.RSAKey.from_private_key_file("the-key-to-her-heart.pem")
client = paramiko.SSHClient()
client.set_missing_host_key_policy(paramiko.AutoAddPolicy())
client.connect(hostname=ip, username="ubuntu", pkey=key)
_, stdout, _ = client.exec_command("cd ~/circ")
if stdout.channel.recv_exit_status():
_, stdout, _ = client.exec_command(
"cd ~ && git clone https://github.com/circify/circ.git && cd ~/circ && git checkout mpc_aws && python3 driver.py -F aby c lp bench && python3 driver.py -c && cd ~ && ./circ/aws_benchmark/setup.sh")
if stdout.channel.recv_exit_status():
print(ip, " failed setup")
else:
_, stdout, _ = client.exec_command(
"cd ~ && ./circ/aws_benchmark/setup.sh")
if stdout.channel.recv_exit_status():
print(ip, " failed setup 2")
print("Set up:", ip)
client.close()
def compile_benchmarks(num):
# TODO
pass
def run_benchmarks(num):
assert(num == 2)
running_instances = list(ec2_resource.instances.filter(
Filters=[{"Name": "instance-state-name", "Values": ["running"]}]))
if len(running_instances) < num:
print("Not all instances are up yet!")
return
running_instance_ips = [
instance.public_dns_name for instance in running_instances]
running_instance_private_ips = [
running_instances[0].private_ip_address for _ in running_instances]
roles = [0, 1]
pool = multiprocessing.Pool(len(running_instance_ips))
pool.starmap(benchmark_worker, zip(running_instance_ips,
running_instance_private_ips, roles))
def benchmark_worker(ip, connect_ip, role):
print("Running benchmark:\nip: {}\nconnect: {}\nrole: {}\n".format(
ip, connect_ip, role))
key = paramiko.RSAKey.from_private_key_file("the-key-to-her-heart.pem")
client = paramiko.SSHClient()
client.set_missing_host_key_policy(paramiko.AutoAddPolicy())
client.connect(hostname=ip, username="ubuntu", pkey=key)
_, stdout, _ = client.exec_command(
"cd ~ && ./circ/aws_benchmark/benchmark.sh {} {} > benchmark.log".format(connect_ip, role))
if stdout.channel.recv_exit_status():
print(ip, " failed running benchmark")
client.close()
def refresh_instances():
running_instances = list(ec2_resource.instances.filter(
Filters=[{"Name": "instance-state-name", "Values": ["running"]}]))
running_instance_ips = [
instance.public_dns_name for instance in running_instances]
pool = multiprocessing.Pool(len(running_instance_ips))
pool.map(refresh_worker, running_instance_ips)
def refresh_worker(ip):
print("Refreshing: {}".format(ip))
key = paramiko.RSAKey.from_private_key_file("the-key-to-her-heart.pem")
client = paramiko.SSHClient()
client.set_missing_host_key_policy(paramiko.AutoAddPolicy())
client.connect(hostname=ip, username="ubuntu", pkey=key)
_, stdout, _ = client.exec_command(
"cd ~/ABY && git pull && cd ~/ABY/build && make && cd ~/circ && git pull && cd ~/HyCC && git pull && make")
if stdout.channel.recv_exit_status():
print(ip, " failed to refresh")
client.close()
def logs():
running_instances = list(ec2_resource.instances.filter(
Filters=[{"Name": "instance-state-name", "Values": ["running"]}]))
running_instance_ips = [
instance.public_dns_name for instance in running_instances]
for dns_name in running_instance_ips:
if not os.path.exists("./logs/"):
os.mkdir("./logs/")
if not os.path.exists("./logs/"+dns_name):
os.mkdir("./logs/"+dns_name)
subprocess.call("scp -o StrictHostKeyChecking=no -i the-key-to-her-heart.pem ubuntu@" +
dns_name+":~/*.log ./logs/"+dns_name, shell=True)
key = paramiko.RSAKey.from_private_key_file("the-key-to-her-heart.pem")
client = paramiko.SSHClient()
client.set_missing_host_key_policy(paramiko.AutoAddPolicy())
client.connect(hostname=dns_name, username="ubuntu", pkey=key)
_, stdout, _ = client.exec_command(
"rm -rf ./*.log")
if stdout.channel.recv_exit_status():
print(dns_name, " failed to remove logs.")
client.close()
if __name__ == "__main__":
last_cmd = ""
while True:
cmds = input("> ").split(" ")
cmd_type = cmds[0]
# press enter to redo
if cmd_type == "" and last_cmd != "":
cmd_type = last_cmd
else:
last_cmd = cmd_type
if cmd_type == "help":
print("Not again... oh well here you go\n")
print("EC2: \tcreate start stop terminate stats")
print("Setup: \tsetup refresh")
print("Build: \tbuild")
print("Run: \trun")
print("Logs: \tlogs")
print("Misc: \tstats hosts")
print("Quit: \tquit q")
elif cmd_type == "create":
create_instances(2)
elif cmd_type == "start":
start_instances(2)
elif cmd_type == "setup":
print("=== will stop instances after setup ===")
setup_instances(2)
stop_instances(2)
elif cmd_type == "run":
run_benchmarks(2)
elif cmd_type == "stop":
stop_instances(2)
elif cmd_type == "terminate":
terminate_instances(2)
elif cmd_type == "stats":
stats()
elif cmd_type == "hosts":
hosts()
elif cmd_type == "refresh":
refresh_instances()
elif cmd_type == "logs":
logs()
elif cmd_type in ["quit", "q", "exit"]:
sys.exit(0)
else:
print("unlucky, not a cmd")