Remove Interpreted device & remaining CPU/TORCH ref (#3423)

* Remove Interpreted device & remaining CPU/TORCH ref

* Oops

* supports_device was useful

* Fix doc wording

---------

Co-authored-by: chenyu <chenyu@fastmail.com>
This commit is contained in:
xarkes
2024-02-16 06:30:21 +01:00
committed by GitHub
parent 6efa68f97b
commit 28a8b72024
15 changed files with 31 additions and 146 deletions

View File

@@ -150,18 +150,10 @@ out = result.lazydata.base.realized.as_buffer().cast('I')
assert out[0] == 5, "when put in numpy, it's 5"
# %%
# == Union[Interpreted, Compiled] (in tinygrad/device.py, code 6/10) ==
# == Compiled (in tinygrad/device.py, code 6/10) ==
# Now you have a choice, you can either write a "Interpreted" backend or "Compiled" backend
# Now you can write a Compiled backend (example: GPU, LLVM or PYTHON)
# Interpreted backends are very simple (example: CPU and TORCH)
class Interpreted:
# and they have a lookup table to functions for the Ops
fxn_for_op: Dict[Op, Callable] = {
UnaryOps.EXP2: lambda x: np.exp2(x),
BinaryOps.ADD: lambda x,y: x+y}
# Compiled backends take a little more (example: GPU and LLVM)
class Compiled:
# a code generator, which compiles the AST
codegen: Type[Linearizer]

View File

@@ -35,7 +35,7 @@ class TinygradBackend(Backend):
@classmethod
def supports_device(cls, device: str) -> bool:
return device == "CPU"
return device == "CLANG"
backend_test = onnx.backend.test.BackendTest(TinygradBackend, __name__)
@@ -48,11 +48,6 @@ backend_test.exclude('test_adam_multiple_cpu')
backend_test.exclude('test_nesterov_momentum_cpu')
# about different dtypes
if Device.DEFAULT in ["TORCH"]:
backend_test.exclude('uint16')
backend_test.exclude('uint32')
backend_test.exclude('uint64')
if Device.DEFAULT in ["METAL"] or (OSX and Device.DEFAULT == "GPU"):
backend_test.exclude('float64')
backend_test.exclude('DOUBLE')

View File

@@ -8,7 +8,7 @@ from tinygrad.features.search import get_linearizer_actions, bufs_from_lin
from tinygrad.tensor import Tensor
from tinygrad.features.graph import print_tree
from tinygrad.helpers import getenv, from_mv, Context
from tinygrad.device import Device, Compiled, Interpreted
from tinygrad.device import Device, Compiled
from tinygrad.codegen.linearizer import UOp
def tuplize_uops(uops:List[UOp]) -> Tuple: return tuple([(x.uop, x.dtype, tuple(uops.index(x) for x in x.vin), x.arg) for x in uops])
@@ -134,7 +134,6 @@ if __name__ == "__main__":
tested = 0
failures = defaultdict(list)
for i, ast in enumerate(ast_strs[:getenv("FUZZ_N", len(ast_strs))]):
if "Variable" in ast and isinstance(device, Interpreted): continue # no symbolic shape for Interpreted
if "dtypes.image" in ast and Device.DEFAULT != "GPU": continue # IMAGE is only for GPU
print(f"testing ast {i}")
tested += 1

View File

@@ -56,7 +56,6 @@ class TestRealWorld(unittest.TestCase):
def test(t, t2): return model(t, 801, t2).realize()
helper_test("test_sd", lambda: (Tensor.randn(1, 4, 64, 64),Tensor.randn(1, 77, 768)), test, 18.0, 953)
@unittest.skipIf(Device.DEFAULT in ["CPU", "TORCH"], "tons of ram with interpreted")
def test_mini_stable_diffusion(self):
model = [ResBlock(16, 24, 16) for _ in range(4)]
derandomize_model(model)
@@ -151,4 +150,4 @@ class TestRealWorld(unittest.TestCase):
dtypes.default_float = dtypes.float32
if __name__ == '__main__':
unittest.main()
unittest.main()

View File

@@ -2,7 +2,7 @@
import unittest
import numpy as np
from tinygrad.tensor import Tensor
from tinygrad import Device, dtypes
from tinygrad import dtypes
N = 200 # has to be bigger than the cache to fail
@@ -20,7 +20,6 @@ class TestAssign(unittest.TestCase):
assert ba1 == ba2 and ba1 != bb1
np.testing.assert_allclose(a.numpy(), (np.arange(N*N)*2).reshape((N,N)))
@unittest.skipIf(Device.DEFAULT == "CPU" or Device.DEFAULT == "TORCH", "questionable tests")
def test_permuted_assignment(self):
a = Tensor(np.arange(N*N, dtype=np.float32)).reshape(N,N)
b = Tensor(np.arange(N*N, dtype=np.float32)).reshape(N,N)

View File

@@ -1,6 +1,6 @@
import unittest
import numpy as np
from tinygrad.tensor import Tensor, Device
from tinygrad.tensor import Tensor
from tinygrad.helpers import Context
class TestConv(unittest.TestCase):
@@ -62,7 +62,7 @@ class TestConv(unittest.TestCase):
np.testing.assert_allclose(r2.numpy(), np.where(out.numpy() > 0, out.numpy(), (np.exp(out.numpy()) - 1)), atol=1e-5)
Tensor.no_grad = False
@unittest.skipIf(Device.DEFAULT != "TORCH", "Takes too long to compile for Compiled backends")
@unittest.skip("Takes too long to compile for Compiled backends")
def test_two_overlapping_binops_no_rerun_wino(self):
Tensor.no_grad = True
with Context(WINO=1):
@@ -140,4 +140,4 @@ class TestConv(unittest.TestCase):
x.numpy()
if __name__ == '__main__':
unittest.main()
unittest.main()

View File

@@ -17,7 +17,6 @@ floats = [dt for dt in core_dtypes if dtypes.is_float(dt)]
def is_dtype_supported(dtype: DType, device: str = Device.DEFAULT):
if dtype == dtypes.bfloat16: return False # numpy doesn't support bf16, tested separately in TestBFloat16DType
if device in ["WEBGPU", "WEBGL"]: return dtype in [dtypes.float, dtypes.int32, dtypes.uint32]
if device == "TORCH": return dtype not in [dtypes.uint16, dtypes.uint32, dtypes.uint64]
# for CI GPU, cl_khr_fp16 isn't supported
# for CI LLVM, it segfaults because it can't link to the casting function
# CUDA in CI uses CUDACPU that does not support half
@@ -127,7 +126,7 @@ def _test_ops(a_dtype:DType, b_dtype:DType, target_dtype=None):
_assert_eq(Tensor([[1,2],[3,4]], dtype=a_dtype)@Tensor.eye(2, dtype=b_dtype), target_dtype, [[1,2],[3,4]])
_assert_eq(Tensor([1,1,1,1], dtype=a_dtype)+Tensor.ones((4,4), dtype=b_dtype), target_dtype, 2*Tensor.ones(4,4).numpy())
@unittest.skipUnless(Device.DEFAULT in ["LLVM", "TORCH"], "bfloat16 not supported")
@unittest.skipUnless(Device.DEFAULT == "LLVM", "bfloat16 not supported")
class TestBFloat16DType(unittest.TestCase):
def test_bf16_to_float(self):
with self.assertRaises(AssertionError):

View File

@@ -2,7 +2,6 @@ import unittest
import time
import numpy as np
from tinygrad import Tensor, dtypes
from tinygrad.device import InterpretedASTRunner
from tinygrad.realize import run_schedule, create_schedule, lower_schedule_item
class TestFusionOp(unittest.TestCase):
@@ -29,7 +28,7 @@ class TestFusionOp(unittest.TestCase):
sched = create_schedule([a.lazydata], None)
ji = lower_schedule_item(sched[-1])
self.assertLess(time.perf_counter()-st, 1.0)
assert isinstance(ji, InterpretedASTRunner) or len(ji.prg.splitlines()) < 250
assert len(ji.prg.splitlines()) < 250
def test_recursive_add_cmp(self):
st = time.perf_counter()

View File

@@ -55,7 +55,7 @@ class TestLinearizerFailures(unittest.TestCase):
ast = LazyOp(op=ReduceOps.MAX, src=(LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=1, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(32, 2, 111, 27), strides=(6160, 3080, 28, 1), offset=0, mask=((0, 32), (0, 2), (0, 110), (0, 27)), contiguous=False), View(shape=(32, 2, 37, 9, 2, 2), strides=(5994, 2997, 81, 3, 27, 1), offset=0, mask=None, contiguous=False))))),), arg=(32, 2, 37, 9, 1, 1))
opts = [Opt(op=OptOps.LOCAL, axis=0, amt=32)]
ast = helper_add_store(ast)
helper_test_lin(Linearizer(ast), opts, failed_platforms=["CPU", "TORCH"])
helper_test_lin(Linearizer(ast), opts, failed_platforms=[])
@unittest.skipIf(CI and Device.DEFAULT=="METAL", "behaves differently on METAL CI")
def test_failure_3(self):

