diff --git a/tinygrad/schedule/kernelize.py b/tinygrad/schedule/kernelize.py index c487fb5c9f..2fcb892332 100644 --- a/tinygrad/schedule/kernelize.py +++ b/tinygrad/schedule/kernelize.py @@ -1,12 +1,12 @@ -from dataclasses import dataclass from tinygrad.uop.ops import UOp, Ops, GroupOp, PatternMatcher, UPat, graph_rewrite, graph_rewrite_map, identity_element, resolve from tinygrad.uop.ops import track_rewrites, _substitute, KernelInfo from tinygrad.uop.spec import type_verify, tensor_uop_spec from tinygrad.uop.symbolic import symbolic_simple -from tinygrad.helpers import Metadata, all_int, all_same, prod, dedup, unwrap, getenv, pluralize, FUSE_ARANGE, DEBUG, SPLIT_REDUCEOP +from tinygrad.helpers import all_int, all_same, prod, dedup, unwrap, getenv, pluralize, FUSE_ARANGE, DEBUG, SPLIT_REDUCEOP from tinygrad.dtype import ImageDType from tinygrad.schedule.multi import multi_pm from tinygrad.schedule.grouper import group_realizes, ALWAYS_CONTIGUOUS +from tinygrad.schedule.rangeify import Kernel from tinygrad.codegen.opt.swizzler import merge_views, apply_swizzle, swizzle_reduceop from tinygrad.codegen.opt import Opt @@ -108,14 +108,6 @@ replace_contiguous = PatternMatcher([ # **** create kernels -@dataclass(frozen=True) -class Kernel: - ast: UOp - metadata: tuple[Metadata, ...] = () - def __repr__(self): - ast_rep = f"SINK{tuple(s.op for s in self.ast.src)}" if self.ast.op is Ops.SINK else repr(self.ast.op) - return f"" - def create_kernel(x:UOp, b:UOp|None=None): if b is None: b = UOp.new_buffer(x.device, x.size, x.dtype) kernel = UOp(Ops.KERNEL, src=(b,)+x.src, arg=Kernel(x.sink(), m if (m:=x.metadata) else ())) diff --git a/tinygrad/schedule/rangeify.py b/tinygrad/schedule/rangeify.py index 53b8c03cfb..06109b3ff4 100644 --- a/tinygrad/schedule/rangeify.py +++ b/tinygrad/schedule/rangeify.py @@ -5,7 +5,7 @@ from tinygrad.dtype import dtypes, PtrDType, ImageDType, AddrSpace from tinygrad.uop.ops import PatternMatcher, UPat, Ops, UOp, resolve, GroupOp, RewriteNotReady, _substitute, ssimplify, KernelInfo from tinygrad.uop.symbolic import sym, symbolic_simple from tinygrad.helpers import argsort, prod, all_same, pluralize, getenv, RANGEIFY, Context, flatten, dedup, unwrap, all_int, DEBUG, SPLIT_REDUCEOP -from tinygrad.schedule.kernelize import Kernel +from tinygrad.helpers import Metadata from tinygrad.uop.ops import track_rewrites, graph_rewrite, identity_element, sint, AxisType from tinygrad.codegen.simplify import pm_flatten_range, pm_reduce_unparented from tinygrad.codegen.opt import Opt @@ -676,6 +676,14 @@ pm_remove_tags = PatternMatcher([ (UPat(GroupOp.All, name="x"), remove_metadata_tags), ]) +@dataclass(frozen=True) +class Kernel: + ast: UOp + metadata: tuple[Metadata, ...] = () + def __repr__(self): + ast_rep = f"SINK{tuple(s.op for s in self.ast.src)}" if self.ast.op is Ops.SINK else repr(self.ast.op) + return f"" + def split_store(ctx:list[UOp], x:UOp): if len(x.ranges): return None if x.src[0].ptrdtype.addrspace is AddrSpace.LOCAL: return None