new cloud is cloudy [pr] (#7631)

* new cloud is cloudy [pr]

* waste lines to add security

* safety, with speed and less lines

* timing and del

* lines

* cleanups

* restore CloudSession

* bump to 3.10

* quotes

* renderer security
This commit is contained in:
George Hotz
2024-11-11 20:18:04 +08:00
committed by GitHub
parent 766a680588
commit d40673505f
4 changed files with 140 additions and 89 deletions

View File

@@ -209,15 +209,15 @@ jobs:
steps:
- name: Checkout Code
uses: actions/checkout@v4
- name: Set up Python 3.8
- name: Set up Python 3.10
uses: actions/setup-python@v5
with:
python-version: 3.8
python-version: "3.10"
- name: Cache python packages
uses: actions/cache@v4
with:
path: ${{ env.Python3_ROOT_DIR }}/lib/python3.8/site-packages
key: linting-packages-${{ hashFiles('**/setup.py') }}-3.8
path: ${{ env.Python3_ROOT_DIR }}/lib/python3.10/site-packages
key: linting-packages-${{ hashFiles('**/setup.py') }}-3.10
- name: Install dependencies
run: pip install -e '.[linting,testing]' --extra-index-url https://download.pytorch.org/whl/cpu
- name: Lint bad-indentation and trailing-whitespace with pylint

View File

@@ -22,7 +22,7 @@ setup(name='tinygrad',
"License :: OSI Approved :: MIT License"
],
install_requires=[],
python_requires='>=3.8',
python_requires='>=3.10',
extras_require={
'llvm': ["llvmlite"],
'arm': ["unicorn"],

View File

@@ -1,8 +1,8 @@
FROM ubuntu:20.04
FROM ubuntu:22.04
# Install python3.8, and pip3
# Install python3.10, and pip3
RUN apt-get update && apt-get install -y --no-install-recommends \
python3.8 \
python3.10 \
python3-pip \
&& rm -rf /var/lib/apt/lists/*

View File

@@ -5,14 +5,70 @@
# it should be a secure (example: no use of pickle) boundary. HTTP is used for RPC
from __future__ import annotations
from typing import Tuple, Optional, Dict, Any, DefaultDict
from typing import Tuple, Optional, Dict, Any, DefaultDict, List
from collections import defaultdict
import multiprocessing, functools, http.client, hashlib, json, time, contextlib, os, binascii
from dataclasses import dataclass, field
from tinygrad.dtype import dtypes
from tinygrad.helpers import getenv, DEBUG, fromimport, unwrap, prod
from tinygrad.device import Compiled, Allocator, Compiler, Device, BufferOptions
import multiprocessing, functools, http.client, hashlib, json, time, os, binascii, struct, ast, contextlib
from http.server import HTTPServer, BaseHTTPRequestHandler
from tinygrad.renderer import Renderer
from tinygrad.dtype import dtypes
from tinygrad.helpers import getenv, DEBUG, fromimport, unwrap, Timing
from tinygrad.device import Compiled, Allocator, Compiler, Device, BufferOptions
# ***** API *****
class CloudRequest: pass
@dataclass(frozen=True)
class BufferAlloc(CloudRequest): buffer_num: int; size: int; options: BufferOptions # noqa: E702
@dataclass(frozen=True)
class BufferFree(CloudRequest): buffer_num: int # noqa: E702
@dataclass(frozen=True)
class CopyIn(CloudRequest): buffer_num: int; datahash: str # noqa: E702
@dataclass(frozen=True)
class CopyOut(CloudRequest): buffer_num: int
@dataclass(frozen=True)
class ProgramAlloc(CloudRequest): name: str; datahash: str # noqa: E702
@dataclass(frozen=True)
class ProgramFree(CloudRequest): name: str; datahash: str # noqa: E702
@dataclass(frozen=True)
class ProgramExec(CloudRequest):
name: str; datahash: str; bufs: Tuple[int, ...]; vals: Tuple[int, ...] # noqa: E702
global_size: Optional[Tuple[int, ...]]; local_size: Optional[Tuple[int, ...]]; wait: bool # noqa: E702
# for safe deserialization
whitelist = {x.__name__:x for x in [BufferAlloc, BufferFree, CopyIn, CopyOut, ProgramAlloc, ProgramFree, ProgramExec, BufferOptions]}
eval_fxns = {ast.Constant: lambda x: x.value, ast.Tuple: lambda x: tuple(map(safe_eval, x.elts)), ast.List: lambda x: list(map(safe_eval, x.elts)),
ast.Call: lambda x: safe_eval(x.func)(*[safe_eval(arg) for arg in x.args], **{kwarg.arg: safe_eval(kwarg.value) for kwarg in x.keywords}),
ast.Name: lambda x: whitelist[x.id], ast.Attribute: lambda x: {"imagef": dtypes.imagef, "imageh": dtypes.imageh}[x.attr]}
def safe_eval(node): return eval_fxns[node.__class__](node)
class BatchRequest:
def __init__(self):
self._q: List[CloudRequest] = []
self._h: Dict[str, bytes] = {}
def h(self, d:bytes) -> str:
binhash = hashlib.sha256(d).digest()
self._h[datahash:=binascii.hexlify(binhash).decode()] = binhash+struct.pack("<Q", len(d))+d
return datahash
def q(self, x:CloudRequest): self._q.append(x)
def serialize(self) -> bytes:
self.h(repr(self._q).encode())
return b''.join(self._h.values())
def deserialize(self, dat:bytes) -> BatchRequest:
ptr = 0
while ptr < len(dat):
datahash, datalen = binascii.hexlify(dat[ptr:ptr+0x20]).decode(), struct.unpack("<Q", dat[ptr+0x20:ptr+0x28])[0]
self._h[datahash] = dat[ptr+0x28:ptr+0x28+datalen]
ptr += 0x28+datalen
self._q = safe_eval(ast.parse(self._h[datahash], mode="eval").body)
return self
# ***** backend *****
@@ -21,7 +77,6 @@ class CloudSession:
programs: Dict[Tuple[str, str], Any] = field(default_factory=dict)
# TODO: the buffer should track this internally
buffers: Dict[int, Tuple[Any, int, Optional[BufferOptions]]] = field(default_factory=dict)
buffer_num = 0
class CloudHandler(BaseHTTPRequestHandler):
protocol_version = 'HTTP/1.1'
@@ -32,66 +87,47 @@ class CloudHandler(BaseHTTPRequestHandler):
super().setup()
print(f"connection established with {self.client_address}, socket: {self.connection.fileno()}")
def get_data(self):
content_len = self.headers.get('Content-Length')
assert content_len is not None
return self.rfile.read(int(content_len))
def get_json(self): return json.loads(self.get_data())
def _fail(self):
self.send_response(404)
self.end_headers()
return 0
def _do(self, method):
session = CloudHandler.sessions[unwrap(self.headers.get("Cookie")).split("session=")[1]]
ret = b""
if self.path == "/renderer" and method == "GET":
ret, status_code = b"", 200
if self.path == "/batch" and method == "POST":
# TODO: streaming deserialize?
req = BatchRequest().deserialize(self.rfile.read(int(unwrap(self.headers.get('Content-Length')))))
# the cmds are always last (currently in datahash)
for c in req._q:
if DEBUG >= 1: print(c)
match c:
case BufferAlloc():
assert c.buffer_num not in session.buffers, f"buffer {c.buffer_num} already allocated"
session.buffers[c.buffer_num] = (Device[CloudHandler.dname].allocator.alloc(c.size, c.options), c.size, c.options)
case BufferFree():
buf,sz,buffer_options = session.buffers[c.buffer_num]
Device[CloudHandler.dname].allocator.free(buf,sz,buffer_options)
del session.buffers[c.buffer_num]
case CopyIn(): Device[CloudHandler.dname].allocator.copyin(session.buffers[c.buffer_num][0], memoryview(bytearray(req._h[c.datahash])))
case CopyOut():
buf,sz,_ = session.buffers[c.buffer_num]
Device[CloudHandler.dname].allocator.copyout(memoryview(ret:=bytearray(sz)), buf)
case ProgramAlloc():
lib = Device[CloudHandler.dname].compiler.compile_cached(req._h[c.datahash].decode())
session.programs[(c.name, c.datahash)] = Device[CloudHandler.dname].runtime(c.name, lib)
case ProgramFree(): del session.programs[(c.name, c.datahash)]
case ProgramExec():
bufs = [session.buffers[x][0] for x in c.bufs]
extra_args = {k:v for k,v in [("global_size", c.global_size), ("local_size", c.local_size)] if v is not None}
r = session.programs[(c.name, c.datahash)](*bufs, vals=c.vals, wait=c.wait, **extra_args)
if r is not None: ret = str(r).encode()
elif self.path == "/renderer" and method == "GET":
cls, args = Device[CloudHandler.dname].renderer.__reduce__()
ret = json.dumps((cls.__module__, cls.__name__, args)).encode()
elif self.path.startswith("/alloc") and method == "POST":
size = int(self.path.split("=")[-1])
buffer_options: Optional[BufferOptions] = None
if 'image' in self.path:
image_shape = tuple([int(x) for x in self.path.split("=")[-2].split("&")[0].split(",")])
buffer_options = BufferOptions(image=dtypes.imageh(image_shape) if prod(image_shape)*2 == size else dtypes.imagef(image_shape))
session.buffer_num += 1
session.buffers[session.buffer_num] = (Device[CloudHandler.dname].allocator.alloc(size, buffer_options), size, buffer_options)
ret = str(session.buffer_num).encode()
elif self.path.startswith("/buffer"):
key = int(self.path.split("/")[-1])
buf,sz,buffer_options = session.buffers[key]
if method == "GET": Device[CloudHandler.dname].allocator.copyout(memoryview(ret:=bytearray(sz)), buf)
elif method == "PUT": Device[CloudHandler.dname].allocator.copyin(buf, memoryview(bytearray(self.get_data())))
elif method == "DELETE":
Device[CloudHandler.dname].allocator.free(buf,sz,buffer_options)
del session.buffers[key]
else: return self._fail()
elif self.path.startswith("/program"):
name, hsh = self.path.split("/")[-2:]
if method == "PUT":
src = self.get_data()
assert hashlib.sha256(src).hexdigest() == hsh
lib = Device[CloudHandler.dname].compiler.compile_cached(src.decode())
session.programs[(name, hsh)] = Device[CloudHandler.dname].runtime(name, lib)
elif method == "POST":
j = self.get_json()
bufs = [session.buffers[x][0] for x in j['bufs']]
del j['bufs']
r = session.programs[(name, hsh)](*bufs, **j)
if r is not None: ret = str(r).encode()
elif method == "DELETE": del session.programs[(name, hsh)]
else: return self._fail()
else: return self._fail()
self.send_response(200)
else: status_code = 404
self.send_response(status_code)
self.send_header('Content-Length', str(len(ret)))
self.end_headers()
return self.wfile.write(ret)
def do_GET(self): return self._do("GET")
def do_POST(self): return self._do("POST")
def do_PUT(self): return self._do("PUT")
def do_DELETE(self): return self._do("DELETE")
def cloud_server(port:int):
multiprocessing.current_process().name = "MainProcess"
@@ -106,44 +142,46 @@ class CloudAllocator(Allocator):
def __init__(self, device:CloudDevice):
self.device = device
super().__init__()
def _alloc(self, size:int, options) -> int:
# TODO: ideally we shouldn't have to deal with images here
extra = ("image="+','.join([str(x) for x in options.image.shape])+"&") if options.image is not None else ""
return int(self.device.send("POST", f"alloc?{extra}size={size}"))
def _free(self, opaque, options):
with contextlib.suppress(ConnectionRefusedError, http.client.CannotSendRequest, http.client.RemoteDisconnected):
self.device.send("DELETE", f"buffer/{opaque}", data=b"")
def copyin(self, dest:int, src:memoryview): self.device.send("PUT", f"buffer/{dest}", data=bytes(src))
# TODO: ideally we shouldn't have to deal with images here
def _alloc(self, size:int, options:BufferOptions) -> int:
self.device.buffer_num += 1
self.device.req.q(BufferAlloc(self.device.buffer_num, size, options))
return self.device.buffer_num
# TODO: options should not be here in any Allocator
def _free(self, opaque:int, options): self.device.req.q(BufferFree(opaque))
def copyin(self, dest:int, src:memoryview): self.device.req.q(CopyIn(dest, self.device.req.h(bytes(src))))
def copyout(self, dest:memoryview, src:int):
resp = self.device.send("GET", f"buffer/{src}")
self.device.req.q(CopyOut(src))
resp = self.device.batch_submit()
assert len(resp) == len(dest), f"buffer length mismatch {len(resp)} != {len(dest)}"
dest[:] = resp
class CloudProgram:
def __init__(self, device:CloudDevice, name:str, lib:bytes):
self.device = device
self.prgid = f"{name}/{hashlib.sha256(lib).hexdigest()}"
self.device.send("PUT", "program/"+self.prgid, lib)
self.device, self.name = device, name
self.datahash = self.device.req.h(lib)
self.device.req.q(ProgramAlloc(self.name, self.datahash))
super().__init__()
def __del__(self): self.device.send("DELETE", "program/"+self.prgid)
def __del__(self): self.device.req.q(ProgramFree(self.name, self.datahash))
def __call__(self, *bufs, global_size=None, local_size=None, vals:Tuple[int, ...]=(), wait=False):
args = {"bufs": bufs, "vals": vals, "wait": wait}
if global_size is not None: args["global_size"] = global_size
if local_size is not None: args["local_size"] = local_size
ret = self.device.send("POST", "program/"+self.prgid, json.dumps(args).encode())
if wait: return float(ret)
self.device.req.q(ProgramExec(self.name, self.datahash, bufs, vals, global_size, local_size, wait))
if wait: return float(self.device.batch_submit())
class CloudDevice(Compiled):
def __init__(self, device:str):
if (host:=getenv("HOST", "")) != "":
self.host = host
if (host:=getenv("HOST", "")) != "": self.host = host
else:
p = multiprocessing.Process(target=cloud_server, args=(6667,))
p.daemon = True
p.start()
self.host = "127.0.0.1:6667"
self.cookie = binascii.hexlify(os.urandom(0x10)).decode()
# state for the connection
self.session = binascii.hexlify(os.urandom(0x10)).decode()
self.buffer_num = 0
self.req: BatchRequest = BatchRequest()
if DEBUG >= 1: print(f"cloud with host {self.host}")
while 1:
try:
@@ -155,13 +193,26 @@ class CloudDevice(Compiled):
time.sleep(0.1)
if DEBUG >= 1: print(f"remote has device {clouddev}")
# TODO: how to we have BEAM be cached on the backend? this should just send a specification of the compute. rethink what goes in Renderer
assert clouddev[0].startswith("tinygrad.renderer."), f"bad renderer {clouddev}"
renderer = fromimport(clouddev[0], clouddev[1])(*clouddev[2])
super().__init__(device, CloudAllocator(self), renderer, Compiler(), functools.partial(CloudProgram, self))
if not clouddev[0].startswith("tinygrad.renderer.") or not clouddev[1].endswith("Renderer"): raise RuntimeError(f"bad renderer {clouddev}")
renderer_class = fromimport(clouddev[0], clouddev[1]) # TODO: is this secure?
if not issubclass(renderer_class, Renderer): raise RuntimeError(f"renderer isn't a Renderer {clouddev}")
super().__init__(device, CloudAllocator(self), renderer_class(*clouddev[2]), Compiler(), functools.partial(CloudProgram, self))
def __del__(self):
# TODO: this is never being called
# TODO: should close the whole session
with contextlib.suppress(ConnectionRefusedError, http.client.CannotSendRequest, http.client.RemoteDisconnected): self.batch_submit()
def batch_submit(self):
data = self.req.serialize()
with Timing(f"*** send {len(self.req._q):-3d} requests {len(self.req._h):-3d} hashes with len {len(data)/1024:.2f} kB in ", enabled=DEBUG>=1):
ret = self.send("POST", "batch", data)
self.req = BatchRequest()
return ret
def send(self, method, path, data:Optional[bytes]=None) -> bytes:
# TODO: retry logic
self.conn.request(method, "/"+path, data, headers={"Cookie": f"session={self.cookie}"})
self.conn.request(method, "/"+path, data, headers={"Cookie": f"session={self.session}"})
response = self.conn.getresponse()
assert response.status == 200, f"failed on {method} {path}"
return response.read()