View File

@@ -123,7 +123,7 @@ class TestNN(unittest.TestCase):
torch_z = torch_layer(torch_x)
np.testing.assert_allclose(z.numpy(), torch_z.detach().numpy(), atol=5e-4, rtol=1e-5)
@unittest.skipIf(Device.DEFAULT not in {"CPU", "TORCH"}, "Takes too long to compile for Compiled backends")
@unittest.skip("Takes too long to compile for Compiled backends")
def test_conv2d_winograd(self):
BS, C1, H, W = 2, 8, 16, 16
C2, K, S, P = 8, 3, 1, 1

View File

@@ -394,15 +394,13 @@ class TestZeroShapeTensor(unittest.TestCase):
assert t.shape == (3, 2, 2)
np.testing.assert_equal(t.numpy(), np.ones((3, 2, 2)))
if Device.DEFAULT != "TORCH":
# torch does not support padding non-zero dim with 0-size. torch.nn.functional.pad(torch.zeros(3,2,0), [0,0,0,4,0,0])
t = Tensor.rand(3, 2, 0).pad((None, (1, 1), None), 1)
assert t.shape == (3, 4, 0)
np.testing.assert_equal(t.numpy(), np.ones((3, 4, 0)))
t = Tensor.rand(3, 2, 0).pad((None, (1, 1), None), 1)
assert t.shape == (3, 4, 0)
np.testing.assert_equal(t.numpy(), np.ones((3, 4, 0)))
t = Tensor.rand(3, 2, 0).pad(((1, 1), None, None), 1)
assert t.shape == (5, 2, 0)
np.testing.assert_equal(t.numpy(), np.ones((5, 2, 0)))
t = Tensor.rand(3, 2, 0).pad(((1, 1), None, None), 1)
assert t.shape == (5, 2, 0)
np.testing.assert_equal(t.numpy(), np.ones((5, 2, 0)))
def test_shrink_into_zero(self):
t = Tensor.rand(3, 4).realize()
@@ -416,12 +414,11 @@ class TestZeroShapeTensor(unittest.TestCase):
assert t.shape == (3, 2, 2)
np.testing.assert_equal(t.numpy(), s.numpy())
if Device.DEFAULT != "TORCH":
# torch does not support padding non-zero dim with 0-size. torch.nn.functional.pad(torch.zeros(3,2,0), [0,0,0,4,0,0])
s = Tensor.rand(3, 4, 0)
t = Tensor.rand(3, 2, 0).cat(s, dim=1)
assert t.shape == (3, 6, 0)
np.testing.assert_equal(t.numpy(), np.zeros((3, 6, 0)))
# torch does not support padding non-zero dim with 0-size. torch.nn.functional.pad(torch.zeros(3,2,0), [0,0,0,4,0,0])
s = Tensor.rand(3, 4, 0)
t = Tensor.rand(3, 2, 0).cat(s, dim=1)
assert t.shape == (3, 6, 0)
np.testing.assert_equal(t.numpy(), np.zeros((3, 6, 0)))
def test_elementwise(self):
a = Tensor.rand(3, 2, 0)

