mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 15:08:02 -05:00
multi device training with GPT2 [pr] (#10375)
* multi device training with GPT2 [pr] * Update grouper.py
This commit is contained in:
@@ -99,7 +99,7 @@ class GPT:
|
||||
|
||||
def __call__(self, idx:Tensor, targets=None):
|
||||
b, t = idx.shape
|
||||
pos = Tensor.arange(0, t)
|
||||
pos = Tensor.arange(0, t, device=idx.device)
|
||||
|
||||
tok_emb = self.wte(idx) # token embeddings of shape (b, t, n_embd)
|
||||
pos_emb = self.wpe(pos) # position embeddings of shape (t, n_embd)
|
||||
@@ -124,6 +124,7 @@ if __name__ == "__main__":
|
||||
parser.add_argument("--batch_size", type=int, default=4, help="batch size")
|
||||
parser.add_argument("--sequence_length", type=int, default=64, help="sequence length")
|
||||
parser.add_argument("--skip_test", action="store_true", help="skip test")
|
||||
parser.add_argument("--gpus", type=int, default=1, help="sequence length")
|
||||
args = parser.parse_args()
|
||||
B, T = args.batch_size, args.sequence_length
|
||||
assert 1 <= T <= 1024
|
||||
@@ -131,6 +132,10 @@ if __name__ == "__main__":
|
||||
model = GPT(GPTConfig(n_layer=12, n_head=12, n_embd=768))
|
||||
model.load_pretrained()
|
||||
|
||||
if args.gpus > 1:
|
||||
GPUS = tuple(f'{Device.DEFAULT}:{i}' for i in range(args.gpus))
|
||||
for x in nn.state.get_parameters(model): x.to_(GPUS) # we put a copy of the model on every GPU
|
||||
|
||||
# init the tokenizer
|
||||
enc = tiktoken.get_encoding("gpt2")
|
||||
encode = lambda s: enc.encode(s, allowed_special={"<|endoftext|>"})
|
||||
@@ -165,23 +170,32 @@ if __name__ == "__main__":
|
||||
x, y = next(data_iter) # we'll overfit this batch below
|
||||
optimizer = nn.optim.AdamW(nn.state.get_parameters(model), lr=1e-4, weight_decay=0)
|
||||
|
||||
print(f"model state: {sum(x.nbytes() for x in nn.state.get_parameters(model))/1e9:.2f} GB")
|
||||
print(f"optimizer state: {sum(x.nbytes() for x in nn.state.get_parameters(optimizer))/1e9:.2f} GB")
|
||||
|
||||
# shard the data on axis 0
|
||||
if args.gpus > 1: x, y = x.shard(GPUS, axis=0), y.shard(GPUS, axis=0)
|
||||
|
||||
@TinyJit
|
||||
def step(x, y):
|
||||
@Tensor.train()
|
||||
def step(x:Tensor, y:Tensor) -> Tensor:
|
||||
_, loss = model(x, y)
|
||||
optimizer.zero_grad()
|
||||
loss.backward()
|
||||
return loss.realize(*optimizer.schedule_step())
|
||||
|
||||
with Tensor.train():
|
||||
for i in range(args.num_iterations):
|
||||
GlobalCounters.reset()
|
||||
t0 = time.time()
|
||||
loss = step(x.contiguous(), y.contiguous())
|
||||
Device[Device.DEFAULT].synchronize()
|
||||
t1 = time.time()
|
||||
print(f"iteration {i}, loss: {loss.item():.6f}, time: {(t1-t0)*1000:.3f}ms, {int(B*T/(t1-t0))} tok/s")
|
||||
for i in range(args.num_iterations):
|
||||
GlobalCounters.reset()
|
||||
t0 = time.perf_counter()
|
||||
loss = step(x.contiguous(), y.contiguous())
|
||||
Device[Device.DEFAULT].synchronize()
|
||||
t1 = time.perf_counter()
|
||||
print(f"iteration {i}, loss: {loss.item():.6f}, time: {(t1-t0)*1000:.3f}ms, {int(B*T/(t1-t0))} tok/s, {GlobalCounters.global_mem/1e9:.2f} GB")
|
||||
|
||||
if not args.skip_test:
|
||||
# copy back to single gpu for test
|
||||
if args.gpus > 1:
|
||||
for x in nn.state.get_parameters(model): x.to_(Device.DEFAULT)
|
||||
start = "<|endoftext|>"
|
||||
start_ids = encode(start)
|
||||
x = (Tensor(start_ids)[None, ...])
|
||||
|
||||
Reference in New Issue
Block a user