no shapetracker in ops [run_process_replay] (#6117)

This commit is contained in:
George Hotz
2024-08-16 17:23:27 -07:00
committed by GitHub
parent 74ee9febec
commit 89c7989659
8 changed files with 54 additions and 53 deletions

View File

@@ -51,11 +51,11 @@ b = Buffer(DEVICE, 1, dtypes.int32).allocate().copyin(memoryview(bytearray(struc
# describe the computation # describe the computation
buf_1 = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.int32), (), 1) buf_1 = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.int32), (), 1)
buf_2 = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.int32), (), 2) buf_2 = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.int32), (), 2)
ld_1 = UOp(UOps.LOAD, dtypes.int32, (buf_1, *UOp.from_st(ShapeTracker.from_shape((1,))))) ld_1 = UOp(UOps.LOAD, dtypes.int32, (buf_1, *ShapeTracker.from_shape((1,)).to_uops()))
ld_2 = UOp(UOps.LOAD, dtypes.int32, (buf_2, *UOp.from_st(ShapeTracker.from_shape((1,))))) ld_2 = UOp(UOps.LOAD, dtypes.int32, (buf_2, *ShapeTracker.from_shape((1,)).to_uops()))
alu = ld_1 + ld_2 alu = ld_1 + ld_2
output_buf = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.int32), (), 0) output_buf = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.int32), (), 0)
idx, valid = UOp.from_st(ShapeTracker.from_shape((1,))) idx, valid = ShapeTracker.from_shape((1,)).to_uops()
st_0 = UOp(UOps.STORE, None, (output_buf, idx, alu, valid)) st_0 = UOp(UOps.STORE, None, (output_buf, idx, alu, valid))
s = UOp(UOps.SINK, None, (st_0,)) s = UOp(UOps.SINK, None, (st_0,))

View File

@@ -4,7 +4,7 @@ import functools, hashlib
from enum import Enum, auto from enum import Enum, auto
from dataclasses import dataclass from dataclasses import dataclass
from tinygrad.helpers import dedup, pretty_print, prod from tinygrad.helpers import dedup, pretty_print, prod
from tinygrad.ops import ReduceOps, UnaryOps, BinaryOps, TernaryOps, reduce_st, UOp, UOps from tinygrad.ops import ReduceOps, UnaryOps, BinaryOps, TernaryOps, UOp, UOps
from tinygrad.dtype import ImageDType, PtrDType, dtypes, DType, ConstType from tinygrad.dtype import ImageDType, PtrDType, dtypes, DType, ConstType
from tinygrad.shape.symbolic import Variable, sint from tinygrad.shape.symbolic import Variable, sint
from tinygrad.shape.shapetracker import ShapeTracker from tinygrad.shape.shapetracker import ShapeTracker
@@ -86,7 +86,7 @@ def verify_lazyop(ast:LazyOp) -> Dict[LazyOp, ShapeTracker]:
if op.op in ReduceOps: if op.op in ReduceOps:
axis = op.arg[-1] if op.op is ReduceOps.WMMA else op.arg axis = op.arg[-1] if op.op is ReduceOps.WMMA else op.arg
assert isinstance(axis, tuple) and all(isinstance(i, int) for i in axis), f"reduceop must have axis {op.arg}" assert isinstance(axis, tuple) and all(isinstance(i, int) for i in axis), f"reduceop must have axis {op.arg}"
st = ShapeTracker.from_shape(reduce_st(sts[op.src[0]], axis)) st = ShapeTracker.from_shape(sts[op.src[0]].reduce(axis))
else: else:
# movementops are pushed to the edges with LOAD # movementops are pushed to the edges with LOAD
# elementwise inherits shape # elementwise inherits shape
@@ -115,7 +115,7 @@ def to_uop(*a) -> UOp:
@functools.lru_cache(None) @functools.lru_cache(None)
def create_uop(lop:LazyOp) -> UOp: def create_uop(lop:LazyOp) -> UOp:
if lop.op in BufferOps: if lop.op in BufferOps:
idx, valid = UOp.from_st(lop.arg.st) idx, valid = lop.arg.st.to_uops()
membuf_dtype: DType = lop.arg.dtype membuf_dtype: DType = lop.arg.dtype
dtype = membuf_dtype.base if isinstance(membuf_dtype, ImageDType) else membuf_dtype dtype = membuf_dtype.base if isinstance(membuf_dtype, ImageDType) else membuf_dtype
if lop.op is BufferOps.CONST: if lop.op is BufferOps.CONST:

