From 5dc227dba613be7b366c164397b0b45fedce5abf Mon Sep 17 00:00:00 2001 From: George Hotz Date: Mon, 6 Mar 2023 07:43:40 -0800 Subject: [PATCH] fix bug in ENABLE_METHOD_CACHE and enable for llvm --- .github/workflows/test.yml | 2 +- tinygrad/codegen/ast.py | 5 +++-- tinygrad/ops.py | 1 + 3 files changed, 5 insertions(+), 3 deletions(-) diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 9390b3bc32..8802742e8c 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -92,7 +92,7 @@ jobs: - name: Install Dependencies run: pip install -e '.[llvm,testing]' --extra-index-url https://download.pytorch.org/whl/cpu - name: Run Pytest - run: LLVM=1 python -m pytest -s -v -n=auto + run: ENABLE_METHOD_CACHE=1 LLVM=1 python -m pytest -s -v -n=auto testtorch: name: Torch Tests diff --git a/tinygrad/codegen/ast.py b/tinygrad/codegen/ast.py index ab22a8c40f..270bc97915 100644 --- a/tinygrad/codegen/ast.py +++ b/tinygrad/codegen/ast.py @@ -2,7 +2,7 @@ import itertools from enum import Enum, auto from typing import List, Tuple from tinygrad.helpers import prod, dedup, all_same, colored -from tinygrad.ops import LazyOp, MovementOps, get_lazyop_info, get_buffers, ReduceOps, get_lazyops +from tinygrad.ops import LazyOp, MovementOps, get_lazyop_info, get_buffers, ReduceOps, get_lazyops, map_buffers from tinygrad.shape import ShapeTracker, View, strides_for_shape def get_first_reduce(shapes): @@ -60,7 +60,8 @@ class ASTKernel: # key for lookup in cache (can change, str might not be right) # bufs are needed because kernels like f(x) = x + x and f(x, y) = x + y have the same str(ast), but are different kernels. - self.key = f"ASTKernelKey ast={str(ast)} bufs={self.bufs}" + # mapping the buffers to integers is required because a-b != b-a (and how would you tell a and b apart?) + self.key = f"ASTKernelKey ast={str(map_buffers({x:i for i,x in enumerate(self.bufs)}, ast))} bufs={self.bufs}" def process(self) -> None: if hasattr(self, "sts"): return # already processed diff --git a/tinygrad/ops.py b/tinygrad/ops.py index 8ea11e81ff..6c0eb9bb07 100644 --- a/tinygrad/ops.py +++ b/tinygrad/ops.py @@ -170,6 +170,7 @@ class CompiledBuffer(DeviceBuffer): # pylint: disable=abstract-method k = cls.codegen_type(ast, output_buffer) if getenv("ENABLE_METHOD_CACHE"): # TODO: this breaks the ops test! if k.key not in cls.method_cache: cls.method_cache[k.key] = k.codegen().build(cls.runtime_type) + elif DEBUG >= 4: print(f"method cache hit : {k.key}") prg = cls.method_cache[k.key] else: prg = k.codegen().build(cls.runtime_type)