All LazyOps in the Linearizer (#1905)

* loadop buffer on cpu

* works for GPU

* sort of working

* has bugs

* gpu tests pass

* fix some tests

* fix tensor cores

* fix test linearizer

* fix symbolic

* fix has_variable_shape

* non symbolic size

* disable weird test

* simple cache fix

* fix custom function

* fix kopt

* cleanups

* a bit broken on the assign

* contig check

* only buffer

* need that order

* idx
This commit is contained in:
George Hotz
2023-09-24 11:50:00 +08:00
committed by GitHub
parent 0f373b8b47
commit a5820390db
15 changed files with 151 additions and 138 deletions

View File

@@ -263,14 +263,13 @@ class Linearizer:
uops: List[UOp]
from tinygrad.tensor import Tensor
from tinygrad.helpers import prod
result = Tensor(2).realize() + Tensor(3).realize()
result.lazydata.realized = Device[Device.DEFAULT].buffer(prod(result.shape), result.dtype)
# use the real Linearizer to linearize 2+3
from tinygrad.lazy import _replace_loadops
from tinygrad.codegen.linearizer import Linearizer
from tinygrad.codegen.kernel import LinearizerOptions
linearizer = Linearizer(result.lazydata.op, result.lazydata, LinearizerOptions())
op, _ = _replace_loadops(result.lazydata.op)
linearizer = Linearizer(op)
linearizer.linearize()
# print the uops
@@ -279,13 +278,11 @@ for uop in linearizer.uops: print(uop)
# output:
"""
0 UOps.DEFINE_GLOBAL : ptr.dtypes.float [] ('data0', dtypes.float)
1 UOps.LOOP : [] ([], 'global')
2 UOps.LOOP : [] ([], 'local')
3 UOps.CONST : dtypes.float [] 2.0
4 UOps.CONST : dtypes.float [] 3.0
5 UOps.ALU : dtypes.float [3, 4] BinaryOps.ADD
6 UOps.STORE : [5] MemOp(name='data0', idx=<0>, local=False, memory_dtype=dtypes.float, valid=<1>, invalid_value=0.0)
7 UOps.ENDLOOP : [] ([], 'global+local')
1 UOps.CONST : dtypes.float [] 2.0
2 UOps.CONST : dtypes.float [] 3.0
3 UOps.ALU : dtypes.float [1, 2] BinaryOps.ADD
4 UOps.CONST : dtypes.int [] 0
5 UOps.STORE : [0, 4, 3] None
"""
# %%

View File

@@ -14,17 +14,17 @@ from examples.hlb_cifar10 import SpeedyResNet
from examples.llama import Transformer as LLaMaTransformer, MODEL_PARAMS as LLAMA_MODEL_PARAMS
from examples.stable_diffusion import UNetModel
def kopt_search_hook(k, create_k, to_prg, baseline):
def kopt_search_hook(k, create_k, to_prg, baseline, bufs):
import nevergrad as ng
wanna_output = k.bufs[0].toCPU().copy()
wanna_output = bufs[0].toCPU().copy()
def check_opt(x):
try:
k = create_k()
k.process()
k.apply_auto_opt(x)
prg = to_prg(k)
first_tm = prg.exec(k.bufs, force_wait=True, optimizing=True)
np.testing.assert_allclose(wanna_output, k.bufs[0].toCPU(), atol=1e-4, rtol=1e-4)
first_tm = prg.exec(bufs, force_wait=True, optimizing=True)
np.testing.assert_allclose(wanna_output, bufs[0].toCPU(), atol=1e-4, rtol=1e-4)
return first_tm
except Exception:
return 10000_000 # 10000 seconds is infinity

View File

@@ -24,7 +24,7 @@ def atan2_gpu(ret:LazyBuffer, a:LazyBuffer, b:LazyBuffer):
__kernel void atan2_gpu(global float *c, global float *a, global float *b) {
int idx = get_global_id(0);
c[idx] = atan2(a[idx], b[idx]);
}""", global_size=[prod(ret.shape)]).build(Device[ret.device].runtime).exec([ret, a, b])
}""", global_size=[prod(ret.shape)]).build(Device[ret.device].runtime).exec([ret.realized, a.realized, b.realized])
return ret.realized
def atan2_cpu(ret:LazyBuffer, a:LazyBuffer, b:LazyBuffer):

View File

