This commit is contained in:
George Hotz
2025-12-22 18:40:53 +00:00
parent 703ab8c63e
commit 9e0a42ec0e
2 changed files with 7 additions and 5 deletions

View File

@@ -214,11 +214,11 @@ class ExecItem:
et = self.prg(bufs, var_vals, wait=wait or DEBUG >= 2)
if do_update_stats:
GlobalCounters.kernel_count += 1
GlobalCounters.global_ops += (op_est:=sym_infer(self.prg.estimates.ops, var_vals))
GlobalCounters.global_mem += (mem_est:=sym_infer(self.prg.estimates.mem, var_vals))
GlobalCounters.global_ops += (op_est:=int(sym_infer(self.prg.estimates.ops, var_vals)))
GlobalCounters.global_mem += (mem_est:=int(sym_infer(self.prg.estimates.mem, var_vals)))
if et is not None: GlobalCounters.time_sum_s += et
if DEBUG >= 2:
lds_est = sym_infer(self.prg.estimates.lds, var_vals)
lds_est = int(sym_infer(self.prg.estimates.lds, var_vals))
mem_est = min(mem_est, lds_est) # there can't be more memory accessed than loads/stores. remove this when symbolic is fixed
header_color = 'magenta' if jit else ('green' if self.prg.first_run else None)
ptm = colored(time_to_str(et, w=9), "yellow" if et > 0.01 else None) if et is not None else ""

View File

@@ -1300,9 +1300,10 @@ class Tensor(OpMixin):
print(t0.cat(t1, t2, dim=1).numpy())
```
"""
if isinstance(self, (tuple, list)): self, args = self[0], tuple(self[1:]) + args # type: ignore[arg-type]
dim = self._resolve_dim(dim)
for arg in args: assert arg.ndim==self.ndim and all(ti==ai for i,(ti,ai) in enumerate(zip(self.shape, arg.shape)) if i!=dim)
tensors = [self, *args]
tensors:list[Tensor] = [self, *args]
dim_cumsum = list(itertools.accumulate([t.shape[dim] for t in tensors], initial=0))
for i,t in enumerate(tensors): tensors[i] = t.pad([(dim_cumsum[i], dim_cumsum[-1]-dim_cumsum[i+1]) if j==dim else None for j in range(t.ndim)])
return functools.reduce(Tensor.add, tensors)
@@ -2202,7 +2203,8 @@ class Tensor(OpMixin):
idx = m * idx.pad(pads, value=dtypes.min(idx.dtype))._pool(k_, stride if stride is not None else k_, dilation)
return pooled.max(axis), spatial_sz - idx.max(axis)
def max_unpool2d(self, indices:Tensor, kernel_size:int|tuple[int, ...]=(2,2), stride=None, dilation=1, padding:int|tuple[int, ...]|list[int]=0, output_size=None):
def max_unpool2d(self, indices:Tensor, kernel_size:int|tuple[int, ...]=(2,2), stride=None, dilation=1,
padding:int|tuple[int, ...]|list[int]=0, output_size=None):
"""
Performs a partial inverse of `max_pool2d` using the indices from the argmax.