mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-08 22:48:25 -05:00
multi device training with GPT2 [pr] (#10375)
* multi device training with GPT2 [pr] * Update grouper.py
This commit is contained in:
@@ -99,7 +99,7 @@ class GPT:
|
||||
|
||||
def __call__(self, idx:Tensor, targets=None):
|
||||
b, t = idx.shape
|
||||
pos = Tensor.arange(0, t)
|
||||
pos = Tensor.arange(0, t, device=idx.device)
|
||||
|
||||
tok_emb = self.wte(idx) # token embeddings of shape (b, t, n_embd)
|
||||
pos_emb = self.wpe(pos) # position embeddings of shape (t, n_embd)
|
||||
@@ -124,6 +124,7 @@ if __name__ == "__main__":
|
||||
parser.add_argument("--batch_size", type=int, default=4, help="batch size")
|
||||
parser.add_argument("--sequence_length", type=int, default=64, help="sequence length")
|
||||
parser.add_argument("--skip_test", action="store_true", help="skip test")
|
||||
parser.add_argument("--gpus", type=int, default=1, help="sequence length")
|
||||
args = parser.parse_args()
|
||||
B, T = args.batch_size, args.sequence_length
|
||||
assert 1 <= T <= 1024
|
||||
@@ -131,6 +132,10 @@ if __name__ == "__main__":
|
||||
model = GPT(GPTConfig(n_layer=12, n_head=12, n_embd=768))
|
||||
model.load_pretrained()
|
||||
|
||||
if args.gpus > 1:
|
||||
GPUS = tuple(f'{Device.DEFAULT}:{i}' for i in range(args.gpus))
|
||||
for x in nn.state.get_parameters(model): x.to_(GPUS) # we put a copy of the model on every GPU
|
||||
|
||||
# init the tokenizer
|
||||
enc = tiktoken.get_encoding("gpt2")
|
||||
encode = lambda s: enc.encode(s, allowed_special={"<|endoftext|>"})
|
||||
@@ -165,23 +170,32 @@ if __name__ == "__main__":
|
||||
x, y = next(data_iter) # we'll overfit this batch below
|
||||
optimizer = nn.optim.AdamW(nn.state.get_parameters(model), lr=1e-4, weight_decay=0)
|
||||
|
||||
print(f"model state: {sum(x.nbytes() for x in nn.state.get_parameters(model))/1e9:.2f} GB")
|
||||
print(f"optimizer state: {sum(x.nbytes() for x in nn.state.get_parameters(optimizer))/1e9:.2f} GB")
|
||||
|
||||
# shard the data on axis 0
|
||||
if args.gpus > 1: x, y = x.shard(GPUS, axis=0), y.shard(GPUS, axis=0)
|
||||
|
||||
@TinyJit
|
||||
def step(x, y):
|
||||
@Tensor.train()
|
||||
def step(x:Tensor, y:Tensor) -> Tensor:
|
||||
_, loss = model(x, y)
|
||||
optimizer.zero_grad()
|
||||
loss.backward()
|
||||
return loss.realize(*optimizer.schedule_step())
|
||||
|
||||
with Tensor.train():
|
||||
for i in range(args.num_iterations):
|
||||
GlobalCounters.reset()
|
||||
t0 = time.time()
|
||||
loss = step(x.contiguous(), y.contiguous())
|
||||
Device[Device.DEFAULT].synchronize()
|
||||
t1 = time.time()
|
||||
print(f"iteration {i}, loss: {loss.item():.6f}, time: {(t1-t0)*1000:.3f}ms, {int(B*T/(t1-t0))} tok/s")
|
||||
for i in range(args.num_iterations):
|
||||
GlobalCounters.reset()
|
||||
t0 = time.perf_counter()
|
||||
loss = step(x.contiguous(), y.contiguous())
|
||||
Device[Device.DEFAULT].synchronize()
|
||||
t1 = time.perf_counter()
|
||||
print(f"iteration {i}, loss: {loss.item():.6f}, time: {(t1-t0)*1000:.3f}ms, {int(B*T/(t1-t0))} tok/s, {GlobalCounters.global_mem/1e9:.2f} GB")
|
||||
|
||||
if not args.skip_test:
|
||||
# copy back to single gpu for test
|
||||
if args.gpus > 1:
|
||||
for x in nn.state.get_parameters(model): x.to_(Device.DEFAULT)
|
||||
start = "<|endoftext|>"
|
||||
start_ids = encode(start)
|
||||
x = (Tensor(start_ids)[None, ...])
|
||||
|
||||
@@ -94,11 +94,12 @@ class BufferSpec:
|
||||
class MultiBuffer:
|
||||
def __init__(self, device:tuple[str, ...], size:int, dtype:DType):
|
||||
self.bufs = [Buffer(d, size, dtype) for d in device]
|
||||
self.dtype = dtype
|
||||
self.size, self.dtype = size, dtype
|
||||
def ref(self, cnt):
|
||||
for b in self.bufs: b.ref(cnt)
|
||||
return self
|
||||
def is_allocated(self): return all(x.is_allocated() for x in self.bufs)
|
||||
def __repr__(self): return f"<multibuf real:{self.is_allocated()} device:{tuple(x.device for x in self.bufs)} size:{self.size} dtype:{self.dtype}>"
|
||||
|
||||
class Buffer:
|
||||
def __init__(self, device:str, size:int, dtype:DType, opaque:Any=None, options:Optional[BufferSpec]=None, initial_value:Optional[bytes]=None,
|
||||
|
||||
@@ -398,7 +398,8 @@ def fix_kernel_ast(k:UOp) -> UOp|None:
|
||||
ast = graph_rewrite(graph_rewrite(ast, view_left, name="Main View Left"), view_right, name="Main View Right")
|
||||
# replace buffer with define_global + add load/store last
|
||||
ast = graph_rewrite(ast, merge_views+add_buffer_ops+fix_kernel_ops, bufs:=tuple(s.buf_uop for s in k.src), bottom_up=True, name="replace buffer")
|
||||
if ast.op is Ops.SINK and not all_same(dev:=[x.device for x in bufs]): raise RuntimeError(f"all buffers must be on the same device: {dev}")
|
||||
if ast.op is Ops.SINK and not all_same([x.device for x in bufs]):
|
||||
raise RuntimeError(f"all buffers must be on the same device: {tuple(b.buffer for b in bufs)}")
|
||||
return k.replace(arg=Kernel(ast, k.arg.metadata))
|
||||
|
||||
create_ast = PatternMatcher([(UPat(Ops.KERNEL, name="k"), fix_kernel_ast),])
|
||||
|
||||
@@ -224,7 +224,7 @@ class AMMemoryManager:
|
||||
if paddr is not None: paddrs += [(paddr, cont_seg_sz)]
|
||||
else:
|
||||
for paddr, _ in paddrs: self.pa_allocator.free(paddr)
|
||||
raise MemoryError(f"Failed to allocate contigous a page. (allocation size={size:#x})")
|
||||
raise MemoryError(f"Failed to allocate a contiguous page. (allocation size={size:#x})")
|
||||
rem_len, off = rem_len - cont_seg_sz, off + cont_seg_sz
|
||||
|
||||
return self.map_range(va, size, paddrs, uncached=uncached)
|
||||
|
||||
Reference in New Issue
Block a user