mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
All LazyOps in the Linearizer (#1905)
* loadop buffer on cpu * works for GPU * sort of working * has bugs * gpu tests pass * fix some tests * fix tensor cores * fix test linearizer * fix symbolic * fix has_variable_shape * non symbolic size * disable weird test * simple cache fix * fix custom function * fix kopt * cleanups * a bit broken on the assign * contig check * only buffer * need that order * idx
This commit is contained in:
@@ -263,14 +263,13 @@ class Linearizer:
|
||||
uops: List[UOp]
|
||||
|
||||
from tinygrad.tensor import Tensor
|
||||
from tinygrad.helpers import prod
|
||||
result = Tensor(2).realize() + Tensor(3).realize()
|
||||
result.lazydata.realized = Device[Device.DEFAULT].buffer(prod(result.shape), result.dtype)
|
||||
|
||||
# use the real Linearizer to linearize 2+3
|
||||
from tinygrad.lazy import _replace_loadops
|
||||
from tinygrad.codegen.linearizer import Linearizer
|
||||
from tinygrad.codegen.kernel import LinearizerOptions
|
||||
linearizer = Linearizer(result.lazydata.op, result.lazydata, LinearizerOptions())
|
||||
op, _ = _replace_loadops(result.lazydata.op)
|
||||
linearizer = Linearizer(op)
|
||||
linearizer.linearize()
|
||||
|
||||
# print the uops
|
||||
@@ -279,13 +278,11 @@ for uop in linearizer.uops: print(uop)
|
||||
# output:
|
||||
"""
|
||||
0 UOps.DEFINE_GLOBAL : ptr.dtypes.float [] ('data0', dtypes.float)
|
||||
1 UOps.LOOP : [] ([], 'global')
|
||||
2 UOps.LOOP : [] ([], 'local')
|
||||
3 UOps.CONST : dtypes.float [] 2.0
|
||||
4 UOps.CONST : dtypes.float [] 3.0
|
||||
5 UOps.ALU : dtypes.float [3, 4] BinaryOps.ADD
|
||||
6 UOps.STORE : [5] MemOp(name='data0', idx=<0>, local=False, memory_dtype=dtypes.float, valid=<1>, invalid_value=0.0)
|
||||
7 UOps.ENDLOOP : [] ([], 'global+local')
|
||||
1 UOps.CONST : dtypes.float [] 2.0
|
||||
2 UOps.CONST : dtypes.float [] 3.0
|
||||
3 UOps.ALU : dtypes.float [1, 2] BinaryOps.ADD
|
||||
4 UOps.CONST : dtypes.int [] 0
|
||||
5 UOps.STORE : [0, 4, 3] None
|
||||
"""
|
||||
|
||||
# %%
|
||||
|
||||
@@ -14,17 +14,17 @@ from examples.hlb_cifar10 import SpeedyResNet
|
||||
from examples.llama import Transformer as LLaMaTransformer, MODEL_PARAMS as LLAMA_MODEL_PARAMS
|
||||
from examples.stable_diffusion import UNetModel
|
||||
|
||||
def kopt_search_hook(k, create_k, to_prg, baseline):
|
||||
def kopt_search_hook(k, create_k, to_prg, baseline, bufs):
|
||||
import nevergrad as ng
|
||||
wanna_output = k.bufs[0].toCPU().copy()
|
||||
wanna_output = bufs[0].toCPU().copy()
|
||||
def check_opt(x):
|
||||
try:
|
||||
k = create_k()
|
||||
k.process()
|
||||
k.apply_auto_opt(x)
|
||||
prg = to_prg(k)
|
||||
first_tm = prg.exec(k.bufs, force_wait=True, optimizing=True)
|
||||
np.testing.assert_allclose(wanna_output, k.bufs[0].toCPU(), atol=1e-4, rtol=1e-4)
|
||||
first_tm = prg.exec(bufs, force_wait=True, optimizing=True)
|
||||
np.testing.assert_allclose(wanna_output, bufs[0].toCPU(), atol=1e-4, rtol=1e-4)
|
||||
return first_tm
|
||||
except Exception:
|
||||
return 10000_000 # 10000 seconds is infinity
|
||||
|
||||
@@ -24,7 +24,7 @@ def atan2_gpu(ret:LazyBuffer, a:LazyBuffer, b:LazyBuffer):
|
||||
__kernel void atan2_gpu(global float *c, global float *a, global float *b) {
|
||||
int idx = get_global_id(0);
|
||||
c[idx] = atan2(a[idx], b[idx]);
|
||||
}""", global_size=[prod(ret.shape)]).build(Device[ret.device].runtime).exec([ret, a, b])
|
||||
}""", global_size=[prod(ret.shape)]).build(Device[ret.device].runtime).exec([ret.realized, a.realized, b.realized])
|
||||
return ret.realized
|
||||
|
||||
def atan2_cpu(ret:LazyBuffer, a:LazyBuffer, b:LazyBuffer):
|
||||
|
||||
@@ -5,6 +5,7 @@ from tinygrad.codegen.linearizer import Linearizer, UOps
|
||||
from tinygrad.ops import Compiled, Device, MovementOps, LazyOp
|
||||
from tinygrad.tensor import Tensor
|
||||
from tinygrad.jit import CacheCollector
|
||||
from tinygrad.lazy import _replace_loadops
|
||||
|
||||
class TestLinearizer(unittest.TestCase):
|
||||
def test_arg_dedup(self):
|
||||
@@ -30,7 +31,7 @@ class TestLinearizer(unittest.TestCase):
|
||||
r = a[:-1] + a[1:]
|
||||
ast = r.lazydata.op
|
||||
r = r.realize() # realize an output buffer
|
||||
k = Linearizer(ast, r.lazydata, Device[Device.DEFAULT].linearizer_opts)
|
||||
k = Linearizer(_replace_loadops(ast)[0], Device[Device.DEFAULT].linearizer_opts)
|
||||
k.process()
|
||||
k.upcast()
|
||||
k.linearize()
|
||||
@@ -48,7 +49,7 @@ class TestLinearizer(unittest.TestCase):
|
||||
r = a.expand([2]) + b.expand([2])
|
||||
ast = r.lazydata.op
|
||||
r = r.realize() # realize an output buffer
|
||||
k = Linearizer(ast, r.lazydata, Device[Device.DEFAULT].linearizer_opts)
|
||||
k = Linearizer(_replace_loadops(ast)[0], Device[Device.DEFAULT].linearizer_opts)
|
||||
k.process()
|
||||
k.upcast()
|
||||
k.linearize()
|
||||
@@ -63,7 +64,7 @@ class TestLinearizer(unittest.TestCase):
|
||||
r = Tensor.stack([a, b])
|
||||
ast = r.lazydata.op
|
||||
r = r.realize() # realize an output buffer
|
||||
k = Linearizer(ast, r.lazydata, Device[Device.DEFAULT].linearizer_opts)
|
||||
k = Linearizer(_replace_loadops(ast)[0], Device[Device.DEFAULT].linearizer_opts)
|
||||
k.process()
|
||||
k.upcast()
|
||||
k.linearize()
|
||||
@@ -79,7 +80,7 @@ class TestLinearizer(unittest.TestCase):
|
||||
r = a * b
|
||||
ast = r.lazydata.op
|
||||
r = r.realize() # realize an output buffer
|
||||
k = Linearizer(ast, r.lazydata, Device[Device.DEFAULT].linearizer_opts)
|
||||
k = Linearizer(_replace_loadops(ast)[0], Device[Device.DEFAULT].linearizer_opts)
|
||||
k.process()
|
||||
k.linearize()
|
||||
num_ops = len([uop for uop in k.uops if uop.uop in [UOps.LOAD, UOps.ALU]])
|
||||
@@ -88,12 +89,14 @@ class TestLinearizer(unittest.TestCase):
|
||||
def helper_linearizer_opt(r:Tensor, opts=[]):
|
||||
wanna_output = None
|
||||
realized_ast = None
|
||||
real_bufs = None
|
||||
|
||||
# HACK to get real ast.
|
||||
real_dev_exec_ast = Device[Device.DEFAULT].exec_ast
|
||||
def fake_exec_ast(ast, output=None, **kwargs):
|
||||
nonlocal realized_ast
|
||||
x = real_dev_exec_ast(ast, output, **kwargs)
|
||||
def fake_exec_ast(ast, output=None, inputs=None, **kwargs):
|
||||
nonlocal realized_ast, real_bufs
|
||||
x = real_dev_exec_ast(ast, output, inputs, **kwargs)
|
||||
real_bufs = [output.realized] + inputs
|
||||
if not(ast.op in MovementOps and ast.src[0].__class__ is not LazyOp and ast.src[0].realized): realized_ast = ast # get last executed
|
||||
return x
|
||||
Device[Device.DEFAULT].exec_ast = fake_exec_ast
|
||||
@@ -106,26 +109,26 @@ def helper_linearizer_opt(r:Tensor, opts=[]):
|
||||
k.process()
|
||||
k.apply_auto_opt(x)
|
||||
prg = to_prg(k)
|
||||
k.bufs[0].realized = k.bufs[0].realized.fromCPU(np.zeros(k.bufs[0].shape, dtype=k.bufs[0].dtype.np)) # Zero to check that all values are filled
|
||||
prg.exec(k.bufs, force_wait=True)
|
||||
np.testing.assert_allclose(wanna_output, k.bufs[0].toCPU(), atol=1e-4, rtol=1e-4)
|
||||
real_bufs[0] = real_bufs[0].fromCPU(np.zeros((real_bufs[0].size, ), dtype=real_bufs[0].dtype.np)) # Zero to check that all values are filled
|
||||
prg.exec(real_bufs, force_wait=True)
|
||||
np.testing.assert_allclose(wanna_output, real_bufs[0].toCPU(), atol=1e-4, rtol=1e-4)
|
||||
|
||||
# Get baseline, which is not optimized at all.
|
||||
k = Linearizer(realized_ast, r.lazydata, Device[Device.DEFAULT].linearizer_opts)
|
||||
k = Linearizer(realized_ast, Device[Device.DEFAULT].linearizer_opts)
|
||||
k.process()
|
||||
prg = Device[Device.DEFAULT].to_program(k)
|
||||
prg.exec(k.bufs, force_wait=True)
|
||||
wanna_output = k.bufs[0].toCPU().copy()
|
||||
prg.exec(real_bufs, force_wait=True)
|
||||
wanna_output = real_bufs[0].toCPU().copy()
|
||||
|
||||
# Check correctness of handcoded optimiztions.
|
||||
k = Linearizer(realized_ast, r.lazydata, Device[Device.DEFAULT].linearizer_opts)
|
||||
k = Linearizer(realized_ast, Device[Device.DEFAULT].linearizer_opts)
|
||||
k.hand_coded_optimizations()
|
||||
prg = Device[Device.DEFAULT].to_program(k)
|
||||
k.bufs[0].realized = k.bufs[0].realized.fromCPU(np.zeros(k.bufs[0].shape, dtype=k.bufs[0].dtype.np)) # Zero to check that all values are filled
|
||||
prg.exec(k.bufs, force_wait=True)
|
||||
np.testing.assert_allclose(wanna_output, k.bufs[0].toCPU(), atol=1e-4, rtol=1e-4)
|
||||
real_bufs[0] = real_bufs[0].fromCPU(np.zeros((real_bufs[0].size, ), dtype=real_bufs[0].dtype.np)) # Zero to check that all values are filled
|
||||
prg.exec(real_bufs, force_wait=True)
|
||||
np.testing.assert_allclose(wanna_output, real_bufs[0].toCPU(), atol=1e-4, rtol=1e-4)
|
||||
for x in opts: # Check custom transformations if any.
|
||||
check_opt(x, lambda: Linearizer(realized_ast, r.lazydata, Device[Device.DEFAULT].linearizer_opts), Device[Device.DEFAULT].to_program)
|
||||
check_opt(x, lambda: Linearizer(realized_ast, Device[Device.DEFAULT].linearizer_opts), Device[Device.DEFAULT].to_program)
|
||||
|
||||
class TestLinearizerOpts(unittest.TestCase):
|
||||
def test_local_and_grouped_reduce(self):
|
||||
|
||||
@@ -1,18 +1,13 @@
|
||||
#!/usr/bin/env python
|
||||
import unittest
|
||||
from typing import NamedTuple, Tuple
|
||||
from tinygrad.ops import LazyOp, BinaryOps, ReduceOps, get_lazyop_info
|
||||
from tinygrad.helpers import DType, dtypes
|
||||
|
||||
class TestBuffer(NamedTuple):
|
||||
__test__ = False # To prevent pytest from collecting this as a test
|
||||
shape: Tuple[int, ...]
|
||||
dtype: DType
|
||||
from tinygrad.ops import LazyOp, BinaryOps, ReduceOps, get_lazyop_info, LoadOps, MemBuffer
|
||||
from tinygrad.shape.view import View
|
||||
from tinygrad.helpers import dtypes
|
||||
|
||||
class TestFlopCounter(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.buf0 = TestBuffer(shape=(4,), dtype=dtypes.float32)
|
||||
self.buf1 = TestBuffer(shape=(4,), dtype=dtypes.float32)
|
||||
self.buf0 = LazyOp(LoadOps.BUFFER, (), MemBuffer(1, dtypes.float32, (View.create((4,)),)))
|
||||
self.buf1 = LazyOp(LoadOps.BUFFER, (), MemBuffer(2, dtypes.float32, (View.create((4,)),)))
|
||||
|
||||
def test_flops_add(self):
|
||||
op0 = LazyOp(BinaryOps.ADD, (self.buf0,self.buf1,), None)
|
||||
|
||||
@@ -1,12 +1,10 @@
|
||||
from typing import NamedTuple, Optional, List, Tuple, cast, Dict
|
||||
import itertools
|
||||
from tinygrad.ops import LazyOp, MovementOps, FlopCounter, get_lazyop_info, ReduceOps
|
||||
from tinygrad.lazy import LazyBuffer
|
||||
from tinygrad.helpers import dedup, dtypes, colored, ImageDType, DType
|
||||
from tinygrad.runtime.lib import buf_is_kernel_arg
|
||||
from tinygrad.ops import LazyOp, FlopCounter, get_lazyop_info, ReduceOps, LoadOps, MemBuffer
|
||||
from tinygrad.helpers import dedup, dtypes, colored, ImageDType, DType, all_int
|
||||
from tinygrad.shape.shapetracker import ShapeTracker
|
||||
from tinygrad.shape.symbolic import sint
|
||||
from tinygrad.shape.view import strides_for_shape
|
||||
from tinygrad.shape.view import strides_for_shape, View
|
||||
|
||||
class LocalBuffer(NamedTuple):
|
||||
name: str
|
||||
@@ -16,6 +14,7 @@ class LocalBuffer(NamedTuple):
|
||||
def __str__(self): return f"localbuffer<{self.name}[{self.size}]>"
|
||||
|
||||
class LinearizerOptions(NamedTuple):
|
||||
device: str = ""
|
||||
# TODO: make this generic with a list of supported types
|
||||
supports_float4: bool = True
|
||||
supports_float4_alu: bool = True
|
||||
@@ -26,41 +25,32 @@ class LinearizerOptions(NamedTuple):
|
||||
local_max: Optional[List[int]] = None
|
||||
|
||||
class Kernel:
|
||||
def __init__(self, ast:LazyOp, output_buffer:LazyBuffer, opts:Optional[LinearizerOptions]=None):
|
||||
# NOTE: if there's a RESHAPE, we skip it. the output shape is set from the reduce op or a latebuf
|
||||
self.ast = ast.src[0] if ast.op == MovementOps.RESHAPE else ast
|
||||
def __init__(self, ast:LazyOp, opts:Optional[LinearizerOptions]=None, var_vals=None):
|
||||
self.opts = opts if opts else LinearizerOptions()
|
||||
|
||||
# get the output buffers
|
||||
self.bufs = [output_buffer] + dedup(ast.buffers)
|
||||
self.arg_bufs = {x:f"data{i}" for i,x in enumerate(dedup([x.realized for x in self.bufs if buf_is_kernel_arg(x)]))}
|
||||
|
||||
# key for lookup in cache (can change, str might not be right)
|
||||
# bufs are needed because kernels like f(x) = x + x and f(x, y) = x + y have the same str(ast), but are different kernels.
|
||||
# mapping the buffers to integers is required because a-b != b-a (and how would you tell a and b apart?)
|
||||
self.key = (ast.map_buffers({x:self.arg_bufs.get(x.realized,x) for x in self.bufs}).key, tuple([x.key for x in self.bufs]))
|
||||
self.ast = ast
|
||||
self.var_vals = var_vals
|
||||
self.key = (ast, tuple(var_vals.keys())) if var_vals else ast
|
||||
|
||||
def process(self) -> None:
|
||||
if hasattr(self, "sts"): return # already processed
|
||||
|
||||
# fetch lazyop info
|
||||
self.info: FlopCounter = get_lazyop_info(cast(LazyOp, self.ast))
|
||||
self.mem_estimate: int = sum(x.dtype.itemsize*x.size for x in self.arg_bufs.keys())
|
||||
|
||||
# there's only allowed to be one reduceop
|
||||
reduceops = [x for x in self.ast.get_lazyops() if x.op in ReduceOps]
|
||||
assert len(dedup(reduceops)) <= 1, "max one reduce op in an ast"
|
||||
self.reduceop = reduceops[0] if reduceops else None
|
||||
|
||||
# get earlybufs, before the one reduce op
|
||||
self.earlybufs = dedup(self.reduceop.buffers) if self.reduceop else []
|
||||
|
||||
# create new shapetrackers inside this kernel, we will permute them
|
||||
self.sts: List[ShapeTracker] = [x.st.copy() for x in self.bufs]
|
||||
self.bufs = [MemBuffer(0, self.info.dtype, (View.create(self.info.shape),))] + [x.arg for x in self.ast.get_lazyops() if x.op in LoadOps]
|
||||
self.sts: List[ShapeTracker] = [ShapeTracker(x.views[-1].shape, views=list(x.views)) for x in self.bufs]
|
||||
for st in self.sts: st.simplify()
|
||||
|
||||
# make the output buffer shape correct in here
|
||||
self.sts[0].reshape(self.info.shape)
|
||||
self.mem_estimate: int = sum(x.dtype.itemsize*x.views[-1].size() for x in self.bufs)
|
||||
|
||||
# get earlybufs, before the one reduce op
|
||||
self.earlybufs = [x.arg for x in self.reduceop.get_lazyops() if x.op in LoadOps] if self.reduceop else []
|
||||
self.full_buf_index: int = self.bufs.index(self.earlybufs[0]) if self.earlybufs else 0
|
||||
|
||||
# parameters
|
||||
@@ -77,7 +67,7 @@ class Kernel:
|
||||
|
||||
def has_variable_shape(self) -> bool:
|
||||
for b in self.bufs:
|
||||
if any(not isinstance(x, int) for x in b.st.shape): return True
|
||||
if not all_int(b.views[-1].shape): return True
|
||||
return False
|
||||
|
||||
def shape_offsets(self, i): return itertools.product(*[list(range(s)) for s in self.sts[i].shape[self.shape_len-self.upcasted:][::-1]]) if self.upcasted > 0 else [tuple()]
|
||||
@@ -147,6 +137,6 @@ class Kernel:
|
||||
def colored_shape(self) -> str: return ' '.join(colored(s, color) for s,color in zip([f"{s:4d}" if isinstance(s, int) else s for s in self.full_shape], self.colors()))
|
||||
def printbufs(self, prefix=""):
|
||||
for i,st in enumerate(self.sts):
|
||||
print(prefix, f"{i:3d} {str(self.bufs[i].realized) if self.bufs[i].realized is not None else str(self.bufs[i]):47s}", st.views)
|
||||
print(prefix, f"{i:3d} {str(self.bufs[i]):47s}", st.views)
|
||||
print(self.colored_shape())
|
||||
|
||||
|
||||
@@ -5,9 +5,8 @@ from collections import defaultdict
|
||||
from enum import Enum, auto
|
||||
|
||||
from tinygrad.helpers import colored, ImageDType, DEBUG, dtypes, DType, prod, PtrDType, all_same
|
||||
from tinygrad.ops import LazyOp, UnaryOps
|
||||
from tinygrad.ops import LazyOp, UnaryOps, LoadOps, ConstBuffer, MemBuffer
|
||||
from tinygrad.ops import ReduceOps, BinaryOps, TernaryOps
|
||||
from tinygrad.runtime.lib import RawConst
|
||||
from tinygrad.shape.shapetracker import ShapeTracker
|
||||
from tinygrad.shape.symbolic import Variable, NumNode, Node, SumNode, MulNode, DivNode, ModNode, LtNode, AndNode, sym_rename
|
||||
from tinygrad.codegen.optimizer import OptimizedKernel
|
||||
@@ -128,11 +127,6 @@ def get_grouped_dims(prefix, start_dim, local_dims, maxdim:int=0):
|
||||
return local_idxs, [x for x in loop_local_idxs if not isinstance(x, NumNode)]
|
||||
|
||||
class Linearizer(OptimizedKernel):
|
||||
def get_buffer_name(self, i):
|
||||
if self.bufs[i].__class__ == LocalBuffer: return self.bufs[i].name
|
||||
assert self.bufs[i].realized.__class__ is not RawConst # constants shouldn't be loaded with memops
|
||||
return self.arg_bufs[self.bufs[i].realized]
|
||||
|
||||
def uop_alu_idx(self, a:UOp, b, ops, ctx:Linearizer, op, dtype=dtypes.int32):
|
||||
render_b:UOp = cast(UOp, (NumNode(b) if not isinstance(b, Node) else b).render(ops, ctx))
|
||||
return self.uop(UOps.ALU, dtype, (a, render_b), op)
|
||||
@@ -147,7 +141,7 @@ class Linearizer(OptimizedKernel):
|
||||
AndNode: lambda self,ops,ctx: functools.reduce(lambda a,b: ctx.uop_alu_idx(a, b, ops, ctx, BinaryOps.MUL, dtype=dtypes.bool), self.nodes[1:], self.nodes[0].render(ops,ctx)) }
|
||||
|
||||
def global_load(self, i:int, idxs:Sequence[VariableOrNum], acc=None) -> List[UOp]:
|
||||
const = self.bufs[i].realized._buf if isinstance(self.bufs[i].realized, RawConst) else acc
|
||||
const = self.bufs[i].val if isinstance(self.bufs[i], ConstBuffer) else acc
|
||||
|
||||
expanded_nodes = [idx.expand() for idx in idxs]
|
||||
_idxs = [x[::-1] for x in itertools.product(*expanded_nodes[::-1])]
|
||||
@@ -176,7 +170,7 @@ class Linearizer(OptimizedKernel):
|
||||
idx, valid = g_idx.substitute(substitute), g_valid.substitute(substitute)
|
||||
localtype = dtypes.float32
|
||||
this_const, idx, valid = (invalid_value, Variable.num(0), Variable.num(1)) if valid.max == 0 else (const, idx, valid)
|
||||
key = f"{acc}{localtype}{this_const if this_const is not None and acc is None else self.get_buffer_name(i)}{idx.render()}{valid.render()}"
|
||||
key = f"{acc}{localtype}{this_const if this_const is not None and acc is None else (self.bufs[i].idx if isinstance(self.bufs[i], MemBuffer) else self.bufs[i].name)}{idx.render()}{valid.render()}"
|
||||
if key not in self.load_cache:
|
||||
if acc is not None:
|
||||
assert valid.min == 1
|
||||
@@ -253,15 +247,13 @@ class Linearizer(OptimizedKernel):
|
||||
self.loop_uops: Dict[str, UOp] = {}
|
||||
|
||||
# add global buffers
|
||||
arg_bufs = {}
|
||||
for buf,name in self.arg_bufs.items():
|
||||
arg_bufs[buf] = self.uop(UOps.DEFINE_GLOBAL, PtrDType(buf.dtype) if not isinstance(buf.dtype, ImageDType) else buf.dtype, (), (name, buf.dtype))
|
||||
for i,b in enumerate(self.bufs):
|
||||
if b.realized in arg_bufs: self.buf_uops[i] = arg_bufs[b.realized]
|
||||
# add variables from symbolic shapes
|
||||
for var in sorted(set(v for buf in self.ast.buffers for v in buf.var_vals), key=lambda k: k.key):
|
||||
assert var.expr is not None
|
||||
self.loop_uops[var.expr] = self.uop(UOps.DEFINE_GLOBAL, dtypes.int32, (), (var.expr, dtypes._arg_int32))
|
||||
for i,buf in enumerate(self.bufs):
|
||||
if isinstance(buf, MemBuffer):
|
||||
self.buf_uops[i] = self.uop(UOps.DEFINE_GLOBAL, PtrDType(buf.dtype) if not isinstance(buf.dtype, ImageDType) else buf.dtype, (), (f"data{buf.idx}", buf.dtype))
|
||||
if self.var_vals:
|
||||
for var in sorted(set(self.var_vals), key=lambda k: k.key):
|
||||
assert var.expr is not None
|
||||
self.loop_uops[var.expr] = self.uop(UOps.DEFINE_GLOBAL, dtypes.int32, (), (var.expr, dtypes._arg_int32))
|
||||
# define local buffers
|
||||
for lb in self.local_alias.values():
|
||||
self.buf_uops[self.bufs.index(lb)] = self.uop(UOps.DEFINE_LOCAL, PtrDType(dtypes.float32), (), (lb.name, self.sts[self.bufs.index(lb)].size()))
|
||||
@@ -491,7 +483,8 @@ class Linearizer(OptimizedKernel):
|
||||
return self.uops[-1]
|
||||
|
||||
def ast_parse(self, x, acc, loaded_buffers, do_reduce=False) -> List[UOp]:
|
||||
if x.__class__ is not LazyOp: return loaded_buffers[x]
|
||||
if x.__class__ is not LazyOp: return loaded_buffers[x] # for LOCAL_BUFFER
|
||||
if x.op in [LoadOps.BUFFER, LoadOps.CONST]: return loaded_buffers[x.arg]
|
||||
if x.op in [UnaryOps.NOOP, UnaryOps.CAST]: return self.ast_parse(x.src[0], acc, loaded_buffers) # cast isn't an ALU op
|
||||
if x.op in ReduceOps and not do_reduce: return acc
|
||||
# MULACC fusion. TODO: this is copied from Interpreted
|
||||
|
||||
@@ -1,9 +1,8 @@
|
||||
from typing import Tuple, List, cast, Optional
|
||||
import itertools, math, os
|
||||
from tinygrad.helpers import DEBUG, prod, getenv, ImageDType, dtypes
|
||||
from tinygrad.ops import ReduceOps, BinaryOps, UnaryOps, LazyOp
|
||||
from tinygrad.ops import ReduceOps, BinaryOps, UnaryOps, LazyOp, LoadOps
|
||||
from tinygrad.codegen.kernel import Kernel, LocalBuffer
|
||||
from tinygrad.lazy import LazyBuffer
|
||||
from tinygrad.shape.shapetracker import ShapeTracker
|
||||
from tinygrad.shape.view import View
|
||||
|
||||
@@ -167,11 +166,11 @@ class OptimizedKernel(Kernel):
|
||||
self.simplify_ones()
|
||||
|
||||
# should use HIP tensor cores?
|
||||
if getenv("TC", 1) != 0 and self.bufs[0].device == "HIP" and self.reduceop and self.reduceop.op == ReduceOps.SUM and \
|
||||
if getenv("TC", 1) != 0 and self.opts.device == "HIP" and self.reduceop and self.reduceop.op == ReduceOps.SUM and \
|
||||
isinstance(self.reduceop.src[0], LazyOp) and self.reduceop.src[0].op == UnaryOps.CAST and \
|
||||
isinstance(self.reduceop.src[0].src[0], LazyOp) and self.reduceop.src[0].src[0].op == BinaryOps.MUL and \
|
||||
isinstance(self.reduceop.src[0].src[0].src[0], LazyBuffer) and isinstance(self.reduceop.src[0].src[0].src[1], LazyBuffer) and self.opts.has_local and \
|
||||
self.reduceop.src[0].src[0].src[0].dtype == dtypes.half and self.reduceop.src[0].src[0].src[1].dtype == dtypes.half:
|
||||
self.reduceop.src[0].src[0].src[0].op == LoadOps.BUFFER and self.reduceop.src[0].src[0].src[1].op == LoadOps.BUFFER and self.opts.has_local and \
|
||||
cast(LazyOp, self.reduceop.src[0].src[0].src[0]).arg.dtype == dtypes.half and cast(LazyOp, self.reduceop.src[0].src[0].src[1]).arg.dtype == dtypes.half:
|
||||
# HIP tensor cores are 16x16x16
|
||||
buf0 = self.bufs.index(self.reduceop.src[0].src[0].src[0])
|
||||
buf1 = self.bufs.index(self.reduceop.src[0].src[0].src[1])
|
||||
@@ -227,10 +226,10 @@ class OptimizedKernel(Kernel):
|
||||
|
||||
# should use METAL tensor cores?
|
||||
# first, confirm it's a straightforward mulacc on a device with real locals
|
||||
tensor_cores_allowed = getenv("TC", 1) != 0 and (getenv("TC", 1) == 2 or (self.bufs[0].device == "METAL" and os.uname().machine == "arm64"))
|
||||
tensor_cores_allowed = getenv("TC", 1) != 0 and (getenv("TC", 1) == 2 or (self.opts.device == "METAL" and os.uname().machine == "arm64"))
|
||||
if tensor_cores_allowed and self.reduceop and self.reduceop.op == ReduceOps.SUM and \
|
||||
isinstance(self.reduceop.src[0], LazyOp) and self.reduceop.src[0].op == BinaryOps.MUL and \
|
||||
isinstance(self.reduceop.src[0].src[0], LazyBuffer) and isinstance(self.reduceop.src[0].src[1], LazyBuffer) and self.opts.has_local:
|
||||
self.reduceop.src[0].src[0].op == LoadOps.BUFFER and self.reduceop.src[0].src[1] == LoadOps.BUFFER and self.opts.has_local:
|
||||
# METAL tensor cores are 8x8x8, with 2 elements per thread in the 32 thread warp
|
||||
buf0 = self.bufs.index(self.reduceop.src[0].src[0])
|
||||
buf1 = self.bufs.index(self.reduceop.src[0].src[1])
|
||||
|
||||
@@ -21,7 +21,7 @@ def kernel_optimize_opts(k:Linearizer):
|
||||
opts.append(ng.p.TransitionChoice([(i,s,"G") for s in get_divisors(k.full_shape[k.first_reduce+i], min_div=4) if all(st.shape[k.first_reduce+i] % s == 0 or st.shape[k.first_reduce+i] == 1 for st in k.sts)]))
|
||||
return opts
|
||||
|
||||
def kernel_optimize_search(k:Linearizer, create_k:Callable[[], Linearizer], to_prg, baseline):
|
||||
def kernel_optimize_search(k:Linearizer, create_k:Callable[[], Linearizer], to_prg, baseline, bufs):
|
||||
import nevergrad as ng
|
||||
def opt(x):
|
||||
try:
|
||||
@@ -29,9 +29,9 @@ def kernel_optimize_search(k:Linearizer, create_k:Callable[[], Linearizer], to_p
|
||||
k.process()
|
||||
k.apply_auto_opt(x)
|
||||
prg = to_prg(k)
|
||||
first_tm = prg.exec(k.bufs, force_wait=True, optimizing=True)
|
||||
first_tm = prg.exec(bufs, force_wait=True, optimizing=True)
|
||||
if baseline*5 < first_tm*1000: return first_tm*1000 # very slow
|
||||
tm = min([first_tm]+[prg.exec(k.bufs, force_wait=True, optimizing=True) for _ in range(2)])*1000
|
||||
tm = min([first_tm]+[prg.exec(bufs, force_wait=True, optimizing=True) for _ in range(2)])*1000
|
||||
return tm
|
||||
except Exception:
|
||||
if DEBUG >= 3:
|
||||
@@ -51,7 +51,7 @@ def kernel_optimize_search(k:Linearizer, create_k:Callable[[], Linearizer], to_p
|
||||
|
||||
# optimization
|
||||
global_db = None
|
||||
def kernel_optimize(k:Linearizer, create_k:Callable[[], Linearizer], to_prg):
|
||||
def kernel_optimize(k:Linearizer, create_k:Callable[[], Linearizer], to_prg, bufs):
|
||||
global global_db
|
||||
|
||||
k.process()
|
||||
@@ -72,8 +72,8 @@ def kernel_optimize(k:Linearizer, create_k:Callable[[], Linearizer], to_prg):
|
||||
k = create_k()
|
||||
k.hand_coded_optimizations()
|
||||
prg = to_prg(k)
|
||||
return min([prg.exec(k.bufs, force_wait=True, optimizing=True) for _ in range(5)])*1000
|
||||
choice = kernel_optimize_search(k, create_k, to_prg, get_baseline())
|
||||
return min([prg.exec(bufs, force_wait=True, optimizing=True) for _ in range(5)])*1000
|
||||
choice = kernel_optimize_search(k, create_k, to_prg, get_baseline(), bufs)
|
||||
if global_db is not None:
|
||||
global_db[skey] = choice
|
||||
global_db.sync()
|
||||
|
||||
@@ -1,16 +1,16 @@
|
||||
from __future__ import annotations
|
||||
import sys, operator, math
|
||||
from typing import Callable, Optional, Tuple, Union, List, Dict, Any, cast
|
||||
from typing import Callable, Optional, Tuple, Union, List, Dict, Any, cast, Mapping
|
||||
from weakref import ref, WeakSet, WeakValueDictionary
|
||||
|
||||
import numpy as np
|
||||
from tinygrad.graph import log_op
|
||||
from tinygrad.helpers import GRAPH, DEBUG, prod, getenv, DType, dtypes, flatten, ImageDType, partition, all_int
|
||||
from tinygrad.ops import Device, Compiled, UnaryOps, BinaryOps, TernaryOps, ReduceOps, MovementOps, LoadOps, OpType, LazyOp
|
||||
from tinygrad.helpers import GRAPH, DEBUG, prod, getenv, DType, dtypes, flatten, ImageDType, partition, all_int, dedup, merge_dicts
|
||||
from tinygrad.ops import Device, Compiled, UnaryOps, BinaryOps, TernaryOps, ReduceOps, MovementOps, LoadOps, OpType, LazyOp, MemBuffer, ConstBuffer
|
||||
from tinygrad.shape.shapetracker import ShapeTracker, View, get_contraction
|
||||
from tinygrad.shape.symbolic import Variable, sint
|
||||
|
||||
from tinygrad.runtime.lib import RawConst, RawBuffer, RawBufferMapped, RawBufferTransfer
|
||||
from tinygrad.runtime.lib import RawConst, RawBuffer, RawBufferMapped, RawBufferTransfer, buf_is_kernel_arg
|
||||
from tinygrad.runtime.ops_cpu import RawNumpyBuffer
|
||||
from tinygrad.runtime.ops_disk import RawDiskBuffer
|
||||
|
||||
@@ -90,6 +90,20 @@ def _ast_binaryops(self:LazyBuffer) -> LazyOp:
|
||||
ast = self.op.map_buffers(cast(Dict[LazyBuffer, Union[LazyOp, LazyBuffer, str]], real_srcs))
|
||||
return LazyOp(MovementOps.RESHAPE, (ast, ), self.shape) if intermediate_shape != self.shape else ast
|
||||
|
||||
def _replace_loadops(op:LazyOp) -> Tuple[LazyOp, List[LazyBuffer]]:
|
||||
replacements:Dict[LazyBuffer, LazyOp] = {}
|
||||
realized_bufs = dedup([x.realized for x in op.buffers if buf_is_kernel_arg(x)])
|
||||
for x in op.buffers:
|
||||
assert x.realized, "buffer isn't realized"
|
||||
x.st.simplify()
|
||||
if isinstance(x.realized, RawConst):
|
||||
replacements[x] = LazyOp(LoadOps.CONST, (), ConstBuffer(x.realized._buf, x.realized.dtype, tuple(x.st.views)))
|
||||
elif x.realized in realized_bufs:
|
||||
replacements[x] = LazyOp(LoadOps.BUFFER, (), MemBuffer(realized_bufs.index(x.realized)+1, x.realized.dtype, tuple(x.st.views)))
|
||||
else:
|
||||
raise NotImplementedError(f"not handled {x}")
|
||||
return (op.src[0] if op.op == MovementOps.RESHAPE else op).map_buffers(replacements), realized_bufs
|
||||
|
||||
# **** lazy operations ****
|
||||
|
||||
def get_single_root(root:LazyBuffer) -> LazyBuffer: return get_single_root(cast(LazyBuffer, root.op.src[0])) if getattr(root, 'op', None) and len(root.op.src) == 1 else root
|
||||
@@ -162,7 +176,9 @@ class LazyBuffer:
|
||||
else:
|
||||
self.op = LazyOp(UnaryOps.CAST, (self.op,), (dtypes.float32, False))
|
||||
self.dtype = dtypes.float32
|
||||
self.realized = Device[self.device].exec_ast(self.op, output=self, **self._device_extra_args())
|
||||
self.var_vals = dict(sorted(merge_dicts([buf.var_vals for buf in self.op.buffers]).items(), key=lambda kv:cast(Variable,kv[0]).key))
|
||||
op, realized_bufs = _replace_loadops(self.op)
|
||||
self.realized = Device[self.device].exec_ast(op, output=self, inputs=realized_bufs, var_vals=self.var_vals, **self._device_extra_args())
|
||||
|
||||
assert self.realized and isinstance(self.realized, (RawConst, Device[self.device].buffer)), f"device mismatch on realized got {type(self.realized)} expected {self.device}"
|
||||
# HACK: allow hot casting of images
|
||||
@@ -315,7 +331,7 @@ class LazyBuffer:
|
||||
|
||||
@property
|
||||
def buffers(self) -> Tuple[LazyBuffer, ...]: return (self,)
|
||||
def map_buffers(self, real_srcs: Dict[LazyBuffer, Union[LazyBuffer, LazyOp, str]]): return real_srcs.get(self, self)
|
||||
def map_buffers(self, real_srcs: Mapping[LazyBuffer, Union[LazyBuffer, LazyOp, str]]): return real_srcs.get(self, self)
|
||||
def get_lazyops(self) -> List[LazyOp]: return []
|
||||
def replace_with_movement_ops(self: LazyBuffer, ops:List[Tuple[MovementOps, Any]]) -> LazyBuffer:
|
||||
y = self
|
||||
|
||||
@@ -1,8 +1,10 @@
|
||||
from __future__ import annotations
|
||||
import time, importlib, inspect, functools, pathlib
|
||||
from enum import Enum, auto
|
||||
from typing import TYPE_CHECKING, Union, Type, Tuple, Any, List, Optional, Dict, Callable, cast
|
||||
from tinygrad.helpers import ansilen, prod, DEBUG, getenv, GlobalCounters, DType, colored, dedup, merge_dicts
|
||||
from typing import TYPE_CHECKING, Union, Type, Tuple, Any, List, Optional, Dict, Callable, cast, Mapping
|
||||
from tinygrad.helpers import ansilen, prod, DEBUG, getenv, GlobalCounters, DType, colored
|
||||
from tinygrad.shape.view import View
|
||||
from dataclasses import dataclass
|
||||
if TYPE_CHECKING: from tinygrad.lazy import LazyBuffer
|
||||
|
||||
# these are the llops your accelerator must implement, along with toCpu
|
||||
@@ -14,11 +16,23 @@ class BinaryOps(Enum): ADD = auto(); SUB = auto(); MUL = auto(); DIV = auto(); M
|
||||
class ReduceOps(Enum): SUM = auto(); MAX = auto() # noqa: E702
|
||||
class TernaryOps(Enum): MULACC = auto(); WHERE = auto() # noqa: E702
|
||||
class MovementOps(Enum): RESHAPE = auto(); PERMUTE = auto(); EXPAND = auto(); PAD = auto(); SHRINK = auto(); STRIDE = auto(); AS_STRIDED = auto() # noqa: E702
|
||||
class LoadOps(Enum): EMPTY = auto(); RAND = auto(); CONST = auto(); FROM = auto(); CONTIGUOUS = auto(); CUSTOM = auto() # noqa: E702
|
||||
class LoadOps(Enum): EMPTY = auto(); RAND = auto(); CONST = auto(); FROM = auto(); CONTIGUOUS = auto(); CUSTOM = auto(); BUFFER = auto() # noqa: E702
|
||||
|
||||
Op = Union[UnaryOps, BinaryOps, ReduceOps, MovementOps, LoadOps, TernaryOps]
|
||||
OpType = Union[Type[UnaryOps], Type[BinaryOps], Type[ReduceOps], Type[MovementOps], Type[LoadOps], Type[TernaryOps]]
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class MemBuffer:
|
||||
idx: int
|
||||
dtype: DType
|
||||
views: Tuple[View, ...]
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ConstBuffer:
|
||||
val: Any
|
||||
dtype: DType
|
||||
views: Tuple[View, ...]
|
||||
|
||||
class LazyOp:
|
||||
__slots__ = "op", "src", "arg", "buffers", "__weakref__"
|
||||
op: Op
|
||||
@@ -37,7 +51,7 @@ class LazyOp:
|
||||
@property
|
||||
def key(self): return (self.op, tuple(map(lambda x: getattr(x, "key", x), self.src)), getattr(self.arg, "key", self.arg))
|
||||
|
||||
def map_buffers(self, real_srcs: Dict[LazyBuffer, Union[LazyBuffer, LazyOp, str]]) -> LazyOp: return LazyOp(self.op, tuple([y.map_buffers(real_srcs) for y in self.src]), self.arg)
|
||||
def map_buffers(self, real_srcs: Mapping[LazyBuffer, Union[LazyBuffer, LazyOp, str]]) -> LazyOp: return LazyOp(self.op, tuple([y.map_buffers(real_srcs) for y in self.src]), self.arg)
|
||||
def get_lazyops(self) -> List[LazyOp]: return [self] + [item for x in self.src for item in x.get_lazyops()]
|
||||
|
||||
def replace_with_movement_ops(self:LazyOp, ops:List[Tuple[MovementOps, Tuple[Any, ...]]]) -> 'LazyBuffer':
|
||||
@@ -87,9 +101,8 @@ Device = _Device()
|
||||
|
||||
# **************** for Interpreted Buffers ****************
|
||||
|
||||
def apply_shapetracker(fxn_for_op, ret, st):
|
||||
st.simplify() # TODO: this is generic for Compiled too
|
||||
for v in st.views:
|
||||
def apply_shapetracker(fxn_for_op, ret, views):
|
||||
for v in views:
|
||||
real_shape = tuple(y-x for x,y in v.mask) if v.mask else v.shape
|
||||
real_offset = v.offset + (sum(x*st for (x,_),st in zip(v.mask, v.strides)) if v.mask else 0)
|
||||
# first, we apply the offset
|
||||
@@ -111,20 +124,22 @@ def apply_shapetracker(fxn_for_op, ret, st):
|
||||
return ret
|
||||
|
||||
class Interpreted:
|
||||
def __init__(self, buffer, fxn_for_op: Dict[Op, Callable], from_lazybuffer=None, to_underlying=lambda x: x._buf, from_underlying=None):
|
||||
def __init__(self, buffer, fxn_for_op: Dict[Op, Callable], to_underlying=lambda x: x._buf, from_underlying=None):
|
||||
self.buffer, self.fxn_for_op, self.to_underlying = buffer, fxn_for_op, to_underlying
|
||||
self.from_underlying = buffer if from_underlying is None else from_underlying
|
||||
self.from_lazybuffer = from_lazybuffer if from_lazybuffer is not None else lambda x: self.from_underlying(apply_shapetracker(self.fxn_for_op, self.to_underlying(x.realized), x.st))
|
||||
self.synchronize = lambda: None
|
||||
self.codegen = None
|
||||
|
||||
def exec_ast(self, ast:LazyOp, output=None, context=None, **kwargs):
|
||||
def exec_ast(self, ast:LazyOp, output=None, inputs=None, var_vals=None, context=None, **kwargs):
|
||||
if ast.op == LoadOps.BUFFER and LoadOps.BUFFER not in self.fxn_for_op:
|
||||
assert inputs[ast.arg.idx-1].dtype == ast.arg.dtype, "dtype mismatch"
|
||||
return self.from_underlying(apply_shapetracker(self.fxn_for_op, self.to_underlying(inputs[ast.arg.idx-1]), ast.arg.views))
|
||||
if TernaryOps.MULACC in self.fxn_for_op and ast.op == ReduceOps.SUM and isinstance(ast.src[0], LazyOp) and ast.src[0].op == BinaryOps.MUL:
|
||||
ast = LazyOp(TernaryOps.MULACC, ast.src[0].src, ast.arg)
|
||||
created_context = context is None
|
||||
if context is None: context = dict()
|
||||
if not created_context and ast in context: return context[ast]
|
||||
srcs = [self.exec_ast(x, context=context, **kwargs) if isinstance(x, LazyOp) else self.from_lazybuffer(x) for x in ast.src]
|
||||
srcs = [self.exec_ast(cast(LazyOp, x), inputs=inputs, context=context, **kwargs) for x in ast.src]
|
||||
if DEBUG >= 3: st = time.perf_counter()
|
||||
ret = self.from_underlying(self.fxn_for_op[ast.op](*([self.to_underlying(x) for x in srcs] + ([ast.arg] if ast.arg is not None else []))))
|
||||
if output is not None and ret.dtype != output.dtype and UnaryOps.CAST in self.fxn_for_op: ret = self.from_underlying(self.fxn_for_op[UnaryOps.CAST](self.to_underlying(ret), (output.dtype, False))) # Do manual casting of ret if it does not match the required output dtype.
|
||||
@@ -147,17 +162,18 @@ class FlopCounter:
|
||||
self.flops, ret = 0, self.flops
|
||||
return ret
|
||||
shape_fxn_for_op: Dict[Op, Callable] = {
|
||||
LoadOps.BUFFER: lambda arg: (arg.views[-1].shape, arg.dtype, 0), LoadOps.CONST: lambda arg: (arg.views[-1].shape, arg.dtype, 0),
|
||||
UnaryOps.CAST: lambda self,arg: (self.shape, arg[0], self.consume_flops()), # cast uses no flops
|
||||
**{op:lambda self: (self.shape, self.dtype, self.consume_flops() + prod(self.shape)) for op in UnaryOps if op != UnaryOps.CAST},
|
||||
**{op:lambda self,y: (self.shape, max(self.dtype, y.dtype), self.consume_flops() + y.consume_flops() + prod(self.shape)) for op in BinaryOps},
|
||||
**{op:lambda self,new_shape: (new_shape, self.dtype, self.consume_flops() + prod(self.shape)) for op in ReduceOps},
|
||||
TernaryOps.WHERE: lambda self,y,z: (self.shape, self.dtype, self.consume_flops() + y.consume_flops() + z.consume_flops() + prod(self.shape))}
|
||||
InterpretedFlopCounter = Interpreted(FlopCounter, shape_fxn_for_op, lambda x: FlopCounter((x.shape, x.dtype, 0)), lambda x: x)
|
||||
InterpretedFlopCounter = Interpreted(FlopCounter, shape_fxn_for_op, lambda x: x)
|
||||
def get_lazyop_info(ast:LazyOp) -> FlopCounter: return InterpretedFlopCounter.exec_ast(ast)
|
||||
|
||||
# **************** for Compiled Buffers ****************
|
||||
|
||||
from tinygrad.runtime.lib import RawBuffer, RawConst, buf_is_kernel_arg
|
||||
from tinygrad.runtime.lib import RawBuffer, RawConst
|
||||
from tinygrad.shape.symbolic import Variable, sym_infer
|
||||
|
||||
class ASTRunner:
|
||||
@@ -169,9 +185,8 @@ class ASTRunner:
|
||||
self.clprg = runtime(self.name, self.prg, **self.runtime_args)
|
||||
return self
|
||||
|
||||
def exec(self, bufs, var_vals:Optional[Dict[Variable, int]]=None, force_wait=False, optimizing=False) -> Optional[float]:
|
||||
def exec(self, rawbufs, var_vals:Optional[Dict[Variable, int]]=None, force_wait=False, optimizing=False) -> Optional[float]:
|
||||
from tinygrad.jit import CacheCollector
|
||||
rawbufs = dedup([x.realized for x in bufs if buf_is_kernel_arg(x)])
|
||||
if not optimizing: CacheCollector.add(self, rawbufs, var_vals if var_vals is not None else {})
|
||||
return self(rawbufs, var_vals, force_wait=force_wait)
|
||||
|
||||
@@ -205,33 +220,38 @@ class Compiled:
|
||||
op_estimate=k.info.flops, mem_estimate=k.mem_estimate,
|
||||
display_name=k.display_name, runtime_args={"binary": False}).build(self.runtime)
|
||||
|
||||
def exec_ast(self, ast:LazyOp, output, **kwargs):
|
||||
def exec_ast(self, ast:LazyOp, output, inputs, var_vals, **kwargs):
|
||||
#if DEBUG >= 4:
|
||||
# from extra.utils import print_tree
|
||||
# print_tree(ast)
|
||||
|
||||
# check if we can reuse the output buffer
|
||||
# if it's aliased, don't use it
|
||||
# NOTE: this is pretty wrong actually, who knows where else this buffer is used?
|
||||
output.realized = output.output_buffer
|
||||
if output.realized:
|
||||
if output.realized.__class__ is RawConst: output.realized = None # can't assign to RawConst
|
||||
for a in ast.buffers:
|
||||
if a.realized == output.realized and not a.st.contiguous:
|
||||
output.realized = None
|
||||
break
|
||||
for i,a in enumerate(inputs):
|
||||
# TODO: if this is contiguous it's fine
|
||||
if a == output.realized:
|
||||
views = [x.arg.views for x in ast.get_lazyops() if x.op == LoadOps.BUFFER and x.arg.idx == i+1]
|
||||
if any(len(v) > 1 or not v[0].contiguous for v in views):
|
||||
output.realized = None
|
||||
break
|
||||
|
||||
# we don't have an output buffer, we have to create it, and create to max size if it has symbolic shape
|
||||
if not output.realized: output.realized = self.buffer(prod((s if isinstance(s, int) else s.max for s in output.shape)), output.dtype, **kwargs)
|
||||
else:
|
||||
from tinygrad.jit import CacheCollector
|
||||
CacheCollector._mark_output_buffer(output.output_buffer)
|
||||
# update the output var_vals from src
|
||||
output.var_vals = dict(sorted(merge_dicts([buf.var_vals for buf in ast.buffers]).items(), key=lambda kv:cast(Variable,kv[0]).key))
|
||||
|
||||
from tinygrad.codegen.linearizer import Linearizer
|
||||
k = Linearizer(ast, output, self.linearizer_opts)
|
||||
k = Linearizer(ast, self.linearizer_opts, var_vals)
|
||||
|
||||
# compilation time
|
||||
def get_program():
|
||||
from tinygrad.codegen.search import kernel_optimize
|
||||
if getenv("KOPT"): kernel_optimize(k, lambda: Linearizer(ast, output, self.linearizer_opts), self.to_program)
|
||||
if getenv("KOPT"): kernel_optimize(k, lambda: Linearizer(ast, self.linearizer_opts, var_vals), self.to_program, [output.realized]+inputs)
|
||||
elif not getenv("NOOPT"): k.hand_coded_optimizations()
|
||||
return self.to_program(k)
|
||||
|
||||
@@ -243,5 +263,5 @@ class Compiled:
|
||||
|
||||
if prg.name == getenv("PRINT_PRG", ''): print(prg.prg)
|
||||
|
||||
prg.exec(k.bufs, var_vals=output.var_vals)
|
||||
prg.exec([output.realized]+inputs, var_vals=var_vals)
|
||||
return output.realized
|
||||
|
||||
@@ -17,7 +17,7 @@ class RawBuffer: # pylint: disable=abstract-method
|
||||
def __del__(self): # NOTE: if it fails on init (bad dtype), it won't have a _memsz
|
||||
if hasattr(self, '_memsz'): GlobalCounters.mem_used -= self._memsz
|
||||
if hasattr(self, '_allocator') and self._allocator: self._allocator.free(self._buf)
|
||||
def __repr__(self): return f"buffer<{self.size}, {self.dtype}>"
|
||||
def __repr__(self): return f"buffer<{self.size}, {self.dtype}, {id(self)}>"
|
||||
@property
|
||||
def key(self): return (self.size, self.dtype)
|
||||
|
||||
|
||||
@@ -93,4 +93,4 @@ __device__ void vstore_half4(float4 data, size_t offset, half *p) { *(p + offset
|
||||
""",
|
||||
gid = [f'blockIdx.{chr(120+i)}' for i in range(3)],
|
||||
lid = [f'threadIdx.{chr(120+i)}' for i in range(3)]))
|
||||
HIPBuffer = Compiled(RawHIPBuffer, LinearizerOptions(), renderer, HIPProgram, hip.hipDeviceSynchronize)
|
||||
HIPBuffer = Compiled(RawHIPBuffer, LinearizerOptions(device="HIP"), renderer, HIPProgram, hip.hipDeviceSynchronize)
|
||||
|
||||
@@ -85,4 +85,4 @@ renderer = functools.partial(uops_to_cstyle, CStyleLanguage(
|
||||
barrier = "threadgroup_barrier(mem_flags::mem_threadgroup);", float4 = "float4", uses_ptr_arithmetic=True,
|
||||
gid = [f"gid.{chr(120+i)}" for i in range(3)], lid = [f"lid.{chr(120+i)}" for i in range(3)],
|
||||
extra_args = ['uint3 gid [[threadgroup_position_in_grid]]', 'uint3 lid [[thread_position_in_threadgroup]]']))
|
||||
MetalBuffer = Compiled(RawMetalBuffer, LinearizerOptions(), renderer, MetalProgram, METAL.synchronize)
|
||||
MetalBuffer = Compiled(RawMetalBuffer, LinearizerOptions(device="METAL"), renderer, MetalProgram, METAL.synchronize)
|
||||
|
||||
@@ -2,7 +2,7 @@ from __future__ import annotations
|
||||
import functools
|
||||
from typing import Tuple, List, Optional, NamedTuple
|
||||
from tinygrad.helpers import prod, all_int
|
||||
from tinygrad.shape.symbolic import NumNode, is_sym_int, sint
|
||||
from tinygrad.shape.symbolic import Node, NumNode, is_sym_int, sint
|
||||
|
||||
@functools.lru_cache(maxsize=None)
|
||||
def filter_strides(shape:Tuple[int, ...], strides:Tuple[int, ...]) -> Tuple[int, ...]:
|
||||
@@ -29,7 +29,7 @@ class View(NamedTuple):
|
||||
return View(shape, strides, offset, mask, contiguous)
|
||||
|
||||
@functools.lru_cache(maxsize=None) # pylint: disable=method-cache-max-size-none
|
||||
def size(self): return prod([s for s,st in zip(self.shape, self.strides) if st != 0])
|
||||
def size(self): return prod([s.max if isinstance(s, Node) else s for s,st in zip(self.shape, self.strides) if st != 0])
|
||||
|
||||
# MovementOps live here now
|
||||
|
||||
|
||||
Reference in New Issue
Block a user