multi device training with GPT2 [pr] (#10375)

* multi device training with GPT2 [pr]

* Update grouper.py
This commit is contained in:
George Hotz
2025-05-17 15:33:56 -07:00
committed by GitHub
parent 6ec88d94df
commit 0b733ba75e
4 changed files with 29 additions and 13 deletions

View File

@@ -99,7 +99,7 @@ class GPT:
def __call__(self, idx:Tensor, targets=None):
b, t = idx.shape
pos = Tensor.arange(0, t)
pos = Tensor.arange(0, t, device=idx.device)
tok_emb = self.wte(idx) # token embeddings of shape (b, t, n_embd)
pos_emb = self.wpe(pos) # position embeddings of shape (t, n_embd)
@@ -124,6 +124,7 @@ if __name__ == "__main__":
parser.add_argument("--batch_size", type=int, default=4, help="batch size")
parser.add_argument("--sequence_length", type=int, default=64, help="sequence length")
parser.add_argument("--skip_test", action="store_true", help="skip test")
parser.add_argument("--gpus", type=int, default=1, help="sequence length")
args = parser.parse_args()
B, T = args.batch_size, args.sequence_length
assert 1 <= T <= 1024
@@ -131,6 +132,10 @@ if __name__ == "__main__":
model = GPT(GPTConfig(n_layer=12, n_head=12, n_embd=768))
model.load_pretrained()
if args.gpus > 1:
GPUS = tuple(f'{Device.DEFAULT}:{i}' for i in range(args.gpus))
for x in nn.state.get_parameters(model): x.to_(GPUS) # we put a copy of the model on every GPU
# init the tokenizer
enc = tiktoken.get_encoding("gpt2")
encode = lambda s: enc.encode(s, allowed_special={"<|endoftext|>"})
@@ -165,23 +170,32 @@ if __name__ == "__main__":
x, y = next(data_iter) # we'll overfit this batch below
optimizer = nn.optim.AdamW(nn.state.get_parameters(model), lr=1e-4, weight_decay=0)
print(f"model state: {sum(x.nbytes() for x in nn.state.get_parameters(model))/1e9:.2f} GB")
print(f"optimizer state: {sum(x.nbytes() for x in nn.state.get_parameters(optimizer))/1e9:.2f} GB")
# shard the data on axis 0
if args.gpus > 1: x, y = x.shard(GPUS, axis=0), y.shard(GPUS, axis=0)
@TinyJit
def step(x, y):
@Tensor.train()
def step(x:Tensor, y:Tensor) -> Tensor:
_, loss = model(x, y)
optimizer.zero_grad()
loss.backward()
return loss.realize(*optimizer.schedule_step())
with Tensor.train():
for i in range(args.num_iterations):
GlobalCounters.reset()
t0 = time.time()
loss = step(x.contiguous(), y.contiguous())
Device[Device.DEFAULT].synchronize()
t1 = time.time()
print(f"iteration {i}, loss: {loss.item():.6f}, time: {(t1-t0)*1000:.3f}ms, {int(B*T/(t1-t0))} tok/s")
for i in range(args.num_iterations):
GlobalCounters.reset()
t0 = time.perf_counter()
loss = step(x.contiguous(), y.contiguous())
Device[Device.DEFAULT].synchronize()
t1 = time.perf_counter()
print(f"iteration {i}, loss: {loss.item():.6f}, time: {(t1-t0)*1000:.3f}ms, {int(B*T/(t1-t0))} tok/s, {GlobalCounters.global_mem/1e9:.2f} GB")
if not args.skip_test:
# copy back to single gpu for test
if args.gpus > 1:
for x in nn.state.get_parameters(model): x.to_(Device.DEFAULT)
start = "<|endoftext|>"
start_ids = encode(start)
x = (Tensor(start_ids)[None, ...])

View File

@@ -94,11 +94,12 @@ class BufferSpec:
class MultiBuffer:
def __init__(self, device:tuple[str, ...], size:int, dtype:DType):
self.bufs = [Buffer(d, size, dtype) for d in device]
self.dtype = dtype
self.size, self.dtype = size, dtype
def ref(self, cnt):
for b in self.bufs: b.ref(cnt)
return self
def is_allocated(self): return all(x.is_allocated() for x in self.bufs)
def __repr__(self): return f"<multibuf real:{self.is_allocated()} device:{tuple(x.device for x in self.bufs)} size:{self.size} dtype:{self.dtype}>"
class Buffer:
def __init__(self, device:str, size:int, dtype:DType, opaque:Any=None, options:Optional[BufferSpec]=None, initial_value:Optional[bytes]=None,

View File

@@ -398,7 +398,8 @@ def fix_kernel_ast(k:UOp) -> UOp|None:
ast = graph_rewrite(graph_rewrite(ast, view_left, name="Main View Left"), view_right, name="Main View Right")
# replace buffer with define_global + add load/store last
ast = graph_rewrite(ast, merge_views+add_buffer_ops+fix_kernel_ops, bufs:=tuple(s.buf_uop for s in k.src), bottom_up=True, name="replace buffer")
if ast.op is Ops.SINK and not all_same(dev:=[x.device for x in bufs]): raise RuntimeError(f"all buffers must be on the same device: {dev}")
if ast.op is Ops.SINK and not all_same([x.device for x in bufs]):
raise RuntimeError(f"all buffers must be on the same device: {tuple(b.buffer for b in bufs)}")
return k.replace(arg=Kernel(ast, k.arg.metadata))
create_ast = PatternMatcher([(UPat(Ops.KERNEL, name="k"), fix_kernel_ast),])

View File

@@ -224,7 +224,7 @@ class AMMemoryManager:
if paddr is not None: paddrs += [(paddr, cont_seg_sz)]
else:
for paddr, _ in paddrs: self.pa_allocator.free(paddr)
raise MemoryError(f"Failed to allocate contigous a page. (allocation size={size:#x})")
raise MemoryError(f"Failed to allocate a contiguous page. (allocation size={size:#x})")
rem_len, off = rem_len - cont_seg_sz, off + cont_seg_sz
return self.map_range(va, size, paddrs, uncached=uncached)