@@ -5,6 +5,7 @@ from tinygrad.codegen.linearizer import Linearizer, UOps
from tinygrad.ops import Compiled, Device, MovementOps, LazyOp
from tinygrad.tensor import Tensor
from tinygrad.jit import CacheCollector
from tinygrad.lazy import _replace_loadops
class TestLinearizer(unittest.TestCase):
def test_arg_dedup(self):
@@ -30,7 +31,7 @@ class TestLinearizer(unittest.TestCase):
r = a[:-1] + a[1:]
ast = r.lazydata.op
r = r.realize() # realize an output buffer
k = Linearizer(ast, r.lazydata, Device[Device.DEFAULT].linearizer_opts)
k = Linearizer(_replace_loadops(ast)[0], Device[Device.DEFAULT].linearizer_opts)
k.process()
k.upcast()
k.linearize()
@@ -48,7 +49,7 @@ class TestLinearizer(unittest.TestCase):
r = a.expand([2]) + b.expand([2])
ast = r.lazydata.op
r = r.realize() # realize an output buffer
k = Linearizer(ast, r.lazydata, Device[Device.DEFAULT].linearizer_opts)
k = Linearizer(_replace_loadops(ast)[0], Device[Device.DEFAULT].linearizer_opts)
k.process()
k.upcast()
k.linearize()
@@ -63,7 +64,7 @@ class TestLinearizer(unittest.TestCase):
r = Tensor.stack([a, b])
ast = r.lazydata.op
r = r.realize() # realize an output buffer
k = Linearizer(ast, r.lazydata, Device[Device.DEFAULT].linearizer_opts)
k = Linearizer(_replace_loadops(ast)[0], Device[Device.DEFAULT].linearizer_opts)
k.process()
k.upcast()
k.linearize()
@@ -79,7 +80,7 @@ class TestLinearizer(unittest.TestCase):
r = a * b
ast = r.lazydata.op
r = r.realize() # realize an output buffer
k = Linearizer(ast, r.lazydata, Device[Device.DEFAULT].linearizer_opts)
k = Linearizer(_replace_loadops(ast)[0], Device[Device.DEFAULT].linearizer_opts)
k.process()
k.linearize()
num_ops = len([uop for uop in k.uops if uop.uop in [UOps.LOAD, UOps.ALU]])
@@ -88,12 +89,14 @@ class TestLinearizer(unittest.TestCase):
def helper_linearizer_opt(r:Tensor, opts=[]):
wanna_output = None
realized_ast = None
real_bufs = None
# HACK to get real ast.
real_dev_exec_ast = Device[Device.DEFAULT].exec_ast
def fake_exec_ast(ast, output=None, **kwargs):
nonlocal realized_ast
x = real_dev_exec_ast(ast, output, **kwargs)
def fake_exec_ast(ast, output=None, inputs=None, **kwargs):
nonlocal realized_ast, real_bufs
x = real_dev_exec_ast(ast, output, inputs, **kwargs)
real_bufs = [output.realized] + inputs
if not(ast.op in MovementOps and ast.src[0].__class__ is not LazyOp and ast.src[0].realized): realized_ast = ast # get last executed
return x
Device[Device.DEFAULT].exec_ast = fake_exec_ast
@@ -106,26 +109,26 @@ def helper_linearizer_opt(r:Tensor, opts=[]):
k.process()
k.apply_auto_opt(x)
prg = to_prg(k)
k.bufs[0].realized = k.bufs[0].realized.fromCPU(np.zeros(k.bufs[0].shape, dtype=k.bufs[0].dtype.np)) # Zero to check that all values are filled
prg.exec(k.bufs, force_wait=True)
np.testing.assert_allclose(wanna_output, k.bufs[0].toCPU(), atol=1e-4, rtol=1e-4)
real_bufs[0] = real_bufs[0].fromCPU(np.zeros((real_bufs[0].size, ), dtype=real_bufs[0].dtype.np)) # Zero to check that all values are filled
prg.exec(real_bufs, force_wait=True)
np.testing.assert_allclose(wanna_output, real_bufs[0].toCPU(), atol=1e-4, rtol=1e-4)
# Get baseline, which is not optimized at all.
k = Linearizer(realized_ast, r.lazydata, Device[Device.DEFAULT].linearizer_opts)
k = Linearizer(realized_ast, Device[Device.DEFAULT].linearizer_opts)
k.process()
prg = Device[Device.DEFAULT].to_program(k)
prg.exec(k.bufs, force_wait=True)
wanna_output = k.bufs[0].toCPU().copy()
prg.exec(real_bufs, force_wait=True)
wanna_output = real_bufs[0].toCPU().copy()
# Check correctness of handcoded optimiztions.
k = Linearizer(realized_ast, r.lazydata, Device[Device.DEFAULT].linearizer_opts)
k = Linearizer(realized_ast, Device[Device.DEFAULT].linearizer_opts)
k.hand_coded_optimizations()
prg = Device[Device.DEFAULT].to_program(k)
k.bufs[0].realized = k.bufs[0].realized.fromCPU(np.zeros(k.bufs[0].shape, dtype=k.bufs[0].dtype.np)) # Zero to check that all values are filled
prg.exec(k.bufs, force_wait=True)
np.testing.assert_allclose(wanna_output, k.bufs[0].toCPU(), atol=1e-4, rtol=1e-4)
real_bufs[0] = real_bufs[0].fromCPU(np.zeros((real_bufs[0].size, ), dtype=real_bufs[0].dtype.np)) # Zero to check that all values are filled
prg.exec(real_bufs, force_wait=True)
np.testing.assert_allclose(wanna_output, real_bufs[0].toCPU(), atol=1e-4, rtol=1e-4)
for x in opts: # Check custom transformations if any.
check_opt(x, lambda: Linearizer(realized_ast, r.lazydata, Device[Device.DEFAULT].linearizer_opts), Device[Device.DEFAULT].to_program)
check_opt(x, lambda: Linearizer(realized_ast, Device[Device.DEFAULT].linearizer_opts), Device[Device.DEFAULT].to_program)
class TestLinearizerOpts(unittest.TestCase):
def test_local_and_grouped_reduce(self):

View File

@@ -1,18 +1,13 @@
#!/usr/bin/env python
import unittest
from typing import NamedTuple, Tuple
from tinygrad.ops import LazyOp, BinaryOps, ReduceOps, get_lazyop_info
from tinygrad.helpers import DType, dtypes
class TestBuffer(NamedTuple):
__test__ = False # To prevent pytest from collecting this as a test
shape: Tuple[int, ...]
dtype: DType
from tinygrad.ops import LazyOp, BinaryOps, ReduceOps, get_lazyop_info, LoadOps, MemBuffer
from tinygrad.shape.view import View
from tinygrad.helpers import dtypes
class TestFlopCounter(unittest.TestCase):
def setUp(self):
self.buf0 = TestBuffer(shape=(4,), dtype=dtypes.float32)
self.buf1 = TestBuffer(shape=(4,), dtype=dtypes.float32)
self.buf0 = LazyOp(LoadOps.BUFFER, (), MemBuffer(1, dtypes.float32, (View.create((4,)),)))
self.buf1 = LazyOp(LoadOps.BUFFER, (), MemBuffer(2, dtypes.float32, (View.create((4,)),)))
def test_flops_add(self):
op0 = LazyOp(BinaryOps.ADD, (self.buf0,self.buf1,), None)

View File

