mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-07 22:23:55 -05:00
move Kernel dataclass to rangeify (#12510)
This commit is contained in:
@@ -1,12 +1,12 @@
|
||||
from dataclasses import dataclass
|
||||
from tinygrad.uop.ops import UOp, Ops, GroupOp, PatternMatcher, UPat, graph_rewrite, graph_rewrite_map, identity_element, resolve
|
||||
from tinygrad.uop.ops import track_rewrites, _substitute, KernelInfo
|
||||
from tinygrad.uop.spec import type_verify, tensor_uop_spec
|
||||
from tinygrad.uop.symbolic import symbolic_simple
|
||||
from tinygrad.helpers import Metadata, all_int, all_same, prod, dedup, unwrap, getenv, pluralize, FUSE_ARANGE, DEBUG, SPLIT_REDUCEOP
|
||||
from tinygrad.helpers import all_int, all_same, prod, dedup, unwrap, getenv, pluralize, FUSE_ARANGE, DEBUG, SPLIT_REDUCEOP
|
||||
from tinygrad.dtype import ImageDType
|
||||
from tinygrad.schedule.multi import multi_pm
|
||||
from tinygrad.schedule.grouper import group_realizes, ALWAYS_CONTIGUOUS
|
||||
from tinygrad.schedule.rangeify import Kernel
|
||||
from tinygrad.codegen.opt.swizzler import merge_views, apply_swizzle, swizzle_reduceop
|
||||
from tinygrad.codegen.opt import Opt
|
||||
|
||||
@@ -108,14 +108,6 @@ replace_contiguous = PatternMatcher([
|
||||
|
||||
# **** create kernels
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class Kernel:
|
||||
ast: UOp
|
||||
metadata: tuple[Metadata, ...] = ()
|
||||
def __repr__(self):
|
||||
ast_rep = f"SINK{tuple(s.op for s in self.ast.src)}" if self.ast.op is Ops.SINK else repr(self.ast.op)
|
||||
return f"<Kernel {len(list(self.ast.toposort()))} {ast_rep} {self.metadata}>"
|
||||
|
||||
def create_kernel(x:UOp, b:UOp|None=None):
|
||||
if b is None: b = UOp.new_buffer(x.device, x.size, x.dtype)
|
||||
kernel = UOp(Ops.KERNEL, src=(b,)+x.src, arg=Kernel(x.sink(), m if (m:=x.metadata) else ()))
|
||||
|
||||
@@ -5,7 +5,7 @@ from tinygrad.dtype import dtypes, PtrDType, ImageDType, AddrSpace
|
||||
from tinygrad.uop.ops import PatternMatcher, UPat, Ops, UOp, resolve, GroupOp, RewriteNotReady, _substitute, ssimplify, KernelInfo
|
||||
from tinygrad.uop.symbolic import sym, symbolic_simple
|
||||
from tinygrad.helpers import argsort, prod, all_same, pluralize, getenv, RANGEIFY, Context, flatten, dedup, unwrap, all_int, DEBUG, SPLIT_REDUCEOP
|
||||
from tinygrad.schedule.kernelize import Kernel
|
||||
from tinygrad.helpers import Metadata
|
||||
from tinygrad.uop.ops import track_rewrites, graph_rewrite, identity_element, sint, AxisType
|
||||
from tinygrad.codegen.simplify import pm_flatten_range, pm_reduce_unparented
|
||||
from tinygrad.codegen.opt import Opt
|
||||
@@ -676,6 +676,14 @@ pm_remove_tags = PatternMatcher([
|
||||
(UPat(GroupOp.All, name="x"), remove_metadata_tags),
|
||||
])
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class Kernel:
|
||||
ast: UOp
|
||||
metadata: tuple[Metadata, ...] = ()
|
||||
def __repr__(self):
|
||||
ast_rep = f"SINK{tuple(s.op for s in self.ast.src)}" if self.ast.op is Ops.SINK else repr(self.ast.op)
|
||||
return f"<Kernel {len(list(self.ast.toposort()))} {ast_rep} {self.metadata}>"
|
||||
|
||||
def split_store(ctx:list[UOp], x:UOp):
|
||||
if len(x.ranges): return None
|
||||
if x.src[0].ptrdtype.addrspace is AddrSpace.LOCAL: return None
|
||||
|
||||
Reference in New Issue
Block a user