no shapetracker in ops [run_process_replay] (#6117)

This commit is contained in:
George Hotz
2024-08-16 17:23:27 -07:00
committed by GitHub
parent 74ee9febec
commit 89c7989659
8 changed files with 54 additions and 53 deletions

View File

@@ -51,11 +51,11 @@ b = Buffer(DEVICE, 1, dtypes.int32).allocate().copyin(memoryview(bytearray(struc
# describe the computation
buf_1 = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.int32), (), 1)
buf_2 = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.int32), (), 2)
ld_1 = UOp(UOps.LOAD, dtypes.int32, (buf_1, *UOp.from_st(ShapeTracker.from_shape((1,)))))
ld_2 = UOp(UOps.LOAD, dtypes.int32, (buf_2, *UOp.from_st(ShapeTracker.from_shape((1,)))))
ld_1 = UOp(UOps.LOAD, dtypes.int32, (buf_1, *ShapeTracker.from_shape((1,)).to_uops()))
ld_2 = UOp(UOps.LOAD, dtypes.int32, (buf_2, *ShapeTracker.from_shape((1,)).to_uops()))
alu = ld_1 + ld_2
output_buf = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.int32), (), 0)
idx, valid = UOp.from_st(ShapeTracker.from_shape((1,)))
idx, valid = ShapeTracker.from_shape((1,)).to_uops()
st_0 = UOp(UOps.STORE, None, (output_buf, idx, alu, valid))
s = UOp(UOps.SINK, None, (st_0,))

View File

@@ -4,7 +4,7 @@ import functools, hashlib
from enum import Enum, auto
from dataclasses import dataclass
from tinygrad.helpers import dedup, pretty_print, prod
from tinygrad.ops import ReduceOps, UnaryOps, BinaryOps, TernaryOps, reduce_st, UOp, UOps
from tinygrad.ops import ReduceOps, UnaryOps, BinaryOps, TernaryOps, UOp, UOps
from tinygrad.dtype import ImageDType, PtrDType, dtypes, DType, ConstType
from tinygrad.shape.symbolic import Variable, sint
from tinygrad.shape.shapetracker import ShapeTracker
@@ -86,7 +86,7 @@ def verify_lazyop(ast:LazyOp) -> Dict[LazyOp, ShapeTracker]:
if op.op in ReduceOps:
axis = op.arg[-1] if op.op is ReduceOps.WMMA else op.arg
assert isinstance(axis, tuple) and all(isinstance(i, int) for i in axis), f"reduceop must have axis {op.arg}"
st = ShapeTracker.from_shape(reduce_st(sts[op.src[0]], axis))
st = ShapeTracker.from_shape(sts[op.src[0]].reduce(axis))
else:
# movementops are pushed to the edges with LOAD
# elementwise inherits shape
@@ -115,7 +115,7 @@ def to_uop(*a) -> UOp:
@functools.lru_cache(None)
def create_uop(lop:LazyOp) -> UOp:
if lop.op in BufferOps:
idx, valid = UOp.from_st(lop.arg.st)
idx, valid = lop.arg.st.to_uops()
membuf_dtype: DType = lop.arg.dtype
dtype = membuf_dtype.base if isinstance(membuf_dtype, ImageDType) else membuf_dtype
if lop.op is BufferOps.CONST:

View File

@@ -8,9 +8,9 @@ from typing import List, Optional, Union, cast
from tinygrad import nn, dtypes
from tinygrad.device import Device
from tinygrad.tensor import Tensor
from tinygrad.ops import BinaryOps, MetaOps, UnaryOps, UOps, verify_ast
from tinygrad.ops import BinaryOps, MetaOps, UnaryOps, UOps
from tinygrad.helpers import CI, DEBUG, FUSE_ARANGE, FUSE_CONV_BW, GlobalCounters, flatten, getenv, SPLIT_REDUCEOP
from tinygrad.codegen.kernel import Kernel
from tinygrad.codegen.kernel import Kernel, verify_ast
from tinygrad.engine.schedule import create_schedule
from tinygrad.engine.realize import run_schedule
from test.helpers import is_dtype_supported, Context

View File