View File

@@ -29,7 +29,6 @@ class TestTorchLoad(unittest.TestCase):
# for LLVM, it segfaults because it can't link to the casting function
# CUDACPU architecture is sm_35 but we need at least sm_70 to run fp16 ALUs
@unittest.skipIf(Device.DEFAULT in ["GPU", "LLVM", "CUDA"] and CI, "fp16 broken in some backends")
@unittest.skipIf(Device.DEFAULT == "TORCH", "torch doesn't support the way we load bfloat (cast to uint32)")
def test_load_llama2bfloat(self): compare_weights_both("https://huggingface.co/qazalin/bf16-lightweight/resolve/main/consolidated.00.pth?download=true")
# pytorch tar format

View File

@@ -485,7 +485,7 @@ class Linearizer(Kernel):
if x.op in ReduceOps and not do_reduce:
assert offs is None, "not available if we aren't doing reduce"
return acc
# MULACC fusion. TODO: this is copied from Interpreted
# MULACC fusion.
if x.op == ReduceOps.SUM:
if x.src[0].op == BinaryOps.MUL: x = LazyOp(TernaryOps.MULACC, x.src[0].src, x.arg)
if (castop:=x.src[0]).op == UnaryOps.CAST and (mulop:=castop.src[0]).op == BinaryOps.MUL:

