mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
fix bug in ENABLE_METHOD_CACHE and enable for llvm
This commit is contained in:
2
.github/workflows/test.yml
vendored
2
.github/workflows/test.yml
vendored
@@ -92,7 +92,7 @@ jobs:
|
||||
- name: Install Dependencies
|
||||
run: pip install -e '.[llvm,testing]' --extra-index-url https://download.pytorch.org/whl/cpu
|
||||
- name: Run Pytest
|
||||
run: LLVM=1 python -m pytest -s -v -n=auto
|
||||
run: ENABLE_METHOD_CACHE=1 LLVM=1 python -m pytest -s -v -n=auto
|
||||
|
||||
testtorch:
|
||||
name: Torch Tests
|
||||
|
||||
@@ -2,7 +2,7 @@ import itertools
|
||||
from enum import Enum, auto
|
||||
from typing import List, Tuple
|
||||
from tinygrad.helpers import prod, dedup, all_same, colored
|
||||
from tinygrad.ops import LazyOp, MovementOps, get_lazyop_info, get_buffers, ReduceOps, get_lazyops
|
||||
from tinygrad.ops import LazyOp, MovementOps, get_lazyop_info, get_buffers, ReduceOps, get_lazyops, map_buffers
|
||||
from tinygrad.shape import ShapeTracker, View, strides_for_shape
|
||||
|
||||
def get_first_reduce(shapes):
|
||||
@@ -60,7 +60,8 @@ class ASTKernel:
|
||||
|
||||
# key for lookup in cache (can change, str might not be right)
|
||||
# bufs are needed because kernels like f(x) = x + x and f(x, y) = x + y have the same str(ast), but are different kernels.
|
||||
self.key = f"ASTKernelKey ast={str(ast)} bufs={self.bufs}"
|
||||
# mapping the buffers to integers is required because a-b != b-a (and how would you tell a and b apart?)
|
||||
self.key = f"ASTKernelKey ast={str(map_buffers({x:i for i,x in enumerate(self.bufs)}, ast))} bufs={self.bufs}"
|
||||
|
||||
def process(self) -> None:
|
||||
if hasattr(self, "sts"): return # already processed
|
||||
|
||||
@@ -170,6 +170,7 @@ class CompiledBuffer(DeviceBuffer): # pylint: disable=abstract-method
|
||||
k = cls.codegen_type(ast, output_buffer)
|
||||
if getenv("ENABLE_METHOD_CACHE"): # TODO: this breaks the ops test!
|
||||
if k.key not in cls.method_cache: cls.method_cache[k.key] = k.codegen().build(cls.runtime_type)
|
||||
elif DEBUG >= 4: print(f"method cache hit : {k.key}")
|
||||
prg = cls.method_cache[k.key]
|
||||
else:
|
||||
prg = k.codegen().build(cls.runtime_type)
|
||||
|
||||
Reference in New Issue
Block a user