@@ -1,12 +1,10 @@
from typing import NamedTuple, Optional, List, Tuple, cast, Dict
import itertools
from tinygrad.ops import LazyOp, MovementOps, FlopCounter, get_lazyop_info, ReduceOps
from tinygrad.lazy import LazyBuffer
from tinygrad.helpers import dedup, dtypes, colored, ImageDType, DType
from tinygrad.runtime.lib import buf_is_kernel_arg
from tinygrad.ops import LazyOp, FlopCounter, get_lazyop_info, ReduceOps, LoadOps, MemBuffer
from tinygrad.helpers import dedup, dtypes, colored, ImageDType, DType, all_int
from tinygrad.shape.shapetracker import ShapeTracker
from tinygrad.shape.symbolic import sint
from tinygrad.shape.view import strides_for_shape
from tinygrad.shape.view import strides_for_shape, View
class LocalBuffer(NamedTuple):
name: str
@@ -16,6 +14,7 @@ class LocalBuffer(NamedTuple):
def __str__(self): return f"localbuffer<{self.name}[{self.size}]>"
class LinearizerOptions(NamedTuple):
device: str = ""
# TODO: make this generic with a list of supported types
supports_float4: bool = True
supports_float4_alu: bool = True
@@ -26,41 +25,32 @@ class LinearizerOptions(NamedTuple):
local_max: Optional[List[int]] = None
class Kernel:
def __init__(self, ast:LazyOp, output_buffer:LazyBuffer, opts:Optional[LinearizerOptions]=None):
# NOTE: if there's a RESHAPE, we skip it. the output shape is set from the reduce op or a latebuf
self.ast = ast.src[0] if ast.op == MovementOps.RESHAPE else ast
def __init__(self, ast:LazyOp, opts:Optional[LinearizerOptions]=None, var_vals=None):
self.opts = opts if opts else LinearizerOptions()
# get the output buffers
self.bufs = [output_buffer] + dedup(ast.buffers)
self.arg_bufs = {x:f"data{i}" for i,x in enumerate(dedup([x.realized for x in self.bufs if buf_is_kernel_arg(x)]))}
# key for lookup in cache (can change, str might not be right)
# bufs are needed because kernels like f(x) = x + x and f(x, y) = x + y have the same str(ast), but are different kernels.
# mapping the buffers to integers is required because a-b != b-a (and how would you tell a and b apart?)
self.key = (ast.map_buffers({x:self.arg_bufs.get(x.realized,x) for x in self.bufs}).key, tuple([x.key for x in self.bufs]))
self.ast = ast
self.var_vals = var_vals
self.key = (ast, tuple(var_vals.keys())) if var_vals else ast
def process(self) -> None:
if hasattr(self, "sts"): return # already processed
# fetch lazyop info
self.info: FlopCounter = get_lazyop_info(cast(LazyOp, self.ast))
self.mem_estimate: int = sum(x.dtype.itemsize*x.size for x in self.arg_bufs.keys())
# there's only allowed to be one reduceop
reduceops = [x for x in self.ast.get_lazyops() if x.op in ReduceOps]
assert len(dedup(reduceops)) <= 1, "max one reduce op in an ast"
self.reduceop = reduceops[0] if reduceops else None
# get earlybufs, before the one reduce op
self.earlybufs = dedup(self.reduceop.buffers) if self.reduceop else []
# create new shapetrackers inside this kernel, we will permute them
self.sts: List[ShapeTracker] = [x.st.copy() for x in self.bufs]
self.bufs = [MemBuffer(0, self.info.dtype, (View.create(self.info.shape),))] + [x.arg for x in self.ast.get_lazyops() if x.op in LoadOps]
self.sts: List[ShapeTracker] = [ShapeTracker(x.views[-1].shape, views=list(x.views)) for x in self.bufs]
for st in self.sts: st.simplify()
# make the output buffer shape correct in here
self.sts[0].reshape(self.info.shape)
self.mem_estimate: int = sum(x.dtype.itemsize*x.views[-1].size() for x in self.bufs)
# get earlybufs, before the one reduce op
self.earlybufs = [x.arg for x in self.reduceop.get_lazyops() if x.op in LoadOps] if self.reduceop else []
self.full_buf_index: int = self.bufs.index(self.earlybufs[0]) if self.earlybufs else 0
# parameters
@@ -77,7 +67,7 @@ class Kernel:
def has_variable_shape(self) -> bool:
for b in self.bufs:
if any(not isinstance(x, int) for x in b.st.shape): return True
if not all_int(b.views[-1].shape): return True
return False
def shape_offsets(self, i): return itertools.product(*[list(range(s)) for s in self.sts[i].shape[self.shape_len-self.upcasted:][::-1]]) if self.upcasted > 0 else [tuple()]
@@ -147,6 +137,6 @@ class Kernel:
def colored_shape(self) -> str: return ' '.join(colored(s, color) for s,color in zip([f"{s:4d}" if isinstance(s, int) else s for s in self.full_shape], self.colors()))
def printbufs(self, prefix=""):
for i,st in enumerate(self.sts):
print(prefix, f"{i:3d} {str(self.bufs[i].realized) if self.bufs[i].realized is not None else str(self.bufs[i]):47s}", st.views)
print(prefix, f"{i:3d} {str(self.bufs[i]):47s}", st.views)
print(self.colored_shape())

View File