View File

@@ -8,9 +8,9 @@ from typing import List, Optional, Union, cast
from tinygrad import nn, dtypes from tinygrad import nn, dtypes
from tinygrad.device import Device from tinygrad.device import Device
from tinygrad.tensor import Tensor from tinygrad.tensor import Tensor
from tinygrad.ops import BinaryOps, MetaOps, UnaryOps, UOps, verify_ast from tinygrad.ops import BinaryOps, MetaOps, UnaryOps, UOps
from tinygrad.helpers import CI, DEBUG, FUSE_ARANGE, FUSE_CONV_BW, GlobalCounters, flatten, getenv, SPLIT_REDUCEOP from tinygrad.helpers import CI, DEBUG, FUSE_ARANGE, FUSE_CONV_BW, GlobalCounters, flatten, getenv, SPLIT_REDUCEOP
from tinygrad.codegen.kernel import Kernel from tinygrad.codegen.kernel import Kernel, verify_ast
from tinygrad.engine.schedule import create_schedule from tinygrad.engine.schedule import create_schedule
from tinygrad.engine.realize import run_schedule from tinygrad.engine.realize import run_schedule
from test.helpers import is_dtype_supported, Context from test.helpers import is_dtype_supported, Context

View File

@@ -4,7 +4,7 @@ from dataclasses import dataclass, replace
from collections import defaultdict from collections import defaultdict
from typing import Literal, Optional, List, Tuple, Union, cast, Dict, Final, DefaultDict from typing import Literal, Optional, List, Tuple, Union, cast, Dict, Final, DefaultDict
from tinygrad.ops import BinaryOps, ReduceOps, UNSAFE_PAD_OPS, KernelInfo, BUFFER_UOPS, UOp, UOps, verify_ast, print_uops from tinygrad.ops import BinaryOps, ReduceOps, UNSAFE_PAD_OPS, KernelInfo, BUFFER_UOPS, UOp, UOps, print_uops
from tinygrad.device import Device from tinygrad.device import Device
from tinygrad.renderer import Renderer, TensorCore, Program from tinygrad.renderer import Renderer, TensorCore, Program
from tinygrad.dtype import DType, ImageDType, PtrDType from tinygrad.dtype import DType, ImageDType, PtrDType
@@ -638,7 +638,7 @@ class Kernel:
if op.op in BUFFER_UOPS: if op.op in BUFFER_UOPS:
# for locals, we use the ShapeTracker that's in the srcs # for locals, we use the ShapeTracker that's in the srcs
st = op.src[-1].arg if op.src[0].op is UOps.DEFINE_LOCAL else self.sts[self.bufs.index(op)] st = op.src[-1].arg if op.src[0].op is UOps.DEFINE_LOCAL else self.sts[self.bufs.index(op)]
idx, valid = UOp.from_st(st if apply_to_st is None else apply_to_st(st)) idx, valid = (st if apply_to_st is None else apply_to_st(st)).to_uops()
if op.op is UOps.CONST: return replace(op, src=(valid,)) if op.op is UOps.CONST: return replace(op, src=(valid,))
if op.op is UOps.STORE: return replace(op, src=(op.src[0], idx, fixup_ast(op.src[2], apply_to_st), valid)) if op.op is UOps.STORE: return replace(op, src=(op.src[0], idx, fixup_ast(op.src[2], apply_to_st), valid))
return replace(op, src=tuple(fixup_ast(x, apply_to_st) for x in op.src[:-2])+(idx, valid)) return replace(op, src=tuple(fixup_ast(x, apply_to_st) for x in op.src[:-2])+(idx, valid))
@@ -697,7 +697,7 @@ class Kernel:
for i,(src,fix_st_fxn) in enumerate(zip(rsrc.src, [fix_st1, fix_st2])): for i,(src,fix_st_fxn) in enumerate(zip(rsrc.src, [fix_st1, fix_st2])):
st_load = [self.sts[self.bufs.index(op)].real_strides() for op in rsrc.parents if op.op is UOps.LOAD] st_load = [self.sts[self.bufs.index(op)].real_strides() for op in rsrc.parents if op.op is UOps.LOAD]
local_shape = tuple(s if max(cast(int, x[i]) for x in st_load) != 0 else 1 for i,s in enumerate(ex_shape)) local_shape = tuple(s if max(cast(int, x[i]) for x in st_load) != 0 else 1 for i,s in enumerate(ex_shape))
idx, valid = UOp.from_st(ShapeTracker.from_shape(local_shape).expand(ex_shape)) idx, valid = ShapeTracker.from_shape(local_shape).expand(ex_shape).to_uops()
membuf = UOp(UOps.DEFINE_LOCAL, PtrDType(tc.dtype_in), (), (f"temp{-(-1-i)}", idx.arg.real_size())) membuf = UOp(UOps.DEFINE_LOCAL, PtrDType(tc.dtype_in), (), (f"temp{-(-1-i)}", idx.arg.real_size()))
local_store = fixup_ast(UOp(UOps.STORE, tc.dtype_in, (membuf, idx, src, valid)), fix_st_fxn) local_store = fixup_ast(UOp(UOps.STORE, tc.dtype_in, (membuf, idx, src, valid)), fix_st_fxn)
srcs.append(UOp(UOps.LOAD, tc.dtype_in, (membuf, local_store, idx, valid))) srcs.append(UOp(UOps.LOAD, tc.dtype_in, (membuf, local_store, idx, valid)))
@@ -714,7 +714,7 @@ class Kernel:
start = UOp(UOps.REDUCE_AXIS, op.dtype, (fixup_ast(op.src[0], apply_to_st),), arg=(op.arg[0], axis)) start = UOp(UOps.REDUCE_AXIS, op.dtype, (fixup_ast(op.src[0], apply_to_st),), arg=(op.arg[0], axis))
local_shape = (1,) * self.global_dims + self.full_shape[self.global_dims:self.global_dims+self.local_dims+self.group_for_reduces] + \ local_shape = (1,) * self.global_dims + self.full_shape[self.global_dims:self.global_dims+self.local_dims+self.group_for_reduces] + \
(1,) * (self.first_upcast - self.group_for_reduces - self.first_reduce) + tuple([x[0] for x in self.upcasted_axis(0)]) (1,) * (self.first_upcast - self.group_for_reduces - self.first_reduce) + tuple([x[0] for x in self.upcasted_axis(0)])
idx, valid = UOp.from_st(ShapeTracker.from_shape(local_shape)) idx, valid = ShapeTracker.from_shape(local_shape).to_uops()
local_buffer = UOp(UOps.DEFINE_LOCAL, PtrDType(cast(DType, op.dtype)), (), ("temp1", idx.arg.real_size())) local_buffer = UOp(UOps.DEFINE_LOCAL, PtrDType(cast(DType, op.dtype)), (), ("temp1", idx.arg.real_size()))
local_load = UOp(UOps.LOAD, op.dtype, (local_buffer, UOp.store(local_buffer, idx, start, valid), idx, valid)) local_load = UOp(UOps.LOAD, op.dtype, (local_buffer, UOp.store(local_buffer, idx, start, valid), idx, valid))
second_axis = tuple(range(self.first_reduce, self.first_reduce+self.group_for_reduces)) second_axis = tuple(range(self.first_reduce, self.first_reduce+self.group_for_reduces))
@@ -760,3 +760,30 @@ class Kernel:
key=lambda x: (x.op, x.src[0].arg))) key=lambda x: (x.op, x.src[0].arg)))
return Program(ansiname, src, self.opts.device, self.uops, mem_estimate=mem_bytes, return Program(ansiname, src, self.opts.device, self.uops, mem_estimate=mem_bytes,
global_size=[1,1,1] if self.opts.has_local else None, local_size=[1,1,1] if self.opts.has_local else None) global_size=[1,1,1] if self.opts.has_local else None, local_size=[1,1,1] if self.opts.has_local else None)
# the living definition of UOps.ST_IDX and UOps.ST_VALID
def verify_ast(ast:UOp) -> Dict[UOp, ShapeTracker]:
assert ast.op is UOps.SINK and all(x.op is UOps.STORE for x in ast.src), "must be SINK"
sts: Dict[UOp, ShapeTracker] = {}
def assert_valid(op:UOp, st:ShapeTracker):
if op in sts or op.op in {UOps.DEFINE_LOCAL, UOps.DEFINE_GLOBAL}: return
# restore globals from the two stage reduce
if op.op is UOps.LOAD and op.src[0].op is UOps.DEFINE_LOCAL:
assert_valid(local_reduce:=op.src[1].src[2], op.src[-1].arg)
return sts.setdefault(op, sts[local_reduce])
for x in op.src: assert_valid(x, st)
# only reduceop is allowed to change shape, limited to turning n to 1
if op.op is UOps.REDUCE_AXIS: st = ShapeTracker.from_shape(sts[op.src[0]].reduce(op.arg[1][-1] if op.arg[0] is ReduceOps.WMMA else op.arg[1]))
else:
# movementops are pushed to the edges with ST_IDX, ST_VALID
# elementwise inherits shape
st = op.arg if op.op in {UOps.ST_IDX, UOps.ST_VALID} else sts[op.src[-1]]
for x in (op.src[1:] if op.op in BUFFER_UOPS else op.src):
if sts[x].shape != st.shape:
if prod(sts[x].shape) == prod(st.shape): raise AssertionError(f"found implicit reshape {x.op} {op.op} {sts[x].shape} != {st.shape}")
raise AssertionError(f"found implicit expand {x.op} {sts[x].shape} != {op.op} {st.shape} {prod(sts[x].shape)} != {prod(st.shape)}")
sts[op] = st
for out in ast.src: assert_valid(out, out.src[-1].arg)
shape_dims = [sorted(dedup(dims)) for dims in zip(*[x.shape for x in sts.values()])]
assert all(len(x) == 1 or (len(x) == 2 and x[0] == 1) for x in shape_dims), f"shapes must have either 1 or n in each dimension, {shape_dims}"
return sts

