do not always update float dims (#6699)

* do not always update float dims

* linter

* isinsatcen
This commit is contained in:
nimlgen
2024-09-24 14:40:45 +08:00
committed by GitHub
parent 048483ee0b
commit a473bf4ba9

View File

@@ -3,7 +3,7 @@ from typing import TypeVar, Generic, Callable, List, Tuple, Union, Dict, cast, O
import functools, itertools, collections
from tinygrad.tensor import Tensor
from tinygrad.lazy import LazyBuffer
from tinygrad.helpers import flatten, merge_dicts, DEBUG, Context, GRAPH, BEAM, getenv, all_int, colored, JIT, dedup, partition
from tinygrad.helpers import flatten, merge_dicts, DEBUG, Context, GRAPH, BEAM, getenv, colored, JIT, dedup, partition
from tinygrad.device import Buffer, Compiled, Device
from tinygrad.dtype import DType
from tinygrad.shape.shapetracker import ShapeTracker
@@ -80,9 +80,11 @@ class GraphRunner(Runner): # pylint: disable=abstract-method
mem_estimate: sint = 0
lds_estimate: sint = 0
def is_sym_dim(dim) -> bool: return not all(isinstance(d, (int, float)) for d in dim)
self.vars = sorted(var_vals.keys(), key=lambda v: v.expr)
self.symbolic_dims = dedup([tuple(d) for ji in jit_cache if isinstance(ji.prg, CompiledRunner) and (d:=ji.prg.p.local_size) and not all_int(d)] +
[tuple(d) for ji in jit_cache if isinstance(ji.prg, CompiledRunner) and (d:=ji.prg.p.global_size) and not all_int(d)])
self.symbolic_dims = dedup([tuple(d) for ji in jit_cache if isinstance(ji.prg, CompiledRunner) and (d:=ji.prg.p.local_size) and is_sym_dim(d)] +
[tuple(d) for ji in jit_cache if isinstance(ji.prg, CompiledRunner) and (d:=ji.prg.p.global_size) and is_sym_dim(d)])
def find_symbolic_dim(dim): return self.symbolic_dims.index(tuple(dim)) if dim is not None and tuple(dim) in self.symbolic_dims else None
for j,ji in enumerate(jit_cache):