applied_opts is in the optimized ast] (#10906)

This commit is contained in:
George Hotz
2025-06-20 18:56:23 -07:00
committed by GitHub
parent 2d9c61e39e
commit fa52bdb50f
3 changed files with 7 additions and 3 deletions

View File

@@ -458,7 +458,7 @@ class Kernel:
return ret.replace(src=(ret.src[0].replace(arg=st),)+ret.src[1:])
if op.op is Ops.SINK:
return ret.replace(arg = KernelInfo(to_function_name(self.name) if name_override is None else name_override,
self.local_dims, self.upcasted, self.dont_use_locals))
self.local_dims, self.upcasted, self.dont_use_locals, tuple(self.applied_opts)))
if op.op is Ops.REDUCE_AXIS:
reduce_idx = len(self.bufs) + self.reduceops.index(op) * 2
@@ -565,5 +565,5 @@ class Kernel:
self.linearize(name_override, ast_transform)
assert self.uops[-1].op is Ops.SINK, "last uop must be sink"
src = self.opts.render(self.uops)
return ProgramSpec(self.name if not name_override else name_override, src, self.opts.device, self.ast, self.uops, self.applied_opts,
return ProgramSpec(self.name if not name_override else name_override, src, self.opts.device, self.ast, self.uops,
global_size=[1,1,1] if self.opts.has_local else None, local_size=[1,1,1] if self.opts.has_local else None)

View File

@@ -87,7 +87,6 @@ class ProgramSpec:
device:str
ast:UOp # save the base ast (this is method cache key)
uops:Optional[list[UOp]]=None
applied_opts:Optional[list[Opt]]=None
# filled in from uops (if we have uops)
global_size:Optional[list[int]]=None
@@ -131,6 +130,10 @@ class ProgramSpec:
@functools.cached_property
def function_name(self) -> str: return to_function_name(self.name)
@property
def applied_opts(self) -> tuple[Opt, ...]|None: return self.uops[-1].arg.applied_opts if \
self.uops is not None and self.uops[-1].op is Ops.SINK and self.uops[-1].arg is not None else None
def launch_dims(self, var_vals:dict[Variable, int]):
global_size = [sym_infer(sz, var_vals) for sz in self.global_size] if self.global_size is not None else None
local_size = [sym_infer(sz, var_vals) for sz in self.local_size] if self.local_size is not None else None

View File

@@ -518,6 +518,7 @@ class KernelInfo:
local_dims: int = 0 # number of local dimensions (this is remapping RANGE to SPECIAL)
upcasted: int = 0 # count that are upcasted (this is remapping RANGE to UNROLL)
dont_use_locals: bool = False # don't use local indexing
applied_opts: tuple = tuple()
# ******** ops in python ********