mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-08 22:48:25 -05:00
no shapetracker in ops [run_process_replay] (#6117)
This commit is contained in:
@@ -51,11 +51,11 @@ b = Buffer(DEVICE, 1, dtypes.int32).allocate().copyin(memoryview(bytearray(struc
|
||||
# describe the computation
|
||||
buf_1 = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.int32), (), 1)
|
||||
buf_2 = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.int32), (), 2)
|
||||
ld_1 = UOp(UOps.LOAD, dtypes.int32, (buf_1, *UOp.from_st(ShapeTracker.from_shape((1,)))))
|
||||
ld_2 = UOp(UOps.LOAD, dtypes.int32, (buf_2, *UOp.from_st(ShapeTracker.from_shape((1,)))))
|
||||
ld_1 = UOp(UOps.LOAD, dtypes.int32, (buf_1, *ShapeTracker.from_shape((1,)).to_uops()))
|
||||
ld_2 = UOp(UOps.LOAD, dtypes.int32, (buf_2, *ShapeTracker.from_shape((1,)).to_uops()))
|
||||
alu = ld_1 + ld_2
|
||||
output_buf = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.int32), (), 0)
|
||||
idx, valid = UOp.from_st(ShapeTracker.from_shape((1,)))
|
||||
idx, valid = ShapeTracker.from_shape((1,)).to_uops()
|
||||
st_0 = UOp(UOps.STORE, None, (output_buf, idx, alu, valid))
|
||||
s = UOp(UOps.SINK, None, (st_0,))
|
||||
|
||||
|
||||
@@ -4,7 +4,7 @@ import functools, hashlib
|
||||
from enum import Enum, auto
|
||||
from dataclasses import dataclass
|
||||
from tinygrad.helpers import dedup, pretty_print, prod
|
||||
from tinygrad.ops import ReduceOps, UnaryOps, BinaryOps, TernaryOps, reduce_st, UOp, UOps
|
||||
from tinygrad.ops import ReduceOps, UnaryOps, BinaryOps, TernaryOps, UOp, UOps
|
||||
from tinygrad.dtype import ImageDType, PtrDType, dtypes, DType, ConstType
|
||||
from tinygrad.shape.symbolic import Variable, sint
|
||||
from tinygrad.shape.shapetracker import ShapeTracker
|
||||
@@ -86,7 +86,7 @@ def verify_lazyop(ast:LazyOp) -> Dict[LazyOp, ShapeTracker]:
|
||||
if op.op in ReduceOps:
|
||||
axis = op.arg[-1] if op.op is ReduceOps.WMMA else op.arg
|
||||
assert isinstance(axis, tuple) and all(isinstance(i, int) for i in axis), f"reduceop must have axis {op.arg}"
|
||||
st = ShapeTracker.from_shape(reduce_st(sts[op.src[0]], axis))
|
||||
st = ShapeTracker.from_shape(sts[op.src[0]].reduce(axis))
|
||||
else:
|
||||
# movementops are pushed to the edges with LOAD
|
||||
# elementwise inherits shape
|
||||
@@ -115,7 +115,7 @@ def to_uop(*a) -> UOp:
|
||||
@functools.lru_cache(None)
|
||||
def create_uop(lop:LazyOp) -> UOp:
|
||||
if lop.op in BufferOps:
|
||||
idx, valid = UOp.from_st(lop.arg.st)
|
||||
idx, valid = lop.arg.st.to_uops()
|
||||
membuf_dtype: DType = lop.arg.dtype
|
||||
dtype = membuf_dtype.base if isinstance(membuf_dtype, ImageDType) else membuf_dtype
|
||||
if lop.op is BufferOps.CONST:
|
||||
|
||||
@@ -8,9 +8,9 @@ from typing import List, Optional, Union, cast
|
||||
from tinygrad import nn, dtypes
|
||||
from tinygrad.device import Device
|
||||
from tinygrad.tensor import Tensor
|
||||
from tinygrad.ops import BinaryOps, MetaOps, UnaryOps, UOps, verify_ast
|
||||
from tinygrad.ops import BinaryOps, MetaOps, UnaryOps, UOps
|
||||
from tinygrad.helpers import CI, DEBUG, FUSE_ARANGE, FUSE_CONV_BW, GlobalCounters, flatten, getenv, SPLIT_REDUCEOP
|
||||
from tinygrad.codegen.kernel import Kernel
|
||||
from tinygrad.codegen.kernel import Kernel, verify_ast
|
||||
from tinygrad.engine.schedule import create_schedule
|
||||
from tinygrad.engine.realize import run_schedule
|
||||
from test.helpers import is_dtype_supported, Context
|
||||
|
||||
@@ -4,7 +4,7 @@ from dataclasses import dataclass, replace
|
||||
from collections import defaultdict
|
||||
from typing import Literal, Optional, List, Tuple, Union, cast, Dict, Final, DefaultDict
|
||||
|
||||
from tinygrad.ops import BinaryOps, ReduceOps, UNSAFE_PAD_OPS, KernelInfo, BUFFER_UOPS, UOp, UOps, verify_ast, print_uops
|
||||
from tinygrad.ops import BinaryOps, ReduceOps, UNSAFE_PAD_OPS, KernelInfo, BUFFER_UOPS, UOp, UOps, print_uops
|
||||
from tinygrad.device import Device
|
||||
from tinygrad.renderer import Renderer, TensorCore, Program
|
||||
from tinygrad.dtype import DType, ImageDType, PtrDType
|
||||
@@ -638,7 +638,7 @@ class Kernel:
|
||||
if op.op in BUFFER_UOPS:
|
||||
# for locals, we use the ShapeTracker that's in the srcs
|
||||
st = op.src[-1].arg if op.src[0].op is UOps.DEFINE_LOCAL else self.sts[self.bufs.index(op)]
|
||||
idx, valid = UOp.from_st(st if apply_to_st is None else apply_to_st(st))
|
||||
idx, valid = (st if apply_to_st is None else apply_to_st(st)).to_uops()
|
||||
if op.op is UOps.CONST: return replace(op, src=(valid,))
|
||||
if op.op is UOps.STORE: return replace(op, src=(op.src[0], idx, fixup_ast(op.src[2], apply_to_st), valid))
|
||||
return replace(op, src=tuple(fixup_ast(x, apply_to_st) for x in op.src[:-2])+(idx, valid))
|
||||
@@ -697,7 +697,7 @@ class Kernel:
|
||||
for i,(src,fix_st_fxn) in enumerate(zip(rsrc.src, [fix_st1, fix_st2])):
|
||||
st_load = [self.sts[self.bufs.index(op)].real_strides() for op in rsrc.parents if op.op is UOps.LOAD]
|
||||
local_shape = tuple(s if max(cast(int, x[i]) for x in st_load) != 0 else 1 for i,s in enumerate(ex_shape))
|
||||
idx, valid = UOp.from_st(ShapeTracker.from_shape(local_shape).expand(ex_shape))
|
||||
idx, valid = ShapeTracker.from_shape(local_shape).expand(ex_shape).to_uops()
|
||||
membuf = UOp(UOps.DEFINE_LOCAL, PtrDType(tc.dtype_in), (), (f"temp{-(-1-i)}", idx.arg.real_size()))
|
||||
local_store = fixup_ast(UOp(UOps.STORE, tc.dtype_in, (membuf, idx, src, valid)), fix_st_fxn)
|
||||
srcs.append(UOp(UOps.LOAD, tc.dtype_in, (membuf, local_store, idx, valid)))
|
||||
@@ -714,7 +714,7 @@ class Kernel:
|
||||
start = UOp(UOps.REDUCE_AXIS, op.dtype, (fixup_ast(op.src[0], apply_to_st),), arg=(op.arg[0], axis))
|
||||
local_shape = (1,) * self.global_dims + self.full_shape[self.global_dims:self.global_dims+self.local_dims+self.group_for_reduces] + \
|
||||
(1,) * (self.first_upcast - self.group_for_reduces - self.first_reduce) + tuple([x[0] for x in self.upcasted_axis(0)])
|
||||
idx, valid = UOp.from_st(ShapeTracker.from_shape(local_shape))
|
||||
idx, valid = ShapeTracker.from_shape(local_shape).to_uops()
|
||||
local_buffer = UOp(UOps.DEFINE_LOCAL, PtrDType(cast(DType, op.dtype)), (), ("temp1", idx.arg.real_size()))
|
||||
local_load = UOp(UOps.LOAD, op.dtype, (local_buffer, UOp.store(local_buffer, idx, start, valid), idx, valid))
|
||||
second_axis = tuple(range(self.first_reduce, self.first_reduce+self.group_for_reduces))
|
||||
@@ -760,3 +760,30 @@ class Kernel:
|
||||
key=lambda x: (x.op, x.src[0].arg)))
|
||||
return Program(ansiname, src, self.opts.device, self.uops, mem_estimate=mem_bytes,
|
||||
global_size=[1,1,1] if self.opts.has_local else None, local_size=[1,1,1] if self.opts.has_local else None)
|
||||
|
||||
# the living definition of UOps.ST_IDX and UOps.ST_VALID
|
||||
def verify_ast(ast:UOp) -> Dict[UOp, ShapeTracker]:
|
||||
assert ast.op is UOps.SINK and all(x.op is UOps.STORE for x in ast.src), "must be SINK"
|
||||
sts: Dict[UOp, ShapeTracker] = {}
|
||||
def assert_valid(op:UOp, st:ShapeTracker):
|
||||
if op in sts or op.op in {UOps.DEFINE_LOCAL, UOps.DEFINE_GLOBAL}: return
|
||||
# restore globals from the two stage reduce
|
||||
if op.op is UOps.LOAD and op.src[0].op is UOps.DEFINE_LOCAL:
|
||||
assert_valid(local_reduce:=op.src[1].src[2], op.src[-1].arg)
|
||||
return sts.setdefault(op, sts[local_reduce])
|
||||
for x in op.src: assert_valid(x, st)
|
||||
# only reduceop is allowed to change shape, limited to turning n to 1
|
||||
if op.op is UOps.REDUCE_AXIS: st = ShapeTracker.from_shape(sts[op.src[0]].reduce(op.arg[1][-1] if op.arg[0] is ReduceOps.WMMA else op.arg[1]))
|
||||
else:
|
||||
# movementops are pushed to the edges with ST_IDX, ST_VALID
|
||||
# elementwise inherits shape
|
||||
st = op.arg if op.op in {UOps.ST_IDX, UOps.ST_VALID} else sts[op.src[-1]]
|
||||
for x in (op.src[1:] if op.op in BUFFER_UOPS else op.src):
|
||||
if sts[x].shape != st.shape:
|
||||
if prod(sts[x].shape) == prod(st.shape): raise AssertionError(f"found implicit reshape {x.op} {op.op} {sts[x].shape} != {st.shape}")
|
||||
raise AssertionError(f"found implicit expand {x.op} {sts[x].shape} != {op.op} {st.shape} {prod(sts[x].shape)} != {prod(st.shape)}")
|
||||
sts[op] = st
|
||||
for out in ast.src: assert_valid(out, out.src[-1].arg)
|
||||
shape_dims = [sorted(dedup(dims)) for dims in zip(*[x.shape for x in sts.values()])]
|
||||
assert all(len(x) == 1 or (len(x) == 2 and x[0] == 1) for x in shape_dims), f"shapes must have either 1 or n in each dimension, {shape_dims}"
|
||||
return sts
|
||||
@@ -2,7 +2,7 @@ import sys, pickle, atexit, importlib, contextlib
|
||||
from collections import defaultdict, deque
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Tuple, List, Dict, Optional, Set, DefaultDict, get_args
|
||||
from tinygrad.ops import MetaOps, ReduceOps, UNSAFE_PAD_OPS, UnaryOps, reduce_st, UOp, UOps
|
||||
from tinygrad.ops import MetaOps, ReduceOps, UNSAFE_PAD_OPS, UnaryOps, UOp, UOps
|
||||
from tinygrad.engine.graph import log_lazybuffer, realized_lazybuffer
|
||||
from tinygrad.helpers import GRAPH, DEBUG, MULTIOUTPUT, SAVE_SCHEDULE, FUSE_CONV_BW, FUSE_ARANGE, \
|
||||
GlobalCounters, colored, prod, dedup, all_int, merge_dicts, getenv, Metadata
|
||||
@@ -134,12 +134,12 @@ def _recurse_reduceops(buf:LazyBuffer, st:ShapeTracker, realizes:Dict[LazyBuffer
|
||||
if buf.srcs[0] is not buf.srcs[0].base and buf.srcs[0].base is top_reduce[0] and buf.op is top_reduce[0].op:
|
||||
# merge this reduce with its parent
|
||||
new_st = top_reduce[1]+st
|
||||
top_reduce = (top_reduce[0], new_st.reshape(reduce_st(top_reduce_input_st, new_axis:=axis+top_reduce_axes)))
|
||||
top_reduce = (top_reduce[0], new_st.reshape(top_reduce_input_st.reduce(new_axis:=axis+top_reduce_axes)))
|
||||
reduce_info[top_reduce] = (top_reduce_input_st, new_axis)
|
||||
return None
|
||||
# reshape this reduceop based on the top reduce
|
||||
input_st = input_st.reshape(tuple(1 if i in top_reduce_axes else s for i,s in enumerate(top_reduce_input_st.shape)))
|
||||
st = st.reshape(reduce_st(input_st, axis))
|
||||
st = st.reshape(input_st.reduce(axis))
|
||||
reduce_info[(buf, st)] = (input_st, axis)
|
||||
return (buf, st)
|
||||
return cache.setdefault((buf, st), top_reduce)
|
||||
@@ -147,7 +147,7 @@ def _recurse_reduceops(buf:LazyBuffer, st:ShapeTracker, realizes:Dict[LazyBuffer
|
||||
def _lower_lazybuffer(outs:List[LazyBuffer], realizes:Dict[LazyBuffer, None]) -> LBScheduleItem:
|
||||
"""describe the computation for a LazyBuffer with UOp + inputs + var_vals"""
|
||||
if (out:=outs[0]).op is MetaOps.COPY and getenv("USE_COPY_KERNEL") and out.device.split(":")[0] == out.srcs[0].device.split(":")[0]:
|
||||
idx, valid = UOp.from_st(ShapeTracker.from_shape(out.arg))
|
||||
idx, valid = ShapeTracker.from_shape(out.arg).to_uops()
|
||||
rd = UOp(UOps.LOAD, dtypes.uint8, (UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.uint8), (), 1), idx, valid))
|
||||
wr = UOp(UOps.STORE, None, (UOp(UOps.DEFINE_GLOBAL, PtrDType(out.dtype), (), 0), idx, rd, valid))
|
||||
return LBScheduleItem(UOp(UOps.SINK, None, (wr,)), outs, [x.base for x in out.srcs])
|
||||
@@ -172,7 +172,7 @@ def _lower_lazybuffer(outs:List[LazyBuffer], realizes:Dict[LazyBuffer, None]) ->
|
||||
ast: List[UOp] = []
|
||||
inputs: Dict[LazyBuffer, int] = {}
|
||||
for i, out in enumerate(outs):
|
||||
output_st = ShapeTracker.from_shape(reduce_st(*deque(reduce_info.values(), 1).pop()) if reduce_info else out.shape)
|
||||
output_st = ShapeTracker.from_shape(ShapeTracker.reduce(*deque(reduce_info.values(), 1).pop()) if reduce_info else out.shape)
|
||||
src = _recursive_uop(out, output_st, tuple(outs), var_vals, inputs, realizes, assign_targets, reduce_info, cache=cache)
|
||||
if out.op is MetaOps.ASSIGN and out.arg:
|
||||
assert out.arg[0].shape == out.shape, f"ASSIGN must not override output shape {out.arg[0].shape} != {out.shape}"
|
||||
|
||||
@@ -2,7 +2,7 @@ from __future__ import annotations
|
||||
from typing import Union, Optional, Any, Tuple, List, get_args
|
||||
from tinygrad.dtype import dtypes, DType, DTypeLike, ConstType, to_dtype
|
||||
from tinygrad.helpers import prod, getenv, all_int, all_same, DEBUG, _METADATA, Metadata, SPLIT_REDUCEOP
|
||||
from tinygrad.ops import MetaOps, UnaryOps, BinaryOps, TernaryOps, ReduceOps, Op, exec_alu, python_alu, reduce_st
|
||||
from tinygrad.ops import MetaOps, UnaryOps, BinaryOps, TernaryOps, ReduceOps, Op, exec_alu, python_alu
|
||||
from tinygrad.shape.symbolic import sint, Variable
|
||||
from tinygrad.shape.shapetracker import ShapeTracker
|
||||
from tinygrad.device import Buffer
|
||||
@@ -171,10 +171,10 @@ class LazyBuffer:
|
||||
assert all(0 <= x < len(self.shape) for x in axis), f"axis args {axis} out of range for shape {self.shape}"
|
||||
axis = tuple(sorted([x for x in axis if self.shape[x] != 1]))
|
||||
if len(axis) == 0: return self
|
||||
return create_lazybuffer(self.device, ShapeTracker.from_shape(reduce_st(self.st, axis)), self.dtype, op, axis, (self,))
|
||||
return create_lazybuffer(self.device, ShapeTracker.from_shape(self.st.reduce(axis)), self.dtype, op, axis, (self,))
|
||||
|
||||
def r(self, op:ReduceOps, axis:Tuple[int, ...]) -> LazyBuffer:
|
||||
new_shape = reduce_st(self.st, axis)
|
||||
new_shape = self.st.reduce(axis)
|
||||
# TODO: this logic should move to the scheduler
|
||||
if 0 in self.shape and 0 not in new_shape: return self.const({ReduceOps.SUM: 0.0, ReduceOps.MAX: dtypes.min(self.dtype)}[op], new_shape)
|
||||
|
||||
|
||||
@@ -5,9 +5,8 @@ import math, operator, ctypes, struct, functools, hashlib, itertools
|
||||
from enum import Enum, auto
|
||||
from dataclasses import dataclass
|
||||
from tinygrad.dtype import ConstType, dtypes, DType
|
||||
from tinygrad.helpers import dedup, merge_dicts, pretty_print, prod
|
||||
from tinygrad.helpers import merge_dicts, pretty_print, prod
|
||||
from tinygrad.shape.symbolic import Variable, sint
|
||||
from tinygrad.shape.shapetracker import ShapeTracker
|
||||
|
||||
# these are the llops your accelerator must implement, along with toCpu
|
||||
# the Enum class doesn't work with mypy, this is static. sorry it's ugly
|
||||
@@ -74,8 +73,6 @@ truncate: Dict[DType, Callable] = {dtypes.bool: bool,
|
||||
|
||||
def exec_alu(op:Op, dtype:DType, operands): return truncate.get(dtype, lambda x: x)(python_alu[op](*operands))
|
||||
|
||||
def reduce_st(st:ShapeTracker, axis:Tuple[int, ...]) -> Tuple[sint, ...]: return tuple(1 if i in axis else s for i,s in enumerate(st.shape))
|
||||
|
||||
# the order of these UOps controls the order of the toposort
|
||||
class UOps(Enum):
|
||||
# ops that aren't rendered
|
||||
@@ -163,8 +160,6 @@ class UOp:
|
||||
def parents(self) -> Dict[UOp, None]: return merge_dicts([{x:None for x in self.src}]+[x.parents for x in self.src])
|
||||
@property # parents with self
|
||||
def sparents(self) -> Dict[UOp, None]: return {**self.parents, self:None}
|
||||
@staticmethod
|
||||
def from_st(st:ShapeTracker) -> Tuple[UOp, UOp]: return UOp(UOps.ST_IDX, dtypes.pyint, (), st), UOp(UOps.ST_VALID, dtypes.bool, (), st)
|
||||
@functools.cached_property
|
||||
def full_shape(self) -> Tuple[sint, ...]:
|
||||
if self.op in {UOps.ST_IDX, UOps.ST_VALID}: return self.arg.shape
|
||||
@@ -374,30 +369,3 @@ def flops_mem(uops:List[UOp], ignore_indexing=False) -> Tuple[sint, sint]:
|
||||
assert u.arg[1] is not None
|
||||
flops += 2 * prod(u.arg[1]) // 32 * mults
|
||||
return flops, mem
|
||||
|
||||
# the living definition of UOps.ST_IDX and UOps.ST_VALID
|
||||
def verify_ast(ast:UOp) -> Dict[UOp, ShapeTracker]:
|
||||
assert ast.op is UOps.SINK and all(x.op is UOps.STORE for x in ast.src), "must be SINK"
|
||||
sts: Dict[UOp, ShapeTracker] = {}
|
||||
def assert_valid(op:UOp, st:ShapeTracker):
|
||||
if op in sts or op.op in {UOps.DEFINE_LOCAL, UOps.DEFINE_GLOBAL}: return
|
||||
# restore globals from the two stage reduce
|
||||
if op.op is UOps.LOAD and op.src[0].op is UOps.DEFINE_LOCAL:
|
||||
assert_valid(local_reduce:=op.src[1].src[2], op.src[-1].arg)
|
||||
return sts.setdefault(op, sts[local_reduce])
|
||||
for x in op.src: assert_valid(x, st)
|
||||
# only reduceop is allowed to change shape, limited to turning n to 1
|
||||
if op.op is UOps.REDUCE_AXIS: st = ShapeTracker.from_shape(reduce_st(sts[op.src[0]], op.arg[1][-1] if op.arg[0] is ReduceOps.WMMA else op.arg[1]))
|
||||
else:
|
||||
# movementops are pushed to the edges with ST_IDX, ST_VALID
|
||||
# elementwise inherits shape
|
||||
st = op.arg if op.op in {UOps.ST_IDX, UOps.ST_VALID} else sts[op.src[-1]]
|
||||
for x in (op.src[1:] if op.op in BUFFER_UOPS else op.src):
|
||||
if sts[x].shape != st.shape:
|
||||
if prod(sts[x].shape) == prod(st.shape): raise AssertionError(f"found implicit reshape {x.op} {op.op} {sts[x].shape} != {st.shape}")
|
||||
raise AssertionError(f"found implicit expand {x.op} {sts[x].shape} != {op.op} {st.shape} {prod(sts[x].shape)} != {prod(st.shape)}")
|
||||
sts[op] = st
|
||||
for out in ast.src: assert_valid(out, out.src[-1].arg)
|
||||
shape_dims = [sorted(dedup(dims)) for dims in zip(*[x.shape for x in sts.values()])]
|
||||
assert all(len(x) == 1 or (len(x) == 2 and x[0] == 1) for x in shape_dims), f"shapes must have either 1 or n in each dimension, {shape_dims}"
|
||||
return sts
|
||||
|
||||
@@ -5,6 +5,8 @@ from typing import Tuple, List, Optional, Dict, Set, Iterable, cast
|
||||
from tinygrad.helpers import merge_dicts, getenv
|
||||
from tinygrad.shape.symbolic import Variable, MulNode, Node, SumNode, NumNode, sint
|
||||
from tinygrad.shape.view import View, strides_for_shape
|
||||
from tinygrad.dtype import dtypes
|
||||
from tinygrad.ops import UOp, UOps
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ShapeTracker:
|
||||
@@ -37,6 +39,10 @@ class ShapeTracker:
|
||||
@property
|
||||
def size(self) -> int: return self.views[-1].size()
|
||||
|
||||
def reduce(self, axis:Tuple[int, ...]) -> Tuple[sint, ...]: return tuple(1 if i in axis else s for i,s in enumerate(self.shape))
|
||||
|
||||
def to_uops(self) -> Tuple[UOp, UOp]: return UOp(UOps.ST_IDX, dtypes.pyint, (), self), UOp(UOps.ST_VALID, dtypes.bool, (), self)
|
||||
|
||||
def real_size(self) -> int:
|
||||
if 0 in self.shape: return 0
|
||||
idx, valid = self.expr_idxs()
|
||||
|
||||
Reference in New Issue
Block a user