View File

@@ -2,7 +2,7 @@ import sys, pickle, atexit, importlib, contextlib
from collections import defaultdict, deque from collections import defaultdict, deque
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import Tuple, List, Dict, Optional, Set, DefaultDict, get_args from typing import Tuple, List, Dict, Optional, Set, DefaultDict, get_args
from tinygrad.ops import MetaOps, ReduceOps, UNSAFE_PAD_OPS, UnaryOps, reduce_st, UOp, UOps from tinygrad.ops import MetaOps, ReduceOps, UNSAFE_PAD_OPS, UnaryOps, UOp, UOps
from tinygrad.engine.graph import log_lazybuffer, realized_lazybuffer from tinygrad.engine.graph import log_lazybuffer, realized_lazybuffer
from tinygrad.helpers import GRAPH, DEBUG, MULTIOUTPUT, SAVE_SCHEDULE, FUSE_CONV_BW, FUSE_ARANGE, \ from tinygrad.helpers import GRAPH, DEBUG, MULTIOUTPUT, SAVE_SCHEDULE, FUSE_CONV_BW, FUSE_ARANGE, \
GlobalCounters, colored, prod, dedup, all_int, merge_dicts, getenv, Metadata GlobalCounters, colored, prod, dedup, all_int, merge_dicts, getenv, Metadata
@@ -134,12 +134,12 @@ def _recurse_reduceops(buf:LazyBuffer, st:ShapeTracker, realizes:Dict[LazyBuffer
if buf.srcs[0] is not buf.srcs[0].base and buf.srcs[0].base is top_reduce[0] and buf.op is top_reduce[0].op: if buf.srcs[0] is not buf.srcs[0].base and buf.srcs[0].base is top_reduce[0] and buf.op is top_reduce[0].op:
# merge this reduce with its parent # merge this reduce with its parent
new_st = top_reduce[1]+st new_st = top_reduce[1]+st
top_reduce = (top_reduce[0], new_st.reshape(reduce_st(top_reduce_input_st, new_axis:=axis+top_reduce_axes))) top_reduce = (top_reduce[0], new_st.reshape(top_reduce_input_st.reduce(new_axis:=axis+top_reduce_axes)))
reduce_info[top_reduce] = (top_reduce_input_st, new_axis) reduce_info[top_reduce] = (top_reduce_input_st, new_axis)
return None return None
# reshape this reduceop based on the top reduce # reshape this reduceop based on the top reduce
input_st = input_st.reshape(tuple(1 if i in top_reduce_axes else s for i,s in enumerate(top_reduce_input_st.shape))) input_st = input_st.reshape(tuple(1 if i in top_reduce_axes else s for i,s in enumerate(top_reduce_input_st.shape)))
st = st.reshape(reduce_st(input_st, axis)) st = st.reshape(input_st.reduce(axis))
reduce_info[(buf, st)] = (input_st, axis) reduce_info[(buf, st)] = (input_st, axis)
return (buf, st) return (buf, st)
return cache.setdefault((buf, st), top_reduce) return cache.setdefault((buf, st), top_reduce)
@@ -147,7 +147,7 @@ def _recurse_reduceops(buf:LazyBuffer, st:ShapeTracker, realizes:Dict[LazyBuffer
def _lower_lazybuffer(outs:List[LazyBuffer], realizes:Dict[LazyBuffer, None]) -> LBScheduleItem: def _lower_lazybuffer(outs:List[LazyBuffer], realizes:Dict[LazyBuffer, None]) -> LBScheduleItem:
"""describe the computation for a LazyBuffer with UOp + inputs + var_vals""" """describe the computation for a LazyBuffer with UOp + inputs + var_vals"""
if (out:=outs[0]).op is MetaOps.COPY and getenv("USE_COPY_KERNEL") and out.device.split(":")[0] == out.srcs[0].device.split(":")[0]: if (out:=outs[0]).op is MetaOps.COPY and getenv("USE_COPY_KERNEL") and out.device.split(":")[0] == out.srcs[0].device.split(":")[0]:
idx, valid = UOp.from_st(ShapeTracker.from_shape(out.arg)) idx, valid = ShapeTracker.from_shape(out.arg).to_uops()
rd = UOp(UOps.LOAD, dtypes.uint8, (UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.uint8), (), 1), idx, valid)) rd = UOp(UOps.LOAD, dtypes.uint8, (UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.uint8), (), 1), idx, valid))
wr = UOp(UOps.STORE, None, (UOp(UOps.DEFINE_GLOBAL, PtrDType(out.dtype), (), 0), idx, rd, valid)) wr = UOp(UOps.STORE, None, (UOp(UOps.DEFINE_GLOBAL, PtrDType(out.dtype), (), 0), idx, rd, valid))
return LBScheduleItem(UOp(UOps.SINK, None, (wr,)), outs, [x.base for x in out.srcs]) return LBScheduleItem(UOp(UOps.SINK, None, (wr,)), outs, [x.base for x in out.srcs])
@@ -172,7 +172,7 @@ def _lower_lazybuffer(outs:List[LazyBuffer], realizes:Dict[LazyBuffer, None]) ->
ast: List[UOp] = [] ast: List[UOp] = []
inputs: Dict[LazyBuffer, int] = {} inputs: Dict[LazyBuffer, int] = {}
for i, out in enumerate(outs): for i, out in enumerate(outs):
output_st = ShapeTracker.from_shape(reduce_st(*deque(reduce_info.values(), 1).pop()) if reduce_info else out.shape) output_st = ShapeTracker.from_shape(ShapeTracker.reduce(*deque(reduce_info.values(), 1).pop()) if reduce_info else out.shape)
src = _recursive_uop(out, output_st, tuple(outs), var_vals, inputs, realizes, assign_targets, reduce_info, cache=cache) src = _recursive_uop(out, output_st, tuple(outs), var_vals, inputs, realizes, assign_targets, reduce_info, cache=cache)
if out.op is MetaOps.ASSIGN and out.arg: if out.op is MetaOps.ASSIGN and out.arg:
assert out.arg[0].shape == out.shape, f"ASSIGN must not override output shape {out.arg[0].shape} != {out.shape}" assert out.arg[0].shape == out.shape, f"ASSIGN must not override output shape {out.arg[0].shape} != {out.shape}"

