mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-23 22:08:08 -05:00
applied_opts is in the optimized ast] (#10906)
This commit is contained in:
@@ -458,7 +458,7 @@ class Kernel:
|
||||
return ret.replace(src=(ret.src[0].replace(arg=st),)+ret.src[1:])
|
||||
if op.op is Ops.SINK:
|
||||
return ret.replace(arg = KernelInfo(to_function_name(self.name) if name_override is None else name_override,
|
||||
self.local_dims, self.upcasted, self.dont_use_locals))
|
||||
self.local_dims, self.upcasted, self.dont_use_locals, tuple(self.applied_opts)))
|
||||
if op.op is Ops.REDUCE_AXIS:
|
||||
reduce_idx = len(self.bufs) + self.reduceops.index(op) * 2
|
||||
|
||||
@@ -565,5 +565,5 @@ class Kernel:
|
||||
self.linearize(name_override, ast_transform)
|
||||
assert self.uops[-1].op is Ops.SINK, "last uop must be sink"
|
||||
src = self.opts.render(self.uops)
|
||||
return ProgramSpec(self.name if not name_override else name_override, src, self.opts.device, self.ast, self.uops, self.applied_opts,
|
||||
return ProgramSpec(self.name if not name_override else name_override, src, self.opts.device, self.ast, self.uops,
|
||||
global_size=[1,1,1] if self.opts.has_local else None, local_size=[1,1,1] if self.opts.has_local else None)
|
||||
|
||||
@@ -87,7 +87,6 @@ class ProgramSpec:
|
||||
device:str
|
||||
ast:UOp # save the base ast (this is method cache key)
|
||||
uops:Optional[list[UOp]]=None
|
||||
applied_opts:Optional[list[Opt]]=None
|
||||
|
||||
# filled in from uops (if we have uops)
|
||||
global_size:Optional[list[int]]=None
|
||||
@@ -131,6 +130,10 @@ class ProgramSpec:
|
||||
@functools.cached_property
|
||||
def function_name(self) -> str: return to_function_name(self.name)
|
||||
|
||||
@property
|
||||
def applied_opts(self) -> tuple[Opt, ...]|None: return self.uops[-1].arg.applied_opts if \
|
||||
self.uops is not None and self.uops[-1].op is Ops.SINK and self.uops[-1].arg is not None else None
|
||||
|
||||
def launch_dims(self, var_vals:dict[Variable, int]):
|
||||
global_size = [sym_infer(sz, var_vals) for sz in self.global_size] if self.global_size is not None else None
|
||||
local_size = [sym_infer(sz, var_vals) for sz in self.local_size] if self.local_size is not None else None
|
||||
|
||||
@@ -518,6 +518,7 @@ class KernelInfo:
|
||||
local_dims: int = 0 # number of local dimensions (this is remapping RANGE to SPECIAL)
|
||||
upcasted: int = 0 # count that are upcasted (this is remapping RANGE to UNROLL)
|
||||
dont_use_locals: bool = False # don't use local indexing
|
||||
applied_opts: tuple = tuple()
|
||||
|
||||
# ******** ops in python ********
|
||||
|
||||
|
||||
Reference in New Issue
Block a user