From 664f1bf76d88dd2a5f321c2bbd41302187b88492 Mon Sep 17 00:00:00 2001 From: chenyu Date: Wed, 4 Feb 2026 17:21:34 -0500 Subject: [PATCH] minor ops/jit cleanups [pr] (#14543) --- tinygrad/engine/jit.py | 4 ++-- tinygrad/uop/ops.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/tinygrad/engine/jit.py b/tinygrad/engine/jit.py index 5c93cb7cf8..79fe034d39 100644 --- a/tinygrad/engine/jit.py +++ b/tinygrad/engine/jit.py @@ -310,7 +310,7 @@ class TinyJit(Generic[ReturnType]): assert self.fxn is not None with Context(BEAM=0 if getenv("IGNORE_JIT_FIRST_BEAM") else BEAM.value): ret = self.fxn(*args, **kwargs) - if len(params:=get_parameters(ret)): Tensor.realize(params[0], *params[1:]) + if len(params:=get_parameters(ret)): Tensor.realize(*params) elif self.cnt == 1: # jit capture assert self.fxn is not None @@ -322,7 +322,7 @@ class TinyJit(Generic[ReturnType]): capturing.append(self) try: ret = self.fxn(*args, **kwargs) - if len(params:=get_parameters(ret)): Tensor.realize(params[0], *params[1:]) + if len(params:=get_parameters(ret)): Tensor.realize(*params) finally: capturing.clear() jit_cache = self._jit_cache del self._buffer_replace, self._jit_cache diff --git a/tinygrad/uop/ops.py b/tinygrad/uop/ops.py index b5e86bc860..b857a02502 100644 --- a/tinygrad/uop/ops.py +++ b/tinygrad/uop/ops.py @@ -60,7 +60,7 @@ def multirange_str(rngs:Iterable[UOp], color=False, pad=None) -> str: def shape_to_shape_arg(arg:tuple[sint, ...]) -> UOp: if len(arg) == 0: return UOp(Ops.VECTORIZE, dtypes.index.vec(0)) - elif all(isinstance(x, int) for x in arg): return UOp.const(dtypes.index.vec(len(arg)), cast(tuple[int, ...], arg)) + elif all_int(arg): return UOp.const(dtypes.index.vec(len(arg)), arg) else: return UOp(Ops.VECTORIZE, dtypes.index.vec(len(arg)), tuple(UOp.const(dtypes.index, x) if isinstance(x, int) else x for x in arg)) def consumer_map_from_toposort(lst:Iterable[UOp]): @@ -636,7 +636,7 @@ class UOp(OpMixin, metaclass=UOpMetaClass): if self.op is Ops.MSTACK: ret = MultiBuffer.__new__(MultiBuffer) ret.bufs = [cast(Buffer, x.buffer) for x in self.src] - assert all_same([x.size for x in ret.bufs]) and all_same([x.dtype for x in ret.bufs]), "multibuffers mismatch buffers" + assert all_same([(x.size, x.dtype) for x in ret.bufs]), "multibuffers mismatch buffers" return ret assert self.op is Ops.BUFFER, f"must be BUFFER {self.op}" assert self.src[0].op is Ops.UNIQUE, f"buffer src[0] must be UNIQUE, not {self.src[0].op}"