@@ -4,7 +4,7 @@ from dataclasses import dataclass, replace
from collections import defaultdict
from typing import Literal, Optional, List, Tuple, Union, cast, Dict, Final, DefaultDict
from tinygrad.ops import BinaryOps, ReduceOps, UNSAFE_PAD_OPS, KernelInfo, BUFFER_UOPS, UOp, UOps, verify_ast, print_uops
from tinygrad.ops import BinaryOps, ReduceOps, UNSAFE_PAD_OPS, KernelInfo, BUFFER_UOPS, UOp, UOps, print_uops
from tinygrad.device import Device
from tinygrad.renderer import Renderer, TensorCore, Program
from tinygrad.dtype import DType, ImageDType, PtrDType
@@ -638,7 +638,7 @@ class Kernel:
if op.op in BUFFER_UOPS:
# for locals, we use the ShapeTracker that's in the srcs
st = op.src[-1].arg if op.src[0].op is UOps.DEFINE_LOCAL else self.sts[self.bufs.index(op)]
idx, valid = UOp.from_st(st if apply_to_st is None else apply_to_st(st))
idx, valid = (st if apply_to_st is None else apply_to_st(st)).to_uops()
if op.op is UOps.CONST: return replace(op, src=(valid,))
if op.op is UOps.STORE: return replace(op, src=(op.src[0], idx, fixup_ast(op.src[2], apply_to_st), valid))
return replace(op, src=tuple(fixup_ast(x, apply_to_st) for x in op.src[:-2])+(idx, valid))
@@ -697,7 +697,7 @@ class Kernel:
for i,(src,fix_st_fxn) in enumerate(zip(rsrc.src, [fix_st1, fix_st2])):
st_load = [self.sts[self.bufs.index(op)].real_strides() for op in rsrc.parents if op.op is UOps.LOAD]
local_shape = tuple(s if max(cast(int, x[i]) for x in st_load) != 0 else 1 for i,s in enumerate(ex_shape))
idx, valid = UOp.from_st(ShapeTracker.from_shape(local_shape).expand(ex_shape))
idx, valid = ShapeTracker.from_shape(local_shape).expand(ex_shape).to_uops()
membuf = UOp(UOps.DEFINE_LOCAL, PtrDType(tc.dtype_in), (), (f"temp{-(-1-i)}", idx.arg.real_size()))
local_store = fixup_ast(UOp(UOps.STORE, tc.dtype_in, (membuf, idx, src, valid)), fix_st_fxn)
srcs.append(UOp(UOps.LOAD, tc.dtype_in, (membuf, local_store, idx, valid)))
@@ -714,7 +714,7 @@ class Kernel:
start = UOp(UOps.REDUCE_AXIS, op.dtype, (fixup_ast(op.src[0], apply_to_st),), arg=(op.arg[0], axis))
local_shape = (1,) * self.global_dims + self.full_shape[self.global_dims:self.global_dims+self.local_dims+self.group_for_reduces] + \
(1,) * (self.first_upcast - self.group_for_reduces - self.first_reduce) + tuple([x[0] for x in self.upcasted_axis(0)])
idx, valid = UOp.from_st(ShapeTracker.from_shape(local_shape))
idx, valid = ShapeTracker.from_shape(local_shape).to_uops()
local_buffer = UOp(UOps.DEFINE_LOCAL, PtrDType(cast(DType, op.dtype)), (), ("temp1", idx.arg.real_size()))
local_load = UOp(UOps.LOAD, op.dtype, (local_buffer, UOp.store(local_buffer, idx, start, valid), idx, valid))
second_axis = tuple(range(self.first_reduce, self.first_reduce+self.group_for_reduces))
@@ -760,3 +760,30 @@ class Kernel:
key=lambda x: (x.op, x.src[0].arg)))
return Program(ansiname, src, self.opts.device, self.uops, mem_estimate=mem_bytes,
global_size=[1,1,1] if self.opts.has_local else None, local_size=[1,1,1] if self.opts.has_local else None)
# the living definition of UOps.ST_IDX and UOps.ST_VALID
def verify_ast(ast:UOp) -> Dict[UOp, ShapeTracker]:
assert ast.op is UOps.SINK and all(x.op is UOps.STORE for x in ast.src), "must be SINK"
sts: Dict[UOp, ShapeTracker] = {}
def assert_valid(op:UOp, st:ShapeTracker):
if op in sts or op.op in {UOps.DEFINE_LOCAL, UOps.DEFINE_GLOBAL}: return
# restore globals from the two stage reduce
if op.op is UOps.LOAD and op.src[0].op is UOps.DEFINE_LOCAL:
assert_valid(local_reduce:=op.src[1].src[2], op.src[-1].arg)
return sts.setdefault(op, sts[local_reduce])
for x in op.src: assert_valid(x, st)
# only reduceop is allowed to change shape, limited to turning n to 1
if op.op is UOps.REDUCE_AXIS: st = ShapeTracker.from_shape(sts[op.src[0]].reduce(op.arg[1][-1] if op.arg[0] is ReduceOps.WMMA else op.arg[1]))
else:
# movementops are pushed to the edges with ST_IDX, ST_VALID
# elementwise inherits shape
st = op.arg if op.op in {UOps.ST_IDX, UOps.ST_VALID} else sts[op.src[-1]]
for x in (op.src[1:] if op.op in BUFFER_UOPS else op.src):
if sts[x].shape != st.shape:
if prod(sts[x].shape) == prod(st.shape): raise AssertionError(f"found implicit reshape {x.op} {op.op} {sts[x].shape} != {st.shape}")
raise AssertionError(f"found implicit expand {x.op} {sts[x].shape} != {op.op} {st.shape} {prod(sts[x].shape)} != {prod(st.shape)}")
sts[op] = st
for out in ast.src: assert_valid(out, out.src[-1].arg)
shape_dims = [sorted(dedup(dims)) for dims in zip(*[x.shape for x in sts.values()])]
assert all(len(x) == 1 or (len(x) == 2 and x[0] == 1) for x in shape_dims), f"shapes must have either 1 or n in each dimension, {shape_dims}"
return sts

