Fix kernel cache key (#570)

This commit is contained in:
Martin Loretz
2023-02-21 12:53:07 +01:00
committed by GitHub
parent 66b4b3bdd3
commit 8550b3e168

View File

@@ -32,9 +32,7 @@ class Token:
# ast kernel can contain one ReduceOp with arbitrary Binary/Unary ops
class ASTKernel:
def __init__(self, ast:LazyOp, output_buffer=None):
# key for lookup in cache (can change, str might not be right)
self.input_ast = ast
self.key = str(ast)
# if the AST ends with a RESHAPE, we remove it and create the buffer accordingly
if ast.op == MovementOps.RESHAPE:
@@ -59,6 +57,10 @@ class ASTKernel:
self.ret = output_buffer if output_buffer else type(self.bufs[0])(output_shape if output_shape else self.info.shape, force_create=True)
self.bufs = ([type(self.ret)(self.info.shape, hostbuf=self.ret)] if output_shape else [self.ret]) + self.bufs
# key for lookup in cache (can change, str might not be right)
# bufs are needed because kernels like f(x) = x + x and f(x, y) = x + y have the same str(ast), but are different kernels.
self.key = f"ASTKernelKey ast={str(ast)} bufs={self.bufs}"
def process(self) -> None:
if hasattr(self, "sts"): return # already processed