test for device mismatch [pr] (#9250)

* test for device mismatch [pr]

* fix bert
This commit is contained in:
George Hotz
2025-02-26 13:06:33 +08:00
committed by GitHub
parent 9c4d9d9f10
commit 3f4eb9006a
3 changed files with 19 additions and 11 deletions

View File

@@ -273,7 +273,7 @@ def train_resnet():
else:
it = iter(tqdm(batch_load_resnet(batch_size=EVAL_BS, val=True, shuffle=False, pad_first_batch=True), total=steps_in_val_epoch))
i, proc = 0, data_get(it)
prev_cookies = []
while proc is not None:
GlobalCounters.reset()
@@ -446,7 +446,7 @@ def train_unet3d():
loss.backward()
optim.step()
return loss.realize()
@Tensor.train(mode=False)
@Tensor.test()
def eval_step(model, x, y):
@@ -455,7 +455,7 @@ def train_unet3d():
loss = dice_ce_loss(y_hat, y)
score = dice_score(y_hat, y)
return loss.realize(), score.realize()
if WANDB: wandb.init(config=config, project=PROJ_NAME)
step_times, start_epoch = [], 1
@@ -464,7 +464,7 @@ def train_unet3d():
next_eval_at = start_eval_at
print(f"Training on {GPUS}")
if BENCHMARK: print("Benchmarking UNet3D")
else: print(f"Start evaluation at epoch {start_eval_at} and every {evaluate_every} epoch(s) afterwards")
@@ -574,9 +574,9 @@ def train_rnnt():
@TinyJit
def train_step_bert(model, optimizer, scheduler, loss_scaler:float, input_ids:Tensor, segment_ids:Tensor, attention_mask:Tensor,
masked_positions:Tensor, masked_lm_ids:Tensor, masked_lm_weights:Tensor, next_sentence_labels:Tensor, GPUS):
if len(GPUS) > 1:
for t in [input_ids, segment_ids, attention_mask, masked_positions, masked_lm_ids, masked_lm_weights, next_sentence_labels]:
t.shard_(GPUS, axis=0)
for t in [input_ids, segment_ids, attention_mask, masked_positions, masked_lm_ids, masked_lm_weights, next_sentence_labels]:
if len(GPUS) > 1: t.shard_(GPUS, axis=0)
else: t.to_(GPUS[0])
optimizer.zero_grad()
lm_logits, seq_relationship_logits = model(input_ids, attention_mask, masked_positions, segment_ids)
@@ -584,7 +584,7 @@ def train_step_bert(model, optimizer, scheduler, loss_scaler:float, input_ids:Te
(loss * loss_scaler).backward()
global_norm = Tensor([0.0], dtype=dtypes.float32, device=optimizer[0].device).realize()
for p in optimizer.params:
for p in optimizer.params:
p.grad = p.grad / loss_scaler
global_norm += p.grad.float().square().sum()
global_norm = global_norm.sqrt()
@@ -597,9 +597,9 @@ def train_step_bert(model, optimizer, scheduler, loss_scaler:float, input_ids:Te
@TinyJit
def eval_step_bert(model, input_ids:Tensor, segment_ids:Tensor, attention_mask:Tensor, masked_positions:Tensor, masked_lm_ids:Tensor,
masked_lm_weights:Tensor, next_sentence_labels:Tensor, GPUS):
if len(GPUS) > 1:
for t in [input_ids, segment_ids, attention_mask, masked_positions, masked_lm_ids, masked_lm_weights, next_sentence_labels]:
t.shard_(GPUS, axis=0)
for t in [input_ids, segment_ids, attention_mask, masked_positions, masked_lm_ids, masked_lm_weights, next_sentence_labels]:
if len(GPUS) > 1: t.shard_(GPUS, axis=0)
else: t.to_(GPUS[0])
lm_logits, seq_relationship_logits = model(input_ids, attention_mask, masked_positions, segment_ids)
masked_lm_accuracy, seq_relationship_accuracy, masked_lm_loss, next_sentence_loss = \
model.accuracy(lm_logits, seq_relationship_logits, masked_lm_ids, masked_lm_weights, next_sentence_labels)

View File

@@ -72,6 +72,13 @@ def _test_conv2d(allowed:int, dtype:DType=dtypes.float, **kwargs):
def schedule_graph_rewrite(big_sink:UOp): return graph_rewrite(big_sink, remove_movement_ops+sym, {})
class TestSchedule(unittest.TestCase):
@unittest.skipIf(Device.DEFAULT == "CPU", "devices must mismatch")
def test_error_on_device_mismatch(self):
a = Tensor.empty(10)
b = Tensor.empty(10, device="CPU")
c = a+b
with self.assertRaises(RuntimeError): check_schedule(c, 1)
def test_basic_binop_fusion(self):
a = Tensor.empty(10)
b = Tensor.empty(10)

View File

@@ -384,6 +384,7 @@ def schedule_uop(sink:UOp, var_vals:dict[Variable, int]) -> ScheduleItem:
ast = sink.src[1].arg.ast.substitute({s.src[1].arg.ast:s.src[0] for s in sink.src[1].src if s.op is Ops.ASSIGN}).sink()
# add buffer ops
ast = graph_rewrite(ast, add_buffer_ops, bufs:=[sink.buf_uop], bottom_up=True)
if ast.op is Ops.SINK and not all_same(dev:=[x.device for x in bufs]): raise RuntimeError(f"all buffers must be on the same device: {dev}")
# unbind_vars + push views to edges
ast = graph_rewrite(graph_rewrite(ast, unbind_vars+view_left, ctx=var_vals), view_right)
# fix_kernel_ops