split out the three steps of exec_ast (#2446)

* split out the three steps of exec_ast

* clean up extra args

* cleanups, bugfix

* allocate is a more normal name

* get_optimized_linearizer is better
This commit is contained in:
George Hotz
2023-11-26 09:07:37 -08:00
committed by GitHub
parent 511310737e
commit f6f712e609
2 changed files with 22 additions and 15 deletions

View File

@@ -144,7 +144,7 @@ def update_stats(name:str, op_estimate:sint, mem_estimate:sint, var_vals: Option
GlobalCounters.global_mem += mem_estimate
if et is not None: GlobalCounters.time_sum_s += et
# **************** shared AST runner ****************
# **************** shared Runner that can go in the JIT ****************
class JITRunner:
def __init__(self):
@@ -182,10 +182,12 @@ class Interpreted:
self.synchronize, self.codegen, self.graph = lambda: None, None, None
self.method_cache: Dict[LazyOp, InterpretedASTRunner] = {}
def exec_ast(self, ast:LazyOp, output:LazyBuffer, inputs:Tuple[LazyBuffer, ...], var_vals:Dict[Variable, int], **kwargs):
if ast not in self.method_cache: self.method_cache[ast] = get_interpreted_fxn(self.fxn_for_op, ast)
def allocate_output(self, ast:LazyOp, output:LazyBuffer, inputs:Tuple[LazyBuffer, ...]):
output.realized = output.output_buffer if output.output_buffer is not None else self.buffer.__new__(self.buffer)
self.method_cache[ast].exec([output.realized] + [x.realized for x in inputs], var_vals)
def get_runner(self, ast:LazyOp, rawbuffers:List[RawBuffer]) -> InterpretedASTRunner:
if ast not in self.method_cache or getenv("DISABLE_METHOD_CACHE"): self.method_cache[ast] = get_interpreted_fxn(self.fxn_for_op, ast)
return self.method_cache[ast]
def get_interpreted_fxn(fxn_for_op:Dict[Op, Callable], ast:LazyOp) -> InterpretedASTRunner:
if DEBUG >= 3:
@@ -277,11 +279,12 @@ class Compiled:
src, runtime_args = self.renderer(to_function_name(k.name), k.uops)
return CompiledASTRunner(k.ast, k.name, src, k.global_size, k.local_size, runtime_args).build(self.compiler, self.runtime)
def exec_ast(self, ast:LazyOp, output:LazyBuffer, inputs:Tuple[LazyBuffer, ...], var_vals:Dict[Variable, int], **kwargs):
def allocate_output(self, ast:LazyOp, output:LazyBuffer, inputs:Tuple[LazyBuffer, ...]):
# check if we can reuse the output buffer
# if it's aliased, don't use it
# TODO: this is pretty wrong actually, who knows where else this buffer is used?
# TODO: what if an assign is required? this silently is wrong
# TODO: this logic just doesn't belong here
output.realized = output.output_buffer
if output.realized is not None:
for i,a in enumerate(inputs):
@@ -293,16 +296,14 @@ class Compiled:
# we don't have an output buffer, we have to create it, and create to max size if it has symbolic shape
if output.realized is None:
output.realized = self.buffer(prod((s if isinstance(s, int) else s.max for s in output.shape)), output.dtype, **kwargs)
if output.realized.size == 0: return output.realized
output.realized = self.buffer(prod((s if isinstance(s, int) else s.max for s in output.shape)), output.dtype, **output._device_extra_args())
# all the rawbuffers
rawbuffers = [output.realized] + [x.realized for x in inputs]
# TODO: the rawbuffers are only used for optimization, they should be removed and optimizer should realloc
def get_runner(self, ast:LazyOp, rawbuffers:List[RawBuffer]) -> CompiledASTRunner:
if ast not in self.method_cache or getenv("DISABLE_METHOD_CACHE"): self.method_cache[ast] = self.to_program(get_optimized_linearizer(ast, self.linearizer_opts, rawbuffers))
return self.method_cache[ast]
if ast not in self.method_cache or getenv("DISABLE_METHOD_CACHE"): self.method_cache[ast] = get_optimized_program(self.linearizer_opts, self.to_program, ast, rawbuffers)
self.method_cache[ast].exec(rawbuffers, var_vals)
def get_optimized_program(linearizer_opts:LinearizerOptions, to_program, ast:LazyOp, rawbuffers:List[RawBuffer]) -> CompiledASTRunner:
def get_optimized_linearizer(ast:LazyOp, linearizer_opts:LinearizerOptions, rawbuffers:List[RawBuffer]) -> Linearizer:
if DEBUG >= 3:
from tinygrad.graph import print_tree
print_tree(ast)
@@ -327,4 +328,4 @@ def get_optimized_program(linearizer_opts:LinearizerOptions, to_program, ast:Laz
k = timed[0][1]
else:
k.required_optimizations()
return to_program(k)
return k

View File

@@ -21,7 +21,13 @@ def run_schedule(schedule:List[ScheduleItem], disable_logging=False):
LOAD_OPS_DISPATCHER[cast(LoadOps, si.ast.op)](si.out, *si.inputs)
else:
assert all(si.out.device == x.device for x in si.inputs), f"all devices must be the same, {si.out.device} != {[x.device for x in si.inputs]} {print_tree(si.ast) or ''}"
Device[si.out.device].exec_ast(si.ast, output=si.out, inputs=si.inputs, var_vals=si.var_vals, **si.out._device_extra_args())
# TODO: allocate_output should be at the top of this function for global memory management
Device[si.out.device].allocate_output(si.ast, si.out, si.inputs)
# TODO: should this be handled here? it probably just shouldn't be in the schedule
if not hasattr(si.out.realized, 'size') or si.out.realized.size != 0:
rawbuffers = [si.out.realized] + [x.realized for x in si.inputs]
# TODO: remove rawbuffers from get_runner, optimizer should reallocate them
Device[si.out.device].get_runner(si.ast, rawbuffers).exec(rawbuffers, si.var_vals)
del si.out.op
for v in si.out.views: del v.op
assert si.out.realized and isinstance(si.out.realized, Device[si.out.device].buffer), f"device mismatch on realized got {type(si.out.realized)} expected {si.out.device}"