JIT support in Interpreted (#2314)

* factor that out

* jit is supported everywhere

* fix some tests

* there's no jit supported device, the jit is everywhere

* fix test uops
This commit is contained in:
George Hotz
2023-11-15 11:13:38 -08:00
committed by GitHub
parent 9a20bc08d6
commit 70a65c201e
16 changed files with 181 additions and 158 deletions

View File

@@ -1,10 +1,9 @@
import numpy as np
import operator, functools
import operator
from typing import Callable, Dict, Tuple, Optional
from tinygrad.helpers import dtypes, DType
from tinygrad.ops import BufferOps, UnaryOps, BinaryOps, MovementOps, ReduceOps, TernaryOps, Op, Interpreted
from tinygrad.runtime.lib import RawBuffer
from tinygrad.runtime.interpreted import interpret_ast
def shape_to_axis(old_shape:Tuple[int, ...], new_shape:Tuple[int, ...]) -> Tuple[int, ...]:
assert len(old_shape) == len(new_shape), "reduce shapes must have same dimensions"
@@ -52,4 +51,4 @@ class RawNumpyBuffer(RawBuffer):
@classmethod
def fromCPU(cls, x): return cls(x.size, dtypes.from_np(x.dtype), x)
def toCPU(self): return self._buf
CPUBuffer = Interpreted(RawNumpyBuffer, functools.partial(interpret_ast, numpy_fxn_for_op, RawNumpyBuffer.fromCPU))
CPUBuffer = Interpreted(RawNumpyBuffer, numpy_fxn_for_op, RawNumpyBuffer.fromCPU)