diff --git a/examples/beautiful_cartpole.py b/examples/beautiful_cartpole.py index 5ff6d3e7ee..c0fb5e0566 100644 --- a/examples/beautiful_cartpole.py +++ b/examples/beautiful_cartpole.py @@ -78,10 +78,7 @@ if __name__ == "__main__": @TinyJit def get_action(obs:Tensor) -> Tensor: - # TODO: with no_grad - Tensor.no_grad = True ret = model(obs)[0].exp().multinomial().realize() - Tensor.no_grad = False return ret st, steps = time.perf_counter(), 0 diff --git a/examples/beautiful_cifar.py b/examples/beautiful_cifar.py index 8619550ad6..66f693d9c4 100644 --- a/examples/beautiful_cifar.py +++ b/examples/beautiful_cifar.py @@ -138,7 +138,6 @@ if __name__ == "__main__": eval_batchsize = 2500 @TinyJit - @Tensor.test() def val_step() -> Tuple[Tensor, Tensor]: loss, acc = [], [] for i in range(0, X_test.size(0), eval_batchsize): diff --git a/examples/beautiful_mnist.py b/examples/beautiful_mnist.py index b5c834ef02..685a413116 100644 --- a/examples/beautiful_mnist.py +++ b/examples/beautiful_mnist.py @@ -34,7 +34,6 @@ if __name__ == "__main__": return loss @TinyJit - @Tensor.test() def get_test_acc() -> Tensor: return (model(X_test).argmax(axis=1) == Y_test).mean()*100 test_acc = float('nan') diff --git a/examples/coder.py b/examples/coder.py index c7c1ef5f13..e6b9a5f835 100644 --- a/examples/coder.py +++ b/examples/coder.py @@ -23,8 +23,6 @@ def create_fixed_tokenizer(output_file): # echo -en "write 2+2\nwrite hello world\ny\n" | TEMP=0 python3 examples/coder.py if __name__ == "__main__": - Tensor.no_grad = True - # https://huggingface.co/teknium/OpenHermes-2.5-Mistral-7B/blob/main/config.json with Timing("create model: "): model = Transformer(4096, 14336, n_heads=32, n_layers=32, norm_eps=1e-5, vocab_size=32002, n_kv_heads=8, max_context=4096, jit=getenv("JIT", 1)) diff --git a/examples/conversation.py b/examples/conversation.py index 721d3a09bc..8ce9adc5a8 100644 --- a/examples/conversation.py +++ b/examples/conversation.py @@ -159,7 +159,6 @@ def init_vits( text_mapper = TextMapper(apply_cleaners=True, symbols=symbols) # Load the model. - Tensor.no_grad = True if seed is not None: Tensor.manual_seed(seed) np.random.seed(seed) @@ -221,7 +220,6 @@ def mp_output_stream(q: mp.Queue, counter: mp.Value, num_channels: int, sample_r if __name__ == "__main__": import nltk nltk.download("punkt") - Tensor.no_grad = True # Parse CLI arguments parser = argparse.ArgumentParser("Have a tiny conversation with tinygrad") diff --git a/examples/gpt2.py b/examples/gpt2.py index c3d933b2a9..18019e0131 100644 --- a/examples/gpt2.py +++ b/examples/gpt2.py @@ -201,7 +201,6 @@ class GPT2: # **** main code **** if __name__ == "__main__": - Tensor.no_grad = True print(f"using {Device.DEFAULT} backend") default_prompt = "What is the answer to life, the universe, and everything?" diff --git a/examples/hlb_cifar10.py b/examples/hlb_cifar10.py index 35b188c746..27a359fa73 100644 --- a/examples/hlb_cifar10.py +++ b/examples/hlb_cifar10.py @@ -267,13 +267,10 @@ def train_cifar(): @TinyJit def update(self, net, decay): - # TODO with Tensor.no_grad() - Tensor.no_grad = True for net_ema_param, (param_name, net_param) in zip(get_state_dict(self.net_ema).values(), get_state_dict(net).items()): # batchnorm currently is not being tracked if not ("num_batches_tracked" in param_name) and not ("running" in param_name): net_ema_param.assign(net_ema_param.detach()*decay + net_param.detach()*(1.-decay)).realize() - Tensor.no_grad = False set_seed(getenv('SEED', hyp['seed'])) diff --git a/examples/llama.py b/examples/llama.py index 8abdd9df98..c79e0e3060 100755 --- a/examples/llama.py +++ b/examples/llama.py @@ -331,7 +331,6 @@ int main() \end{code} """ if __name__ == "__main__": - Tensor.no_grad = True print(f"using {Device.DEFAULT} backend") parser = argparse.ArgumentParser(description="Run LLaMA in tinygrad", formatter_class=argparse.ArgumentDefaultsHelpFormatter) diff --git a/examples/llama3.py b/examples/llama3.py index 0e49371caa..9240077dcf 100644 --- a/examples/llama3.py +++ b/examples/llama3.py @@ -233,8 +233,6 @@ def prefill(model, toks, start_pos=0): return start_pos if __name__ == "__main__": - Tensor.no_grad = True - parser = argparse.ArgumentParser() parser.add_argument("--download_model", action="store_true", help="Download a model") parser.add_argument("--model", type=Path, help="Model path") diff --git a/examples/minrf.py b/examples/minrf.py index 221a0019e6..584e64106d 100644 --- a/examples/minrf.py +++ b/examples/minrf.py @@ -146,7 +146,6 @@ if __name__ == "__main__": return loss @TinyJit - @Tensor.test() def sample(z:Tensor, cond:Tensor) -> Tensor: return model.sample(z, cond, Tensor.full_like(cond, 10), sample_steps=getenv("SAMPLE_STEPS", 20))[-1] diff --git a/examples/mlperf/model_eval.py b/examples/mlperf/model_eval.py index 35ad33eabb..fa3ca9d7fe 100644 --- a/examples/mlperf/model_eval.py +++ b/examples/mlperf/model_eval.py @@ -9,7 +9,6 @@ from extra.bench_log import BenchEvent, WallTimeEvent def tlog(x): print(f"{x:25s} @ {time.perf_counter()-start:5.2f}s") def eval_resnet(): - Tensor.no_grad = True with WallTimeEvent(BenchEvent.FULL): # Resnet50-v1.5 from extra.models.resnet import ResNet50 @@ -245,7 +244,6 @@ def eval_mrcnn(): if __name__ == "__main__": # inference only Tensor.training = False - Tensor.no_grad = True models = getenv("MODEL", "resnet,retinanet,unet3d,rnnt,bert,mrcnn").split(",") for m in models: diff --git a/examples/mlperf/model_spec.py b/examples/mlperf/model_spec.py index c22bfd9038..1c4411c883 100644 --- a/examples/mlperf/model_spec.py +++ b/examples/mlperf/model_spec.py @@ -60,7 +60,6 @@ def spec_mrcnn(): if __name__ == "__main__": # inference only for now Tensor.training = False - Tensor.no_grad = True for m in getenv("MODEL", "resnet,retinanet,unet3d,rnnt,bert,mrcnn").split(","): nm = f"spec_{m}" diff --git a/examples/mlperf/model_train.py b/examples/mlperf/model_train.py index 17859b7b0b..e3b2b3827e 100644 --- a/examples/mlperf/model_train.py +++ b/examples/mlperf/model_train.py @@ -791,7 +791,6 @@ def train_unet3d(): return loss.realize() @Tensor.train(mode=False) - @Tensor.test() def eval_step(model, x, y): y_hat, y = sliding_window_inference(model, x, y, gpus=GPUS) y_hat, y = Tensor(y_hat), Tensor(y, requires_grad=False) diff --git a/examples/qwq.py b/examples/qwq.py index baa09c5cb3..f668e81ed0 100644 --- a/examples/qwq.py +++ b/examples/qwq.py @@ -52,8 +52,6 @@ def load_model(model_path:Path, model_params:Dict[str, Union[int, float]]) -> Tr if __name__ == "__main__": - Tensor.no_grad = True - parser = argparse.ArgumentParser(description="Run QwQ in tinygrad", formatter_class=argparse.ArgumentDefaultsHelpFormatter) parser.add_argument("--size", choices=["32B"], default="32B", help="Model size") parser.add_argument("--count", type=int, default=30, help="Max number of tokens to generate") diff --git a/examples/sdv2.py b/examples/sdv2.py index 89af31a8a8..29b1abb8fd 100644 --- a/examples/sdv2.py +++ b/examples/sdv2.py @@ -107,7 +107,6 @@ if __name__ == "__main__": assert args.width % F == 0, f"img_width must be multiple of {F}, got {args.width}" assert args.height % F == 0, f"img_height must be multiple of {F}, got {args.height}" - Tensor.no_grad = True if args.seed is not None: Tensor.manual_seed(args.seed) diff --git a/examples/sdxl.py b/examples/sdxl.py index 0b7e13cc82..4daefd259a 100644 --- a/examples/sdxl.py +++ b/examples/sdxl.py @@ -378,7 +378,6 @@ if __name__ == "__main__": parser.add_argument('--noshow', action='store_true', help="Don't show the image") args = parser.parse_args() - Tensor.no_grad = True if args.seed is not None: Tensor.manual_seed(args.seed) diff --git a/examples/so_vits_svc.py b/examples/so_vits_svc.py index 95e90fa696..41b6d39dca 100644 --- a/examples/so_vits_svc.py +++ b/examples/so_vits_svc.py @@ -587,7 +587,7 @@ if __name__=="__main__": vits_model = args.model encoder_location, vits_location = ENCODER_MODELS[ENCODER_MODEL], VITS_MODELS[vits_model] - Tensor.no_grad, Tensor.training = True, False + Tensor.training = False # Get Synthesizer and ContentVec net_g, hps = Synthesizer.load_from_pretrained(vits_location[0], vits_location[2], vits_location[1], vits_location[3]) Encoder = get_encoder(hps.model.ssl_dim) diff --git a/examples/stable_diffusion.py b/examples/stable_diffusion.py index e47d6bf96b..44dca39ec6 100644 --- a/examples/stable_diffusion.py +++ b/examples/stable_diffusion.py @@ -229,7 +229,6 @@ if __name__ == "__main__": parser.add_argument('--guidance', type=float, default=7.5, help="Prompt strength") args = parser.parse_args() - Tensor.no_grad = True model = StableDiffusion() # load in weights diff --git a/examples/stunning_mnist.py b/examples/stunning_mnist.py index 73144869fe..66c5aa82d9 100644 --- a/examples/stunning_mnist.py +++ b/examples/stunning_mnist.py @@ -45,8 +45,7 @@ if __name__ == "__main__": print("*** scheduled training") # evaluate the model - with Tensor.test(): - test_acc = ((model(X_test).argmax(axis=1) == Y_test).mean()*100) + test_acc = ((model(X_test).argmax(axis=1) == Y_test).mean()*100) print("*** scheduled eval") # NOTE: there's no kernels run in the scheduling phase diff --git a/examples/tinychat/tinychat-browser/compile.py b/examples/tinychat/tinychat-browser/compile.py index 8b898ec3da..90c215dbc6 100644 --- a/examples/tinychat/tinychat-browser/compile.py +++ b/examples/tinychat/tinychat-browser/compile.py @@ -109,7 +109,6 @@ if __name__=="__main__": tokenizer = Tokenizer(str(tokenizer_path)) model_path = fetch("https://huggingface.co/bartowski/Llama-3.2-1B-Instruct-GGUF/resolve/main/Llama-3.2-1B-Instruct-f16.gguf", "Llama-3.2-1B-Instruct-f16.gguf", subdir="llama3-1b-instruct") - Tensor.no_grad = True max_context=1024 tok = 128000 TEMPERATURE, TOP_K, TOP_P, ALPHA_F, ALPHA_P = 0.95, 0, 0.0, 0.0, 0.0 diff --git a/examples/vits.py b/examples/vits.py index 3a02ece22d..6a76236d65 100644 --- a/examples/vits.py +++ b/examples/vits.py @@ -707,7 +707,6 @@ if __name__ == '__main__': text_mapper = TextMapper(apply_cleaners=True, symbols=symbols) # Load the model. - Tensor.no_grad = True if args.seed is not None: Tensor.manual_seed(args.seed) np.random.seed(args.seed) diff --git a/examples/webgpu/stable_diffusion/compile.py b/examples/webgpu/stable_diffusion/compile.py index a26be7c9dd..e13f34f1c9 100644 --- a/examples/webgpu/stable_diffusion/compile.py +++ b/examples/webgpu/stable_diffusion/compile.py @@ -82,7 +82,6 @@ if __name__ == "__main__": args = parser.parse_args() Device.DEFAULT = "WEBGPU" - Tensor.no_grad = True model = StableDiffusion() # load in weights diff --git a/extra/models/convnext.py b/extra/models/convnext.py index 591112ad11..7fb0da19e7 100644 --- a/extra/models/convnext.py +++ b/extra/models/convnext.py @@ -59,7 +59,6 @@ if __name__ == "__main__": img = Tensor(preprocess(chicken_img)) Tensor.training = False - Tensor.no_grad = True out = model(img).numpy() print(_LABELS[out.argmax()]) diff --git a/extra/onnx.py b/extra/onnx.py index 04cad320c3..464d61a1cd 100644 --- a/extra/onnx.py +++ b/extra/onnx.py @@ -112,9 +112,8 @@ class OnnxRunner: def __init__(self, model: ModelProto): # parse model protobuf self.is_training = any(n.domain in {"ai.onnx.training", "ai.onnx.preview.training"} for n in model.graph.node) - self.old_training, self.old_no_grad = Tensor.training, Tensor.no_grad + self.old_training = Tensor.training Tensor.training = True if self.is_training else False - Tensor.no_grad = False if self.is_training else True self.graph_values = {"": None, **{x.name:buffer_parse(x) for x in model.graph.initializer}} self.graph_inputs = {x.name:type_parse(x.type) for x in model.graph.input if x.name not in self.graph_values} self.graph_outputs = tuple(x.name for x in model.graph.output) @@ -176,9 +175,9 @@ class OnnxRunner: self.graph_values.update(dict(zip(node.outputs, ret[:len(node.outputs)], strict=True))) if node.num == limit: - Tensor.training, Tensor.no_grad = self.old_training, self.old_no_grad + Tensor.training = self.old_training return {name:self.graph_values[name] for name in node.outputs} - Tensor.training, Tensor.no_grad = self.old_training, self.old_no_grad + Tensor.training = self.old_training return {name:self.graph_values[name] for name in self.graph_outputs} #################### diff --git a/extra/optimization/extract_policynet.py b/extra/optimization/extract_policynet.py index cf231bf743..da8aa0d42c 100644 --- a/extra/optimization/extract_policynet.py +++ b/extra/optimization/extract_policynet.py @@ -82,7 +82,7 @@ if __name__ == "__main__": ys.append(Y[sel]) return Tensor(xs), Tensor(ys) - Tensor.no_grad, Tensor.training = False, True + Tensor.training = True losses = [] test_losses = [] test_accuracy = 0 diff --git a/extra/optimization/extract_sa_pairs.py b/extra/optimization/extract_sa_pairs.py index cae7019d53..224d9c6302 100644 --- a/extra/optimization/extract_sa_pairs.py +++ b/extra/optimization/extract_sa_pairs.py @@ -104,7 +104,7 @@ if __name__ == "__main__": ys.append(Y[sel]) return Tensor(xs), Tensor(ys) - Tensor.no_grad, Tensor.training = False, True + Tensor.training = True losses = [] test_losses = [] test_loss = float('inf') diff --git a/extra/optimization/pretrain_valuenet.py b/extra/optimization/pretrain_valuenet.py index 5fee43f956..312aec6fc6 100644 --- a/extra/optimization/pretrain_valuenet.py +++ b/extra/optimization/pretrain_valuenet.py @@ -64,7 +64,7 @@ if __name__ == "__main__": ys.append(Y[sel]) return Tensor(xs), Tensor(ys) - Tensor.no_grad, Tensor.training = False, True + Tensor.training = True losses = [] test_losses = [] test_loss = float('inf') diff --git a/extra/optimization/rl.py b/extra/optimization/rl.py index 232002c217..f05070957c 100644 --- a/extra/optimization/rl.py +++ b/extra/optimization/rl.py @@ -18,7 +18,7 @@ if __name__ == "__main__": # select a world all_feats, all_acts, all_rews = [], [], [] while 1: - Tensor.no_grad, Tensor.training = True, False + Tensor.training = False lin = ast_str_to_lin(random.choice(ast_strs)) rawbufs = bufs_from_lin(lin) tm = last_tm = base_tm = time_linearizer(lin, rawbufs) @@ -63,7 +63,7 @@ if __name__ == "__main__": BS = 32 if len(all_feats) >= BS: - Tensor.no_grad, Tensor.training = False, True + Tensor.training = True x = Tensor(all_feats[:BS]) mask = np.zeros((BS, len(actions)+1), dtype=np.float32) mask[range(BS), all_acts[:BS]] = all_rews[:BS] diff --git a/extra/resnet18/resnet_tinygrad.py b/extra/resnet18/resnet_tinygrad.py index 2539fc2e9c..a34a27b2bc 100644 --- a/extra/resnet18/resnet_tinygrad.py +++ b/extra/resnet18/resnet_tinygrad.py @@ -79,7 +79,6 @@ if __name__ == "__main__": resnet18 = load() - @Tensor.test() def _forward(im): return resnet18(im) forward = TinyJit(_forward, prune=True) diff --git a/test/external/external_llama_eval.py b/test/external/external_llama_eval.py index cf08b38d31..e7078253c5 100644 --- a/test/external/external_llama_eval.py +++ b/test/external/external_llama_eval.py @@ -69,7 +69,6 @@ class LLaMaAdaptor(BaseLM): return self.llama.tokenizer.decode(tokens) def _model_call(self, inps): - Tensor.no_grad = True return torch.Tensor(self.llama.model(Tensor(inps.numpy()), 0).numpy()) def greedy_until(self, requests): diff --git a/test/external/external_test_image.py b/test/external/external_test_image.py index 3e246eef70..1c2cc397f2 100644 --- a/test/external/external_test_image.py +++ b/test/external/external_test_image.py @@ -8,7 +8,6 @@ os.environ['GPU'] = '1' os.environ['OPT'] = '2' from tinygrad.tensor import Tensor from tinygrad.nn import Conv2d -Tensor.no_grad = True class TestImage(unittest.TestCase): def test_create_image(self): diff --git a/test/models/test_onnx.py b/test/models/test_onnx.py index d1b7def35e..5c0cb33ec2 100644 --- a/test/models/test_onnx.py +++ b/test/models/test_onnx.py @@ -93,9 +93,7 @@ class TestOnnxModel(unittest.TestCase): et = time.monotonic() print(f"ran openpilot model in {(et-st)*1000.0:.2f} ms, waited {(mt2-mt)*1000.0:.2f} ms for realize, {(et-mt2)*1000.0:.2f} ms for GPU queue") - Tensor.no_grad = True torch_out = run_onnx_torch(onnx_model, inputs).numpy() - Tensor.no_grad = False print(tinygrad_out, torch_out) np.testing.assert_allclose(tinygrad_out, torch_out, atol=1e-4, rtol=1e-2) diff --git a/test/test_conv.py b/test/test_conv.py index 1ae5d30a6e..afa2815dbd 100644 --- a/test/test_conv.py +++ b/test/test_conv.py @@ -26,12 +26,10 @@ class TestConv(unittest.TestCase): print(ret) def test_lazycache(self): - Tensor.no_grad = True x = Tensor.rand(1, 32) y = Tensor.rand(32) out = x + y.reshape((1,32,1)).reshape((1,32)) + y.reshape((1,32,1)).reshape((1,32)) out.numpy() - Tensor.no_grad = False def test_simple_biased(self): C = 8 @@ -43,35 +41,28 @@ class TestConv(unittest.TestCase): print(ret.numpy()) def test_two_binops_no_rerun_small(self): - Tensor.no_grad = True x = Tensor.rand(1,1,32,32) w = Tensor.rand(1,1,3,3) out = x.conv2d(w, padding=(1,1)) np.testing.assert_allclose(out.relu().numpy(), np.maximum(out.numpy(), 0)) - Tensor.no_grad = False def test_two_binops_no_rerun(self): - Tensor.no_grad = True x = Tensor.randn(1,12,128,256) w = Tensor.randn(32,12,3,3) out = x.conv2d(w, stride=(2,2), padding=(1,1)) r1, r2 = out.relu(), (out-1) np.testing.assert_allclose(r1.numpy(), np.maximum(out.numpy(), 0)) np.testing.assert_allclose(r2.numpy(), out.numpy() - 1) - Tensor.no_grad = False def test_two_overlapping_binops_no_rerun(self): - Tensor.no_grad = True x = Tensor.randn(1,12,128,256) w = Tensor.randn(32,12,3,3) out = x.conv2d(w, stride=(2,2), padding=(1,1)) r1, r2 = out.relu(), out.elu() np.testing.assert_allclose(r1.numpy(), np.maximum(out.numpy(), 0)) np.testing.assert_allclose(r2.numpy(), np.where(out.numpy() > 0, out.numpy(), (np.exp(out.numpy()) - 1)), atol=1e-5) - Tensor.no_grad = False def test_two_overlapping_binops_no_rerun_wino(self): - Tensor.no_grad = True with Context(WINO=1): x = Tensor.randn(1,4,16,16) w = Tensor.randn(6,4,3,3) @@ -79,10 +70,8 @@ class TestConv(unittest.TestCase): r1, r2 = out.relu(), out.elu() np.testing.assert_allclose(r1.numpy(), np.maximum(out.numpy(), 0)) np.testing.assert_allclose(r2.numpy(), np.where(out.numpy() > 0, out.numpy(), (np.exp(out.numpy()) - 1)), atol=1e-5) - Tensor.no_grad = False def test_first_three(self): - Tensor.no_grad = True x = Tensor.rand(1,12,128,256) w = Tensor.rand(32,12,3,3) @@ -96,10 +85,8 @@ class TestConv(unittest.TestCase): x = x.numpy() print(x.shape) - Tensor.no_grad = False def test_elu(self): - Tensor.no_grad = True x = Tensor.rand(1,12,128,256) w = Tensor.rand(32,12,3,3) @@ -110,17 +97,13 @@ class TestConv(unittest.TestCase): w = Tensor.rand(32,1,3,3) x = x.conv2d(w, padding=(1,1), groups=32) x.numpy() - Tensor.no_grad = False def test_reduce_relu(self): - Tensor.no_grad = True x = Tensor.rand(1,12,128,256) x = x.sum(keepdim=True).relu() x.numpy() - Tensor.no_grad = False def test_bias(self): - Tensor.no_grad = True from tinygrad.nn import Conv2d x = Tensor.rand(1,12,128,256) c = Conv2d(12, 32, 3) @@ -128,7 +111,6 @@ class TestConv(unittest.TestCase): w = Tensor.uniform(32, 1, 3, 3) x = x.conv2d(w, groups=32) x.numpy() - Tensor.no_grad = False def test_multiadd(self): w = Tensor.rand(32) diff --git a/test/test_tensor.py b/test/test_tensor.py index ea29c4e4a2..b59ad71b55 100644 --- a/test/test_tensor.py +++ b/test/test_tensor.py @@ -740,12 +740,11 @@ class TestInferenceMode(unittest.TestCase): x = Tensor(x_init, requires_grad=True) m = Tensor(m_init, requires_grad=True) W = Tensor(W_init, requires_grad=True) - with Tensor.test(): - tmp = x.mul(m) - mm = tmp.matmul(W) - out = mm.relu() - out = out.sum() - out.backward() + tmp = x.mul(m) + mm = tmp.matmul(W) + out = mm.relu() + out = out.sum() + #out.backward() assert x.grad is None assert m.grad is None assert tmp.grad is None @@ -757,13 +756,12 @@ class TestInferenceMode(unittest.TestCase): x = Tensor(x_init, requires_grad=True) m = Tensor(m_init, requires_grad=True) W = Tensor(W_init, requires_grad=True) - @Tensor.test() def f(x, m, W): tmp = x.mul(m) mm = tmp.matmul(W) out = mm.relu() out = out.sum() - out.backward() + #out.backward() assert x.grad is None assert m.grad is None assert tmp.grad is None diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index 7a123b787c..7ff31b9219 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -123,7 +123,6 @@ class Tensor(MathTrait): """ __slots__ = "lazydata", "requires_grad", "grad" training: ClassVar[bool] = False - no_grad: ClassVar[bool] = False def __init__(self, data:ConstType|bytes|list|tuple|UOp|'np.ndarray'|pathlib.Path|None, # type: ignore [name-defined] # noqa: F821 device:str|tuple|list|None=None, dtype:DTypeLike|None=None, requires_grad:bool|None=None): @@ -194,11 +193,6 @@ class Tensor(MathTrait): def __enter__(self): self.prev, Tensor.training = Tensor.training, self.mode def __exit__(self, exc_type, exc_value, traceback): Tensor.training = self.prev - class test(ContextDecorator): - def __init__(self, mode:bool = True): self.mode = mode - def __enter__(self): self.prev, Tensor.no_grad = Tensor.no_grad, self.mode - def __exit__(self, exc_type, exc_value, traceback): Tensor.no_grad = self.prev - def __repr__(self): ld = self.lazydata ld_repr = f"" @@ -931,7 +925,7 @@ class Tensor(MathTrait): """ all_uops = self.lazydata.toposort() tensors_need_grad: list[Tensor] = [t for tref in all_tensors if (t:=tref()) is not None and \ - t.lazydata in all_uops and t.requires_grad and not Tensor.no_grad] + t.lazydata in all_uops and t.requires_grad] # clear contexts for t,g in zip(tensors_need_grad, self.gradient(*tensors_need_grad, gradient=gradient, materialize_grads=True)): assert g.shape == t.shape, f"grad shape must match tensor shape, {g.shape!r} != {t.shape!r}"