mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 23:18:04 -05:00
move Kernel dataclass to rangeify (#12510)
This commit is contained in:
@@ -1,12 +1,12 @@
|
|||||||
from dataclasses import dataclass
|
|
||||||
from tinygrad.uop.ops import UOp, Ops, GroupOp, PatternMatcher, UPat, graph_rewrite, graph_rewrite_map, identity_element, resolve
|
from tinygrad.uop.ops import UOp, Ops, GroupOp, PatternMatcher, UPat, graph_rewrite, graph_rewrite_map, identity_element, resolve
|
||||||
from tinygrad.uop.ops import track_rewrites, _substitute, KernelInfo
|
from tinygrad.uop.ops import track_rewrites, _substitute, KernelInfo
|
||||||
from tinygrad.uop.spec import type_verify, tensor_uop_spec
|
from tinygrad.uop.spec import type_verify, tensor_uop_spec
|
||||||
from tinygrad.uop.symbolic import symbolic_simple
|
from tinygrad.uop.symbolic import symbolic_simple
|
||||||
from tinygrad.helpers import Metadata, all_int, all_same, prod, dedup, unwrap, getenv, pluralize, FUSE_ARANGE, DEBUG, SPLIT_REDUCEOP
|
from tinygrad.helpers import all_int, all_same, prod, dedup, unwrap, getenv, pluralize, FUSE_ARANGE, DEBUG, SPLIT_REDUCEOP
|
||||||
from tinygrad.dtype import ImageDType
|
from tinygrad.dtype import ImageDType
|
||||||
from tinygrad.schedule.multi import multi_pm
|
from tinygrad.schedule.multi import multi_pm
|
||||||
from tinygrad.schedule.grouper import group_realizes, ALWAYS_CONTIGUOUS
|
from tinygrad.schedule.grouper import group_realizes, ALWAYS_CONTIGUOUS
|
||||||
|
from tinygrad.schedule.rangeify import Kernel
|
||||||
from tinygrad.codegen.opt.swizzler import merge_views, apply_swizzle, swizzle_reduceop
|
from tinygrad.codegen.opt.swizzler import merge_views, apply_swizzle, swizzle_reduceop
|
||||||
from tinygrad.codegen.opt import Opt
|
from tinygrad.codegen.opt import Opt
|
||||||
|
|
||||||
@@ -108,14 +108,6 @@ replace_contiguous = PatternMatcher([
|
|||||||
|
|
||||||
# **** create kernels
|
# **** create kernels
|
||||||
|
|
||||||
@dataclass(frozen=True)
|
|
||||||
class Kernel:
|
|
||||||
ast: UOp
|
|
||||||
metadata: tuple[Metadata, ...] = ()
|
|
||||||
def __repr__(self):
|
|
||||||
ast_rep = f"SINK{tuple(s.op for s in self.ast.src)}" if self.ast.op is Ops.SINK else repr(self.ast.op)
|
|
||||||
return f"<Kernel {len(list(self.ast.toposort()))} {ast_rep} {self.metadata}>"
|
|
||||||
|
|
||||||
def create_kernel(x:UOp, b:UOp|None=None):
|
def create_kernel(x:UOp, b:UOp|None=None):
|
||||||
if b is None: b = UOp.new_buffer(x.device, x.size, x.dtype)
|
if b is None: b = UOp.new_buffer(x.device, x.size, x.dtype)
|
||||||
kernel = UOp(Ops.KERNEL, src=(b,)+x.src, arg=Kernel(x.sink(), m if (m:=x.metadata) else ()))
|
kernel = UOp(Ops.KERNEL, src=(b,)+x.src, arg=Kernel(x.sink(), m if (m:=x.metadata) else ()))
|
||||||
|
|||||||
@@ -5,7 +5,7 @@ from tinygrad.dtype import dtypes, PtrDType, ImageDType, AddrSpace
|
|||||||
from tinygrad.uop.ops import PatternMatcher, UPat, Ops, UOp, resolve, GroupOp, RewriteNotReady, _substitute, ssimplify, KernelInfo
|
from tinygrad.uop.ops import PatternMatcher, UPat, Ops, UOp, resolve, GroupOp, RewriteNotReady, _substitute, ssimplify, KernelInfo
|
||||||
from tinygrad.uop.symbolic import sym, symbolic_simple
|
from tinygrad.uop.symbolic import sym, symbolic_simple
|
||||||
from tinygrad.helpers import argsort, prod, all_same, pluralize, getenv, RANGEIFY, Context, flatten, dedup, unwrap, all_int, DEBUG, SPLIT_REDUCEOP
|
from tinygrad.helpers import argsort, prod, all_same, pluralize, getenv, RANGEIFY, Context, flatten, dedup, unwrap, all_int, DEBUG, SPLIT_REDUCEOP
|
||||||
from tinygrad.schedule.kernelize import Kernel
|
from tinygrad.helpers import Metadata
|
||||||
from tinygrad.uop.ops import track_rewrites, graph_rewrite, identity_element, sint, AxisType
|
from tinygrad.uop.ops import track_rewrites, graph_rewrite, identity_element, sint, AxisType
|
||||||
from tinygrad.codegen.simplify import pm_flatten_range, pm_reduce_unparented
|
from tinygrad.codegen.simplify import pm_flatten_range, pm_reduce_unparented
|
||||||
from tinygrad.codegen.opt import Opt
|
from tinygrad.codegen.opt import Opt
|
||||||
@@ -676,6 +676,14 @@ pm_remove_tags = PatternMatcher([
|
|||||||
(UPat(GroupOp.All, name="x"), remove_metadata_tags),
|
(UPat(GroupOp.All, name="x"), remove_metadata_tags),
|
||||||
])
|
])
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class Kernel:
|
||||||
|
ast: UOp
|
||||||
|
metadata: tuple[Metadata, ...] = ()
|
||||||
|
def __repr__(self):
|
||||||
|
ast_rep = f"SINK{tuple(s.op for s in self.ast.src)}" if self.ast.op is Ops.SINK else repr(self.ast.op)
|
||||||
|
return f"<Kernel {len(list(self.ast.toposort()))} {ast_rep} {self.metadata}>"
|
||||||
|
|
||||||
def split_store(ctx:list[UOp], x:UOp):
|
def split_store(ctx:list[UOp], x:UOp):
|
||||||
if len(x.ranges): return None
|
if len(x.ranges): return None
|
||||||
if x.src[0].ptrdtype.addrspace is AddrSpace.LOCAL: return None
|
if x.src[0].ptrdtype.addrspace is AddrSpace.LOCAL: return None
|
||||||
|
|||||||
Reference in New Issue
Block a user