@@ -5,9 +5,8 @@ from collections import defaultdict
from enum import Enum, auto
from tinygrad.helpers import colored, ImageDType, DEBUG, dtypes, DType, prod, PtrDType, all_same
from tinygrad.ops import LazyOp, UnaryOps
from tinygrad.ops import LazyOp, UnaryOps, LoadOps, ConstBuffer, MemBuffer
from tinygrad.ops import ReduceOps, BinaryOps, TernaryOps
from tinygrad.runtime.lib import RawConst
from tinygrad.shape.shapetracker import ShapeTracker
from tinygrad.shape.symbolic import Variable, NumNode, Node, SumNode, MulNode, DivNode, ModNode, LtNode, AndNode, sym_rename
from tinygrad.codegen.optimizer import OptimizedKernel
@@ -128,11 +127,6 @@ def get_grouped_dims(prefix, start_dim, local_dims, maxdim:int=0):
return local_idxs, [x for x in loop_local_idxs if not isinstance(x, NumNode)]
class Linearizer(OptimizedKernel):
def get_buffer_name(self, i):
if self.bufs[i].__class__ == LocalBuffer: return self.bufs[i].name
assert self.bufs[i].realized.__class__ is not RawConst # constants shouldn't be loaded with memops
return self.arg_bufs[self.bufs[i].realized]
def uop_alu_idx(self, a:UOp, b, ops, ctx:Linearizer, op, dtype=dtypes.int32):
render_b:UOp = cast(UOp, (NumNode(b) if not isinstance(b, Node) else b).render(ops, ctx))
return self.uop(UOps.ALU, dtype, (a, render_b), op)
@@ -147,7 +141,7 @@ class Linearizer(OptimizedKernel):
AndNode: lambda self,ops,ctx: functools.reduce(lambda a,b: ctx.uop_alu_idx(a, b, ops, ctx, BinaryOps.MUL, dtype=dtypes.bool), self.nodes[1:], self.nodes[0].render(ops,ctx)) }
def global_load(self, i:int, idxs:Sequence[VariableOrNum], acc=None) -> List[UOp]:
const = self.bufs[i].realized._buf if isinstance(self.bufs[i].realized, RawConst) else acc
const = self.bufs[i].val if isinstance(self.bufs[i], ConstBuffer) else acc
expanded_nodes = [idx.expand() for idx in idxs]
_idxs = [x[::-1] for x in itertools.product(*expanded_nodes[::-1])]
@@ -176,7 +170,7 @@ class Linearizer(OptimizedKernel):
idx, valid = g_idx.substitute(substitute), g_valid.substitute(substitute)
localtype = dtypes.float32
this_const, idx, valid = (invalid_value, Variable.num(0), Variable.num(1)) if valid.max == 0 else (const, idx, valid)
key = f"{acc}{localtype}{this_const if this_const is not None and acc is None else self.get_buffer_name(i)}{idx.render()}{valid.render()}"
key = f"{acc}{localtype}{this_const if this_const is not None and acc is None else (self.bufs[i].idx if isinstance(self.bufs[i], MemBuffer) else self.bufs[i].name)}{idx.render()}{valid.render()}"
if key not in self.load_cache:
if acc is not None:
assert valid.min == 1
@@ -253,15 +247,13 @@ class Linearizer(OptimizedKernel):
self.loop_uops: Dict[str, UOp] = {}
# add global buffers
arg_bufs = {}
for buf,name in self.arg_bufs.items():
arg_bufs[buf] = self.uop(UOps.DEFINE_GLOBAL, PtrDType(buf.dtype) if not isinstance(buf.dtype, ImageDType) else buf.dtype, (), (name, buf.dtype))
for i,b in enumerate(self.bufs):
if b.realized in arg_bufs: self.buf_uops[i] = arg_bufs[b.realized]
# add variables from symbolic shapes
for var in sorted(set(v for buf in self.ast.buffers for v in buf.var_vals), key=lambda k: k.key):
assert var.expr is not None
self.loop_uops[var.expr] = self.uop(UOps.DEFINE_GLOBAL, dtypes.int32, (), (var.expr, dtypes._arg_int32))
for i,buf in enumerate(self.bufs):
if isinstance(buf, MemBuffer):
self.buf_uops[i] = self.uop(UOps.DEFINE_GLOBAL, PtrDType(buf.dtype) if not isinstance(buf.dtype, ImageDType) else buf.dtype, (), (f"data{buf.idx}", buf.dtype))
if self.var_vals:
for var in sorted(set(self.var_vals), key=lambda k: k.key):
assert var.expr is not None
self.loop_uops[var.expr] = self.uop(UOps.DEFINE_GLOBAL, dtypes.int32, (), (var.expr, dtypes._arg_int32))
# define local buffers
for lb in self.local_alias.values():
self.buf_uops[self.bufs.index(lb)] = self.uop(UOps.DEFINE_LOCAL, PtrDType(dtypes.float32), (), (lb.name, self.sts[self.bufs.index(lb)].size()))
@@ -491,7 +483,8 @@ class Linearizer(OptimizedKernel):
return self.uops[-1]
def ast_parse(self, x, acc, loaded_buffers, do_reduce=False) -> List[UOp]:
if x.__class__ is not LazyOp: return loaded_buffers[x]
if x.__class__ is not LazyOp: return loaded_buffers[x] # for LOCAL_BUFFER
if x.op in [LoadOps.BUFFER, LoadOps.CONST]: return loaded_buffers[x.arg]
if x.op in [UnaryOps.NOOP, UnaryOps.CAST]: return self.ast_parse(x.src[0], acc, loaded_buffers) # cast isn't an ALU op
if x.op in ReduceOps and not do_reduce: return acc
# MULACC fusion. TODO: this is copied from Interpreted

View File

