diff --git a/examples/mlperf/model_train.py b/examples/mlperf/model_train.py index f7c6abc06e..85202ad9c0 100644 --- a/examples/mlperf/model_train.py +++ b/examples/mlperf/model_train.py @@ -273,7 +273,7 @@ def train_resnet(): else: it = iter(tqdm(batch_load_resnet(batch_size=EVAL_BS, val=True, shuffle=False, pad_first_batch=True), total=steps_in_val_epoch)) i, proc = 0, data_get(it) - + prev_cookies = [] while proc is not None: GlobalCounters.reset() @@ -446,7 +446,7 @@ def train_unet3d(): loss.backward() optim.step() return loss.realize() - + @Tensor.train(mode=False) @Tensor.test() def eval_step(model, x, y): @@ -455,7 +455,7 @@ def train_unet3d(): loss = dice_ce_loss(y_hat, y) score = dice_score(y_hat, y) return loss.realize(), score.realize() - + if WANDB: wandb.init(config=config, project=PROJ_NAME) step_times, start_epoch = [], 1 @@ -464,7 +464,7 @@ def train_unet3d(): next_eval_at = start_eval_at print(f"Training on {GPUS}") - + if BENCHMARK: print("Benchmarking UNet3D") else: print(f"Start evaluation at epoch {start_eval_at} and every {evaluate_every} epoch(s) afterwards") @@ -574,9 +574,9 @@ def train_rnnt(): @TinyJit def train_step_bert(model, optimizer, scheduler, loss_scaler:float, input_ids:Tensor, segment_ids:Tensor, attention_mask:Tensor, masked_positions:Tensor, masked_lm_ids:Tensor, masked_lm_weights:Tensor, next_sentence_labels:Tensor, GPUS): - if len(GPUS) > 1: - for t in [input_ids, segment_ids, attention_mask, masked_positions, masked_lm_ids, masked_lm_weights, next_sentence_labels]: - t.shard_(GPUS, axis=0) + for t in [input_ids, segment_ids, attention_mask, masked_positions, masked_lm_ids, masked_lm_weights, next_sentence_labels]: + if len(GPUS) > 1: t.shard_(GPUS, axis=0) + else: t.to_(GPUS[0]) optimizer.zero_grad() lm_logits, seq_relationship_logits = model(input_ids, attention_mask, masked_positions, segment_ids) @@ -584,7 +584,7 @@ def train_step_bert(model, optimizer, scheduler, loss_scaler:float, input_ids:Te (loss * loss_scaler).backward() global_norm = Tensor([0.0], dtype=dtypes.float32, device=optimizer[0].device).realize() - for p in optimizer.params: + for p in optimizer.params: p.grad = p.grad / loss_scaler global_norm += p.grad.float().square().sum() global_norm = global_norm.sqrt() @@ -597,9 +597,9 @@ def train_step_bert(model, optimizer, scheduler, loss_scaler:float, input_ids:Te @TinyJit def eval_step_bert(model, input_ids:Tensor, segment_ids:Tensor, attention_mask:Tensor, masked_positions:Tensor, masked_lm_ids:Tensor, masked_lm_weights:Tensor, next_sentence_labels:Tensor, GPUS): - if len(GPUS) > 1: - for t in [input_ids, segment_ids, attention_mask, masked_positions, masked_lm_ids, masked_lm_weights, next_sentence_labels]: - t.shard_(GPUS, axis=0) + for t in [input_ids, segment_ids, attention_mask, masked_positions, masked_lm_ids, masked_lm_weights, next_sentence_labels]: + if len(GPUS) > 1: t.shard_(GPUS, axis=0) + else: t.to_(GPUS[0]) lm_logits, seq_relationship_logits = model(input_ids, attention_mask, masked_positions, segment_ids) masked_lm_accuracy, seq_relationship_accuracy, masked_lm_loss, next_sentence_loss = \ model.accuracy(lm_logits, seq_relationship_logits, masked_lm_ids, masked_lm_weights, next_sentence_labels) diff --git a/test/test_schedule.py b/test/test_schedule.py index 767d18ef93..80268a5c1f 100644 --- a/test/test_schedule.py +++ b/test/test_schedule.py @@ -72,6 +72,13 @@ def _test_conv2d(allowed:int, dtype:DType=dtypes.float, **kwargs): def schedule_graph_rewrite(big_sink:UOp): return graph_rewrite(big_sink, remove_movement_ops+sym, {}) class TestSchedule(unittest.TestCase): + @unittest.skipIf(Device.DEFAULT == "CPU", "devices must mismatch") + def test_error_on_device_mismatch(self): + a = Tensor.empty(10) + b = Tensor.empty(10, device="CPU") + c = a+b + with self.assertRaises(RuntimeError): check_schedule(c, 1) + def test_basic_binop_fusion(self): a = Tensor.empty(10) b = Tensor.empty(10) diff --git a/tinygrad/engine/schedule.py b/tinygrad/engine/schedule.py index ac4785074f..b1460f2212 100644 --- a/tinygrad/engine/schedule.py +++ b/tinygrad/engine/schedule.py @@ -384,6 +384,7 @@ def schedule_uop(sink:UOp, var_vals:dict[Variable, int]) -> ScheduleItem: ast = sink.src[1].arg.ast.substitute({s.src[1].arg.ast:s.src[0] for s in sink.src[1].src if s.op is Ops.ASSIGN}).sink() # add buffer ops ast = graph_rewrite(ast, add_buffer_ops, bufs:=[sink.buf_uop], bottom_up=True) + if ast.op is Ops.SINK and not all_same(dev:=[x.device for x in bufs]): raise RuntimeError(f"all buffers must be on the same device: {dev}") # unbind_vars + push views to edges ast = graph_rewrite(graph_rewrite(ast, unbind_vars+view_left, ctx=var_vals), view_right) # fix_kernel_ops