Devicebufferless (#708)

* runs one metal kernel

* conv2d works

* ops tests are passing

* const folding

* all ops work

* pre commit always passes

* torch works

* working still

* fix graph test

* tests passing

* image almost works

* image conv works

* most images

* fix custom

* fix assignment

* fix compile enet

* clean up comments

* fix realize return value

* include shapetracker in LB repr

* copy should make a copy

* reenable method cache

* fix lna

* dtypes in graph

* forward only for IMAGE=2

* simple realize

* getting close

* fixup new api, it's good except the kernel count

* back to 197 kernels

* tests should pass

* go to a real float

* no type_on_cpu

* fix the docs

* put shapetracker back in it's proper place
This commit is contained in:
George Hotz
2023-03-18 14:40:23 -07:00
committed by GitHub
parent 26a3888ab8
commit f5467cfedc
37 changed files with 471 additions and 446 deletions

View File

@@ -79,7 +79,7 @@ jobs:
run: curl https://media.istockphoto.com/photos/hen-picture-id831791190 | ./recognize | grep hen
testllvm:
name: LLVM Tests
name: LLVM Tests (w method cache)
runs-on: ubuntu-latest
steps:
@@ -160,7 +160,9 @@ jobs:
- name: Install Dependencies
run: pip install -e '.[gpu,testing]' --extra-index-url https://download.pytorch.org/whl/cpu
- name: Test GPU IMAGE ops
run: GPU=1 IMAGE=2 python3 test/test_ops.py
run: |
GPU=1 IMAGE=1 python3 test/test_ops.py
FORWARD_ONLY=1 GPU=1 IMAGE=2 python3 test/test_ops.py
- name: Test openpilot model
run: |
ALLOWED_KERNEL_COUNT=197 FLOAT16=1 VALIDHACKS=1 DEBUGCL=1 GPU=1 IMAGE=2 python3 openpilot/compile.py

View File

@@ -1,6 +1,12 @@
repos:
- repo: local
hooks:
- id: docs
name: docs
entry: python3 docs/abstractions.py
language: system
always_run: true
pass_filenames: false
- id: flake8
name: flake8
entry: flake8 tinygrad/ --indent-size=2 --select=F,E112,E113,E203,E304,E502,E702,E703,E71,E72,E731,W191,W6 --statistics -j4

View File

@@ -469,4 +469,4 @@ check-str-concat-over-line-jumps=yes
# Exceptions that will emit a warning when being caught. Defaults to
# "Exception"
overgeneral-exceptions=Exception
overgeneral-exceptions=builtins.Exception

View File

@@ -77,14 +77,20 @@ class LazyBuffer:
shape: Tuple[int, ...]
dtype: DType
# a ShapeTracker is used to track things like reshapes and permutes
# all MovementOps are zero copy in tinygrad!
# the ShapeTracker specifies how the data in the RawBuffer matches to the shape
# we'll come back to this later
st: ShapeTracker
# if the LazyBuffer is realized, it has a RawBuffer
# we will come back to RawBuffers later
realized: Optional[RawBuffer]
# if the lazybuffer is unrealized, it has a LazyOp
# this LazyOp describes the computation needed to realize this LazyBuffer
op: Optional[LazyOp]
# if the LazyBuffer is realized, it has a DeviceBuffer
# we will come back to DeviceBuffers later, first we'll explore the LazyOp
realized: Optional[DeviceBuffer]
# LazyOp (in tinygrad/ops.py, code 4/10)
# in a tree they form an Abstract Syntax Tree for a single GPU kernel
class LazyOp:
@@ -128,81 +134,60 @@ assert len(lazyop.src) == 2
# again, a LazyOp AST is like a GPU kernel. you have to copy the data on the device first
print(lazyop.src[0].op)
assert lazyop.src[0].op.op == LoadOps.FROMCPU
assert lazyop.src[0].op.arg[0] == [2], "the arg of the FROMCPU LazyOP is the [2.]"
assert lazyop.src[0].op.arg.fxn == [2], "the arg of the FROMCPU LazyOP is the [2.]"
assert result.lazydata.realized is None, "the LazyBuffer is not realized yet"
# now we realize the LazyBuffer
result.lazydata.realize()
assert result.lazydata.realized is not None, "the LazyBuffer is realized!"
# this brings us nicely to DeviceBuffer, of which the realized ClangBuffer is a subclass
assert 'ClangBuffer' in str(type(result.lazydata.realized))
assert 'RawMallocBuffer' in str(type(result.lazydata.realized))
# getting ahead of ourselves, but we can copy the DeviceBuffer toCPU
assert result.lazydata.realized.toCPU()[0] == 5, "when put in numpy with toCPU, it's 5"
# %%
# == DeviceBuffer (in tinygrad/ops.py, code 4/10) ==
# == Union[Interpreted, Compiled] (in tinygrad/ops.py, code 5/10) ==
# DeviceBuffer is an abstract class to be implemented for each Device backend
class DeviceBuffer(ABC):
# these two are straightforward.
# unlike LazyBuffer, there's no need for device, since that's contained in the concrete type
shape: Tuple[int, ...]
dtype: DType
# Now you have a choice, you can either write a "Interpreted" backend or "Compiled" backend
# this is the magic method that "fills" a DeviceBuffer and does all the math in tinygrad
# NOTE: fromCPU no longer exists here, it's just a one LoadOps AST, LoadOps.FROMCPU
def exec_ast(self, ast:LazyOp): raise NotImplementedError("must be implemented")
# Interpreted backends are very simple (example: CPU and TORCH)
class Interpreted:
# they have a backing RawBuffer
buffer: Type[RawBuffer]
# however, toCPU still exists. it will raise a RuntimeException if exec_ast has never been called
# it copies out the underlying to the CPU, and will do any sync operations
def toCPU(self) -> np.ndarray: raise NotImplementedError("must be implemented")
# and they have a lookup table to functions for the Ops
fxn_for_op: Dict[Op, Callable] = {
UnaryOps.EXP: lambda x: np.exp(x),
BinaryOps.ADD: lambda x,y: x+y}
# DeviceBuffers come in two flavors, InterpretedBuffer and CompiledBuffer
# InterpretedBuffers are a lot simpler than CompiledBuffers
# they are used to implement the CPU(numpy) and TORCH(torch) backends
# it's worth reading CPUBuffer (in tinygrad/runtime/ops_cpu.py, code 8/10)
import numpy as np
import torch
class InterpretedBuffer(DeviceBuffer):
# this is where the data actually lives
# finally some classes you recognize!
_buf: Union[np.ndarray, torch.Tensor]
# Compiled backends take a little more (example: GPU and LLVM)
class Compiled:
# they also have a backing RawBuffer
buffer: Type[RawBuffer]
# the compute itself is defined here. these functions are called with _buf
# here's a UnaryOp and BinaryOp from CPUBuffer(InterpretedBuffer)
fxn_for_op: ClassVar[Dict[Op, Callable]] = {UnaryOps.EXP: lambda x: np.exp(x), BinaryOps.ADD: lambda x,y: x+y}
# NOTE: exec_ast should not need to be overridden!
# The actual method lives in tinygrad/ops.py
# it walks the LazyOp tree and calls fxn_for_op as appropriate
# ********** NOTE: for the CPU and TORCH backends, we are done and you can stop reading here **********
# %%
# == CompiledBuffer (in tinygrad/ops.py, code 4/10) ==
# however, all the magic of tinygrad will come from CompiledBuffer
# this is used for the GPU(opencl), CUDA, METAL, CLANG, and LLVM backends
class CompiledBuffer(DeviceBuffer):
# this is where the data actually lives, same as InterpretedBuffer
# a RawBuffer is just raw (typed) memory on the Device in question
_buf: RawBuffer
# introducing...ShapeTracker! all MovementOps are zero copy in tinygrad
# the ShapeTracker specifies how the data in the RawBuffer matches to the shape
# we'll come back to this later
st: ShapeTracker
# NOTE: exec_ast should not need to be overridden!
# instead you need three classes, explained below
raw_buffer: Type[RawBuffer]
runtime: Type[Runtime]
# a code generator, which compiles the AST
codegen: Type[ASTKernel]
# for completeness, we include RawBuffer. it's very boring and exactly what you expect
# and a runtime, which runs the generated code
runtime: Type[Runtime]
# Runtime is what actually runs the kernels for a compiled backend
class Runtime(ABC):
# `name` is the name of the function, and `prg` is the code
# the constructor compiles the code
def __init__(self, name:str, prg:str): pass
# call runs the code on the bufs. NOTE: the output is always bufs[0], but this is just a convention
def __call__(self, global_size:Optional[List[int]], local_size:Optional[List[int]], *bufs:List[RawBuffer]): pass
# %%
# == RawBuffer (in tinygrad/runtime/lib.py, code 5/10) ==
import numpy as np
# RawBuffer is where the data is actualy held. it's pretty close to just memory
class RawBuffer(ABC):
# create an empty rawbuffer that holds `size` elements of type `dtype`
def __init__(self, size:int, dtype:DType): raise NotImplementedError("must be implemented")
# `buf` is an opaque container class
def __init__(self, size:int, dtype:DType, buf:Any): raise NotImplementedError("must be implemented")
# fromCPU is classmethod that creates a RawBuffer, it's a classmethod since some runtimes are 0 copy
@classmethod
@@ -211,13 +196,14 @@ class RawBuffer(ABC):
# toCPU converts the RawBuffer to a numpy array with shape (size,). many backends are 0 copy here
def toCPU(self) -> np.ndarray: raise NotImplementedError("must be implemented")
# Runtime is what actually runs the kernels
class Runtime(ABC):
# `name` is the name of the function, and `prg` is the code
# the constructor compiles the code
def __init__(self, name:str, prg:str): pass
# call runs the code on the bufs. NOTE: the output is always bufs[0], but this is just a convention
def __call__(self, global_size:Optional[List[int]], local_size:Optional[List[int]], *bufs:List[RawBuffer]): pass
# RawNumpyBuffer is a RawBuffer example for numpy. It's very simple
class RawNumpyBuffer(RawBuffer):
# NOTE: the "np.ndarray" is stored in the opaque container
def __init__(self, buf:np.ndarray):
super().__init__(buf.size, dtypes.from_np(buf.dtype), buf)
@classmethod
def fromCPU(cls, x): return cls(x)
def toCPU(self): return self._buf
# %%
# == Example: 2+3 in raw clang ==
@@ -262,11 +248,11 @@ class ASTKernel:
def __init__(self, ast:LazyOp): pass
def codegen(self) -> ASTRunner: pass
# we return a class that runs code on CompiledBuffers
# we return a class that runs code on LazyBuffers, which are all expected to be realized
class ASTRunner: # (from tinygrad/ops.py)
def __init__(self, name, prg, global_size:Optional[List[int]], local_size:Optional[List[int]]): pass
def build(self, runtime:Runtime): pass
def exec(self, bufs:List[CompiledBuffer]): pass
def exec(self, bufs:List[LazyBuffer]): pass
# that hides a lot of complexity that will be refactored, but that's the basic idea of code generation

View File

@@ -43,10 +43,10 @@ if __name__ == "__main__":
# hack to put the inputs back
assert len(run.input_replace) == 1, f"didn't get one input to replace {run.input_replace}"
for (j,i),idx in run.input_replace.items():
run.jit_cache[j][1][i] = the_input.lazydata.realized.raw()
run.jit_cache[j][1][i] = the_input.lazydata.realized
# TODO: fetch this from the jit in self.input_replace and self.ret (hint: use get_parameters on self.ret)
special_names = {id(the_input.lazydata.realized.raw()): "input", id(the_output.lazydata.realized.raw()): "outputs"}
special_names = {id(the_input.lazydata.realized): "input", id(the_output.lazydata.realized): "outputs"}
functions, statements, bufs, bufs_to_save = compile_net(run, special_names)
@@ -109,4 +109,5 @@ int main(int argc, char* argv[]) {
}"""]
# CLANG=1 python3 examples/compile_efficientnet.py | clang -O2 -lm -x c - -o recognize && DEBUG=1 time ./recognize docs/stable_diffusion_by_tinygrad.jpg
# category : 281 (tabby, tabby cat) with 9.452788
print('\n'.join(cprog))

View File

@@ -3,7 +3,7 @@ import gc
from tinygrad.helpers import prod
from tinygrad.tensor import Tensor
from tinygrad.lazy import LazyBuffer
from tinygrad.runtime.ops_gpu import GPUBuffer
from tinygrad.runtime.ops_gpu import CLBuffer
from tinygrad.ops import GlobalCounters
def print_objects():
@@ -11,7 +11,7 @@ def print_objects():
tensors = [x for x in gc.get_objects() if isinstance(x, Tensor)]
tensor_ram_used = sum([prod(x.shape)*4 for x in tensors])
lazybuffers = [x for x in gc.get_objects() if isinstance(x, LazyBuffer)]
gpubuffers = [x for x in gc.get_objects() if isinstance(x, GPUBuffer)]
gpubuffers = [x for x in gc.get_objects() if isinstance(x, CLBuffer)]
realized_buffers = [x.realized for x in lazybuffers if x.realized]
gpubuffers_orphaned = [x for x in gpubuffers if x not in realized_buffers]

View File

@@ -109,8 +109,8 @@ def load_single_weight(t:Tensor, myfile, shape, strides, dtype, storage_offset,
# this needs real APIs
if t.device in ["METAL", "CLANG", "LLVM"]:
del t.lazydata.op
t.lazydata.realized = t.lazydata.dbuffer(t.shape, dtype=t.dtype)
myfile.readinto(t.lazydata.realized.raw()._buffer())
t.lazydata.realized = t.lazydata.dbuffer.buffer(prod(t.shape), dtype=t.dtype)
myfile.readinto(t.lazydata.realized._buffer())
else:
def _mmap(lna):
assert myfile._compress_type == 0, "compressed data can't be mmaped"

View File

@@ -9,7 +9,7 @@ if os.getenv("GPU", None) is None:
if os.getenv("IMAGE", None) is None:
os.environ['IMAGE'] = '2'
from tinygrad.helpers import getenv
from tinygrad.helpers import getenv, dtypes
ALLOWED_KERNEL_COUNT = getenv("ALLOWED_KERNEL_COUNT", 0)
DEBUGCL = getenv("DEBUGCL", 0)
@@ -38,7 +38,7 @@ from tinygrad.jit import TinyJit
@TinyJit
def model_exec(run_onnx, using_graph, **inputs):
ret = next(iter(run_onnx(inputs).values()))
ret = next(iter(run_onnx(inputs).values())).cast(dtypes.float32)
GlobalCounters.reset()
GlobalCounters.cache = [] # don't cache pre-realize
if using_graph: graph.GRAPH = True
@@ -49,7 +49,7 @@ def compile(dat, output_fn):
Tensor.manual_seed(1337)
Tensor.no_grad = True
using_graph = graph.GRAPH
graph.GRAPH = False
if getenv("GRAPH") < 2: graph.GRAPH = False
onnx_model = onnx.load(io.BytesIO(dat))
run_onnx = get_run_onnx(onnx_model)
@@ -63,7 +63,7 @@ def compile(dat, output_fn):
assert len(model_exec.jit_cache) <= ALLOWED_KERNEL_COUNT or ALLOWED_KERNEL_COUNT == 0, "too many kernels!"
# pull out inputs and put them in the jit cache
input_rawbuffers = {k:inputs[k].lazydata.realized.raw() for k in inputs.keys()}
input_rawbuffers = {k:inputs[k].lazydata.realized for k in inputs.keys()}
for (j,i),idx in model_exec.input_replace.items(): model_exec.jit_cache[j][1][i] = input_rawbuffers[idx]
# transform to CL.CACHE
@@ -73,11 +73,11 @@ def compile(dat, output_fn):
# pass these to thneed
setattr(prg.clprg, 'op_estimate', prg.op_estimate)
setattr(prg.clprg, 'prg', prg.prg)
cl_cache.append((prg.clprg, [prg.global_size, prg.local_size, *[x._cl for x in args]]))
cl_cache.append((prg.clprg, [prg.global_size, prg.local_size, *[x._buf for x in args]]))
used_ops += prg.op_estimate
from extra.thneed import Thneed
t = Thneed(cl_cache, {k:v._cl for k,v in input_rawbuffers.items()})
t = Thneed(cl_cache, {k:v._buf for k,v in input_rawbuffers.items()})
# save thneed (before run)
t.save(output_fn)

View File

@@ -226,6 +226,14 @@ class TestOpt(unittest.TestCase):
np.testing.assert_allclose(c.numpy(), d.numpy().transpose(1,0), rtol=1e-3, atol=1e-5)
assert cache_len == 1, "reduceop was rerun!"
def test_fold_with_contiguous(self):
a = Tensor.randn(16, 16, 16)
b = Tensor.randn(16, 16)
with CLCache():
c = (a.sum(2).contiguous() + b).contiguous()
c.realize()
cache_len = len(GlobalCounters.cache)
assert cache_len == 1, "contiguous wasn't folded"
if __name__ == '__main__':
unittest.main()

View File

@@ -4,12 +4,13 @@ import unittest
from extra.utils import fetch, fake_torch_load_zipped
from PIL import Image
class TestUtils(unittest.TestCase):
class TestUtils(unittest.TestCase):
@unittest.skip("hangs sometimes")
def test_fetch_bad_http(self):
self.assertRaises(AssertionError, fetch, 'http://httpstat.us/500')
self.assertRaises(AssertionError, fetch, 'http://httpstat.us/404')
self.assertRaises(AssertionError, fetch, 'http://httpstat.us/400')
def test_fetch_small(self):
assert(len(fetch('https://google.com'))>0)

View File

@@ -8,25 +8,23 @@ from tinygrad.helpers import prod, dtypes
# *** first, we implement the atan2 op at the lowest level ***
# `atan2_gpu` for GPUBuffers and `atan2_cpu` for CPUBuffers
from tinygrad.ops import ASTRunner, CompiledBuffer
from tinygrad.runtime.ops_cpu import CPUBuffer
from tinygrad.lazy import LazyBuffer, Device
from tinygrad.ops import ASTRunner
# we don't always have GPU support, so the type signature is the abstract CompiledBuffer instead of GPUBuffer
def atan2_gpu(a:CompiledBuffer, b:CompiledBuffer) -> CompiledBuffer:
from tinygrad.runtime.ops_gpu import GPUBuffer
assert type(a) == GPUBuffer and type(b) == GPUBuffer, "gpu function requires GPUBuffers"
def atan2_gpu(ret:LazyBuffer, a:LazyBuffer, b:LazyBuffer):
assert a.device == "GPU" and b.device == "GPU", "gpu function requires GPUBuffers"
assert a.dtype == b.dtype and a.dtype == dtypes.float32, "gpu function only supports float32"
ret = GPUBuffer(a.shape)
ret.realized = Device[ret.device].buffer(prod(ret.shape), ret.dtype)
ASTRunner("atan2", """
__kernel void atan2(global float *c, global float *a, global float *b) {
int idx = get_global_id(0);
c[idx] = atan2(a[idx], b[idx]);
}""", global_size=[prod(ret.shape)]).build(GPUBuffer.spec.runtime).exec([ret, a.contiguous(), b.contiguous()])
return ret
}""", global_size=[prod(ret.shape)]).build(Device[ret.device].runtime).exec([ret, a, b])
return ret.realized
def atan2_cpu(a:CPUBuffer, b:CPUBuffer) -> CPUBuffer:
return CPUBuffer(np.arctan2(a._buf, b._buf))
def atan2_cpu(ret:LazyBuffer, a:LazyBuffer, b:LazyBuffer):
return Device[ret.device].buffer(np.arctan2(a.realized._buf, b.realized._buf))
# *** second, we write the ATan2 mlop ***
# NOTE: The derivative of atan2 doesn't need a custom op! https://www.liquisearch.com/atan2/derivative
@@ -40,7 +38,7 @@ class ATan2(Function):
def forward(self, a:LazyBuffer, b:LazyBuffer) -> LazyBuffer:
assert prod(a.shape) == prod(b.shape) and a.device == b.device, "shape or device mismatch"
self.a, self.b = a, b
ast = LazyOp(LoadOps.CUSTOM, (a, b), {"GPU": atan2_gpu, "CPU": atan2_cpu}[a.device])
ast = LazyOp(LoadOps.CUSTOM, (a.contiguous(), b.contiguous()), {"GPU": atan2_gpu, "CPU": atan2_cpu}[a.device])
return LazyBuffer(a.device, a.shape, LoadOps, ast, max(a.dtype, b.dtype))
def backward(self, grad_output:LazyBuffer) -> Tuple[Optional[LazyBuffer], Optional[LazyBuffer]]:
denom = (self.a.binary_op(BinaryOps.MUL, self.a)).binary_op(BinaryOps.ADD, self.b.binary_op(BinaryOps.MUL, self.b))

