minor ops/jit cleanups [pr] (#14543)

This commit is contained in:
chenyu
2026-02-04 17:21:34 -05:00
committed by GitHub
parent 03d0fa9c3f
commit 664f1bf76d
2 changed files with 4 additions and 4 deletions

View File

@@ -310,7 +310,7 @@ class TinyJit(Generic[ReturnType]):
assert self.fxn is not None
with Context(BEAM=0 if getenv("IGNORE_JIT_FIRST_BEAM") else BEAM.value):
ret = self.fxn(*args, **kwargs)
if len(params:=get_parameters(ret)): Tensor.realize(params[0], *params[1:])
if len(params:=get_parameters(ret)): Tensor.realize(*params)
elif self.cnt == 1:
# jit capture
assert self.fxn is not None
@@ -322,7 +322,7 @@ class TinyJit(Generic[ReturnType]):
capturing.append(self)
try:
ret = self.fxn(*args, **kwargs)
if len(params:=get_parameters(ret)): Tensor.realize(params[0], *params[1:])
if len(params:=get_parameters(ret)): Tensor.realize(*params)
finally: capturing.clear()
jit_cache = self._jit_cache
del self._buffer_replace, self._jit_cache

View File

@@ -60,7 +60,7 @@ def multirange_str(rngs:Iterable[UOp], color=False, pad=None) -> str:
def shape_to_shape_arg(arg:tuple[sint, ...]) -> UOp:
if len(arg) == 0: return UOp(Ops.VECTORIZE, dtypes.index.vec(0))
elif all(isinstance(x, int) for x in arg): return UOp.const(dtypes.index.vec(len(arg)), cast(tuple[int, ...], arg))
elif all_int(arg): return UOp.const(dtypes.index.vec(len(arg)), arg)
else: return UOp(Ops.VECTORIZE, dtypes.index.vec(len(arg)), tuple(UOp.const(dtypes.index, x) if isinstance(x, int) else x for x in arg))
def consumer_map_from_toposort(lst:Iterable[UOp]):
@@ -636,7 +636,7 @@ class UOp(OpMixin, metaclass=UOpMetaClass):
if self.op is Ops.MSTACK:
ret = MultiBuffer.__new__(MultiBuffer)
ret.bufs = [cast(Buffer, x.buffer) for x in self.src]
assert all_same([x.size for x in ret.bufs]) and all_same([x.dtype for x in ret.bufs]), "multibuffers mismatch buffers"
assert all_same([(x.size, x.dtype) for x in ret.bufs]), "multibuffers mismatch buffers"
return ret
assert self.op is Ops.BUFFER, f"must be BUFFER {self.op}"
assert self.src[0].op is Ops.UNIQUE, f"buffer src[0] must be UNIQUE, not {self.src[0].op}"