ops_ext to replace cpu import (#3409)

* ops_ext to replace cpu import

* don't allow zero copy with as buffer

* memoryview(bytearray

* reenable test

* fix jit issue
This commit is contained in:
George Hotz
2024-02-15 13:03:42 +01:00
committed by GitHub
parent ede4fd4705
commit a40df14fef
11 changed files with 35 additions and 12 deletions

View File

@@ -135,8 +135,8 @@ assert len(lazyop.srcs) == 2
# the source is a LazyBuffer that is a "CPU" Tensor # the source is a LazyBuffer that is a "CPU" Tensor
# again, a LazyOp AST is like a GPU kernel. you have to copy the data on the device first # again, a LazyOp AST is like a GPU kernel. you have to copy the data on the device first
assert lazyop.srcs[0].op == LoadOps.COPY assert lazyop.srcs[0].op == LoadOps.COPY
assert lazyop.srcs[0].srcs[0].device == "CPU" assert lazyop.srcs[0].srcs[0].device == "EXT"
assert lazyop.srcs[0].srcs[0].realized._buf[0] == 2, "the src of the COPY LazyOP is a LazyBuffer on the CPU holding [2]" assert lazyop.srcs[0].srcs[0].realized._buf[0][0] == 2, "the src of the COPY LazyOP is a LazyBuffer on the CPU holding [2]"
assert result.lazydata.base.realized is None, "the LazyBuffer is not realized yet" assert result.lazydata.base.realized is None, "the LazyBuffer is not realized yet"
# now we realize the LazyBuffer # now we realize the LazyBuffer

View File

@@ -7,7 +7,7 @@ def multidevice_test(fxn):
exclude_devices = getenv("EXCLUDE_DEVICES", "").split(",") exclude_devices = getenv("EXCLUDE_DEVICES", "").split(",")
def ret(self): def ret(self):
for device in Device._devices: for device in Device._devices:
if device in ["DISK", "FAKE"]: continue if device in ["DISK", "EXT", "FAKE"]: continue
if not CI: print(device) if not CI: print(device)
if device in exclude_devices: if device in exclude_devices:
if not CI: print(f"WARNING: {device} test is excluded") if not CI: print(f"WARNING: {device} test is excluded")

View File

@@ -350,7 +350,7 @@ class TestSchedule(unittest.TestCase):
def test_double_from(self): def test_double_from(self):
x = Tensor([1,2,3,4]) x = Tensor([1,2,3,4])
out = x.to('cpu') out = x.to('ext')
check_schedule(out, 0, filter_loadops=False) check_schedule(out, 0, filter_loadops=False)
def test_pow_const_tensor_simplified(self): def test_pow_const_tensor_simplified(self):

View File

@@ -68,7 +68,7 @@ class LinearizerOptions(NamedTuple):
class Kernel: class Kernel:
def __init__(self, ast:LazyOp, opts:Optional[LinearizerOptions]=None): def __init__(self, ast:LazyOp, opts:Optional[LinearizerOptions]=None):
self.opts = opts or (device.compiler.linearizer_opts if isinstance(device:=Device[Device.DEFAULT], Compiled) else self.opts = opts or (device.compiler.linearizer_opts if isinstance(device:=Device[Device.DEFAULT], Compiled) and device.compiler is not None else
LinearizerOptions(Device.DEFAULT)) LinearizerOptions(Device.DEFAULT))
self.ast = ast self.ast = ast
assert ast.op == BufferOps.STORE, f"kernels must have a store as the output, got {ast.op}" assert ast.op == BufferOps.STORE, f"kernels must have a store as the output, got {ast.op}"

View File

@@ -281,6 +281,7 @@ class CompiledASTRunner(JITRunner):
if local_size is not None: local_size = local_size + [1]*(3-len(local_size)) if local_size is not None: local_size = local_size + [1]*(3-len(local_size))
self.name, self.display_name, self.prg, self.device, self.global_size, self.local_size, self.first_run = \ self.name, self.display_name, self.prg, self.device, self.global_size, self.local_size, self.first_run = \
to_function_name(name), name, prg, device, global_size, local_size, True to_function_name(name), name, prg, device, global_size, local_size, True
assert self.device.compiler is not None, "compiler is reuired to make an AST kernel"
lib:bytes = precompiled if precompiled is not None else self.device.compiler.compile_cached(prg) lib:bytes = precompiled if precompiled is not None else self.device.compiler.compile_cached(prg)
self.lib, self.clprg = lib, self.device.runtime(self.name, lib) self.lib, self.clprg = lib, self.device.runtime(self.name, lib)
self.vars: List[Variable] = [] self.vars: List[Variable] = []
@@ -312,15 +313,17 @@ class CompiledASTRunner(JITRunner):
return et return et
class Compiled: class Compiled:
def __init__(self, device:str, allocator:Allocator, compiler:Compiler, runtime, graph=None): def __init__(self, device:str, allocator:Allocator, compiler:Optional[Compiler], runtime, graph=None):
self.dname, self.allocator, self.compiler, self.runtime, self.graph = device, allocator, compiler, runtime, graph self.dname, self.allocator, self.compiler, self.runtime, self.graph = device, allocator, compiler, runtime, graph
def synchronize(self): pass # override this in your device def synchronize(self): pass # override this in your device
def to_program(self, k:Linearizer) -> CompiledASTRunner: def to_program(self, k:Linearizer) -> CompiledASTRunner:
assert self.compiler is not None, "compiler is required to run AST"
k.linearize() k.linearize()
return CompiledASTRunner(k.ast, k.name, self.compiler.render(to_function_name(k.name), k.uops), self, k.global_size, k.local_size) return CompiledASTRunner(k.ast, k.name, self.compiler.render(to_function_name(k.name), k.uops), self, k.global_size, k.local_size)
def get_linearizer(self, ast:LazyOp) -> Linearizer: def get_linearizer(self, ast:LazyOp) -> Linearizer:
assert self.compiler is not None, "compiler is required to build AST"
if DEBUG >= 3: if DEBUG >= 3:
from tinygrad.features.graph import print_tree from tinygrad.features.graph import print_tree
print_tree(ast) print_tree(ast)

View File

@@ -165,7 +165,9 @@ class _CacheCollector:
def add(self, prg, rawbufs, var_vals): def add(self, prg, rawbufs, var_vals):
if self.cache is None: return if self.cache is None: return
for k,v in var_vals.items(): assert k in self.var_vals and self.var_vals[k] == v, f"var_vals {k} mismatch {v} != {self.var_vals.get(k)}" for k,v in var_vals.items(): assert k in self.var_vals and self.var_vals[k] == v, f"var_vals {k} mismatch {v} != {self.var_vals.get(k)}"
if len(rawbufs): self.placeholders[rawbufs[0]] = PlaceHolder(rawbufs[0]) # NOTE: this is making an assumption that 0 is special # NOTE: this is making an assumption that 0 is special
# TODO: this is wrong for sync and wait
if len(rawbufs): self.placeholders[rawbufs[0]] = PlaceHolder(rawbufs[0])
self.cache.append((prg, [self.placeholders.get(x, x) if isinstance(x, Buffer) else x for x in rawbufs])) self.cache.append((prg, [self.placeholders.get(x, x) if isinstance(x, Buffer) else x for x in rawbufs]))
def finish(self) -> List[JitItem]: def finish(self) -> List[JitItem]:

View File

@@ -153,7 +153,7 @@ def time_linearizer(lin:Linearizer, rawbufs:List[Buffer], allow_test_size=True,
if not disable_cache and CACHELEVEL >= 2 and (val:=diskcache_get("time_linearizer", key)) is not None: return min(val) if not disable_cache and CACHELEVEL >= 2 and (val:=diskcache_get("time_linearizer", key)) is not None: return min(val)
dev = Device[lin.opts.device] dev = Device[lin.opts.device]
assert isinstance(dev, Compiled) assert isinstance(dev, Compiled) and dev.compiler is not None
var_vals = {k:(k.max+k.min)//2 for k in lin.ast.vars()} var_vals = {k:(k.max+k.min)//2 for k in lin.ast.vars()}
lib, global_size, local_size = _compile_linearizer(dev.compiler, lin) lib, global_size, local_size = _compile_linearizer(dev.compiler, lin)

View File

@@ -73,6 +73,9 @@ class LazyBuffer:
def _copy(self, device:str) -> LazyBuffer: def _copy(self, device:str) -> LazyBuffer:
sync_size = 1 if self.device.startswith("HIP") else 0 sync_size = 1 if self.device.startswith("HIP") else 0
if self.device.startswith("EXT"):
# EXT doesn't sync
return create_lazybuffer(device, ShapeTracker.from_shape(self.shape), self.dtype, LoadOps.COPY, None, (self,), enable_cache=False)
sync = LazyBuffer.loadop(LoadOps.SYNC, (sync_size,), dtypes.uint32, self.device, src=self, enable_cache=True) sync = LazyBuffer.loadop(LoadOps.SYNC, (sync_size,), dtypes.uint32, self.device, src=self, enable_cache=True)
wait = LazyBuffer.loadop(LoadOps.WAIT, (0,), dtypes.uint32, device, src=sync, enable_cache=True) wait = LazyBuffer.loadop(LoadOps.WAIT, (0,), dtypes.uint32, device, src=sync, enable_cache=True)
return create_lazybuffer(device, ShapeTracker.from_shape(self.shape), self.dtype, LoadOps.COPY, None, (self, wait), enable_cache=False) return create_lazybuffer(device, ShapeTracker.from_shape(self.shape), self.dtype, LoadOps.COPY, None, (self, wait), enable_cache=False)

View File

@@ -0,0 +1,12 @@
from typing import Tuple, Any
from tinygrad.device import Compiled, Allocator
# the Any is an arbitrary object that's kept in scope with the memoryview
class ExtAllocator(Allocator):
# NOTE: this doesn't work with allow_zero_copy, it's read only somehow
#def as_buffer(self, src:Tuple[memoryview, Any]) -> memoryview: return src[0]
def copyin(self, dest:Tuple[memoryview, Any], src:memoryview): dest[0][:] = src
def copyout(self, dest:memoryview, src:Tuple[memoryview, Any]): dest[:] = src[0]
class ExtDevice(Compiled):
def __init__(self, device:str): super().__init__(device, ExtAllocator(), None, None)

View File

@@ -93,7 +93,7 @@ class PythonProgram:
ul[i] = [pbufs.pop(0).cast(dtype.fmt)] * warp_size ul[i] = [pbufs.pop(0).cast(dtype.fmt)] * warp_size
elif uop is UOps.DEFINE_LOCAL: elif uop is UOps.DEFINE_LOCAL:
assert dtype.fmt is not None assert dtype.fmt is not None
lbuf = memoryview(bytearray(arg[1]*dtype.sz)) lbuf = memoryview(bytearray(arg[1]*dtype.itemsize))
ul[i] = [lbuf.cast(dtype.fmt)] * warp_size ul[i] = [lbuf.cast(dtype.fmt)] * warp_size
elif uop is UOps.SPECIAL: elif uop is UOps.SPECIAL:
if arg[1][0] == 'g': if arg[1][0] == 'g':

View File

@@ -7,7 +7,7 @@ from functools import partialmethod, reduce
import numpy as np import numpy as np
from tinygrad.dtype import DType, dtypes, ImageDType, Scalar, least_upper_float, least_upper_dtype from tinygrad.dtype import DType, dtypes, ImageDType, Scalar, least_upper_float, least_upper_dtype
from tinygrad.helpers import argfix, make_pair, getenv, IMAGE, DEBUG, WINO, flatten, prod, all_int, round_up, merge_dicts, fully_flatten from tinygrad.helpers import argfix, make_pair, getenv, IMAGE, DEBUG, WINO, flatten, prod, all_int, round_up, merge_dicts, fully_flatten, flat_mv
from tinygrad.lazy import LazyBuffer from tinygrad.lazy import LazyBuffer
from tinygrad.features.multi import MultiLazyBuffer from tinygrad.features.multi import MultiLazyBuffer
from tinygrad.ops import LoadOps from tinygrad.ops import LoadOps
@@ -42,8 +42,11 @@ def _loadop(op, shape:Tuple[sint,...], dtype:DType, device:Union[str, Tuple[str,
return MultiLazyBuffer([LazyBuffer.loadop(op, shape, dtype, d, arg, src) for d in device], None) return MultiLazyBuffer([LazyBuffer.loadop(op, shape, dtype, d, arg, src) for d in device], None)
def _fromcpu(x: np.ndarray) -> LazyBuffer: def _fromcpu(x: np.ndarray) -> LazyBuffer:
ret = LazyBuffer.loadop(LoadOps.EMPTY, x.shape, dtypes.from_np(x.dtype), "CPU") ret = LazyBuffer.loadop(LoadOps.EMPTY, x.shape, dtypes.from_np(x.dtype), "EXT")
ret.realized = Buffer("CPU", prod(x.shape), dtypes.from_np(x.dtype), x.flatten()) if x.size == 0:
ret.realized = Buffer("EXT", 0, dtypes.from_np(x.dtype), (memoryview(bytearray()), None))
else:
ret.realized = Buffer("EXT", prod(x.shape), dtypes.from_np(x.dtype), (flat_mv(np.require(x, requirements='C').data), x))
return ret return ret
def _get_winograd_matcols(mat, dims:int, shp:Tuple[sint, ...], device:Union[str, Tuple[str, ...]]) -> List[List[Tensor]]: def _get_winograd_matcols(mat, dims:int, shp:Tuple[sint, ...], device:Union[str, Tuple[str, ...]]) -> List[List[Tensor]]: