hip multigpu training (#1878)

* feat: move to hip

* feat: special path for RawBufferTransfer

* feat: initial rawbuffertransfer

* feat: hip ipc

* feat: working hip ipc

* feat: need to base device without args

* feat: close mem handle

* feat: modified test

* feat: more multihip stuff

* clean: cleanup

* feat: cleaner

* feat: don't crash

* feat: test more

* clean: way cleaner hip wrapper

* feat: barrier

* feat: barrier

* feat: this breaks stuff

* feat: we can use empty here

* feat: maybe fix tests

* feat: maybe fix tests again?

* fix: probably fix tests

* feat: no waiting here

* feat: wait here

* feat: much larger test

* feat: need to sync here

* feat: make this async

* feat: no waiting!

* feat: cut here

* feat: sync copy

* feat: random imports

* feat: much cleaner world

* feat: restore this

* feat: restore this

* clean: cleanup

* feat: set this
This commit is contained in:
wozeparrot
2023-10-24 17:35:53 -04:00
committed by GitHub
parent 2e89fd264f
commit c29653605e
9 changed files with 115 additions and 58 deletions

View File

@@ -426,8 +426,12 @@ if __name__ == "__main__":
if not getenv("DIST"):
train_cifar()
else: # distributed
from tinygrad.runtime.ops_gpu import CL
devices = [f"gpu:{i}" for i in range(len(CL.devices))]
if getenv("HIP"):
from tinygrad.runtime.ops_hip import HIP
devices = [f"hip:{i}" for i in range(HIP.device_count)]
else:
from tinygrad.runtime.ops_gpu import CL
devices = [f"gpu:{i}" for i in range(len(CL.devices))]
world_size = len(devices)
# ensure that the batch size is divisible by the number of devices

View File

@@ -46,12 +46,11 @@ def _process_wrap(rank:int, device:str, oob:_OOB, fn:Callable, args=()):
from tinygrad.runtime.ops_gpu import CL
CL.post_init(device_num)
elif "HIP" in device:
import extra.hip_wrapper as hip
hip.hipSetDevice(device_num)
os.environ["HIP_DEFAULT_DEVICE"] = os.environ["HIP_VISIBLE_DEVICES"] = str(device_num)
if DEBUG >= 1: print(f"distributed process {rank} initialized runtime for device {device}")
# convert device to be process specific
Device.DEFAULT = device.split(":")[0]
Device.DEFAULT = device.split(":")[0] if "GPU" in device else device
fn(*args)

View File

@@ -1,4 +1,4 @@
from tinygrad.tensor import Tensor, Device
from tinygrad.tensor import Tensor
from tinygrad.helpers import getenv
from extra.dist import world
@@ -12,7 +12,7 @@ def allreduce(t:Tensor, cache_id=None) -> Tensor:
# pad to evenly divide
if flattened.shape[0] % WORLD_SIZE != 0:
flattened = Tensor.cat(flattened, Tensor.zeros(WORLD_SIZE - (flattened.shape[0] % WORLD_SIZE)))
flattened = Tensor.cat(flattened, Tensor.empty(WORLD_SIZE - (flattened.shape[0] % WORLD_SIZE)))
# chunk
chunks = flattened.chunk(WORLD_SIZE, dim=0)

118
extra/dist/world.py vendored
View File

@@ -1,59 +1,101 @@
from typing import Any, Optional, Tuple
from typing import Optional
from extra import dist
from multiprocessing import shared_memory
from tinygrad.helpers import DEBUG, colored
from tinygrad.helpers import DEBUG, colored, getenv
from tinygrad.lazy import LazyBuffer
from tinygrad.runtime.lib import RawBufferCopyIn, RawBufferCopyInOut
from tinygrad.runtime.lib import RawBuffer, RawBufferCopyIn, RawBufferCopyInOut
try: from tinygrad.runtime.ops_hip import RawHIPBuffer
except: RawHIPBuffer = None
from tinygrad.runtime.ops_shm import RawShmBuffer
from tinygrad.jit import CacheCollector
from tinygrad.tensor import Tensor, Function
import extra.hip_wrapper as hip
import numpy as np
# fake the function signature of ASTRunner so we can put it in the cache
def __send_rb(args:Tuple[RawBufferCopyInOut, RawShmBuffer, int, Any], variables=None, jit=False, force_wait=False):
args[0]._copyout(np.frombuffer(args[1]._buffer(), dtype=args[0].dtype.np))
dist.OOB.send(args[3], args[2])
if DEBUG >= 2: print(f"{colored('****', 'magenta' if jit else None)} sent {args[0]} to rank {args[2]}")
def __send_rb(args, variables=None, jit=False, force_wait=False):
x, target_rank, y = args[:3]
if RawHIPBuffer and x.__class__ is RawHIPBuffer:
hip.hipSetDevice(x._device)
hip.hipDeviceSynchronize()
else:
if isinstance(x, RawBufferCopyInOut): x._copyout(np.frombuffer(y._buffer(), dtype=x.dtype.np))
else: y.fromCPU(x.toCPU())
dist.OOB.send(None, target_rank)
if DEBUG >= 2: print(f"{colored('****', 'magenta' if jit else None)} rank {getenv('RANK')} sent {x} to rank {target_rank}")
def __recv_rb(args:Tuple[RawBufferCopyIn, RawShmBuffer, int], variables=None, jit=False, force_wait=False):
dist.OOB.recv(args[2])
args[0]._copyin(args[1].toCPU())
if DEBUG >= 2: print(f"{colored('****', 'magenta' if jit else None)} recv {args[0]} from rank {args[2]}")
def __recv_rb(args, variables=None, jit=False, force_wait=False):
x, target_rank, y = args[:3]
dist.OOB.recv(target_rank)
if RawHIPBuffer and x.__class__ is RawHIPBuffer:
x._transfer(y)
elif isinstance(x, RawBufferCopyIn): x._copyin(y.toCPU())
else: x.fromCPU(y.toCPU())
if DEBUG >= 2: print(f"{colored('****', 'magenta' if jit else None)} rank {getenv('RANK')} recv {x} from rank {target_rank}")
# send a rawbuffer from out rank to the target rank
def _send_rb(x:RawBufferCopyInOut, target_rank:int, cache_id:Optional[str]=None):
assert isinstance(x, RawBufferCopyInOut), "we only support RawBufferCopyInOut for now"
# cache the shared memory so we don't have to create it every time
if cache_id is not None and cache_id in _send_rb.shared_memory_cache:
shm_name = _send_rb.shared_memory_cache[cache_id]
else:
shm_name = (s := shared_memory.SharedMemory(create=True, size=x.size * x.dtype.itemsize)).name
s.close()
if cache_id is not None: _send_rb.shared_memory_cache[cache_id] = shm_name
# copy the buffer into shared memory
device = f"{shm_name},{cache_id}" if cache_id is not None else shm_name
rb = RawShmBuffer(x.size, x.dtype, device=device)
__send_rb((x, rb, target_rank, (shm_name, cache_id)))
def _send_rb(x:RawBuffer, target_rank:int, cache_id:Optional[str]=None):
if RawHIPBuffer and x.__class__ is RawHIPBuffer:
# send ipc handle
hip.hipSetDevice(x._device)
hip.hipDeviceSynchronize()
handle = hip.hipIpcGetMemHandle(x._buf)
dist.OOB.send((handle, x._device), target_rank)
# jit support
CacheCollector.add(__send_rb, [x, rb, target_rank, None], {})
# jit support
x._allocator = None # need to disconnect allocator for sent buffers
CacheCollector.add(__send_rb, [x, target_rank, None], {})
else:
# cache the shared memory so we don't have to create it every time
if cache_id is not None and cache_id in _send_rb.shared_memory_cache:
shm_name = _send_rb.shared_memory_cache[cache_id]
else:
shm_name = (s := shared_memory.SharedMemory(create=True, size=x.size * x.dtype.itemsize)).name
s.close()
if cache_id is not None: _send_rb.shared_memory_cache[cache_id] = shm_name
# copy the buffer into shared memory
device = f"{shm_name},{cache_id}" if cache_id is not None else shm_name
y = RawShmBuffer(x.size, x.dtype, device=device)
# fast path when we can directly copyout
if isinstance(x, RawBufferCopyInOut): x._copyout(np.frombuffer(y._buffer(), dtype=x.dtype.np))
else: y.fromCPU(x.toCPU())
dist.OOB.send((shm_name, cache_id), target_rank)
# jit support
CacheCollector.add(__send_rb, [x, target_rank, y], {})
setattr(_send_rb, "shared_memory_cache", {})
# receive a rawbuffer from the target rank
def _recv_rb(x:RawBufferCopyIn, target_rank:int):
assert isinstance(x, RawBufferCopyIn), "we only support RawBufferCopyIn for now"
extra = dist.OOB.recv(target_rank)
device = f"{extra[0]},{extra[1]}" if extra[1] is not None else f"{extra[0]}"
rb = RawShmBuffer(x.size, x.dtype, device=device)
x._copyin(rb.toCPU())
if DEBUG >= 2: print(f"**** got {x} from rank {target_rank}")
def _recv_rb(x:RawBuffer, target_rank:int):
if RawHIPBuffer and isinstance(x, RawHIPBuffer):
# open ipc handle
handle, y_device = dist.OOB.recv(target_rank)
hip.hipSetDevice(y_device)
ptr = hip.hipIpcOpenMemHandle(handle, 0)
if extra[1] is None:
(s := shared_memory.SharedMemory(name=extra[0])).close()
s.unlink()
# build a new buffer
y = RawHIPBuffer(x.size, x.dtype, device=str(y_device), buf=ptr, allocator=None)
x._transfer(y)
# jit support
CacheCollector.add(__recv_rb, [x, rb, target_rank], {})
if DEBUG >= 2: print(f"**** rank {getenv('RANK')} got {x} from rank {target_rank}")
CacheCollector.add(__recv_rb, [x, target_rank, y], {})
else:
extra = dist.OOB.recv(target_rank)
device = f"{extra[0]},{extra[1]}" if extra[1] is not None else f"{extra[0]}"
y = RawShmBuffer(x.size, x.dtype, device=device)
# fast path when we can directly copyin
if isinstance(x, RawBufferCopyIn): x._copyin(y.toCPU())
else: x.fromCPU(y.toCPU())
if DEBUG >= 2: print(f"**** rank {getenv('RANK')} got {x} from rank {target_rank}")
if extra[1] is None:
(s := shared_memory.SharedMemory(name=extra[0])).close()
s.unlink()
# jit support
CacheCollector.add(__recv_rb, [x, target_rank, y], {})
# sends a lazybuffer from our rank to the target rank
def _send_lb(x:LazyBuffer, target_rank:int, cache_id:Optional[str]=None) -> None:

View File

@@ -244,7 +244,6 @@ try:
status = _libhip.hipGraphExecKernelNodeSetParams(gexec, node, ctypes.byref(params.c_struct))
hipCheckStatus(status)
_libhip.hipMalloc.restype = int
_libhip.hipMalloc.argtypes = [ctypes.POINTER(ctypes.c_void_p), ctypes.c_size_t]
def hipMalloc(count):

View File

@@ -24,7 +24,7 @@ def run():
# loop 3 times to make sure it works with the jit
for _ in range(3):
# create a tensor to send
t = Tensor.zeros(SIZE, SIZE) if rank == 0 else Tensor.ones(SIZE, SIZE)
t = Tensor.zeros(SIZE, SIZE) if rank != 0 else Tensor.ones(SIZE, SIZE)
t2 = allreduce_jit(t.contiguous().realize(), cache_id="test")
assert np.allclose(np.ones((SIZE, SIZE)), t2.numpy()), f"{t2.numpy()} wasn't ones"
@@ -42,7 +42,12 @@ def run():
print(f"rank {rank} passed")
if __name__ == "__main__":
devices = ["gpu:0", "gpu:1" if not CI else "gpu:0"]
if getenv("HIP"):
from tinygrad.runtime.ops_hip import HIP
devices = [f"hip:{i}" for i in range(HIP.device_count)]
else:
from tinygrad.runtime.ops_gpu import CL
devices = [f"gpu:{i}" for i in range(len(CL.devices))] if not CI else ["gpu:0", "gpu:0"]
world_size = len(devices)
dist.init_oob(world_size)

View File

@@ -45,12 +45,15 @@ def run():
# check that the received tensor is the same as the sent tensor
if rank == 0:
assert np.allclose(t.numpy(), t2.numpy())
assert np.allclose(t.numpy(), t2.numpy()), f"{t2.numpy()} wasn't equal to {t.numpy()}"
print(f"rank {rank} passed")
if __name__ == "__main__":
devices = ["gpu:0", "gpu:1" if not CI else "gpu:0"]
if getenv("HIP"):
devices = ["hip:0", "hip:1"]
else:
devices = ["gpu:0", "gpu:1" if not CI else "gpu:0"]
world_size = len(devices)
dist.init_oob(world_size)

View File

@@ -24,7 +24,7 @@ class TinyJit:
def __get__(self, obj, objtype): return functools.partial(self.__call__, obj)
def __call__(self, *args, **kwargs) -> Any:
if Device.DEFAULT not in JIT_SUPPORTED_DEVICE: return self.fxn(*args, **kwargs) # only jit on supported device
if Device.DEFAULT.split(":")[0] not in JIT_SUPPORTED_DEVICE: return self.fxn(*args, **kwargs) # only jit on supported device
# NOTE: this cast is needed since although we know realize will create a ".realized" RawBuffer, the type checker doesn't
input_rawbuffers: Dict[Union[int, str], Tuple[RawBuffer, ShapeTracker]] = {cast(Union[int, str], k):(cast(RawBuffer, v.realize().lazydata.realized), v.lazydata.st) for k,v in itertools.chain(enumerate(args), kwargs.items()) if v.__class__ is Tensor}
assert len(input_rawbuffers) != 0, "no inputs to JIT"

View File

@@ -23,9 +23,10 @@ class HIPAllocator(LRUAllocator):
def _cached_bufkey(self, size, dtype, device): return (device, size*dtype.itemsize) # Buffers of the same length could be reused, no matter what dtype.
class _HIP:
def __init__(self):
def __init__(self, device=None):
self.default_device = device or getenv("HIP_DEFAULT_DEVICE")
hip.hipSetDevice(self.default_device)
self.device_count = hip.hipGetDeviceCount()
self.default_device = getenv("HIP_DEFAULT_DEVICE")
self.allocator = HIPAllocator(hip.hipGetDeviceProperties(self.default_device).totalGlobalMem)
HIP = _HIP()
@@ -66,12 +67,16 @@ class HIPGraph(GraphBatchExecutor):
def exec_instance(self, instid): hip.hipGraphLaunch(self.graphs[instid][0])
class RawHIPBuffer(RawBufferCopyInOut, RawBufferTransfer):
def __init__(self, size, dtype, device=str(HIP.default_device)): super().__init__(size, dtype, allocator=HIP.allocator, **{'device': int(device)})
def __init__(self, size, dtype, device=HIP.default_device, buf=None, allocator=HIP.allocator): super().__init__(size, dtype, buf=buf, allocator=allocator, **{'device': int(device)})
def _copyin(self, x:np.ndarray):
x = np.require(x, requirements='C')
hip.hipMemcpyAsync(self._buf, x.ctypes.data, self.size * self.dtype.itemsize, hip.hipMemcpyHostToDevice, 0)
def _copyout(self, x:np.ndarray): hip.hipMemcpy(x.ctypes.data, self._buf, self.size * self.dtype.itemsize, hip.hipMemcpyDeviceToHost)
def _transfer(self, x): hip.hipMemcpyAsync(self._buf, x._buf, self.size * self.dtype.itemsize, hip.hipMemcpyDeviceToDevice, 0)
hip.hipSetDevice(self._device)
hip.hipMemcpyAsync(self._buf, np.require(x, requirements='C').ctypes.data, self.size * self.dtype.itemsize, hip.hipMemcpyHostToDevice, 0)
def _copyout(self, x:np.ndarray):
hip.hipSetDevice(self._device)
hip.hipMemcpy(x.ctypes.data, self._buf, self.size * self.dtype.itemsize, hip.hipMemcpyDeviceToHost)
def _transfer(self, x):
hip.hipSetDevice(x._device)
hip.hipMemcpy(self._buf, x._buf, self.size * self.dtype.itemsize, hip.hipMemcpyDeviceToDevice)
class HIPProgram:
def __init__(self, name:str, prg:str, binary=False):