From 9e0a42ec0e291ded1d775cc45289ae44f09e1892 Mon Sep 17 00:00:00 2001 From: George Hotz Date: Mon, 22 Dec 2025 18:40:53 +0000 Subject: [PATCH] typed --- tinygrad/engine/realize.py | 6 +++--- tinygrad/tensor.py | 6 ++++-- 2 files changed, 7 insertions(+), 5 deletions(-) diff --git a/tinygrad/engine/realize.py b/tinygrad/engine/realize.py index dcdb9c3965..976a233da0 100644 --- a/tinygrad/engine/realize.py +++ b/tinygrad/engine/realize.py @@ -214,11 +214,11 @@ class ExecItem: et = self.prg(bufs, var_vals, wait=wait or DEBUG >= 2) if do_update_stats: GlobalCounters.kernel_count += 1 - GlobalCounters.global_ops += (op_est:=sym_infer(self.prg.estimates.ops, var_vals)) - GlobalCounters.global_mem += (mem_est:=sym_infer(self.prg.estimates.mem, var_vals)) + GlobalCounters.global_ops += (op_est:=int(sym_infer(self.prg.estimates.ops, var_vals))) + GlobalCounters.global_mem += (mem_est:=int(sym_infer(self.prg.estimates.mem, var_vals))) if et is not None: GlobalCounters.time_sum_s += et if DEBUG >= 2: - lds_est = sym_infer(self.prg.estimates.lds, var_vals) + lds_est = int(sym_infer(self.prg.estimates.lds, var_vals)) mem_est = min(mem_est, lds_est) # there can't be more memory accessed than loads/stores. remove this when symbolic is fixed header_color = 'magenta' if jit else ('green' if self.prg.first_run else None) ptm = colored(time_to_str(et, w=9), "yellow" if et > 0.01 else None) if et is not None else "" diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index b17fe54d22..e0da25f0f5 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -1300,9 +1300,10 @@ class Tensor(OpMixin): print(t0.cat(t1, t2, dim=1).numpy()) ``` """ + if isinstance(self, (tuple, list)): self, args = self[0], tuple(self[1:]) + args # type: ignore[arg-type] dim = self._resolve_dim(dim) for arg in args: assert arg.ndim==self.ndim and all(ti==ai for i,(ti,ai) in enumerate(zip(self.shape, arg.shape)) if i!=dim) - tensors = [self, *args] + tensors:list[Tensor] = [self, *args] dim_cumsum = list(itertools.accumulate([t.shape[dim] for t in tensors], initial=0)) for i,t in enumerate(tensors): tensors[i] = t.pad([(dim_cumsum[i], dim_cumsum[-1]-dim_cumsum[i+1]) if j==dim else None for j in range(t.ndim)]) return functools.reduce(Tensor.add, tensors) @@ -2202,7 +2203,8 @@ class Tensor(OpMixin): idx = m * idx.pad(pads, value=dtypes.min(idx.dtype))._pool(k_, stride if stride is not None else k_, dilation) return pooled.max(axis), spatial_sz - idx.max(axis) - def max_unpool2d(self, indices:Tensor, kernel_size:int|tuple[int, ...]=(2,2), stride=None, dilation=1, padding:int|tuple[int, ...]|list[int]=0, output_size=None): + def max_unpool2d(self, indices:Tensor, kernel_size:int|tuple[int, ...]=(2,2), stride=None, dilation=1, + padding:int|tuple[int, ...]|list[int]=0, output_size=None): """ Performs a partial inverse of `max_pool2d` using the indices from the argmax.