View File

@@ -3,7 +3,7 @@ import time
import numpy as np
import unittest
from tinygrad.tensor import Tensor
from tinygrad.helpers import getenv, IMAGE
from tinygrad.helpers import getenv
FORWARD_ONLY = getenv("FORWARD_ONLY", 0)
def helper_test_op(shps, torch_fxn, tinygrad_fxn=None, atol=1e-6, rtol=1e-3, grad_atol=1e-4, grad_rtol=1e-3, forward_only=False, vals=None, a=-0.5, b=3):
@@ -300,7 +300,7 @@ class TestOps(unittest.TestCase):
lambda x,w: torch.nn.functional.conv2d(x, w),
lambda x,w: x.conv2d(w), atol=1e-2)
@unittest.skip("not supported with IMAGE=1")
@unittest.skip("slow")
def test_large_bs_conv(self):
# large batch size can cause OpenCL image to exceed max image height on macOS
# (or cause the conv kernel to overflow short sampling coords)
@@ -308,7 +308,7 @@ class TestOps(unittest.TestCase):
lambda x,w: torch.nn.functional.conv2d(x, w),
lambda x,w: x.conv2d(w), atol=1e-4, rtol=1e-2)
@unittest.skip("not supported with IMAGE=1")
@unittest.skip("slow")
def test_large_ic_conv(self):
# large input channel count can cause OpenCL image to exceed max image width on macOS
helper_test_op([(1,2048,3,3), (1,2048,3,3)],
@@ -377,7 +377,7 @@ class TestOps(unittest.TestCase):
cin = 2
helper_test_op([(bs,groups*cin,1,1), (groups*rcout,cin,1,1)],
lambda x,w: torch.nn.functional.conv2d(x,w,groups=groups).relu(),
lambda x,w: Tensor.conv2d(x,w,groups=groups).relu(), atol=1e-4, grad_rtol=1e-5, forward_only=IMAGE>=2)
lambda x,w: Tensor.conv2d(x,w,groups=groups).relu(), atol=1e-4, grad_rtol=1e-5)
def test_medium_grouped_conv2d(self):
bs = 1
@@ -386,7 +386,7 @@ class TestOps(unittest.TestCase):
cin = 2
helper_test_op([(bs,groups*cin,1,1), (groups*rcout,cin,1,1)],
lambda x,w: torch.nn.functional.conv2d(x,w,groups=groups).relu(),
lambda x,w: Tensor.conv2d(x,w,groups=groups).relu(), atol=1e-4, grad_rtol=1e-5, forward_only=IMAGE>=2)
lambda x,w: Tensor.conv2d(x,w,groups=groups).relu(), atol=1e-4, grad_rtol=1e-5)
def test_depthwise_conv2d(self):
bs = 1

View File

@@ -17,6 +17,13 @@ def multidevice_test(fxn):
return ret
class TestExample(unittest.TestCase):
@multidevice_test
def test_convert_to_cpu(self, device):
a = Tensor([[1,2],[3,4]], device=device)
assert a.numpy().shape == (2,2)
b = a.cpu()
assert b.numpy().shape == (2,2)
@multidevice_test
def test_2_plus_3(self, device):
a = Tensor([2], device=device)

View File

@@ -1,37 +1,38 @@
#!/usr/bin/env python
import unittest
from tinygrad.ops import LazyOp, BinaryOps
from tinygrad.interpreted import get_lazyop_info, InterpretedBuffer, GenericShape
from typing import NamedTuple, Tuple
from tinygrad.ops import LazyOp, BinaryOps, get_lazyop_info
from tinygrad.helpers import DType, dtypes
class TestBuffer(NamedTuple):
shape: Tuple[int, ...]
dtype: DType
class TestFlopCounter(unittest.TestCase):
def setUp(self):
self.buf0 = TestBuffer(shape=(4,), dtype=dtypes.float32)
self.buf1 = TestBuffer(shape=(4,), dtype=dtypes.float32)
def test_flops_add(self):
buf0 = InterpretedBuffer(GenericShape((4,)))
buf1 = InterpretedBuffer(GenericShape((4,)))
op0 = LazyOp(BinaryOps.ADD, (buf0,buf1,), None)
op0 = LazyOp(BinaryOps.ADD, (self.buf0,self.buf1,), None)
info = get_lazyop_info(op0)
self.assertEqual(info.flops, 4)
def test_flops_add_twice(self):
buf0 = InterpretedBuffer(GenericShape((4,)))
buf1 = InterpretedBuffer(GenericShape((4,)))
op0 = LazyOp(BinaryOps.ADD, (buf0,buf1,), None)
op1 = LazyOp(BinaryOps.ADD, (op0,buf1,), None)
op0 = LazyOp(BinaryOps.ADD, (self.buf0,self.buf1,), None)
op1 = LazyOp(BinaryOps.ADD, (op0,self.buf1,), None)
info = get_lazyop_info(op1)
self.assertEqual(info.flops, 8)
def test_flops_add_self(self):
buf0 = InterpretedBuffer(GenericShape((4,)))
buf1 = InterpretedBuffer(GenericShape((4,)))
op0 = LazyOp(BinaryOps.ADD, (buf0,buf1,), None)
op0 = LazyOp(BinaryOps.ADD, (self.buf0,self.buf1,), None)
op1 = LazyOp(BinaryOps.ADD, (op0,op0,), None)
info = get_lazyop_info(op1)
self.assertEqual(info.flops, 8)
def test_flops_add_roundabout_self(self):
buf0 = InterpretedBuffer(GenericShape((4,)))
buf1 = InterpretedBuffer(GenericShape((4,)))
op0 = LazyOp(BinaryOps.ADD, (buf0,buf1,), None)
op1 = LazyOp(BinaryOps.ADD, (op0,buf1,), None)
op0 = LazyOp(BinaryOps.ADD, (self.buf0,self.buf1,), None)
op1 = LazyOp(BinaryOps.ADD, (op0,self.buf1,), None)
op2 = LazyOp(BinaryOps.ADD, (op0,op1,), None)
info = get_lazyop_info(op2)
self.assertEqual(info.flops, 12)

View File

@@ -1,11 +1,12 @@
#!/usr/bin/env python
import unittest
import networkx as nx # type: ignore
import numpy as np
from tinygrad.tensor import Tensor
from tinygrad.graph import G, log_op, prune_graph
from tinygrad.runtime.ops_cpu import CPUBuffer
from tinygrad.ops import BinaryOps, LazyOp, MovementOps, ReduceOps
def buf(*shp): return Tensor.ones(*shp, device="CPU").lazydata
class TestGraph(unittest.TestCase):
def setUp(self):
G.clear()
@@ -14,10 +15,10 @@ class TestGraph(unittest.TestCase):
assert nx.is_isomorphic(G, RG, node_match=lambda x,y: x["label"] == y["label"], edge_match=lambda x,y: x["label"] == y["label"] if "label" in y else True)
def test_add_graph(self):
a = CPUBuffer(np.ones((4,4), dtype=np.float32))
b = CPUBuffer(np.ones((4,4), dtype=np.float32))
a = buf(4,4)
b = buf(4,4)
ast = LazyOp(BinaryOps.ADD, (a,b))
ret = CPUBuffer(np.ones((4,4), dtype=np.float32))
ret = buf(4,4)
RG = nx.DiGraph()
RG.add_node(0, label="(4, 4)")
@@ -30,12 +31,12 @@ class TestGraph(unittest.TestCase):
self.helper_compare_graph(RG)
def test_add_sum_graph(self):
a = CPUBuffer(np.ones((4,4), dtype=np.float32))
b = CPUBuffer(np.ones((1,1), dtype=np.float32))
a = buf(4,4)
b = buf(1,1)
op0 = LazyOp(MovementOps.RESHAPE, (b,), (4, 4))
op1 = LazyOp(BinaryOps.ADD, (a,op0))
ast = LazyOp(ReduceOps.SUM, (op1,), (1,1))
ret = CPUBuffer(np.ones((1,1), dtype=np.float32))
ret = buf(1,1)
RG = nx.DiGraph()
RG.add_node(0, label="(4, 4)")
@@ -48,14 +49,14 @@ class TestGraph(unittest.TestCase):
self.helper_compare_graph(RG)
def test_add_graph_prune(self):
a = CPUBuffer(np.ones((1,1), dtype=np.float32))
a = buf(1,1)
ast = LazyOp(MovementOps.RESHAPE, (a,), (4, 4))
ret = CPUBuffer(np.ones((4,4), dtype=np.float32))
ret = buf(4,4)
log_op(ret, ast, show_graph=True)
b = CPUBuffer(np.ones((4,4), dtype=np.float32))
b = buf(4,4)
ast = LazyOp(BinaryOps.ADD, (ret,b))
ret = CPUBuffer(np.ones((4,4), dtype=np.float32))
ret = buf(4,4)
log_op(ret, ast, show_graph=True)
prune_graph()

