mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-06 21:53:53 -05:00
new cloud is cloudy [pr] (#7631)
* new cloud is cloudy [pr] * waste lines to add security * safety, with speed and less lines * timing and del * lines * cleanups * restore CloudSession * bump to 3.10 * quotes * renderer security
This commit is contained in:
8
.github/workflows/test.yml
vendored
8
.github/workflows/test.yml
vendored
@@ -209,15 +209,15 @@ jobs:
|
||||
steps:
|
||||
- name: Checkout Code
|
||||
uses: actions/checkout@v4
|
||||
- name: Set up Python 3.8
|
||||
- name: Set up Python 3.10
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: 3.8
|
||||
python-version: "3.10"
|
||||
- name: Cache python packages
|
||||
uses: actions/cache@v4
|
||||
with:
|
||||
path: ${{ env.Python3_ROOT_DIR }}/lib/python3.8/site-packages
|
||||
key: linting-packages-${{ hashFiles('**/setup.py') }}-3.8
|
||||
path: ${{ env.Python3_ROOT_DIR }}/lib/python3.10/site-packages
|
||||
key: linting-packages-${{ hashFiles('**/setup.py') }}-3.10
|
||||
- name: Install dependencies
|
||||
run: pip install -e '.[linting,testing]' --extra-index-url https://download.pytorch.org/whl/cpu
|
||||
- name: Lint bad-indentation and trailing-whitespace with pylint
|
||||
|
||||
2
setup.py
2
setup.py
@@ -22,7 +22,7 @@ setup(name='tinygrad',
|
||||
"License :: OSI Approved :: MIT License"
|
||||
],
|
||||
install_requires=[],
|
||||
python_requires='>=3.8',
|
||||
python_requires='>=3.10',
|
||||
extras_require={
|
||||
'llvm': ["llvmlite"],
|
||||
'arm': ["unicorn"],
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
FROM ubuntu:20.04
|
||||
FROM ubuntu:22.04
|
||||
|
||||
# Install python3.8, and pip3
|
||||
# Install python3.10, and pip3
|
||||
RUN apt-get update && apt-get install -y --no-install-recommends \
|
||||
python3.8 \
|
||||
python3.10 \
|
||||
python3-pip \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
|
||||
|
||||
@@ -5,14 +5,70 @@
|
||||
# it should be a secure (example: no use of pickle) boundary. HTTP is used for RPC
|
||||
|
||||
from __future__ import annotations
|
||||
from typing import Tuple, Optional, Dict, Any, DefaultDict
|
||||
from typing import Tuple, Optional, Dict, Any, DefaultDict, List
|
||||
from collections import defaultdict
|
||||
import multiprocessing, functools, http.client, hashlib, json, time, contextlib, os, binascii
|
||||
from dataclasses import dataclass, field
|
||||
from tinygrad.dtype import dtypes
|
||||
from tinygrad.helpers import getenv, DEBUG, fromimport, unwrap, prod
|
||||
from tinygrad.device import Compiled, Allocator, Compiler, Device, BufferOptions
|
||||
import multiprocessing, functools, http.client, hashlib, json, time, os, binascii, struct, ast, contextlib
|
||||
from http.server import HTTPServer, BaseHTTPRequestHandler
|
||||
from tinygrad.renderer import Renderer
|
||||
from tinygrad.dtype import dtypes
|
||||
from tinygrad.helpers import getenv, DEBUG, fromimport, unwrap, Timing
|
||||
from tinygrad.device import Compiled, Allocator, Compiler, Device, BufferOptions
|
||||
|
||||
# ***** API *****
|
||||
|
||||
class CloudRequest: pass
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class BufferAlloc(CloudRequest): buffer_num: int; size: int; options: BufferOptions # noqa: E702
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class BufferFree(CloudRequest): buffer_num: int # noqa: E702
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class CopyIn(CloudRequest): buffer_num: int; datahash: str # noqa: E702
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class CopyOut(CloudRequest): buffer_num: int
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ProgramAlloc(CloudRequest): name: str; datahash: str # noqa: E702
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ProgramFree(CloudRequest): name: str; datahash: str # noqa: E702
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ProgramExec(CloudRequest):
|
||||
name: str; datahash: str; bufs: Tuple[int, ...]; vals: Tuple[int, ...] # noqa: E702
|
||||
global_size: Optional[Tuple[int, ...]]; local_size: Optional[Tuple[int, ...]]; wait: bool # noqa: E702
|
||||
|
||||
# for safe deserialization
|
||||
whitelist = {x.__name__:x for x in [BufferAlloc, BufferFree, CopyIn, CopyOut, ProgramAlloc, ProgramFree, ProgramExec, BufferOptions]}
|
||||
eval_fxns = {ast.Constant: lambda x: x.value, ast.Tuple: lambda x: tuple(map(safe_eval, x.elts)), ast.List: lambda x: list(map(safe_eval, x.elts)),
|
||||
ast.Call: lambda x: safe_eval(x.func)(*[safe_eval(arg) for arg in x.args], **{kwarg.arg: safe_eval(kwarg.value) for kwarg in x.keywords}),
|
||||
ast.Name: lambda x: whitelist[x.id], ast.Attribute: lambda x: {"imagef": dtypes.imagef, "imageh": dtypes.imageh}[x.attr]}
|
||||
def safe_eval(node): return eval_fxns[node.__class__](node)
|
||||
|
||||
class BatchRequest:
|
||||
def __init__(self):
|
||||
self._q: List[CloudRequest] = []
|
||||
self._h: Dict[str, bytes] = {}
|
||||
def h(self, d:bytes) -> str:
|
||||
binhash = hashlib.sha256(d).digest()
|
||||
self._h[datahash:=binascii.hexlify(binhash).decode()] = binhash+struct.pack("<Q", len(d))+d
|
||||
return datahash
|
||||
def q(self, x:CloudRequest): self._q.append(x)
|
||||
def serialize(self) -> bytes:
|
||||
self.h(repr(self._q).encode())
|
||||
return b''.join(self._h.values())
|
||||
def deserialize(self, dat:bytes) -> BatchRequest:
|
||||
ptr = 0
|
||||
while ptr < len(dat):
|
||||
datahash, datalen = binascii.hexlify(dat[ptr:ptr+0x20]).decode(), struct.unpack("<Q", dat[ptr+0x20:ptr+0x28])[0]
|
||||
self._h[datahash] = dat[ptr+0x28:ptr+0x28+datalen]
|
||||
ptr += 0x28+datalen
|
||||
self._q = safe_eval(ast.parse(self._h[datahash], mode="eval").body)
|
||||
return self
|
||||
|
||||
# ***** backend *****
|
||||
|
||||
@@ -21,7 +77,6 @@ class CloudSession:
|
||||
programs: Dict[Tuple[str, str], Any] = field(default_factory=dict)
|
||||
# TODO: the buffer should track this internally
|
||||
buffers: Dict[int, Tuple[Any, int, Optional[BufferOptions]]] = field(default_factory=dict)
|
||||
buffer_num = 0
|
||||
|
||||
class CloudHandler(BaseHTTPRequestHandler):
|
||||
protocol_version = 'HTTP/1.1'
|
||||
@@ -32,66 +87,47 @@ class CloudHandler(BaseHTTPRequestHandler):
|
||||
super().setup()
|
||||
print(f"connection established with {self.client_address}, socket: {self.connection.fileno()}")
|
||||
|
||||
def get_data(self):
|
||||
content_len = self.headers.get('Content-Length')
|
||||
assert content_len is not None
|
||||
return self.rfile.read(int(content_len))
|
||||
def get_json(self): return json.loads(self.get_data())
|
||||
|
||||
def _fail(self):
|
||||
self.send_response(404)
|
||||
self.end_headers()
|
||||
return 0
|
||||
|
||||
def _do(self, method):
|
||||
session = CloudHandler.sessions[unwrap(self.headers.get("Cookie")).split("session=")[1]]
|
||||
ret = b""
|
||||
if self.path == "/renderer" and method == "GET":
|
||||
ret, status_code = b"", 200
|
||||
if self.path == "/batch" and method == "POST":
|
||||
# TODO: streaming deserialize?
|
||||
req = BatchRequest().deserialize(self.rfile.read(int(unwrap(self.headers.get('Content-Length')))))
|
||||
# the cmds are always last (currently in datahash)
|
||||
for c in req._q:
|
||||
if DEBUG >= 1: print(c)
|
||||
match c:
|
||||
case BufferAlloc():
|
||||
assert c.buffer_num not in session.buffers, f"buffer {c.buffer_num} already allocated"
|
||||
session.buffers[c.buffer_num] = (Device[CloudHandler.dname].allocator.alloc(c.size, c.options), c.size, c.options)
|
||||
case BufferFree():
|
||||
buf,sz,buffer_options = session.buffers[c.buffer_num]
|
||||
Device[CloudHandler.dname].allocator.free(buf,sz,buffer_options)
|
||||
del session.buffers[c.buffer_num]
|
||||
case CopyIn(): Device[CloudHandler.dname].allocator.copyin(session.buffers[c.buffer_num][0], memoryview(bytearray(req._h[c.datahash])))
|
||||
case CopyOut():
|
||||
buf,sz,_ = session.buffers[c.buffer_num]
|
||||
Device[CloudHandler.dname].allocator.copyout(memoryview(ret:=bytearray(sz)), buf)
|
||||
case ProgramAlloc():
|
||||
lib = Device[CloudHandler.dname].compiler.compile_cached(req._h[c.datahash].decode())
|
||||
session.programs[(c.name, c.datahash)] = Device[CloudHandler.dname].runtime(c.name, lib)
|
||||
case ProgramFree(): del session.programs[(c.name, c.datahash)]
|
||||
case ProgramExec():
|
||||
bufs = [session.buffers[x][0] for x in c.bufs]
|
||||
extra_args = {k:v for k,v in [("global_size", c.global_size), ("local_size", c.local_size)] if v is not None}
|
||||
r = session.programs[(c.name, c.datahash)](*bufs, vals=c.vals, wait=c.wait, **extra_args)
|
||||
if r is not None: ret = str(r).encode()
|
||||
elif self.path == "/renderer" and method == "GET":
|
||||
cls, args = Device[CloudHandler.dname].renderer.__reduce__()
|
||||
ret = json.dumps((cls.__module__, cls.__name__, args)).encode()
|
||||
elif self.path.startswith("/alloc") and method == "POST":
|
||||
size = int(self.path.split("=")[-1])
|
||||
buffer_options: Optional[BufferOptions] = None
|
||||
if 'image' in self.path:
|
||||
image_shape = tuple([int(x) for x in self.path.split("=")[-2].split("&")[0].split(",")])
|
||||
buffer_options = BufferOptions(image=dtypes.imageh(image_shape) if prod(image_shape)*2 == size else dtypes.imagef(image_shape))
|
||||
session.buffer_num += 1
|
||||
session.buffers[session.buffer_num] = (Device[CloudHandler.dname].allocator.alloc(size, buffer_options), size, buffer_options)
|
||||
ret = str(session.buffer_num).encode()
|
||||
elif self.path.startswith("/buffer"):
|
||||
key = int(self.path.split("/")[-1])
|
||||
buf,sz,buffer_options = session.buffers[key]
|
||||
if method == "GET": Device[CloudHandler.dname].allocator.copyout(memoryview(ret:=bytearray(sz)), buf)
|
||||
elif method == "PUT": Device[CloudHandler.dname].allocator.copyin(buf, memoryview(bytearray(self.get_data())))
|
||||
elif method == "DELETE":
|
||||
Device[CloudHandler.dname].allocator.free(buf,sz,buffer_options)
|
||||
del session.buffers[key]
|
||||
else: return self._fail()
|
||||
elif self.path.startswith("/program"):
|
||||
name, hsh = self.path.split("/")[-2:]
|
||||
if method == "PUT":
|
||||
src = self.get_data()
|
||||
assert hashlib.sha256(src).hexdigest() == hsh
|
||||
lib = Device[CloudHandler.dname].compiler.compile_cached(src.decode())
|
||||
session.programs[(name, hsh)] = Device[CloudHandler.dname].runtime(name, lib)
|
||||
elif method == "POST":
|
||||
j = self.get_json()
|
||||
bufs = [session.buffers[x][0] for x in j['bufs']]
|
||||
del j['bufs']
|
||||
r = session.programs[(name, hsh)](*bufs, **j)
|
||||
if r is not None: ret = str(r).encode()
|
||||
elif method == "DELETE": del session.programs[(name, hsh)]
|
||||
else: return self._fail()
|
||||
else: return self._fail()
|
||||
self.send_response(200)
|
||||
else: status_code = 404
|
||||
self.send_response(status_code)
|
||||
self.send_header('Content-Length', str(len(ret)))
|
||||
self.end_headers()
|
||||
return self.wfile.write(ret)
|
||||
|
||||
def do_GET(self): return self._do("GET")
|
||||
def do_POST(self): return self._do("POST")
|
||||
def do_PUT(self): return self._do("PUT")
|
||||
def do_DELETE(self): return self._do("DELETE")
|
||||
|
||||
def cloud_server(port:int):
|
||||
multiprocessing.current_process().name = "MainProcess"
|
||||
@@ -106,44 +142,46 @@ class CloudAllocator(Allocator):
|
||||
def __init__(self, device:CloudDevice):
|
||||
self.device = device
|
||||
super().__init__()
|
||||
def _alloc(self, size:int, options) -> int:
|
||||
# TODO: ideally we shouldn't have to deal with images here
|
||||
extra = ("image="+','.join([str(x) for x in options.image.shape])+"&") if options.image is not None else ""
|
||||
return int(self.device.send("POST", f"alloc?{extra}size={size}"))
|
||||
def _free(self, opaque, options):
|
||||
with contextlib.suppress(ConnectionRefusedError, http.client.CannotSendRequest, http.client.RemoteDisconnected):
|
||||
self.device.send("DELETE", f"buffer/{opaque}", data=b"")
|
||||
def copyin(self, dest:int, src:memoryview): self.device.send("PUT", f"buffer/{dest}", data=bytes(src))
|
||||
# TODO: ideally we shouldn't have to deal with images here
|
||||
def _alloc(self, size:int, options:BufferOptions) -> int:
|
||||
self.device.buffer_num += 1
|
||||
self.device.req.q(BufferAlloc(self.device.buffer_num, size, options))
|
||||
return self.device.buffer_num
|
||||
# TODO: options should not be here in any Allocator
|
||||
def _free(self, opaque:int, options): self.device.req.q(BufferFree(opaque))
|
||||
def copyin(self, dest:int, src:memoryview): self.device.req.q(CopyIn(dest, self.device.req.h(bytes(src))))
|
||||
def copyout(self, dest:memoryview, src:int):
|
||||
resp = self.device.send("GET", f"buffer/{src}")
|
||||
self.device.req.q(CopyOut(src))
|
||||
resp = self.device.batch_submit()
|
||||
assert len(resp) == len(dest), f"buffer length mismatch {len(resp)} != {len(dest)}"
|
||||
dest[:] = resp
|
||||
|
||||
class CloudProgram:
|
||||
def __init__(self, device:CloudDevice, name:str, lib:bytes):
|
||||
self.device = device
|
||||
self.prgid = f"{name}/{hashlib.sha256(lib).hexdigest()}"
|
||||
self.device.send("PUT", "program/"+self.prgid, lib)
|
||||
self.device, self.name = device, name
|
||||
self.datahash = self.device.req.h(lib)
|
||||
self.device.req.q(ProgramAlloc(self.name, self.datahash))
|
||||
super().__init__()
|
||||
def __del__(self): self.device.send("DELETE", "program/"+self.prgid)
|
||||
def __del__(self): self.device.req.q(ProgramFree(self.name, self.datahash))
|
||||
|
||||
def __call__(self, *bufs, global_size=None, local_size=None, vals:Tuple[int, ...]=(), wait=False):
|
||||
args = {"bufs": bufs, "vals": vals, "wait": wait}
|
||||
if global_size is not None: args["global_size"] = global_size
|
||||
if local_size is not None: args["local_size"] = local_size
|
||||
ret = self.device.send("POST", "program/"+self.prgid, json.dumps(args).encode())
|
||||
if wait: return float(ret)
|
||||
self.device.req.q(ProgramExec(self.name, self.datahash, bufs, vals, global_size, local_size, wait))
|
||||
if wait: return float(self.device.batch_submit())
|
||||
|
||||
class CloudDevice(Compiled):
|
||||
def __init__(self, device:str):
|
||||
if (host:=getenv("HOST", "")) != "":
|
||||
self.host = host
|
||||
if (host:=getenv("HOST", "")) != "": self.host = host
|
||||
else:
|
||||
p = multiprocessing.Process(target=cloud_server, args=(6667,))
|
||||
p.daemon = True
|
||||
p.start()
|
||||
self.host = "127.0.0.1:6667"
|
||||
self.cookie = binascii.hexlify(os.urandom(0x10)).decode()
|
||||
|
||||
# state for the connection
|
||||
self.session = binascii.hexlify(os.urandom(0x10)).decode()
|
||||
self.buffer_num = 0
|
||||
self.req: BatchRequest = BatchRequest()
|
||||
|
||||
if DEBUG >= 1: print(f"cloud with host {self.host}")
|
||||
while 1:
|
||||
try:
|
||||
@@ -155,13 +193,26 @@ class CloudDevice(Compiled):
|
||||
time.sleep(0.1)
|
||||
if DEBUG >= 1: print(f"remote has device {clouddev}")
|
||||
# TODO: how to we have BEAM be cached on the backend? this should just send a specification of the compute. rethink what goes in Renderer
|
||||
assert clouddev[0].startswith("tinygrad.renderer."), f"bad renderer {clouddev}"
|
||||
renderer = fromimport(clouddev[0], clouddev[1])(*clouddev[2])
|
||||
super().__init__(device, CloudAllocator(self), renderer, Compiler(), functools.partial(CloudProgram, self))
|
||||
if not clouddev[0].startswith("tinygrad.renderer.") or not clouddev[1].endswith("Renderer"): raise RuntimeError(f"bad renderer {clouddev}")
|
||||
renderer_class = fromimport(clouddev[0], clouddev[1]) # TODO: is this secure?
|
||||
if not issubclass(renderer_class, Renderer): raise RuntimeError(f"renderer isn't a Renderer {clouddev}")
|
||||
super().__init__(device, CloudAllocator(self), renderer_class(*clouddev[2]), Compiler(), functools.partial(CloudProgram, self))
|
||||
|
||||
def __del__(self):
|
||||
# TODO: this is never being called
|
||||
# TODO: should close the whole session
|
||||
with contextlib.suppress(ConnectionRefusedError, http.client.CannotSendRequest, http.client.RemoteDisconnected): self.batch_submit()
|
||||
|
||||
def batch_submit(self):
|
||||
data = self.req.serialize()
|
||||
with Timing(f"*** send {len(self.req._q):-3d} requests {len(self.req._h):-3d} hashes with len {len(data)/1024:.2f} kB in ", enabled=DEBUG>=1):
|
||||
ret = self.send("POST", "batch", data)
|
||||
self.req = BatchRequest()
|
||||
return ret
|
||||
|
||||
def send(self, method, path, data:Optional[bytes]=None) -> bytes:
|
||||
# TODO: retry logic
|
||||
self.conn.request(method, "/"+path, data, headers={"Cookie": f"session={self.cookie}"})
|
||||
self.conn.request(method, "/"+path, data, headers={"Cookie": f"session={self.session}"})
|
||||
response = self.conn.getresponse()
|
||||
assert response.status == 200, f"failed on {method} {path}"
|
||||
return response.read()
|
||||
|
||||
Reference in New Issue
Block a user