mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-06 21:53:53 -05:00
remove Tensor.no_grad, it's meaningless now [pr] (#10556)
This commit is contained in:
@@ -78,10 +78,7 @@ if __name__ == "__main__":
|
||||
|
||||
@TinyJit
|
||||
def get_action(obs:Tensor) -> Tensor:
|
||||
# TODO: with no_grad
|
||||
Tensor.no_grad = True
|
||||
ret = model(obs)[0].exp().multinomial().realize()
|
||||
Tensor.no_grad = False
|
||||
return ret
|
||||
|
||||
st, steps = time.perf_counter(), 0
|
||||
|
||||
@@ -138,7 +138,6 @@ if __name__ == "__main__":
|
||||
|
||||
eval_batchsize = 2500
|
||||
@TinyJit
|
||||
@Tensor.test()
|
||||
def val_step() -> Tuple[Tensor, Tensor]:
|
||||
loss, acc = [], []
|
||||
for i in range(0, X_test.size(0), eval_batchsize):
|
||||
|
||||
@@ -34,7 +34,6 @@ if __name__ == "__main__":
|
||||
return loss
|
||||
|
||||
@TinyJit
|
||||
@Tensor.test()
|
||||
def get_test_acc() -> Tensor: return (model(X_test).argmax(axis=1) == Y_test).mean()*100
|
||||
|
||||
test_acc = float('nan')
|
||||
|
||||
@@ -23,8 +23,6 @@ def create_fixed_tokenizer(output_file):
|
||||
# echo -en "write 2+2\nwrite hello world\ny\n" | TEMP=0 python3 examples/coder.py
|
||||
|
||||
if __name__ == "__main__":
|
||||
Tensor.no_grad = True
|
||||
|
||||
# https://huggingface.co/teknium/OpenHermes-2.5-Mistral-7B/blob/main/config.json
|
||||
with Timing("create model: "):
|
||||
model = Transformer(4096, 14336, n_heads=32, n_layers=32, norm_eps=1e-5, vocab_size=32002, n_kv_heads=8, max_context=4096, jit=getenv("JIT", 1))
|
||||
|
||||
@@ -159,7 +159,6 @@ def init_vits(
|
||||
text_mapper = TextMapper(apply_cleaners=True, symbols=symbols)
|
||||
|
||||
# Load the model.
|
||||
Tensor.no_grad = True
|
||||
if seed is not None:
|
||||
Tensor.manual_seed(seed)
|
||||
np.random.seed(seed)
|
||||
@@ -221,7 +220,6 @@ def mp_output_stream(q: mp.Queue, counter: mp.Value, num_channels: int, sample_r
|
||||
if __name__ == "__main__":
|
||||
import nltk
|
||||
nltk.download("punkt")
|
||||
Tensor.no_grad = True
|
||||
# Parse CLI arguments
|
||||
parser = argparse.ArgumentParser("Have a tiny conversation with tinygrad")
|
||||
|
||||
|
||||
@@ -201,7 +201,6 @@ class GPT2:
|
||||
# **** main code ****
|
||||
|
||||
if __name__ == "__main__":
|
||||
Tensor.no_grad = True
|
||||
print(f"using {Device.DEFAULT} backend")
|
||||
default_prompt = "What is the answer to life, the universe, and everything?"
|
||||
|
||||
|
||||
@@ -267,13 +267,10 @@ def train_cifar():
|
||||
|
||||
@TinyJit
|
||||
def update(self, net, decay):
|
||||
# TODO with Tensor.no_grad()
|
||||
Tensor.no_grad = True
|
||||
for net_ema_param, (param_name, net_param) in zip(get_state_dict(self.net_ema).values(), get_state_dict(net).items()):
|
||||
# batchnorm currently is not being tracked
|
||||
if not ("num_batches_tracked" in param_name) and not ("running" in param_name):
|
||||
net_ema_param.assign(net_ema_param.detach()*decay + net_param.detach()*(1.-decay)).realize()
|
||||
Tensor.no_grad = False
|
||||
|
||||
set_seed(getenv('SEED', hyp['seed']))
|
||||
|
||||
|
||||
@@ -331,7 +331,6 @@ int main()
|
||||
\end{code}
|
||||
"""
|
||||
if __name__ == "__main__":
|
||||
Tensor.no_grad = True
|
||||
print(f"using {Device.DEFAULT} backend")
|
||||
|
||||
parser = argparse.ArgumentParser(description="Run LLaMA in tinygrad", formatter_class=argparse.ArgumentDefaultsHelpFormatter)
|
||||
|
||||
@@ -233,8 +233,6 @@ def prefill(model, toks, start_pos=0):
|
||||
return start_pos
|
||||
|
||||
if __name__ == "__main__":
|
||||
Tensor.no_grad = True
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--download_model", action="store_true", help="Download a model")
|
||||
parser.add_argument("--model", type=Path, help="Model path")
|
||||
|
||||
@@ -146,7 +146,6 @@ if __name__ == "__main__":
|
||||
return loss
|
||||
|
||||
@TinyJit
|
||||
@Tensor.test()
|
||||
def sample(z:Tensor, cond:Tensor) -> Tensor:
|
||||
return model.sample(z, cond, Tensor.full_like(cond, 10), sample_steps=getenv("SAMPLE_STEPS", 20))[-1]
|
||||
|
||||
|
||||
@@ -9,7 +9,6 @@ from extra.bench_log import BenchEvent, WallTimeEvent
|
||||
def tlog(x): print(f"{x:25s} @ {time.perf_counter()-start:5.2f}s")
|
||||
|
||||
def eval_resnet():
|
||||
Tensor.no_grad = True
|
||||
with WallTimeEvent(BenchEvent.FULL):
|
||||
# Resnet50-v1.5
|
||||
from extra.models.resnet import ResNet50
|
||||
@@ -245,7 +244,6 @@ def eval_mrcnn():
|
||||
if __name__ == "__main__":
|
||||
# inference only
|
||||
Tensor.training = False
|
||||
Tensor.no_grad = True
|
||||
|
||||
models = getenv("MODEL", "resnet,retinanet,unet3d,rnnt,bert,mrcnn").split(",")
|
||||
for m in models:
|
||||
|
||||
@@ -60,7 +60,6 @@ def spec_mrcnn():
|
||||
if __name__ == "__main__":
|
||||
# inference only for now
|
||||
Tensor.training = False
|
||||
Tensor.no_grad = True
|
||||
|
||||
for m in getenv("MODEL", "resnet,retinanet,unet3d,rnnt,bert,mrcnn").split(","):
|
||||
nm = f"spec_{m}"
|
||||
|
||||
@@ -791,7 +791,6 @@ def train_unet3d():
|
||||
return loss.realize()
|
||||
|
||||
@Tensor.train(mode=False)
|
||||
@Tensor.test()
|
||||
def eval_step(model, x, y):
|
||||
y_hat, y = sliding_window_inference(model, x, y, gpus=GPUS)
|
||||
y_hat, y = Tensor(y_hat), Tensor(y, requires_grad=False)
|
||||
|
||||
@@ -52,8 +52,6 @@ def load_model(model_path:Path, model_params:Dict[str, Union[int, float]]) -> Tr
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
Tensor.no_grad = True
|
||||
|
||||
parser = argparse.ArgumentParser(description="Run QwQ in tinygrad", formatter_class=argparse.ArgumentDefaultsHelpFormatter)
|
||||
parser.add_argument("--size", choices=["32B"], default="32B", help="Model size")
|
||||
parser.add_argument("--count", type=int, default=30, help="Max number of tokens to generate")
|
||||
|
||||
@@ -107,7 +107,6 @@ if __name__ == "__main__":
|
||||
assert args.width % F == 0, f"img_width must be multiple of {F}, got {args.width}"
|
||||
assert args.height % F == 0, f"img_height must be multiple of {F}, got {args.height}"
|
||||
|
||||
Tensor.no_grad = True
|
||||
if args.seed is not None:
|
||||
Tensor.manual_seed(args.seed)
|
||||
|
||||
|
||||
@@ -378,7 +378,6 @@ if __name__ == "__main__":
|
||||
parser.add_argument('--noshow', action='store_true', help="Don't show the image")
|
||||
args = parser.parse_args()
|
||||
|
||||
Tensor.no_grad = True
|
||||
if args.seed is not None:
|
||||
Tensor.manual_seed(args.seed)
|
||||
|
||||
|
||||
@@ -587,7 +587,7 @@ if __name__=="__main__":
|
||||
vits_model = args.model
|
||||
encoder_location, vits_location = ENCODER_MODELS[ENCODER_MODEL], VITS_MODELS[vits_model]
|
||||
|
||||
Tensor.no_grad, Tensor.training = True, False
|
||||
Tensor.training = False
|
||||
# Get Synthesizer and ContentVec
|
||||
net_g, hps = Synthesizer.load_from_pretrained(vits_location[0], vits_location[2], vits_location[1], vits_location[3])
|
||||
Encoder = get_encoder(hps.model.ssl_dim)
|
||||
|
||||
@@ -229,7 +229,6 @@ if __name__ == "__main__":
|
||||
parser.add_argument('--guidance', type=float, default=7.5, help="Prompt strength")
|
||||
args = parser.parse_args()
|
||||
|
||||
Tensor.no_grad = True
|
||||
model = StableDiffusion()
|
||||
|
||||
# load in weights
|
||||
|
||||
@@ -45,8 +45,7 @@ if __name__ == "__main__":
|
||||
print("*** scheduled training")
|
||||
|
||||
# evaluate the model
|
||||
with Tensor.test():
|
||||
test_acc = ((model(X_test).argmax(axis=1) == Y_test).mean()*100)
|
||||
test_acc = ((model(X_test).argmax(axis=1) == Y_test).mean()*100)
|
||||
print("*** scheduled eval")
|
||||
|
||||
# NOTE: there's no kernels run in the scheduling phase
|
||||
|
||||
@@ -109,7 +109,6 @@ if __name__=="__main__":
|
||||
tokenizer = Tokenizer(str(tokenizer_path))
|
||||
|
||||
model_path = fetch("https://huggingface.co/bartowski/Llama-3.2-1B-Instruct-GGUF/resolve/main/Llama-3.2-1B-Instruct-f16.gguf", "Llama-3.2-1B-Instruct-f16.gguf", subdir="llama3-1b-instruct")
|
||||
Tensor.no_grad = True
|
||||
max_context=1024
|
||||
tok = 128000
|
||||
TEMPERATURE, TOP_K, TOP_P, ALPHA_F, ALPHA_P = 0.95, 0, 0.0, 0.0, 0.0
|
||||
|
||||
@@ -707,7 +707,6 @@ if __name__ == '__main__':
|
||||
text_mapper = TextMapper(apply_cleaners=True, symbols=symbols)
|
||||
|
||||
# Load the model.
|
||||
Tensor.no_grad = True
|
||||
if args.seed is not None:
|
||||
Tensor.manual_seed(args.seed)
|
||||
np.random.seed(args.seed)
|
||||
|
||||
@@ -82,7 +82,6 @@ if __name__ == "__main__":
|
||||
args = parser.parse_args()
|
||||
Device.DEFAULT = "WEBGPU"
|
||||
|
||||
Tensor.no_grad = True
|
||||
model = StableDiffusion()
|
||||
|
||||
# load in weights
|
||||
|
||||
@@ -59,7 +59,6 @@ if __name__ == "__main__":
|
||||
img = Tensor(preprocess(chicken_img))
|
||||
|
||||
Tensor.training = False
|
||||
Tensor.no_grad = True
|
||||
|
||||
out = model(img).numpy()
|
||||
print(_LABELS[out.argmax()])
|
||||
|
||||
@@ -112,9 +112,8 @@ class OnnxRunner:
|
||||
def __init__(self, model: ModelProto):
|
||||
# parse model protobuf
|
||||
self.is_training = any(n.domain in {"ai.onnx.training", "ai.onnx.preview.training"} for n in model.graph.node)
|
||||
self.old_training, self.old_no_grad = Tensor.training, Tensor.no_grad
|
||||
self.old_training = Tensor.training
|
||||
Tensor.training = True if self.is_training else False
|
||||
Tensor.no_grad = False if self.is_training else True
|
||||
self.graph_values = {"": None, **{x.name:buffer_parse(x) for x in model.graph.initializer}}
|
||||
self.graph_inputs = {x.name:type_parse(x.type) for x in model.graph.input if x.name not in self.graph_values}
|
||||
self.graph_outputs = tuple(x.name for x in model.graph.output)
|
||||
@@ -176,9 +175,9 @@ class OnnxRunner:
|
||||
self.graph_values.update(dict(zip(node.outputs, ret[:len(node.outputs)], strict=True)))
|
||||
|
||||
if node.num == limit:
|
||||
Tensor.training, Tensor.no_grad = self.old_training, self.old_no_grad
|
||||
Tensor.training = self.old_training
|
||||
return {name:self.graph_values[name] for name in node.outputs}
|
||||
Tensor.training, Tensor.no_grad = self.old_training, self.old_no_grad
|
||||
Tensor.training = self.old_training
|
||||
return {name:self.graph_values[name] for name in self.graph_outputs}
|
||||
|
||||
####################
|
||||
|
||||
@@ -82,7 +82,7 @@ if __name__ == "__main__":
|
||||
ys.append(Y[sel])
|
||||
return Tensor(xs), Tensor(ys)
|
||||
|
||||
Tensor.no_grad, Tensor.training = False, True
|
||||
Tensor.training = True
|
||||
losses = []
|
||||
test_losses = []
|
||||
test_accuracy = 0
|
||||
|
||||
@@ -104,7 +104,7 @@ if __name__ == "__main__":
|
||||
ys.append(Y[sel])
|
||||
return Tensor(xs), Tensor(ys)
|
||||
|
||||
Tensor.no_grad, Tensor.training = False, True
|
||||
Tensor.training = True
|
||||
losses = []
|
||||
test_losses = []
|
||||
test_loss = float('inf')
|
||||
|
||||
@@ -64,7 +64,7 @@ if __name__ == "__main__":
|
||||
ys.append(Y[sel])
|
||||
return Tensor(xs), Tensor(ys)
|
||||
|
||||
Tensor.no_grad, Tensor.training = False, True
|
||||
Tensor.training = True
|
||||
losses = []
|
||||
test_losses = []
|
||||
test_loss = float('inf')
|
||||
|
||||
@@ -18,7 +18,7 @@ if __name__ == "__main__":
|
||||
# select a world
|
||||
all_feats, all_acts, all_rews = [], [], []
|
||||
while 1:
|
||||
Tensor.no_grad, Tensor.training = True, False
|
||||
Tensor.training = False
|
||||
lin = ast_str_to_lin(random.choice(ast_strs))
|
||||
rawbufs = bufs_from_lin(lin)
|
||||
tm = last_tm = base_tm = time_linearizer(lin, rawbufs)
|
||||
@@ -63,7 +63,7 @@ if __name__ == "__main__":
|
||||
|
||||
BS = 32
|
||||
if len(all_feats) >= BS:
|
||||
Tensor.no_grad, Tensor.training = False, True
|
||||
Tensor.training = True
|
||||
x = Tensor(all_feats[:BS])
|
||||
mask = np.zeros((BS, len(actions)+1), dtype=np.float32)
|
||||
mask[range(BS), all_acts[:BS]] = all_rews[:BS]
|
||||
|
||||
@@ -79,7 +79,6 @@ if __name__ == "__main__":
|
||||
|
||||
resnet18 = load()
|
||||
|
||||
@Tensor.test()
|
||||
def _forward(im): return resnet18(im)
|
||||
forward = TinyJit(_forward, prune=True)
|
||||
|
||||
|
||||
1
test/external/external_llama_eval.py
vendored
1
test/external/external_llama_eval.py
vendored
@@ -69,7 +69,6 @@ class LLaMaAdaptor(BaseLM):
|
||||
return self.llama.tokenizer.decode(tokens)
|
||||
|
||||
def _model_call(self, inps):
|
||||
Tensor.no_grad = True
|
||||
return torch.Tensor(self.llama.model(Tensor(inps.numpy()), 0).numpy())
|
||||
|
||||
def greedy_until(self, requests):
|
||||
|
||||
1
test/external/external_test_image.py
vendored
1
test/external/external_test_image.py
vendored
@@ -8,7 +8,6 @@ os.environ['GPU'] = '1'
|
||||
os.environ['OPT'] = '2'
|
||||
from tinygrad.tensor import Tensor
|
||||
from tinygrad.nn import Conv2d
|
||||
Tensor.no_grad = True
|
||||
|
||||
class TestImage(unittest.TestCase):
|
||||
def test_create_image(self):
|
||||
|
||||
@@ -93,9 +93,7 @@ class TestOnnxModel(unittest.TestCase):
|
||||
et = time.monotonic()
|
||||
print(f"ran openpilot model in {(et-st)*1000.0:.2f} ms, waited {(mt2-mt)*1000.0:.2f} ms for realize, {(et-mt2)*1000.0:.2f} ms for GPU queue")
|
||||
|
||||
Tensor.no_grad = True
|
||||
torch_out = run_onnx_torch(onnx_model, inputs).numpy()
|
||||
Tensor.no_grad = False
|
||||
print(tinygrad_out, torch_out)
|
||||
np.testing.assert_allclose(tinygrad_out, torch_out, atol=1e-4, rtol=1e-2)
|
||||
|
||||
|
||||
@@ -26,12 +26,10 @@ class TestConv(unittest.TestCase):
|
||||
print(ret)
|
||||
|
||||
def test_lazycache(self):
|
||||
Tensor.no_grad = True
|
||||
x = Tensor.rand(1, 32)
|
||||
y = Tensor.rand(32)
|
||||
out = x + y.reshape((1,32,1)).reshape((1,32)) + y.reshape((1,32,1)).reshape((1,32))
|
||||
out.numpy()
|
||||
Tensor.no_grad = False
|
||||
|
||||
def test_simple_biased(self):
|
||||
C = 8
|
||||
@@ -43,35 +41,28 @@ class TestConv(unittest.TestCase):
|
||||
print(ret.numpy())
|
||||
|
||||
def test_two_binops_no_rerun_small(self):
|
||||
Tensor.no_grad = True
|
||||
x = Tensor.rand(1,1,32,32)
|
||||
w = Tensor.rand(1,1,3,3)
|
||||
out = x.conv2d(w, padding=(1,1))
|
||||
np.testing.assert_allclose(out.relu().numpy(), np.maximum(out.numpy(), 0))
|
||||
Tensor.no_grad = False
|
||||
|
||||
def test_two_binops_no_rerun(self):
|
||||
Tensor.no_grad = True
|
||||
x = Tensor.randn(1,12,128,256)
|
||||
w = Tensor.randn(32,12,3,3)
|
||||
out = x.conv2d(w, stride=(2,2), padding=(1,1))
|
||||
r1, r2 = out.relu(), (out-1)
|
||||
np.testing.assert_allclose(r1.numpy(), np.maximum(out.numpy(), 0))
|
||||
np.testing.assert_allclose(r2.numpy(), out.numpy() - 1)
|
||||
Tensor.no_grad = False
|
||||
|
||||
def test_two_overlapping_binops_no_rerun(self):
|
||||
Tensor.no_grad = True
|
||||
x = Tensor.randn(1,12,128,256)
|
||||
w = Tensor.randn(32,12,3,3)
|
||||
out = x.conv2d(w, stride=(2,2), padding=(1,1))
|
||||
r1, r2 = out.relu(), out.elu()
|
||||
np.testing.assert_allclose(r1.numpy(), np.maximum(out.numpy(), 0))
|
||||
np.testing.assert_allclose(r2.numpy(), np.where(out.numpy() > 0, out.numpy(), (np.exp(out.numpy()) - 1)), atol=1e-5)
|
||||
Tensor.no_grad = False
|
||||
|
||||
def test_two_overlapping_binops_no_rerun_wino(self):
|
||||
Tensor.no_grad = True
|
||||
with Context(WINO=1):
|
||||
x = Tensor.randn(1,4,16,16)
|
||||
w = Tensor.randn(6,4,3,3)
|
||||
@@ -79,10 +70,8 @@ class TestConv(unittest.TestCase):
|
||||
r1, r2 = out.relu(), out.elu()
|
||||
np.testing.assert_allclose(r1.numpy(), np.maximum(out.numpy(), 0))
|
||||
np.testing.assert_allclose(r2.numpy(), np.where(out.numpy() > 0, out.numpy(), (np.exp(out.numpy()) - 1)), atol=1e-5)
|
||||
Tensor.no_grad = False
|
||||
|
||||
def test_first_three(self):
|
||||
Tensor.no_grad = True
|
||||
x = Tensor.rand(1,12,128,256)
|
||||
|
||||
w = Tensor.rand(32,12,3,3)
|
||||
@@ -96,10 +85,8 @@ class TestConv(unittest.TestCase):
|
||||
|
||||
x = x.numpy()
|
||||
print(x.shape)
|
||||
Tensor.no_grad = False
|
||||
|
||||
def test_elu(self):
|
||||
Tensor.no_grad = True
|
||||
x = Tensor.rand(1,12,128,256)
|
||||
|
||||
w = Tensor.rand(32,12,3,3)
|
||||
@@ -110,17 +97,13 @@ class TestConv(unittest.TestCase):
|
||||
w = Tensor.rand(32,1,3,3)
|
||||
x = x.conv2d(w, padding=(1,1), groups=32)
|
||||
x.numpy()
|
||||
Tensor.no_grad = False
|
||||
|
||||
def test_reduce_relu(self):
|
||||
Tensor.no_grad = True
|
||||
x = Tensor.rand(1,12,128,256)
|
||||
x = x.sum(keepdim=True).relu()
|
||||
x.numpy()
|
||||
Tensor.no_grad = False
|
||||
|
||||
def test_bias(self):
|
||||
Tensor.no_grad = True
|
||||
from tinygrad.nn import Conv2d
|
||||
x = Tensor.rand(1,12,128,256)
|
||||
c = Conv2d(12, 32, 3)
|
||||
@@ -128,7 +111,6 @@ class TestConv(unittest.TestCase):
|
||||
w = Tensor.uniform(32, 1, 3, 3)
|
||||
x = x.conv2d(w, groups=32)
|
||||
x.numpy()
|
||||
Tensor.no_grad = False
|
||||
|
||||
def test_multiadd(self):
|
||||
w = Tensor.rand(32)
|
||||
|
||||
@@ -740,12 +740,11 @@ class TestInferenceMode(unittest.TestCase):
|
||||
x = Tensor(x_init, requires_grad=True)
|
||||
m = Tensor(m_init, requires_grad=True)
|
||||
W = Tensor(W_init, requires_grad=True)
|
||||
with Tensor.test():
|
||||
tmp = x.mul(m)
|
||||
mm = tmp.matmul(W)
|
||||
out = mm.relu()
|
||||
out = out.sum()
|
||||
out.backward()
|
||||
tmp = x.mul(m)
|
||||
mm = tmp.matmul(W)
|
||||
out = mm.relu()
|
||||
out = out.sum()
|
||||
#out.backward()
|
||||
assert x.grad is None
|
||||
assert m.grad is None
|
||||
assert tmp.grad is None
|
||||
@@ -757,13 +756,12 @@ class TestInferenceMode(unittest.TestCase):
|
||||
x = Tensor(x_init, requires_grad=True)
|
||||
m = Tensor(m_init, requires_grad=True)
|
||||
W = Tensor(W_init, requires_grad=True)
|
||||
@Tensor.test()
|
||||
def f(x, m, W):
|
||||
tmp = x.mul(m)
|
||||
mm = tmp.matmul(W)
|
||||
out = mm.relu()
|
||||
out = out.sum()
|
||||
out.backward()
|
||||
#out.backward()
|
||||
assert x.grad is None
|
||||
assert m.grad is None
|
||||
assert tmp.grad is None
|
||||
|
||||
@@ -123,7 +123,6 @@ class Tensor(MathTrait):
|
||||
"""
|
||||
__slots__ = "lazydata", "requires_grad", "grad"
|
||||
training: ClassVar[bool] = False
|
||||
no_grad: ClassVar[bool] = False
|
||||
|
||||
def __init__(self, data:ConstType|bytes|list|tuple|UOp|'np.ndarray'|pathlib.Path|None, # type: ignore [name-defined] # noqa: F821
|
||||
device:str|tuple|list|None=None, dtype:DTypeLike|None=None, requires_grad:bool|None=None):
|
||||
@@ -194,11 +193,6 @@ class Tensor(MathTrait):
|
||||
def __enter__(self): self.prev, Tensor.training = Tensor.training, self.mode
|
||||
def __exit__(self, exc_type, exc_value, traceback): Tensor.training = self.prev
|
||||
|
||||
class test(ContextDecorator):
|
||||
def __init__(self, mode:bool = True): self.mode = mode
|
||||
def __enter__(self): self.prev, Tensor.no_grad = Tensor.no_grad, self.mode
|
||||
def __exit__(self, exc_type, exc_value, traceback): Tensor.no_grad = self.prev
|
||||
|
||||
def __repr__(self):
|
||||
ld = self.lazydata
|
||||
ld_repr = f"<UOp {ld.device} {ld.shape} {str(ld.dtype)[7:]} {ld.st if ld.base is not ld else (ld.op, ld.realized)}>"
|
||||
@@ -931,7 +925,7 @@ class Tensor(MathTrait):
|
||||
"""
|
||||
all_uops = self.lazydata.toposort()
|
||||
tensors_need_grad: list[Tensor] = [t for tref in all_tensors if (t:=tref()) is not None and \
|
||||
t.lazydata in all_uops and t.requires_grad and not Tensor.no_grad]
|
||||
t.lazydata in all_uops and t.requires_grad]
|
||||
# clear contexts
|
||||
for t,g in zip(tensors_need_grad, self.gradient(*tensors_need_grad, gradient=gradient, materialize_grads=True)):
|
||||
assert g.shape == t.shape, f"grad shape must match tensor shape, {g.shape!r} != {t.shape!r}"
|
||||
|
||||
Reference in New Issue
Block a user