mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-07 03:00:26 -04:00
Fix kernel cache key (#570)
This commit is contained in:
@@ -32,9 +32,7 @@ class Token:
|
||||
# ast kernel can contain one ReduceOp with arbitrary Binary/Unary ops
|
||||
class ASTKernel:
|
||||
def __init__(self, ast:LazyOp, output_buffer=None):
|
||||
# key for lookup in cache (can change, str might not be right)
|
||||
self.input_ast = ast
|
||||
self.key = str(ast)
|
||||
|
||||
# if the AST ends with a RESHAPE, we remove it and create the buffer accordingly
|
||||
if ast.op == MovementOps.RESHAPE:
|
||||
@@ -59,6 +57,10 @@ class ASTKernel:
|
||||
self.ret = output_buffer if output_buffer else type(self.bufs[0])(output_shape if output_shape else self.info.shape, force_create=True)
|
||||
self.bufs = ([type(self.ret)(self.info.shape, hostbuf=self.ret)] if output_shape else [self.ret]) + self.bufs
|
||||
|
||||
# key for lookup in cache (can change, str might not be right)
|
||||
# bufs are needed because kernels like f(x) = x + x and f(x, y) = x + y have the same str(ast), but are different kernels.
|
||||
self.key = f"ASTKernelKey ast={str(ast)} bufs={self.bufs}"
|
||||
|
||||
def process(self) -> None:
|
||||
if hasattr(self, "sts"): return # already processed
|
||||
|
||||
|
||||
Reference in New Issue
Block a user