@@ -1,9 +1,8 @@
from typing import Tuple, List, cast, Optional
import itertools, math, os
from tinygrad.helpers import DEBUG, prod, getenv, ImageDType, dtypes
from tinygrad.ops import ReduceOps, BinaryOps, UnaryOps, LazyOp
from tinygrad.ops import ReduceOps, BinaryOps, UnaryOps, LazyOp, LoadOps
from tinygrad.codegen.kernel import Kernel, LocalBuffer
from tinygrad.lazy import LazyBuffer
from tinygrad.shape.shapetracker import ShapeTracker
from tinygrad.shape.view import View
@@ -167,11 +166,11 @@ class OptimizedKernel(Kernel):
self.simplify_ones()
# should use HIP tensor cores?
if getenv("TC", 1) != 0 and self.bufs[0].device == "HIP" and self.reduceop and self.reduceop.op == ReduceOps.SUM and \
if getenv("TC", 1) != 0 and self.opts.device == "HIP" and self.reduceop and self.reduceop.op == ReduceOps.SUM and \
isinstance(self.reduceop.src[0], LazyOp) and self.reduceop.src[0].op == UnaryOps.CAST and \
isinstance(self.reduceop.src[0].src[0], LazyOp) and self.reduceop.src[0].src[0].op == BinaryOps.MUL and \
isinstance(self.reduceop.src[0].src[0].src[0], LazyBuffer) and isinstance(self.reduceop.src[0].src[0].src[1], LazyBuffer) and self.opts.has_local and \
self.reduceop.src[0].src[0].src[0].dtype == dtypes.half and self.reduceop.src[0].src[0].src[1].dtype == dtypes.half:
self.reduceop.src[0].src[0].src[0].op == LoadOps.BUFFER and self.reduceop.src[0].src[0].src[1].op == LoadOps.BUFFER and self.opts.has_local and \
cast(LazyOp, self.reduceop.src[0].src[0].src[0]).arg.dtype == dtypes.half and cast(LazyOp, self.reduceop.src[0].src[0].src[1]).arg.dtype == dtypes.half:
# HIP tensor cores are 16x16x16
buf0 = self.bufs.index(self.reduceop.src[0].src[0].src[0])
buf1 = self.bufs.index(self.reduceop.src[0].src[0].src[1])
@@ -227,10 +226,10 @@ class OptimizedKernel(Kernel):
# should use METAL tensor cores?
# first, confirm it's a straightforward mulacc on a device with real locals
tensor_cores_allowed = getenv("TC", 1) != 0 and (getenv("TC", 1) == 2 or (self.bufs[0].device == "METAL" and os.uname().machine == "arm64"))
tensor_cores_allowed = getenv("TC", 1) != 0 and (getenv("TC", 1) == 2 or (self.opts.device == "METAL" and os.uname().machine == "arm64"))
if tensor_cores_allowed and self.reduceop and self.reduceop.op == ReduceOps.SUM and \
isinstance(self.reduceop.src[0], LazyOp) and self.reduceop.src[0].op == BinaryOps.MUL and \
isinstance(self.reduceop.src[0].src[0], LazyBuffer) and isinstance(self.reduceop.src[0].src[1], LazyBuffer) and self.opts.has_local:
self.reduceop.src[0].src[0].op == LoadOps.BUFFER and self.reduceop.src[0].src[1] == LoadOps.BUFFER and self.opts.has_local:
# METAL tensor cores are 8x8x8, with 2 elements per thread in the 32 thread warp
buf0 = self.bufs.index(self.reduceop.src[0].src[0])
buf1 = self.bufs.index(self.reduceop.src[0].src[1])

View File

