s/shape_spec/ast_spec [pr] (#10485)

This commit is contained in:
qazal
2025-05-23 15:43:54 +03:00
committed by GitHub
parent 127a7c8aee
commit 7a762f01ab
4 changed files with 7 additions and 9 deletions

View File

@@ -13,13 +13,11 @@ from tinygrad.dtype import DType, ImageDType
from tinygrad.shape.shapetracker import ShapeTracker
from tinygrad.uop.ops import PatternMatcher, UOp, Ops, GroupOp, UPat, graph_rewrite, track_rewrites
from tinygrad.codegen.symbolic import symbolic_simple
from tinygrad.uop.spec import type_verify, shape_spec
from tinygrad.helpers import CI, DEBUG, FUSE_ARANGE, SPLIT_REDUCEOP, GlobalCounters, Context, getenv, all_same, temp
from tinygrad.engine.grouper import view_left, view_right, sym, get_kernelize_map, Kernel, create_ast, merge_views, create_kernels
from tinygrad.engine.schedule import ScheduleItem, create_schedule_with_vars
from tinygrad.engine.realize import CompiledRunner, run_schedule, lower_schedule
def verify_ast(sink:UOp): return type_verify(list(sink.toposort()), shape_spec)
class KernelCountException(Exception): pass
def check_schedule(t:Union[Tensor, List[Tensor], UOp], allowed:int, to_prerealize:Optional[List[Tensor]]=None, filter_sink=True):
if to_prerealize:

View File

@@ -5,7 +5,7 @@ from tinygrad import Tensor
from tinygrad.codegen.kernel import Kernel
from tinygrad.helpers import DEBUG
from tinygrad.uop.ops import UOp, Ops, print_uops
from tinygrad.uop.spec import type_verify, shape_spec
from tinygrad.uop.spec import type_verify, ast_spec
from tinygrad.shape.shapetracker import ShapeTracker
from tinygrad import dtypes
from tinygrad.shape.view import View
@@ -15,7 +15,7 @@ def helper_test_verify_ast(*stores:UOp) -> Kernel:
sink = UOp(Ops.SINK, dtypes.void, stores)
if DEBUG >= 3:
for op in stores: print(op)
try: type_verify(list(sink.toposort()), shape_spec)
try: type_verify(list(sink.toposort()), ast_spec)
except RuntimeError as e: raise InvalidASTException(e.args)
k = Kernel(sink)
k.linearize()

View File

@@ -6,7 +6,7 @@ from typing import Optional, cast, Final, Callable, Sequence
from tinygrad.uop.ops import GroupOp, KernelInfo, UOp, Ops, can_pad, resolve, Variable, sint, graph_rewrite, track_rewrites, print_uops
from tinygrad.uop.ops import PatternMatcher, smax
from tinygrad.uop.spec import type_verify, shape_spec
from tinygrad.uop.spec import type_verify, ast_spec
from tinygrad.device import Device
from tinygrad.renderer import Renderer, TensorCore, ProgramSpec, Opt, OptOps
from tinygrad.dtype import ImageDType
@@ -42,7 +42,7 @@ class Kernel:
self.opts = opts if opts is not None else Device[Device.DEFAULT].renderer
# verify AST matches the spec
if __debug__: type_verify(list(self.ast.toposort()), shape_spec)
if __debug__: type_verify(list(self.ast.toposort()), ast_spec)
self.reduceops = [x for x in self.ast.toposort() if x.op is Ops.REDUCE_AXIS]
@@ -550,7 +550,7 @@ class Kernel:
# verify AST matches the spec after applying opts
if __debug__: type_verify(list(modified_ast.toposort()))
# TODO: sadly modified_ast doesn't pass the shape spec because of how group_for_reduces constructs UOps, there's probably a way to fix this
#if __debug__: type_verify(list(modified_ast.toposort()), shape_spec)
#if __debug__: type_verify(list(modified_ast.toposort()), ast_spec)
try:
self.uops:list[UOp] = full_rewrite(modified_ast, self.opts)

View File

@@ -191,7 +191,7 @@ sched_spec = buffer_spec+assign_spec+PatternMatcher([
(UPat(GroupOp.All-{Ops.SINK}), lambda: False),
])
# *** this is the UOp shape spec ***
# *** this is the UOp AST spec ***
def verify_sink_dims(sink:UOp):
if not all_same([s.shape for s in sink.src]): return False
@@ -200,7 +200,7 @@ def verify_sink_dims(sink:UOp):
print(f"# INVALID KERNEL DIMS: can only have 1 or n in each dimension: {n_dims}")
return False
shape_spec = PatternMatcher([
ast_spec = PatternMatcher([
# shapes must have either 1 or n in each dimension
(UPat(Ops.SINK, src=UPat(Ops.STORE), name="sink"), verify_sink_dims),
# VIEW can only exist in the edges