View File

@@ -2,7 +2,7 @@ from __future__ import annotations
from typing import Union, Optional, Any, Tuple, List, get_args from typing import Union, Optional, Any, Tuple, List, get_args
from tinygrad.dtype import dtypes, DType, DTypeLike, ConstType, to_dtype from tinygrad.dtype import dtypes, DType, DTypeLike, ConstType, to_dtype
from tinygrad.helpers import prod, getenv, all_int, all_same, DEBUG, _METADATA, Metadata, SPLIT_REDUCEOP from tinygrad.helpers import prod, getenv, all_int, all_same, DEBUG, _METADATA, Metadata, SPLIT_REDUCEOP
from tinygrad.ops import MetaOps, UnaryOps, BinaryOps, TernaryOps, ReduceOps, Op, exec_alu, python_alu, reduce_st from tinygrad.ops import MetaOps, UnaryOps, BinaryOps, TernaryOps, ReduceOps, Op, exec_alu, python_alu
from tinygrad.shape.symbolic import sint, Variable from tinygrad.shape.symbolic import sint, Variable
from tinygrad.shape.shapetracker import ShapeTracker from tinygrad.shape.shapetracker import ShapeTracker
from tinygrad.device import Buffer from tinygrad.device import Buffer
@@ -171,10 +171,10 @@ class LazyBuffer:
assert all(0 <= x < len(self.shape) for x in axis), f"axis args {axis} out of range for shape {self.shape}" assert all(0 <= x < len(self.shape) for x in axis), f"axis args {axis} out of range for shape {self.shape}"
axis = tuple(sorted([x for x in axis if self.shape[x] != 1])) axis = tuple(sorted([x for x in axis if self.shape[x] != 1]))
if len(axis) == 0: return self if len(axis) == 0: return self
return create_lazybuffer(self.device, ShapeTracker.from_shape(reduce_st(self.st, axis)), self.dtype, op, axis, (self,)) return create_lazybuffer(self.device, ShapeTracker.from_shape(self.st.reduce(axis)), self.dtype, op, axis, (self,))
def r(self, op:ReduceOps, axis:Tuple[int, ...]) -> LazyBuffer: def r(self, op:ReduceOps, axis:Tuple[int, ...]) -> LazyBuffer:
new_shape = reduce_st(self.st, axis) new_shape = self.st.reduce(axis)
# TODO: this logic should move to the scheduler # TODO: this logic should move to the scheduler
if 0 in self.shape and 0 not in new_shape: return self.const({ReduceOps.SUM: 0.0, ReduceOps.MAX: dtypes.min(self.dtype)}[op], new_shape) if 0 in self.shape and 0 not in new_shape: return self.const({ReduceOps.SUM: 0.0, ReduceOps.MAX: dtypes.min(self.dtype)}[op], new_shape)

