mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-10 07:28:15 -05:00
s/shape_spec/ast_spec [pr] (#10485)
This commit is contained in:
@@ -13,13 +13,11 @@ from tinygrad.dtype import DType, ImageDType
|
||||
from tinygrad.shape.shapetracker import ShapeTracker
|
||||
from tinygrad.uop.ops import PatternMatcher, UOp, Ops, GroupOp, UPat, graph_rewrite, track_rewrites
|
||||
from tinygrad.codegen.symbolic import symbolic_simple
|
||||
from tinygrad.uop.spec import type_verify, shape_spec
|
||||
from tinygrad.helpers import CI, DEBUG, FUSE_ARANGE, SPLIT_REDUCEOP, GlobalCounters, Context, getenv, all_same, temp
|
||||
from tinygrad.engine.grouper import view_left, view_right, sym, get_kernelize_map, Kernel, create_ast, merge_views, create_kernels
|
||||
from tinygrad.engine.schedule import ScheduleItem, create_schedule_with_vars
|
||||
from tinygrad.engine.realize import CompiledRunner, run_schedule, lower_schedule
|
||||
|
||||
def verify_ast(sink:UOp): return type_verify(list(sink.toposort()), shape_spec)
|
||||
class KernelCountException(Exception): pass
|
||||
def check_schedule(t:Union[Tensor, List[Tensor], UOp], allowed:int, to_prerealize:Optional[List[Tensor]]=None, filter_sink=True):
|
||||
if to_prerealize:
|
||||
|
||||
@@ -5,7 +5,7 @@ from tinygrad import Tensor
|
||||
from tinygrad.codegen.kernel import Kernel
|
||||
from tinygrad.helpers import DEBUG
|
||||
from tinygrad.uop.ops import UOp, Ops, print_uops
|
||||
from tinygrad.uop.spec import type_verify, shape_spec
|
||||
from tinygrad.uop.spec import type_verify, ast_spec
|
||||
from tinygrad.shape.shapetracker import ShapeTracker
|
||||
from tinygrad import dtypes
|
||||
from tinygrad.shape.view import View
|
||||
@@ -15,7 +15,7 @@ def helper_test_verify_ast(*stores:UOp) -> Kernel:
|
||||
sink = UOp(Ops.SINK, dtypes.void, stores)
|
||||
if DEBUG >= 3:
|
||||
for op in stores: print(op)
|
||||
try: type_verify(list(sink.toposort()), shape_spec)
|
||||
try: type_verify(list(sink.toposort()), ast_spec)
|
||||
except RuntimeError as e: raise InvalidASTException(e.args)
|
||||
k = Kernel(sink)
|
||||
k.linearize()
|
||||
|
||||
@@ -6,7 +6,7 @@ from typing import Optional, cast, Final, Callable, Sequence
|
||||
|
||||
from tinygrad.uop.ops import GroupOp, KernelInfo, UOp, Ops, can_pad, resolve, Variable, sint, graph_rewrite, track_rewrites, print_uops
|
||||
from tinygrad.uop.ops import PatternMatcher, smax
|
||||
from tinygrad.uop.spec import type_verify, shape_spec
|
||||
from tinygrad.uop.spec import type_verify, ast_spec
|
||||
from tinygrad.device import Device
|
||||
from tinygrad.renderer import Renderer, TensorCore, ProgramSpec, Opt, OptOps
|
||||
from tinygrad.dtype import ImageDType
|
||||
@@ -42,7 +42,7 @@ class Kernel:
|
||||
|
||||
self.opts = opts if opts is not None else Device[Device.DEFAULT].renderer
|
||||
# verify AST matches the spec
|
||||
if __debug__: type_verify(list(self.ast.toposort()), shape_spec)
|
||||
if __debug__: type_verify(list(self.ast.toposort()), ast_spec)
|
||||
|
||||
self.reduceops = [x for x in self.ast.toposort() if x.op is Ops.REDUCE_AXIS]
|
||||
|
||||
@@ -550,7 +550,7 @@ class Kernel:
|
||||
# verify AST matches the spec after applying opts
|
||||
if __debug__: type_verify(list(modified_ast.toposort()))
|
||||
# TODO: sadly modified_ast doesn't pass the shape spec because of how group_for_reduces constructs UOps, there's probably a way to fix this
|
||||
#if __debug__: type_verify(list(modified_ast.toposort()), shape_spec)
|
||||
#if __debug__: type_verify(list(modified_ast.toposort()), ast_spec)
|
||||
|
||||
try:
|
||||
self.uops:list[UOp] = full_rewrite(modified_ast, self.opts)
|
||||
|
||||
@@ -191,7 +191,7 @@ sched_spec = buffer_spec+assign_spec+PatternMatcher([
|
||||
(UPat(GroupOp.All-{Ops.SINK}), lambda: False),
|
||||
])
|
||||
|
||||
# *** this is the UOp shape spec ***
|
||||
# *** this is the UOp AST spec ***
|
||||
|
||||
def verify_sink_dims(sink:UOp):
|
||||
if not all_same([s.shape for s in sink.src]): return False
|
||||
@@ -200,7 +200,7 @@ def verify_sink_dims(sink:UOp):
|
||||
print(f"# INVALID KERNEL DIMS: can only have 1 or n in each dimension: {n_dims}")
|
||||
return False
|
||||
|
||||
shape_spec = PatternMatcher([
|
||||
ast_spec = PatternMatcher([
|
||||
# shapes must have either 1 or n in each dimension
|
||||
(UPat(Ops.SINK, src=UPat(Ops.STORE), name="sink"), verify_sink_dims),
|
||||
# VIEW can only exist in the edges
|
||||
|
||||
Reference in New Issue
Block a user