View File

@@ -180,6 +180,9 @@ class TestSymbolic(unittest.TestCase):
def test_div_numerator_negative(self):
self.helper_test_variable((Variable("idx", 0, 9)*-10)//11, -9, 0, "((((idx*-10)+99)//11)+-9)")
def test_div_into_mod(self):
self.helper_test_variable((Variable("idx", 0, 16)*4)%8//4, 0, 1, "(idx%2)")
class TestSymbolicNumeric(unittest.TestCase):
def helper_test_numeric(self, f):
# TODO: why are the negative tests broken? (even if we did support negative variables)

View File

@@ -2,9 +2,8 @@ import itertools
from enum import Enum, auto
from typing import List, Tuple
from tinygrad.helpers import prod, dedup, all_same, colored, DType
from tinygrad.ops import LazyOp, MovementOps, get_buffers, ReduceOps, get_lazyops, map_buffers, ASTRunner
from tinygrad.ops import LazyOp, MovementOps, get_buffers, ReduceOps, get_lazyops, map_buffers, ASTRunner, get_lazyop_info, FlopCounter
from tinygrad.shape.shapetracker import ShapeTracker, View, strides_for_shape
from tinygrad.interpreted import get_lazyop_info, GenericShape
def get_first_reduce(shapes):
for i in range(len(shapes[0])):
@@ -35,31 +34,14 @@ class ASTKernel:
def __init__(self, ast:LazyOp, output_buffer=None):
self.input_ast = ast
# if the AST ends with a RESHAPE, we remove it and create the buffer accordingly
if ast.op == MovementOps.RESHAPE:
output_shape = ast.arg
ast = ast.src[0]
else:
output_shape = None
# NOTE: if there's a RESHAPE, we skip it. the output shape is set from the reduce op or a latebuf
if ast.op == MovementOps.RESHAPE: ast = ast.src[0]
self.bufs = dedup(get_buffers(ast))
self.bufs = [output_buffer] + dedup(get_buffers(ast))
self.ast = ast
# check if the output buffer is allowed to be used
# if it's aliased, don't use it
if output_buffer is not None:
for a in self.bufs:
if a._buf == output_buffer._buf and not a.st.contiguous:
output_buffer = None
break
# fetch lazyop info (this can be cached!)
self.info: GenericShape = get_lazyop_info(ast)
# create the buffer we are returning (as the same type as the input buffers) and add it as the first buffer
self.ret = output_buffer if output_buffer else type(self.bufs[0])(output_shape if output_shape else self.info.shape, force_create=True, dtype=self.info.dtype)
assert self.ret.dtype == self.info.dtype, f"return dtype {self.ret.dtype} != {self.info.dtype}"
self.bufs = ([type(self.ret)(self.info.shape, hostbuf=self.ret, dtype=self.info.dtype)] if output_shape else [self.ret]) + self.bufs
self.info: FlopCounter = get_lazyop_info(ast)
# key for lookup in cache (can change, str might not be right)
# bufs are needed because kernels like f(x) = x + x and f(x, y) = x + y have the same str(ast), but are different kernels.
@@ -79,8 +61,8 @@ class ASTKernel:
# check valid AST kernel
assert all_same([x.shape for x in self.earlybufs]), "all earlybufs must have the same shape"
assert all_same([x.shape for x in self.bufs if x not in self.earlybufs]), "all latebufs must have the same shape"
assert all_same([len(x.shape) for x in self.bufs]), "all bufs must have the same shape size"
assert all_same([x.shape for x in self.bufs[1:] if x not in self.earlybufs]), "all latebufs must have the same shape"
assert all_same([len(x.shape) for x in self.bufs[1:]]), "all bufs must have the same shape size"
# get full shape buf index (earlybufs if there are any, otherwise output)
self.full_buf_index: int = self.bufs.index(self.earlybufs[0]) if len(self.earlybufs) > 0 else 0
@@ -89,6 +71,10 @@ class ASTKernel:
self.sts: List[ShapeTracker] = [x.st.copy() for x in self.bufs] # create new shapetrackers inside this kernel
for st in self.sts: st.simplify()
# make the output buffer shape correct in here
if self.reduceop is not None: self.sts[0].reshape(self.reduceop.arg)
else: self.sts[0].reshape([x.shape for x in self.bufs[1:] if x not in self.earlybufs][0])
# move all reduce axes to the end
reduce = list(enumerate(zip(self.full_shape, self.sts[0].shape)))
permute = tuple([i for i,(s,n) in reduce if s == n] + [i for i,(s,n) in reduce if s != n])
@@ -125,7 +111,7 @@ class ASTKernel:
if print_shapetrackers:
for st in self.sts: print(st)
for i in range(len(self.sts)):
print(prefix, self.bufs[i].dtype if self.bufs[i] is not None else None, self.buftokens[i], f"early:{'T' if i < len(self.bufs) and self.bufs[i] in self.earlybufs else 'F'}", self.sts[i].shape, self.sts[i].views[-1].strides, len(self.sts[i].views), type(self.bufs[i]._buf) if self.bufs[i] is not None else "FAKE")
print(prefix, self.bufs[i].dtype if self.bufs[i] is not None else None, self.buftokens[i], f"early:{'T' if i < len(self.bufs) and self.bufs[i] in self.earlybufs else 'F'}", self.sts[i].shape, self.sts[i].views[-1].strides, len(self.sts[i].views), self.bufs[i].realized if self.bufs[i] is not None else "FAKE")
def codegen(self) -> ASTRunner: raise NotImplementedError("need a codegen")

View File

@@ -3,9 +3,10 @@ from collections import defaultdict
from typing import Optional, List, Tuple, Dict, Set, Final, NamedTuple, ClassVar, DefaultDict
from tinygrad.ops import UnaryOps, BinaryOps, ReduceOps, LazyOp, Op, ASTRunner
from tinygrad.codegen.ast import ASTKernel, Token, Types
from tinygrad.shape.symbolic import Node, MulNode, DivNode, SumNode, AndNode, Variable, render_python
from tinygrad.shape.symbolic import Node, MulNode, DivNode, SumNode, AndNode, ModNode, Variable, render_python
from tinygrad.shape.shapetracker import ShapeTracker, View
from tinygrad.helpers import getenv, DEBUG, prod, partition, mnum, all_same, dedup, dtypes
from tinygrad.runtime.lib import RawConst
# div is different in cl than python
render_cl = render_python.copy()
@@ -55,7 +56,7 @@ class GPUCodegen(ASTKernel):
kernel_name_cache: Final[Dict[str, str]] = {}
code_for_op: Final[Dict[Op, str]] = {
UnaryOps.NOOP: "(A)", UnaryOps.NEG: "(-(A))", UnaryOps.NOT: "(1.0f-A)",
UnaryOps.NOOP: "(A)", UnaryOps.NEG: "(-(A))", UnaryOps.NOT: "(1.0f-A)", UnaryOps.CAST: "(A)",
UnaryOps.EXP: "native_exp(A)" if NATIVE_EXPLOG else "exp(A)",
UnaryOps.LOG: "native_log(A)" if NATIVE_EXPLOG else "log(A)",
BinaryOps.ADD: "(A+B)", BinaryOps.SUB: "(A-B)", BinaryOps.MUL: "(A*B)",
@@ -73,7 +74,7 @@ class GPUCodegen(ASTKernel):
assert len(self.sts[buf_index].views) == 1, "store has more than one view"
# all stores can merge, since they have one view and are valid
should_upcast = self.lang.float4 and self.buftokens[buf_index].can_float4() and (self.bufs[buf_index] is None or self.bufs[buf_index].dtype != dtypes.float16 or hasattr(self.bufs[buf_index]._buf, "IMAGE"))
should_upcast = self.lang.float4 and self.buftokens[buf_index].can_float4() and (self.bufs[buf_index] is None or self.bufs[buf_index].dtype != dtypes.float16 or self.bufs[buf_index].dtype.name.startswith('image'))
to_store = {o:v for o,v in zip(self.buftokens[buf_index].offsets(), value)}
did_store = set()
@@ -84,10 +85,10 @@ class GPUCodegen(ASTKernel):
if should_upcast:
for j in range(4): did_store.add(o+j)
v = self.group_float4([to_store[o+j] for j in range(4)])
if self.bufs[buf_index] is not None and hasattr(self.bufs[buf_index]._buf, "IMAGE"):
if self.bufs[buf_index] is not None and self.bufs[buf_index].dtype.name.startswith('image'):
assert v.typ == Types.FLOAT4, "Image requires upcasting to FLOAT4"
idx, idy = to_image_idx(self.bufs[buf_index]._base_shape, idxy, valid)
self.kernel.append(f"write_imagef({self.buftokens[buf_index].tok}, (int2)({idx.render(render_cl)}, {idy.render(render_cl)}), {v.tok}); /* {self.bufs[buf_index]._base_shape} */\n")
idx, idy = to_image_idx(self.bufs[buf_index].dtype.shape, idxy, valid)
self.kernel.append(f"write_imagef({self.buftokens[buf_index].tok}, (int2)({idx.render(render_cl)}, {idy.render(render_cl)}), {v.tok}); /* {self.bufs[buf_index].dtype.shape} */\n")
elif v.typ == Types.FLOAT4:
self.kernel.append(f"(({self.lang.buffer_prefix if self.bufs[buf_index] is not None else self.lang.smem_prefix}float4*){self.buftokens[buf_index].tok})[{(idxy//4).render(render_cl)}] = {v.tok};\n")
else:
@@ -96,12 +97,21 @@ class GPUCodegen(ASTKernel):
def load(self, buf_index:int, idx_override:Optional[str]=None) -> List[Token]:
# constant folding
const = None
if self.bufs[buf_index] is not None and self.bufs[buf_index]._base_shape == (1,) and self.bufs[buf_index]._backing is not None:
if self.bufs[buf_index] is not None and isinstance(self.bufs[buf_index].realized, RawConst):
# bufs_to_delete can be removed, just ignore RawConst at runtime
if buf_index != 0: self.bufs_to_delete.add(buf_index)
val = self.bufs[buf_index]._backing[0]
val = self.bufs[buf_index].realized._buf
assert not math.isnan(val)
const = Token(f"({val}f)", Types.FLOAT)
should_upcast = self.lang.float4 and const is None and self.buftokens[buf_index].can_float4() and (self.bufs[buf_index] is None or self.bufs[buf_index].dtype != dtypes.float16 or hasattr(self.bufs[buf_index]._buf, "IMAGE"))
def check_no_mul(test, var):
if test == var: return True
if isinstance(test, SumNode): return any(check_no_mul(x, var) for x in test.nodes) # in a sum is okay
if isinstance(test, ModNode) and test.b%4 == 0: return check_no_mul(test.a, var) # removing a mod is okay
return False
is_image = self.bufs[buf_index] is not None and self.bufs[buf_index].dtype.name.startswith('image')
should_upcast = self.lang.float4 and const is None and self.buftokens[buf_index].can_float4() and (self.bufs[buf_index] is None or self.bufs[buf_index].dtype != dtypes.float16 or self.bufs[buf_index].dtype.name.startswith('image'))
tokens = []
test_idy = []
for o in self.buftokens[buf_index].offsets():
@@ -111,21 +121,22 @@ class GPUCodegen(ASTKernel):
if should_upcast:
float4_index = Variable("FLOAT4_INDEX", 0, 3)
idxy_test, valid_test = self.sts[buf_index].expr_idxs(float4_index+o) if idx_override is None else self.sts[buf_index].expr_node(idx_override, float4_index+o)
can_merge = idxy_test == float4_index or (isinstance(idxy_test, SumNode) and any(x == float4_index for x in idxy_test.nodes)) # float4_index must be in there without a multiply
can_merge = can_merge and "FLOAT4_INDEX" not in (idxy_test//4).render() and "FLOAT4_INDEX" not in valid_test.render() # float4_index must not be in after divide or in valid (TODO: don't check render)
can_merge = check_no_mul(idxy_test, float4_index)
# NOTE: valid_test.render() can contain a FLOAT4_INDEX that can't affect the result: example <(((idx0<0,511>*4)+FLOAT4_INDEX<0,3>)<1024)>
can_merge = can_merge and "FLOAT4_INDEX" not in (idxy_test//4).render() and ("FLOAT4_INDEX" not in valid_test.render() or True) # float4_index must not be in after divide or in valid (TODO: don't check render)
if const is not None:
ldr = const
elif self.bufs[buf_index] is not None and hasattr(self.bufs[buf_index]._buf, "IMAGE"):
assert should_upcast and can_merge, f"Image requires upcasting to FLOAT4 {self.buftokens[buf_index]}"
idx, idy = to_image_idx(self.bufs[buf_index]._base_shape, idxy, valid, VALIDHACKS)
ldr = Token(f"read_imagef({self.buftokens[buf_index].tok}, smp, (int2)({idx.render(render_cl)}, {idy.render(render_cl)})) /* {self.bufs[buf_index]._base_shape} */", Types.FLOAT4)
elif self.bufs[buf_index] is not None and is_image:
assert should_upcast and can_merge, f"Image requires upcasting to FLOAT4 {self.buftokens[buf_index]} should_upcast:{should_upcast} can_merge:{can_merge}"
idx, idy = to_image_idx(self.bufs[buf_index].dtype.shape, idxy, valid, VALIDHACKS)
ldr = Token(f"read_imagef({self.buftokens[buf_index].tok}, smp, (int2)({idx.render(render_cl)}, {idy.render(render_cl)})) /* {self.bufs[buf_index].dtype.shape} */", Types.FLOAT4)
test_idy.append(idy.render(render_cl))
elif should_upcast and can_merge:
ldr = Token(f"(({self.lang.buffer_prefix if self.bufs[buf_index] is not None else self.lang.smem_prefix}float4*){self.buftokens[buf_index].tok})[{(idxy//4).render(render_cl)}]", Types.FLOAT4)
else:
ldr = Token(f"{self.buftokens[buf_index].tok}[{idxy.render(render_cl)}]", Types.FLOAT)
invalid = self.group_float4([Token("0.0f", Types.FLOAT)]*4) if ldr.typ == Types.FLOAT4 else Token("0.0f", Types.FLOAT)
ldr = ldr if valid.min == 1 or (VALIDHACKS and hasattr(self.bufs[buf_index]._buf, "IMAGE")) else (Token(f"({valid.render(render_cl)} ? {ldr.tok} : {invalid.tok})", ldr.typ) if valid.max == 1 else invalid)
ldr = ldr if valid.min == 1 or (VALIDHACKS and is_image) else (Token(f"({valid.render(render_cl)} ? {ldr.tok} : {invalid.tok})", ldr.typ) if valid.max == 1 else invalid)
if const is not None:
self.loaded_keys[(buf_index,o)] = ldr
else:
@@ -154,9 +165,10 @@ class GPUCodegen(ASTKernel):
def required_optimizations(self, early_only=False):
for buf_index,buf in enumerate(self.bufs):
upcast_strides = [self.sts[buf_index].strides[i] for i in self.upcast_in_mid_reduce_axes]
if (not early_only or buf in self.earlybufs) and hasattr(buf._buf, "IMAGE") and not (self.buftokens[buf_index].can_float4() or (buf not in self.earlybufs and (1 in upcast_strides))):
if (not early_only or buf in self.earlybufs) and self.bufs[buf_index].dtype.name.startswith('image') and not (self.buftokens[buf_index].can_float4() or (buf not in self.earlybufs and (1 in upcast_strides))):
axes = [i for i,x in enumerate(self.sts[buf_index].strides) if x == 1]
assert len(axes) == 1, f"wrong number of stride 1 axis : {axes}"
assert len(axes) == 1, f"wrong number of stride 1 axis : {axes} on buf_index {buf_index}, {self.sts[buf_index]}"
assert self.sts[buf_index].shape[axes[0]]%4 == 0, f"axis:{axes[0]} in buffer {buf_index} is not a multiple of 4, {self.sts[buf_index].shape}"
self.shift_to(axes[0], 4)
self.upcast()
assert self.buftokens[buf_index].can_float4()
@@ -177,12 +189,13 @@ class GPUCodegen(ASTKernel):
self.group_for_reduce.append(sz)
break
# are we upcasting in mid reduce?
if hasattr(self.bufs[0]._buf, "IMAGE") and not self.buftokens[0].can_float4() and self.group_for_reduce and self.first_reduce <= 2:
# are we upcasting in mid reduce? (only for images)
if self.bufs[0].dtype.name.startswith('image') and not self.buftokens[0].can_float4() and self.group_for_reduce and self.first_reduce <= 2 and prod(self.sts[0].shape) > 1:
axes = [i for i,x in enumerate(self.sts[0].strides) if x == 1]
assert len(axes) == 1, f"wrong number of stride 1 axis : {axes}"
self.shift_to(axes[0], 4, insert_before=self.first_reduce + len(self.group_for_reduce)) # insert at the end of the grouped axis
self.group_for_reduce.append(4)
if self.sts[0].shape[axes[0]]%4 == 0:
self.shift_to(axes[0], 4, insert_before=self.first_reduce + len(self.group_for_reduce)) # insert at the end of the grouped axis
self.group_for_reduce.append(4)
# now do everything required
self.required_optimizations()
@@ -191,8 +204,8 @@ class GPUCodegen(ASTKernel):
self.simplify_ones()
# use more opencl indexing if the output buffer is an image and we have room
if hasattr(self.bufs[0]._buf, "IMAGE") and self.first_reduce+len(self.group_for_reduce) < 3:
base_shape = self.bufs[0]._base_shape
if self.bufs[0].dtype.name.startswith('image') and self.first_reduce+len(self.group_for_reduce) < 3:
base_shape = self.bufs[0].dtype.shape
if (base_shape[0]*base_shape[1]) % self.sts[0].shape[0] == 0 and self.sts[0].shape[0]//base_shape[0] != 0:
if DEBUG >= 4: print("split opencl", base_shape, self.sts[0].shape)
self.reshape_and_permute(lambda x: [base_shape[0], x[0]//base_shape[0]]+list(x[1:]), None)
@@ -266,7 +279,7 @@ class GPUCodegen(ASTKernel):
self.bufs_to_delete: Set[int] = set()
self.loaded_keys: Dict[Tuple[int,int], Token] = {}
self.prekernel: Set[str] = set()
self.kernel: List[str] = ["const sampler_t smp = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST;\n"] if any(hasattr(buf._buf, "IMAGE") for buf in self.bufs if buf is not None) else []
self.kernel: List[str] = ["const sampler_t smp = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST;\n"] if any(buf.dtype.name.startswith("image") for buf in self.bufs if buf is not None) else []
if self.lang.half_prekernel and any(x.dtype == dtypes.float16 for x in self.bufs if x is not None): self.prekernel.add(self.lang.half_prekernel+"\n")
@@ -307,7 +320,7 @@ class GPUCodegen(ASTKernel):
for j in self.upcast_in_mid_reduce_axes:
self.reshape_and_permute(None, [i for i in range(self.shape_len) if i != j] + [j])
self.upcast()
if DEBUG >= 4: print("upcast", self.colorshape()) # NOTE: colorshape is wrong here
#if DEBUG >= 4: print("upcast", self.colorshape()) # NOTE: colorshape is wrong here, you have to remove it from group_for_reduce before calling
self.kernel.append(f"if ({lidx.render(render_cl)} == 0) {{\n") # lidx.max works here too
@@ -324,7 +337,7 @@ class GPUCodegen(ASTKernel):
self.kernel.append("\n}")
# concat kernel into prg
buftypes = [f"{'read_only' if i > 0 else 'write_only'} image2d_t" if hasattr(x._buf, "IMAGE") else self.lang.buffer_prefix+self.buftokens[i].decltype(self.bufs[i].dtype)+self.lang.buffer_suffix for i,x in enumerate(self.bufs) if x is not None]
buftypes = [f"{'read_only' if i > 0 else 'write_only'} image2d_t" if x.dtype.name.startswith('image') else self.lang.buffer_prefix+self.buftokens[i].decltype(self.bufs[i].dtype)+self.lang.buffer_suffix for i,x in enumerate(self.bufs) if x is not None]
prg = ' '.join(list(self.prekernel) + [f"{self.lang.kernel_prefix} void KERNEL_NAME_PLACEHOLDER(",] +
[', '.join([f'{t} data{i}' for i,t in enumerate(buftypes) if i not in self.bufs_to_delete] + self.lang.extra_args)] +
[") {\n"] + self.kernel)
@@ -342,4 +355,5 @@ class GPUCodegen(ASTKernel):
return ASTRunner(function_name, prg.replace("KERNEL_NAME_PLACEHOLDER", function_name), self.bufs_to_delete,
list(self.output_shape[::-1]) if len(self.output_shape) > 0 else [1],
(self.group_for_reduce[::-1] + [1]*(len(self.output_shape)-len(self.group_for_reduce))) if self.group_for_reduce else None,
op_estimate=self.info.flops, mem_estimate=sum(x.dtype.itemsize*prod(x._base_shape) for x in self.bufs if x is not None))
op_estimate=self.info.flops,
mem_estimate=sum(x.dtype.itemsize*(x.realized.size if x.realized is not None else prod(x.shape)) for x in self.bufs if x is not None))

View File

@@ -212,4 +212,6 @@ class LLVMCodegen(ASTKernel):
loop_entry[-1].branch(loop_exit[-1]._block)
loop_exit[0].ret_void()
return ASTRunner('exec', str(module), op_estimate=self.info.flops, mem_estimate=sum(x.dtype.itemsize*prod(x._base_shape) for x in self.bufs))
# TODO: mem_estimate is copied from GPU
return ASTRunner('exec', str(module), op_estimate=self.info.flops,
mem_estimate=sum(x.dtype.itemsize*(x.realized.size if x.realized is not None else prod(x.shape)) for x in self.bufs if x is not None))

View File

@@ -5,7 +5,8 @@ except ImportError:
nx = None # graph won't work
from collections import defaultdict
from typing import Dict, List, Optional
from tinygrad.ops import DeviceBuffer, UnaryOps, BinaryOps, ReduceOps, MovementOps, LoadOps, FusedOps, Op, OpType, LazyOp, get_buffers, get_lazyops
from tinygrad.ops import UnaryOps, BinaryOps, ReduceOps, MovementOps, LoadOps, FusedOps, Op, OpType, LazyOp, get_buffers, get_lazyops
from tinygrad.lazy import LazyBuffer
from tinygrad.helpers import getenv, DEBUG
GRAPH, PRUNEGRAPH, GRAPHPATH = getenv("GRAPH", 0), getenv("PRUNEGRAPH", 0), getenv("GRAPHPATH", "/tmp/net")
@@ -37,18 +38,22 @@ def get_sop(op: List[Op]):
if len(op) <= 4: return '.'.join([str(y).split(".")[1][0:3] for y in op][::-1])
return str(len(op))
def log_op(ret: DeviceBuffer, ast: LazyOp, show_graph: Optional[bool] = None):
def str_dtype(dtyp):
ret = str(dtyp)[7:]
return "" if ret == 'float' else f"\n{ret}"
def log_op(ret: LazyBuffer, ast: LazyOp, show_graph: Optional[bool] = None):
if show_graph is None: show_graph = bool(GRAPH)
if not DEBUG and not show_graph: return
op: List[Op] = [x.op for x in get_lazyops(ast)]
inp: List[DeviceBuffer] = get_buffers(ast)
inp: List[LazyBuffer] = get_buffers(ast)
if len(inp) == 1 and inp[0] == ret:
if show_graph and nm(ret) in G.nodes: G.nodes[nm(ret)]['style'] += ', bold'
return # don't log self loops
oporder = [LoadOps, FusedOps, ReduceOps, BinaryOps, UnaryOps, MovementOps]
optype = type(sorted(op, key=lambda x: oporder.index(type(x)))[0])
cnts[optype] += 1
if DEBUG >= 4: print(f"{op} : {', '.join([f'{x.shape}-<{nm(x)}>' for x in inp])} -> {ret.shape}-<{nm(ret)}>")
if DEBUG >= 6: print(f"{op} : {', '.join([f'{x.shape}-<{nm(x)}>' for x in inp])} -> {ret.shape}-<{nm(ret)}>")
if show_graph:
top_colors = {LoadOps: '#FFFF80', UnaryOps: "#c0c0c0", ReduceOps: "#8080ff", BinaryOps: "#c0c0c0", MovementOps: "#80ff80", FusedOps: "#ff8080"}
dashed = (optype == LoadOps and hasattr(ret, "_backing")) or (hasattr(ret, "st") and not ret.st.contiguous) # type: ignore
@@ -56,10 +61,10 @@ def log_op(ret: DeviceBuffer, ast: LazyOp, show_graph: Optional[bool] = None):
for x in inp:
G.add_edge(nm(x), nm(ret), label=get_sop(op))
if 'label' not in G.nodes[nm(x)]:
G.nodes[nm(x)]['label'] = str(x.shape)
G.nodes[nm(x)]['label'] = str(x.shape)+str_dtype(ret.dtype)
if nm(ret) not in G.nodes: G.add_node(nm(ret))
G.nodes[nm(ret)]['label'] = str(set(x.shape for x in inp))+"\n"+str(ret.shape) if optype == ReduceOps else str(ret.shape)
G.nodes[nm(ret)]['label'] = (str(set(x.shape for x in inp))+"\n"+str(ret.shape) if optype == ReduceOps else str(ret.shape))+str_dtype(ret.dtype)
G.nodes[nm(ret)]['fillcolor'] = (top_colors[optype] + ('80' if dashed else str())) if optype in top_colors else "#ffffff"
G.nodes[nm(ret)]['style'] = 'filled, dashed' if dashed else 'filled'
G.nodes[nm(ret)]['prunable'] = optype in [LoadOps, MovementOps]

View File

@@ -1,3 +1,4 @@
from __future__ import annotations
import os, math, functools
import numpy as np
from typing import Tuple, Union, List, NamedTuple, Final, Iterator, ClassVar, Optional, Callable, Any
@@ -23,23 +24,33 @@ DEBUG, IMAGE = getenv("DEBUG", 0), getenv("IMAGE", 0)
# **** tinygrad now supports dtypes! *****
class DType(NamedTuple):
priority: int # this determines when things get upcasted
itemsize: int
name: str
np: type # TODO: someday this will be removed with the "remove numpy" project
def __repr__(self): return f"dtypes.{self.name}"
# dependent typing?
class ImageDType(DType):
def __new__(cls, priority, itemsize, name, np, shape):
return super().__new__(cls, priority, itemsize, name, np)
def __init__(self, priority, itemsize, name, np, shape):
self.shape: Tuple[int, ...] = shape # arbitrary arg for the dtype, used in image for the shape
super().__init__()
def __repr__(self): return f"dtypes.{self.name}({self.shape})"
class LazyNumpyArray:
def __init__(self, fxn, shape, dtype): self.fxn, self.shape, self.dtype = fxn, shape, dtype
def __call__(self): return self.fxn(self)
def __call__(self) -> np.ndarray: return np.ascontiguousarray(self.fxn(self) if callable(self.fxn) else self.fxn).reshape(self.shape).astype(self.dtype)
def reshape(self, new_shape): return LazyNumpyArray(self.fxn, new_shape, self.dtype)
def copy(self): return self
def astype(self, typ): return self
def copy(self): return self if callable(self.fxn) else LazyNumpyArray(self.fxn.copy(), self.shape, self.dtype)
def astype(self, typ): return LazyNumpyArray(self.fxn, self.shape, typ)
class dtypes:
float16: Final[DType] = DType(2, "half", np.float16)
float32: Final[DType] = DType(4, "float", np.float32)
float16: Final[DType] = DType(0, 2, "half", np.float16)
float32: Final[DType] = DType(1, 4, "float", np.float32)
@staticmethod
def from_np(x:Union[LazyNumpyArray, np.ndarray]) -> DType: return {np.dtype(np.float16): dtypes.float16, np.dtype(np.float32): dtypes.float32}[np.dtype(x.dtype)]
def from_np(x) -> DType: return {np.dtype(np.float16): dtypes.float16, np.dtype(np.float32): dtypes.float32}[np.dtype(x)]
class GlobalCounters:
global_ops: ClassVar[int] = 0

View File

@@ -1,56 +0,0 @@
from __future__ import annotations
from typing import Tuple, Any, ClassVar, Optional, Callable, Dict
import functools
from tinygrad.helpers import DType, dtypes, prod, GlobalCounters, DEBUG
from tinygrad.shape.shapetracker import ShapeTracker
from tinygrad.ops import DeviceBuffer, LazyOp, get_buffers, map_buffers, Op, FusedOps, UnaryOps, MovementOps, ReduceOps, BinaryOps
# this is a quick "buffer" class for flop tracking and getting the output shape
class GenericShape:
def __init__(self, shape:Tuple[int, ...], dtype:DType=dtypes.float32, flops:int=0): self.shape, self.dtype, self.flops = shape, dtype, flops
def consume_flops(self):
self.flops, ret = 0, self.flops
return ret
shape_fxn_for_op: Dict[Op, Callable] = {
**{op:lambda self: GenericShape(self.shape, self.dtype, self.consume_flops() + prod(self.shape)) for op in UnaryOps},
**{op:lambda self,y: GenericShape(self.shape, max(self.dtype, y.dtype), self.consume_flops() + y.consume_flops() + prod(self.shape)) for op in BinaryOps},
**{op:lambda self,new_shape: GenericShape(new_shape, self.dtype, self.consume_flops() + prod(self.shape)) for op in ReduceOps},
**{op:functools.partial(lambda mop,self,arg: GenericShape(ShapeTracker(self.shape).movement_op(mop, arg).shape, self.dtype, self.consume_flops()), op) for op in MovementOps}}
# this runs the LazyOp and gives you the output shape/dtype and flop count
def get_lazyop_info(ast:LazyOp) -> GenericShape: return InterpretedBuffer.exec_ast(map_buffers({x:InterpretedBuffer(GenericShape(x.shape, x.dtype)) for x in get_buffers(ast)}, ast))._buf
# used in CPUBuffer and TorchBuffer
class InterpretedBuffer(DeviceBuffer): # pylint: disable=abstract-method
fxn_for_op: ClassVar = shape_fxn_for_op
def __init__(self, lbuf:Any):
self._buf: Any = lbuf
self.shape: Tuple[int, ...] = tuple(lbuf.shape)
self.dtype: DType = self.to_tinygrad_dtype() if hasattr(self, 'to_tinygrad_dtype') else lbuf.dtype
# NOTE: this is overcounting the memory used, as reshapes and stuff are aliases
self._memsz = (prod(self.shape) * self.dtype.itemsize) if not isinstance(self, InterpretedBuffer) else 0
GlobalCounters.mem_used += self._memsz
def __del__(self): GlobalCounters.mem_used -= self._memsz
def contiguous(self): return type(self).exec_ast(LazyOp(op=UnaryOps.NOOP, src=(self,)))
def movement_op(self, op:MovementOps, arg=None): return type(self)(self.fxn_for_op[op](self._buf, arg)) if op in self.fxn_for_op else type(self)(getattr(self._buf, op.name.lower())(arg))
@classmethod
def exec_ast(cls, ast:LazyOp, output_buffer:Optional[InterpretedBuffer]=None, context=None):
if FusedOps.MULACC in cls.fxn_for_op and ast.op == ReduceOps.SUM and isinstance(ast.src[0], LazyOp) and ast.src[0].op == BinaryOps.MUL:
ast = LazyOp(FusedOps.MULACC, ast.src[0].src, ast.arg)
created_context = context is None
if context is None: context = dict()
if not created_context and ast in context: return context[ast]
srcs = [cls.exec_ast(x, context=context) if isinstance(x, LazyOp) else x for x in ast.src]
if ast.op in BinaryOps: assert srcs[0].shape == srcs[1].shape, f"BinaryOps shape mismatch {srcs[0].shape} != {srcs[1].shape}"
if ast.op in ReduceOps: assert all(r == n or n == 1 for r,n in zip(srcs[0].shape, ast.arg)), f"ReduceOps can't reduce {srcs[0].shape} -> {ast.arg}"
if ast.op in MovementOps: ret = srcs[0].movement_op(ast.op, ast.arg)
else: ret = cls(cls.fxn_for_op[ast.op](*([x._buf for x in srcs] + ([ast.arg] if ast.arg is not None else []))))
if DEBUG >= 4 or (not isinstance(cls, InterpretedBuffer) and DEBUG >= 3):
print(f"*** {'exec' if created_context else ' '} {GlobalCounters.mem_used/1e9:5.2f} GB op: {ast.op:20s} out({ret.dtype.name}): {str(ret.shape):30s} in({len(srcs)}):", list(set(x.shape for x in srcs)), ast.arg if ast.arg is not None else "")
if not created_context: context[ast] = ret
if output_buffer is not None:
assert output_buffer.shape == ret.shape, output_buffer.dtype == ret.dtype
output_buffer._buf = ret._buf
return output_buffer
else:
return ret

View File

@@ -4,7 +4,7 @@ from tinygrad.helpers import DEBUG, colored
from tinygrad.lazy import Device
from tinygrad.tensor import Tensor
from tinygrad.ops import GlobalCounters, CompiledBuffer, RawBuffer
from tinygrad.ops import GlobalCounters, RawBuffer
class TinyJit:
def __init__(self, fxn:Callable):
@@ -20,9 +20,7 @@ class TinyJit:
def __call__(self, *args, **kwargs) -> Any:
if Device.DEFAULT not in ["GPU", "CLANG", "METAL", "CUDA"]: return self.fxn(*args, **kwargs) # only jit on the GPU codegen
# NOTE: this cast is needed since although we know realize will create a ".realized" DeviceBuffer, the type checker doesn't
input_buffers: Dict[Union[int, str], CompiledBuffer] = {cast(Union[int, str], k):cast(CompiledBuffer, v.realize().lazydata.realized) for k,v in itertools.chain(enumerate(args), kwargs.items()) if isinstance(v, Tensor)}
# TODO: check the shapetrackers on the CompiledBuffers are what we jitted
input_rawbuffers: Dict[Union[int, str], RawBuffer] = {k:v.raw() for k,v in input_buffers.items()}
input_rawbuffers: Dict[Union[int, str], RawBuffer] = {cast(Union[int, str], k):cast(RawBuffer, v.realize().lazydata.realized) for k,v in itertools.chain(enumerate(args), kwargs.items()) if isinstance(v, Tensor)}
assert len(input_rawbuffers) != 0, "no inputs to JIT"
if self.cnt >= 2:
for (j,i),idx in self.input_replace.items(): self.jit_cache[j][1][i] = input_rawbuffers[idx]

View File

@@ -1,12 +1,11 @@
from __future__ import annotations
from typing import Optional, Tuple, Union, List, Dict, Any, ClassVar, Type
from typing import Optional, Tuple, Union, List, Dict, Any, ClassVar, cast
import sys, weakref, importlib, inspect, functools, pathlib
from weakref import WeakValueDictionary
from tinygrad.helpers import prod, getenv, DType, dtypes, LazyNumpyArray, flatten
from tinygrad.helpers import prod, getenv, DType, dtypes, LazyNumpyArray, flatten, ImageDType
from tinygrad.shape.shapetracker import ShapeTracker, get_contraction
from tinygrad.ops import DeviceBuffer, UnaryOps, BinaryOps, ReduceOps, MovementOps, LoadOps, OpType, LazyOp, get_buffers, get_lazyops, map_buffers
from tinygrad.runtime.ops_cpu import CPUBuffer
from tinygrad.graph import log_op
from tinygrad.ops import Compiled, Interpreted, UnaryOps, BinaryOps, ReduceOps, MovementOps, LoadOps, OpType, LazyOp, get_buffers, get_lazyops, map_buffers
from tinygrad.runtime.lib import RawConst, RawBuffer
# lazy can recurse a lot
sys.setrecursionlimit(10000)
@@ -19,7 +18,7 @@ class _Device:
self._buffers: List[str] = [x.stem[len("ops_"):].upper() for x in (pathlib.Path(__file__).parent/"runtime").iterdir() if x.stem.startswith("ops_")]
self.DEFAULT: str = functools.reduce(lambda val, ele: ele if getenv(ele) == 1 else val, self._buffers, "CPU")
@functools.lru_cache(maxsize=None) # this class is a singleton, pylint: disable=method-cache-max-size-none
def __getitem__(self, x:str) -> Type[DeviceBuffer]: return [cls for cname, cls in inspect.getmembers(importlib.import_module(f'tinygrad.runtime.ops_{x.lower()}'), inspect.isclass) if (cname.lower() == x.lower() + "buffer") and x in self._buffers][0]
def __getitem__(self, x:str) -> Union[Interpreted, Compiled]: return [cls for cname, cls in inspect.getmembers(importlib.import_module(f'tinygrad.runtime.ops_{x.lower()}')) if (cname.lower() == x.lower() + "buffer") and x in self._buffers][0]
Device = _Device()
# TODO: movement ops that only change shape are really nops. treat them as such
@@ -70,7 +69,7 @@ def get_movementroot_contiguous(x:LazyBuffer) -> LazyBuffer: return get_movement
def replace_with_movement_op(y:Union[LazyOp, LazyBuffer], op:MovementOps, arg:Tuple[Any, ...]) -> LazyBuffer:
if isinstance(y, LazyBuffer): return y.movement_op(op, arg)
assert y.op in BinaryOps or y.op in UnaryOps
return elementwise_op(y.op, *[replace_with_movement_op(z, op, arg) for z in y.src]) # type: ignore
return elementwise_op(y.op, *[replace_with_movement_op(z, op, arg) for z in y.src], arg=y.arg) # type: ignore
def support_weakref(x): return x
@support_weakref # needed for mypyc, this prevents LazyBuffer from becoming a native class
@@ -91,9 +90,10 @@ class LazyBuffer:
if hasattr(self, 'device'):
return # cache hit, we return and don't reinit
self.st = shape if isinstance(shape, ShapeTracker) else ShapeTracker(tuple(shape))
self.shape, self.optype, self.op, self.dtype = self.st.shape, optype, op, dtype
self.realized: Optional[DeviceBuffer] = None
self.output_buffer: Optional[DeviceBuffer] = None
self.shape, self.optype, self.dtype = self.st.shape, optype, dtype
self.op: LazyOp = op
self.realized: Optional[RawBuffer] = None
self.output_buffer: Optional[RawBuffer] = None
self.device, self.dbuffer = device, Device[device]
# TODO: does children have to be a ref count instead of a set? can a Buffer be a double child?
self.children: weakref.WeakSet[LazyBuffer] = weakref.WeakSet()
@@ -101,67 +101,72 @@ class LazyBuffer:
for x in get_buffers(op): x.children.add(self)
if not LAZY: self.realize()
def __repr__(self): return f"<LB {self.shape} {self.dtype} op:{self.op.op if self.realized is None else 'realized'}>"
def __repr__(self): return f"<LB {self.shape} {self.dtype} op:{self.op.op if self.realized is None else 'realized'} st:{self.st}>"
# this produces a device buffer
def realize(self:LazyBuffer, required_device=None) -> DeviceBuffer:
def realize(self:LazyBuffer, required_device=None) -> LazyBuffer:
assert required_device is None or required_device == self.device
if self.realized is None:
# get real ops first
if self.op.op == LoadOps.FROMCPU:
# resolve LazyNumpyArray
ast = LazyOp(self.op.op, tuple(), self.op.arg() if isinstance(self.op.arg, LazyNumpyArray) else self.op.arg)
self.realized = Device[self.device].buffer.fromCPU(self.op.arg())
elif self.op.op == LoadOps.CONTIGUOUS:
real_src = self.op.src[0].realize(self.device)
self.realized = real_src.contiguous()
ast = LazyOp(self.op.op, (real_src, ))
elif self.op.op == LoadOps.CUSTOM:
real_srcs = tuple(x.realize(self.device) for x in self.op.src)
self.realized = self.op.arg(*real_srcs)
ast = LazyOp(self.op.op, real_srcs)
elif self.optype == MovementOps:
src = self.op.src[0]
# fuse RESHAPE and ReduceOps
# NOTE: this is sort of a hack for IMAGE, otherwise it shouldn't matter
if src.realized is None and src.optype == ReduceOps and self.op.op == MovementOps.RESHAPE and len(src.children) <= 1:
# it's okay to add a RESHAPE to the ast here
ast = LazyOp(MovementOps.RESHAPE, (_ast_reduceops(src), ), self.op.arg)
realized = self.op.src[0].realize(self.device).realized
if self.op.src[0].st.contiguous and not isinstance(realized, RawConst) and realized.size == prod(self.shape):
# no need to run an AST, this is already contiguous
self.realized = realized
else:
# movement ops aren't an AST, just run them
real_src = src.realize(self.device)
self.realized = real_src.movement_op(self.op.op, self.op.arg)
ast = LazyOp(self.op.op, (real_src, ))
elif self.optype == ReduceOps: ast = _ast_reduceops(self)
elif self.optype == BinaryOps: ast = _ast_binaryops(self)
# TODO: remove UnaryOps.NOOP, replace with LoadOps.CONTIGUOUS. confusing with Compiled though
self.op = LazyOp(UnaryOps.NOOP, self.op.src)
elif self.op.op == LoadOps.CUSTOM:
# this needs to immediately realize
self.realized = self.op.arg(self, *[x.realize(self.device) for x in self.op.src])
# these can be late folded and change the op to go further back in the graph
elif self.optype == ReduceOps: self.op = _ast_reduceops(self)
elif self.optype == BinaryOps: self.op = _ast_binaryops(self) # ISSUE: this can include a reshape
# run the ast if we still have to, and log the op
if self.realized is None:
for x in get_buffers(self.op): x.realize(self.device)
# HACK: image shape can be wrong, hot cast it back to a normal float
if self.optype != MovementOps and isinstance(self.dtype, ImageDType) and (prod(self.shape) != prod(self.dtype.shape) or self.shape[self.st.strides.index(1)]%4 != 0):
upcasted = LazyOp(UnaryOps.CAST, (self.op,), dtypes.float32)
if self.op.op == MovementOps.RESHAPE: self.op = LazyOp(MovementOps.RESHAPE, upcasted, self.op.arg)
else: self.op = upcasted
self.dtype = dtypes.float32
self.realized = Device[self.device].exec_ast(self.op, output=self)
# log to the graph
from tinygrad.graph import log_op
log_op(self, self.op)
# no need to keep the op after realization
del self.op
# run the ast if we still have to, and log the op
if self.realized is None:
ast = map_buffers({x:x.realize(self.device) for x in get_buffers(ast)}, ast)
self.realized = self.dbuffer.exec_ast(ast, output_buffer=self.output_buffer)
log_op(self.realized, ast)
assert self.realized.shape == self.shape, f"shape mismatch on realize got {self.realized.shape} expected {self.shape}"
assert isinstance(self.realized, Device[self.device]), f"device mismatch on realized got {type(self.realized)} expected {self.device}"
assert self.realized.dtype == self.dtype, f"dtype mismatch on realize got {self.realized.dtype} expected {self.dtype}"
return self.realized
assert isinstance(self.realized, (RawConst, Device[self.device].buffer)), f"device mismatch on realized got {type(self.realized)} expected {self.device}"
# HACK: allow hot casting of images
assert self.realized.dtype == self.dtype or self.dtype.name.startswith("image"), f"dtype mismatch on realize got {self.realized.dtype} expected {self.dtype}"
self.dtype = self.realized.dtype
return self
# NOTE: we have to make a copy of the numpy array here in case the user changes it. expose this? LazyNumpyArray doesn't have this problem
@staticmethod
def fromCPU(x, device) -> LazyBuffer: return LazyBuffer(device, x.shape, LoadOps, LazyOp(LoadOps.FROMCPU, tuple(), x.copy()), dtypes.from_np(x))
def fromCPU(x:LazyNumpyArray, device) -> LazyBuffer:
return LazyBuffer(device, x.shape, LoadOps, LazyOp(LoadOps.FROMCPU, tuple(), x.copy()), dtypes.from_np(x.dtype))
# NOTE: we also have to copy the numpy array on the way out...otherwise the underlying Tensor could be freed and use after free. improve this?
def toCPU(self):
ret = self.realize().toCPU()
log_op(CPUBuffer(ret), LazyOp(LoadOps.TOCPU, (self.realized,), None))
realized = self.cast(dtypes.from_np(self.dtype.np)).contiguous().realize().realized
ret = cast(RawBuffer, realized).toCPU().reshape(self.shape)
return ret.copy()
def cast(self:LazyBuffer, arg:DType) -> LazyBuffer: return elementwise_op(UnaryOps.CAST, self, arg=arg) if self.dtype != arg else self
def unary_op(self:LazyBuffer, op:UnaryOps) -> LazyBuffer: return elementwise_op(op, self)
def binary_op(self:LazyBuffer, op:BinaryOps, y:LazyBuffer) -> LazyBuffer: return elementwise_op(op, self, y)
def contiguous(self:LazyBuffer) -> LazyBuffer: return LazyBuffer(self.device, self.shape, LoadOps, LazyOp(LoadOps.CONTIGUOUS, (self,)), self.dtype)
def contiguous(self:LazyBuffer) -> LazyBuffer:
if self.realized is None and self.op.op == LoadOps.CONTIGUOUS: return self # two CONTIGUOUS in a row is one
return LazyBuffer(self.device, self.shape, LoadOps, LazyOp(LoadOps.CONTIGUOUS, (self,)), self.dtype)
def reduce_op(self:LazyBuffer, op:ReduceOps, new_shape:Tuple[int, ...]) -> LazyBuffer:
if self.shape == tuple(new_shape): return self
@@ -199,8 +204,8 @@ class LazyBuffer:
# some permutes are actually just reshapes
if op == MovementOps.PERMUTE and local_st.contiguous: return self.movement_op(MovementOps.RESHAPE, tuple(self.shape[i] for i in arg))
# move permutes before expands
if op == MovementOps.PERMUTE and PUSH_PERMUTES and self.realized is None and self.op.op == MovementOps.EXPAND:
# move permutes before expands (always, this is safe)
if op == MovementOps.PERMUTE and self.realized is None and self.op.op == MovementOps.EXPAND:
self.op.src[0].children.discard(self)
return self.op.src[0].movement_op(MovementOps.PERMUTE, arg).movement_op(MovementOps.EXPAND, tuple(self.op.arg[a] for a in arg))
@@ -228,8 +233,8 @@ class LazyBuffer:
return ret
def elementwise_op(op:Union[UnaryOps, BinaryOps], *srcs:LazyBuffer) -> LazyBuffer:
out_device, out_shape, out_dtype = srcs[0].device, srcs[0].shape, max(x.dtype for x in srcs)
def elementwise_op(op:Union[UnaryOps, BinaryOps], *srcs:LazyBuffer, arg:Optional[Any]=None) -> LazyBuffer:
out_device, out_shape, out_dtype = srcs[0].device, srcs[0].shape, max(x.dtype for x in srcs) if op != UnaryOps.CAST else cast(DType, arg)
# push all contiguous to the end of BinaryOps. kernels 198 -> 196
if PUSH_CONTIGUOUS and any(x.realized is None and x.op.op == LoadOps.CONTIGUOUS and len(x.op.src[0].children) <= 1 for x in srcs):
@@ -240,10 +245,10 @@ def elementwise_op(op:Union[UnaryOps, BinaryOps], *srcs:LazyBuffer) -> LazyBuffe
new_srcs.append(x.op.src[0])
else:
new_srcs.append(x)
return elementwise_op(op, *new_srcs).contiguous()
return elementwise_op(op, *new_srcs, arg=arg).contiguous()
if MERGE_ELEMENTWISE_OPS or (MERGE_UNARY_OPS and len(set(srcs)) == 1):
# remove the buffers from any (childless) BinaryOps that feed into this
srcs = tuple(x.op if x.optype == BinaryOps and len(x.children) == 0 and x.realized is None else x for x in srcs) # type: ignore
return LazyBuffer(out_device, out_shape, BinaryOps, LazyOp(op, srcs), out_dtype)
return LazyBuffer(out_device, out_shape, BinaryOps, LazyOp(op, srcs, arg), out_dtype)

View File

@@ -8,6 +8,13 @@ class Contiguous(Function):
def forward(self, x): return x.contiguous()
def backward(self, grad_output): return grad_output
class Cast(Function):
def forward(self, x, dtype):
self.input_dtype = x.dtype
return x.cast(dtype)
def backward(self, grad_output):
return grad_output.cast(self.input_dtype)
# ************* unary ops *************
class Log(Function):

View File

@@ -1,6 +1,10 @@
from tinygrad.helpers import prod, IMAGE
import numpy as np
from tinygrad.helpers import prod, IMAGE, ImageDType, getenv, dtypes
from tinygrad.lazy import get_single_root
FLOAT16 = getenv("FLOAT16", 0)
base_image_type = (100, 2, "image_half", np.float16) if FLOAT16 else (100, 4, "image_float", np.float32)
def image_dot(self, w):
# NOTE: we use a 1x1 conv2d to do the matmul. mxk @ kxn = (1,k,m,1).conv2d(n,k,1,1)
bs, groups = prod(self.shape[0:-2]), prod(w.shape[0:-2])
@@ -49,6 +53,7 @@ def image_conv2d(self, weight, bias=None, groups=1, stride=1, dilation=1, paddin
else: w = w.reshape(cout//4,4,cin//4,4,H,W).permute(0,4,2,5,3,1).reshape(cout//4, H*cin//4*W*4, 4)
# contiguous creates the image, and early realize static weights (TODO: test for the static weight)
if IMAGE >= 2: x,w = x.cast(ImageDType(*base_image_type, shape=x.shape)), w.cast(ImageDType(*base_image_type, shape=w.shape))
x, w = x.contiguous(), w.contiguous()
if get_single_root(w.lazydata).realized: w.realize()
@@ -71,10 +76,14 @@ def image_conv2d(self, weight, bias=None, groups=1, stride=1, dilation=1, paddin
# prepare weights
w = w.permute(0,4,2,5,1,3)
w = w.reshape((1, 1, 1, *cout_expand, rcin_hi, rcin_lo, H, W))
w = w.reshape((1, 1, 1, *cout_expand, rcin_hi, rcin_lo, H, W)).expand(x.shape)
# the conv!
ret = (x*w).sum((-4, -3, -2, -1)).reshape(bs*oy, ox*cout//4, 4)
# the conv! (+ the bias)
ret = (x*w).cast(dtypes.float32).sum((-4, -3, -2, -1))
# reshape to image and cast back to image
ret = ret.reshape(bs*oy, ox*cout//4, 4)
if IMAGE >= 2: ret = ret.cast(ImageDType(*base_image_type, shape=ret.shape))
if IMAGE >= 3: ret = ret.contiguous()
# undo hack for non multiples of 4 on C.rcout

View File

@@ -1,15 +1,14 @@
from __future__ import annotations
import functools, itertools, operator, random
import numpy as np
from enum import Enum, auto
from typing import Union, Type, NamedTuple, Tuple, Any, List, ClassVar, Optional, Dict, Set, Final
from tinygrad.helpers import prod, DEBUG, getenv, DType, dtypes, GlobalCounters
from tinygrad.shape.shapetracker import ShapeTracker, MovementOps
from typing import Union, Type, NamedTuple, Tuple, Any, List, Optional, Dict, Set, Callable
from tinygrad.helpers import prod, DEBUG, getenv, GlobalCounters
from tinygrad.shape.shapetracker import MovementOps
from tinygrad.runtime.lib import RawBuffer
# these are the llops your accelerator must implement, along with toCpu
# the Enum class doesn't work with mypy, this is static. sorry it's ugly
class UnaryOps(Enum): NOOP = auto(); EXP = auto(); LOG = auto(); NEG = auto(); NOT = auto() # noqa: E702
class UnaryOps(Enum): NOOP = auto(); EXP = auto(); LOG = auto(); NEG = auto(); NOT = auto(); CAST = auto() # noqa: E702
class BinaryOps(Enum): ADD = auto(); SUB = auto(); MUL = auto(); DIV = auto(); POW = auto(); CMPEQ = auto(); MAX = auto() # noqa: E702
class ReduceOps(Enum): SUM = auto(); MAX = auto() # noqa: E702
class FusedOps(Enum): MULACC = auto() # noqa: E702
@@ -32,15 +31,6 @@ def map_buffers(real_srcs:Dict[Any, Any], x:Any) -> LazyOp:
if len(real_srcs) and x in real_srcs: return map_buffers(real_srcs, real_srcs[x]) if isinstance(real_srcs[x], LazyOp) else real_srcs[x]
return LazyOp(x.op, tuple((map_buffers(real_srcs, y) if isinstance(y, LazyOp) else real_srcs[y]) for y in x.src), x.arg)
# a placeholder class to extend by the exec classes
class DeviceBuffer:
_buf: Any # underlying buffer
shape: Tuple[int, ...]
dtype: DType
@classmethod
def exec_ast(cls, ast:LazyOp, output_buffer=None): raise NotImplementedError("must be implemented")
def toCPU(self) -> np.ndarray: raise NotImplementedError("must be implemented")
class ASTRunner:
def __init__(self, name, prg, bufs_to_delete:Optional[Set[int]]=None, global_size:Optional[List[int]]=None, local_size:Optional[List[int]]=None, op_estimate=0, mem_estimate=0):
if DEBUG >= 4: print(prg)
@@ -50,8 +40,9 @@ class ASTRunner:
self.clprg = runtime(self.name, self.prg)
return self
def exec(self, bufs:List[Optional[CompiledBuffer]]) -> Optional[float]:
rawbufs = [x.raw() for i,x in enumerate(bufs) if x is not None and i not in self.bufs_to_delete]
def exec(self, bufs) -> Optional[float]:
rawbufs = [x.realized for i,x in enumerate(bufs) if x is not None and i not in self.bufs_to_delete]
assert all(x is not None for x in rawbufs), "some rawbufs are None, you probably didn't realize them"
if getenv("OPTLOCAL") and self.global_size is not None and self.local_size is None: self.local_size = self.optimize_local_size(rawbufs)
if GlobalCounters.cache is not None: GlobalCounters.cache.append((self, rawbufs))
return self(rawbufs)
@@ -81,66 +72,84 @@ class ASTRunner:
local_sizes = [list(x) for x in itertools.product(*local_dims) if prod(x) <= MAX_WORKGROUP] * 2 # try each valid size twice
return min([(self.timeit(rawbufs, local_size), local_size) for local_size in random.sample(local_sizes, len(local_sizes))])[1]
from tinygrad.codegen.ast import ASTKernel
class Specialized(NamedTuple):
raw_buffer: Type[RawBuffer]
codegen: Type[ASTKernel]
runtime: Type
class Interpreted:
def __init__(self, buffer, fxn_for_op: Dict[Op, Callable], from_lazybuffer=lambda x: x.realized, to_underlying=lambda x: x._buf):
self.buffer = buffer
self.fxn_for_op = fxn_for_op
self.from_lazybuffer = from_lazybuffer
self.to_underlying = to_underlying
# assumes you are using ShapeTracker
# used in GPUBuffer and LLVMBuffer
class CompiledBuffer(DeviceBuffer): # pylint: disable=abstract-method
spec: ClassVar[Specialized]
def __init__(self, shape:Union[ShapeTracker, Tuple[int, ...]], hostbuf:Optional[CompiledBuffer]=None, backing:Optional[np.ndarray]=None, force_create=False, dtype:DType=dtypes.float32):
self.st: ShapeTracker = shape if isinstance(shape, ShapeTracker) else ShapeTracker(tuple(shape))
self.shape: Tuple[int, ...] = self.st.shape
self.dtype: DType = dtype
assert hostbuf is None or hostbuf.dtype == dtype, f"hostbuf dtype {hostbuf.dtype} != {dtype}"
self._base_shape: Tuple[int, ...] = hostbuf._base_shape if hostbuf is not None else self.shape
self._buf = hostbuf._buf if hostbuf is not None else None
self._backing: Optional[np.ndarray] = hostbuf._backing if hostbuf is not None else backing
assert self._backing is None or dtypes.from_np(self._backing) == dtype, f"backing dtype {dtypes.from_np(self._backing)} != {dtype}"
if (self._backing is not None and self._backing.shape != (1,)) or force_create: self.raw()
def __repr__(self): return f"{type(self).__name__}(shape={self.st}, hostbuf={type(self).__name__}(shape={self._base_shape}" + (f", backing=np.array({self._backing}, dtype=np.{self.dtype.np.__name__}), dtype={self.dtype}), dtype={self.dtype})" if self._backing is not None else f", force_create=True, dtype={self.dtype}), dtype={self.dtype})")
def create_raw_buffer(self, shape:Tuple[int, ...], backing:Optional[np.ndarray], dtype:DType) -> RawBuffer:
assert backing is None or prod(shape) == prod(backing.shape), "backing has the wrong shape"
assert backing is None or GlobalCounters.cache is None, f"can't copy in {backing.shape} while caching"
if DEBUG >= 4: print(f"create raw buffer {shape} {dtype} backed:{backing is not None}")
return self.spec.raw_buffer(prod(shape), dtype) if backing is None else self.spec.raw_buffer.fromCPU(backing)
def raw(self) -> RawBuffer:
if self._buf is None:
if DEBUG >= 4 and self._backing is not None: print(f"**** copy in {self._backing.shape} to {type(self)}")
self._buf = self.create_raw_buffer(self._base_shape, self._backing, self.dtype)
self._backing = None
return self._buf
def toCPU(self) -> np.ndarray:
assert GlobalCounters.cache is None, f"can't copy out {self} while caching"
if DEBUG >= 3: print(f"**** copy out {self.shape}")
return self.contiguous().raw().toCPU().reshape(self.shape)
method_cache: Final[Dict[str, ASTRunner]] = {}
@classmethod
def exec_ast(cls, ast:LazyOp, output_buffer:Optional[CompiledBuffer]=None):
if ast.op == LoadOps.FROMCPU: return cls(ast.arg.shape, backing=ast.arg.ravel(), dtype=dtypes.from_np(ast.arg))
k = cls.spec.codegen(ast, output_buffer)
if getenv("ENABLE_METHOD_CACHE", 1): # this is the default now
if k.key not in cls.method_cache: cls.method_cache[k.key] = k.codegen().build(cls.spec.runtime)
elif DEBUG >= 4: print(f"method cache hit : {k.key}")
prg = cls.method_cache[k.key]
def exec_ast(self, ast:LazyOp, output=None, context=None):
if FusedOps.MULACC in self.fxn_for_op and ast.op == ReduceOps.SUM and isinstance(ast.src[0], LazyOp) and ast.src[0].op == BinaryOps.MUL:
ast = LazyOp(FusedOps.MULACC, ast.src[0].src, ast.arg)
created_context = context is None
if context is None: context = dict()
if not created_context and ast in context: return context[ast]
srcs = [self.exec_ast(x, context=context) if isinstance(x, LazyOp) else self.from_lazybuffer(x) for x in ast.src]
ret = self.buffer(self.fxn_for_op[ast.op](*([self.to_underlying(x) for x in srcs] + ([ast.arg] if ast.arg is not None else []))))
if DEBUG >= 4: print(f"*** {'exec' if created_context else ' '} {GlobalCounters.mem_used/1e9:5.2f} GB op: {ast.op:20s} out({ret.dtype.name}): {str(ret.shape):30s} in({len(srcs)}):", list(set(x.shape for x in srcs)), ast.arg if ast.arg is not None else "")
if not created_context: context[ast] = ret
if output is not None and output.output_buffer is not None:
assert output.output_buffer.size == ret.size, output.output_buffer.dtype == ret.dtype
output.output_buffer._buf = ret._buf
return output.output_buffer
else:
prg = k.codegen().build(cls.spec.runtime)
return ret
class FlopCounter:
def __init__(self, tup:Tuple[Tuple[int, ...], DType, int]): self.shape, self.dtype, self.flops = tup
def consume_flops(self):
self.flops, ret = 0, self.flops
return ret
from tinygrad.shape.shapetracker import ShapeTracker
shape_fxn_for_op: Dict[Op, Callable] = {
UnaryOps.CAST: lambda self,dtype: (self.shape, dtype, self.consume_flops() + prod(self.shape)),
**{op:lambda self: (self.shape, self.dtype, self.consume_flops() + prod(self.shape)) for op in UnaryOps if op != UnaryOps.CAST},
**{op:lambda self,y: (self.shape, max(self.dtype, y.dtype), self.consume_flops() + y.consume_flops() + prod(self.shape)) for op in BinaryOps},
**{op:lambda self,new_shape: (new_shape, self.dtype, self.consume_flops() + prod(self.shape)) for op in ReduceOps},
**{op:functools.partial(lambda mop,self,arg: (ShapeTracker(self.shape).movement_op(mop, arg).shape, self.dtype, self.consume_flops()), op) for op in MovementOps}}
InterpretedFlopCounter = Interpreted(FlopCounter, shape_fxn_for_op, lambda x: FlopCounter((x.shape, x.dtype, 0)), lambda x: x)
def get_lazyop_info(ast:LazyOp) -> FlopCounter: return InterpretedFlopCounter.exec_ast(ast)
from tinygrad.codegen.ast import ASTKernel
from tinygrad.helpers import DType
class Compiled:
def __init__(self, buffer: Type[RawBuffer], codegen: Type[ASTKernel], runtime):
self.buffer, self.codegen, self.runtime = buffer, codegen, runtime
self.method_cache: Dict[str, ASTRunner] = {}
def exec_ast(self, ast:LazyOp, output):
# all movementops do nothing in a Compiled buffer!
if ast.op in MovementOps and not isinstance(ast.src[0], LazyOp) and ast.src[0].realized is not None: return ast.src[0].realized
k = self.codegen(ast, output)
# this is the default now
if getenv("ENABLE_METHOD_CACHE", 1):
if k.key not in self.method_cache: self.method_cache[k.key] = k.codegen().build(self.runtime)
elif DEBUG >= 4: print(f"method cache hit : {k.key}")
prg = self.method_cache[k.key]
else:
prg = k.codegen().build(self.runtime)
if getenv("PRINT_AST", "") == prg.name or getenv("PRINT_AST", "") == "1":
k.print()
print(prg.prg)
# check if we can reuse the output buffer
# if it's aliased, don't use it
# NOTE: this is pretty wrong actually, who knows where else this buffer is used?
output.realized = output.output_buffer
if output.realized is not None:
for a in get_buffers(ast):
if a.realized == output.realized and not a.st.contiguous:
output.realized = None
break
# we don't have an output buffer, we have to create it
if output.realized is None:
output.realized = self.buffer(prod(output.shape), output.dtype)
prg.exec(k.bufs)
return k.ret
# universal for shape tracked
def contiguous(self): return self if self.st.contiguous and prod(self._base_shape) == prod(self.shape) else type(self).exec_ast(LazyOp(op=UnaryOps.NOOP, src=(self,)))
def movement_op(self, op:MovementOps, arg): return type(self)(ShapeTracker(self.st).movement_op(op, arg), hostbuf=self, dtype=self.dtype)
return output.realized

View File

@@ -1,13 +1,14 @@
import ctypes
import numpy as np
from typing import TypeVar, Type
from typing import TypeVar, Type, Any
from tinygrad.helpers import DType, dtypes, prod, GlobalCounters
_T = TypeVar("_T")
class RawBuffer: # pylint: disable=abstract-method
def __init__(self, size:int, dtype:DType):
def __init__(self, size:int, dtype:DType, buf:Any=None):
self.size: int = size
self.dtype: DType = dtype
self._buf = buf
self._memsz: int = size*dtype.itemsize
GlobalCounters.mem_used += self._memsz
def __del__(self): GlobalCounters.mem_used -= self._memsz
@@ -22,7 +23,7 @@ class RawBufferCopyIn(RawBuffer):
@classmethod
def fromCPU(cls, x:np.ndarray):
ret = cls(prod(x.shape), dtypes.from_np(x))
ret = cls(prod(x.shape), dtypes.from_np(x.dtype))
ret._copyin(x)
return ret
@@ -33,9 +34,7 @@ class RawBufferMapped(RawBufferCopyIn):
# this one is simple enough that i moved it out of the runtimes
class RawMallocBuffer(RawBufferMapped):
def __init__(self, size, dtype: DType):
super().__init__(size, dtype)
self._buf = ({dtypes.float32: ctypes.c_float, dtypes.float16: ctypes.c_int16}[dtype] * size)()
def __init__(self, size, dtype: DType): super().__init__(size, dtype, ({dtypes.float32: ctypes.c_float, dtypes.float16: ctypes.c_int16}[dtype] * size)())
def _buffer(self): return memoryview(self._buf)
class RawBufferCopyInOut(RawBufferCopyIn):
@@ -45,3 +44,5 @@ class RawBufferCopyInOut(RawBufferCopyIn):
x: np.ndarray = np.empty(self.size, dtype=self.dtype.np)
self._copyout(x)
return x
class RawConst(RawBuffer): pass # pylint: disable=abstract-method

View File

@@ -1,5 +1,5 @@
import os, time, ctypes, hashlib, subprocess, platform
from tinygrad.ops import CompiledBuffer, Specialized
from tinygrad.ops import Compiled
from tinygrad.runtime.lib import RawMallocBuffer
from tinygrad.codegen.gpu import GPUCodegen, GPULanguage
@@ -23,5 +23,4 @@ class ClangProgram:
class ClangCodegen(GPUCodegen):
lang = GPULanguage(buffer_suffix=" restrict")
class ClangBuffer(CompiledBuffer):
spec = Specialized(RawMallocBuffer, ClangCodegen, ClangProgram)
ClangBuffer = Compiled(RawMallocBuffer, ClangCodegen, ClangProgram)

View File

@@ -1,9 +1,9 @@
import numpy as np
import operator
from typing import ClassVar, Callable, Dict, Tuple
from typing import Callable, Dict, Tuple
from tinygrad.helpers import dtypes
from tinygrad.ops import UnaryOps, BinaryOps, MovementOps, ReduceOps, FusedOps, LoadOps, Op
from tinygrad.interpreted import InterpretedBuffer
from tinygrad.ops import UnaryOps, BinaryOps, MovementOps, ReduceOps, FusedOps, Op, Interpreted
from tinygrad.runtime.lib import RawBuffer
def shape_to_axis(old_shape:Tuple[int, ...], new_shape:Tuple[int, ...]) -> Tuple[int, ...]:
assert len(old_shape) == len(new_shape), "reduce shapes must have same dimensions"
@@ -14,7 +14,7 @@ base_fxn_for_op: Dict[Op, Callable] = {
BinaryOps.ADD: operator.add, BinaryOps.SUB: operator.sub, BinaryOps.MUL: operator.mul, BinaryOps.DIV: operator.truediv, BinaryOps.POW: operator.pow,
ReduceOps.SUM: lambda x, new_shape: x.sum(shape_to_axis(x.shape, new_shape), keepdims=True) if tuple(x.shape) != tuple(new_shape) else x[:],
ReduceOps.MAX: lambda x, new_shape: (x.amax if hasattr(x, 'amax') else x.max)(shape_to_axis(x.shape, new_shape), keepdims=True) if tuple(x.shape) != tuple(new_shape) else x[:],
MovementOps.SHRINK: lambda x, arg: x[tuple(slice(p[0], p[1], None) for p in arg)],
MovementOps.RESHAPE: lambda x, arg: x.reshape(arg), MovementOps.SHRINK: lambda x, arg: x[tuple(slice(p[0], p[1], None) for p in arg)],
}
def einsum_mulacc(einsum, get_strides, expand):
@@ -33,10 +33,11 @@ numpy_fxn_for_op: Dict[Op, Callable] = {**base_fxn_for_op, **{
MovementOps.PERMUTE: lambda x, order: x.transpose(order), MovementOps.PAD: np.pad, MovementOps.EXPAND: np.broadcast_to,
MovementOps.STRIDE: lambda x, arg: x[tuple(slice(None, None, i) for i in arg)],
FusedOps.MULACC: einsum_mulacc(lambda s,a,b: np.einsum(s, a.copy(), b.copy()), lambda x: x.strides, np.broadcast_to),
LoadOps.FROMCPU: lambda arg: arg
}}
class CPUBuffer(InterpretedBuffer):
fxn_for_op: ClassVar[Dict[Op, Callable]] = numpy_fxn_for_op
def to_tinygrad_dtype(self): return dtypes.from_np(self._buf)
class RawNumpyBuffer(RawBuffer):
def __init__(self, buf:np.ndarray): super().__init__(buf.size, dtypes.from_np(buf.dtype), buf)
@classmethod
def fromCPU(cls, x): return cls(x)
def toCPU(self): return self._buf
CPUBuffer = Interpreted(RawNumpyBuffer, numpy_fxn_for_op)

View File

@@ -4,16 +4,14 @@ import pycuda.autoprimaryctx # type: ignore # pylint: disable=unused-import # no
import pycuda.driver as cuda # type: ignore
from pycuda.compiler import compile as cuda_compile # type: ignore
from tinygrad.helpers import DEBUG
from tinygrad.ops import CompiledBuffer, Specialized
from tinygrad.ops import Compiled
from tinygrad.runtime.lib import RawBufferCopyInOut
from tinygrad.codegen.gpu import GPUCodegen, GPULanguage
class RawCUDABuffer(RawBufferCopyInOut):
def __init__(self, size, dtype):
super().__init__(size, dtype)
self._cl = cuda.mem_alloc(self._memsz)
def _copyin(self, x:np.ndarray, stream:Optional[cuda.Stream]=None): cuda.memcpy_htod_async(self._cl, x, stream)
def _copyout(self, x:np.ndarray): cuda.memcpy_dtoh(x, self._cl)
def __init__(self, size, dtype): super().__init__(size, dtype, cuda.mem_alloc(size * dtype.itemsize))
def _copyin(self, x:np.ndarray, stream:Optional[cuda.Stream]=None): cuda.memcpy_htod_async(self._buf, x, stream)
def _copyout(self, x:np.ndarray): cuda.memcpy_dtoh(x, self._buf)
class CUDAProgram:
def __init__(self, name:str, prg:str, binary=False):
@@ -47,5 +45,4 @@ class CUDACodegen(GPUCodegen):
gid = [f'blockDim.{chr(120+i)}*blockIdx.{chr(120+i)}+threadIdx.{chr(120+i)}' for i in range(3)],
lid = [f'threadIdx.{chr(120+i)}' for i in range(3)])
class CUDABuffer(CompiledBuffer):
spec = Specialized(RawCUDABuffer, CUDACodegen, CUDAProgram)
CUDABuffer = Compiled(RawCUDABuffer, CUDACodegen, CUDAProgram)

View File

@@ -2,10 +2,10 @@ from __future__ import annotations
import platform
import numpy as np
import pyopencl as cl # type: ignore
from typing import Optional, List, Final
from tinygrad.helpers import IMAGE, DEBUG, getenv, dtypes
from tinygrad.ops import CompiledBuffer, GlobalCounters, Specialized
from tinygrad.runtime.lib import RawBufferCopyInOut, RawBuffer
from typing import Optional, List
from tinygrad.helpers import DEBUG, getenv, prod, ImageDType
from tinygrad.ops import Compiled
from tinygrad.runtime.lib import RawBufferCopyInOut
from tinygrad.codegen.gpu import GPUCodegen, GPULanguage
OSX = platform.system() == "Darwin"
@@ -21,13 +21,25 @@ class _CL:
self.cl_queue: cl.CommandQueue = cl.CommandQueue(self.cl_ctx, properties=cl.command_queue_properties.PROFILING_ENABLE) # this is an in-order command queue
CL = _CL()
# TODO: merge CLImage in here
class CLBuffer(RawBufferCopyInOut):
def __init__(self, size, dtype):
super().__init__(size, dtype)
self._cl = cl.Buffer(CL.cl_ctx, cl.mem_flags.READ_WRITE, self._memsz)
def _copyin(self, x:np.ndarray): cl.enqueue_copy(CL.cl_queue, self._cl, x, is_blocking=False)
def _copyout(self, x:np.ndarray): cl.enqueue_copy(CL.cl_queue, x, self._cl, is_blocking=True)
if isinstance(dtype, ImageDType):
fmt = cl.ImageFormat(cl.channel_order.RGBA, {2: cl.channel_type.HALF_FLOAT, 4: cl.channel_type.FLOAT}[dtype.itemsize])
buf = cl.Image(CL.cl_ctx, cl.mem_flags.READ_WRITE, fmt, shape=(dtype.shape[1], dtype.shape[0]))
assert size == prod(dtype.shape), f"image size mismatch {size} != {dtype.shape}"
# NOTE: the memory is a bit off here due to padding, it's buf.row_pitch * buf.height * 4 * dtype.itemsize
else:
buf = cl.Buffer(CL.cl_ctx, cl.mem_flags.READ_WRITE, size * dtype.itemsize)
super().__init__(size, dtype, buf)
def _copyin(self, x:np.ndarray):
assert not self.dtype.name.startswith("image"), f"can't copyin images {self.dtype}"
cl.enqueue_copy(CL.cl_queue, self._buf, x, is_blocking=False)
def _copyout(self, x:np.ndarray):
assert not self.dtype.name.startswith("image"), f"can't copyout images {self.dtype}"
cl.enqueue_copy(CL.cl_queue, x, self._buf, is_blocking=True)
"""
class CLImage(RawBuffer): # pylint: disable=abstract-method
IMAGE: Final = True
def __init__(self, shape, dtype=dtypes.float16 if getenv("FLOAT16") else dtypes.float32): # pylint: disable=super-init-not-called
@@ -35,6 +47,7 @@ class CLImage(RawBuffer): # pylint: disable=abstract-method
self.size, self.dtype, self._cl = shape, dtype, cl.Image(CL.cl_ctx, cl.mem_flags.READ_WRITE, fmt, shape=(shape[1], shape[0]))
GlobalCounters.mem_used += self._cl.row_pitch * self._cl.height
def __del__(self): GlobalCounters.mem_used -= self._cl.row_pitch * self._cl.height
"""
class CLProgram:
def __init__(self, name:str, prg:str, binary=False, argdtypes=None):
@@ -59,7 +72,7 @@ class CLProgram:
def max_work_group_size(): return CL.cl_ctx.devices[0].max_work_group_size
def __call__(self, global_size, local_size, *bufs, wait=False) -> Optional[float]:
e = self.clprg(CL.cl_queue, global_size, local_size, *[x._cl if isinstance(x, (CLBuffer, CLImage)) else x for x in bufs])
e = self.clprg(CL.cl_queue, global_size, local_size, *[x._buf if isinstance(x, CLBuffer) else x for x in bufs])
if wait:
CL.cl_queue.finish()
return ((e.profile.end - e.profile.start) * OSX_TIMING_RATIO) * 1e-9
@@ -72,9 +85,13 @@ class CLCodegen(GPUCodegen):
barrier = "barrier(CLK_LOCAL_MEM_FENCE);", float4 = "(float4)",
gid = [f'get_global_id({i})' for i in range(3)], lid = [f'get_local_id({i})' for i in range(3)])
GPUBuffer = Compiled(CLBuffer, CLCodegen, CLProgram)
"""
class GPUBuffer(CompiledBuffer):
spec = Specialized(CLBuffer, CLCodegen, CLProgram)
# override this method for image
def create_raw_buffer(self, shape, backing, dtype) -> RawBuffer:
if len(shape) == 3 and shape[2] == 4 and IMAGE >= 2 and backing is None: return CLImage(shape) # NOTE: this is a hack. we don't pass in the dtype here, it's controlled by the FLOAT16 env var
else: return super().create_raw_buffer(shape, backing, dtype)
"""

View File

@@ -1,6 +1,6 @@
import time, hashlib, ctypes
from typing import ClassVar
from tinygrad.ops import CompiledBuffer, Specialized
from tinygrad.ops import Compiled
from tinygrad.helpers import getenv, DEBUG
from ctypes import CFUNCTYPE
from tinygrad.codegen.llvm import LLVMCodegen
@@ -62,5 +62,4 @@ class LLVMProgram:
cfunc(*[x._buf for x in bufs])
if wait: return time.monotonic()-st
class LLVMBuffer(CompiledBuffer):
spec = Specialized(RawMallocBuffer, LLVMCodegen, LLVMProgram)
LLVMBuffer = Compiled(RawMallocBuffer, LLVMCodegen, LLVMProgram)

View File

@@ -4,7 +4,8 @@ import Metal, Cocoa, libdispatch # type: ignore
from typing import List, Any
from tinygrad.codegen.gpu import GPUCodegen, GPULanguage
from tinygrad.helpers import prod, getenv, DEBUG, DType
from tinygrad.ops import CompiledBuffer, Specialized
#from tinygrad.ops import CompiledBuffer, Specialized
from tinygrad.ops import Compiled
from tinygrad.runtime.lib import RawBufferMapped
METAL_XCODE = getenv("METAL_XCODE")
@@ -17,16 +18,14 @@ class _METAL:
METAL = _METAL()
class RawMetalBuffer(RawBufferMapped):
def __init__(self, size:int, dtype:DType):
super().__init__(size, dtype)
self._cl = METAL.device.newBufferWithLength_options_(size*dtype.itemsize, Metal.MTLResourceStorageModeShared)
def __init__(self, size:int, dtype:DType): super().__init__(size, dtype, METAL.device.newBufferWithLength_options_(size*dtype.itemsize, Metal.MTLResourceStorageModeShared))
def __del__(self):
self._cl.release()
self._buf.release()
super().__del__()
def _buffer(self):
for cbuf in METAL.mtl_buffers_in_flight: cbuf.waitUntilCompleted()
METAL.mtl_buffers_in_flight.clear()
return self._cl.contents().as_buffer(self._cl.length())
return self._buf.contents().as_buffer(self._buf.length())
def unwrap(x):
ret, err = x
@@ -65,7 +64,7 @@ class MetalProgram:
command_buffer = METAL.mtl_queue.commandBuffer()
encoder = command_buffer.computeCommandEncoder()
encoder.setComputePipelineState_(self.pipeline_state)
for i,a in enumerate(bufs): encoder.setBuffer_offset_atIndex_(a._cl, 0, i)
for i,a in enumerate(bufs): encoder.setBuffer_offset_atIndex_(a._buf, 0, i)
encoder.dispatchThreads_threadsPerThreadgroup_(Metal.MTLSize(*global_size), Metal.MTLSize(*local_size))
encoder.endEncoding()
command_buffer.commit()
@@ -82,5 +81,4 @@ class MetalCodegen(GPUCodegen):
gid = [f"gid.{chr(120+i)}" for i in range(3)], lid = [f"lid.{chr(120+i)}" for i in range(3)],
extra_args = ['uint3 gid [[thread_position_in_grid]]', 'uint3 lid [[thread_position_in_threadgroup]]'])
class MetalBuffer(CompiledBuffer):
spec = Specialized(RawMetalBuffer, MetalCodegen, MetalProgram)
MetalBuffer = Compiled(RawMetalBuffer, MetalCodegen, MetalProgram)

View File

@@ -1,9 +1,9 @@
import torch
from typing import ClassVar, Dict, Callable
from tinygrad.ops import UnaryOps, BinaryOps, MovementOps, FusedOps, LoadOps, Op
from tinygrad.helpers import getenv, dtypes
from tinygrad.interpreted import InterpretedBuffer
from typing import Dict, Callable
from tinygrad.ops import UnaryOps, BinaryOps, MovementOps, FusedOps, Op, Interpreted
from tinygrad.helpers import getenv, dtypes, prod
from tinygrad.runtime.ops_cpu import base_fxn_for_op, einsum_mulacc
from tinygrad.runtime.lib import RawBuffer
device = torch.device("cuda:0" if torch.cuda.is_available() else ("mps" if getenv("MPS", 0) else "cpu"))
@@ -13,10 +13,12 @@ torch_fxn_for_op: Dict[Op, Callable] = {**base_fxn_for_op, **{
MovementOps.PAD: lambda x, padding: torch.nn.functional.pad(x, [item for sublist in padding[::-1] for item in sublist]),
FusedOps.MULACC: einsum_mulacc(lambda s,a,b: torch.einsum(s, a.float(), b.float()).type(a.dtype), lambda x: x.stride(), lambda x,s: x.expand(s)),
MovementOps.STRIDE: lambda x, arg: x[tuple(slice(None, None, abs(i)) for i in arg)].flip([i for i,a in enumerate(arg) if a < 0]),
LoadOps.FROMCPU: lambda arg: torch.from_numpy(arg).requires_grad_(False).to(device)
MovementOps.EXPAND: lambda x, arg: x.expand(arg), MovementOps.PERMUTE: lambda x, arg: x.permute(arg)
}}
class TorchBuffer(InterpretedBuffer):
fxn_for_op: ClassVar = torch_fxn_for_op
def to_tinygrad_dtype(self): return {torch.float16: dtypes.float16, torch.float32: dtypes.float32}[self._buf.dtype]
class RawTorchBuffer(RawBuffer):
def __init__(self, buf:torch.Tensor): super().__init__(prod(buf.shape), {torch.float16: dtypes.float16, torch.float32: dtypes.float32}[buf.dtype], buf)
@classmethod
def fromCPU(cls, x): return cls(torch.from_numpy(x).requires_grad_(False).to(device))
def toCPU(self): return self._buf.cpu().numpy()
TorchBuffer = Interpreted(RawTorchBuffer, torch_fxn_for_op)

View File

@@ -38,6 +38,7 @@ class Node:
assert b != 0
if b < 0: return (self//-b)*-1
if b == 1: return self
if isinstance(self, ModNode) and self.b % b == 0: return (self.a//b) % (self.b//b) # put the div inside mod
if isinstance(self, DivNode): return self.a//(self.b*b) # two divs is one div
if isinstance(self, MulNode) and self.b % b == 0: return self.a*(self.b//b)
if isinstance(self, MulNode) and b % self.b == 0: return self.a//(b//self.b)

View File

@@ -40,7 +40,12 @@ class Tensor:
# TODO: this has to realize, it shouldn't have to
data = data.realize().toCPU()
if isinstance(data, (np.ndarray, LazyNumpyArray)):
# all ndarrays are lazy now
if isinstance(data, np.ndarray): data = LazyNumpyArray(data, data.shape, data.dtype)
# by here, it's either LazyNumpyArray or LazyBuffer
# TODO: it should all be LazyBuffer I think
if isinstance(data, LazyNumpyArray):
data = data if data.shape else data.reshape((1,))
lazydata = LazyBuffer.fromCPU(data.astype(dtype.np) if dtype is not None else data, device)
elif isinstance(data, LazyBuffer):
@@ -455,6 +460,7 @@ class Tensor:
# TODO: this is a hack, but if we add float(0), it will become a float. need real casting support
def float(self) -> Tensor: return self.add(Tensor([0], device=self.device, dtype=dtypes.float32, requires_grad=self.requires_grad))
def cast(self, dtype:DType) -> Tensor: return mlops.Cast.apply(self, dtype=dtype) if self.dtype != dtype else self
# register functions to move between devices
for device in Device._buffers: