From c32ea95d7daeef5e6fdc752d8e474f9fbc7a3ea7 Mon Sep 17 00:00:00 2001 From: George Hotz <72895+geohot@users.noreply.github.com> Date: Thu, 8 Feb 2024 19:24:55 +0100 Subject: [PATCH] Python uop emulator (#3327) * start uop emu * tiny_add passes * more ops * emulate the whole warp * test_gemm passes * metal gemm test pass * works on big gemm * works on big gemm * more tests pass * touch ups * fix mypy * cleanups * exp2 mypy * arch is where it belongs * actually emulate tensor cores * fix test * new style --- .github/workflows/test.yml | 2 + test/test_linearizer.py | 9 +- test/test_ops.py | 6 ++ tinygrad/codegen/kernel.py | 12 +-- tinygrad/runtime/ops_hip.py | 61 ++++++------ tinygrad/runtime/ops_metal.py | 2 +- tinygrad/runtime/ops_python.py | 166 +++++++++++++++++++++++++++++++++ 7 files changed, 216 insertions(+), 42 deletions(-) create mode 100644 tinygrad/runtime/ops_python.py diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index a1f8051a12..469467fc58 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -55,6 +55,8 @@ jobs: PYTHONPATH="." python test/external/fuzz_shapetracker_math.py - name: Test shapetracker to_movement_ops run: PYTHONPATH="." python extra/to_movement_ops.py + - name: Test emulated METAL tensor cores + run: DEBUG=2 EMULATE_METAL=1 FORWARD_ONLY=1 PYTHON=1 python3 test/test_ops.py TestOps.test_big_gemm - name: Use as an external package run: | mkdir $HOME/test_external_dir diff --git a/test/test_linearizer.py b/test/test_linearizer.py index 8ad3b4dd1b..fc6cfc0744 100644 --- a/test/test_linearizer.py +++ b/test/test_linearizer.py @@ -1,5 +1,5 @@ import numpy as np -import unittest, os +import unittest from tinygrad.codegen.kernel import Opt, OptOps, tensor_cores from tinygrad.codegen.linearizer import Linearizer, UOp, UOps, expand_node @@ -103,10 +103,10 @@ class TestLinearizer(unittest.TestCase): d, w = Tensor.rand(4, 8, 8, 8, dtype=tensor_dtype), Tensor.rand(8, 8, 2, 2, dtype=tensor_dtype) helper_arg_acc_dtype(d.conv2d(w, acc_dtype=acc_dtype), expected_dtype) - @unittest.skipUnless(Device.DEFAULT in tensor_cores, "No tensor cores for device") def test_tensor_cores(self): + if not Device[Device.DEFAULT].compiler.linearizer_opts.has_tensor_cores: + self.skipTest("device doesn't have tensor cores") for tc in tensor_cores[Device.DEFAULT]: - if tc.arch is not None and tc.arch != os.uname().machine: continue a, b = Tensor.rand(tc.dims[1], tc.dims[2], dtype=tc.dtype_in), Tensor.rand(tc.dims[2], tc.dims[0], dtype=tc.dtype_in) np_a, np_b = a.numpy(), b.numpy() r = a.matmul(b, acc_dtype=tc.dtype_out) @@ -529,13 +529,14 @@ class TestLinearizerOpts(unittest.TestCase): def test_tensor_core_opts(self): if not Device[Device.DEFAULT].compiler.linearizer_opts.has_local: self.skipTest("Only Compiled uses linearizer with locals") + if not Device[Device.DEFAULT].compiler.linearizer_opts.has_tensor_cores: + self.skipTest("device doesn't have tensor cores") if Device.DEFAULT not in tensor_cores: self.skipTest("No tensor cores for device") N = 128 Tensor.manual_seed(1552) for tc in tensor_cores[Device.DEFAULT]: - if tc.arch is not None and tc.arch != os.uname().machine: continue a, b = Tensor.rand(N, N, dtype=tc.dtype_in), Tensor.rand(N, N, dtype=tc.dtype_in) r = a.matmul(b, acc_dtype=tc.dtype_out) (atol, rtol) = ((0.25, 0.01) if tc.dtype_out == dtypes.half else (3e-2, 1e-3)) if tc.dtype_in == dtypes.half else (1e-4, 1e-4) diff --git a/test/test_ops.py b/test/test_ops.py index 81143dec4f..e48c95e9a9 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -286,6 +286,9 @@ class TestOps(unittest.TestCase): helper_test_op(None, torch.minimum, Tensor.minimum, vals=[[True, False, False], True], forward_only=True) helper_test_op(None, torch.minimum, Tensor.minimum, vals=[[True, False, False], [True, True, False]], forward_only=True) + def test_tiny_add(self): + helper_test_op([(3), (3)], lambda x,y: x+y, Tensor.add, forward_only=True) + def test_add(self): helper_test_op([(45,68), (45,68)], lambda x,y: x+y, Tensor.add) helper_test_op([(45,68), (45,68)], lambda x,y: x+y) @@ -631,6 +634,9 @@ class TestOps(unittest.TestCase): helper_test_op([(4,3), (1,3,3,5)], lambda x,y: x.matmul(y), Tensor.dot, atol=1e-4) def test_small_gemm(self): helper_test_op([(8,8), (8,8)], lambda x,y: x.matmul(y), lambda x,y: x@y, atol=1e-3) + def test_small_gemm_range(self): + helper_test_op(None, lambda x,y: x.matmul(y), lambda x,y: x@y, atol=1e-3, vals=[np.arange(0,64,dtype=np.float32).reshape(8,8), + np.arange(64,128,dtype=np.float32).reshape(8,8)]) def test_small_gemm_eye(self): helper_test_op(None, lambda x,y: x.matmul(y), lambda x,y: x@y, atol=1e-3, vals=[np.eye(8).astype(np.float32), np.eye(8).astype(np.float32)]) def test_gemm(self): diff --git a/tinygrad/codegen/kernel.py b/tinygrad/codegen/kernel.py index fafe356c27..450997482c 100644 --- a/tinygrad/codegen/kernel.py +++ b/tinygrad/codegen/kernel.py @@ -1,5 +1,5 @@ from __future__ import annotations -import os, math, itertools +import math, itertools from typing import NamedTuple, Optional, List, Tuple, cast, Dict, Union from tinygrad.ops import LazyOp, FlopCounter, get_lazyop_info, UnaryOps, BinaryOps, ReduceOps, MemBuffer, ConstBuffer, BufferOps from tinygrad.device import Device, Compiled @@ -33,14 +33,13 @@ class TensorCore: thread_local_aliases: List[List[List[int]]] # a list of [threads_1, ..., threads_n, upcast_1(unrolled), upcast_2(upcast)] defining the alias (-1 is upcast, 1-n is warp threads) for each TC dim # noqa: E501 thread_local_sizes: List[int] # in each thread, the number of elements stored in registers for each TC dim wmma_func: str # name of wmma function to call - arch: Optional[str] = None def __str__(self): return f"tensor_core<{self.dims}, {self.dtype_in}, {self.dtype_out}>" tensor_cores: Dict[str, List[TensorCore]] = { "METAL": [ - TensorCore(dims=[8,8,8], dtype_in=dtypes.float, dtype_out=dtypes.float, wmma_func="__metal_wmma", upcast_dim=0, threads=[(0,2),(1,4),(0,2),(1,2)], thread_local_sizes=[2,2,2], thread_local_aliases= [ [[4],[0],[2],[0],[-1, 1, 3],[0]], [[0],[3],[0],[1],[2, 4],[-1]], [[4],[3],[2],[1],[0],[-1]] ], arch="arm64"), # noqa: E501 - TensorCore(dims=[8,8,8], dtype_in=dtypes.half, dtype_out=dtypes.float, wmma_func="__metal_wmma", upcast_dim=0, threads=[(0,2),(1,4),(0,2),(1,2)], thread_local_sizes=[2,2,2], thread_local_aliases= [ [[4],[0],[2],[0],[-1, 1, 3],[0]], [[0],[3],[0],[1],[2, 4],[-1]], [[4],[3],[2],[1],[0],[-1]] ], arch="arm64"), # noqa: E501 - TensorCore(dims=[8,8,8], dtype_in=dtypes.half, dtype_out=dtypes.half, wmma_func="__metal_wmma", upcast_dim=0, threads=[(0,2),(1,4),(0,2),(1,2)], thread_local_sizes=[2,2,2], thread_local_aliases= [ [[4],[0],[2],[0],[-1, 1, 3],[0]], [[0],[3],[0],[1],[2, 4],[-1]], [[4],[3],[2],[1],[0],[-1]] ], arch="arm64"), # noqa: E501 + TensorCore(dims=[8,8,8], dtype_in=dtypes.float, dtype_out=dtypes.float, wmma_func="__metal_wmma", upcast_dim=0, threads=[(0,2),(1,4),(0,2),(1,2)], thread_local_sizes=[2,2,2], thread_local_aliases= [ [[4],[0],[2],[0],[-1, 1, 3],[0]], [[0],[3],[0],[1],[2, 4],[-1]], [[4],[3],[2],[1],[0],[-1]] ]), # noqa: E501 + TensorCore(dims=[8,8,8], dtype_in=dtypes.half, dtype_out=dtypes.float, wmma_func="__metal_wmma", upcast_dim=0, threads=[(0,2),(1,4),(0,2),(1,2)], thread_local_sizes=[2,2,2], thread_local_aliases= [ [[4],[0],[2],[0],[-1, 1, 3],[0]], [[0],[3],[0],[1],[2, 4],[-1]], [[4],[3],[2],[1],[0],[-1]] ]), # noqa: E501 + TensorCore(dims=[8,8,8], dtype_in=dtypes.half, dtype_out=dtypes.half, wmma_func="__metal_wmma", upcast_dim=0, threads=[(0,2),(1,4),(0,2),(1,2)], thread_local_sizes=[2,2,2], thread_local_aliases= [ [[4],[0],[2],[0],[-1, 1, 3],[0]], [[0],[3],[0],[1],[2, 4],[-1]], [[4],[3],[2],[1],[0],[-1]] ]), # noqa: E501 ], "HIP": [ TensorCore(dims=[16,16,16], dtype_in=dtypes.half, dtype_out=dtypes.float, wmma_func="__builtin_amdgcn_wmma_f32_16x16x16_f16_w32", upcast_dim=1, threads=[(0,16),(1,2)], thread_local_sizes=[16,16,8], thread_local_aliases=[ [[0],[0],[-1],[1]], [[0],[1],[-1],[0]], [[0],[1],[0],[2,-1]] ]), # noqa: E501 @@ -61,6 +60,7 @@ class LinearizerOptions(NamedTuple): supports_float4: bool = True has_local: bool = True has_shared: bool = True + has_tensor_cores: bool = False # NOTE: these two should be in z,y,x(reversed) order for cstyle backends, they are flipped when kernel is rendered global_max: Optional[List[int]] = None local_max: Optional[List[int]] = None @@ -330,9 +330,9 @@ class Kernel: # ******************** high level optimizers ******************** def apply_tensor_cores(self, use_tensor_cores=1, extra_opts:Optional[List[Opt]]=None) -> bool: + if not self.opts.has_tensor_cores and use_tensor_cores != 2: return False if use_tensor_cores and self.opts.has_local and self.reduceop and self.reduceop.op == ReduceOps.SUM and self.opts.device in tensor_cores: for tc in tensor_cores[self.opts.device]: - if not (use_tensor_cores==2 or (tc.arch is None or tc.arch == os.uname().machine)): continue has_cast = tc.dtype_in != tc.dtype_out if has_cast and not(self.reduceop.src[0].op == UnaryOps.CAST and self.reduceop.src[0].arg[0] == tc.dtype_out): continue diff --git a/tinygrad/runtime/ops_hip.py b/tinygrad/runtime/ops_hip.py index 27be2f4229..df0c7ed248 100644 --- a/tinygrad/runtime/ops_hip.py +++ b/tinygrad/runtime/ops_hip.py @@ -11,7 +11,7 @@ from tinygrad.runtime.compiler.hip_comgr import compile_hip class HIPCompiler(Compiler): - linearizer_opts = LinearizerOptions("HIP") + linearizer_opts = LinearizerOptions("HIP", has_tensor_cores=True) def __init__(self, arch:str): self.arch = arch super().__init__(f"compile_hip_{self.arch}") @@ -127,29 +127,6 @@ class HIPAllocator(LRUAllocator): hip_set_device(self.device.device) check(hip.hipMemcpyAsync(dest, src, sz, hip.hipMemcpyDeviceToDevice, None)) -class HIPDevice(Compiled): - def __init__(self, device:str=""): - self.device = int(device.split(":")[1]) if ":" in device else 0 - self.arch = init_c_var(hip.hipDeviceProp_t(), lambda x: check(hip.hipGetDeviceProperties(x, self.device))).gcnArchName.decode() - self.pending_copyin: List[ctypes.c_void_p] = [] - self.track_cross_buffer: List[Any] = [] - self.peers: Set[int] = set() - - from tinygrad.runtime.graph.hip import HIPGraph - super().__init__(device, HIPAllocator(self), HIPCompiler(self.arch), - functools.partial(HIPProgram, self.device), HIPGraph) - def synchronize(self): - hip_set_device(self.device) - check(hip.hipDeviceSynchronize()) - for opaque in self.pending_copyin: check(hip.hipFree(opaque)) - self.track_cross_buffer.clear() - self.pending_copyin.clear() - def enable_peer(self, dnum): - if self.device == dnum or dnum in self.peers: return - hip_set_device(self.device) - check(hip.hipDeviceEnablePeerAccess(dnum, 0)) - self.peers.add(dnum) - class HIPSyncEvent(JITRunner): def __init__(self, lb): self.lb, self.device, self.dname = lb, cast(HIPDevice, Device[lb.device]), lb.device @@ -170,16 +147,38 @@ class HIPWaitEvent(JITRunner): update_stats(colored("wait", "RED"), 0, 0, {}, None, 1, jit, device=self.dname) if getenv("HIPCPU"): - hip = ctypes.CDLL("/usr/local/lib/libremu.so") # type: ignore[assignment] - - class HIPProgram: # type: ignore[no-redef] + rhip = ctypes.CDLL("/usr/local/lib/libremu.so") + class RHIPProgram: def __init__(self, name:str, lib:bytes): self.name, self.lib = name, lib def __call__(self, *args, global_size, local_size, vals=(), wait=False): args = (*args, *vals) - hip.hipModuleLaunchKernel(self.lib, len(self.lib), *global_size, *local_size, 0, None, None, + rhip.hipModuleLaunchKernel(self.lib, len(self.lib), *global_size, *local_size, 0, None, None, len(args), (ctypes.c_void_p * len(args))(*[ctypes.cast(x, ctypes.c_void_p) for x in args])) - class HIPDevice(Compiled): # type: ignore[no-redef] - def __init__(self, device=""): - super().__init__(device, MallocAllocator, HIPCompiler("gfx1100"), HIPProgram) +class HIPDevice(Compiled): + def __init__(self, device:str=""): + self.device = int(device.split(":")[1]) if ":" in device else 0 + self.pending_copyin: List[ctypes.c_void_p] = [] + self.track_cross_buffer: List[Any] = [] + self.peers: Set[int] = set() + + if getenv("HIPCPU"): + super().__init__(device, MallocAllocator, HIPCompiler("gfx1100"), RHIPProgram) + else: + self.arch = init_c_var(hip.hipDeviceProp_t(), lambda x: check(hip.hipGetDeviceProperties(x, self.device))).gcnArchName.decode() + from tinygrad.runtime.graph.hip import HIPGraph + super().__init__(device, HIPAllocator(self), HIPCompiler(self.arch), + functools.partial(HIPProgram, self.device), HIPGraph) + def synchronize(self): + if getenv("HIPCPU"): return + hip_set_device(self.device) + check(hip.hipDeviceSynchronize()) + for opaque in self.pending_copyin: check(hip.hipFree(opaque)) + self.track_cross_buffer.clear() + self.pending_copyin.clear() + def enable_peer(self, dnum): + if self.device == dnum or dnum in self.peers: return + hip_set_device(self.device) + check(hip.hipDeviceEnablePeerAccess(dnum, 0)) + self.peers.add(dnum) diff --git a/tinygrad/runtime/ops_metal.py b/tinygrad/runtime/ops_metal.py index ae6a186f95..f20173a873 100644 --- a/tinygrad/runtime/ops_metal.py +++ b/tinygrad/runtime/ops_metal.py @@ -8,7 +8,7 @@ from tinygrad.device import Compiled, LRUAllocator, Compiler from tinygrad.renderer.cstyle import MetalRenderer class MetalCompiler(Compiler): - linearizer_opts = LinearizerOptions("METAL") + linearizer_opts = LinearizerOptions("METAL", has_tensor_cores=os.uname().machine == "arm64") def __init__(self, device:Optional[MetalDevice]): self.device = device super().__init__("compile_metal") diff --git a/tinygrad/runtime/ops_python.py b/tinygrad/runtime/ops_python.py new file mode 100644 index 0000000000..b086fb0f21 --- /dev/null +++ b/tinygrad/runtime/ops_python.py @@ -0,0 +1,166 @@ +# a python uops emulator +# works to test the tensor cores, and all the uops in general +# this is the (living) definition of uops +from typing import Tuple, List, Optional, Any, Dict +import pickle, base64, itertools, time, math +from tinygrad.dtype import DType, dtypes +from tinygrad.helpers import all_same, getenv +from tinygrad.device import Compiled, Allocator, Compiler +from tinygrad.codegen.uops import UOp, UOps +from tinygrad.ops import UnaryOps, BinaryOps, TernaryOps +from tinygrad.codegen.kernel import LinearizerOptions + +def exec_alu(arg, dtype, p): + # TODO: make this complete and correctly honor the dtypes + # TODO: use this for constant folding + if arg == TernaryOps.MULACC: return p[0]*p[1]+p[2] + if arg == TernaryOps.WHERE: return p[1] if p[0] else p[2] + if arg == UnaryOps.LOG2: return math.log2(p[0]) if p[0] > 0 else math.nan + if arg == UnaryOps.EXP2: return math.exp(p[0]*math.log(2)) + if arg == UnaryOps.SQRT: return math.sqrt(p[0]) if p[0] > 0 else math.nan + if arg == UnaryOps.SIN: return math.sin(p[0]) + if arg == UnaryOps.NEG: return -p[0] + if arg == BinaryOps.MUL: return p[0]*p[1] + if arg == BinaryOps.ADD: return p[0]+p[1] + if arg == BinaryOps.SUB: return p[0]-p[1] + if arg == BinaryOps.XOR: return p[0]^p[1] + if arg == BinaryOps.MAX: return max(p[0], p[1]) + if arg == BinaryOps.CMPEQ: return p[0] == p[1] + if arg == BinaryOps.CMPLT: return p[0] < p[1] + if arg == BinaryOps.DIV: return p[0]//p[1] if dtypes.is_int(dtype) else (p[0]/p[1] if p[1] != 0 else math.nan) + if arg == BinaryOps.MOD: return p[0]%p[1] + raise NotImplementedError(f"no support for {arg}") + +class PythonProgram: + def __init__(self, name:str, lib:bytes): + self.uops: List[Tuple[UOps, Optional[DType], List[int], Any]] = pickle.loads(lib) + def __call__(self, *bufs, global_size:Tuple[int,int,int]=(1,1,1), local_size:Tuple[int,int,int]=(1,1,1), vals:Tuple[int, ...]=(), wait=False): + st = time.perf_counter() + warp = list(itertools.product(*[range(x) for x in local_size[::-1]])) + warp_size = len(warp) + for idxs in itertools.product(*[range(x) for x in global_size[::-1]]): + ul: Dict[int, Any] = {} + dl: Dict[int, DType] = {} + pbufs: List[memoryview] = list(bufs) + i = 0 + loop_ends: Dict[int, int] = {} + while i < len(self.uops): + uop, dtype, idp, arg = self.uops[i] + inp = [ul[v] for v in idp] + dtp = [dl[v] for v in idp] + if uop is UOps.STORE: + if dtp[2].sz > 1: + for j,val in enumerate(inp[2]): + for m,o,v in zip(inp[0], inp[1], val): m[o+j] = v + else: + for m,o,v in zip(*inp): m[o] = v + i += 1 + continue + elif uop is UOps.END: + loop_ends[idp[0]] = i + i = idp[0] + continue + elif uop is UOps.BARRIER: + # in the python emulator, the warp is always in sync + i += 1 + continue + assert dtype is not None, f"{uop} is missing a dtype" + dl[i] = dtype + if uop is UOps.DEFINE_GLOBAL: + assert dtype.fmt is not None + ul[i] = [pbufs.pop(0).cast(dtype.fmt)] * warp_size + elif uop is UOps.DEFINE_LOCAL: + assert dtype.fmt is not None + lbuf = memoryview(bytearray(arg[1]*dtype.sz)) + ul[i] = [lbuf.cast(dtype.fmt)] * warp_size + elif uop is UOps.SPECIAL: + if arg[1][0] == 'g': + ul[i] = [idxs[2-arg[0]]] * warp_size + elif arg[1][0] == 'l': + ul[i] = [x[2-arg[0]] for x in warp] + elif uop is UOps.CONST: ul[i] = [int(arg) if dtypes.is_int(dtype) else float(arg)] * warp_size + elif uop is UOps.DEFINE_ACC: + if dtype.sz > 1: + ul[i] = [[arg] * warp_size for _ in range(dtype.sz)] + else: + ul[i] = [arg] * warp_size + elif uop is UOps.LOOP: + if i not in ul: + ul[i] = [0] * warp_size + else: + for j in range(len(ul[i])): + ul[i][j] += 1 + if ul[i][0] == inp[1][0]: + i = loop_ends[i] + 1 + continue + elif uop is UOps.CAST: + if dtype.sz > 1: + ul[i] = inp + else: + # TODO: add real cast + if dtypes.is_int(dtype): + ul[i] = [int(x) for x in inp[0]] + elif dtypes.is_float(dtype): + ul[i] = [float(x) for x in inp[0]] + else: + ul[i] = inp[0] + elif uop is UOps.LOAD: + if dtype.sz > 1: + ul[i] = [[m[x+j] for m,x in zip(inp[0], inp[1])] for j in range(dtype.sz)] + else: + ul[i] = [m[x] for m,x in zip(inp[0], inp[1])] + elif uop is UOps.PHI: + for j in range(len(inp[0])): + inp[0][j] = inp[1][j] + ul[i] = inp[0] + elif uop is UOps.GEP: + ul[i] = inp[0][arg] + elif uop is UOps.WMMA: + # here are the models for the WMMA instruction on the different hardware + if arg == '__metal_wmma': + order = [0, 32, 1, 33, 8, 40, 9, 41, + 2, 34, 3, 35, 10, 42, 11, 43, + 4, 36, 5, 37, 12, 44, 13, 45, + 6, 38, 7, 39, 14, 46, 15, 47, + 16, 48, 17, 49, 24, 56, 25, 57, + 18, 50, 19, 51, 26, 58, 27, 59, + 20, 52, 21, 53, 28, 60, 29, 61, + 22, 54, 23, 55, 30, 62, 31, 63] + def unswizzle(goff, x): return [x[0][goff+idx] if idx < 32 else + x[1][goff+idx-32] for idx in order] + out = inp[2][0][:], inp[2][1][:] + for goff in range(0, warp_size, 32): + m1,m2 = unswizzle(goff, inp[0]), unswizzle(goff, inp[1]) + for _i in range(8): + for _j in range(8): + oidx = order[_i*8 + _j] + nval = sum(m1[_i*8+_k] * m2[_k*8+_j] for _k in range(8)) + if oidx < 32: out[0][goff+oidx] += nval + else: out[1][goff+oidx-32] += nval + ul[i] = out + else: + raise Exception(f"unimplemented tensor core {arg}") + elif uop is UOps.ALU: + assert all_same([len(x) for x in inp]), f"{[len(x) for x in inp]} doesn't match on {arg}" + assert all_same([dtype] + dtp) or arg in {BinaryOps.CMPEQ, BinaryOps.CMPLT, TernaryOps.WHERE}, f"dtype mismatch on {arg}" + ul[i] = [exec_alu(arg, dtype, p) for p in zip(*inp)] + assert i in ul, (uop, dtype, idp, arg) + #print(i, uop, dtype, arg, ul[i] if i in ul else None) + i += 1 + return time.perf_counter() - st + +class PythonCompiler(Compiler): + linearizer_opts = LinearizerOptions("METAL", has_tensor_cores=True) if getenv("EMULATE_METAL") else LinearizerOptions() + def render(self, name:str, uops:List[UOp]) -> str: + lops = [(u.uop, u.dtype, [uops.index(v) for v in u.vin], u.arg) for u in uops] + return base64.b64encode(pickle.dumps(lops)).decode() + def compile(self, src:str) -> bytes: return base64.b64decode(src) + +class PythonAllocator(Allocator): + def _alloc(self, size): return memoryview(bytearray(size)) + def copyin(self, dest, src:memoryview): dest[:] = src + def copyout(self, dest:memoryview, src): dest[:] = src + +class PythonDevice(Compiled): + def __init__(self, device:str): + super().__init__(device, PythonAllocator(), PythonCompiler(), PythonProgram) \ No newline at end of file