View File

@@ -5,9 +5,8 @@ import math, operator, ctypes, struct, functools, hashlib, itertools
from enum import Enum, auto from enum import Enum, auto
from dataclasses import dataclass from dataclasses import dataclass
from tinygrad.dtype import ConstType, dtypes, DType from tinygrad.dtype import ConstType, dtypes, DType
from tinygrad.helpers import dedup, merge_dicts, pretty_print, prod from tinygrad.helpers import merge_dicts, pretty_print, prod
from tinygrad.shape.symbolic import Variable, sint from tinygrad.shape.symbolic import Variable, sint
from tinygrad.shape.shapetracker import ShapeTracker
# these are the llops your accelerator must implement, along with toCpu # these are the llops your accelerator must implement, along with toCpu
# the Enum class doesn't work with mypy, this is static. sorry it's ugly # the Enum class doesn't work with mypy, this is static. sorry it's ugly
@@ -74,8 +73,6 @@ truncate: Dict[DType, Callable] = {dtypes.bool: bool,
def exec_alu(op:Op, dtype:DType, operands): return truncate.get(dtype, lambda x: x)(python_alu[op](*operands)) def exec_alu(op:Op, dtype:DType, operands): return truncate.get(dtype, lambda x: x)(python_alu[op](*operands))
def reduce_st(st:ShapeTracker, axis:Tuple[int, ...]) -> Tuple[sint, ...]: return tuple(1 if i in axis else s for i,s in enumerate(st.shape))
# the order of these UOps controls the order of the toposort # the order of these UOps controls the order of the toposort
class UOps(Enum): class UOps(Enum):
# ops that aren't rendered # ops that aren't rendered
@@ -163,8 +160,6 @@ class UOp:
def parents(self) -> Dict[UOp, None]: return merge_dicts([{x:None for x in self.src}]+[x.parents for x in self.src]) def parents(self) -> Dict[UOp, None]: return merge_dicts([{x:None for x in self.src}]+[x.parents for x in self.src])
@property # parents with self @property # parents with self
def sparents(self) -> Dict[UOp, None]: return {**self.parents, self:None} def sparents(self) -> Dict[UOp, None]: return {**self.parents, self:None}
@staticmethod
def from_st(st:ShapeTracker) -> Tuple[UOp, UOp]: return UOp(UOps.ST_IDX, dtypes.pyint, (), st), UOp(UOps.ST_VALID, dtypes.bool, (), st)
@functools.cached_property @functools.cached_property
def full_shape(self) -> Tuple[sint, ...]: def full_shape(self) -> Tuple[sint, ...]:
if self.op in {UOps.ST_IDX, UOps.ST_VALID}: return self.arg.shape if self.op in {UOps.ST_IDX, UOps.ST_VALID}: return self.arg.shape
@@ -374,30 +369,3 @@ def flops_mem(uops:List[UOp], ignore_indexing=False) -> Tuple[sint, sint]:
assert u.arg[1] is not None assert u.arg[1] is not None
flops += 2 * prod(u.arg[1]) // 32 * mults flops += 2 * prod(u.arg[1]) // 32 * mults
return flops, mem return flops, mem
# the living definition of UOps.ST_IDX and UOps.ST_VALID
def verify_ast(ast:UOp) -> Dict[UOp, ShapeTracker]:
assert ast.op is UOps.SINK and all(x.op is UOps.STORE for x in ast.src), "must be SINK"
sts: Dict[UOp, ShapeTracker] = {}
def assert_valid(op:UOp, st:ShapeTracker):
if op in sts or op.op in {UOps.DEFINE_LOCAL, UOps.DEFINE_GLOBAL}: return
# restore globals from the two stage reduce
if op.op is UOps.LOAD and op.src[0].op is UOps.DEFINE_LOCAL:
assert_valid(local_reduce:=op.src[1].src[2], op.src[-1].arg)
return sts.setdefault(op, sts[local_reduce])
for x in op.src: assert_valid(x, st)
# only reduceop is allowed to change shape, limited to turning n to 1
if op.op is UOps.REDUCE_AXIS: st = ShapeTracker.from_shape(reduce_st(sts[op.src[0]], op.arg[1][-1] if op.arg[0] is ReduceOps.WMMA else op.arg[1]))
else:
# movementops are pushed to the edges with ST_IDX, ST_VALID
# elementwise inherits shape
st = op.arg if op.op in {UOps.ST_IDX, UOps.ST_VALID} else sts[op.src[-1]]
for x in (op.src[1:] if op.op in BUFFER_UOPS else op.src):
if sts[x].shape != st.shape:
if prod(sts[x].shape) == prod(st.shape): raise AssertionError(f"found implicit reshape {x.op} {op.op} {sts[x].shape} != {st.shape}")
raise AssertionError(f"found implicit expand {x.op} {sts[x].shape} != {op.op} {st.shape} {prod(sts[x].shape)} != {prod(st.shape)}")
sts[op] = st
for out in ast.src: assert_valid(out, out.src[-1].arg)
shape_dims = [sorted(dedup(dims)) for dims in zip(*[x.shape for x in sts.values()])]
assert all(len(x) == 1 or (len(x) == 2 and x[0] == 1) for x in shape_dims), f"shapes must have either 1 or n in each dimension, {shape_dims}"
return sts

View File

@@ -5,6 +5,8 @@ from typing import Tuple, List, Optional, Dict, Set, Iterable, cast
from tinygrad.helpers import merge_dicts, getenv from tinygrad.helpers import merge_dicts, getenv
from tinygrad.shape.symbolic import Variable, MulNode, Node, SumNode, NumNode, sint from tinygrad.shape.symbolic import Variable, MulNode, Node, SumNode, NumNode, sint
from tinygrad.shape.view import View, strides_for_shape from tinygrad.shape.view import View, strides_for_shape
from tinygrad.dtype import dtypes
from tinygrad.ops import UOp, UOps
@dataclass(frozen=True) @dataclass(frozen=True)
class ShapeTracker: class ShapeTracker:
@@ -37,6 +39,10 @@ class ShapeTracker:
@property @property
def size(self) -> int: return self.views[-1].size() def size(self) -> int: return self.views[-1].size()
def reduce(self, axis:Tuple[int, ...]) -> Tuple[sint, ...]: return tuple(1 if i in axis else s for i,s in enumerate(self.shape))
def to_uops(self) -> Tuple[UOp, UOp]: return UOp(UOps.ST_IDX, dtypes.pyint, (), self), UOp(UOps.ST_VALID, dtypes.bool, (), self)
def real_size(self) -> int: def real_size(self) -> int:
if 0 in self.shape: return 0 if 0 in self.shape: return 0
idx, valid = self.expr_idxs() idx, valid = self.expr_idxs()