diff --git a/docs/abstractions.py b/docs/abstractions.py index ba230b8066..5982963fa9 100644 --- a/docs/abstractions.py +++ b/docs/abstractions.py @@ -278,7 +278,7 @@ result = Tensor([2]) + Tensor([3]) # we have a global cache used by the JIT # from there, we can see the generated clang code -from tinygrad.ops import GlobalCounters +from tinygrad.helpers import GlobalCounters GlobalCounters.cache = [] # enables the cache result.realize() # create the program and runs it cache_saved = GlobalCounters.cache diff --git a/test/unit/test_flopcounter.py b/test/unit/test_flopcounter.py index 6ad0f5b397..bb0e123b0f 100644 --- a/test/unit/test_flopcounter.py +++ b/test/unit/test_flopcounter.py @@ -1,6 +1,7 @@ #!/usr/bin/env python import unittest -from tinygrad.ops import InterpretedBuffer, GenericShape, LazyOp, BinaryOps, get_lazyop_info +from tinygrad.ops import LazyOp, BinaryOps +from tinygrad.interpreted import get_lazyop_info, InterpretedBuffer, GenericShape class TestFlopCounter(unittest.TestCase): def test_flops_add(self): @@ -34,6 +35,6 @@ class TestFlopCounter(unittest.TestCase): op2 = LazyOp(BinaryOps.ADD, (op0,op1,), None) info = get_lazyop_info(op2) self.assertEqual(info.flops, 12) - + if __name__ == '__main__': unittest.main() diff --git a/test/unit/test_graph.py b/test/unit/test_graph.py index ca480efbf9..a2e09f57e7 100644 --- a/test/unit/test_graph.py +++ b/test/unit/test_graph.py @@ -14,8 +14,8 @@ class TestGraph(unittest.TestCase): assert nx.is_isomorphic(G, RG, node_match=lambda x,y: x["label"] == y["label"], edge_match=lambda x,y: x["label"] == y["label"] if "label" in y else True) def test_add_graph(self): - a = CPUBuffer.fromCPU(np.ones((4,4), dtype=np.float32)) - b = CPUBuffer.fromCPU(np.ones((4,4), dtype=np.float32)) + a = CPUBuffer(np.ones((4,4), dtype=np.float32)) + b = CPUBuffer(np.ones((4,4), dtype=np.float32)) ast = LazyOp(BinaryOps.ADD, (a,b)) ret = CPUBuffer(np.ones((4,4), dtype=np.float32)) @@ -30,8 +30,8 @@ class TestGraph(unittest.TestCase): self.helper_compare_graph(RG) def test_add_sum_graph(self): - a = CPUBuffer.fromCPU(np.ones((4,4), dtype=np.float32)) - b = CPUBuffer.fromCPU(np.ones((1,1), dtype=np.float32)) + a = CPUBuffer(np.ones((4,4), dtype=np.float32)) + b = CPUBuffer(np.ones((1,1), dtype=np.float32)) op0 = LazyOp(MovementOps.RESHAPE, (b,), (4, 4)) op1 = LazyOp(BinaryOps.ADD, (a,op0)) ast = LazyOp(ReduceOps.SUM, (op1,), (1,1)) @@ -48,12 +48,12 @@ class TestGraph(unittest.TestCase): self.helper_compare_graph(RG) def test_add_graph_prune(self): - a = CPUBuffer.fromCPU(np.ones((1,1), dtype=np.float32)) + a = CPUBuffer(np.ones((1,1), dtype=np.float32)) ast = LazyOp(MovementOps.RESHAPE, (a,), (4, 4)) ret = CPUBuffer(np.ones((4,4), dtype=np.float32)) log_op(ret, ast, show_graph=True) - b = CPUBuffer.fromCPU(np.ones((4,4), dtype=np.float32)) + b = CPUBuffer(np.ones((4,4), dtype=np.float32)) ast = LazyOp(BinaryOps.ADD, (ret,b)) ret = CPUBuffer(np.ones((4,4), dtype=np.float32)) log_op(ret, ast, show_graph=True) diff --git a/tinygrad/codegen/ast.py b/tinygrad/codegen/ast.py index 8902b28ab0..7728bf98ae 100644 --- a/tinygrad/codegen/ast.py +++ b/tinygrad/codegen/ast.py @@ -2,8 +2,9 @@ import itertools from enum import Enum, auto from typing import List, Tuple from tinygrad.helpers import prod, dedup, all_same, colored, DType -from tinygrad.ops import LazyOp, MovementOps, get_lazyop_info, get_buffers, ReduceOps, get_lazyops, map_buffers, GenericShape, ASTRunner +from tinygrad.ops import LazyOp, MovementOps, get_buffers, ReduceOps, get_lazyops, map_buffers, ASTRunner from tinygrad.shape.shapetracker import ShapeTracker, View, strides_for_shape +from tinygrad.interpreted import get_lazyop_info, GenericShape def get_first_reduce(shapes): for i in range(len(shapes[0])): diff --git a/tinygrad/helpers.py b/tinygrad/helpers.py index 20c58278fe..3a9c9717e5 100644 --- a/tinygrad/helpers.py +++ b/tinygrad/helpers.py @@ -1,7 +1,8 @@ import os, math, functools import numpy as np -from typing import Tuple, Union, List, NamedTuple, Final, Iterator +from typing import Tuple, Union, List, NamedTuple, Final, Iterator, ClassVar, Optional, Callable, Any ShapeType = Tuple[int, ...] +# NOTE: helpers is not allowed to import from anything else in tinygrad def dedup(x): return list(dict.fromkeys(x)) # retains list order def prod(x:Union[List[int], Tuple[int, ...]]) -> int: return math.prod(x) @@ -39,3 +40,13 @@ class dtypes: float32: Final[DType] = DType(4, "float", np.float32) @staticmethod def from_np(x:Union[LazyNumpyArray, np.ndarray]) -> DType: return {np.dtype(np.float16): dtypes.float16, np.dtype(np.float32): dtypes.float32}[np.dtype(x.dtype)] + +class GlobalCounters: + global_ops: ClassVar[int] = 0 + global_mem: ClassVar[int] = 0 + time_sum_s: ClassVar[float] = 0.0 + kernel_count: ClassVar[int] = 0 + mem_used: ClassVar[int] = 0 # NOTE: this is not reset + cache: ClassVar[Optional[List[Tuple[Callable, Any]]]] = None + @staticmethod + def reset(): GlobalCounters.global_ops, GlobalCounters.global_mem, GlobalCounters.time_sum_s, GlobalCounters.kernel_count, GlobalCounters.cache = 0,0,0.0,0,None diff --git a/tinygrad/interpreted.py b/tinygrad/interpreted.py new file mode 100644 index 0000000000..2c9a5fb8ca --- /dev/null +++ b/tinygrad/interpreted.py @@ -0,0 +1,56 @@ +from __future__ import annotations +from typing import Tuple, Any, ClassVar, Optional, Callable, Dict +import functools +from tinygrad.helpers import DType, dtypes, prod, GlobalCounters, DEBUG +from tinygrad.shape.shapetracker import ShapeTracker +from tinygrad.ops import DeviceBuffer, LazyOp, get_buffers, map_buffers, Op, FusedOps, UnaryOps, MovementOps, ReduceOps, BinaryOps + +# this is a quick "buffer" class for flop tracking and getting the output shape +class GenericShape: + def __init__(self, shape:Tuple[int, ...], dtype:DType=dtypes.float32, flops:int=0): self.shape, self.dtype, self.flops = shape, dtype, flops + def consume_flops(self): + self.flops, ret = 0, self.flops + return ret +shape_fxn_for_op: Dict[Op, Callable] = { + **{op:lambda self: GenericShape(self.shape, self.dtype, self.consume_flops() + prod(self.shape)) for op in UnaryOps}, + **{op:lambda self,y: GenericShape(self.shape, max(self.dtype, y.dtype), self.consume_flops() + y.consume_flops() + prod(self.shape)) for op in BinaryOps}, + **{op:lambda self,new_shape: GenericShape(new_shape, self.dtype, self.consume_flops() + prod(self.shape)) for op in ReduceOps}, + **{op:functools.partial(lambda mop,self,arg: GenericShape(ShapeTracker(self.shape).movement_op(mop, arg).shape, self.dtype, self.consume_flops()), op) for op in MovementOps}} + +# this runs the LazyOp and gives you the output shape/dtype and flop count +def get_lazyop_info(ast:LazyOp) -> GenericShape: return InterpretedBuffer.exec_ast(map_buffers({x:InterpretedBuffer(GenericShape(x.shape, x.dtype)) for x in get_buffers(ast)}, ast))._buf + +# used in CPUBuffer and TorchBuffer +class InterpretedBuffer(DeviceBuffer): # pylint: disable=abstract-method + fxn_for_op: ClassVar = shape_fxn_for_op + def __init__(self, lbuf:Any): + self._buf: Any = lbuf + self.shape: Tuple[int, ...] = tuple(lbuf.shape) + self.dtype: DType = self.to_tinygrad_dtype() if hasattr(self, 'to_tinygrad_dtype') else lbuf.dtype + # NOTE: this is overcounting the memory used, as reshapes and stuff are aliases + self._memsz = (prod(self.shape) * self.dtype.itemsize) if not isinstance(self, InterpretedBuffer) else 0 + GlobalCounters.mem_used += self._memsz + def __del__(self): GlobalCounters.mem_used -= self._memsz + def contiguous(self): return type(self).exec_ast(LazyOp(op=UnaryOps.NOOP, src=(self,))) + def movement_op(self, op:MovementOps, arg=None): return type(self)(self.fxn_for_op[op](self._buf, arg)) if op in self.fxn_for_op else type(self)(getattr(self._buf, op.name.lower())(arg)) + @classmethod + def exec_ast(cls, ast:LazyOp, output_buffer:Optional[InterpretedBuffer]=None, context=None): + if FusedOps.MULACC in cls.fxn_for_op and ast.op == ReduceOps.SUM and isinstance(ast.src[0], LazyOp) and ast.src[0].op == BinaryOps.MUL: + ast = LazyOp(FusedOps.MULACC, ast.src[0].src, ast.arg) + created_context = context is None + if context is None: context = dict() + if not created_context and ast in context: return context[ast] + srcs = [cls.exec_ast(x, context=context) if isinstance(x, LazyOp) else x for x in ast.src] + if ast.op in BinaryOps: assert srcs[0].shape == srcs[1].shape, f"BinaryOps shape mismatch {srcs[0].shape} != {srcs[1].shape}" + if ast.op in ReduceOps: assert all(r == n or n == 1 for r,n in zip(srcs[0].shape, ast.arg)), f"ReduceOps can't reduce {srcs[0].shape} -> {ast.arg}" + if ast.op in MovementOps: ret = srcs[0].movement_op(ast.op, ast.arg) + else: ret = cls(cls.fxn_for_op[ast.op](*([x._buf for x in srcs] + ([ast.arg] if ast.arg is not None else [])))) + if DEBUG >= 4 or (not isinstance(cls, InterpretedBuffer) and DEBUG >= 3): + print(f"*** {'exec' if created_context else ' '} {GlobalCounters.mem_used/1e9:5.2f} GB op: {ast.op:20s} out({ret.dtype.name}): {str(ret.shape):30s} in({len(srcs)}):", list(set(x.shape for x in srcs)), ast.arg if ast.arg is not None else "") + if not created_context: context[ast] = ret + if output_buffer is not None: + assert output_buffer.shape == ret.shape, output_buffer.dtype == ret.dtype + output_buffer._buf = ret._buf + return output_buffer + else: + return ret \ No newline at end of file diff --git a/tinygrad/lazy.py b/tinygrad/lazy.py index 41c338ca37..9cbe484c6f 100644 --- a/tinygrad/lazy.py +++ b/tinygrad/lazy.py @@ -109,8 +109,8 @@ class LazyBuffer: if self.realized is None: # get real ops first if self.op.op == LoadOps.FROMCPU: - self.realized = Device[self.device].fromCPU(self.op.arg() if isinstance(self.op.arg, LazyNumpyArray) else self.op.arg) - ast = LazyOp(self.op.op, tuple()) + # resolve LazyNumpyArray + ast = LazyOp(self.op.op, tuple(), self.op.arg() if isinstance(self.op.arg, LazyNumpyArray) else self.op.arg) elif self.op.op == LoadOps.CONTIGUOUS: real_src = self.op.src[0].realize(self.device) self.realized = real_src.contiguous() diff --git a/tinygrad/ops.py b/tinygrad/ops.py index 073e8a9688..1c21b341eb 100644 --- a/tinygrad/ops.py +++ b/tinygrad/ops.py @@ -1,10 +1,11 @@ from __future__ import annotations -import functools, itertools, operator, random, ctypes +import functools, itertools, operator, random import numpy as np from enum import Enum, auto -from typing import Union, Type, NamedTuple, Tuple, Any, List, ClassVar, Optional, Callable, Dict, TypeVar, Set, Final -from tinygrad.helpers import prod, DEBUG, getenv, DType, dtypes +from typing import Union, Type, NamedTuple, Tuple, Any, List, ClassVar, Optional, Dict, Set, Final +from tinygrad.helpers import prod, DEBUG, getenv, DType, dtypes, GlobalCounters from tinygrad.shape.shapetracker import ShapeTracker, MovementOps +from tinygrad.runtime.lib import RawBuffer # these are the llops your accelerator must implement, along with toCpu # the Enum class doesn't work with mypy, this is static. sorry it's ugly @@ -28,107 +29,17 @@ class LazyOp(NamedTuple): def get_buffers(op:LazyOp) -> List[Any]: return functools.reduce(operator.add, [get_buffers(x) if isinstance(x, LazyOp) else [x] for x in op.src], []) def get_lazyops(op:LazyOp) -> List[LazyOp]: return functools.reduce(operator.add, [get_lazyops(x) for x in op.src if isinstance(x, LazyOp)], [op]) def map_buffers(real_srcs:Dict[Any, Any], x:Any) -> LazyOp: - if x in real_srcs: return map_buffers(real_srcs, real_srcs[x]) if isinstance(real_srcs[x], LazyOp) else real_srcs[x] + if len(real_srcs) and x in real_srcs: return map_buffers(real_srcs, real_srcs[x]) if isinstance(real_srcs[x], LazyOp) else real_srcs[x] return LazyOp(x.op, tuple((map_buffers(real_srcs, y) if isinstance(y, LazyOp) else real_srcs[y]) for y in x.src), x.arg) -_T = TypeVar("_T") -class Copyable: - @classmethod - def fromCPU(cls:Type[_T], x:np.ndarray) -> _T: raise NotImplementedError("must be implemented") - def toCPU(self:Copyable) -> np.ndarray: raise NotImplementedError("must be implemented") - -class RawBuffer(Copyable): # pylint: disable=abstract-method - def __init__(self, size:int, dtype:DType): - self.size: int = size - self.dtype: DType = dtype - self._memsz: int = size*dtype.itemsize - GlobalCounters.mem_used += self._memsz - def __del__(self): GlobalCounters.mem_used -= self._memsz - -class RawBufferCopyIn(RawBuffer): - def copyin(self, x:np.ndarray) -> None: raise NotImplementedError("must be implemented") - - @classmethod - def fromCPU(cls, x:np.ndarray): - ret = cls(prod(x.shape), dtypes.from_np(x)) - ret.copyin(x) - return ret - -class RawBufferMapped(RawBufferCopyIn): - def _buffer(self) -> memoryview: raise NotImplementedError("must be implemented") - def toCPU(self) -> np.ndarray: return np.frombuffer(self._buffer(), dtype=self.dtype.np) - def copyin(self, x:np.ndarray) -> None: np.copyto(self.toCPU(), x.reshape(-1)) - -# this one is simple enough that i moved it out of the runtimes -class RawMallocBuffer(RawBufferMapped): - def __init__(self, size, dtype: DType): - super().__init__(size, dtype) - self._buf = ({dtypes.float32: ctypes.c_float, dtypes.float16: ctypes.c_int16}[dtype] * size)() - def _buffer(self): return memoryview(self._buf) - -class RawBufferCopyInOut(RawBufferCopyIn): - def copyout(self, x:np.ndarray) -> None: raise NotImplementedError("must be implemented") - - def toCPU(self) -> np.ndarray: - x: np.ndarray = np.empty(self.size, dtype=self.dtype.np) - self.copyout(x) - return x - # a placeholder class to extend by the exec classes -class DeviceBuffer(Copyable): +class DeviceBuffer: _buf: Any # underlying buffer shape: Tuple[int, ...] dtype: DType @classmethod def exec_ast(cls, ast:LazyOp, output_buffer=None): raise NotImplementedError("must be implemented") - -# this is a quick "buffer" class for flop tracking and getting the output shape -class GenericShape: - def __init__(self, shape:Tuple[int, ...], dtype:DType=dtypes.float32, flops:int=0): self.shape, self.dtype, self.flops = shape, dtype, flops - def consume_flops(self): - self.flops, ret = 0, self.flops - return ret -shape_fxn_for_op: Dict[Op, Callable] = { - **{op:lambda self: GenericShape(self.shape, self.dtype, self.consume_flops() + prod(self.shape)) for op in UnaryOps}, - **{op:lambda self,y: GenericShape(self.shape, max(self.dtype, y.dtype), self.consume_flops() + y.consume_flops() + prod(self.shape)) for op in BinaryOps}, - **{op:lambda self,new_shape: GenericShape(new_shape, self.dtype, self.consume_flops() + prod(self.shape)) for op in ReduceOps}, - **{op:functools.partial(lambda mop,self,arg: GenericShape(ShapeTracker(self.shape).movement_op(mop, arg).shape, self.dtype, self.consume_flops()), op) for op in MovementOps}} -def get_lazyop_info(ast:LazyOp) -> GenericShape: return InterpretedBuffer.exec_ast(map_buffers({x:InterpretedBuffer(GenericShape(x.shape, x.dtype)) for x in get_buffers(ast)}, ast))._buf - -# used in CPUBuffer and TorchBuffer -class InterpretedBuffer(DeviceBuffer): # pylint: disable=abstract-method - fxn_for_op: ClassVar = shape_fxn_for_op - def __init__(self, lbuf:Any): - self._buf: Any = lbuf - self.shape: Tuple[int, ...] = tuple(lbuf.shape) - self.dtype: DType = self.to_tinygrad_dtype() if hasattr(self, 'to_tinygrad_dtype') else lbuf.dtype - # NOTE: this is overcounting the memory used, as reshapes and stuff are aliases - self._memsz = (prod(self.shape) * self.dtype.itemsize) if not isinstance(lbuf, GenericShape) else 0 - GlobalCounters.mem_used += self._memsz - def __del__(self): GlobalCounters.mem_used -= self._memsz - def contiguous(self): return type(self).exec_ast(LazyOp(op=UnaryOps.NOOP, src=(self,))) - def movement_op(self, op:MovementOps, arg=None): return type(self)(self.fxn_for_op[op](self._buf, arg)) if op in self.fxn_for_op else type(self)(getattr(self._buf, op.name.lower())(arg)) - @classmethod - def exec_ast(cls, ast:LazyOp, output_buffer:Optional[InterpretedBuffer]=None, context=None): - if FusedOps.MULACC in cls.fxn_for_op and ast.op == ReduceOps.SUM and isinstance(ast.src[0], LazyOp) and ast.src[0].op == BinaryOps.MUL: - ast = LazyOp(FusedOps.MULACC, ast.src[0].src, ast.arg) - created_context = context is None - if context is None: context = dict() - if ast in context: return context[ast] - srcs = [cls.exec_ast(x, context=context) if isinstance(x, LazyOp) else x for x in ast.src] - if ast.op in BinaryOps: assert srcs[0].shape == srcs[1].shape, f"BinaryOps shape mismatch {srcs[0].shape} != {srcs[1].shape}" - if ast.op in ReduceOps: assert all(r == n or n == 1 for r,n in zip(srcs[0].shape, ast.arg)), f"ReduceOps can't reduce {srcs[0].shape} -> {ast.arg}" - if ast.op in MovementOps: ret = srcs[0].movement_op(ast.op, ast.arg) - else: ret = cls(cls.fxn_for_op[ast.op](*([x._buf for x in srcs] + ([ast.arg] if ast.arg else [])))) - if DEBUG >= 4 or (not isinstance(srcs[0]._buf, GenericShape) and DEBUG >= 3): - print(f"*** {'exec' if created_context else ' '} {GlobalCounters.mem_used/1e9:5.2f} GB op: {ast.op:20s} out({ret.dtype.name}): {str(ret.shape):30s} in({len(srcs)}):", list(set(x.shape for x in srcs)), ast.arg if ast.arg is not None else "") - context[ast] = ret - if output_buffer is not None: - assert output_buffer.shape == ret.shape, output_buffer.dtype == ret.dtype - output_buffer._buf = ret._buf - return output_buffer - else: - return ret + def toCPU(self) -> np.ndarray: raise NotImplementedError("must be implemented") class ASTRunner: def __init__(self, name, prg, bufs_to_delete:Optional[Set[int]]=None, global_size:Optional[List[int]]=None, local_size:Optional[List[int]]=None, op_estimate=0, mem_estimate=0): @@ -207,8 +118,6 @@ class CompiledBuffer(DeviceBuffer): # pylint: disable=abstract-method self._backing = None return self._buf - @classmethod - def fromCPU(cls, x:np.ndarray) -> CompiledBuffer: return cls(x.shape, backing=x.ravel(), dtype=dtypes.from_np(x)) def toCPU(self) -> np.ndarray: assert GlobalCounters.cache is None, f"can't copy out {self} while caching" if DEBUG >= 3: print(f"**** copy out {self.shape}") @@ -217,6 +126,7 @@ class CompiledBuffer(DeviceBuffer): # pylint: disable=abstract-method method_cache: Final[Dict[str, ASTRunner]] = {} @classmethod def exec_ast(cls, ast:LazyOp, output_buffer:Optional[CompiledBuffer]=None): + if ast.op == LoadOps.FROMCPU: return cls(ast.arg.shape, backing=ast.arg.ravel(), dtype=dtypes.from_np(ast.arg)) k = cls.spec.codegen(ast, output_buffer) if getenv("ENABLE_METHOD_CACHE", 1): # this is the default now if k.key not in cls.method_cache: cls.method_cache[k.key] = k.codegen().build(cls.spec.runtime) @@ -234,12 +144,3 @@ class CompiledBuffer(DeviceBuffer): # pylint: disable=abstract-method def contiguous(self): return self if self.st.contiguous and prod(self._base_shape) == prod(self.shape) else type(self).exec_ast(LazyOp(op=UnaryOps.NOOP, src=(self,))) def movement_op(self, op:MovementOps, arg): return type(self)(ShapeTracker(self.st).movement_op(op, arg), hostbuf=self, dtype=self.dtype) -class GlobalCounters: - global_ops: ClassVar[int] = 0 - global_mem: ClassVar[int] = 0 - time_sum_s: ClassVar[float] = 0.0 - kernel_count: ClassVar[int] = 0 - mem_used: ClassVar[int] = 0 # NOTE: this is not reset - cache: ClassVar[Optional[List[Tuple[Callable, Any]]]] = None - @staticmethod - def reset(): GlobalCounters.global_ops, GlobalCounters.global_mem, GlobalCounters.time_sum_s, GlobalCounters.kernel_count, GlobalCounters.cache = 0,0,0.0,0,None diff --git a/tinygrad/runtime/lib.py b/tinygrad/runtime/lib.py new file mode 100644 index 0000000000..a4c4171322 --- /dev/null +++ b/tinygrad/runtime/lib.py @@ -0,0 +1,47 @@ +import ctypes +import numpy as np +from typing import TypeVar, Type +from tinygrad.helpers import DType, dtypes, prod, GlobalCounters + +_T = TypeVar("_T") +class RawBuffer: # pylint: disable=abstract-method + def __init__(self, size:int, dtype:DType): + self.size: int = size + self.dtype: DType = dtype + self._memsz: int = size*dtype.itemsize + GlobalCounters.mem_used += self._memsz + def __del__(self): GlobalCounters.mem_used -= self._memsz + + # NOTE: this interface allows for 0 copy + @classmethod + def fromCPU(cls:Type[_T], x:np.ndarray) -> _T: raise NotImplementedError("must be implemented") + def toCPU(self) -> np.ndarray: raise NotImplementedError("must be implemented") + +class RawBufferCopyIn(RawBuffer): + def _copyin(self, x:np.ndarray) -> None: raise NotImplementedError("must be implemented") + + @classmethod + def fromCPU(cls, x:np.ndarray): + ret = cls(prod(x.shape), dtypes.from_np(x)) + ret._copyin(x) + return ret + +class RawBufferMapped(RawBufferCopyIn): + def _buffer(self) -> memoryview: raise NotImplementedError("must be implemented") + def toCPU(self) -> np.ndarray: return np.frombuffer(self._buffer(), dtype=self.dtype.np) + def _copyin(self, x:np.ndarray) -> None: np.copyto(self.toCPU(), x.reshape(-1)) + +# this one is simple enough that i moved it out of the runtimes +class RawMallocBuffer(RawBufferMapped): + def __init__(self, size, dtype: DType): + super().__init__(size, dtype) + self._buf = ({dtypes.float32: ctypes.c_float, dtypes.float16: ctypes.c_int16}[dtype] * size)() + def _buffer(self): return memoryview(self._buf) + +class RawBufferCopyInOut(RawBufferCopyIn): + def _copyout(self, x:np.ndarray) -> None: raise NotImplementedError("must be implemented") + + def toCPU(self) -> np.ndarray: + x: np.ndarray = np.empty(self.size, dtype=self.dtype.np) + self._copyout(x) + return x diff --git a/tinygrad/runtime/ops_clang.py b/tinygrad/runtime/ops_clang.py index 429ef7c5d2..3ea8c4c04c 100644 --- a/tinygrad/runtime/ops_clang.py +++ b/tinygrad/runtime/ops_clang.py @@ -1,5 +1,6 @@ import os, time, ctypes, hashlib, subprocess, platform -from tinygrad.ops import CompiledBuffer, Specialized, RawMallocBuffer +from tinygrad.ops import CompiledBuffer, Specialized +from tinygrad.runtime.lib import RawMallocBuffer from tinygrad.codegen.gpu import GPUCodegen, GPULanguage class ClangProgram: @@ -13,6 +14,7 @@ class ClangProgram: os.rename(fn+".tmp", fn) self.lib = ctypes.CDLL(fn) self.fxn = self.lib[name] + def __call__(self, global_size, local_size, *args, wait=False): if wait: st = time.monotonic() self.fxn(*[x._buf for x in args]) diff --git a/tinygrad/runtime/ops_cpu.py b/tinygrad/runtime/ops_cpu.py index e955b7b855..ec5d687b0a 100644 --- a/tinygrad/runtime/ops_cpu.py +++ b/tinygrad/runtime/ops_cpu.py @@ -2,7 +2,8 @@ import numpy as np import operator from typing import ClassVar, Callable, Dict, Tuple from tinygrad.helpers import dtypes -from tinygrad.ops import UnaryOps, BinaryOps, MovementOps, ReduceOps, FusedOps, InterpretedBuffer, Op +from tinygrad.ops import UnaryOps, BinaryOps, MovementOps, ReduceOps, FusedOps, LoadOps, Op +from tinygrad.interpreted import InterpretedBuffer def shape_to_axis(old_shape:Tuple[int, ...], new_shape:Tuple[int, ...]) -> Tuple[int, ...]: assert len(old_shape) == len(new_shape), "reduce shapes must have same dimensions" @@ -32,12 +33,10 @@ numpy_fxn_for_op: Dict[Op, Callable] = {**base_fxn_for_op, **{ MovementOps.PERMUTE: lambda x, order: x.transpose(order), MovementOps.PAD: np.pad, MovementOps.EXPAND: np.broadcast_to, MovementOps.STRIDE: lambda x, arg: x[tuple(slice(None, None, i) for i in arg)], FusedOps.MULACC: einsum_mulacc(lambda s,a,b: np.einsum(s, a.copy(), b.copy()), lambda x: x.strides, np.broadcast_to), + LoadOps.FROMCPU: lambda arg: arg }} class CPUBuffer(InterpretedBuffer): fxn_for_op: ClassVar[Dict[Op, Callable]] = numpy_fxn_for_op def to_tinygrad_dtype(self): return dtypes.from_np(self._buf) - - @staticmethod - def fromCPU(x): return CPUBuffer(x) def toCPU(self): return self._buf diff --git a/tinygrad/runtime/ops_cuda.py b/tinygrad/runtime/ops_cuda.py index 20c7c5bf13..625c18cdf0 100644 --- a/tinygrad/runtime/ops_cuda.py +++ b/tinygrad/runtime/ops_cuda.py @@ -4,15 +4,16 @@ import pycuda.autoprimaryctx # type: ignore # pylint: disable=unused-import # no import pycuda.driver as cuda # type: ignore from pycuda.compiler import compile as cuda_compile # type: ignore from tinygrad.helpers import DEBUG -from tinygrad.ops import CompiledBuffer, RawBufferCopyInOut, Specialized +from tinygrad.ops import CompiledBuffer, Specialized +from tinygrad.runtime.lib import RawBufferCopyInOut from tinygrad.codegen.gpu import GPUCodegen, GPULanguage class RawCUDABuffer(RawBufferCopyInOut): def __init__(self, size, dtype): super().__init__(size, dtype) self._cl = cuda.mem_alloc(self._memsz) - def copyin(self, x:np.ndarray, stream:Optional[cuda.Stream]=None): cuda.memcpy_htod_async(self._cl, x, stream) - def copyout(self, x:np.ndarray): cuda.memcpy_dtoh(x, self._cl) + def _copyin(self, x:np.ndarray, stream:Optional[cuda.Stream]=None): cuda.memcpy_htod_async(self._cl, x, stream) + def _copyout(self, x:np.ndarray): cuda.memcpy_dtoh(x, self._cl) class CUDAProgram: def __init__(self, name:str, prg:str, binary=False): diff --git a/tinygrad/runtime/ops_gpu.py b/tinygrad/runtime/ops_gpu.py index e86c3c20a9..0c1b5dd7ec 100644 --- a/tinygrad/runtime/ops_gpu.py +++ b/tinygrad/runtime/ops_gpu.py @@ -4,7 +4,8 @@ import numpy as np import pyopencl as cl # type: ignore from typing import Optional, List, Final from tinygrad.helpers import IMAGE, DEBUG, getenv, dtypes -from tinygrad.ops import CompiledBuffer, GlobalCounters, RawBufferCopyInOut, RawBuffer, Specialized +from tinygrad.ops import CompiledBuffer, GlobalCounters, Specialized +from tinygrad.runtime.lib import RawBufferCopyInOut, RawBuffer from tinygrad.codegen.gpu import GPUCodegen, GPULanguage OSX = platform.system() == "Darwin" @@ -24,8 +25,8 @@ class CLBuffer(RawBufferCopyInOut): def __init__(self, size, dtype): super().__init__(size, dtype) self._cl = cl.Buffer(CL.cl_ctx, cl.mem_flags.READ_WRITE, self._memsz) - def copyin(self, x:np.ndarray): cl.enqueue_copy(CL.cl_queue, self._cl, x, is_blocking=False) - def copyout(self, x:np.ndarray): cl.enqueue_copy(CL.cl_queue, x, self._cl, is_blocking=True) + def _copyin(self, x:np.ndarray): cl.enqueue_copy(CL.cl_queue, self._cl, x, is_blocking=False) + def _copyout(self, x:np.ndarray): cl.enqueue_copy(CL.cl_queue, x, self._cl, is_blocking=True) class CLImage(RawBuffer): # pylint: disable=abstract-method IMAGE: Final = True @@ -35,7 +36,6 @@ class CLImage(RawBuffer): # pylint: disable=abstract-method GlobalCounters.mem_used += self._cl.row_pitch * self._cl.height def __del__(self): GlobalCounters.mem_used -= self._cl.row_pitch * self._cl.height -#@functools.lru_cache(maxsize=None) class CLProgram: def __init__(self, name:str, prg:str, binary=False, argdtypes=None): self.name, self.argdtypes, self.clprogram = name, argdtypes, cl.Program(CL.cl_ctx, CL.cl_ctx.devices, [prg]) if binary else cl.Program(CL.cl_ctx, prg) # type: ignore diff --git a/tinygrad/runtime/ops_llvm.py b/tinygrad/runtime/ops_llvm.py index a357b1aacf..22404142f7 100644 --- a/tinygrad/runtime/ops_llvm.py +++ b/tinygrad/runtime/ops_llvm.py @@ -1,9 +1,10 @@ import time, hashlib, ctypes from typing import ClassVar -from tinygrad.ops import CompiledBuffer, Specialized, RawMallocBuffer +from tinygrad.ops import CompiledBuffer, Specialized from tinygrad.helpers import getenv, DEBUG from ctypes import CFUNCTYPE from tinygrad.codegen.llvm import LLVMCodegen +from tinygrad.runtime.lib import RawMallocBuffer import llvmlite.binding as llvm # type: ignore diff --git a/tinygrad/runtime/ops_metal.py b/tinygrad/runtime/ops_metal.py index d048fd2f4b..2c96fca67e 100644 --- a/tinygrad/runtime/ops_metal.py +++ b/tinygrad/runtime/ops_metal.py @@ -4,7 +4,8 @@ import Metal, Cocoa, libdispatch # type: ignore from typing import List, Any from tinygrad.codegen.gpu import GPUCodegen, GPULanguage from tinygrad.helpers import prod, getenv, DEBUG, DType -from tinygrad.ops import CompiledBuffer, RawBufferMapped, Specialized +from tinygrad.ops import CompiledBuffer, Specialized +from tinygrad.runtime.lib import RawBufferMapped METAL_XCODE = getenv("METAL_XCODE") diff --git a/tinygrad/runtime/ops_torch.py b/tinygrad/runtime/ops_torch.py index 713e55e573..cd2d2c0890 100644 --- a/tinygrad/runtime/ops_torch.py +++ b/tinygrad/runtime/ops_torch.py @@ -1,22 +1,22 @@ import torch from typing import ClassVar, Dict, Callable -from tinygrad.ops import UnaryOps, BinaryOps, MovementOps, FusedOps, InterpretedBuffer, Op +from tinygrad.ops import UnaryOps, BinaryOps, MovementOps, FusedOps, LoadOps, Op from tinygrad.helpers import getenv, dtypes +from tinygrad.interpreted import InterpretedBuffer from tinygrad.runtime.ops_cpu import base_fxn_for_op, einsum_mulacc +device = torch.device("cuda:0" if torch.cuda.is_available() else ("mps" if getenv("MPS", 0) else "cpu")) + torch_fxn_for_op: Dict[Op, Callable] = {**base_fxn_for_op, **{ UnaryOps.NOOP: lambda x: x.contiguous(), UnaryOps.EXP: lambda x: x.exp(), UnaryOps.LOG: lambda x: x.log(), BinaryOps.MAX: torch.maximum, BinaryOps.CMPEQ: lambda x,y: (x==y).float(), MovementOps.PAD: lambda x, padding: torch.nn.functional.pad(x, [item for sublist in padding[::-1] for item in sublist]), FusedOps.MULACC: einsum_mulacc(lambda s,a,b: torch.einsum(s, a.float(), b.float()).type(a.dtype), lambda x: x.stride(), lambda x,s: x.expand(s)), - MovementOps.STRIDE: lambda x, arg: x[tuple(slice(None, None, abs(i)) for i in arg)].flip([i for i,a in enumerate(arg) if a < 0]) + MovementOps.STRIDE: lambda x, arg: x[tuple(slice(None, None, abs(i)) for i in arg)].flip([i for i,a in enumerate(arg) if a < 0]), + LoadOps.FROMCPU: lambda arg: torch.from_numpy(arg).requires_grad_(False).to(device) }} -device = torch.device("cuda:0" if torch.cuda.is_available() else ("mps" if getenv("MPS", 0) else "cpu")) class TorchBuffer(InterpretedBuffer): fxn_for_op: ClassVar = torch_fxn_for_op def to_tinygrad_dtype(self): return {torch.float16: dtypes.float16, torch.float32: dtypes.float32}[self._buf.dtype] - - @staticmethod - def fromCPU(x): return TorchBuffer(torch.from_numpy(x).requires_grad_(False).to(device)) def toCPU(self): return self._buf.cpu().numpy() diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index 84cf3f2322..f7bbb6c491 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -216,7 +216,7 @@ class Tensor: # - Both i and j will be clamped to the range (-N, N], where N in the length of the sequence # - Indexing with np.newaxis or None on a given axis will add a new dimension of size one before that axis # - Empty slices are not allowed - # - Strides other than 1 are not allowedå + # - Strides other than 1 are not allowed def __getitem__(self, val): def slcfix(i, sz, default): return default if i is None else max(0, min(sz, sz+i if i < 0 else i)) # Fix negative idxs, clamp to [0,N] new_slice, new_shape = [], []