mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 23:18:04 -05:00
cleanup realize (#2505)
* delete reallocs * cleaner * that's real * less lines
This commit is contained in:
@@ -44,12 +44,11 @@ def run_schedule(schedule:List[ScheduleItem], disable_logging=False):
|
||||
# *** zero op LoadOps ***
|
||||
|
||||
def _realize_empty(buffer: LazyBuffer) -> None:
|
||||
assert all_int(buffer.shape), "LoadOps do not support symbolic shape"
|
||||
if DEBUG >= 2: print(f"*** empty {buffer.device} shape {str(buffer.shape):23s} dtype {buffer.dtype}")
|
||||
|
||||
# TODO: remove this and write the RNG in tinygrad
|
||||
def _realize_rand(buffer: LazyBuffer) -> None:
|
||||
assert all_int(buffer.shape), "LoadOps do not support symbolic shape"
|
||||
assert all_int(buffer.shape), "rand doesn't support symbolic shape"
|
||||
if DEBUG >= 2: print(f"*** rand {buffer.device} seed {buffer.op.arg:<10d} shape {str(buffer.shape):23s} dtype {buffer.dtype}")
|
||||
rng = np.random.default_rng(buffer.op.arg)
|
||||
buffer.realized._copyin(rng.random(size=prod(buffer.shape), dtype=np.float32).astype(dtype=buffer.dtype.np, copy=False), **buffer._device_extra_args())
|
||||
@@ -59,19 +58,15 @@ def _realize_rand(buffer: LazyBuffer) -> None:
|
||||
from tinygrad.runtime.lib import RawBufferMapped, RawBufferTransfer
|
||||
from tinygrad.runtime.ops_disk import RawDiskBuffer
|
||||
def _realize_from(buffer: LazyBuffer, src: LazyBuffer) -> None:
|
||||
assert all_int(buffer.shape), "LoadOps do not support symbolic shape"
|
||||
assert src.realized.size == buffer.st.size(), f"size mismatch on FROM {src.realized.size} != {buffer.st.size()}"
|
||||
assert src.realized.size == buffer.realized.size, f"size mismatch on FROM {src.realized.size=} != {buffer.realized.size=}"
|
||||
assert src.st.contiguous and buffer.st.contiguous, "all must be contiguous for from"
|
||||
if DEBUG >= 2: print(f"*** copy {buffer.device} <- {src.device} size {src.realized.size:<16d} shape {str(buffer.shape):23s} dtype {src.realized.dtype}")
|
||||
# TODO: make this generic
|
||||
if isinstance(src.realized, RawDiskBuffer) and issubclass(Device[buffer.device].buffer, RawBufferMapped):
|
||||
buffer.realized = Device[buffer.device].buffer(prod(buffer.shape), buffer.dtype, **buffer._device_extra_args())
|
||||
src.realized.readinto(cast(RawBufferMapped, buffer.realized)._buffer())
|
||||
elif isinstance(src.realized, RawBufferTransfer) and issubclass(Device[buffer.device].buffer, RawBufferTransfer) and getenv("P2P", 0) >= 1:
|
||||
buffer.realized = Device[buffer.device].buffer(prod(buffer.shape), buffer.dtype, **buffer._device_extra_args())
|
||||
cast(RawBufferTransfer, buffer.realized)._transfer(src.realized)
|
||||
if isinstance(src.realized, RawDiskBuffer) and isinstance(buffer.realized, RawBufferMapped):
|
||||
src.realized.readinto(buffer.realized._buffer())
|
||||
elif isinstance(src.realized, RawBufferTransfer) and isinstance(buffer.realized, RawBufferTransfer) and getenv("P2P", 0) >= 1:
|
||||
buffer.realized._transfer(src.realized)
|
||||
else:
|
||||
# TODO: schedule this as FROM to go to CPU, and a FROM to go to device
|
||||
buffer.realized._copyin(src.realized.toCPU())
|
||||
|
||||
# *** n op LoadOps ***
|
||||
|
||||
Reference in New Issue
Block a user