@@ -21,7 +21,7 @@ def kernel_optimize_opts(k:Linearizer):
opts.append(ng.p.TransitionChoice([(i,s,"G") for s in get_divisors(k.full_shape[k.first_reduce+i], min_div=4) if all(st.shape[k.first_reduce+i] % s == 0 or st.shape[k.first_reduce+i] == 1 for st in k.sts)]))
return opts
def kernel_optimize_search(k:Linearizer, create_k:Callable[[], Linearizer], to_prg, baseline):
def kernel_optimize_search(k:Linearizer, create_k:Callable[[], Linearizer], to_prg, baseline, bufs):
import nevergrad as ng
def opt(x):
try:
@@ -29,9 +29,9 @@ def kernel_optimize_search(k:Linearizer, create_k:Callable[[], Linearizer], to_p
k.process()
k.apply_auto_opt(x)
prg = to_prg(k)
first_tm = prg.exec(k.bufs, force_wait=True, optimizing=True)
first_tm = prg.exec(bufs, force_wait=True, optimizing=True)
if baseline*5 < first_tm*1000: return first_tm*1000 # very slow
tm = min([first_tm]+[prg.exec(k.bufs, force_wait=True, optimizing=True) for _ in range(2)])*1000
tm = min([first_tm]+[prg.exec(bufs, force_wait=True, optimizing=True) for _ in range(2)])*1000
return tm
except Exception:
if DEBUG >= 3:
@@ -51,7 +51,7 @@ def kernel_optimize_search(k:Linearizer, create_k:Callable[[], Linearizer], to_p
# optimization
global_db = None
def kernel_optimize(k:Linearizer, create_k:Callable[[], Linearizer], to_prg):
def kernel_optimize(k:Linearizer, create_k:Callable[[], Linearizer], to_prg, bufs):
global global_db
k.process()
@@ -72,8 +72,8 @@ def kernel_optimize(k:Linearizer, create_k:Callable[[], Linearizer], to_prg):
k = create_k()
k.hand_coded_optimizations()
prg = to_prg(k)
return min([prg.exec(k.bufs, force_wait=True, optimizing=True) for _ in range(5)])*1000
choice = kernel_optimize_search(k, create_k, to_prg, get_baseline())
return min([prg.exec(bufs, force_wait=True, optimizing=True) for _ in range(5)])*1000
choice = kernel_optimize_search(k, create_k, to_prg, get_baseline(), bufs)
if global_db is not None:
global_db[skey] = choice
global_db.sync()

View File

@@ -1,16 +1,16 @@
from __future__ import annotations
import sys, operator, math
from typing import Callable, Optional, Tuple, Union, List, Dict, Any, cast
from typing import Callable, Optional, Tuple, Union, List, Dict, Any, cast, Mapping
from weakref import ref, WeakSet, WeakValueDictionary
import numpy as np
from tinygrad.graph import log_op
from tinygrad.helpers import GRAPH, DEBUG, prod, getenv, DType, dtypes, flatten, ImageDType, partition, all_int
from tinygrad.ops import Device, Compiled, UnaryOps, BinaryOps, TernaryOps, ReduceOps, MovementOps, LoadOps, OpType, LazyOp
from tinygrad.helpers import GRAPH, DEBUG, prod, getenv, DType, dtypes, flatten, ImageDType, partition, all_int, dedup, merge_dicts
from tinygrad.ops import Device, Compiled, UnaryOps, BinaryOps, TernaryOps, ReduceOps, MovementOps, LoadOps, OpType, LazyOp, MemBuffer, ConstBuffer
from tinygrad.shape.shapetracker import ShapeTracker, View, get_contraction
from tinygrad.shape.symbolic import Variable, sint
from tinygrad.runtime.lib import RawConst, RawBuffer, RawBufferMapped, RawBufferTransfer
from tinygrad.runtime.lib import RawConst, RawBuffer, RawBufferMapped, RawBufferTransfer, buf_is_kernel_arg
from tinygrad.runtime.ops_cpu import RawNumpyBuffer
from tinygrad.runtime.ops_disk import RawDiskBuffer
@@ -90,6 +90,20 @@ def _ast_binaryops(self:LazyBuffer) -> LazyOp:
ast = self.op.map_buffers(cast(Dict[LazyBuffer, Union[LazyOp, LazyBuffer, str]], real_srcs))
return LazyOp(MovementOps.RESHAPE, (ast, ), self.shape) if intermediate_shape != self.shape else ast
def _replace_loadops(op:LazyOp) -> Tuple[LazyOp, List[LazyBuffer]]:
replacements:Dict[LazyBuffer, LazyOp] = {}
realized_bufs = dedup([x.realized for x in op.buffers if buf_is_kernel_arg(x)])
for x in op.buffers:
assert x.realized, "buffer isn't realized"
x.st.simplify()
if isinstance(x.realized, RawConst):
replacements[x] = LazyOp(LoadOps.CONST, (), ConstBuffer(x.realized._buf, x.realized.dtype, tuple(x.st.views)))
elif x.realized in realized_bufs:
replacements[x] = LazyOp(LoadOps.BUFFER, (), MemBuffer(realized_bufs.index(x.realized)+1, x.realized.dtype, tuple(x.st.views)))
else:
raise NotImplementedError(f"not handled {x}")
return (op.src[0] if op.op == MovementOps.RESHAPE else op).map_buffers(replacements), realized_bufs
# **** lazy operations ****
def get_single_root(root:LazyBuffer) -> LazyBuffer: return get_single_root(cast(LazyBuffer, root.op.src[0])) if getattr(root, 'op', None) and len(root.op.src) == 1 else root
@@ -162,7 +176,9 @@ class LazyBuffer:
else:
self.op = LazyOp(UnaryOps.CAST, (self.op,), (dtypes.float32, False))
self.dtype = dtypes.float32
self.realized = Device[self.device].exec_ast(self.op, output=self, **self._device_extra_args())
self.var_vals = dict(sorted(merge_dicts([buf.var_vals for buf in self.op.buffers]).items(), key=lambda kv:cast(Variable,kv[0]).key))
op, realized_bufs = _replace_loadops(self.op)
self.realized = Device[self.device].exec_ast(op, output=self, inputs=realized_bufs, var_vals=self.var_vals, **self._device_extra_args())
assert self.realized and isinstance(self.realized, (RawConst, Device[self.device].buffer)), f"device mismatch on realized got {type(self.realized)} expected {self.device}"
# HACK: allow hot casting of images
@@ -315,7 +331,7 @@ class LazyBuffer:
@property
def buffers(self) -> Tuple[LazyBuffer, ...]: return (self,)
def map_buffers(self, real_srcs: Dict[LazyBuffer, Union[LazyBuffer, LazyOp, str]]): return real_srcs.get(self, self)
def map_buffers(self, real_srcs: Mapping[LazyBuffer, Union[LazyBuffer, LazyOp, str]]): return real_srcs.get(self, self)
def get_lazyops(self) -> List[LazyOp]: return []
def replace_with_movement_ops(self: LazyBuffer, ops:List[Tuple[MovementOps, Any]]) -> LazyBuffer:
y = self

View File

@@ -1,8 +1,10 @@
from __future__ import annotations
import time, importlib, inspect, functools, pathlib
from enum import Enum, auto
from typing import TYPE_CHECKING, Union, Type, Tuple, Any, List, Optional, Dict, Callable, cast
from tinygrad.helpers import ansilen, prod, DEBUG, getenv, GlobalCounters, DType, colored, dedup, merge_dicts
from typing import TYPE_CHECKING, Union, Type, Tuple, Any, List, Optional, Dict, Callable, cast, Mapping
from tinygrad.helpers import ansilen, prod, DEBUG, getenv, GlobalCounters, DType, colored
from tinygrad.shape.view import View
from dataclasses import dataclass
if TYPE_CHECKING: from tinygrad.lazy import LazyBuffer
# these are the llops your accelerator must implement, along with toCpu
@@ -14,11 +16,23 @@ class BinaryOps(Enum): ADD = auto(); SUB = auto(); MUL = auto(); DIV = auto(); M
class ReduceOps(Enum): SUM = auto(); MAX = auto() # noqa: E702
class TernaryOps(Enum): MULACC = auto(); WHERE = auto() # noqa: E702
class MovementOps(Enum): RESHAPE = auto(); PERMUTE = auto(); EXPAND = auto(); PAD = auto(); SHRINK = auto(); STRIDE = auto(); AS_STRIDED = auto() # noqa: E702
class LoadOps(Enum): EMPTY = auto(); RAND = auto(); CONST = auto(); FROM = auto(); CONTIGUOUS = auto(); CUSTOM = auto() # noqa: E702
class LoadOps(Enum): EMPTY = auto(); RAND = auto(); CONST = auto(); FROM = auto(); CONTIGUOUS = auto(); CUSTOM = auto(); BUFFER = auto() # noqa: E702
Op = Union[UnaryOps, BinaryOps, ReduceOps, MovementOps, LoadOps, TernaryOps]
OpType = Union[Type[UnaryOps], Type[BinaryOps], Type[ReduceOps], Type[MovementOps], Type[LoadOps], Type[TernaryOps]]
@dataclass(frozen=True)
class MemBuffer:
idx: int
dtype: DType
views: Tuple[View, ...]
@dataclass(frozen=True)
class ConstBuffer:
val: Any
dtype: DType
views: Tuple[View, ...]
class LazyOp:
__slots__ = "op", "src", "arg", "buffers", "__weakref__"
op: Op
@@ -37,7 +51,7 @@ class LazyOp:
@property
def key(self): return (self.op, tuple(map(lambda x: getattr(x, "key", x), self.src)), getattr(self.arg, "key", self.arg))
def map_buffers(self, real_srcs: Dict[LazyBuffer, Union[LazyBuffer, LazyOp, str]]) -> LazyOp: return LazyOp(self.op, tuple([y.map_buffers(real_srcs) for y in self.src]), self.arg)
def map_buffers(self, real_srcs: Mapping[LazyBuffer, Union[LazyBuffer, LazyOp, str]]) -> LazyOp: return LazyOp(self.op, tuple([y.map_buffers(real_srcs) for y in self.src]), self.arg)
def get_lazyops(self) -> List[LazyOp]: return [self] + [item for x in self.src for item in x.get_lazyops()]
def replace_with_movement_ops(self:LazyOp, ops:List[Tuple[MovementOps, Tuple[Any, ...]]]) -> 'LazyBuffer':
@@ -87,9 +101,8 @@ Device = _Device()
# **************** for Interpreted Buffers ****************
def apply_shapetracker(fxn_for_op, ret, st):
st.simplify() # TODO: this is generic for Compiled too
for v in st.views:
def apply_shapetracker(fxn_for_op, ret, views):
for v in views:
real_shape = tuple(y-x for x,y in v.mask) if v.mask else v.shape
real_offset = v.offset + (sum(x*st for (x,_),st in zip(v.mask, v.strides)) if v.mask else 0)
# first, we apply the offset
@@ -111,20 +124,22 @@ def apply_shapetracker(fxn_for_op, ret, st):
return ret
class Interpreted:
def __init__(self, buffer, fxn_for_op: Dict[Op, Callable], from_lazybuffer=None, to_underlying=lambda x: x._buf, from_underlying=None):
def __init__(self, buffer, fxn_for_op: Dict[Op, Callable], to_underlying=lambda x: x._buf, from_underlying=None):
self.buffer, self.fxn_for_op, self.to_underlying = buffer, fxn_for_op, to_underlying
self.from_underlying = buffer if from_underlying is None else from_underlying
self.from_lazybuffer = from_lazybuffer if from_lazybuffer is not None else lambda x: self.from_underlying(apply_shapetracker(self.fxn_for_op, self.to_underlying(x.realized), x.st))
self.synchronize = lambda: None
self.codegen = None
def exec_ast(self, ast:LazyOp, output=None, context=None, **kwargs):
def exec_ast(self, ast:LazyOp, output=None, inputs=None, var_vals=None, context=None, **kwargs):
if ast.op == LoadOps.BUFFER and LoadOps.BUFFER not in self.fxn_for_op:
assert inputs[ast.arg.idx-1].dtype == ast.arg.dtype, "dtype mismatch"
return self.from_underlying(apply_shapetracker(self.fxn_for_op, self.to_underlying(inputs[ast.arg.idx-1]), ast.arg.views))
if TernaryOps.MULACC in self.fxn_for_op and ast.op == ReduceOps.SUM and isinstance(ast.src[0], LazyOp) and ast.src[0].op == BinaryOps.MUL:
ast = LazyOp(TernaryOps.MULACC, ast.src[0].src, ast.arg)
created_context = context is None
if context is None: context = dict()
if not created_context and ast in context: return context[ast]
srcs = [self.exec_ast(x, context=context, **kwargs) if isinstance(x, LazyOp) else self.from_lazybuffer(x) for x in ast.src]
srcs = [self.exec_ast(cast(LazyOp, x), inputs=inputs, context=context, **kwargs) for x in ast.src]
if DEBUG >= 3: st = time.perf_counter()
ret = self.from_underlying(self.fxn_for_op[ast.op](*([self.to_underlying(x) for x in srcs] + ([ast.arg] if ast.arg is not None else []))))
if output is not None and ret.dtype != output.dtype and UnaryOps.CAST in self.fxn_for_op: ret = self.from_underlying(self.fxn_for_op[UnaryOps.CAST](self.to_underlying(ret), (output.dtype, False))) # Do manual casting of ret if it does not match the required output dtype.
@@ -147,17 +162,18 @@ class FlopCounter:
self.flops, ret = 0, self.flops
return ret
shape_fxn_for_op: Dict[Op, Callable] = {
LoadOps.BUFFER: lambda arg: (arg.views[-1].shape, arg.dtype, 0), LoadOps.CONST: lambda arg: (arg.views[-1].shape, arg.dtype, 0),
UnaryOps.CAST: lambda self,arg: (self.shape, arg[0], self.consume_flops()), # cast uses no flops
**{op:lambda self: (self.shape, self.dtype, self.consume_flops() + prod(self.shape)) for op in UnaryOps if op != UnaryOps.CAST},
**{op:lambda self,y: (self.shape, max(self.dtype, y.dtype), self.consume_flops() + y.consume_flops() + prod(self.shape)) for op in BinaryOps},
**{op:lambda self,new_shape: (new_shape, self.dtype, self.consume_flops() + prod(self.shape)) for op in ReduceOps},
TernaryOps.WHERE: lambda self,y,z: (self.shape, self.dtype, self.consume_flops() + y.consume_flops() + z.consume_flops() + prod(self.shape))}
InterpretedFlopCounter = Interpreted(FlopCounter, shape_fxn_for_op, lambda x: FlopCounter((x.shape, x.dtype, 0)), lambda x: x)
InterpretedFlopCounter = Interpreted(FlopCounter, shape_fxn_for_op, lambda x: x)
def get_lazyop_info(ast:LazyOp) -> FlopCounter: return InterpretedFlopCounter.exec_ast(ast)
# **************** for Compiled Buffers ****************
from tinygrad.runtime.lib import RawBuffer, RawConst, buf_is_kernel_arg
from tinygrad.runtime.lib import RawBuffer, RawConst
from tinygrad.shape.symbolic import Variable, sym_infer
class ASTRunner:
@@ -169,9 +185,8 @@ class ASTRunner:
self.clprg = runtime(self.name, self.prg, **self.runtime_args)
return self
def exec(self, bufs, var_vals:Optional[Dict[Variable, int]]=None, force_wait=False, optimizing=False) -> Optional[float]:
def exec(self, rawbufs, var_vals:Optional[Dict[Variable, int]]=None, force_wait=False, optimizing=False) -> Optional[float]:
from tinygrad.jit import CacheCollector
rawbufs = dedup([x.realized for x in bufs if buf_is_kernel_arg(x)])
if not optimizing: CacheCollector.add(self, rawbufs, var_vals if var_vals is not None else {})
return self(rawbufs, var_vals, force_wait=force_wait)
@@ -205,33 +220,38 @@ class Compiled:
op_estimate=k.info.flops, mem_estimate=k.mem_estimate,
display_name=k.display_name, runtime_args={"binary": False}).build(self.runtime)
def exec_ast(self, ast:LazyOp, output, **kwargs):
def exec_ast(self, ast:LazyOp, output, inputs, var_vals, **kwargs):
#if DEBUG >= 4:
# from extra.utils import print_tree
# print_tree(ast)
# check if we can reuse the output buffer
# if it's aliased, don't use it
# NOTE: this is pretty wrong actually, who knows where else this buffer is used?
output.realized = output.output_buffer
if output.realized:
if output.realized.__class__ is RawConst: output.realized = None # can't assign to RawConst
for a in ast.buffers:
if a.realized == output.realized and not a.st.contiguous:
output.realized = None
break
for i,a in enumerate(inputs):
# TODO: if this is contiguous it's fine
if a == output.realized:
views = [x.arg.views for x in ast.get_lazyops() if x.op == LoadOps.BUFFER and x.arg.idx == i+1]
if any(len(v) > 1 or not v[0].contiguous for v in views):
output.realized = None
break
# we don't have an output buffer, we have to create it, and create to max size if it has symbolic shape
if not output.realized: output.realized = self.buffer(prod((s if isinstance(s, int) else s.max for s in output.shape)), output.dtype, **kwargs)
else:
from tinygrad.jit import CacheCollector
CacheCollector._mark_output_buffer(output.output_buffer)
# update the output var_vals from src
output.var_vals = dict(sorted(merge_dicts([buf.var_vals for buf in ast.buffers]).items(), key=lambda kv:cast(Variable,kv[0]).key))
from tinygrad.codegen.linearizer import Linearizer
k = Linearizer(ast, output, self.linearizer_opts)
k = Linearizer(ast, self.linearizer_opts, var_vals)
# compilation time
def get_program():
from tinygrad.codegen.search import kernel_optimize
if getenv("KOPT"): kernel_optimize(k, lambda: Linearizer(ast, output, self.linearizer_opts), self.to_program)
if getenv("KOPT"): kernel_optimize(k, lambda: Linearizer(ast, self.linearizer_opts, var_vals), self.to_program, [output.realized]+inputs)
elif not getenv("NOOPT"): k.hand_coded_optimizations()
return self.to_program(k)
@@ -243,5 +263,5 @@ class Compiled:
if prg.name == getenv("PRINT_PRG", ''): print(prg.prg)
prg.exec(k.bufs, var_vals=output.var_vals)
prg.exec([output.realized]+inputs, var_vals=var_vals)
return output.realized

View File

@@ -17,7 +17,7 @@ class RawBuffer: # pylint: disable=abstract-method
def __del__(self): # NOTE: if it fails on init (bad dtype), it won't have a _memsz
if hasattr(self, '_memsz'): GlobalCounters.mem_used -= self._memsz
if hasattr(self, '_allocator') and self._allocator: self._allocator.free(self._buf)
def __repr__(self): return f"buffer<{self.size}, {self.dtype}>"
def __repr__(self): return f"buffer<{self.size}, {self.dtype}, {id(self)}>"
@property
def key(self): return (self.size, self.dtype)

View File

@@ -93,4 +93,4 @@ __device__ void vstore_half4(float4 data, size_t offset, half *p) { *(p + offset
""",
gid = [f'blockIdx.{chr(120+i)}' for i in range(3)],
lid = [f'threadIdx.{chr(120+i)}' for i in range(3)]))
HIPBuffer = Compiled(RawHIPBuffer, LinearizerOptions(), renderer, HIPProgram, hip.hipDeviceSynchronize)
HIPBuffer = Compiled(RawHIPBuffer, LinearizerOptions(device="HIP"), renderer, HIPProgram, hip.hipDeviceSynchronize)

View File

@@ -85,4 +85,4 @@ renderer = functools.partial(uops_to_cstyle, CStyleLanguage(
barrier = "threadgroup_barrier(mem_flags::mem_threadgroup);", float4 = "float4", uses_ptr_arithmetic=True,
gid = [f"gid.{chr(120+i)}" for i in range(3)], lid = [f"lid.{chr(120+i)}" for i in range(3)],
extra_args = ['uint3 gid [[threadgroup_position_in_grid]]', 'uint3 lid [[thread_position_in_threadgroup]]']))
MetalBuffer = Compiled(RawMetalBuffer, LinearizerOptions(), renderer, MetalProgram, METAL.synchronize)
MetalBuffer = Compiled(RawMetalBuffer, LinearizerOptions(device="METAL"), renderer, MetalProgram, METAL.synchronize)

View File

@@ -2,7 +2,7 @@ from __future__ import annotations
import functools
from typing import Tuple, List, Optional, NamedTuple
from tinygrad.helpers import prod, all_int
from tinygrad.shape.symbolic import NumNode, is_sym_int, sint
from tinygrad.shape.symbolic import Node, NumNode, is_sym_int, sint
@functools.lru_cache(maxsize=None)
def filter_strides(shape:Tuple[int, ...], strides:Tuple[int, ...]) -> Tuple[int, ...]:
@@ -29,7 +29,7 @@ class View(NamedTuple):
return View(shape, strides, offset, mask, contiguous)
@functools.lru_cache(maxsize=None) # pylint: disable=method-cache-max-size-none
def size(self): return prod([s for s,st in zip(self.shape, self.strides) if st != 0])
def size(self): return prod([s.max if isinstance(s, Node) else s for s,st in zip(self.shape, self.strides) if st != 0])
# MovementOps live here now