mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
minor ops/jit cleanups [pr] (#14543)
This commit is contained in:
@@ -310,7 +310,7 @@ class TinyJit(Generic[ReturnType]):
|
||||
assert self.fxn is not None
|
||||
with Context(BEAM=0 if getenv("IGNORE_JIT_FIRST_BEAM") else BEAM.value):
|
||||
ret = self.fxn(*args, **kwargs)
|
||||
if len(params:=get_parameters(ret)): Tensor.realize(params[0], *params[1:])
|
||||
if len(params:=get_parameters(ret)): Tensor.realize(*params)
|
||||
elif self.cnt == 1:
|
||||
# jit capture
|
||||
assert self.fxn is not None
|
||||
@@ -322,7 +322,7 @@ class TinyJit(Generic[ReturnType]):
|
||||
capturing.append(self)
|
||||
try:
|
||||
ret = self.fxn(*args, **kwargs)
|
||||
if len(params:=get_parameters(ret)): Tensor.realize(params[0], *params[1:])
|
||||
if len(params:=get_parameters(ret)): Tensor.realize(*params)
|
||||
finally: capturing.clear()
|
||||
jit_cache = self._jit_cache
|
||||
del self._buffer_replace, self._jit_cache
|
||||
|
||||
@@ -60,7 +60,7 @@ def multirange_str(rngs:Iterable[UOp], color=False, pad=None) -> str:
|
||||
|
||||
def shape_to_shape_arg(arg:tuple[sint, ...]) -> UOp:
|
||||
if len(arg) == 0: return UOp(Ops.VECTORIZE, dtypes.index.vec(0))
|
||||
elif all(isinstance(x, int) for x in arg): return UOp.const(dtypes.index.vec(len(arg)), cast(tuple[int, ...], arg))
|
||||
elif all_int(arg): return UOp.const(dtypes.index.vec(len(arg)), arg)
|
||||
else: return UOp(Ops.VECTORIZE, dtypes.index.vec(len(arg)), tuple(UOp.const(dtypes.index, x) if isinstance(x, int) else x for x in arg))
|
||||
|
||||
def consumer_map_from_toposort(lst:Iterable[UOp]):
|
||||
@@ -636,7 +636,7 @@ class UOp(OpMixin, metaclass=UOpMetaClass):
|
||||
if self.op is Ops.MSTACK:
|
||||
ret = MultiBuffer.__new__(MultiBuffer)
|
||||
ret.bufs = [cast(Buffer, x.buffer) for x in self.src]
|
||||
assert all_same([x.size for x in ret.bufs]) and all_same([x.dtype for x in ret.bufs]), "multibuffers mismatch buffers"
|
||||
assert all_same([(x.size, x.dtype) for x in ret.bufs]), "multibuffers mismatch buffers"
|
||||
return ret
|
||||
assert self.op is Ops.BUFFER, f"must be BUFFER {self.op}"
|
||||
assert self.src[0].op is Ops.UNIQUE, f"buffer src[0] must be UNIQUE, not {self.src[0].op}"
|
||||
|
||||
Reference in New Issue
Block a user