mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 15:08:02 -05:00
manually handle OSX
This commit is contained in:
@@ -5,7 +5,7 @@ from multiprocessing import Queue, Process, shared_memory, connection, Lock, cpu
|
|||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from tinygrad import dtypes, Tensor
|
from tinygrad import dtypes, Tensor
|
||||||
from tinygrad.helpers import getenv, prod, Context, round_up, tqdm
|
from tinygrad.helpers import getenv, prod, Context, round_up, tqdm, OSX
|
||||||
|
|
||||||
### ResNet
|
### ResNet
|
||||||
|
|
||||||
@@ -129,14 +129,15 @@ def batch_load_resnet(batch_size=64, val=False, shuffle=True, seed=None, pad_fir
|
|||||||
q_in, q_out = Queue(), Queue()
|
q_in, q_out = Queue(), Queue()
|
||||||
|
|
||||||
sz = (batch_size*BATCH_COUNT, 224, 224, 3)
|
sz = (batch_size*BATCH_COUNT, 224, 224, 3)
|
||||||
if os.path.exists("/dev/shm/resnet_X"): os.unlink("/dev/shm/resnet_X")
|
shm_name = "resnet_X_val" if val else "resnet_X_train",
|
||||||
shm = shared_memory.SharedMemory(name="resnet_X", create=True, size=prod(sz))
|
if not OSX and os.path.exists(f"/dev/shm/{shm_name}"): os.unlink(f"/dev/shm/{shm_name}")
|
||||||
|
shm = shared_memory.SharedMemory(name=shm_name, create=True, size=prod(sz))
|
||||||
procs = []
|
procs = []
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# disk:shm is slower
|
# disk:shm is slower
|
||||||
#X = Tensor.empty(*sz, dtype=dtypes.uint8, device=f"disk:shm:{shm.name}")
|
if OSX: X = Tensor.empty(*sz, dtype=dtypes.uint8, device=f"disk:shm:{shm.name}")
|
||||||
X = Tensor.empty(*sz, dtype=dtypes.uint8, device=f"disk:/dev/shm/resnet_X")
|
else: X = Tensor.empty(*sz, dtype=dtypes.uint8, device=f"disk:/dev/shm/{shm_name}")
|
||||||
Y = [None] * (batch_size*BATCH_COUNT)
|
Y = [None] * (batch_size*BATCH_COUNT)
|
||||||
|
|
||||||
for _ in range(cpu_count()):
|
for _ in range(cpu_count()):
|
||||||
@@ -312,7 +313,7 @@ def batch_load_unet3d(preprocessed_dataset_dir:Path, batch_size:int=6, val:bool=
|
|||||||
proc = Process(target=load_unet3d_data, args=(preprocessed_dataset_dir, seed, queue_in, queue_out, X, Y))
|
proc = Process(target=load_unet3d_data, args=(preprocessed_dataset_dir, seed, queue_in, queue_out, X, Y))
|
||||||
proc.daemon = True
|
proc.daemon = True
|
||||||
proc.start()
|
proc.start()
|
||||||
|
|
||||||
procs.append(proc)
|
procs.append(proc)
|
||||||
|
|
||||||
for bc in range(batch_count):
|
for bc in range(batch_count):
|
||||||
|
|||||||
Reference in New Issue
Block a user