View File

@@ -2,7 +2,7 @@ import sys, pickle, atexit, importlib, contextlib
from collections import defaultdict, deque
from dataclasses import dataclass, field
from typing import Tuple, List, Dict, Optional, Set, DefaultDict, get_args
from tinygrad.ops import MetaOps, ReduceOps, UNSAFE_PAD_OPS, UnaryOps, reduce_st, UOp, UOps
from tinygrad.ops import MetaOps, ReduceOps, UNSAFE_PAD_OPS, UnaryOps, UOp, UOps
from tinygrad.engine.graph import log_lazybuffer, realized_lazybuffer
from tinygrad.helpers import GRAPH, DEBUG, MULTIOUTPUT, SAVE_SCHEDULE, FUSE_CONV_BW, FUSE_ARANGE, \
GlobalCounters, colored, prod, dedup, all_int, merge_dicts, getenv, Metadata
@@ -134,12 +134,12 @@ def _recurse_reduceops(buf:LazyBuffer, st:ShapeTracker, realizes:Dict[LazyBuffer
if buf.srcs[0] is not buf.srcs[0].base and buf.srcs[0].base is top_reduce[0] and buf.op is top_reduce[0].op:
# merge this reduce with its parent
new_st = top_reduce[1]+st
top_reduce = (top_reduce[0], new_st.reshape(reduce_st(top_reduce_input_st, new_axis:=axis+top_reduce_axes)))
top_reduce = (top_reduce[0], new_st.reshape(top_reduce_input_st.reduce(new_axis:=axis+top_reduce_axes)))
reduce_info[top_reduce] = (top_reduce_input_st, new_axis)
return None
# reshape this reduceop based on the top reduce
input_st = input_st.reshape(tuple(1 if i in top_reduce_axes else s for i,s in enumerate(top_reduce_input_st.shape)))
st = st.reshape(reduce_st(input_st, axis))
st = st.reshape(input_st.reduce(axis))
reduce_info[(buf, st)] = (input_st, axis)
return (buf, st)
return cache.setdefault((buf, st), top_reduce)
@@ -147,7 +147,7 @@ def _recurse_reduceops(buf:LazyBuffer, st:ShapeTracker, realizes:Dict[LazyBuffer
def _lower_lazybuffer(outs:List[LazyBuffer], realizes:Dict[LazyBuffer, None]) -> LBScheduleItem:
"""describe the computation for a LazyBuffer with UOp + inputs + var_vals"""
if (out:=outs[0]).op is MetaOps.COPY and getenv("USE_COPY_KERNEL") and out.device.split(":")[0] == out.srcs[0].device.split(":")[0]:
idx, valid = UOp.from_st(ShapeTracker.from_shape(out.arg))
idx, valid = ShapeTracker.from_shape(out.arg).to_uops()
rd = UOp(UOps.LOAD, dtypes.uint8, (UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.uint8), (), 1), idx, valid))
wr = UOp(UOps.STORE, None, (UOp(UOps.DEFINE_GLOBAL, PtrDType(out.dtype), (), 0), idx, rd, valid))
return LBScheduleItem(UOp(UOps.SINK, None, (wr,)), outs, [x.base for x in out.srcs])
@@ -172,7 +172,7 @@ def _lower_lazybuffer(outs:List[LazyBuffer], realizes:Dict[LazyBuffer, None]) ->
ast: List[UOp] = []
inputs: Dict[LazyBuffer, int] = {}
for i, out in enumerate(outs):
output_st = ShapeTracker.from_shape(reduce_st(*deque(reduce_info.values(), 1).pop()) if reduce_info else out.shape)
output_st = ShapeTracker.from_shape(ShapeTracker.reduce(*deque(reduce_info.values(), 1).pop()) if reduce_info else out.shape)
src = _recursive_uop(out, output_st, tuple(outs), var_vals, inputs, realizes, assign_targets, reduce_info, cache=cache)
if out.op is MetaOps.ASSIGN and out.arg:
assert out.arg[0].shape == out.shape, f"ASSIGN must not override output shape {out.arg[0].shape} != {out.shape}"

View File

@@ -2,7 +2,7 @@ from __future__ import annotations
from typing import Union, Optional, Any, Tuple, List, get_args
from tinygrad.dtype import dtypes, DType, DTypeLike, ConstType, to_dtype
from tinygrad.helpers import prod, getenv, all_int, all_same, DEBUG, _METADATA, Metadata, SPLIT_REDUCEOP
from tinygrad.ops import MetaOps, UnaryOps, BinaryOps, TernaryOps, ReduceOps, Op, exec_alu, python_alu, reduce_st
from tinygrad.ops import MetaOps, UnaryOps, BinaryOps, TernaryOps, ReduceOps, Op, exec_alu, python_alu
from tinygrad.shape.symbolic import sint, Variable
from tinygrad.shape.shapetracker import ShapeTracker
from tinygrad.device import Buffer
@@ -171,10 +171,10 @@ class LazyBuffer:
assert all(0 <= x < len(self.shape) for x in axis), f"axis args {axis} out of range for shape {self.shape}"
axis = tuple(sorted([x for x in axis if self.shape[x] != 1]))
if len(axis) == 0: return self
return create_lazybuffer(self.device, ShapeTracker.from_shape(reduce_st(self.st, axis)), self.dtype, op, axis, (self,))
return create_lazybuffer(self.device, ShapeTracker.from_shape(self.st.reduce(axis)), self.dtype, op, axis, (self,))
def r(self, op:ReduceOps, axis:Tuple[int, ...]) -> LazyBuffer:
new_shape = reduce_st(self.st, axis)
new_shape = self.st.reduce(axis)
# TODO: this logic should move to the scheduler
if 0 in self.shape and 0 not in new_shape: return self.const({ReduceOps.SUM: 0.0, ReduceOps.MAX: dtypes.min(self.dtype)}[op], new_shape)

View File

@@ -5,9 +5,8 @@ import math, operator, ctypes, struct, functools, hashlib, itertools
from enum import Enum, auto
from dataclasses import dataclass
from tinygrad.dtype import ConstType, dtypes, DType
from tinygrad.helpers import dedup, merge_dicts, pretty_print, prod
from tinygrad.helpers import merge_dicts, pretty_print, prod
from tinygrad.shape.symbolic import Variable, sint
from tinygrad.shape.shapetracker import ShapeTracker
# these are the llops your accelerator must implement, along with toCpu
# the Enum class doesn't work with mypy, this is static. sorry it's ugly
@@ -74,8 +73,6 @@ truncate: Dict[DType, Callable] = {dtypes.bool: bool,
def exec_alu(op:Op, dtype:DType, operands): return truncate.get(dtype, lambda x: x)(python_alu[op](*operands))
def reduce_st(st:ShapeTracker, axis:Tuple[int, ...]) -> Tuple[sint, ...]: return tuple(1 if i in axis else s for i,s in enumerate(st.shape))
# the order of these UOps controls the order of the toposort
class UOps(Enum):
# ops that aren't rendered
@@ -163,8 +160,6 @@ class UOp:
def parents(self) -> Dict[UOp, None]: return merge_dicts([{x:None for x in self.src}]+[x.parents for x in self.src])
@property # parents with self
def sparents(self) -> Dict[UOp, None]: return {**self.parents, self:None}
@staticmethod
def from_st(st:ShapeTracker) -> Tuple[UOp, UOp]: return UOp(UOps.ST_IDX, dtypes.pyint, (), st), UOp(UOps.ST_VALID, dtypes.bool, (), st)
@functools.cached_property
def full_shape(self) -> Tuple[sint, ...]:
if self.op in {UOps.ST_IDX, UOps.ST_VALID}: return self.arg.shape
@@ -374,30 +369,3 @@ def flops_mem(uops:List[UOp], ignore_indexing=False) -> Tuple[sint, sint]:
assert u.arg[1] is not None
flops += 2 * prod(u.arg[1]) // 32 * mults
return flops, mem
# the living definition of UOps.ST_IDX and UOps.ST_VALID
def verify_ast(ast:UOp) -> Dict[UOp, ShapeTracker]:
assert ast.op is UOps.SINK and all(x.op is UOps.STORE for x in ast.src), "must be SINK"
sts: Dict[UOp, ShapeTracker] = {}
def assert_valid(op:UOp, st:ShapeTracker):
if op in sts or op.op in {UOps.DEFINE_LOCAL, UOps.DEFINE_GLOBAL}: return
# restore globals from the two stage reduce
if op.op is UOps.LOAD and op.src[0].op is UOps.DEFINE_LOCAL:
assert_valid(local_reduce:=op.src[1].src[2], op.src[-1].arg)
return sts.setdefault(op, sts[local_reduce])
for x in op.src: assert_valid(x, st)
# only reduceop is allowed to change shape, limited to turning n to 1
if op.op is UOps.REDUCE_AXIS: st = ShapeTracker.from_shape(reduce_st(sts[op.src[0]], op.arg[1][-1] if op.arg[0] is ReduceOps.WMMA else op.arg[1]))
else:
# movementops are pushed to the edges with ST_IDX, ST_VALID
# elementwise inherits shape
st = op.arg if op.op in {UOps.ST_IDX, UOps.ST_VALID} else sts[op.src[-1]]
for x in (op.src[1:] if op.op in BUFFER_UOPS else op.src):
if sts[x].shape != st.shape:
if prod(sts[x].shape) == prod(st.shape): raise AssertionError(f"found implicit reshape {x.op} {op.op} {sts[x].shape} != {st.shape}")
raise AssertionError(f"found implicit expand {x.op} {sts[x].shape} != {op.op} {st.shape} {prod(sts[x].shape)} != {prod(st.shape)}")
sts[op] = st
for out in ast.src: assert_valid(out, out.src[-1].arg)
shape_dims = [sorted(dedup(dims)) for dims in zip(*[x.shape for x in sts.values()])]
assert all(len(x) == 1 or (len(x) == 2 and x[0] == 1) for x in shape_dims), f"shapes must have either 1 or n in each dimension, {shape_dims}"
return sts

View File

@@ -5,6 +5,8 @@ from typing import Tuple, List, Optional, Dict, Set, Iterable, cast
from tinygrad.helpers import merge_dicts, getenv
from tinygrad.shape.symbolic import Variable, MulNode, Node, SumNode, NumNode, sint
from tinygrad.shape.view import View, strides_for_shape
from tinygrad.dtype import dtypes
from tinygrad.ops import UOp, UOps
@dataclass(frozen=True)
class ShapeTracker:
@@ -37,6 +39,10 @@ class ShapeTracker:
@property
def size(self) -> int: return self.views[-1].size()
def reduce(self, axis:Tuple[int, ...]) -> Tuple[sint, ...]: return tuple(1 if i in axis else s for i,s in enumerate(self.shape))
def to_uops(self) -> Tuple[UOp, UOp]: return UOp(UOps.ST_IDX, dtypes.pyint, (), self), UOp(UOps.ST_VALID, dtypes.bool, (), self)
def real_size(self) -> int:
if 0 in self.shape: return 0
idx, valid = self.expr_idxs()