mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-10 23:48:01 -05:00
do not always update float dims (#6699)
* do not always update float dims * linter * isinsatcen
This commit is contained in:
@@ -3,7 +3,7 @@ from typing import TypeVar, Generic, Callable, List, Tuple, Union, Dict, cast, O
|
||||
import functools, itertools, collections
|
||||
from tinygrad.tensor import Tensor
|
||||
from tinygrad.lazy import LazyBuffer
|
||||
from tinygrad.helpers import flatten, merge_dicts, DEBUG, Context, GRAPH, BEAM, getenv, all_int, colored, JIT, dedup, partition
|
||||
from tinygrad.helpers import flatten, merge_dicts, DEBUG, Context, GRAPH, BEAM, getenv, colored, JIT, dedup, partition
|
||||
from tinygrad.device import Buffer, Compiled, Device
|
||||
from tinygrad.dtype import DType
|
||||
from tinygrad.shape.shapetracker import ShapeTracker
|
||||
@@ -80,9 +80,11 @@ class GraphRunner(Runner): # pylint: disable=abstract-method
|
||||
mem_estimate: sint = 0
|
||||
lds_estimate: sint = 0
|
||||
|
||||
def is_sym_dim(dim) -> bool: return not all(isinstance(d, (int, float)) for d in dim)
|
||||
|
||||
self.vars = sorted(var_vals.keys(), key=lambda v: v.expr)
|
||||
self.symbolic_dims = dedup([tuple(d) for ji in jit_cache if isinstance(ji.prg, CompiledRunner) and (d:=ji.prg.p.local_size) and not all_int(d)] +
|
||||
[tuple(d) for ji in jit_cache if isinstance(ji.prg, CompiledRunner) and (d:=ji.prg.p.global_size) and not all_int(d)])
|
||||
self.symbolic_dims = dedup([tuple(d) for ji in jit_cache if isinstance(ji.prg, CompiledRunner) and (d:=ji.prg.p.local_size) and is_sym_dim(d)] +
|
||||
[tuple(d) for ji in jit_cache if isinstance(ji.prg, CompiledRunner) and (d:=ji.prg.p.global_size) and is_sym_dim(d)])
|
||||
def find_symbolic_dim(dim): return self.symbolic_dims.index(tuple(dim)) if dim is not None and tuple(dim) in self.symbolic_dims else None
|
||||
|
||||
for j,ji in enumerate(jit_cache):
|
||||
|
||||
Reference in New Issue
Block a user