mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 23:18:04 -05:00
test for device mismatch [pr] (#9250)
* test for device mismatch [pr] * fix bert
This commit is contained in:
@@ -273,7 +273,7 @@ def train_resnet():
|
||||
else:
|
||||
it = iter(tqdm(batch_load_resnet(batch_size=EVAL_BS, val=True, shuffle=False, pad_first_batch=True), total=steps_in_val_epoch))
|
||||
i, proc = 0, data_get(it)
|
||||
|
||||
|
||||
prev_cookies = []
|
||||
while proc is not None:
|
||||
GlobalCounters.reset()
|
||||
@@ -446,7 +446,7 @@ def train_unet3d():
|
||||
loss.backward()
|
||||
optim.step()
|
||||
return loss.realize()
|
||||
|
||||
|
||||
@Tensor.train(mode=False)
|
||||
@Tensor.test()
|
||||
def eval_step(model, x, y):
|
||||
@@ -455,7 +455,7 @@ def train_unet3d():
|
||||
loss = dice_ce_loss(y_hat, y)
|
||||
score = dice_score(y_hat, y)
|
||||
return loss.realize(), score.realize()
|
||||
|
||||
|
||||
if WANDB: wandb.init(config=config, project=PROJ_NAME)
|
||||
|
||||
step_times, start_epoch = [], 1
|
||||
@@ -464,7 +464,7 @@ def train_unet3d():
|
||||
next_eval_at = start_eval_at
|
||||
|
||||
print(f"Training on {GPUS}")
|
||||
|
||||
|
||||
if BENCHMARK: print("Benchmarking UNet3D")
|
||||
else: print(f"Start evaluation at epoch {start_eval_at} and every {evaluate_every} epoch(s) afterwards")
|
||||
|
||||
@@ -574,9 +574,9 @@ def train_rnnt():
|
||||
@TinyJit
|
||||
def train_step_bert(model, optimizer, scheduler, loss_scaler:float, input_ids:Tensor, segment_ids:Tensor, attention_mask:Tensor,
|
||||
masked_positions:Tensor, masked_lm_ids:Tensor, masked_lm_weights:Tensor, next_sentence_labels:Tensor, GPUS):
|
||||
if len(GPUS) > 1:
|
||||
for t in [input_ids, segment_ids, attention_mask, masked_positions, masked_lm_ids, masked_lm_weights, next_sentence_labels]:
|
||||
t.shard_(GPUS, axis=0)
|
||||
for t in [input_ids, segment_ids, attention_mask, masked_positions, masked_lm_ids, masked_lm_weights, next_sentence_labels]:
|
||||
if len(GPUS) > 1: t.shard_(GPUS, axis=0)
|
||||
else: t.to_(GPUS[0])
|
||||
optimizer.zero_grad()
|
||||
|
||||
lm_logits, seq_relationship_logits = model(input_ids, attention_mask, masked_positions, segment_ids)
|
||||
@@ -584,7 +584,7 @@ def train_step_bert(model, optimizer, scheduler, loss_scaler:float, input_ids:Te
|
||||
(loss * loss_scaler).backward()
|
||||
|
||||
global_norm = Tensor([0.0], dtype=dtypes.float32, device=optimizer[0].device).realize()
|
||||
for p in optimizer.params:
|
||||
for p in optimizer.params:
|
||||
p.grad = p.grad / loss_scaler
|
||||
global_norm += p.grad.float().square().sum()
|
||||
global_norm = global_norm.sqrt()
|
||||
@@ -597,9 +597,9 @@ def train_step_bert(model, optimizer, scheduler, loss_scaler:float, input_ids:Te
|
||||
@TinyJit
|
||||
def eval_step_bert(model, input_ids:Tensor, segment_ids:Tensor, attention_mask:Tensor, masked_positions:Tensor, masked_lm_ids:Tensor,
|
||||
masked_lm_weights:Tensor, next_sentence_labels:Tensor, GPUS):
|
||||
if len(GPUS) > 1:
|
||||
for t in [input_ids, segment_ids, attention_mask, masked_positions, masked_lm_ids, masked_lm_weights, next_sentence_labels]:
|
||||
t.shard_(GPUS, axis=0)
|
||||
for t in [input_ids, segment_ids, attention_mask, masked_positions, masked_lm_ids, masked_lm_weights, next_sentence_labels]:
|
||||
if len(GPUS) > 1: t.shard_(GPUS, axis=0)
|
||||
else: t.to_(GPUS[0])
|
||||
lm_logits, seq_relationship_logits = model(input_ids, attention_mask, masked_positions, segment_ids)
|
||||
masked_lm_accuracy, seq_relationship_accuracy, masked_lm_loss, next_sentence_loss = \
|
||||
model.accuracy(lm_logits, seq_relationship_logits, masked_lm_ids, masked_lm_weights, next_sentence_labels)
|
||||
|
||||
@@ -72,6 +72,13 @@ def _test_conv2d(allowed:int, dtype:DType=dtypes.float, **kwargs):
|
||||
def schedule_graph_rewrite(big_sink:UOp): return graph_rewrite(big_sink, remove_movement_ops+sym, {})
|
||||
|
||||
class TestSchedule(unittest.TestCase):
|
||||
@unittest.skipIf(Device.DEFAULT == "CPU", "devices must mismatch")
|
||||
def test_error_on_device_mismatch(self):
|
||||
a = Tensor.empty(10)
|
||||
b = Tensor.empty(10, device="CPU")
|
||||
c = a+b
|
||||
with self.assertRaises(RuntimeError): check_schedule(c, 1)
|
||||
|
||||
def test_basic_binop_fusion(self):
|
||||
a = Tensor.empty(10)
|
||||
b = Tensor.empty(10)
|
||||
|
||||
@@ -384,6 +384,7 @@ def schedule_uop(sink:UOp, var_vals:dict[Variable, int]) -> ScheduleItem:
|
||||
ast = sink.src[1].arg.ast.substitute({s.src[1].arg.ast:s.src[0] for s in sink.src[1].src if s.op is Ops.ASSIGN}).sink()
|
||||
# add buffer ops
|
||||
ast = graph_rewrite(ast, add_buffer_ops, bufs:=[sink.buf_uop], bottom_up=True)
|
||||
if ast.op is Ops.SINK and not all_same(dev:=[x.device for x in bufs]): raise RuntimeError(f"all buffers must be on the same device: {dev}")
|
||||
# unbind_vars + push views to edges
|
||||
ast = graph_rewrite(graph_rewrite(ast, unbind_vars+view_left, ctx=var_vals), view_right)
|
||||
# fix_kernel_ops
|
||||
|
||||
Reference in New Issue
Block a user