diff --git a/tinygrad/realize.py b/tinygrad/realize.py index da68412ab2..63e63080ea 100644 --- a/tinygrad/realize.py +++ b/tinygrad/realize.py @@ -44,12 +44,11 @@ def run_schedule(schedule:List[ScheduleItem], disable_logging=False): # *** zero op LoadOps *** def _realize_empty(buffer: LazyBuffer) -> None: - assert all_int(buffer.shape), "LoadOps do not support symbolic shape" if DEBUG >= 2: print(f"*** empty {buffer.device} shape {str(buffer.shape):23s} dtype {buffer.dtype}") # TODO: remove this and write the RNG in tinygrad def _realize_rand(buffer: LazyBuffer) -> None: - assert all_int(buffer.shape), "LoadOps do not support symbolic shape" + assert all_int(buffer.shape), "rand doesn't support symbolic shape" if DEBUG >= 2: print(f"*** rand {buffer.device} seed {buffer.op.arg:<10d} shape {str(buffer.shape):23s} dtype {buffer.dtype}") rng = np.random.default_rng(buffer.op.arg) buffer.realized._copyin(rng.random(size=prod(buffer.shape), dtype=np.float32).astype(dtype=buffer.dtype.np, copy=False), **buffer._device_extra_args()) @@ -59,19 +58,15 @@ def _realize_rand(buffer: LazyBuffer) -> None: from tinygrad.runtime.lib import RawBufferMapped, RawBufferTransfer from tinygrad.runtime.ops_disk import RawDiskBuffer def _realize_from(buffer: LazyBuffer, src: LazyBuffer) -> None: - assert all_int(buffer.shape), "LoadOps do not support symbolic shape" - assert src.realized.size == buffer.st.size(), f"size mismatch on FROM {src.realized.size} != {buffer.st.size()}" + assert src.realized.size == buffer.realized.size, f"size mismatch on FROM {src.realized.size=} != {buffer.realized.size=}" assert src.st.contiguous and buffer.st.contiguous, "all must be contiguous for from" if DEBUG >= 2: print(f"*** copy {buffer.device} <- {src.device} size {src.realized.size:<16d} shape {str(buffer.shape):23s} dtype {src.realized.dtype}") # TODO: make this generic - if isinstance(src.realized, RawDiskBuffer) and issubclass(Device[buffer.device].buffer, RawBufferMapped): - buffer.realized = Device[buffer.device].buffer(prod(buffer.shape), buffer.dtype, **buffer._device_extra_args()) - src.realized.readinto(cast(RawBufferMapped, buffer.realized)._buffer()) - elif isinstance(src.realized, RawBufferTransfer) and issubclass(Device[buffer.device].buffer, RawBufferTransfer) and getenv("P2P", 0) >= 1: - buffer.realized = Device[buffer.device].buffer(prod(buffer.shape), buffer.dtype, **buffer._device_extra_args()) - cast(RawBufferTransfer, buffer.realized)._transfer(src.realized) + if isinstance(src.realized, RawDiskBuffer) and isinstance(buffer.realized, RawBufferMapped): + src.realized.readinto(buffer.realized._buffer()) + elif isinstance(src.realized, RawBufferTransfer) and isinstance(buffer.realized, RawBufferTransfer) and getenv("P2P", 0) >= 1: + buffer.realized._transfer(src.realized) else: - # TODO: schedule this as FROM to go to CPU, and a FROM to go to device buffer.realized._copyin(src.realized.toCPU()) # *** n op LoadOps ***