mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-02-15 09:05:40 -05:00
ops_ext to replace cpu import (#3409)
* ops_ext to replace cpu import * don't allow zero copy with as buffer * memoryview(bytearray * reenable test * fix jit issue
This commit is contained in:
12
tinygrad/runtime/ops_ext.py
Normal file
12
tinygrad/runtime/ops_ext.py
Normal file
@@ -0,0 +1,12 @@
|
||||
from typing import Tuple, Any
|
||||
from tinygrad.device import Compiled, Allocator
|
||||
|
||||
# the Any is an arbitrary object that's kept in scope with the memoryview
|
||||
class ExtAllocator(Allocator):
|
||||
# NOTE: this doesn't work with allow_zero_copy, it's read only somehow
|
||||
#def as_buffer(self, src:Tuple[memoryview, Any]) -> memoryview: return src[0]
|
||||
def copyin(self, dest:Tuple[memoryview, Any], src:memoryview): dest[0][:] = src
|
||||
def copyout(self, dest:memoryview, src:Tuple[memoryview, Any]): dest[:] = src[0]
|
||||
|
||||
class ExtDevice(Compiled):
|
||||
def __init__(self, device:str): super().__init__(device, ExtAllocator(), None, None)
|
||||
@@ -93,7 +93,7 @@ class PythonProgram:
|
||||
ul[i] = [pbufs.pop(0).cast(dtype.fmt)] * warp_size
|
||||
elif uop is UOps.DEFINE_LOCAL:
|
||||
assert dtype.fmt is not None
|
||||
lbuf = memoryview(bytearray(arg[1]*dtype.sz))
|
||||
lbuf = memoryview(bytearray(arg[1]*dtype.itemsize))
|
||||
ul[i] = [lbuf.cast(dtype.fmt)] * warp_size
|
||||
elif uop is UOps.SPECIAL:
|
||||
if arg[1][0] == 'g':
|
||||
|
||||
Reference in New Issue
Block a user