View File

@@ -1,12 +1,11 @@
from __future__ import annotations
from collections import defaultdict
from typing import TYPE_CHECKING, Union, Any, List, Optional, Dict, Callable, Tuple, cast, ClassVar
import importlib, inspect, functools, pathlib, time, re, ctypes
from typing import TYPE_CHECKING, Any, List, Optional, Dict, Tuple, ClassVar
import importlib, inspect, functools, pathlib, time, ctypes
from tinygrad.dtype import DType, ImageDType
from tinygrad.helpers import ansilen, DEBUG, getenv, colored, BEAM, NOOPT, all_int, to_function_name, from_mv, flat_mv, diskcache_get, diskcache_put
from tinygrad.shape.shapetracker import ShapeTracker
from tinygrad.shape.symbolic import Variable, sym_infer, sint
from tinygrad.ops import LazyOp, TernaryOps, get_lazyop_info, ReduceOps, BufferOps, BinaryOps, UnaryOps, Op, GlobalCounters, MovementOps
from tinygrad.ops import LazyOp, get_lazyop_info, GlobalCounters
from dataclasses import dataclass
if TYPE_CHECKING:
@@ -21,9 +20,9 @@ class _Device:
def _canonicalize(self, device:str) -> str: return (device.split(":", 1)[0].upper() + ((":"+device.split(":", 1)[1]) if ':' in device else '')).replace(":0", "") # noqa: E501
# NOTE: you can't cache canonicalize in case Device.DEFAULT changes
def canonicalize(self, device:Optional[str]) -> str: return self._canonicalize(device) if device is not None else Device.DEFAULT
def __getitem__(self, ix:str) -> Union[Interpreted, Compiled]: return self.__get_canonicalized_item(self.canonicalize(ix))
def __getitem__(self, ix:str) -> Compiled: return self.__get_canonicalized_item(self.canonicalize(ix))
@functools.lru_cache(maxsize=None) # this class is a singleton, pylint: disable=method-cache-max-size-none
def __get_canonicalized_item(self, ix:str) -> Union[Interpreted, Compiled]:
def __get_canonicalized_item(self, ix:str) -> Compiled:
x = ix.split(":")[0].upper()
return [cls for cname, cls in inspect.getmembers(importlib.import_module(f'tinygrad.runtime.ops_{x.lower()}')) if (cname.lower() == x.lower() + "device") and x in self._devices][0](ix) # noqa: E501
@functools.cached_property
@@ -168,98 +167,6 @@ class _MallocAllocator(LRUAllocator):
def copyout(self, dest:memoryview, src): ctypes.memmove(from_mv(dest), src, len(dest))
MallocAllocator = _MallocAllocator()
# **************** for Interpreted Devices ****************
class InterpretedASTRunner(JITRunner):
def __init__(self, ast:LazyOp, fxn:Callable):
super().__init__()
self.fxn = fxn
info = get_lazyop_info(ast)
self.op_estimate, self.mem_estimate = info.flops, info.mem_estimate
def __call__(self, rawbufs:List[Buffer], var_vals:Dict[Variable, int], wait=False, jit=False) -> float:
st = time.perf_counter()
rawbufs[0]._buf = self.fxn([x._buf for x in rawbufs[1:]], var_vals)
et = time.perf_counter() - st
update_stats(f"<interpreted {rawbufs[0].size}>", self.op_estimate, self.mem_estimate, var_vals, et, len(rawbufs), jit, device=rawbufs[0].device)
return et
class Interpreted:
def __init__(self, device:str, allocator: Allocator, fxn_for_op:Dict[Op, Callable]):
self.dname, self.allocator, self.fxn_for_op = device, allocator, fxn_for_op
self.synchronize, self.codegen, self.graph = lambda: None, None, None
@functools.lru_cache(None) # pylint: disable=method-cache-max-size-none
def get_runner(self, ast:LazyOp) -> InterpretedASTRunner: return _get_interpreted_fxn(self.fxn_for_op, ast)
def _get_interpreted_fxn(fxn_for_op:Dict[Op, Callable], ast:LazyOp) -> InterpretedASTRunner:
if DEBUG >= 3:
from tinygrad.features.graph import print_tree
print_tree(ast)
tglob: Dict[str, Any] = {"Variable": Variable}
@functools.lru_cache(None)
def gstr(x:Any, nm=None) -> str:
if ('Variable' in (str_arg := repr(x)) or 'NumNode' in str_arg):
str_arg = re.sub(r'Variable\(.*?\)', lambda m: f'var_vals[{str(m.group(0))}]', str_arg)
# TODO: (Variable - Variable) might create NumNode. can we remove it?
return re.sub(r'NumNode\((.*?)\)', r'\1', str_arg)
ret = str(nm).replace(".", "_") if nm else f"m{len(tglob):04d}"
tglob[ret] = x
return ret
lines: List[str] = []
@functools.lru_cache(None)
def _interpret_ast(ast:LazyOp) -> str:
# TODO: shortcutted store won't work with strides
if ast.op == BufferOps.STORE: return _interpret_ast(ast.src[0])
if TernaryOps.MULACC in fxn_for_op and ast.op == ReduceOps.SUM:
if ast.src[0].op == BinaryOps.MUL: ast = LazyOp(TernaryOps.MULACC, ast.src[0].src, ast.arg)
if (castop:=ast.src[0]).op == UnaryOps.CAST and (mulop:=castop.src[0]).op == BinaryOps.MUL:
# MULACC with acc cast rewrite: MUL -> CAST -> SUM => CAST -> MULACC
ast = LazyOp(TernaryOps.MULACC, tuple(LazyOp(UnaryOps.CAST, (s, ), castop.arg) for s in mulop.src), ast.arg)
if ast.op in BufferOps:
if ast.op == BufferOps.CONST: tmp = f"{gstr(fxn_for_op[ast.op], ast.op)}({gstr(ast.arg.val)}, {gstr(ast.arg.dtype)})"
else: tmp = f"{gstr(fxn_for_op[UnaryOps.CAST], UnaryOps.CAST)}(inputs[{ast.arg.idx-1}], ({gstr(ast.arg.dtype)}, True))"
# convert ShapeTracker to MovementOps
to_apply:List[Tuple[MovementOps, Tuple]] = []
for v in cast(ShapeTracker, ast.arg.st).views:
real_shape = tuple(y-x for x,y in v.mask) if v.mask else v.shape
real_offset = 0 if 0 in real_shape else (v.offset + (sum(x*st for (x,_),st in zip(v.mask, v.strides)) if v.mask else 0))
# first, we apply the offset
# then, we make it the correct shape
# then, we apply permutations
to_apply.append((MovementOps.AS_STRIDED, (tuple([s if st != 0 else 1 for s,st in zip(real_shape, v.strides)]), v.strides, real_offset)))
# then, we apply pre expand pads
if v.mask is not None:
pre_expand_pads = tuple((x,s-y) if st != 0 else (0,0) for (x,y),s,st in zip(v.mask, v.shape, v.strides))
post_expand_pads = tuple((x,s-y) if st == 0 else (0,0) for (x,y),s,st in zip(v.mask, v.shape, v.strides))
if any(x != (0,0) for x in pre_expand_pads):
to_apply.append((MovementOps.PAD, pre_expand_pads))
real_shape = tuple(x+s[0]+s[1] for x,s in zip(real_shape, pre_expand_pads))
# then, we do any expands
# NOTE: this is a good idea even without masks, since torch doesn't support negative strides and has to make a copy
if any(s != 1 and st == 0 for s,st in zip(real_shape, v.strides)): to_apply.append((MovementOps.EXPAND, real_shape))
# lastly, we apply post expand pads
if v.mask is not None and any(x != (0,0) for x in post_expand_pads): to_apply.append((MovementOps.PAD, post_expand_pads))
# apply those MovementOps
for mop,arg in to_apply: tmp = f"{gstr(fxn_for_op[mop], mop)}({tmp}, {gstr(arg)})"
else:
tmp = f"{gstr(fxn_for_op[ast.op], ast.op)}({', '.join([_interpret_ast(src) for src in ast.src] + ([gstr(ast.arg)] if ast.arg else []))})"
ret = f"a{len(lines)}"
lines.append(f" {ret} = {tmp}")
return ret
ret = _interpret_ast(ast)
src = '\n'.join(['def run(inputs, var_vals):'] + lines + [f" return {ret}"])
if DEBUG >= 4: print(functools.reduce(lambda x,y: (x.replace(y[0], str(y[1])) if y[0][0:2] == "m0" else x), tglob.items(), src))
exec(compile(src, "<ast>", "exec"), tglob) # pylint: disable=exec-used
return InterpretedASTRunner(ast, tglob['run'])
# **************** for Compiled Devices ****************
class Compiler:

View File

@@ -112,7 +112,7 @@ def beam_search(lin:Linearizer, rawbufs, amt:int, allow_test_size=True) -> Linea
while not exiting:
acted_lins = flatten([get_linearizer_actions(lin, include_0=False).values() for lin,_ in beam]) if len(beam) else [lin]
timed_lins: List[Tuple[Linearizer, float]] = []
_compile_fn = functools.partial(_try_compile_linearized_w_idx, compiler=cast(Compiled, Device[lin.opts.device]).compiler)
_compile_fn = functools.partial(_try_compile_linearized_w_idx, compiler=Device[lin.opts.device].compiler)
for i,proc in (map(_compile_fn, enumerate(acted_lins)) if beam_pool is None else beam_pool.imap_unordered(_compile_fn, enumerate(acted_lins))):
if proc is None: continue
lib, global_size, local_size = proc