mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-06 21:53:53 -05:00
typed
This commit is contained in:
@@ -214,11 +214,11 @@ class ExecItem:
|
||||
et = self.prg(bufs, var_vals, wait=wait or DEBUG >= 2)
|
||||
if do_update_stats:
|
||||
GlobalCounters.kernel_count += 1
|
||||
GlobalCounters.global_ops += (op_est:=sym_infer(self.prg.estimates.ops, var_vals))
|
||||
GlobalCounters.global_mem += (mem_est:=sym_infer(self.prg.estimates.mem, var_vals))
|
||||
GlobalCounters.global_ops += (op_est:=int(sym_infer(self.prg.estimates.ops, var_vals)))
|
||||
GlobalCounters.global_mem += (mem_est:=int(sym_infer(self.prg.estimates.mem, var_vals)))
|
||||
if et is not None: GlobalCounters.time_sum_s += et
|
||||
if DEBUG >= 2:
|
||||
lds_est = sym_infer(self.prg.estimates.lds, var_vals)
|
||||
lds_est = int(sym_infer(self.prg.estimates.lds, var_vals))
|
||||
mem_est = min(mem_est, lds_est) # there can't be more memory accessed than loads/stores. remove this when symbolic is fixed
|
||||
header_color = 'magenta' if jit else ('green' if self.prg.first_run else None)
|
||||
ptm = colored(time_to_str(et, w=9), "yellow" if et > 0.01 else None) if et is not None else ""
|
||||
|
||||
@@ -1300,9 +1300,10 @@ class Tensor(OpMixin):
|
||||
print(t0.cat(t1, t2, dim=1).numpy())
|
||||
```
|
||||
"""
|
||||
if isinstance(self, (tuple, list)): self, args = self[0], tuple(self[1:]) + args # type: ignore[arg-type]
|
||||
dim = self._resolve_dim(dim)
|
||||
for arg in args: assert arg.ndim==self.ndim and all(ti==ai for i,(ti,ai) in enumerate(zip(self.shape, arg.shape)) if i!=dim)
|
||||
tensors = [self, *args]
|
||||
tensors:list[Tensor] = [self, *args]
|
||||
dim_cumsum = list(itertools.accumulate([t.shape[dim] for t in tensors], initial=0))
|
||||
for i,t in enumerate(tensors): tensors[i] = t.pad([(dim_cumsum[i], dim_cumsum[-1]-dim_cumsum[i+1]) if j==dim else None for j in range(t.ndim)])
|
||||
return functools.reduce(Tensor.add, tensors)
|
||||
@@ -2202,7 +2203,8 @@ class Tensor(OpMixin):
|
||||
idx = m * idx.pad(pads, value=dtypes.min(idx.dtype))._pool(k_, stride if stride is not None else k_, dilation)
|
||||
return pooled.max(axis), spatial_sz - idx.max(axis)
|
||||
|
||||
def max_unpool2d(self, indices:Tensor, kernel_size:int|tuple[int, ...]=(2,2), stride=None, dilation=1, padding:int|tuple[int, ...]|list[int]=0, output_size=None):
|
||||
def max_unpool2d(self, indices:Tensor, kernel_size:int|tuple[int, ...]=(2,2), stride=None, dilation=1,
|
||||
padding:int|tuple[int, ...]|list[int]=0, output_size=None):
|
||||
"""
|
||||
Performs a partial inverse of `max_pool2d` using the indices from the argmax.
|
||||
|
||||
|
||||
Reference in New Issue
Block a user