mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 15:08:02 -05:00
Devicebufferless (#708)
* runs one metal kernel * conv2d works * ops tests are passing * const folding * all ops work * pre commit always passes * torch works * working still * fix graph test * tests passing * image almost works * image conv works * most images * fix custom * fix assignment * fix compile enet * clean up comments * fix realize return value * include shapetracker in LB repr * copy should make a copy * reenable method cache * fix lna * dtypes in graph * forward only for IMAGE=2 * simple realize * getting close * fixup new api, it's good except the kernel count * back to 197 kernels * tests should pass * go to a real float * no type_on_cpu * fix the docs * put shapetracker back in it's proper place
This commit is contained in:
6
.github/workflows/test.yml
vendored
6
.github/workflows/test.yml
vendored
@@ -79,7 +79,7 @@ jobs:
|
||||
run: curl https://media.istockphoto.com/photos/hen-picture-id831791190 | ./recognize | grep hen
|
||||
|
||||
testllvm:
|
||||
name: LLVM Tests
|
||||
name: LLVM Tests (w method cache)
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
steps:
|
||||
@@ -160,7 +160,9 @@ jobs:
|
||||
- name: Install Dependencies
|
||||
run: pip install -e '.[gpu,testing]' --extra-index-url https://download.pytorch.org/whl/cpu
|
||||
- name: Test GPU IMAGE ops
|
||||
run: GPU=1 IMAGE=2 python3 test/test_ops.py
|
||||
run: |
|
||||
GPU=1 IMAGE=1 python3 test/test_ops.py
|
||||
FORWARD_ONLY=1 GPU=1 IMAGE=2 python3 test/test_ops.py
|
||||
- name: Test openpilot model
|
||||
run: |
|
||||
ALLOWED_KERNEL_COUNT=197 FLOAT16=1 VALIDHACKS=1 DEBUGCL=1 GPU=1 IMAGE=2 python3 openpilot/compile.py
|
||||
|
||||
@@ -1,6 +1,12 @@
|
||||
repos:
|
||||
- repo: local
|
||||
hooks:
|
||||
- id: docs
|
||||
name: docs
|
||||
entry: python3 docs/abstractions.py
|
||||
language: system
|
||||
always_run: true
|
||||
pass_filenames: false
|
||||
- id: flake8
|
||||
name: flake8
|
||||
entry: flake8 tinygrad/ --indent-size=2 --select=F,E112,E113,E203,E304,E502,E702,E703,E71,E72,E731,W191,W6 --statistics -j4
|
||||
|
||||
@@ -469,4 +469,4 @@ check-str-concat-over-line-jumps=yes
|
||||
|
||||
# Exceptions that will emit a warning when being caught. Defaults to
|
||||
# "Exception"
|
||||
overgeneral-exceptions=Exception
|
||||
overgeneral-exceptions=builtins.Exception
|
||||
|
||||
@@ -77,14 +77,20 @@ class LazyBuffer:
|
||||
shape: Tuple[int, ...]
|
||||
dtype: DType
|
||||
|
||||
# a ShapeTracker is used to track things like reshapes and permutes
|
||||
# all MovementOps are zero copy in tinygrad!
|
||||
# the ShapeTracker specifies how the data in the RawBuffer matches to the shape
|
||||
# we'll come back to this later
|
||||
st: ShapeTracker
|
||||
|
||||
# if the LazyBuffer is realized, it has a RawBuffer
|
||||
# we will come back to RawBuffers later
|
||||
realized: Optional[RawBuffer]
|
||||
|
||||
# if the lazybuffer is unrealized, it has a LazyOp
|
||||
# this LazyOp describes the computation needed to realize this LazyBuffer
|
||||
op: Optional[LazyOp]
|
||||
|
||||
# if the LazyBuffer is realized, it has a DeviceBuffer
|
||||
# we will come back to DeviceBuffers later, first we'll explore the LazyOp
|
||||
realized: Optional[DeviceBuffer]
|
||||
|
||||
# LazyOp (in tinygrad/ops.py, code 4/10)
|
||||
# in a tree they form an Abstract Syntax Tree for a single GPU kernel
|
||||
class LazyOp:
|
||||
@@ -128,81 +134,60 @@ assert len(lazyop.src) == 2
|
||||
# again, a LazyOp AST is like a GPU kernel. you have to copy the data on the device first
|
||||
print(lazyop.src[0].op)
|
||||
assert lazyop.src[0].op.op == LoadOps.FROMCPU
|
||||
assert lazyop.src[0].op.arg[0] == [2], "the arg of the FROMCPU LazyOP is the [2.]"
|
||||
assert lazyop.src[0].op.arg.fxn == [2], "the arg of the FROMCPU LazyOP is the [2.]"
|
||||
assert result.lazydata.realized is None, "the LazyBuffer is not realized yet"
|
||||
|
||||
# now we realize the LazyBuffer
|
||||
result.lazydata.realize()
|
||||
assert result.lazydata.realized is not None, "the LazyBuffer is realized!"
|
||||
# this brings us nicely to DeviceBuffer, of which the realized ClangBuffer is a subclass
|
||||
assert 'ClangBuffer' in str(type(result.lazydata.realized))
|
||||
assert 'RawMallocBuffer' in str(type(result.lazydata.realized))
|
||||
# getting ahead of ourselves, but we can copy the DeviceBuffer toCPU
|
||||
assert result.lazydata.realized.toCPU()[0] == 5, "when put in numpy with toCPU, it's 5"
|
||||
|
||||
# %%
|
||||
# == DeviceBuffer (in tinygrad/ops.py, code 4/10) ==
|
||||
# == Union[Interpreted, Compiled] (in tinygrad/ops.py, code 5/10) ==
|
||||
|
||||
# DeviceBuffer is an abstract class to be implemented for each Device backend
|
||||
class DeviceBuffer(ABC):
|
||||
# these two are straightforward.
|
||||
# unlike LazyBuffer, there's no need for device, since that's contained in the concrete type
|
||||
shape: Tuple[int, ...]
|
||||
dtype: DType
|
||||
# Now you have a choice, you can either write a "Interpreted" backend or "Compiled" backend
|
||||
|
||||
# this is the magic method that "fills" a DeviceBuffer and does all the math in tinygrad
|
||||
# NOTE: fromCPU no longer exists here, it's just a one LoadOps AST, LoadOps.FROMCPU
|
||||
def exec_ast(self, ast:LazyOp): raise NotImplementedError("must be implemented")
|
||||
# Interpreted backends are very simple (example: CPU and TORCH)
|
||||
class Interpreted:
|
||||
# they have a backing RawBuffer
|
||||
buffer: Type[RawBuffer]
|
||||
|
||||
# however, toCPU still exists. it will raise a RuntimeException if exec_ast has never been called
|
||||
# it copies out the underlying to the CPU, and will do any sync operations
|
||||
def toCPU(self) -> np.ndarray: raise NotImplementedError("must be implemented")
|
||||
# and they have a lookup table to functions for the Ops
|
||||
fxn_for_op: Dict[Op, Callable] = {
|
||||
UnaryOps.EXP: lambda x: np.exp(x),
|
||||
BinaryOps.ADD: lambda x,y: x+y}
|
||||
|
||||
# DeviceBuffers come in two flavors, InterpretedBuffer and CompiledBuffer
|
||||
# InterpretedBuffers are a lot simpler than CompiledBuffers
|
||||
# they are used to implement the CPU(numpy) and TORCH(torch) backends
|
||||
# it's worth reading CPUBuffer (in tinygrad/runtime/ops_cpu.py, code 8/10)
|
||||
import numpy as np
|
||||
import torch
|
||||
class InterpretedBuffer(DeviceBuffer):
|
||||
# this is where the data actually lives
|
||||
# finally some classes you recognize!
|
||||
_buf: Union[np.ndarray, torch.Tensor]
|
||||
# Compiled backends take a little more (example: GPU and LLVM)
|
||||
class Compiled:
|
||||
# they also have a backing RawBuffer
|
||||
buffer: Type[RawBuffer]
|
||||
|
||||
# the compute itself is defined here. these functions are called with _buf
|
||||
# here's a UnaryOp and BinaryOp from CPUBuffer(InterpretedBuffer)
|
||||
fxn_for_op: ClassVar[Dict[Op, Callable]] = {UnaryOps.EXP: lambda x: np.exp(x), BinaryOps.ADD: lambda x,y: x+y}
|
||||
|
||||
# NOTE: exec_ast should not need to be overridden!
|
||||
# The actual method lives in tinygrad/ops.py
|
||||
# it walks the LazyOp tree and calls fxn_for_op as appropriate
|
||||
|
||||
# ********** NOTE: for the CPU and TORCH backends, we are done and you can stop reading here **********
|
||||
|
||||
# %%
|
||||
# == CompiledBuffer (in tinygrad/ops.py, code 4/10) ==
|
||||
|
||||
# however, all the magic of tinygrad will come from CompiledBuffer
|
||||
# this is used for the GPU(opencl), CUDA, METAL, CLANG, and LLVM backends
|
||||
class CompiledBuffer(DeviceBuffer):
|
||||
# this is where the data actually lives, same as InterpretedBuffer
|
||||
# a RawBuffer is just raw (typed) memory on the Device in question
|
||||
_buf: RawBuffer
|
||||
|
||||
# introducing...ShapeTracker! all MovementOps are zero copy in tinygrad
|
||||
# the ShapeTracker specifies how the data in the RawBuffer matches to the shape
|
||||
# we'll come back to this later
|
||||
st: ShapeTracker
|
||||
|
||||
# NOTE: exec_ast should not need to be overridden!
|
||||
# instead you need three classes, explained below
|
||||
raw_buffer: Type[RawBuffer]
|
||||
runtime: Type[Runtime]
|
||||
# a code generator, which compiles the AST
|
||||
codegen: Type[ASTKernel]
|
||||
|
||||
# for completeness, we include RawBuffer. it's very boring and exactly what you expect
|
||||
# and a runtime, which runs the generated code
|
||||
runtime: Type[Runtime]
|
||||
|
||||
# Runtime is what actually runs the kernels for a compiled backend
|
||||
class Runtime(ABC):
|
||||
# `name` is the name of the function, and `prg` is the code
|
||||
# the constructor compiles the code
|
||||
def __init__(self, name:str, prg:str): pass
|
||||
# call runs the code on the bufs. NOTE: the output is always bufs[0], but this is just a convention
|
||||
def __call__(self, global_size:Optional[List[int]], local_size:Optional[List[int]], *bufs:List[RawBuffer]): pass
|
||||
|
||||
# %%
|
||||
# == RawBuffer (in tinygrad/runtime/lib.py, code 5/10) ==
|
||||
import numpy as np
|
||||
|
||||
# RawBuffer is where the data is actualy held. it's pretty close to just memory
|
||||
class RawBuffer(ABC):
|
||||
# create an empty rawbuffer that holds `size` elements of type `dtype`
|
||||
def __init__(self, size:int, dtype:DType): raise NotImplementedError("must be implemented")
|
||||
# `buf` is an opaque container class
|
||||
def __init__(self, size:int, dtype:DType, buf:Any): raise NotImplementedError("must be implemented")
|
||||
|
||||
# fromCPU is classmethod that creates a RawBuffer, it's a classmethod since some runtimes are 0 copy
|
||||
@classmethod
|
||||
@@ -211,13 +196,14 @@ class RawBuffer(ABC):
|
||||
# toCPU converts the RawBuffer to a numpy array with shape (size,). many backends are 0 copy here
|
||||
def toCPU(self) -> np.ndarray: raise NotImplementedError("must be implemented")
|
||||
|
||||
# Runtime is what actually runs the kernels
|
||||
class Runtime(ABC):
|
||||
# `name` is the name of the function, and `prg` is the code
|
||||
# the constructor compiles the code
|
||||
def __init__(self, name:str, prg:str): pass
|
||||
# call runs the code on the bufs. NOTE: the output is always bufs[0], but this is just a convention
|
||||
def __call__(self, global_size:Optional[List[int]], local_size:Optional[List[int]], *bufs:List[RawBuffer]): pass
|
||||
# RawNumpyBuffer is a RawBuffer example for numpy. It's very simple
|
||||
class RawNumpyBuffer(RawBuffer):
|
||||
# NOTE: the "np.ndarray" is stored in the opaque container
|
||||
def __init__(self, buf:np.ndarray):
|
||||
super().__init__(buf.size, dtypes.from_np(buf.dtype), buf)
|
||||
@classmethod
|
||||
def fromCPU(cls, x): return cls(x)
|
||||
def toCPU(self): return self._buf
|
||||
|
||||
# %%
|
||||
# == Example: 2+3 in raw clang ==
|
||||
@@ -262,11 +248,11 @@ class ASTKernel:
|
||||
def __init__(self, ast:LazyOp): pass
|
||||
def codegen(self) -> ASTRunner: pass
|
||||
|
||||
# we return a class that runs code on CompiledBuffers
|
||||
# we return a class that runs code on LazyBuffers, which are all expected to be realized
|
||||
class ASTRunner: # (from tinygrad/ops.py)
|
||||
def __init__(self, name, prg, global_size:Optional[List[int]], local_size:Optional[List[int]]): pass
|
||||
def build(self, runtime:Runtime): pass
|
||||
def exec(self, bufs:List[CompiledBuffer]): pass
|
||||
def exec(self, bufs:List[LazyBuffer]): pass
|
||||
|
||||
# that hides a lot of complexity that will be refactored, but that's the basic idea of code generation
|
||||
|
||||
|
||||
@@ -43,10 +43,10 @@ if __name__ == "__main__":
|
||||
# hack to put the inputs back
|
||||
assert len(run.input_replace) == 1, f"didn't get one input to replace {run.input_replace}"
|
||||
for (j,i),idx in run.input_replace.items():
|
||||
run.jit_cache[j][1][i] = the_input.lazydata.realized.raw()
|
||||
run.jit_cache[j][1][i] = the_input.lazydata.realized
|
||||
|
||||
# TODO: fetch this from the jit in self.input_replace and self.ret (hint: use get_parameters on self.ret)
|
||||
special_names = {id(the_input.lazydata.realized.raw()): "input", id(the_output.lazydata.realized.raw()): "outputs"}
|
||||
special_names = {id(the_input.lazydata.realized): "input", id(the_output.lazydata.realized): "outputs"}
|
||||
|
||||
functions, statements, bufs, bufs_to_save = compile_net(run, special_names)
|
||||
|
||||
@@ -109,4 +109,5 @@ int main(int argc, char* argv[]) {
|
||||
}"""]
|
||||
|
||||
# CLANG=1 python3 examples/compile_efficientnet.py | clang -O2 -lm -x c - -o recognize && DEBUG=1 time ./recognize docs/stable_diffusion_by_tinygrad.jpg
|
||||
# category : 281 (tabby, tabby cat) with 9.452788
|
||||
print('\n'.join(cprog))
|
||||
|
||||
@@ -3,7 +3,7 @@ import gc
|
||||
from tinygrad.helpers import prod
|
||||
from tinygrad.tensor import Tensor
|
||||
from tinygrad.lazy import LazyBuffer
|
||||
from tinygrad.runtime.ops_gpu import GPUBuffer
|
||||
from tinygrad.runtime.ops_gpu import CLBuffer
|
||||
from tinygrad.ops import GlobalCounters
|
||||
|
||||
def print_objects():
|
||||
@@ -11,7 +11,7 @@ def print_objects():
|
||||
tensors = [x for x in gc.get_objects() if isinstance(x, Tensor)]
|
||||
tensor_ram_used = sum([prod(x.shape)*4 for x in tensors])
|
||||
lazybuffers = [x for x in gc.get_objects() if isinstance(x, LazyBuffer)]
|
||||
gpubuffers = [x for x in gc.get_objects() if isinstance(x, GPUBuffer)]
|
||||
gpubuffers = [x for x in gc.get_objects() if isinstance(x, CLBuffer)]
|
||||
realized_buffers = [x.realized for x in lazybuffers if x.realized]
|
||||
gpubuffers_orphaned = [x for x in gpubuffers if x not in realized_buffers]
|
||||
|
||||
|
||||
@@ -109,8 +109,8 @@ def load_single_weight(t:Tensor, myfile, shape, strides, dtype, storage_offset,
|
||||
# this needs real APIs
|
||||
if t.device in ["METAL", "CLANG", "LLVM"]:
|
||||
del t.lazydata.op
|
||||
t.lazydata.realized = t.lazydata.dbuffer(t.shape, dtype=t.dtype)
|
||||
myfile.readinto(t.lazydata.realized.raw()._buffer())
|
||||
t.lazydata.realized = t.lazydata.dbuffer.buffer(prod(t.shape), dtype=t.dtype)
|
||||
myfile.readinto(t.lazydata.realized._buffer())
|
||||
else:
|
||||
def _mmap(lna):
|
||||
assert myfile._compress_type == 0, "compressed data can't be mmaped"
|
||||
|
||||
@@ -9,7 +9,7 @@ if os.getenv("GPU", None) is None:
|
||||
if os.getenv("IMAGE", None) is None:
|
||||
os.environ['IMAGE'] = '2'
|
||||
|
||||
from tinygrad.helpers import getenv
|
||||
from tinygrad.helpers import getenv, dtypes
|
||||
ALLOWED_KERNEL_COUNT = getenv("ALLOWED_KERNEL_COUNT", 0)
|
||||
DEBUGCL = getenv("DEBUGCL", 0)
|
||||
|
||||
@@ -38,7 +38,7 @@ from tinygrad.jit import TinyJit
|
||||
|
||||
@TinyJit
|
||||
def model_exec(run_onnx, using_graph, **inputs):
|
||||
ret = next(iter(run_onnx(inputs).values()))
|
||||
ret = next(iter(run_onnx(inputs).values())).cast(dtypes.float32)
|
||||
GlobalCounters.reset()
|
||||
GlobalCounters.cache = [] # don't cache pre-realize
|
||||
if using_graph: graph.GRAPH = True
|
||||
@@ -49,7 +49,7 @@ def compile(dat, output_fn):
|
||||
Tensor.manual_seed(1337)
|
||||
Tensor.no_grad = True
|
||||
using_graph = graph.GRAPH
|
||||
graph.GRAPH = False
|
||||
if getenv("GRAPH") < 2: graph.GRAPH = False
|
||||
|
||||
onnx_model = onnx.load(io.BytesIO(dat))
|
||||
run_onnx = get_run_onnx(onnx_model)
|
||||
@@ -63,7 +63,7 @@ def compile(dat, output_fn):
|
||||
assert len(model_exec.jit_cache) <= ALLOWED_KERNEL_COUNT or ALLOWED_KERNEL_COUNT == 0, "too many kernels!"
|
||||
|
||||
# pull out inputs and put them in the jit cache
|
||||
input_rawbuffers = {k:inputs[k].lazydata.realized.raw() for k in inputs.keys()}
|
||||
input_rawbuffers = {k:inputs[k].lazydata.realized for k in inputs.keys()}
|
||||
for (j,i),idx in model_exec.input_replace.items(): model_exec.jit_cache[j][1][i] = input_rawbuffers[idx]
|
||||
|
||||
# transform to CL.CACHE
|
||||
@@ -73,11 +73,11 @@ def compile(dat, output_fn):
|
||||
# pass these to thneed
|
||||
setattr(prg.clprg, 'op_estimate', prg.op_estimate)
|
||||
setattr(prg.clprg, 'prg', prg.prg)
|
||||
cl_cache.append((prg.clprg, [prg.global_size, prg.local_size, *[x._cl for x in args]]))
|
||||
cl_cache.append((prg.clprg, [prg.global_size, prg.local_size, *[x._buf for x in args]]))
|
||||
used_ops += prg.op_estimate
|
||||
|
||||
from extra.thneed import Thneed
|
||||
t = Thneed(cl_cache, {k:v._cl for k,v in input_rawbuffers.items()})
|
||||
t = Thneed(cl_cache, {k:v._buf for k,v in input_rawbuffers.items()})
|
||||
|
||||
# save thneed (before run)
|
||||
t.save(output_fn)
|
||||
|
||||
8
test/external/external_test_opt.py
vendored
8
test/external/external_test_opt.py
vendored
@@ -226,6 +226,14 @@ class TestOpt(unittest.TestCase):
|
||||
np.testing.assert_allclose(c.numpy(), d.numpy().transpose(1,0), rtol=1e-3, atol=1e-5)
|
||||
assert cache_len == 1, "reduceop was rerun!"
|
||||
|
||||
def test_fold_with_contiguous(self):
|
||||
a = Tensor.randn(16, 16, 16)
|
||||
b = Tensor.randn(16, 16)
|
||||
with CLCache():
|
||||
c = (a.sum(2).contiguous() + b).contiguous()
|
||||
c.realize()
|
||||
cache_len = len(GlobalCounters.cache)
|
||||
assert cache_len == 1, "contiguous wasn't folded"
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
||||
@@ -4,12 +4,13 @@ import unittest
|
||||
from extra.utils import fetch, fake_torch_load_zipped
|
||||
from PIL import Image
|
||||
|
||||
class TestUtils(unittest.TestCase):
|
||||
class TestUtils(unittest.TestCase):
|
||||
@unittest.skip("hangs sometimes")
|
||||
def test_fetch_bad_http(self):
|
||||
self.assertRaises(AssertionError, fetch, 'http://httpstat.us/500')
|
||||
self.assertRaises(AssertionError, fetch, 'http://httpstat.us/404')
|
||||
self.assertRaises(AssertionError, fetch, 'http://httpstat.us/400')
|
||||
|
||||
|
||||
def test_fetch_small(self):
|
||||
assert(len(fetch('https://google.com'))>0)
|
||||
|
||||
|
||||
@@ -8,25 +8,23 @@ from tinygrad.helpers import prod, dtypes
|
||||
|
||||
# *** first, we implement the atan2 op at the lowest level ***
|
||||
# `atan2_gpu` for GPUBuffers and `atan2_cpu` for CPUBuffers
|
||||
|
||||
from tinygrad.ops import ASTRunner, CompiledBuffer
|
||||
from tinygrad.runtime.ops_cpu import CPUBuffer
|
||||
from tinygrad.lazy import LazyBuffer, Device
|
||||
from tinygrad.ops import ASTRunner
|
||||
|
||||
# we don't always have GPU support, so the type signature is the abstract CompiledBuffer instead of GPUBuffer
|
||||
def atan2_gpu(a:CompiledBuffer, b:CompiledBuffer) -> CompiledBuffer:
|
||||
from tinygrad.runtime.ops_gpu import GPUBuffer
|
||||
assert type(a) == GPUBuffer and type(b) == GPUBuffer, "gpu function requires GPUBuffers"
|
||||
def atan2_gpu(ret:LazyBuffer, a:LazyBuffer, b:LazyBuffer):
|
||||
assert a.device == "GPU" and b.device == "GPU", "gpu function requires GPUBuffers"
|
||||
assert a.dtype == b.dtype and a.dtype == dtypes.float32, "gpu function only supports float32"
|
||||
ret = GPUBuffer(a.shape)
|
||||
ret.realized = Device[ret.device].buffer(prod(ret.shape), ret.dtype)
|
||||
ASTRunner("atan2", """
|
||||
__kernel void atan2(global float *c, global float *a, global float *b) {
|
||||
int idx = get_global_id(0);
|
||||
c[idx] = atan2(a[idx], b[idx]);
|
||||
}""", global_size=[prod(ret.shape)]).build(GPUBuffer.spec.runtime).exec([ret, a.contiguous(), b.contiguous()])
|
||||
return ret
|
||||
}""", global_size=[prod(ret.shape)]).build(Device[ret.device].runtime).exec([ret, a, b])
|
||||
return ret.realized
|
||||
|
||||
def atan2_cpu(a:CPUBuffer, b:CPUBuffer) -> CPUBuffer:
|
||||
return CPUBuffer(np.arctan2(a._buf, b._buf))
|
||||
def atan2_cpu(ret:LazyBuffer, a:LazyBuffer, b:LazyBuffer):
|
||||
return Device[ret.device].buffer(np.arctan2(a.realized._buf, b.realized._buf))
|
||||
|
||||
# *** second, we write the ATan2 mlop ***
|
||||
# NOTE: The derivative of atan2 doesn't need a custom op! https://www.liquisearch.com/atan2/derivative
|
||||
@@ -40,7 +38,7 @@ class ATan2(Function):
|
||||
def forward(self, a:LazyBuffer, b:LazyBuffer) -> LazyBuffer:
|
||||
assert prod(a.shape) == prod(b.shape) and a.device == b.device, "shape or device mismatch"
|
||||
self.a, self.b = a, b
|
||||
ast = LazyOp(LoadOps.CUSTOM, (a, b), {"GPU": atan2_gpu, "CPU": atan2_cpu}[a.device])
|
||||
ast = LazyOp(LoadOps.CUSTOM, (a.contiguous(), b.contiguous()), {"GPU": atan2_gpu, "CPU": atan2_cpu}[a.device])
|
||||
return LazyBuffer(a.device, a.shape, LoadOps, ast, max(a.dtype, b.dtype))
|
||||
def backward(self, grad_output:LazyBuffer) -> Tuple[Optional[LazyBuffer], Optional[LazyBuffer]]:
|
||||
denom = (self.a.binary_op(BinaryOps.MUL, self.a)).binary_op(BinaryOps.ADD, self.b.binary_op(BinaryOps.MUL, self.b))
|
||||
|
||||
@@ -3,7 +3,7 @@ import time
|
||||
import numpy as np
|
||||
import unittest
|
||||
from tinygrad.tensor import Tensor
|
||||
from tinygrad.helpers import getenv, IMAGE
|
||||
from tinygrad.helpers import getenv
|
||||
|
||||
FORWARD_ONLY = getenv("FORWARD_ONLY", 0)
|
||||
def helper_test_op(shps, torch_fxn, tinygrad_fxn=None, atol=1e-6, rtol=1e-3, grad_atol=1e-4, grad_rtol=1e-3, forward_only=False, vals=None, a=-0.5, b=3):
|
||||
@@ -300,7 +300,7 @@ class TestOps(unittest.TestCase):
|
||||
lambda x,w: torch.nn.functional.conv2d(x, w),
|
||||
lambda x,w: x.conv2d(w), atol=1e-2)
|
||||
|
||||
@unittest.skip("not supported with IMAGE=1")
|
||||
@unittest.skip("slow")
|
||||
def test_large_bs_conv(self):
|
||||
# large batch size can cause OpenCL image to exceed max image height on macOS
|
||||
# (or cause the conv kernel to overflow short sampling coords)
|
||||
@@ -308,7 +308,7 @@ class TestOps(unittest.TestCase):
|
||||
lambda x,w: torch.nn.functional.conv2d(x, w),
|
||||
lambda x,w: x.conv2d(w), atol=1e-4, rtol=1e-2)
|
||||
|
||||
@unittest.skip("not supported with IMAGE=1")
|
||||
@unittest.skip("slow")
|
||||
def test_large_ic_conv(self):
|
||||
# large input channel count can cause OpenCL image to exceed max image width on macOS
|
||||
helper_test_op([(1,2048,3,3), (1,2048,3,3)],
|
||||
@@ -377,7 +377,7 @@ class TestOps(unittest.TestCase):
|
||||
cin = 2
|
||||
helper_test_op([(bs,groups*cin,1,1), (groups*rcout,cin,1,1)],
|
||||
lambda x,w: torch.nn.functional.conv2d(x,w,groups=groups).relu(),
|
||||
lambda x,w: Tensor.conv2d(x,w,groups=groups).relu(), atol=1e-4, grad_rtol=1e-5, forward_only=IMAGE>=2)
|
||||
lambda x,w: Tensor.conv2d(x,w,groups=groups).relu(), atol=1e-4, grad_rtol=1e-5)
|
||||
|
||||
def test_medium_grouped_conv2d(self):
|
||||
bs = 1
|
||||
@@ -386,7 +386,7 @@ class TestOps(unittest.TestCase):
|
||||
cin = 2
|
||||
helper_test_op([(bs,groups*cin,1,1), (groups*rcout,cin,1,1)],
|
||||
lambda x,w: torch.nn.functional.conv2d(x,w,groups=groups).relu(),
|
||||
lambda x,w: Tensor.conv2d(x,w,groups=groups).relu(), atol=1e-4, grad_rtol=1e-5, forward_only=IMAGE>=2)
|
||||
lambda x,w: Tensor.conv2d(x,w,groups=groups).relu(), atol=1e-4, grad_rtol=1e-5)
|
||||
|
||||
def test_depthwise_conv2d(self):
|
||||
bs = 1
|
||||
|
||||
@@ -17,6 +17,13 @@ def multidevice_test(fxn):
|
||||
return ret
|
||||
|
||||
class TestExample(unittest.TestCase):
|
||||
@multidevice_test
|
||||
def test_convert_to_cpu(self, device):
|
||||
a = Tensor([[1,2],[3,4]], device=device)
|
||||
assert a.numpy().shape == (2,2)
|
||||
b = a.cpu()
|
||||
assert b.numpy().shape == (2,2)
|
||||
|
||||
@multidevice_test
|
||||
def test_2_plus_3(self, device):
|
||||
a = Tensor([2], device=device)
|
||||
|
||||
@@ -1,37 +1,38 @@
|
||||
#!/usr/bin/env python
|
||||
import unittest
|
||||
from tinygrad.ops import LazyOp, BinaryOps
|
||||
from tinygrad.interpreted import get_lazyop_info, InterpretedBuffer, GenericShape
|
||||
from typing import NamedTuple, Tuple
|
||||
from tinygrad.ops import LazyOp, BinaryOps, get_lazyop_info
|
||||
from tinygrad.helpers import DType, dtypes
|
||||
|
||||
class TestBuffer(NamedTuple):
|
||||
shape: Tuple[int, ...]
|
||||
dtype: DType
|
||||
|
||||
class TestFlopCounter(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.buf0 = TestBuffer(shape=(4,), dtype=dtypes.float32)
|
||||
self.buf1 = TestBuffer(shape=(4,), dtype=dtypes.float32)
|
||||
|
||||
def test_flops_add(self):
|
||||
buf0 = InterpretedBuffer(GenericShape((4,)))
|
||||
buf1 = InterpretedBuffer(GenericShape((4,)))
|
||||
op0 = LazyOp(BinaryOps.ADD, (buf0,buf1,), None)
|
||||
op0 = LazyOp(BinaryOps.ADD, (self.buf0,self.buf1,), None)
|
||||
info = get_lazyop_info(op0)
|
||||
self.assertEqual(info.flops, 4)
|
||||
|
||||
def test_flops_add_twice(self):
|
||||
buf0 = InterpretedBuffer(GenericShape((4,)))
|
||||
buf1 = InterpretedBuffer(GenericShape((4,)))
|
||||
op0 = LazyOp(BinaryOps.ADD, (buf0,buf1,), None)
|
||||
op1 = LazyOp(BinaryOps.ADD, (op0,buf1,), None)
|
||||
op0 = LazyOp(BinaryOps.ADD, (self.buf0,self.buf1,), None)
|
||||
op1 = LazyOp(BinaryOps.ADD, (op0,self.buf1,), None)
|
||||
info = get_lazyop_info(op1)
|
||||
self.assertEqual(info.flops, 8)
|
||||
|
||||
def test_flops_add_self(self):
|
||||
buf0 = InterpretedBuffer(GenericShape((4,)))
|
||||
buf1 = InterpretedBuffer(GenericShape((4,)))
|
||||
op0 = LazyOp(BinaryOps.ADD, (buf0,buf1,), None)
|
||||
op0 = LazyOp(BinaryOps.ADD, (self.buf0,self.buf1,), None)
|
||||
op1 = LazyOp(BinaryOps.ADD, (op0,op0,), None)
|
||||
info = get_lazyop_info(op1)
|
||||
self.assertEqual(info.flops, 8)
|
||||
|
||||
def test_flops_add_roundabout_self(self):
|
||||
buf0 = InterpretedBuffer(GenericShape((4,)))
|
||||
buf1 = InterpretedBuffer(GenericShape((4,)))
|
||||
op0 = LazyOp(BinaryOps.ADD, (buf0,buf1,), None)
|
||||
op1 = LazyOp(BinaryOps.ADD, (op0,buf1,), None)
|
||||
op0 = LazyOp(BinaryOps.ADD, (self.buf0,self.buf1,), None)
|
||||
op1 = LazyOp(BinaryOps.ADD, (op0,self.buf1,), None)
|
||||
op2 = LazyOp(BinaryOps.ADD, (op0,op1,), None)
|
||||
info = get_lazyop_info(op2)
|
||||
self.assertEqual(info.flops, 12)
|
||||
|
||||
@@ -1,11 +1,12 @@
|
||||
#!/usr/bin/env python
|
||||
import unittest
|
||||
import networkx as nx # type: ignore
|
||||
import numpy as np
|
||||
from tinygrad.tensor import Tensor
|
||||
from tinygrad.graph import G, log_op, prune_graph
|
||||
from tinygrad.runtime.ops_cpu import CPUBuffer
|
||||
from tinygrad.ops import BinaryOps, LazyOp, MovementOps, ReduceOps
|
||||
|
||||
def buf(*shp): return Tensor.ones(*shp, device="CPU").lazydata
|
||||
|
||||
class TestGraph(unittest.TestCase):
|
||||
def setUp(self):
|
||||
G.clear()
|
||||
@@ -14,10 +15,10 @@ class TestGraph(unittest.TestCase):
|
||||
assert nx.is_isomorphic(G, RG, node_match=lambda x,y: x["label"] == y["label"], edge_match=lambda x,y: x["label"] == y["label"] if "label" in y else True)
|
||||
|
||||
def test_add_graph(self):
|
||||
a = CPUBuffer(np.ones((4,4), dtype=np.float32))
|
||||
b = CPUBuffer(np.ones((4,4), dtype=np.float32))
|
||||
a = buf(4,4)
|
||||
b = buf(4,4)
|
||||
ast = LazyOp(BinaryOps.ADD, (a,b))
|
||||
ret = CPUBuffer(np.ones((4,4), dtype=np.float32))
|
||||
ret = buf(4,4)
|
||||
|
||||
RG = nx.DiGraph()
|
||||
RG.add_node(0, label="(4, 4)")
|
||||
@@ -30,12 +31,12 @@ class TestGraph(unittest.TestCase):
|
||||
self.helper_compare_graph(RG)
|
||||
|
||||
def test_add_sum_graph(self):
|
||||
a = CPUBuffer(np.ones((4,4), dtype=np.float32))
|
||||
b = CPUBuffer(np.ones((1,1), dtype=np.float32))
|
||||
a = buf(4,4)
|
||||
b = buf(1,1)
|
||||
op0 = LazyOp(MovementOps.RESHAPE, (b,), (4, 4))
|
||||
op1 = LazyOp(BinaryOps.ADD, (a,op0))
|
||||
ast = LazyOp(ReduceOps.SUM, (op1,), (1,1))
|
||||
ret = CPUBuffer(np.ones((1,1), dtype=np.float32))
|
||||
ret = buf(1,1)
|
||||
|
||||
RG = nx.DiGraph()
|
||||
RG.add_node(0, label="(4, 4)")
|
||||
@@ -48,14 +49,14 @@ class TestGraph(unittest.TestCase):
|
||||
self.helper_compare_graph(RG)
|
||||
|
||||
def test_add_graph_prune(self):
|
||||
a = CPUBuffer(np.ones((1,1), dtype=np.float32))
|
||||
a = buf(1,1)
|
||||
ast = LazyOp(MovementOps.RESHAPE, (a,), (4, 4))
|
||||
ret = CPUBuffer(np.ones((4,4), dtype=np.float32))
|
||||
ret = buf(4,4)
|
||||
log_op(ret, ast, show_graph=True)
|
||||
|
||||
b = CPUBuffer(np.ones((4,4), dtype=np.float32))
|
||||
b = buf(4,4)
|
||||
ast = LazyOp(BinaryOps.ADD, (ret,b))
|
||||
ret = CPUBuffer(np.ones((4,4), dtype=np.float32))
|
||||
ret = buf(4,4)
|
||||
log_op(ret, ast, show_graph=True)
|
||||
prune_graph()
|
||||
|
||||
|
||||
@@ -180,6 +180,9 @@ class TestSymbolic(unittest.TestCase):
|
||||
def test_div_numerator_negative(self):
|
||||
self.helper_test_variable((Variable("idx", 0, 9)*-10)//11, -9, 0, "((((idx*-10)+99)//11)+-9)")
|
||||
|
||||
def test_div_into_mod(self):
|
||||
self.helper_test_variable((Variable("idx", 0, 16)*4)%8//4, 0, 1, "(idx%2)")
|
||||
|
||||
class TestSymbolicNumeric(unittest.TestCase):
|
||||
def helper_test_numeric(self, f):
|
||||
# TODO: why are the negative tests broken? (even if we did support negative variables)
|
||||
|
||||
@@ -2,9 +2,8 @@ import itertools
|
||||
from enum import Enum, auto
|
||||
from typing import List, Tuple
|
||||
from tinygrad.helpers import prod, dedup, all_same, colored, DType
|
||||
from tinygrad.ops import LazyOp, MovementOps, get_buffers, ReduceOps, get_lazyops, map_buffers, ASTRunner
|
||||
from tinygrad.ops import LazyOp, MovementOps, get_buffers, ReduceOps, get_lazyops, map_buffers, ASTRunner, get_lazyop_info, FlopCounter
|
||||
from tinygrad.shape.shapetracker import ShapeTracker, View, strides_for_shape
|
||||
from tinygrad.interpreted import get_lazyop_info, GenericShape
|
||||
|
||||
def get_first_reduce(shapes):
|
||||
for i in range(len(shapes[0])):
|
||||
@@ -35,31 +34,14 @@ class ASTKernel:
|
||||
def __init__(self, ast:LazyOp, output_buffer=None):
|
||||
self.input_ast = ast
|
||||
|
||||
# if the AST ends with a RESHAPE, we remove it and create the buffer accordingly
|
||||
if ast.op == MovementOps.RESHAPE:
|
||||
output_shape = ast.arg
|
||||
ast = ast.src[0]
|
||||
else:
|
||||
output_shape = None
|
||||
# NOTE: if there's a RESHAPE, we skip it. the output shape is set from the reduce op or a latebuf
|
||||
if ast.op == MovementOps.RESHAPE: ast = ast.src[0]
|
||||
|
||||
self.bufs = dedup(get_buffers(ast))
|
||||
self.bufs = [output_buffer] + dedup(get_buffers(ast))
|
||||
self.ast = ast
|
||||
|
||||
# check if the output buffer is allowed to be used
|
||||
# if it's aliased, don't use it
|
||||
if output_buffer is not None:
|
||||
for a in self.bufs:
|
||||
if a._buf == output_buffer._buf and not a.st.contiguous:
|
||||
output_buffer = None
|
||||
break
|
||||
|
||||
# fetch lazyop info (this can be cached!)
|
||||
self.info: GenericShape = get_lazyop_info(ast)
|
||||
|
||||
# create the buffer we are returning (as the same type as the input buffers) and add it as the first buffer
|
||||
self.ret = output_buffer if output_buffer else type(self.bufs[0])(output_shape if output_shape else self.info.shape, force_create=True, dtype=self.info.dtype)
|
||||
assert self.ret.dtype == self.info.dtype, f"return dtype {self.ret.dtype} != {self.info.dtype}"
|
||||
self.bufs = ([type(self.ret)(self.info.shape, hostbuf=self.ret, dtype=self.info.dtype)] if output_shape else [self.ret]) + self.bufs
|
||||
self.info: FlopCounter = get_lazyop_info(ast)
|
||||
|
||||
# key for lookup in cache (can change, str might not be right)
|
||||
# bufs are needed because kernels like f(x) = x + x and f(x, y) = x + y have the same str(ast), but are different kernels.
|
||||
@@ -79,8 +61,8 @@ class ASTKernel:
|
||||
|
||||
# check valid AST kernel
|
||||
assert all_same([x.shape for x in self.earlybufs]), "all earlybufs must have the same shape"
|
||||
assert all_same([x.shape for x in self.bufs if x not in self.earlybufs]), "all latebufs must have the same shape"
|
||||
assert all_same([len(x.shape) for x in self.bufs]), "all bufs must have the same shape size"
|
||||
assert all_same([x.shape for x in self.bufs[1:] if x not in self.earlybufs]), "all latebufs must have the same shape"
|
||||
assert all_same([len(x.shape) for x in self.bufs[1:]]), "all bufs must have the same shape size"
|
||||
|
||||
# get full shape buf index (earlybufs if there are any, otherwise output)
|
||||
self.full_buf_index: int = self.bufs.index(self.earlybufs[0]) if len(self.earlybufs) > 0 else 0
|
||||
@@ -89,6 +71,10 @@ class ASTKernel:
|
||||
self.sts: List[ShapeTracker] = [x.st.copy() for x in self.bufs] # create new shapetrackers inside this kernel
|
||||
for st in self.sts: st.simplify()
|
||||
|
||||
# make the output buffer shape correct in here
|
||||
if self.reduceop is not None: self.sts[0].reshape(self.reduceop.arg)
|
||||
else: self.sts[0].reshape([x.shape for x in self.bufs[1:] if x not in self.earlybufs][0])
|
||||
|
||||
# move all reduce axes to the end
|
||||
reduce = list(enumerate(zip(self.full_shape, self.sts[0].shape)))
|
||||
permute = tuple([i for i,(s,n) in reduce if s == n] + [i for i,(s,n) in reduce if s != n])
|
||||
@@ -125,7 +111,7 @@ class ASTKernel:
|
||||
if print_shapetrackers:
|
||||
for st in self.sts: print(st)
|
||||
for i in range(len(self.sts)):
|
||||
print(prefix, self.bufs[i].dtype if self.bufs[i] is not None else None, self.buftokens[i], f"early:{'T' if i < len(self.bufs) and self.bufs[i] in self.earlybufs else 'F'}", self.sts[i].shape, self.sts[i].views[-1].strides, len(self.sts[i].views), type(self.bufs[i]._buf) if self.bufs[i] is not None else "FAKE")
|
||||
print(prefix, self.bufs[i].dtype if self.bufs[i] is not None else None, self.buftokens[i], f"early:{'T' if i < len(self.bufs) and self.bufs[i] in self.earlybufs else 'F'}", self.sts[i].shape, self.sts[i].views[-1].strides, len(self.sts[i].views), self.bufs[i].realized if self.bufs[i] is not None else "FAKE")
|
||||
|
||||
def codegen(self) -> ASTRunner: raise NotImplementedError("need a codegen")
|
||||
|
||||
|
||||
@@ -3,9 +3,10 @@ from collections import defaultdict
|
||||
from typing import Optional, List, Tuple, Dict, Set, Final, NamedTuple, ClassVar, DefaultDict
|
||||
from tinygrad.ops import UnaryOps, BinaryOps, ReduceOps, LazyOp, Op, ASTRunner
|
||||
from tinygrad.codegen.ast import ASTKernel, Token, Types
|
||||
from tinygrad.shape.symbolic import Node, MulNode, DivNode, SumNode, AndNode, Variable, render_python
|
||||
from tinygrad.shape.symbolic import Node, MulNode, DivNode, SumNode, AndNode, ModNode, Variable, render_python
|
||||
from tinygrad.shape.shapetracker import ShapeTracker, View
|
||||
from tinygrad.helpers import getenv, DEBUG, prod, partition, mnum, all_same, dedup, dtypes
|
||||
from tinygrad.runtime.lib import RawConst
|
||||
|
||||
# div is different in cl than python
|
||||
render_cl = render_python.copy()
|
||||
@@ -55,7 +56,7 @@ class GPUCodegen(ASTKernel):
|
||||
kernel_name_cache: Final[Dict[str, str]] = {}
|
||||
|
||||
code_for_op: Final[Dict[Op, str]] = {
|
||||
UnaryOps.NOOP: "(A)", UnaryOps.NEG: "(-(A))", UnaryOps.NOT: "(1.0f-A)",
|
||||
UnaryOps.NOOP: "(A)", UnaryOps.NEG: "(-(A))", UnaryOps.NOT: "(1.0f-A)", UnaryOps.CAST: "(A)",
|
||||
UnaryOps.EXP: "native_exp(A)" if NATIVE_EXPLOG else "exp(A)",
|
||||
UnaryOps.LOG: "native_log(A)" if NATIVE_EXPLOG else "log(A)",
|
||||
BinaryOps.ADD: "(A+B)", BinaryOps.SUB: "(A-B)", BinaryOps.MUL: "(A*B)",
|
||||
@@ -73,7 +74,7 @@ class GPUCodegen(ASTKernel):
|
||||
assert len(self.sts[buf_index].views) == 1, "store has more than one view"
|
||||
|
||||
# all stores can merge, since they have one view and are valid
|
||||
should_upcast = self.lang.float4 and self.buftokens[buf_index].can_float4() and (self.bufs[buf_index] is None or self.bufs[buf_index].dtype != dtypes.float16 or hasattr(self.bufs[buf_index]._buf, "IMAGE"))
|
||||
should_upcast = self.lang.float4 and self.buftokens[buf_index].can_float4() and (self.bufs[buf_index] is None or self.bufs[buf_index].dtype != dtypes.float16 or self.bufs[buf_index].dtype.name.startswith('image'))
|
||||
|
||||
to_store = {o:v for o,v in zip(self.buftokens[buf_index].offsets(), value)}
|
||||
did_store = set()
|
||||
@@ -84,10 +85,10 @@ class GPUCodegen(ASTKernel):
|
||||
if should_upcast:
|
||||
for j in range(4): did_store.add(o+j)
|
||||
v = self.group_float4([to_store[o+j] for j in range(4)])
|
||||
if self.bufs[buf_index] is not None and hasattr(self.bufs[buf_index]._buf, "IMAGE"):
|
||||
if self.bufs[buf_index] is not None and self.bufs[buf_index].dtype.name.startswith('image'):
|
||||
assert v.typ == Types.FLOAT4, "Image requires upcasting to FLOAT4"
|
||||
idx, idy = to_image_idx(self.bufs[buf_index]._base_shape, idxy, valid)
|
||||
self.kernel.append(f"write_imagef({self.buftokens[buf_index].tok}, (int2)({idx.render(render_cl)}, {idy.render(render_cl)}), {v.tok}); /* {self.bufs[buf_index]._base_shape} */\n")
|
||||
idx, idy = to_image_idx(self.bufs[buf_index].dtype.shape, idxy, valid)
|
||||
self.kernel.append(f"write_imagef({self.buftokens[buf_index].tok}, (int2)({idx.render(render_cl)}, {idy.render(render_cl)}), {v.tok}); /* {self.bufs[buf_index].dtype.shape} */\n")
|
||||
elif v.typ == Types.FLOAT4:
|
||||
self.kernel.append(f"(({self.lang.buffer_prefix if self.bufs[buf_index] is not None else self.lang.smem_prefix}float4*){self.buftokens[buf_index].tok})[{(idxy//4).render(render_cl)}] = {v.tok};\n")
|
||||
else:
|
||||
@@ -96,12 +97,21 @@ class GPUCodegen(ASTKernel):
|
||||
def load(self, buf_index:int, idx_override:Optional[str]=None) -> List[Token]:
|
||||
# constant folding
|
||||
const = None
|
||||
if self.bufs[buf_index] is not None and self.bufs[buf_index]._base_shape == (1,) and self.bufs[buf_index]._backing is not None:
|
||||
if self.bufs[buf_index] is not None and isinstance(self.bufs[buf_index].realized, RawConst):
|
||||
# bufs_to_delete can be removed, just ignore RawConst at runtime
|
||||
if buf_index != 0: self.bufs_to_delete.add(buf_index)
|
||||
val = self.bufs[buf_index]._backing[0]
|
||||
val = self.bufs[buf_index].realized._buf
|
||||
assert not math.isnan(val)
|
||||
const = Token(f"({val}f)", Types.FLOAT)
|
||||
should_upcast = self.lang.float4 and const is None and self.buftokens[buf_index].can_float4() and (self.bufs[buf_index] is None or self.bufs[buf_index].dtype != dtypes.float16 or hasattr(self.bufs[buf_index]._buf, "IMAGE"))
|
||||
|
||||
def check_no_mul(test, var):
|
||||
if test == var: return True
|
||||
if isinstance(test, SumNode): return any(check_no_mul(x, var) for x in test.nodes) # in a sum is okay
|
||||
if isinstance(test, ModNode) and test.b%4 == 0: return check_no_mul(test.a, var) # removing a mod is okay
|
||||
return False
|
||||
|
||||
is_image = self.bufs[buf_index] is not None and self.bufs[buf_index].dtype.name.startswith('image')
|
||||
should_upcast = self.lang.float4 and const is None and self.buftokens[buf_index].can_float4() and (self.bufs[buf_index] is None or self.bufs[buf_index].dtype != dtypes.float16 or self.bufs[buf_index].dtype.name.startswith('image'))
|
||||
tokens = []
|
||||
test_idy = []
|
||||
for o in self.buftokens[buf_index].offsets():
|
||||
@@ -111,21 +121,22 @@ class GPUCodegen(ASTKernel):
|
||||
if should_upcast:
|
||||
float4_index = Variable("FLOAT4_INDEX", 0, 3)
|
||||
idxy_test, valid_test = self.sts[buf_index].expr_idxs(float4_index+o) if idx_override is None else self.sts[buf_index].expr_node(idx_override, float4_index+o)
|
||||
can_merge = idxy_test == float4_index or (isinstance(idxy_test, SumNode) and any(x == float4_index for x in idxy_test.nodes)) # float4_index must be in there without a multiply
|
||||
can_merge = can_merge and "FLOAT4_INDEX" not in (idxy_test//4).render() and "FLOAT4_INDEX" not in valid_test.render() # float4_index must not be in after divide or in valid (TODO: don't check render)
|
||||
can_merge = check_no_mul(idxy_test, float4_index)
|
||||
# NOTE: valid_test.render() can contain a FLOAT4_INDEX that can't affect the result: example <(((idx0<0,511>*4)+FLOAT4_INDEX<0,3>)<1024)>
|
||||
can_merge = can_merge and "FLOAT4_INDEX" not in (idxy_test//4).render() and ("FLOAT4_INDEX" not in valid_test.render() or True) # float4_index must not be in after divide or in valid (TODO: don't check render)
|
||||
if const is not None:
|
||||
ldr = const
|
||||
elif self.bufs[buf_index] is not None and hasattr(self.bufs[buf_index]._buf, "IMAGE"):
|
||||
assert should_upcast and can_merge, f"Image requires upcasting to FLOAT4 {self.buftokens[buf_index]}"
|
||||
idx, idy = to_image_idx(self.bufs[buf_index]._base_shape, idxy, valid, VALIDHACKS)
|
||||
ldr = Token(f"read_imagef({self.buftokens[buf_index].tok}, smp, (int2)({idx.render(render_cl)}, {idy.render(render_cl)})) /* {self.bufs[buf_index]._base_shape} */", Types.FLOAT4)
|
||||
elif self.bufs[buf_index] is not None and is_image:
|
||||
assert should_upcast and can_merge, f"Image requires upcasting to FLOAT4 {self.buftokens[buf_index]} should_upcast:{should_upcast} can_merge:{can_merge}"
|
||||
idx, idy = to_image_idx(self.bufs[buf_index].dtype.shape, idxy, valid, VALIDHACKS)
|
||||
ldr = Token(f"read_imagef({self.buftokens[buf_index].tok}, smp, (int2)({idx.render(render_cl)}, {idy.render(render_cl)})) /* {self.bufs[buf_index].dtype.shape} */", Types.FLOAT4)
|
||||
test_idy.append(idy.render(render_cl))
|
||||
elif should_upcast and can_merge:
|
||||
ldr = Token(f"(({self.lang.buffer_prefix if self.bufs[buf_index] is not None else self.lang.smem_prefix}float4*){self.buftokens[buf_index].tok})[{(idxy//4).render(render_cl)}]", Types.FLOAT4)
|
||||
else:
|
||||
ldr = Token(f"{self.buftokens[buf_index].tok}[{idxy.render(render_cl)}]", Types.FLOAT)
|
||||
invalid = self.group_float4([Token("0.0f", Types.FLOAT)]*4) if ldr.typ == Types.FLOAT4 else Token("0.0f", Types.FLOAT)
|
||||
ldr = ldr if valid.min == 1 or (VALIDHACKS and hasattr(self.bufs[buf_index]._buf, "IMAGE")) else (Token(f"({valid.render(render_cl)} ? {ldr.tok} : {invalid.tok})", ldr.typ) if valid.max == 1 else invalid)
|
||||
ldr = ldr if valid.min == 1 or (VALIDHACKS and is_image) else (Token(f"({valid.render(render_cl)} ? {ldr.tok} : {invalid.tok})", ldr.typ) if valid.max == 1 else invalid)
|
||||
if const is not None:
|
||||
self.loaded_keys[(buf_index,o)] = ldr
|
||||
else:
|
||||
@@ -154,9 +165,10 @@ class GPUCodegen(ASTKernel):
|
||||
def required_optimizations(self, early_only=False):
|
||||
for buf_index,buf in enumerate(self.bufs):
|
||||
upcast_strides = [self.sts[buf_index].strides[i] for i in self.upcast_in_mid_reduce_axes]
|
||||
if (not early_only or buf in self.earlybufs) and hasattr(buf._buf, "IMAGE") and not (self.buftokens[buf_index].can_float4() or (buf not in self.earlybufs and (1 in upcast_strides))):
|
||||
if (not early_only or buf in self.earlybufs) and self.bufs[buf_index].dtype.name.startswith('image') and not (self.buftokens[buf_index].can_float4() or (buf not in self.earlybufs and (1 in upcast_strides))):
|
||||
axes = [i for i,x in enumerate(self.sts[buf_index].strides) if x == 1]
|
||||
assert len(axes) == 1, f"wrong number of stride 1 axis : {axes}"
|
||||
assert len(axes) == 1, f"wrong number of stride 1 axis : {axes} on buf_index {buf_index}, {self.sts[buf_index]}"
|
||||
assert self.sts[buf_index].shape[axes[0]]%4 == 0, f"axis:{axes[0]} in buffer {buf_index} is not a multiple of 4, {self.sts[buf_index].shape}"
|
||||
self.shift_to(axes[0], 4)
|
||||
self.upcast()
|
||||
assert self.buftokens[buf_index].can_float4()
|
||||
@@ -177,12 +189,13 @@ class GPUCodegen(ASTKernel):
|
||||
self.group_for_reduce.append(sz)
|
||||
break
|
||||
|
||||
# are we upcasting in mid reduce?
|
||||
if hasattr(self.bufs[0]._buf, "IMAGE") and not self.buftokens[0].can_float4() and self.group_for_reduce and self.first_reduce <= 2:
|
||||
# are we upcasting in mid reduce? (only for images)
|
||||
if self.bufs[0].dtype.name.startswith('image') and not self.buftokens[0].can_float4() and self.group_for_reduce and self.first_reduce <= 2 and prod(self.sts[0].shape) > 1:
|
||||
axes = [i for i,x in enumerate(self.sts[0].strides) if x == 1]
|
||||
assert len(axes) == 1, f"wrong number of stride 1 axis : {axes}"
|
||||
self.shift_to(axes[0], 4, insert_before=self.first_reduce + len(self.group_for_reduce)) # insert at the end of the grouped axis
|
||||
self.group_for_reduce.append(4)
|
||||
if self.sts[0].shape[axes[0]]%4 == 0:
|
||||
self.shift_to(axes[0], 4, insert_before=self.first_reduce + len(self.group_for_reduce)) # insert at the end of the grouped axis
|
||||
self.group_for_reduce.append(4)
|
||||
|
||||
# now do everything required
|
||||
self.required_optimizations()
|
||||
@@ -191,8 +204,8 @@ class GPUCodegen(ASTKernel):
|
||||
self.simplify_ones()
|
||||
|
||||
# use more opencl indexing if the output buffer is an image and we have room
|
||||
if hasattr(self.bufs[0]._buf, "IMAGE") and self.first_reduce+len(self.group_for_reduce) < 3:
|
||||
base_shape = self.bufs[0]._base_shape
|
||||
if self.bufs[0].dtype.name.startswith('image') and self.first_reduce+len(self.group_for_reduce) < 3:
|
||||
base_shape = self.bufs[0].dtype.shape
|
||||
if (base_shape[0]*base_shape[1]) % self.sts[0].shape[0] == 0 and self.sts[0].shape[0]//base_shape[0] != 0:
|
||||
if DEBUG >= 4: print("split opencl", base_shape, self.sts[0].shape)
|
||||
self.reshape_and_permute(lambda x: [base_shape[0], x[0]//base_shape[0]]+list(x[1:]), None)
|
||||
@@ -266,7 +279,7 @@ class GPUCodegen(ASTKernel):
|
||||
self.bufs_to_delete: Set[int] = set()
|
||||
self.loaded_keys: Dict[Tuple[int,int], Token] = {}
|
||||
self.prekernel: Set[str] = set()
|
||||
self.kernel: List[str] = ["const sampler_t smp = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST;\n"] if any(hasattr(buf._buf, "IMAGE") for buf in self.bufs if buf is not None) else []
|
||||
self.kernel: List[str] = ["const sampler_t smp = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST;\n"] if any(buf.dtype.name.startswith("image") for buf in self.bufs if buf is not None) else []
|
||||
|
||||
if self.lang.half_prekernel and any(x.dtype == dtypes.float16 for x in self.bufs if x is not None): self.prekernel.add(self.lang.half_prekernel+"\n")
|
||||
|
||||
@@ -307,7 +320,7 @@ class GPUCodegen(ASTKernel):
|
||||
for j in self.upcast_in_mid_reduce_axes:
|
||||
self.reshape_and_permute(None, [i for i in range(self.shape_len) if i != j] + [j])
|
||||
self.upcast()
|
||||
if DEBUG >= 4: print("upcast", self.colorshape()) # NOTE: colorshape is wrong here
|
||||
#if DEBUG >= 4: print("upcast", self.colorshape()) # NOTE: colorshape is wrong here, you have to remove it from group_for_reduce before calling
|
||||
|
||||
self.kernel.append(f"if ({lidx.render(render_cl)} == 0) {{\n") # lidx.max works here too
|
||||
|
||||
@@ -324,7 +337,7 @@ class GPUCodegen(ASTKernel):
|
||||
self.kernel.append("\n}")
|
||||
|
||||
# concat kernel into prg
|
||||
buftypes = [f"{'read_only' if i > 0 else 'write_only'} image2d_t" if hasattr(x._buf, "IMAGE") else self.lang.buffer_prefix+self.buftokens[i].decltype(self.bufs[i].dtype)+self.lang.buffer_suffix for i,x in enumerate(self.bufs) if x is not None]
|
||||
buftypes = [f"{'read_only' if i > 0 else 'write_only'} image2d_t" if x.dtype.name.startswith('image') else self.lang.buffer_prefix+self.buftokens[i].decltype(self.bufs[i].dtype)+self.lang.buffer_suffix for i,x in enumerate(self.bufs) if x is not None]
|
||||
prg = ' '.join(list(self.prekernel) + [f"{self.lang.kernel_prefix} void KERNEL_NAME_PLACEHOLDER(",] +
|
||||
[', '.join([f'{t} data{i}' for i,t in enumerate(buftypes) if i not in self.bufs_to_delete] + self.lang.extra_args)] +
|
||||
[") {\n"] + self.kernel)
|
||||
@@ -342,4 +355,5 @@ class GPUCodegen(ASTKernel):
|
||||
return ASTRunner(function_name, prg.replace("KERNEL_NAME_PLACEHOLDER", function_name), self.bufs_to_delete,
|
||||
list(self.output_shape[::-1]) if len(self.output_shape) > 0 else [1],
|
||||
(self.group_for_reduce[::-1] + [1]*(len(self.output_shape)-len(self.group_for_reduce))) if self.group_for_reduce else None,
|
||||
op_estimate=self.info.flops, mem_estimate=sum(x.dtype.itemsize*prod(x._base_shape) for x in self.bufs if x is not None))
|
||||
op_estimate=self.info.flops,
|
||||
mem_estimate=sum(x.dtype.itemsize*(x.realized.size if x.realized is not None else prod(x.shape)) for x in self.bufs if x is not None))
|
||||
|
||||
@@ -212,4 +212,6 @@ class LLVMCodegen(ASTKernel):
|
||||
loop_entry[-1].branch(loop_exit[-1]._block)
|
||||
loop_exit[0].ret_void()
|
||||
|
||||
return ASTRunner('exec', str(module), op_estimate=self.info.flops, mem_estimate=sum(x.dtype.itemsize*prod(x._base_shape) for x in self.bufs))
|
||||
# TODO: mem_estimate is copied from GPU
|
||||
return ASTRunner('exec', str(module), op_estimate=self.info.flops,
|
||||
mem_estimate=sum(x.dtype.itemsize*(x.realized.size if x.realized is not None else prod(x.shape)) for x in self.bufs if x is not None))
|
||||
|
||||
@@ -5,7 +5,8 @@ except ImportError:
|
||||
nx = None # graph won't work
|
||||
from collections import defaultdict
|
||||
from typing import Dict, List, Optional
|
||||
from tinygrad.ops import DeviceBuffer, UnaryOps, BinaryOps, ReduceOps, MovementOps, LoadOps, FusedOps, Op, OpType, LazyOp, get_buffers, get_lazyops
|
||||
from tinygrad.ops import UnaryOps, BinaryOps, ReduceOps, MovementOps, LoadOps, FusedOps, Op, OpType, LazyOp, get_buffers, get_lazyops
|
||||
from tinygrad.lazy import LazyBuffer
|
||||
from tinygrad.helpers import getenv, DEBUG
|
||||
|
||||
GRAPH, PRUNEGRAPH, GRAPHPATH = getenv("GRAPH", 0), getenv("PRUNEGRAPH", 0), getenv("GRAPHPATH", "/tmp/net")
|
||||
@@ -37,18 +38,22 @@ def get_sop(op: List[Op]):
|
||||
if len(op) <= 4: return '.'.join([str(y).split(".")[1][0:3] for y in op][::-1])
|
||||
return str(len(op))
|
||||
|
||||
def log_op(ret: DeviceBuffer, ast: LazyOp, show_graph: Optional[bool] = None):
|
||||
def str_dtype(dtyp):
|
||||
ret = str(dtyp)[7:]
|
||||
return "" if ret == 'float' else f"\n{ret}"
|
||||
|
||||
def log_op(ret: LazyBuffer, ast: LazyOp, show_graph: Optional[bool] = None):
|
||||
if show_graph is None: show_graph = bool(GRAPH)
|
||||
if not DEBUG and not show_graph: return
|
||||
op: List[Op] = [x.op for x in get_lazyops(ast)]
|
||||
inp: List[DeviceBuffer] = get_buffers(ast)
|
||||
inp: List[LazyBuffer] = get_buffers(ast)
|
||||
if len(inp) == 1 and inp[0] == ret:
|
||||
if show_graph and nm(ret) in G.nodes: G.nodes[nm(ret)]['style'] += ', bold'
|
||||
return # don't log self loops
|
||||
oporder = [LoadOps, FusedOps, ReduceOps, BinaryOps, UnaryOps, MovementOps]
|
||||
optype = type(sorted(op, key=lambda x: oporder.index(type(x)))[0])
|
||||
cnts[optype] += 1
|
||||
if DEBUG >= 4: print(f"{op} : {', '.join([f'{x.shape}-<{nm(x)}>' for x in inp])} -> {ret.shape}-<{nm(ret)}>")
|
||||
if DEBUG >= 6: print(f"{op} : {', '.join([f'{x.shape}-<{nm(x)}>' for x in inp])} -> {ret.shape}-<{nm(ret)}>")
|
||||
if show_graph:
|
||||
top_colors = {LoadOps: '#FFFF80', UnaryOps: "#c0c0c0", ReduceOps: "#8080ff", BinaryOps: "#c0c0c0", MovementOps: "#80ff80", FusedOps: "#ff8080"}
|
||||
dashed = (optype == LoadOps and hasattr(ret, "_backing")) or (hasattr(ret, "st") and not ret.st.contiguous) # type: ignore
|
||||
@@ -56,10 +61,10 @@ def log_op(ret: DeviceBuffer, ast: LazyOp, show_graph: Optional[bool] = None):
|
||||
for x in inp:
|
||||
G.add_edge(nm(x), nm(ret), label=get_sop(op))
|
||||
if 'label' not in G.nodes[nm(x)]:
|
||||
G.nodes[nm(x)]['label'] = str(x.shape)
|
||||
G.nodes[nm(x)]['label'] = str(x.shape)+str_dtype(ret.dtype)
|
||||
if nm(ret) not in G.nodes: G.add_node(nm(ret))
|
||||
|
||||
G.nodes[nm(ret)]['label'] = str(set(x.shape for x in inp))+"\n"+str(ret.shape) if optype == ReduceOps else str(ret.shape)
|
||||
G.nodes[nm(ret)]['label'] = (str(set(x.shape for x in inp))+"\n"+str(ret.shape) if optype == ReduceOps else str(ret.shape))+str_dtype(ret.dtype)
|
||||
G.nodes[nm(ret)]['fillcolor'] = (top_colors[optype] + ('80' if dashed else str())) if optype in top_colors else "#ffffff"
|
||||
G.nodes[nm(ret)]['style'] = 'filled, dashed' if dashed else 'filled'
|
||||
G.nodes[nm(ret)]['prunable'] = optype in [LoadOps, MovementOps]
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
from __future__ import annotations
|
||||
import os, math, functools
|
||||
import numpy as np
|
||||
from typing import Tuple, Union, List, NamedTuple, Final, Iterator, ClassVar, Optional, Callable, Any
|
||||
@@ -23,23 +24,33 @@ DEBUG, IMAGE = getenv("DEBUG", 0), getenv("IMAGE", 0)
|
||||
# **** tinygrad now supports dtypes! *****
|
||||
|
||||
class DType(NamedTuple):
|
||||
priority: int # this determines when things get upcasted
|
||||
itemsize: int
|
||||
name: str
|
||||
np: type # TODO: someday this will be removed with the "remove numpy" project
|
||||
def __repr__(self): return f"dtypes.{self.name}"
|
||||
|
||||
# dependent typing?
|
||||
class ImageDType(DType):
|
||||
def __new__(cls, priority, itemsize, name, np, shape):
|
||||
return super().__new__(cls, priority, itemsize, name, np)
|
||||
def __init__(self, priority, itemsize, name, np, shape):
|
||||
self.shape: Tuple[int, ...] = shape # arbitrary arg for the dtype, used in image for the shape
|
||||
super().__init__()
|
||||
def __repr__(self): return f"dtypes.{self.name}({self.shape})"
|
||||
|
||||
class LazyNumpyArray:
|
||||
def __init__(self, fxn, shape, dtype): self.fxn, self.shape, self.dtype = fxn, shape, dtype
|
||||
def __call__(self): return self.fxn(self)
|
||||
def __call__(self) -> np.ndarray: return np.ascontiguousarray(self.fxn(self) if callable(self.fxn) else self.fxn).reshape(self.shape).astype(self.dtype)
|
||||
def reshape(self, new_shape): return LazyNumpyArray(self.fxn, new_shape, self.dtype)
|
||||
def copy(self): return self
|
||||
def astype(self, typ): return self
|
||||
def copy(self): return self if callable(self.fxn) else LazyNumpyArray(self.fxn.copy(), self.shape, self.dtype)
|
||||
def astype(self, typ): return LazyNumpyArray(self.fxn, self.shape, typ)
|
||||
|
||||
class dtypes:
|
||||
float16: Final[DType] = DType(2, "half", np.float16)
|
||||
float32: Final[DType] = DType(4, "float", np.float32)
|
||||
float16: Final[DType] = DType(0, 2, "half", np.float16)
|
||||
float32: Final[DType] = DType(1, 4, "float", np.float32)
|
||||
@staticmethod
|
||||
def from_np(x:Union[LazyNumpyArray, np.ndarray]) -> DType: return {np.dtype(np.float16): dtypes.float16, np.dtype(np.float32): dtypes.float32}[np.dtype(x.dtype)]
|
||||
def from_np(x) -> DType: return {np.dtype(np.float16): dtypes.float16, np.dtype(np.float32): dtypes.float32}[np.dtype(x)]
|
||||
|
||||
class GlobalCounters:
|
||||
global_ops: ClassVar[int] = 0
|
||||
|
||||
@@ -1,56 +0,0 @@
|
||||
from __future__ import annotations
|
||||
from typing import Tuple, Any, ClassVar, Optional, Callable, Dict
|
||||
import functools
|
||||
from tinygrad.helpers import DType, dtypes, prod, GlobalCounters, DEBUG
|
||||
from tinygrad.shape.shapetracker import ShapeTracker
|
||||
from tinygrad.ops import DeviceBuffer, LazyOp, get_buffers, map_buffers, Op, FusedOps, UnaryOps, MovementOps, ReduceOps, BinaryOps
|
||||
|
||||
# this is a quick "buffer" class for flop tracking and getting the output shape
|
||||
class GenericShape:
|
||||
def __init__(self, shape:Tuple[int, ...], dtype:DType=dtypes.float32, flops:int=0): self.shape, self.dtype, self.flops = shape, dtype, flops
|
||||
def consume_flops(self):
|
||||
self.flops, ret = 0, self.flops
|
||||
return ret
|
||||
shape_fxn_for_op: Dict[Op, Callable] = {
|
||||
**{op:lambda self: GenericShape(self.shape, self.dtype, self.consume_flops() + prod(self.shape)) for op in UnaryOps},
|
||||
**{op:lambda self,y: GenericShape(self.shape, max(self.dtype, y.dtype), self.consume_flops() + y.consume_flops() + prod(self.shape)) for op in BinaryOps},
|
||||
**{op:lambda self,new_shape: GenericShape(new_shape, self.dtype, self.consume_flops() + prod(self.shape)) for op in ReduceOps},
|
||||
**{op:functools.partial(lambda mop,self,arg: GenericShape(ShapeTracker(self.shape).movement_op(mop, arg).shape, self.dtype, self.consume_flops()), op) for op in MovementOps}}
|
||||
|
||||
# this runs the LazyOp and gives you the output shape/dtype and flop count
|
||||
def get_lazyop_info(ast:LazyOp) -> GenericShape: return InterpretedBuffer.exec_ast(map_buffers({x:InterpretedBuffer(GenericShape(x.shape, x.dtype)) for x in get_buffers(ast)}, ast))._buf
|
||||
|
||||
# used in CPUBuffer and TorchBuffer
|
||||
class InterpretedBuffer(DeviceBuffer): # pylint: disable=abstract-method
|
||||
fxn_for_op: ClassVar = shape_fxn_for_op
|
||||
def __init__(self, lbuf:Any):
|
||||
self._buf: Any = lbuf
|
||||
self.shape: Tuple[int, ...] = tuple(lbuf.shape)
|
||||
self.dtype: DType = self.to_tinygrad_dtype() if hasattr(self, 'to_tinygrad_dtype') else lbuf.dtype
|
||||
# NOTE: this is overcounting the memory used, as reshapes and stuff are aliases
|
||||
self._memsz = (prod(self.shape) * self.dtype.itemsize) if not isinstance(self, InterpretedBuffer) else 0
|
||||
GlobalCounters.mem_used += self._memsz
|
||||
def __del__(self): GlobalCounters.mem_used -= self._memsz
|
||||
def contiguous(self): return type(self).exec_ast(LazyOp(op=UnaryOps.NOOP, src=(self,)))
|
||||
def movement_op(self, op:MovementOps, arg=None): return type(self)(self.fxn_for_op[op](self._buf, arg)) if op in self.fxn_for_op else type(self)(getattr(self._buf, op.name.lower())(arg))
|
||||
@classmethod
|
||||
def exec_ast(cls, ast:LazyOp, output_buffer:Optional[InterpretedBuffer]=None, context=None):
|
||||
if FusedOps.MULACC in cls.fxn_for_op and ast.op == ReduceOps.SUM and isinstance(ast.src[0], LazyOp) and ast.src[0].op == BinaryOps.MUL:
|
||||
ast = LazyOp(FusedOps.MULACC, ast.src[0].src, ast.arg)
|
||||
created_context = context is None
|
||||
if context is None: context = dict()
|
||||
if not created_context and ast in context: return context[ast]
|
||||
srcs = [cls.exec_ast(x, context=context) if isinstance(x, LazyOp) else x for x in ast.src]
|
||||
if ast.op in BinaryOps: assert srcs[0].shape == srcs[1].shape, f"BinaryOps shape mismatch {srcs[0].shape} != {srcs[1].shape}"
|
||||
if ast.op in ReduceOps: assert all(r == n or n == 1 for r,n in zip(srcs[0].shape, ast.arg)), f"ReduceOps can't reduce {srcs[0].shape} -> {ast.arg}"
|
||||
if ast.op in MovementOps: ret = srcs[0].movement_op(ast.op, ast.arg)
|
||||
else: ret = cls(cls.fxn_for_op[ast.op](*([x._buf for x in srcs] + ([ast.arg] if ast.arg is not None else []))))
|
||||
if DEBUG >= 4 or (not isinstance(cls, InterpretedBuffer) and DEBUG >= 3):
|
||||
print(f"*** {'exec' if created_context else ' '} {GlobalCounters.mem_used/1e9:5.2f} GB op: {ast.op:20s} out({ret.dtype.name}): {str(ret.shape):30s} in({len(srcs)}):", list(set(x.shape for x in srcs)), ast.arg if ast.arg is not None else "")
|
||||
if not created_context: context[ast] = ret
|
||||
if output_buffer is not None:
|
||||
assert output_buffer.shape == ret.shape, output_buffer.dtype == ret.dtype
|
||||
output_buffer._buf = ret._buf
|
||||
return output_buffer
|
||||
else:
|
||||
return ret
|
||||
@@ -4,7 +4,7 @@ from tinygrad.helpers import DEBUG, colored
|
||||
|
||||
from tinygrad.lazy import Device
|
||||
from tinygrad.tensor import Tensor
|
||||
from tinygrad.ops import GlobalCounters, CompiledBuffer, RawBuffer
|
||||
from tinygrad.ops import GlobalCounters, RawBuffer
|
||||
|
||||
class TinyJit:
|
||||
def __init__(self, fxn:Callable):
|
||||
@@ -20,9 +20,7 @@ class TinyJit:
|
||||
def __call__(self, *args, **kwargs) -> Any:
|
||||
if Device.DEFAULT not in ["GPU", "CLANG", "METAL", "CUDA"]: return self.fxn(*args, **kwargs) # only jit on the GPU codegen
|
||||
# NOTE: this cast is needed since although we know realize will create a ".realized" DeviceBuffer, the type checker doesn't
|
||||
input_buffers: Dict[Union[int, str], CompiledBuffer] = {cast(Union[int, str], k):cast(CompiledBuffer, v.realize().lazydata.realized) for k,v in itertools.chain(enumerate(args), kwargs.items()) if isinstance(v, Tensor)}
|
||||
# TODO: check the shapetrackers on the CompiledBuffers are what we jitted
|
||||
input_rawbuffers: Dict[Union[int, str], RawBuffer] = {k:v.raw() for k,v in input_buffers.items()}
|
||||
input_rawbuffers: Dict[Union[int, str], RawBuffer] = {cast(Union[int, str], k):cast(RawBuffer, v.realize().lazydata.realized) for k,v in itertools.chain(enumerate(args), kwargs.items()) if isinstance(v, Tensor)}
|
||||
assert len(input_rawbuffers) != 0, "no inputs to JIT"
|
||||
if self.cnt >= 2:
|
||||
for (j,i),idx in self.input_replace.items(): self.jit_cache[j][1][i] = input_rawbuffers[idx]
|
||||
|
||||
117
tinygrad/lazy.py
117
tinygrad/lazy.py
@@ -1,12 +1,11 @@
|
||||
from __future__ import annotations
|
||||
from typing import Optional, Tuple, Union, List, Dict, Any, ClassVar, Type
|
||||
from typing import Optional, Tuple, Union, List, Dict, Any, ClassVar, cast
|
||||
import sys, weakref, importlib, inspect, functools, pathlib
|
||||
from weakref import WeakValueDictionary
|
||||
from tinygrad.helpers import prod, getenv, DType, dtypes, LazyNumpyArray, flatten
|
||||
from tinygrad.helpers import prod, getenv, DType, dtypes, LazyNumpyArray, flatten, ImageDType
|
||||
from tinygrad.shape.shapetracker import ShapeTracker, get_contraction
|
||||
from tinygrad.ops import DeviceBuffer, UnaryOps, BinaryOps, ReduceOps, MovementOps, LoadOps, OpType, LazyOp, get_buffers, get_lazyops, map_buffers
|
||||
from tinygrad.runtime.ops_cpu import CPUBuffer
|
||||
from tinygrad.graph import log_op
|
||||
from tinygrad.ops import Compiled, Interpreted, UnaryOps, BinaryOps, ReduceOps, MovementOps, LoadOps, OpType, LazyOp, get_buffers, get_lazyops, map_buffers
|
||||
from tinygrad.runtime.lib import RawConst, RawBuffer
|
||||
|
||||
# lazy can recurse a lot
|
||||
sys.setrecursionlimit(10000)
|
||||
@@ -19,7 +18,7 @@ class _Device:
|
||||
self._buffers: List[str] = [x.stem[len("ops_"):].upper() for x in (pathlib.Path(__file__).parent/"runtime").iterdir() if x.stem.startswith("ops_")]
|
||||
self.DEFAULT: str = functools.reduce(lambda val, ele: ele if getenv(ele) == 1 else val, self._buffers, "CPU")
|
||||
@functools.lru_cache(maxsize=None) # this class is a singleton, pylint: disable=method-cache-max-size-none
|
||||
def __getitem__(self, x:str) -> Type[DeviceBuffer]: return [cls for cname, cls in inspect.getmembers(importlib.import_module(f'tinygrad.runtime.ops_{x.lower()}'), inspect.isclass) if (cname.lower() == x.lower() + "buffer") and x in self._buffers][0]
|
||||
def __getitem__(self, x:str) -> Union[Interpreted, Compiled]: return [cls for cname, cls in inspect.getmembers(importlib.import_module(f'tinygrad.runtime.ops_{x.lower()}')) if (cname.lower() == x.lower() + "buffer") and x in self._buffers][0]
|
||||
Device = _Device()
|
||||
|
||||
# TODO: movement ops that only change shape are really nops. treat them as such
|
||||
@@ -70,7 +69,7 @@ def get_movementroot_contiguous(x:LazyBuffer) -> LazyBuffer: return get_movement
|
||||
def replace_with_movement_op(y:Union[LazyOp, LazyBuffer], op:MovementOps, arg:Tuple[Any, ...]) -> LazyBuffer:
|
||||
if isinstance(y, LazyBuffer): return y.movement_op(op, arg)
|
||||
assert y.op in BinaryOps or y.op in UnaryOps
|
||||
return elementwise_op(y.op, *[replace_with_movement_op(z, op, arg) for z in y.src]) # type: ignore
|
||||
return elementwise_op(y.op, *[replace_with_movement_op(z, op, arg) for z in y.src], arg=y.arg) # type: ignore
|
||||
|
||||
def support_weakref(x): return x
|
||||
@support_weakref # needed for mypyc, this prevents LazyBuffer from becoming a native class
|
||||
@@ -91,9 +90,10 @@ class LazyBuffer:
|
||||
if hasattr(self, 'device'):
|
||||
return # cache hit, we return and don't reinit
|
||||
self.st = shape if isinstance(shape, ShapeTracker) else ShapeTracker(tuple(shape))
|
||||
self.shape, self.optype, self.op, self.dtype = self.st.shape, optype, op, dtype
|
||||
self.realized: Optional[DeviceBuffer] = None
|
||||
self.output_buffer: Optional[DeviceBuffer] = None
|
||||
self.shape, self.optype, self.dtype = self.st.shape, optype, dtype
|
||||
self.op: LazyOp = op
|
||||
self.realized: Optional[RawBuffer] = None
|
||||
self.output_buffer: Optional[RawBuffer] = None
|
||||
self.device, self.dbuffer = device, Device[device]
|
||||
# TODO: does children have to be a ref count instead of a set? can a Buffer be a double child?
|
||||
self.children: weakref.WeakSet[LazyBuffer] = weakref.WeakSet()
|
||||
@@ -101,67 +101,72 @@ class LazyBuffer:
|
||||
for x in get_buffers(op): x.children.add(self)
|
||||
if not LAZY: self.realize()
|
||||
|
||||
def __repr__(self): return f"<LB {self.shape} {self.dtype} op:{self.op.op if self.realized is None else 'realized'}>"
|
||||
def __repr__(self): return f"<LB {self.shape} {self.dtype} op:{self.op.op if self.realized is None else 'realized'} st:{self.st}>"
|
||||
|
||||
# this produces a device buffer
|
||||
def realize(self:LazyBuffer, required_device=None) -> DeviceBuffer:
|
||||
def realize(self:LazyBuffer, required_device=None) -> LazyBuffer:
|
||||
assert required_device is None or required_device == self.device
|
||||
if self.realized is None:
|
||||
# get real ops first
|
||||
if self.op.op == LoadOps.FROMCPU:
|
||||
# resolve LazyNumpyArray
|
||||
ast = LazyOp(self.op.op, tuple(), self.op.arg() if isinstance(self.op.arg, LazyNumpyArray) else self.op.arg)
|
||||
self.realized = Device[self.device].buffer.fromCPU(self.op.arg())
|
||||
elif self.op.op == LoadOps.CONTIGUOUS:
|
||||
real_src = self.op.src[0].realize(self.device)
|
||||
self.realized = real_src.contiguous()
|
||||
ast = LazyOp(self.op.op, (real_src, ))
|
||||
elif self.op.op == LoadOps.CUSTOM:
|
||||
real_srcs = tuple(x.realize(self.device) for x in self.op.src)
|
||||
self.realized = self.op.arg(*real_srcs)
|
||||
ast = LazyOp(self.op.op, real_srcs)
|
||||
elif self.optype == MovementOps:
|
||||
src = self.op.src[0]
|
||||
|
||||
# fuse RESHAPE and ReduceOps
|
||||
# NOTE: this is sort of a hack for IMAGE, otherwise it shouldn't matter
|
||||
if src.realized is None and src.optype == ReduceOps and self.op.op == MovementOps.RESHAPE and len(src.children) <= 1:
|
||||
# it's okay to add a RESHAPE to the ast here
|
||||
ast = LazyOp(MovementOps.RESHAPE, (_ast_reduceops(src), ), self.op.arg)
|
||||
realized = self.op.src[0].realize(self.device).realized
|
||||
if self.op.src[0].st.contiguous and not isinstance(realized, RawConst) and realized.size == prod(self.shape):
|
||||
# no need to run an AST, this is already contiguous
|
||||
self.realized = realized
|
||||
else:
|
||||
# movement ops aren't an AST, just run them
|
||||
real_src = src.realize(self.device)
|
||||
self.realized = real_src.movement_op(self.op.op, self.op.arg)
|
||||
ast = LazyOp(self.op.op, (real_src, ))
|
||||
elif self.optype == ReduceOps: ast = _ast_reduceops(self)
|
||||
elif self.optype == BinaryOps: ast = _ast_binaryops(self)
|
||||
# TODO: remove UnaryOps.NOOP, replace with LoadOps.CONTIGUOUS. confusing with Compiled though
|
||||
self.op = LazyOp(UnaryOps.NOOP, self.op.src)
|
||||
elif self.op.op == LoadOps.CUSTOM:
|
||||
# this needs to immediately realize
|
||||
self.realized = self.op.arg(self, *[x.realize(self.device) for x in self.op.src])
|
||||
# these can be late folded and change the op to go further back in the graph
|
||||
elif self.optype == ReduceOps: self.op = _ast_reduceops(self)
|
||||
elif self.optype == BinaryOps: self.op = _ast_binaryops(self) # ISSUE: this can include a reshape
|
||||
|
||||
# run the ast if we still have to, and log the op
|
||||
if self.realized is None:
|
||||
for x in get_buffers(self.op): x.realize(self.device)
|
||||
|
||||
# HACK: image shape can be wrong, hot cast it back to a normal float
|
||||
if self.optype != MovementOps and isinstance(self.dtype, ImageDType) and (prod(self.shape) != prod(self.dtype.shape) or self.shape[self.st.strides.index(1)]%4 != 0):
|
||||
upcasted = LazyOp(UnaryOps.CAST, (self.op,), dtypes.float32)
|
||||
if self.op.op == MovementOps.RESHAPE: self.op = LazyOp(MovementOps.RESHAPE, upcasted, self.op.arg)
|
||||
else: self.op = upcasted
|
||||
self.dtype = dtypes.float32
|
||||
|
||||
self.realized = Device[self.device].exec_ast(self.op, output=self)
|
||||
|
||||
# log to the graph
|
||||
from tinygrad.graph import log_op
|
||||
log_op(self, self.op)
|
||||
|
||||
# no need to keep the op after realization
|
||||
del self.op
|
||||
|
||||
# run the ast if we still have to, and log the op
|
||||
if self.realized is None:
|
||||
ast = map_buffers({x:x.realize(self.device) for x in get_buffers(ast)}, ast)
|
||||
self.realized = self.dbuffer.exec_ast(ast, output_buffer=self.output_buffer)
|
||||
log_op(self.realized, ast)
|
||||
|
||||
assert self.realized.shape == self.shape, f"shape mismatch on realize got {self.realized.shape} expected {self.shape}"
|
||||
assert isinstance(self.realized, Device[self.device]), f"device mismatch on realized got {type(self.realized)} expected {self.device}"
|
||||
assert self.realized.dtype == self.dtype, f"dtype mismatch on realize got {self.realized.dtype} expected {self.dtype}"
|
||||
return self.realized
|
||||
assert isinstance(self.realized, (RawConst, Device[self.device].buffer)), f"device mismatch on realized got {type(self.realized)} expected {self.device}"
|
||||
# HACK: allow hot casting of images
|
||||
assert self.realized.dtype == self.dtype or self.dtype.name.startswith("image"), f"dtype mismatch on realize got {self.realized.dtype} expected {self.dtype}"
|
||||
self.dtype = self.realized.dtype
|
||||
return self
|
||||
|
||||
# NOTE: we have to make a copy of the numpy array here in case the user changes it. expose this? LazyNumpyArray doesn't have this problem
|
||||
@staticmethod
|
||||
def fromCPU(x, device) -> LazyBuffer: return LazyBuffer(device, x.shape, LoadOps, LazyOp(LoadOps.FROMCPU, tuple(), x.copy()), dtypes.from_np(x))
|
||||
def fromCPU(x:LazyNumpyArray, device) -> LazyBuffer:
|
||||
return LazyBuffer(device, x.shape, LoadOps, LazyOp(LoadOps.FROMCPU, tuple(), x.copy()), dtypes.from_np(x.dtype))
|
||||
|
||||
# NOTE: we also have to copy the numpy array on the way out...otherwise the underlying Tensor could be freed and use after free. improve this?
|
||||
def toCPU(self):
|
||||
ret = self.realize().toCPU()
|
||||
log_op(CPUBuffer(ret), LazyOp(LoadOps.TOCPU, (self.realized,), None))
|
||||
realized = self.cast(dtypes.from_np(self.dtype.np)).contiguous().realize().realized
|
||||
ret = cast(RawBuffer, realized).toCPU().reshape(self.shape)
|
||||
return ret.copy()
|
||||
|
||||
def cast(self:LazyBuffer, arg:DType) -> LazyBuffer: return elementwise_op(UnaryOps.CAST, self, arg=arg) if self.dtype != arg else self
|
||||
def unary_op(self:LazyBuffer, op:UnaryOps) -> LazyBuffer: return elementwise_op(op, self)
|
||||
def binary_op(self:LazyBuffer, op:BinaryOps, y:LazyBuffer) -> LazyBuffer: return elementwise_op(op, self, y)
|
||||
def contiguous(self:LazyBuffer) -> LazyBuffer: return LazyBuffer(self.device, self.shape, LoadOps, LazyOp(LoadOps.CONTIGUOUS, (self,)), self.dtype)
|
||||
def contiguous(self:LazyBuffer) -> LazyBuffer:
|
||||
if self.realized is None and self.op.op == LoadOps.CONTIGUOUS: return self # two CONTIGUOUS in a row is one
|
||||
return LazyBuffer(self.device, self.shape, LoadOps, LazyOp(LoadOps.CONTIGUOUS, (self,)), self.dtype)
|
||||
|
||||
def reduce_op(self:LazyBuffer, op:ReduceOps, new_shape:Tuple[int, ...]) -> LazyBuffer:
|
||||
if self.shape == tuple(new_shape): return self
|
||||
@@ -199,8 +204,8 @@ class LazyBuffer:
|
||||
# some permutes are actually just reshapes
|
||||
if op == MovementOps.PERMUTE and local_st.contiguous: return self.movement_op(MovementOps.RESHAPE, tuple(self.shape[i] for i in arg))
|
||||
|
||||
# move permutes before expands
|
||||
if op == MovementOps.PERMUTE and PUSH_PERMUTES and self.realized is None and self.op.op == MovementOps.EXPAND:
|
||||
# move permutes before expands (always, this is safe)
|
||||
if op == MovementOps.PERMUTE and self.realized is None and self.op.op == MovementOps.EXPAND:
|
||||
self.op.src[0].children.discard(self)
|
||||
return self.op.src[0].movement_op(MovementOps.PERMUTE, arg).movement_op(MovementOps.EXPAND, tuple(self.op.arg[a] for a in arg))
|
||||
|
||||
@@ -228,8 +233,8 @@ class LazyBuffer:
|
||||
|
||||
return ret
|
||||
|
||||
def elementwise_op(op:Union[UnaryOps, BinaryOps], *srcs:LazyBuffer) -> LazyBuffer:
|
||||
out_device, out_shape, out_dtype = srcs[0].device, srcs[0].shape, max(x.dtype for x in srcs)
|
||||
def elementwise_op(op:Union[UnaryOps, BinaryOps], *srcs:LazyBuffer, arg:Optional[Any]=None) -> LazyBuffer:
|
||||
out_device, out_shape, out_dtype = srcs[0].device, srcs[0].shape, max(x.dtype for x in srcs) if op != UnaryOps.CAST else cast(DType, arg)
|
||||
|
||||
# push all contiguous to the end of BinaryOps. kernels 198 -> 196
|
||||
if PUSH_CONTIGUOUS and any(x.realized is None and x.op.op == LoadOps.CONTIGUOUS and len(x.op.src[0].children) <= 1 for x in srcs):
|
||||
@@ -240,10 +245,10 @@ def elementwise_op(op:Union[UnaryOps, BinaryOps], *srcs:LazyBuffer) -> LazyBuffe
|
||||
new_srcs.append(x.op.src[0])
|
||||
else:
|
||||
new_srcs.append(x)
|
||||
return elementwise_op(op, *new_srcs).contiguous()
|
||||
return elementwise_op(op, *new_srcs, arg=arg).contiguous()
|
||||
|
||||
if MERGE_ELEMENTWISE_OPS or (MERGE_UNARY_OPS and len(set(srcs)) == 1):
|
||||
# remove the buffers from any (childless) BinaryOps that feed into this
|
||||
srcs = tuple(x.op if x.optype == BinaryOps and len(x.children) == 0 and x.realized is None else x for x in srcs) # type: ignore
|
||||
|
||||
return LazyBuffer(out_device, out_shape, BinaryOps, LazyOp(op, srcs), out_dtype)
|
||||
return LazyBuffer(out_device, out_shape, BinaryOps, LazyOp(op, srcs, arg), out_dtype)
|
||||
|
||||
@@ -8,6 +8,13 @@ class Contiguous(Function):
|
||||
def forward(self, x): return x.contiguous()
|
||||
def backward(self, grad_output): return grad_output
|
||||
|
||||
class Cast(Function):
|
||||
def forward(self, x, dtype):
|
||||
self.input_dtype = x.dtype
|
||||
return x.cast(dtype)
|
||||
def backward(self, grad_output):
|
||||
return grad_output.cast(self.input_dtype)
|
||||
|
||||
# ************* unary ops *************
|
||||
|
||||
class Log(Function):
|
||||
|
||||
@@ -1,6 +1,10 @@
|
||||
from tinygrad.helpers import prod, IMAGE
|
||||
import numpy as np
|
||||
from tinygrad.helpers import prod, IMAGE, ImageDType, getenv, dtypes
|
||||
from tinygrad.lazy import get_single_root
|
||||
|
||||
FLOAT16 = getenv("FLOAT16", 0)
|
||||
base_image_type = (100, 2, "image_half", np.float16) if FLOAT16 else (100, 4, "image_float", np.float32)
|
||||
|
||||
def image_dot(self, w):
|
||||
# NOTE: we use a 1x1 conv2d to do the matmul. mxk @ kxn = (1,k,m,1).conv2d(n,k,1,1)
|
||||
bs, groups = prod(self.shape[0:-2]), prod(w.shape[0:-2])
|
||||
@@ -49,6 +53,7 @@ def image_conv2d(self, weight, bias=None, groups=1, stride=1, dilation=1, paddin
|
||||
else: w = w.reshape(cout//4,4,cin//4,4,H,W).permute(0,4,2,5,3,1).reshape(cout//4, H*cin//4*W*4, 4)
|
||||
|
||||
# contiguous creates the image, and early realize static weights (TODO: test for the static weight)
|
||||
if IMAGE >= 2: x,w = x.cast(ImageDType(*base_image_type, shape=x.shape)), w.cast(ImageDType(*base_image_type, shape=w.shape))
|
||||
x, w = x.contiguous(), w.contiguous()
|
||||
if get_single_root(w.lazydata).realized: w.realize()
|
||||
|
||||
@@ -71,10 +76,14 @@ def image_conv2d(self, weight, bias=None, groups=1, stride=1, dilation=1, paddin
|
||||
|
||||
# prepare weights
|
||||
w = w.permute(0,4,2,5,1,3)
|
||||
w = w.reshape((1, 1, 1, *cout_expand, rcin_hi, rcin_lo, H, W))
|
||||
w = w.reshape((1, 1, 1, *cout_expand, rcin_hi, rcin_lo, H, W)).expand(x.shape)
|
||||
|
||||
# the conv!
|
||||
ret = (x*w).sum((-4, -3, -2, -1)).reshape(bs*oy, ox*cout//4, 4)
|
||||
# the conv! (+ the bias)
|
||||
ret = (x*w).cast(dtypes.float32).sum((-4, -3, -2, -1))
|
||||
|
||||
# reshape to image and cast back to image
|
||||
ret = ret.reshape(bs*oy, ox*cout//4, 4)
|
||||
if IMAGE >= 2: ret = ret.cast(ImageDType(*base_image_type, shape=ret.shape))
|
||||
if IMAGE >= 3: ret = ret.contiguous()
|
||||
|
||||
# undo hack for non multiples of 4 on C.rcout
|
||||
|
||||
155
tinygrad/ops.py
155
tinygrad/ops.py
@@ -1,15 +1,14 @@
|
||||
from __future__ import annotations
|
||||
import functools, itertools, operator, random
|
||||
import numpy as np
|
||||
from enum import Enum, auto
|
||||
from typing import Union, Type, NamedTuple, Tuple, Any, List, ClassVar, Optional, Dict, Set, Final
|
||||
from tinygrad.helpers import prod, DEBUG, getenv, DType, dtypes, GlobalCounters
|
||||
from tinygrad.shape.shapetracker import ShapeTracker, MovementOps
|
||||
from typing import Union, Type, NamedTuple, Tuple, Any, List, Optional, Dict, Set, Callable
|
||||
from tinygrad.helpers import prod, DEBUG, getenv, GlobalCounters
|
||||
from tinygrad.shape.shapetracker import MovementOps
|
||||
from tinygrad.runtime.lib import RawBuffer
|
||||
|
||||
# these are the llops your accelerator must implement, along with toCpu
|
||||
# the Enum class doesn't work with mypy, this is static. sorry it's ugly
|
||||
class UnaryOps(Enum): NOOP = auto(); EXP = auto(); LOG = auto(); NEG = auto(); NOT = auto() # noqa: E702
|
||||
class UnaryOps(Enum): NOOP = auto(); EXP = auto(); LOG = auto(); NEG = auto(); NOT = auto(); CAST = auto() # noqa: E702
|
||||
class BinaryOps(Enum): ADD = auto(); SUB = auto(); MUL = auto(); DIV = auto(); POW = auto(); CMPEQ = auto(); MAX = auto() # noqa: E702
|
||||
class ReduceOps(Enum): SUM = auto(); MAX = auto() # noqa: E702
|
||||
class FusedOps(Enum): MULACC = auto() # noqa: E702
|
||||
@@ -32,15 +31,6 @@ def map_buffers(real_srcs:Dict[Any, Any], x:Any) -> LazyOp:
|
||||
if len(real_srcs) and x in real_srcs: return map_buffers(real_srcs, real_srcs[x]) if isinstance(real_srcs[x], LazyOp) else real_srcs[x]
|
||||
return LazyOp(x.op, tuple((map_buffers(real_srcs, y) if isinstance(y, LazyOp) else real_srcs[y]) for y in x.src), x.arg)
|
||||
|
||||
# a placeholder class to extend by the exec classes
|
||||
class DeviceBuffer:
|
||||
_buf: Any # underlying buffer
|
||||
shape: Tuple[int, ...]
|
||||
dtype: DType
|
||||
@classmethod
|
||||
def exec_ast(cls, ast:LazyOp, output_buffer=None): raise NotImplementedError("must be implemented")
|
||||
def toCPU(self) -> np.ndarray: raise NotImplementedError("must be implemented")
|
||||
|
||||
class ASTRunner:
|
||||
def __init__(self, name, prg, bufs_to_delete:Optional[Set[int]]=None, global_size:Optional[List[int]]=None, local_size:Optional[List[int]]=None, op_estimate=0, mem_estimate=0):
|
||||
if DEBUG >= 4: print(prg)
|
||||
@@ -50,8 +40,9 @@ class ASTRunner:
|
||||
self.clprg = runtime(self.name, self.prg)
|
||||
return self
|
||||
|
||||
def exec(self, bufs:List[Optional[CompiledBuffer]]) -> Optional[float]:
|
||||
rawbufs = [x.raw() for i,x in enumerate(bufs) if x is not None and i not in self.bufs_to_delete]
|
||||
def exec(self, bufs) -> Optional[float]:
|
||||
rawbufs = [x.realized for i,x in enumerate(bufs) if x is not None and i not in self.bufs_to_delete]
|
||||
assert all(x is not None for x in rawbufs), "some rawbufs are None, you probably didn't realize them"
|
||||
if getenv("OPTLOCAL") and self.global_size is not None and self.local_size is None: self.local_size = self.optimize_local_size(rawbufs)
|
||||
if GlobalCounters.cache is not None: GlobalCounters.cache.append((self, rawbufs))
|
||||
return self(rawbufs)
|
||||
@@ -81,66 +72,84 @@ class ASTRunner:
|
||||
local_sizes = [list(x) for x in itertools.product(*local_dims) if prod(x) <= MAX_WORKGROUP] * 2 # try each valid size twice
|
||||
return min([(self.timeit(rawbufs, local_size), local_size) for local_size in random.sample(local_sizes, len(local_sizes))])[1]
|
||||
|
||||
from tinygrad.codegen.ast import ASTKernel
|
||||
class Specialized(NamedTuple):
|
||||
raw_buffer: Type[RawBuffer]
|
||||
codegen: Type[ASTKernel]
|
||||
runtime: Type
|
||||
class Interpreted:
|
||||
def __init__(self, buffer, fxn_for_op: Dict[Op, Callable], from_lazybuffer=lambda x: x.realized, to_underlying=lambda x: x._buf):
|
||||
self.buffer = buffer
|
||||
self.fxn_for_op = fxn_for_op
|
||||
self.from_lazybuffer = from_lazybuffer
|
||||
self.to_underlying = to_underlying
|
||||
|
||||
# assumes you are using ShapeTracker
|
||||
# used in GPUBuffer and LLVMBuffer
|
||||
class CompiledBuffer(DeviceBuffer): # pylint: disable=abstract-method
|
||||
spec: ClassVar[Specialized]
|
||||
|
||||
def __init__(self, shape:Union[ShapeTracker, Tuple[int, ...]], hostbuf:Optional[CompiledBuffer]=None, backing:Optional[np.ndarray]=None, force_create=False, dtype:DType=dtypes.float32):
|
||||
self.st: ShapeTracker = shape if isinstance(shape, ShapeTracker) else ShapeTracker(tuple(shape))
|
||||
self.shape: Tuple[int, ...] = self.st.shape
|
||||
self.dtype: DType = dtype
|
||||
assert hostbuf is None or hostbuf.dtype == dtype, f"hostbuf dtype {hostbuf.dtype} != {dtype}"
|
||||
self._base_shape: Tuple[int, ...] = hostbuf._base_shape if hostbuf is not None else self.shape
|
||||
self._buf = hostbuf._buf if hostbuf is not None else None
|
||||
self._backing: Optional[np.ndarray] = hostbuf._backing if hostbuf is not None else backing
|
||||
assert self._backing is None or dtypes.from_np(self._backing) == dtype, f"backing dtype {dtypes.from_np(self._backing)} != {dtype}"
|
||||
if (self._backing is not None and self._backing.shape != (1,)) or force_create: self.raw()
|
||||
|
||||
def __repr__(self): return f"{type(self).__name__}(shape={self.st}, hostbuf={type(self).__name__}(shape={self._base_shape}" + (f", backing=np.array({self._backing}, dtype=np.{self.dtype.np.__name__}), dtype={self.dtype}), dtype={self.dtype})" if self._backing is not None else f", force_create=True, dtype={self.dtype}), dtype={self.dtype})")
|
||||
|
||||
def create_raw_buffer(self, shape:Tuple[int, ...], backing:Optional[np.ndarray], dtype:DType) -> RawBuffer:
|
||||
assert backing is None or prod(shape) == prod(backing.shape), "backing has the wrong shape"
|
||||
assert backing is None or GlobalCounters.cache is None, f"can't copy in {backing.shape} while caching"
|
||||
if DEBUG >= 4: print(f"create raw buffer {shape} {dtype} backed:{backing is not None}")
|
||||
return self.spec.raw_buffer(prod(shape), dtype) if backing is None else self.spec.raw_buffer.fromCPU(backing)
|
||||
|
||||
def raw(self) -> RawBuffer:
|
||||
if self._buf is None:
|
||||
if DEBUG >= 4 and self._backing is not None: print(f"**** copy in {self._backing.shape} to {type(self)}")
|
||||
self._buf = self.create_raw_buffer(self._base_shape, self._backing, self.dtype)
|
||||
self._backing = None
|
||||
return self._buf
|
||||
|
||||
def toCPU(self) -> np.ndarray:
|
||||
assert GlobalCounters.cache is None, f"can't copy out {self} while caching"
|
||||
if DEBUG >= 3: print(f"**** copy out {self.shape}")
|
||||
return self.contiguous().raw().toCPU().reshape(self.shape)
|
||||
|
||||
method_cache: Final[Dict[str, ASTRunner]] = {}
|
||||
@classmethod
|
||||
def exec_ast(cls, ast:LazyOp, output_buffer:Optional[CompiledBuffer]=None):
|
||||
if ast.op == LoadOps.FROMCPU: return cls(ast.arg.shape, backing=ast.arg.ravel(), dtype=dtypes.from_np(ast.arg))
|
||||
k = cls.spec.codegen(ast, output_buffer)
|
||||
if getenv("ENABLE_METHOD_CACHE", 1): # this is the default now
|
||||
if k.key not in cls.method_cache: cls.method_cache[k.key] = k.codegen().build(cls.spec.runtime)
|
||||
elif DEBUG >= 4: print(f"method cache hit : {k.key}")
|
||||
prg = cls.method_cache[k.key]
|
||||
def exec_ast(self, ast:LazyOp, output=None, context=None):
|
||||
if FusedOps.MULACC in self.fxn_for_op and ast.op == ReduceOps.SUM and isinstance(ast.src[0], LazyOp) and ast.src[0].op == BinaryOps.MUL:
|
||||
ast = LazyOp(FusedOps.MULACC, ast.src[0].src, ast.arg)
|
||||
created_context = context is None
|
||||
if context is None: context = dict()
|
||||
if not created_context and ast in context: return context[ast]
|
||||
srcs = [self.exec_ast(x, context=context) if isinstance(x, LazyOp) else self.from_lazybuffer(x) for x in ast.src]
|
||||
ret = self.buffer(self.fxn_for_op[ast.op](*([self.to_underlying(x) for x in srcs] + ([ast.arg] if ast.arg is not None else []))))
|
||||
if DEBUG >= 4: print(f"*** {'exec' if created_context else ' '} {GlobalCounters.mem_used/1e9:5.2f} GB op: {ast.op:20s} out({ret.dtype.name}): {str(ret.shape):30s} in({len(srcs)}):", list(set(x.shape for x in srcs)), ast.arg if ast.arg is not None else "")
|
||||
if not created_context: context[ast] = ret
|
||||
if output is not None and output.output_buffer is not None:
|
||||
assert output.output_buffer.size == ret.size, output.output_buffer.dtype == ret.dtype
|
||||
output.output_buffer._buf = ret._buf
|
||||
return output.output_buffer
|
||||
else:
|
||||
prg = k.codegen().build(cls.spec.runtime)
|
||||
return ret
|
||||
|
||||
class FlopCounter:
|
||||
def __init__(self, tup:Tuple[Tuple[int, ...], DType, int]): self.shape, self.dtype, self.flops = tup
|
||||
def consume_flops(self):
|
||||
self.flops, ret = 0, self.flops
|
||||
return ret
|
||||
from tinygrad.shape.shapetracker import ShapeTracker
|
||||
shape_fxn_for_op: Dict[Op, Callable] = {
|
||||
UnaryOps.CAST: lambda self,dtype: (self.shape, dtype, self.consume_flops() + prod(self.shape)),
|
||||
**{op:lambda self: (self.shape, self.dtype, self.consume_flops() + prod(self.shape)) for op in UnaryOps if op != UnaryOps.CAST},
|
||||
**{op:lambda self,y: (self.shape, max(self.dtype, y.dtype), self.consume_flops() + y.consume_flops() + prod(self.shape)) for op in BinaryOps},
|
||||
**{op:lambda self,new_shape: (new_shape, self.dtype, self.consume_flops() + prod(self.shape)) for op in ReduceOps},
|
||||
**{op:functools.partial(lambda mop,self,arg: (ShapeTracker(self.shape).movement_op(mop, arg).shape, self.dtype, self.consume_flops()), op) for op in MovementOps}}
|
||||
InterpretedFlopCounter = Interpreted(FlopCounter, shape_fxn_for_op, lambda x: FlopCounter((x.shape, x.dtype, 0)), lambda x: x)
|
||||
def get_lazyop_info(ast:LazyOp) -> FlopCounter: return InterpretedFlopCounter.exec_ast(ast)
|
||||
|
||||
from tinygrad.codegen.ast import ASTKernel
|
||||
from tinygrad.helpers import DType
|
||||
|
||||
class Compiled:
|
||||
def __init__(self, buffer: Type[RawBuffer], codegen: Type[ASTKernel], runtime):
|
||||
self.buffer, self.codegen, self.runtime = buffer, codegen, runtime
|
||||
self.method_cache: Dict[str, ASTRunner] = {}
|
||||
|
||||
def exec_ast(self, ast:LazyOp, output):
|
||||
# all movementops do nothing in a Compiled buffer!
|
||||
if ast.op in MovementOps and not isinstance(ast.src[0], LazyOp) and ast.src[0].realized is not None: return ast.src[0].realized
|
||||
|
||||
k = self.codegen(ast, output)
|
||||
|
||||
# this is the default now
|
||||
if getenv("ENABLE_METHOD_CACHE", 1):
|
||||
if k.key not in self.method_cache: self.method_cache[k.key] = k.codegen().build(self.runtime)
|
||||
elif DEBUG >= 4: print(f"method cache hit : {k.key}")
|
||||
prg = self.method_cache[k.key]
|
||||
else:
|
||||
prg = k.codegen().build(self.runtime)
|
||||
|
||||
if getenv("PRINT_AST", "") == prg.name or getenv("PRINT_AST", "") == "1":
|
||||
k.print()
|
||||
print(prg.prg)
|
||||
|
||||
# check if we can reuse the output buffer
|
||||
# if it's aliased, don't use it
|
||||
# NOTE: this is pretty wrong actually, who knows where else this buffer is used?
|
||||
output.realized = output.output_buffer
|
||||
if output.realized is not None:
|
||||
for a in get_buffers(ast):
|
||||
if a.realized == output.realized and not a.st.contiguous:
|
||||
output.realized = None
|
||||
break
|
||||
|
||||
# we don't have an output buffer, we have to create it
|
||||
if output.realized is None:
|
||||
output.realized = self.buffer(prod(output.shape), output.dtype)
|
||||
|
||||
prg.exec(k.bufs)
|
||||
return k.ret
|
||||
|
||||
# universal for shape tracked
|
||||
def contiguous(self): return self if self.st.contiguous and prod(self._base_shape) == prod(self.shape) else type(self).exec_ast(LazyOp(op=UnaryOps.NOOP, src=(self,)))
|
||||
def movement_op(self, op:MovementOps, arg): return type(self)(ShapeTracker(self.st).movement_op(op, arg), hostbuf=self, dtype=self.dtype)
|
||||
|
||||
return output.realized
|
||||
|
||||
@@ -1,13 +1,14 @@
|
||||
import ctypes
|
||||
import numpy as np
|
||||
from typing import TypeVar, Type
|
||||
from typing import TypeVar, Type, Any
|
||||
from tinygrad.helpers import DType, dtypes, prod, GlobalCounters
|
||||
|
||||
_T = TypeVar("_T")
|
||||
class RawBuffer: # pylint: disable=abstract-method
|
||||
def __init__(self, size:int, dtype:DType):
|
||||
def __init__(self, size:int, dtype:DType, buf:Any=None):
|
||||
self.size: int = size
|
||||
self.dtype: DType = dtype
|
||||
self._buf = buf
|
||||
self._memsz: int = size*dtype.itemsize
|
||||
GlobalCounters.mem_used += self._memsz
|
||||
def __del__(self): GlobalCounters.mem_used -= self._memsz
|
||||
@@ -22,7 +23,7 @@ class RawBufferCopyIn(RawBuffer):
|
||||
|
||||
@classmethod
|
||||
def fromCPU(cls, x:np.ndarray):
|
||||
ret = cls(prod(x.shape), dtypes.from_np(x))
|
||||
ret = cls(prod(x.shape), dtypes.from_np(x.dtype))
|
||||
ret._copyin(x)
|
||||
return ret
|
||||
|
||||
@@ -33,9 +34,7 @@ class RawBufferMapped(RawBufferCopyIn):
|
||||
|
||||
# this one is simple enough that i moved it out of the runtimes
|
||||
class RawMallocBuffer(RawBufferMapped):
|
||||
def __init__(self, size, dtype: DType):
|
||||
super().__init__(size, dtype)
|
||||
self._buf = ({dtypes.float32: ctypes.c_float, dtypes.float16: ctypes.c_int16}[dtype] * size)()
|
||||
def __init__(self, size, dtype: DType): super().__init__(size, dtype, ({dtypes.float32: ctypes.c_float, dtypes.float16: ctypes.c_int16}[dtype] * size)())
|
||||
def _buffer(self): return memoryview(self._buf)
|
||||
|
||||
class RawBufferCopyInOut(RawBufferCopyIn):
|
||||
@@ -45,3 +44,5 @@ class RawBufferCopyInOut(RawBufferCopyIn):
|
||||
x: np.ndarray = np.empty(self.size, dtype=self.dtype.np)
|
||||
self._copyout(x)
|
||||
return x
|
||||
|
||||
class RawConst(RawBuffer): pass # pylint: disable=abstract-method
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
import os, time, ctypes, hashlib, subprocess, platform
|
||||
from tinygrad.ops import CompiledBuffer, Specialized
|
||||
from tinygrad.ops import Compiled
|
||||
from tinygrad.runtime.lib import RawMallocBuffer
|
||||
from tinygrad.codegen.gpu import GPUCodegen, GPULanguage
|
||||
|
||||
@@ -23,5 +23,4 @@ class ClangProgram:
|
||||
class ClangCodegen(GPUCodegen):
|
||||
lang = GPULanguage(buffer_suffix=" restrict")
|
||||
|
||||
class ClangBuffer(CompiledBuffer):
|
||||
spec = Specialized(RawMallocBuffer, ClangCodegen, ClangProgram)
|
||||
ClangBuffer = Compiled(RawMallocBuffer, ClangCodegen, ClangProgram)
|
||||
|
||||
@@ -1,9 +1,9 @@
|
||||
import numpy as np
|
||||
import operator
|
||||
from typing import ClassVar, Callable, Dict, Tuple
|
||||
from typing import Callable, Dict, Tuple
|
||||
from tinygrad.helpers import dtypes
|
||||
from tinygrad.ops import UnaryOps, BinaryOps, MovementOps, ReduceOps, FusedOps, LoadOps, Op
|
||||
from tinygrad.interpreted import InterpretedBuffer
|
||||
from tinygrad.ops import UnaryOps, BinaryOps, MovementOps, ReduceOps, FusedOps, Op, Interpreted
|
||||
from tinygrad.runtime.lib import RawBuffer
|
||||
|
||||
def shape_to_axis(old_shape:Tuple[int, ...], new_shape:Tuple[int, ...]) -> Tuple[int, ...]:
|
||||
assert len(old_shape) == len(new_shape), "reduce shapes must have same dimensions"
|
||||
@@ -14,7 +14,7 @@ base_fxn_for_op: Dict[Op, Callable] = {
|
||||
BinaryOps.ADD: operator.add, BinaryOps.SUB: operator.sub, BinaryOps.MUL: operator.mul, BinaryOps.DIV: operator.truediv, BinaryOps.POW: operator.pow,
|
||||
ReduceOps.SUM: lambda x, new_shape: x.sum(shape_to_axis(x.shape, new_shape), keepdims=True) if tuple(x.shape) != tuple(new_shape) else x[:],
|
||||
ReduceOps.MAX: lambda x, new_shape: (x.amax if hasattr(x, 'amax') else x.max)(shape_to_axis(x.shape, new_shape), keepdims=True) if tuple(x.shape) != tuple(new_shape) else x[:],
|
||||
MovementOps.SHRINK: lambda x, arg: x[tuple(slice(p[0], p[1], None) for p in arg)],
|
||||
MovementOps.RESHAPE: lambda x, arg: x.reshape(arg), MovementOps.SHRINK: lambda x, arg: x[tuple(slice(p[0], p[1], None) for p in arg)],
|
||||
}
|
||||
|
||||
def einsum_mulacc(einsum, get_strides, expand):
|
||||
@@ -33,10 +33,11 @@ numpy_fxn_for_op: Dict[Op, Callable] = {**base_fxn_for_op, **{
|
||||
MovementOps.PERMUTE: lambda x, order: x.transpose(order), MovementOps.PAD: np.pad, MovementOps.EXPAND: np.broadcast_to,
|
||||
MovementOps.STRIDE: lambda x, arg: x[tuple(slice(None, None, i) for i in arg)],
|
||||
FusedOps.MULACC: einsum_mulacc(lambda s,a,b: np.einsum(s, a.copy(), b.copy()), lambda x: x.strides, np.broadcast_to),
|
||||
LoadOps.FROMCPU: lambda arg: arg
|
||||
}}
|
||||
|
||||
class CPUBuffer(InterpretedBuffer):
|
||||
fxn_for_op: ClassVar[Dict[Op, Callable]] = numpy_fxn_for_op
|
||||
def to_tinygrad_dtype(self): return dtypes.from_np(self._buf)
|
||||
class RawNumpyBuffer(RawBuffer):
|
||||
def __init__(self, buf:np.ndarray): super().__init__(buf.size, dtypes.from_np(buf.dtype), buf)
|
||||
@classmethod
|
||||
def fromCPU(cls, x): return cls(x)
|
||||
def toCPU(self): return self._buf
|
||||
CPUBuffer = Interpreted(RawNumpyBuffer, numpy_fxn_for_op)
|
||||
|
||||
@@ -4,16 +4,14 @@ import pycuda.autoprimaryctx # type: ignore # pylint: disable=unused-import # no
|
||||
import pycuda.driver as cuda # type: ignore
|
||||
from pycuda.compiler import compile as cuda_compile # type: ignore
|
||||
from tinygrad.helpers import DEBUG
|
||||
from tinygrad.ops import CompiledBuffer, Specialized
|
||||
from tinygrad.ops import Compiled
|
||||
from tinygrad.runtime.lib import RawBufferCopyInOut
|
||||
from tinygrad.codegen.gpu import GPUCodegen, GPULanguage
|
||||
|
||||
class RawCUDABuffer(RawBufferCopyInOut):
|
||||
def __init__(self, size, dtype):
|
||||
super().__init__(size, dtype)
|
||||
self._cl = cuda.mem_alloc(self._memsz)
|
||||
def _copyin(self, x:np.ndarray, stream:Optional[cuda.Stream]=None): cuda.memcpy_htod_async(self._cl, x, stream)
|
||||
def _copyout(self, x:np.ndarray): cuda.memcpy_dtoh(x, self._cl)
|
||||
def __init__(self, size, dtype): super().__init__(size, dtype, cuda.mem_alloc(size * dtype.itemsize))
|
||||
def _copyin(self, x:np.ndarray, stream:Optional[cuda.Stream]=None): cuda.memcpy_htod_async(self._buf, x, stream)
|
||||
def _copyout(self, x:np.ndarray): cuda.memcpy_dtoh(x, self._buf)
|
||||
|
||||
class CUDAProgram:
|
||||
def __init__(self, name:str, prg:str, binary=False):
|
||||
@@ -47,5 +45,4 @@ class CUDACodegen(GPUCodegen):
|
||||
gid = [f'blockDim.{chr(120+i)}*blockIdx.{chr(120+i)}+threadIdx.{chr(120+i)}' for i in range(3)],
|
||||
lid = [f'threadIdx.{chr(120+i)}' for i in range(3)])
|
||||
|
||||
class CUDABuffer(CompiledBuffer):
|
||||
spec = Specialized(RawCUDABuffer, CUDACodegen, CUDAProgram)
|
||||
CUDABuffer = Compiled(RawCUDABuffer, CUDACodegen, CUDAProgram)
|
||||
|
||||
@@ -2,10 +2,10 @@ from __future__ import annotations
|
||||
import platform
|
||||
import numpy as np
|
||||
import pyopencl as cl # type: ignore
|
||||
from typing import Optional, List, Final
|
||||
from tinygrad.helpers import IMAGE, DEBUG, getenv, dtypes
|
||||
from tinygrad.ops import CompiledBuffer, GlobalCounters, Specialized
|
||||
from tinygrad.runtime.lib import RawBufferCopyInOut, RawBuffer
|
||||
from typing import Optional, List
|
||||
from tinygrad.helpers import DEBUG, getenv, prod, ImageDType
|
||||
from tinygrad.ops import Compiled
|
||||
from tinygrad.runtime.lib import RawBufferCopyInOut
|
||||
from tinygrad.codegen.gpu import GPUCodegen, GPULanguage
|
||||
|
||||
OSX = platform.system() == "Darwin"
|
||||
@@ -21,13 +21,25 @@ class _CL:
|
||||
self.cl_queue: cl.CommandQueue = cl.CommandQueue(self.cl_ctx, properties=cl.command_queue_properties.PROFILING_ENABLE) # this is an in-order command queue
|
||||
CL = _CL()
|
||||
|
||||
# TODO: merge CLImage in here
|
||||
class CLBuffer(RawBufferCopyInOut):
|
||||
def __init__(self, size, dtype):
|
||||
super().__init__(size, dtype)
|
||||
self._cl = cl.Buffer(CL.cl_ctx, cl.mem_flags.READ_WRITE, self._memsz)
|
||||
def _copyin(self, x:np.ndarray): cl.enqueue_copy(CL.cl_queue, self._cl, x, is_blocking=False)
|
||||
def _copyout(self, x:np.ndarray): cl.enqueue_copy(CL.cl_queue, x, self._cl, is_blocking=True)
|
||||
if isinstance(dtype, ImageDType):
|
||||
fmt = cl.ImageFormat(cl.channel_order.RGBA, {2: cl.channel_type.HALF_FLOAT, 4: cl.channel_type.FLOAT}[dtype.itemsize])
|
||||
buf = cl.Image(CL.cl_ctx, cl.mem_flags.READ_WRITE, fmt, shape=(dtype.shape[1], dtype.shape[0]))
|
||||
assert size == prod(dtype.shape), f"image size mismatch {size} != {dtype.shape}"
|
||||
# NOTE: the memory is a bit off here due to padding, it's buf.row_pitch * buf.height * 4 * dtype.itemsize
|
||||
else:
|
||||
buf = cl.Buffer(CL.cl_ctx, cl.mem_flags.READ_WRITE, size * dtype.itemsize)
|
||||
super().__init__(size, dtype, buf)
|
||||
def _copyin(self, x:np.ndarray):
|
||||
assert not self.dtype.name.startswith("image"), f"can't copyin images {self.dtype}"
|
||||
cl.enqueue_copy(CL.cl_queue, self._buf, x, is_blocking=False)
|
||||
def _copyout(self, x:np.ndarray):
|
||||
assert not self.dtype.name.startswith("image"), f"can't copyout images {self.dtype}"
|
||||
cl.enqueue_copy(CL.cl_queue, x, self._buf, is_blocking=True)
|
||||
|
||||
"""
|
||||
class CLImage(RawBuffer): # pylint: disable=abstract-method
|
||||
IMAGE: Final = True
|
||||
def __init__(self, shape, dtype=dtypes.float16 if getenv("FLOAT16") else dtypes.float32): # pylint: disable=super-init-not-called
|
||||
@@ -35,6 +47,7 @@ class CLImage(RawBuffer): # pylint: disable=abstract-method
|
||||
self.size, self.dtype, self._cl = shape, dtype, cl.Image(CL.cl_ctx, cl.mem_flags.READ_WRITE, fmt, shape=(shape[1], shape[0]))
|
||||
GlobalCounters.mem_used += self._cl.row_pitch * self._cl.height
|
||||
def __del__(self): GlobalCounters.mem_used -= self._cl.row_pitch * self._cl.height
|
||||
"""
|
||||
|
||||
class CLProgram:
|
||||
def __init__(self, name:str, prg:str, binary=False, argdtypes=None):
|
||||
@@ -59,7 +72,7 @@ class CLProgram:
|
||||
def max_work_group_size(): return CL.cl_ctx.devices[0].max_work_group_size
|
||||
|
||||
def __call__(self, global_size, local_size, *bufs, wait=False) -> Optional[float]:
|
||||
e = self.clprg(CL.cl_queue, global_size, local_size, *[x._cl if isinstance(x, (CLBuffer, CLImage)) else x for x in bufs])
|
||||
e = self.clprg(CL.cl_queue, global_size, local_size, *[x._buf if isinstance(x, CLBuffer) else x for x in bufs])
|
||||
if wait:
|
||||
CL.cl_queue.finish()
|
||||
return ((e.profile.end - e.profile.start) * OSX_TIMING_RATIO) * 1e-9
|
||||
@@ -72,9 +85,13 @@ class CLCodegen(GPUCodegen):
|
||||
barrier = "barrier(CLK_LOCAL_MEM_FENCE);", float4 = "(float4)",
|
||||
gid = [f'get_global_id({i})' for i in range(3)], lid = [f'get_local_id({i})' for i in range(3)])
|
||||
|
||||
GPUBuffer = Compiled(CLBuffer, CLCodegen, CLProgram)
|
||||
|
||||
"""
|
||||
class GPUBuffer(CompiledBuffer):
|
||||
spec = Specialized(CLBuffer, CLCodegen, CLProgram)
|
||||
# override this method for image
|
||||
def create_raw_buffer(self, shape, backing, dtype) -> RawBuffer:
|
||||
if len(shape) == 3 and shape[2] == 4 and IMAGE >= 2 and backing is None: return CLImage(shape) # NOTE: this is a hack. we don't pass in the dtype here, it's controlled by the FLOAT16 env var
|
||||
else: return super().create_raw_buffer(shape, backing, dtype)
|
||||
"""
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import time, hashlib, ctypes
|
||||
from typing import ClassVar
|
||||
from tinygrad.ops import CompiledBuffer, Specialized
|
||||
from tinygrad.ops import Compiled
|
||||
from tinygrad.helpers import getenv, DEBUG
|
||||
from ctypes import CFUNCTYPE
|
||||
from tinygrad.codegen.llvm import LLVMCodegen
|
||||
@@ -62,5 +62,4 @@ class LLVMProgram:
|
||||
cfunc(*[x._buf for x in bufs])
|
||||
if wait: return time.monotonic()-st
|
||||
|
||||
class LLVMBuffer(CompiledBuffer):
|
||||
spec = Specialized(RawMallocBuffer, LLVMCodegen, LLVMProgram)
|
||||
LLVMBuffer = Compiled(RawMallocBuffer, LLVMCodegen, LLVMProgram)
|
||||
|
||||
@@ -4,7 +4,8 @@ import Metal, Cocoa, libdispatch # type: ignore
|
||||
from typing import List, Any
|
||||
from tinygrad.codegen.gpu import GPUCodegen, GPULanguage
|
||||
from tinygrad.helpers import prod, getenv, DEBUG, DType
|
||||
from tinygrad.ops import CompiledBuffer, Specialized
|
||||
#from tinygrad.ops import CompiledBuffer, Specialized
|
||||
from tinygrad.ops import Compiled
|
||||
from tinygrad.runtime.lib import RawBufferMapped
|
||||
|
||||
METAL_XCODE = getenv("METAL_XCODE")
|
||||
@@ -17,16 +18,14 @@ class _METAL:
|
||||
METAL = _METAL()
|
||||
|
||||
class RawMetalBuffer(RawBufferMapped):
|
||||
def __init__(self, size:int, dtype:DType):
|
||||
super().__init__(size, dtype)
|
||||
self._cl = METAL.device.newBufferWithLength_options_(size*dtype.itemsize, Metal.MTLResourceStorageModeShared)
|
||||
def __init__(self, size:int, dtype:DType): super().__init__(size, dtype, METAL.device.newBufferWithLength_options_(size*dtype.itemsize, Metal.MTLResourceStorageModeShared))
|
||||
def __del__(self):
|
||||
self._cl.release()
|
||||
self._buf.release()
|
||||
super().__del__()
|
||||
def _buffer(self):
|
||||
for cbuf in METAL.mtl_buffers_in_flight: cbuf.waitUntilCompleted()
|
||||
METAL.mtl_buffers_in_flight.clear()
|
||||
return self._cl.contents().as_buffer(self._cl.length())
|
||||
return self._buf.contents().as_buffer(self._buf.length())
|
||||
|
||||
def unwrap(x):
|
||||
ret, err = x
|
||||
@@ -65,7 +64,7 @@ class MetalProgram:
|
||||
command_buffer = METAL.mtl_queue.commandBuffer()
|
||||
encoder = command_buffer.computeCommandEncoder()
|
||||
encoder.setComputePipelineState_(self.pipeline_state)
|
||||
for i,a in enumerate(bufs): encoder.setBuffer_offset_atIndex_(a._cl, 0, i)
|
||||
for i,a in enumerate(bufs): encoder.setBuffer_offset_atIndex_(a._buf, 0, i)
|
||||
encoder.dispatchThreads_threadsPerThreadgroup_(Metal.MTLSize(*global_size), Metal.MTLSize(*local_size))
|
||||
encoder.endEncoding()
|
||||
command_buffer.commit()
|
||||
@@ -82,5 +81,4 @@ class MetalCodegen(GPUCodegen):
|
||||
gid = [f"gid.{chr(120+i)}" for i in range(3)], lid = [f"lid.{chr(120+i)}" for i in range(3)],
|
||||
extra_args = ['uint3 gid [[thread_position_in_grid]]', 'uint3 lid [[thread_position_in_threadgroup]]'])
|
||||
|
||||
class MetalBuffer(CompiledBuffer):
|
||||
spec = Specialized(RawMetalBuffer, MetalCodegen, MetalProgram)
|
||||
MetalBuffer = Compiled(RawMetalBuffer, MetalCodegen, MetalProgram)
|
||||
|
||||
@@ -1,9 +1,9 @@
|
||||
import torch
|
||||
from typing import ClassVar, Dict, Callable
|
||||
from tinygrad.ops import UnaryOps, BinaryOps, MovementOps, FusedOps, LoadOps, Op
|
||||
from tinygrad.helpers import getenv, dtypes
|
||||
from tinygrad.interpreted import InterpretedBuffer
|
||||
from typing import Dict, Callable
|
||||
from tinygrad.ops import UnaryOps, BinaryOps, MovementOps, FusedOps, Op, Interpreted
|
||||
from tinygrad.helpers import getenv, dtypes, prod
|
||||
from tinygrad.runtime.ops_cpu import base_fxn_for_op, einsum_mulacc
|
||||
from tinygrad.runtime.lib import RawBuffer
|
||||
|
||||
device = torch.device("cuda:0" if torch.cuda.is_available() else ("mps" if getenv("MPS", 0) else "cpu"))
|
||||
|
||||
@@ -13,10 +13,12 @@ torch_fxn_for_op: Dict[Op, Callable] = {**base_fxn_for_op, **{
|
||||
MovementOps.PAD: lambda x, padding: torch.nn.functional.pad(x, [item for sublist in padding[::-1] for item in sublist]),
|
||||
FusedOps.MULACC: einsum_mulacc(lambda s,a,b: torch.einsum(s, a.float(), b.float()).type(a.dtype), lambda x: x.stride(), lambda x,s: x.expand(s)),
|
||||
MovementOps.STRIDE: lambda x, arg: x[tuple(slice(None, None, abs(i)) for i in arg)].flip([i for i,a in enumerate(arg) if a < 0]),
|
||||
LoadOps.FROMCPU: lambda arg: torch.from_numpy(arg).requires_grad_(False).to(device)
|
||||
MovementOps.EXPAND: lambda x, arg: x.expand(arg), MovementOps.PERMUTE: lambda x, arg: x.permute(arg)
|
||||
}}
|
||||
|
||||
class TorchBuffer(InterpretedBuffer):
|
||||
fxn_for_op: ClassVar = torch_fxn_for_op
|
||||
def to_tinygrad_dtype(self): return {torch.float16: dtypes.float16, torch.float32: dtypes.float32}[self._buf.dtype]
|
||||
class RawTorchBuffer(RawBuffer):
|
||||
def __init__(self, buf:torch.Tensor): super().__init__(prod(buf.shape), {torch.float16: dtypes.float16, torch.float32: dtypes.float32}[buf.dtype], buf)
|
||||
@classmethod
|
||||
def fromCPU(cls, x): return cls(torch.from_numpy(x).requires_grad_(False).to(device))
|
||||
def toCPU(self): return self._buf.cpu().numpy()
|
||||
TorchBuffer = Interpreted(RawTorchBuffer, torch_fxn_for_op)
|
||||
|
||||
@@ -38,6 +38,7 @@ class Node:
|
||||
assert b != 0
|
||||
if b < 0: return (self//-b)*-1
|
||||
if b == 1: return self
|
||||
if isinstance(self, ModNode) and self.b % b == 0: return (self.a//b) % (self.b//b) # put the div inside mod
|
||||
if isinstance(self, DivNode): return self.a//(self.b*b) # two divs is one div
|
||||
if isinstance(self, MulNode) and self.b % b == 0: return self.a*(self.b//b)
|
||||
if isinstance(self, MulNode) and b % self.b == 0: return self.a//(b//self.b)
|
||||
|
||||
@@ -40,7 +40,12 @@ class Tensor:
|
||||
# TODO: this has to realize, it shouldn't have to
|
||||
data = data.realize().toCPU()
|
||||
|
||||
if isinstance(data, (np.ndarray, LazyNumpyArray)):
|
||||
# all ndarrays are lazy now
|
||||
if isinstance(data, np.ndarray): data = LazyNumpyArray(data, data.shape, data.dtype)
|
||||
|
||||
# by here, it's either LazyNumpyArray or LazyBuffer
|
||||
# TODO: it should all be LazyBuffer I think
|
||||
if isinstance(data, LazyNumpyArray):
|
||||
data = data if data.shape else data.reshape((1,))
|
||||
lazydata = LazyBuffer.fromCPU(data.astype(dtype.np) if dtype is not None else data, device)
|
||||
elif isinstance(data, LazyBuffer):
|
||||
@@ -455,6 +460,7 @@ class Tensor:
|
||||
|
||||
# TODO: this is a hack, but if we add float(0), it will become a float. need real casting support
|
||||
def float(self) -> Tensor: return self.add(Tensor([0], device=self.device, dtype=dtypes.float32, requires_grad=self.requires_grad))
|
||||
def cast(self, dtype:DType) -> Tensor: return mlops.Cast.apply(self, dtype=dtype) if self.dtype != dtype else self
|
||||
|
||||
# register functions to move between devices
|
||||
for device in Device._buffers:
|
||||
|
||||
Reference in New Issue
Block a user