mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 15:08:02 -05:00
hip multigpu training (#1878)
* feat: move to hip * feat: special path for RawBufferTransfer * feat: initial rawbuffertransfer * feat: hip ipc * feat: working hip ipc * feat: need to base device without args * feat: close mem handle * feat: modified test * feat: more multihip stuff * clean: cleanup * feat: cleaner * feat: don't crash * feat: test more * clean: way cleaner hip wrapper * feat: barrier * feat: barrier * feat: this breaks stuff * feat: we can use empty here * feat: maybe fix tests * feat: maybe fix tests again? * fix: probably fix tests * feat: no waiting here * feat: wait here * feat: much larger test * feat: need to sync here * feat: make this async * feat: no waiting! * feat: cut here * feat: sync copy * feat: random imports * feat: much cleaner world * feat: restore this * feat: restore this * clean: cleanup * feat: set this
This commit is contained in:
@@ -426,8 +426,12 @@ if __name__ == "__main__":
|
||||
if not getenv("DIST"):
|
||||
train_cifar()
|
||||
else: # distributed
|
||||
from tinygrad.runtime.ops_gpu import CL
|
||||
devices = [f"gpu:{i}" for i in range(len(CL.devices))]
|
||||
if getenv("HIP"):
|
||||
from tinygrad.runtime.ops_hip import HIP
|
||||
devices = [f"hip:{i}" for i in range(HIP.device_count)]
|
||||
else:
|
||||
from tinygrad.runtime.ops_gpu import CL
|
||||
devices = [f"gpu:{i}" for i in range(len(CL.devices))]
|
||||
world_size = len(devices)
|
||||
|
||||
# ensure that the batch size is divisible by the number of devices
|
||||
|
||||
5
extra/dist/__init__.py
vendored
5
extra/dist/__init__.py
vendored
@@ -46,12 +46,11 @@ def _process_wrap(rank:int, device:str, oob:_OOB, fn:Callable, args=()):
|
||||
from tinygrad.runtime.ops_gpu import CL
|
||||
CL.post_init(device_num)
|
||||
elif "HIP" in device:
|
||||
import extra.hip_wrapper as hip
|
||||
hip.hipSetDevice(device_num)
|
||||
os.environ["HIP_DEFAULT_DEVICE"] = os.environ["HIP_VISIBLE_DEVICES"] = str(device_num)
|
||||
if DEBUG >= 1: print(f"distributed process {rank} initialized runtime for device {device}")
|
||||
|
||||
# convert device to be process specific
|
||||
Device.DEFAULT = device.split(":")[0]
|
||||
Device.DEFAULT = device.split(":")[0] if "GPU" in device else device
|
||||
|
||||
fn(*args)
|
||||
|
||||
|
||||
4
extra/dist/collectives.py
vendored
4
extra/dist/collectives.py
vendored
@@ -1,4 +1,4 @@
|
||||
from tinygrad.tensor import Tensor, Device
|
||||
from tinygrad.tensor import Tensor
|
||||
from tinygrad.helpers import getenv
|
||||
|
||||
from extra.dist import world
|
||||
@@ -12,7 +12,7 @@ def allreduce(t:Tensor, cache_id=None) -> Tensor:
|
||||
|
||||
# pad to evenly divide
|
||||
if flattened.shape[0] % WORLD_SIZE != 0:
|
||||
flattened = Tensor.cat(flattened, Tensor.zeros(WORLD_SIZE - (flattened.shape[0] % WORLD_SIZE)))
|
||||
flattened = Tensor.cat(flattened, Tensor.empty(WORLD_SIZE - (flattened.shape[0] % WORLD_SIZE)))
|
||||
|
||||
# chunk
|
||||
chunks = flattened.chunk(WORLD_SIZE, dim=0)
|
||||
|
||||
118
extra/dist/world.py
vendored
118
extra/dist/world.py
vendored
@@ -1,59 +1,101 @@
|
||||
from typing import Any, Optional, Tuple
|
||||
from typing import Optional
|
||||
from extra import dist
|
||||
from multiprocessing import shared_memory
|
||||
from tinygrad.helpers import DEBUG, colored
|
||||
from tinygrad.helpers import DEBUG, colored, getenv
|
||||
from tinygrad.lazy import LazyBuffer
|
||||
from tinygrad.runtime.lib import RawBufferCopyIn, RawBufferCopyInOut
|
||||
from tinygrad.runtime.lib import RawBuffer, RawBufferCopyIn, RawBufferCopyInOut
|
||||
try: from tinygrad.runtime.ops_hip import RawHIPBuffer
|
||||
except: RawHIPBuffer = None
|
||||
from tinygrad.runtime.ops_shm import RawShmBuffer
|
||||
from tinygrad.jit import CacheCollector
|
||||
from tinygrad.tensor import Tensor, Function
|
||||
import extra.hip_wrapper as hip
|
||||
import numpy as np
|
||||
|
||||
# fake the function signature of ASTRunner so we can put it in the cache
|
||||
def __send_rb(args:Tuple[RawBufferCopyInOut, RawShmBuffer, int, Any], variables=None, jit=False, force_wait=False):
|
||||
args[0]._copyout(np.frombuffer(args[1]._buffer(), dtype=args[0].dtype.np))
|
||||
dist.OOB.send(args[3], args[2])
|
||||
if DEBUG >= 2: print(f"{colored('****', 'magenta' if jit else None)} sent {args[0]} to rank {args[2]}")
|
||||
def __send_rb(args, variables=None, jit=False, force_wait=False):
|
||||
x, target_rank, y = args[:3]
|
||||
if RawHIPBuffer and x.__class__ is RawHIPBuffer:
|
||||
hip.hipSetDevice(x._device)
|
||||
hip.hipDeviceSynchronize()
|
||||
else:
|
||||
if isinstance(x, RawBufferCopyInOut): x._copyout(np.frombuffer(y._buffer(), dtype=x.dtype.np))
|
||||
else: y.fromCPU(x.toCPU())
|
||||
dist.OOB.send(None, target_rank)
|
||||
if DEBUG >= 2: print(f"{colored('****', 'magenta' if jit else None)} rank {getenv('RANK')} sent {x} to rank {target_rank}")
|
||||
|
||||
def __recv_rb(args:Tuple[RawBufferCopyIn, RawShmBuffer, int], variables=None, jit=False, force_wait=False):
|
||||
dist.OOB.recv(args[2])
|
||||
args[0]._copyin(args[1].toCPU())
|
||||
if DEBUG >= 2: print(f"{colored('****', 'magenta' if jit else None)} recv {args[0]} from rank {args[2]}")
|
||||
def __recv_rb(args, variables=None, jit=False, force_wait=False):
|
||||
x, target_rank, y = args[:3]
|
||||
dist.OOB.recv(target_rank)
|
||||
if RawHIPBuffer and x.__class__ is RawHIPBuffer:
|
||||
x._transfer(y)
|
||||
elif isinstance(x, RawBufferCopyIn): x._copyin(y.toCPU())
|
||||
else: x.fromCPU(y.toCPU())
|
||||
if DEBUG >= 2: print(f"{colored('****', 'magenta' if jit else None)} rank {getenv('RANK')} recv {x} from rank {target_rank}")
|
||||
|
||||
# send a rawbuffer from out rank to the target rank
|
||||
def _send_rb(x:RawBufferCopyInOut, target_rank:int, cache_id:Optional[str]=None):
|
||||
assert isinstance(x, RawBufferCopyInOut), "we only support RawBufferCopyInOut for now"
|
||||
# cache the shared memory so we don't have to create it every time
|
||||
if cache_id is not None and cache_id in _send_rb.shared_memory_cache:
|
||||
shm_name = _send_rb.shared_memory_cache[cache_id]
|
||||
else:
|
||||
shm_name = (s := shared_memory.SharedMemory(create=True, size=x.size * x.dtype.itemsize)).name
|
||||
s.close()
|
||||
if cache_id is not None: _send_rb.shared_memory_cache[cache_id] = shm_name
|
||||
# copy the buffer into shared memory
|
||||
device = f"{shm_name},{cache_id}" if cache_id is not None else shm_name
|
||||
rb = RawShmBuffer(x.size, x.dtype, device=device)
|
||||
__send_rb((x, rb, target_rank, (shm_name, cache_id)))
|
||||
def _send_rb(x:RawBuffer, target_rank:int, cache_id:Optional[str]=None):
|
||||
if RawHIPBuffer and x.__class__ is RawHIPBuffer:
|
||||
# send ipc handle
|
||||
hip.hipSetDevice(x._device)
|
||||
hip.hipDeviceSynchronize()
|
||||
handle = hip.hipIpcGetMemHandle(x._buf)
|
||||
dist.OOB.send((handle, x._device), target_rank)
|
||||
|
||||
# jit support
|
||||
CacheCollector.add(__send_rb, [x, rb, target_rank, None], {})
|
||||
# jit support
|
||||
x._allocator = None # need to disconnect allocator for sent buffers
|
||||
CacheCollector.add(__send_rb, [x, target_rank, None], {})
|
||||
else:
|
||||
# cache the shared memory so we don't have to create it every time
|
||||
if cache_id is not None and cache_id in _send_rb.shared_memory_cache:
|
||||
shm_name = _send_rb.shared_memory_cache[cache_id]
|
||||
else:
|
||||
shm_name = (s := shared_memory.SharedMemory(create=True, size=x.size * x.dtype.itemsize)).name
|
||||
s.close()
|
||||
if cache_id is not None: _send_rb.shared_memory_cache[cache_id] = shm_name
|
||||
# copy the buffer into shared memory
|
||||
device = f"{shm_name},{cache_id}" if cache_id is not None else shm_name
|
||||
y = RawShmBuffer(x.size, x.dtype, device=device)
|
||||
|
||||
# fast path when we can directly copyout
|
||||
if isinstance(x, RawBufferCopyInOut): x._copyout(np.frombuffer(y._buffer(), dtype=x.dtype.np))
|
||||
else: y.fromCPU(x.toCPU())
|
||||
dist.OOB.send((shm_name, cache_id), target_rank)
|
||||
|
||||
# jit support
|
||||
CacheCollector.add(__send_rb, [x, target_rank, y], {})
|
||||
setattr(_send_rb, "shared_memory_cache", {})
|
||||
|
||||
# receive a rawbuffer from the target rank
|
||||
def _recv_rb(x:RawBufferCopyIn, target_rank:int):
|
||||
assert isinstance(x, RawBufferCopyIn), "we only support RawBufferCopyIn for now"
|
||||
extra = dist.OOB.recv(target_rank)
|
||||
device = f"{extra[0]},{extra[1]}" if extra[1] is not None else f"{extra[0]}"
|
||||
rb = RawShmBuffer(x.size, x.dtype, device=device)
|
||||
x._copyin(rb.toCPU())
|
||||
if DEBUG >= 2: print(f"**** got {x} from rank {target_rank}")
|
||||
def _recv_rb(x:RawBuffer, target_rank:int):
|
||||
if RawHIPBuffer and isinstance(x, RawHIPBuffer):
|
||||
# open ipc handle
|
||||
handle, y_device = dist.OOB.recv(target_rank)
|
||||
hip.hipSetDevice(y_device)
|
||||
ptr = hip.hipIpcOpenMemHandle(handle, 0)
|
||||
|
||||
if extra[1] is None:
|
||||
(s := shared_memory.SharedMemory(name=extra[0])).close()
|
||||
s.unlink()
|
||||
# build a new buffer
|
||||
y = RawHIPBuffer(x.size, x.dtype, device=str(y_device), buf=ptr, allocator=None)
|
||||
x._transfer(y)
|
||||
|
||||
# jit support
|
||||
CacheCollector.add(__recv_rb, [x, rb, target_rank], {})
|
||||
if DEBUG >= 2: print(f"**** rank {getenv('RANK')} got {x} from rank {target_rank}")
|
||||
CacheCollector.add(__recv_rb, [x, target_rank, y], {})
|
||||
else:
|
||||
extra = dist.OOB.recv(target_rank)
|
||||
device = f"{extra[0]},{extra[1]}" if extra[1] is not None else f"{extra[0]}"
|
||||
y = RawShmBuffer(x.size, x.dtype, device=device)
|
||||
|
||||
# fast path when we can directly copyin
|
||||
if isinstance(x, RawBufferCopyIn): x._copyin(y.toCPU())
|
||||
else: x.fromCPU(y.toCPU())
|
||||
if DEBUG >= 2: print(f"**** rank {getenv('RANK')} got {x} from rank {target_rank}")
|
||||
|
||||
if extra[1] is None:
|
||||
(s := shared_memory.SharedMemory(name=extra[0])).close()
|
||||
s.unlink()
|
||||
|
||||
# jit support
|
||||
CacheCollector.add(__recv_rb, [x, target_rank, y], {})
|
||||
|
||||
# sends a lazybuffer from our rank to the target rank
|
||||
def _send_lb(x:LazyBuffer, target_rank:int, cache_id:Optional[str]=None) -> None:
|
||||
|
||||
@@ -244,7 +244,6 @@ try:
|
||||
status = _libhip.hipGraphExecKernelNodeSetParams(gexec, node, ctypes.byref(params.c_struct))
|
||||
hipCheckStatus(status)
|
||||
|
||||
|
||||
_libhip.hipMalloc.restype = int
|
||||
_libhip.hipMalloc.argtypes = [ctypes.POINTER(ctypes.c_void_p), ctypes.c_size_t]
|
||||
def hipMalloc(count):
|
||||
|
||||
9
test/external/dist/test_collectives.py
vendored
9
test/external/dist/test_collectives.py
vendored
@@ -24,7 +24,7 @@ def run():
|
||||
# loop 3 times to make sure it works with the jit
|
||||
for _ in range(3):
|
||||
# create a tensor to send
|
||||
t = Tensor.zeros(SIZE, SIZE) if rank == 0 else Tensor.ones(SIZE, SIZE)
|
||||
t = Tensor.zeros(SIZE, SIZE) if rank != 0 else Tensor.ones(SIZE, SIZE)
|
||||
t2 = allreduce_jit(t.contiguous().realize(), cache_id="test")
|
||||
assert np.allclose(np.ones((SIZE, SIZE)), t2.numpy()), f"{t2.numpy()} wasn't ones"
|
||||
|
||||
@@ -42,7 +42,12 @@ def run():
|
||||
print(f"rank {rank} passed")
|
||||
|
||||
if __name__ == "__main__":
|
||||
devices = ["gpu:0", "gpu:1" if not CI else "gpu:0"]
|
||||
if getenv("HIP"):
|
||||
from tinygrad.runtime.ops_hip import HIP
|
||||
devices = [f"hip:{i}" for i in range(HIP.device_count)]
|
||||
else:
|
||||
from tinygrad.runtime.ops_gpu import CL
|
||||
devices = [f"gpu:{i}" for i in range(len(CL.devices))] if not CI else ["gpu:0", "gpu:0"]
|
||||
world_size = len(devices)
|
||||
|
||||
dist.init_oob(world_size)
|
||||
|
||||
7
test/external/dist/test_world.py
vendored
7
test/external/dist/test_world.py
vendored
@@ -45,12 +45,15 @@ def run():
|
||||
|
||||
# check that the received tensor is the same as the sent tensor
|
||||
if rank == 0:
|
||||
assert np.allclose(t.numpy(), t2.numpy())
|
||||
assert np.allclose(t.numpy(), t2.numpy()), f"{t2.numpy()} wasn't equal to {t.numpy()}"
|
||||
|
||||
print(f"rank {rank} passed")
|
||||
|
||||
if __name__ == "__main__":
|
||||
devices = ["gpu:0", "gpu:1" if not CI else "gpu:0"]
|
||||
if getenv("HIP"):
|
||||
devices = ["hip:0", "hip:1"]
|
||||
else:
|
||||
devices = ["gpu:0", "gpu:1" if not CI else "gpu:0"]
|
||||
world_size = len(devices)
|
||||
|
||||
dist.init_oob(world_size)
|
||||
|
||||
@@ -24,7 +24,7 @@ class TinyJit:
|
||||
def __get__(self, obj, objtype): return functools.partial(self.__call__, obj)
|
||||
|
||||
def __call__(self, *args, **kwargs) -> Any:
|
||||
if Device.DEFAULT not in JIT_SUPPORTED_DEVICE: return self.fxn(*args, **kwargs) # only jit on supported device
|
||||
if Device.DEFAULT.split(":")[0] not in JIT_SUPPORTED_DEVICE: return self.fxn(*args, **kwargs) # only jit on supported device
|
||||
# NOTE: this cast is needed since although we know realize will create a ".realized" RawBuffer, the type checker doesn't
|
||||
input_rawbuffers: Dict[Union[int, str], Tuple[RawBuffer, ShapeTracker]] = {cast(Union[int, str], k):(cast(RawBuffer, v.realize().lazydata.realized), v.lazydata.st) for k,v in itertools.chain(enumerate(args), kwargs.items()) if v.__class__ is Tensor}
|
||||
assert len(input_rawbuffers) != 0, "no inputs to JIT"
|
||||
|
||||
@@ -23,9 +23,10 @@ class HIPAllocator(LRUAllocator):
|
||||
def _cached_bufkey(self, size, dtype, device): return (device, size*dtype.itemsize) # Buffers of the same length could be reused, no matter what dtype.
|
||||
|
||||
class _HIP:
|
||||
def __init__(self):
|
||||
def __init__(self, device=None):
|
||||
self.default_device = device or getenv("HIP_DEFAULT_DEVICE")
|
||||
hip.hipSetDevice(self.default_device)
|
||||
self.device_count = hip.hipGetDeviceCount()
|
||||
self.default_device = getenv("HIP_DEFAULT_DEVICE")
|
||||
self.allocator = HIPAllocator(hip.hipGetDeviceProperties(self.default_device).totalGlobalMem)
|
||||
HIP = _HIP()
|
||||
|
||||
@@ -66,12 +67,16 @@ class HIPGraph(GraphBatchExecutor):
|
||||
def exec_instance(self, instid): hip.hipGraphLaunch(self.graphs[instid][0])
|
||||
|
||||
class RawHIPBuffer(RawBufferCopyInOut, RawBufferTransfer):
|
||||
def __init__(self, size, dtype, device=str(HIP.default_device)): super().__init__(size, dtype, allocator=HIP.allocator, **{'device': int(device)})
|
||||
def __init__(self, size, dtype, device=HIP.default_device, buf=None, allocator=HIP.allocator): super().__init__(size, dtype, buf=buf, allocator=allocator, **{'device': int(device)})
|
||||
def _copyin(self, x:np.ndarray):
|
||||
x = np.require(x, requirements='C')
|
||||
hip.hipMemcpyAsync(self._buf, x.ctypes.data, self.size * self.dtype.itemsize, hip.hipMemcpyHostToDevice, 0)
|
||||
def _copyout(self, x:np.ndarray): hip.hipMemcpy(x.ctypes.data, self._buf, self.size * self.dtype.itemsize, hip.hipMemcpyDeviceToHost)
|
||||
def _transfer(self, x): hip.hipMemcpyAsync(self._buf, x._buf, self.size * self.dtype.itemsize, hip.hipMemcpyDeviceToDevice, 0)
|
||||
hip.hipSetDevice(self._device)
|
||||
hip.hipMemcpyAsync(self._buf, np.require(x, requirements='C').ctypes.data, self.size * self.dtype.itemsize, hip.hipMemcpyHostToDevice, 0)
|
||||
def _copyout(self, x:np.ndarray):
|
||||
hip.hipSetDevice(self._device)
|
||||
hip.hipMemcpy(x.ctypes.data, self._buf, self.size * self.dtype.itemsize, hip.hipMemcpyDeviceToHost)
|
||||
def _transfer(self, x):
|
||||
hip.hipSetDevice(x._device)
|
||||
hip.hipMemcpy(self._buf, x._buf, self.size * self.dtype.itemsize, hip.hipMemcpyDeviceToDevice)
|
||||
|
||||
class HIPProgram:
|
||||
def __init__(self, name:str, prg:str, binary=False):
|
||||
|
||||
Reference in New Issue
Block a user