that had bugs, force an order (#2411)

This commit is contained in:
George Hotz
2023-11-23 15:52:16 -08:00
committed by GitHub
parent 65f4e6971b
commit 193be14b6c
5 changed files with 11 additions and 14 deletions

View File

@@ -55,7 +55,7 @@ class TestSymbolicJit(unittest.TestCase):
symbolic = jf(a.reshape(3, vi), b.reshape(vi, 5)).numpy()
expected = f(a, b).numpy()
np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6)
assert len(jf.jit_cache) == 2
assert len(jf.jit_cache) == 2 or getattr(Device[Device.DEFAULT], "graph", None)
def test_attention(self):
def f(q, k, v): return Tensor.scaled_dot_product_attention(q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2)).realize()
@@ -68,7 +68,7 @@ class TestSymbolicJit(unittest.TestCase):
symbolic = jf(q, k.reshape(2, vi, 4, 8), v.reshape(2, vi, 4, 8)).reshape(2, 4, 1, 8).numpy()
expected = f(q, k, v).numpy()
np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6)
assert len(jf.jit_cache) == 6
assert len(jf.jit_cache) == 6 or getattr(Device[Device.DEFAULT], "graph", None)
def test_cat_dim0(self):
def f(a, b): return a.cat(b, dim=0).realize()

View File

@@ -169,7 +169,7 @@ class Linearizer(Kernel):
if isinstance(buf, MemBuffer):
self.buf_uops[i] = self.uop(UOps.DEFINE_GLOBAL, PtrDType(buf.dtype) if not isinstance(buf.dtype, ImageDType) else buf.dtype, (), (f"data{buf.idx}", buf.dtype))
# add var vals
for var in sorted(vars_from_ast(self.ast)):
for var in vars_from_ast(self.ast):
assert var.expr is not None
self.loop_uops[var.expr] = self.uop(UOps.DEFINE_GLOBAL, dtypes.int32, (), (var.expr, dtypes._arg_int32))
# define local buffers

View File

@@ -79,7 +79,8 @@ def _replace_bufferops(op:LazyOp) -> Tuple[LazyOp, List[LazyBuffer]]:
def get_movementroot(root:LazyBuffer, allow_contiguous=False) -> LazyBuffer: return get_movementroot(cast(LazyBuffer, root.op.src[0]), allow_contiguous) if not root.realized and (root.optype == MovementOps or (root.op.op == LoadOps.CONTIGUOUS and allow_contiguous and root.op.src[0].st.contiguous)) else root
def get_movementroot_contiguous(x:LazyBuffer) -> LazyBuffer: return get_movementroot_contiguous(cast(LazyBuffer, x.op.src[0])) if not x.realized and x.op.op == LoadOps.CONTIGUOUS else (get_movementroot(x, True) if x.optype == MovementOps and x.st.contiguous else x)
def vars_from_ast(ast:LazyOp) -> Set[Variable]: return set.union(*[x.arg.st.vars() for x in ast.get_lazyops() if x.op in BufferOps], set())
# NOTE: this is the canonical order
def vars_from_ast(ast:LazyOp) -> List[Variable]: return sorted(set.union(*[x.arg.st.vars() for x in ast.get_lazyops() if x.op in BufferOps], set()), key=lambda x: str(x.expr))
lazycache: WeakValueDictionary = WeakValueDictionary()
def create_lazybuffer(device:str, st:ShapeTracker, optype:OpType, op:LazyOp, dtype:DType, base:Optional[LazyBuffer]=None):

View File

@@ -1,7 +1,7 @@
from __future__ import annotations
import importlib, inspect, functools, pathlib, time, re
from enum import Enum, auto
from typing import TYPE_CHECKING, Union, Type, Tuple, Any, List, Optional, Dict, Callable, Mapping, Set
from typing import TYPE_CHECKING, Union, Type, Tuple, Any, List, Optional, Dict, Callable, Mapping
from tinygrad.helpers import ansilen, prod, DEBUG, getenv, GlobalCounters, DType, colored, BEAM, NOOPT, dedup, all_int
from tinygrad.runtime.lib import RawBuffer
from tinygrad.shape.symbolic import Variable, sym_infer, sint
@@ -184,10 +184,8 @@ class Interpreted:
def exec_ast(self, ast:LazyOp, output:LazyBuffer, inputs:Tuple[LazyBuffer, ...], var_vals:Dict[Variable, int], **kwargs):
if ast not in self.method_cache: self.method_cache[ast] = get_interpreted_fxn(self.fxn_for_op, ast)
rawbufs = [output.realized if output.realized is not None else output.output_buffer] + [x.realized for x in inputs]
if rawbufs[0] is None: rawbufs[0] = self.buffer.__new__(self.buffer)
self.method_cache[ast].exec(rawbufs, var_vals)
output.realized = rawbufs[0]
output.realized = output.output_buffer if output.output_buffer is not None else self.buffer.__new__(self.buffer)
self.method_cache[ast].exec([output.realized] + [x.realized for x in inputs], var_vals)
def get_interpreted_fxn(fxn_for_op:Dict[Op, Callable], ast:LazyOp) -> InterpretedASTRunner:
if DEBUG >= 3:
@@ -236,7 +234,7 @@ class CompiledASTRunner(JITRunner):
if DEBUG >= 4: print(prg)
self.name, self.prg, self.global_size, self.local_size, self.op_estimate, self.mem_estimate, self.display_name, self.runtime_args = \
name, prg, global_size, local_size, op_estimate, mem_estimate, display_name, runtime_args if runtime_args is not None else {}
self.vars: Set[Variable] = set()
self.vars: List[Variable] = []
if ast:
info = get_lazyop_info(ast)
self.op_estimate, self.mem_estimate = info.flops, info.mem_estimate
@@ -255,8 +253,6 @@ class CompiledASTRunner(JITRunner):
return global_size, local_size
def __call__(self, rawbufs:List[RawBuffer], var_vals:Dict[Variable, int], wait=False, jit=False) -> Optional[float]:
# filter the var_vals
var_vals = {k:var_vals[k] for k in sorted(self.vars)}
global_size, local_size = self.launch_dims(var_vals)
if global_size is not None and local_size is None and all_int(self.global_size): # type: ignore[arg-type]
# TODO: this is copied from get_program
@@ -266,7 +262,7 @@ class CompiledASTRunner(JITRunner):
lra = self.runtime_args.copy()
if global_size: lra['global_size'] = global_size
if local_size and 'local_size' not in lra: lra['local_size'] = local_size
et = self.clprg(*rawbufs, *var_vals.values(), **lra, wait=wait or DEBUG>=2)
et = self.clprg(*rawbufs, *[var_vals[k] for k in self.vars], **lra, wait=wait or DEBUG>=2)
update_stats(self.display_name if self.display_name is not None else self.name, self.op_estimate, self.mem_estimate, var_vals, et, len(rawbufs), jit, lra=lra)
return et

View File

@@ -114,7 +114,7 @@ class MetalGraph:
icb_command.setKernelBuffer_offset_atIndex_(b._buf, 0, i)
if i == 0: write_resources.append(b._buf)
else: read_resources.append(b._buf)
var_vals_keys = sorted(var_vals.keys())
var_vals_keys = list(var_vals.keys())
for i,v in enumerate(prg.vars):
icb_command.setKernelBuffer_offset_atIndex_(self.int_buf._buf, var_vals_keys.index(v)*4, len(ji.rawbufs)+i)
global_size, local_size = prg.launch_dims(var_vals)