multi device training with GPT2 [pr] (#10375)

* multi device training with GPT2 [pr]

* Update grouper.py
This commit is contained in:
George Hotz
2025-05-17 15:33:56 -07:00
committed by GitHub
parent 6ec88d94df
commit 0b733ba75e
4 changed files with 29 additions and 13 deletions

View File

@@ -99,7 +99,7 @@ class GPT:
def __call__(self, idx:Tensor, targets=None):
b, t = idx.shape
pos = Tensor.arange(0, t)
pos = Tensor.arange(0, t, device=idx.device)
tok_emb = self.wte(idx) # token embeddings of shape (b, t, n_embd)
pos_emb = self.wpe(pos) # position embeddings of shape (t, n_embd)
@@ -124,6 +124,7 @@ if __name__ == "__main__":
parser.add_argument("--batch_size", type=int, default=4, help="batch size")
parser.add_argument("--sequence_length", type=int, default=64, help="sequence length")
parser.add_argument("--skip_test", action="store_true", help="skip test")
parser.add_argument("--gpus", type=int, default=1, help="sequence length")
args = parser.parse_args()
B, T = args.batch_size, args.sequence_length
assert 1 <= T <= 1024
@@ -131,6 +132,10 @@ if __name__ == "__main__":
model = GPT(GPTConfig(n_layer=12, n_head=12, n_embd=768))
model.load_pretrained()
if args.gpus > 1:
GPUS = tuple(f'{Device.DEFAULT}:{i}' for i in range(args.gpus))
for x in nn.state.get_parameters(model): x.to_(GPUS) # we put a copy of the model on every GPU
# init the tokenizer
enc = tiktoken.get_encoding("gpt2")
encode = lambda s: enc.encode(s, allowed_special={"<|endoftext|>"})
@@ -165,23 +170,32 @@ if __name__ == "__main__":
x, y = next(data_iter) # we'll overfit this batch below
optimizer = nn.optim.AdamW(nn.state.get_parameters(model), lr=1e-4, weight_decay=0)
print(f"model state: {sum(x.nbytes() for x in nn.state.get_parameters(model))/1e9:.2f} GB")
print(f"optimizer state: {sum(x.nbytes() for x in nn.state.get_parameters(optimizer))/1e9:.2f} GB")
# shard the data on axis 0
if args.gpus > 1: x, y = x.shard(GPUS, axis=0), y.shard(GPUS, axis=0)
@TinyJit
def step(x, y):
@Tensor.train()
def step(x:Tensor, y:Tensor) -> Tensor:
_, loss = model(x, y)
optimizer.zero_grad()
loss.backward()
return loss.realize(*optimizer.schedule_step())
with Tensor.train():
for i in range(args.num_iterations):
GlobalCounters.reset()
t0 = time.time()
loss = step(x.contiguous(), y.contiguous())
Device[Device.DEFAULT].synchronize()
t1 = time.time()
print(f"iteration {i}, loss: {loss.item():.6f}, time: {(t1-t0)*1000:.3f}ms, {int(B*T/(t1-t0))} tok/s")
for i in range(args.num_iterations):
GlobalCounters.reset()
t0 = time.perf_counter()
loss = step(x.contiguous(), y.contiguous())
Device[Device.DEFAULT].synchronize()
t1 = time.perf_counter()
print(f"iteration {i}, loss: {loss.item():.6f}, time: {(t1-t0)*1000:.3f}ms, {int(B*T/(t1-t0))} tok/s, {GlobalCounters.global_mem/1e9:.2f} GB")
if not args.skip_test:
# copy back to single gpu for test
if args.gpus > 1:
for x in nn.state.get_parameters(model): x.to_(Device.DEFAULT)
start = "<|endoftext|>"
start_ids = encode(start)
x = (Tensor(start_ids)[None, ...])