diff --git a/docs/quickstart.md b/docs/quickstart.md index 08f6f1e7e6..9ef9ec73db 100644 --- a/docs/quickstart.md +++ b/docs/quickstart.md @@ -145,7 +145,7 @@ def sparse_categorical_crossentropy(out, Y, ignore_index=-1): loss_mask = Y != ignore_index num_classes = out.shape[-1] y_counter = Tensor.arange(num_classes, requires_grad=False).unsqueeze(0).expand(Y.numel(), num_classes) - y = (y_counter == Y.flatten().reshape(-1, 1)).where(-1.0, 0) + y = (y_counter == Y.flatten().reshape(-1, 1)).where(-1.0, 0) y = y * loss_mask.reshape(-1, 1) y = y.reshape(*Y.shape, num_classes) return out.log_softmax().mul(y).sum() / loss_mask.sum() @@ -165,7 +165,7 @@ opt = SGD([net.l1.weight, net.l2.weight], lr=3e-4) We can see that we are passing in the parameters of our neural network to the optimizer. This is due to the fact that the optimizer needs to know which parameters to update. -There is a simpler way to do this just by using `get_parameters(net)` from `tinygrad.state` which will return a list of all the parameters in the neural network. +There is a simpler way to do this just by using `get_parameters(net)` from `tinygrad.nn.state` which will return a list of all the parameters in the neural network. The parameters are just listed out explicitly here for clarity. Now that we have our network, loss function, and optimizer defined all we are missing is the data to train on! @@ -291,7 +291,7 @@ The standard weight format for tinygrad is [safetensors](https://github.com/hugg There are functions in [state.py](/tinygrad/state.py) to save and load models to and from this format. ```python -from tinygrad.state import safe_save, safe_load, get_state_dict, load_state_dict +from tinygrad.nn.state import safe_save, safe_load, get_state_dict, load_state_dict # first we need the state dict of our model state_dict = get_state_dict(net) diff --git a/examples/benchmark_train_efficientnet.py b/examples/benchmark_train_efficientnet.py index c22e0262f6..18dcfddda8 100755 --- a/examples/benchmark_train_efficientnet.py +++ b/examples/benchmark_train_efficientnet.py @@ -3,7 +3,7 @@ import gc import time from tqdm import trange from models.efficientnet import EfficientNet -from tinygrad.state import get_parameters +from tinygrad.nn.state import get_parameters from tinygrad.nn import optim from tinygrad.tensor import Tensor from tinygrad.ops import GlobalCounters diff --git a/examples/compile_efficientnet.py b/examples/compile_efficientnet.py index bf7ade2d44..5c2cb57f98 100644 --- a/examples/compile_efficientnet.py +++ b/examples/compile_efficientnet.py @@ -1,6 +1,6 @@ from models.efficientnet import EfficientNet from tinygrad.tensor import Tensor -from tinygrad.state import safe_save +from tinygrad.nn.state import safe_save from extra.utils import fetch from extra.export_model import export_model from tinygrad.helpers import getenv diff --git a/examples/deep_deterministic_policy_gradient.py b/examples/deep_deterministic_policy_gradient.py index 906069f0db..aa88c97dd4 100644 --- a/examples/deep_deterministic_policy_gradient.py +++ b/examples/deep_deterministic_policy_gradient.py @@ -1,7 +1,7 @@ from typing import Optional, Tuple from numpy.typing import NDArray -from tinygrad.state import get_parameters +from tinygrad.nn.state import get_parameters from tinygrad.tensor import Tensor from tinygrad.nn import optim from tinygrad.helpers import getenv diff --git a/examples/gpt2.py b/examples/gpt2.py index d695ddd52b..562312cb3f 100644 --- a/examples/gpt2.py +++ b/examples/gpt2.py @@ -132,7 +132,7 @@ class GPT2: @staticmethod def build(model_size="gpt2"): import tiktoken - from tinygrad.state import torch_load, load_state_dict, get_state_dict + from tinygrad.nn.state import torch_load, load_state_dict, get_state_dict from extra.utils import fetch_as_file tokenizer = tiktoken.get_encoding("gpt2") diff --git a/examples/hlb_cifar10.py b/examples/hlb_cifar10.py index 96bab82e9c..3581c9526c 100644 --- a/examples/hlb_cifar10.py +++ b/examples/hlb_cifar10.py @@ -14,7 +14,7 @@ import random import numpy as np from extra.datasets import fetch_cifar, cifar_mean, cifar_std from tinygrad import nn -from tinygrad.state import get_state_dict +from tinygrad.nn.state import get_state_dict from tinygrad.nn import optim from tinygrad.lazy import Device from tinygrad.tensor import Tensor diff --git a/examples/llama.py b/examples/llama.py index 77f198604f..d81167efc9 100755 --- a/examples/llama.py +++ b/examples/llama.py @@ -251,7 +251,7 @@ class LLaMa: sp_model = SentencePieceProcessor(model_file=str(tokenizer_path)) assert sp_model.vocab_size() == VOCAB_SIZE - from tinygrad.state import torch_load, load_state_dict + from tinygrad.nn.state import torch_load, load_state_dict params = MODEL_PARAMS[model_gen][model_size] model = Transformer(**params["args"], linear=AbsmaxQuantizedLinear) if quantize else Transformer(**params["args"]) weights = concat_weights([torch_load(filename) for filename in [f"{model_path}/{model_size}/consolidated.{i:02d}.pth" for i in range(params["files"])]]) diff --git a/examples/mnist_gan.py b/examples/mnist_gan.py index e7f91036c6..0d4c0a34c1 100644 --- a/examples/mnist_gan.py +++ b/examples/mnist_gan.py @@ -3,7 +3,7 @@ import numpy as np from tqdm import trange import torch from torchvision.utils import make_grid, save_image -from tinygrad.state import get_parameters +from tinygrad.nn.state import get_parameters from tinygrad.tensor import Tensor from tinygrad.helpers import getenv from tinygrad.nn import optim diff --git a/examples/serious_mnist.py b/examples/serious_mnist.py index 87c3937e95..b0c4c69ae5 100644 --- a/examples/serious_mnist.py +++ b/examples/serious_mnist.py @@ -2,7 +2,7 @@ #inspired by https://github.com/Matuzas77/MNIST-0.17/blob/master/MNIST_final_solution.ipynb import sys import numpy as np -from tinygrad.state import get_parameters +from tinygrad.nn.state import get_parameters from tinygrad.tensor import Tensor from tinygrad.nn import BatchNorm2d, optim from tinygrad.helpers import getenv diff --git a/examples/simple_conv_bn.py b/examples/simple_conv_bn.py index 36ccaaba51..16182c26af 100644 --- a/examples/simple_conv_bn.py +++ b/examples/simple_conv_bn.py @@ -2,7 +2,7 @@ from tinygrad.tensor import Tensor from tinygrad.nn import Conv2d, BatchNorm2d -from tinygrad.state import get_parameters +from tinygrad.nn.state import get_parameters if __name__ == "__main__": Tensor.training = True diff --git a/examples/so_vits_svc.py b/examples/so_vits_svc.py index bef84f3296..37a81be6fb 100644 --- a/examples/so_vits_svc.py +++ b/examples/so_vits_svc.py @@ -7,7 +7,7 @@ from typing import Tuple, Optional, Type from tinygrad import nn from tinygrad.tensor import Tensor from tinygrad.helpers import dtypes, getenv -from tinygrad.state import torch_load +from tinygrad.nn.state import torch_load from examples.vits import ResidualCouplingBlock, PosteriorEncoder, Encoder, ResBlock1, ResBlock2, LRELU_SLOPE, sequence_mask, split, download_if_not_present, get_hparams_from_file, load_checkpoint, weight_norm, HParams from examples.sovits_helpers import preprocess import soundfile diff --git a/examples/stable_diffusion.py b/examples/stable_diffusion.py index 810089b9a9..577b39a548 100644 --- a/examples/stable_diffusion.py +++ b/examples/stable_diffusion.py @@ -12,7 +12,7 @@ from tinygrad.tensor import Tensor from tinygrad.helpers import dtypes, GlobalCounters from tinygrad.nn import Conv2d, Linear, GroupNorm, LayerNorm, Embedding from extra.utils import download_file -from tinygrad.state import torch_load, load_state_dict, get_state_dict +from tinygrad.nn.state import torch_load, load_state_dict, get_state_dict class AttnBlock: def __init__(self, in_channels): diff --git a/examples/train_efficientnet.py b/examples/train_efficientnet.py index 22dc290866..98a4612cda 100644 --- a/examples/train_efficientnet.py +++ b/examples/train_efficientnet.py @@ -3,7 +3,7 @@ import time from multiprocessing import Process, Queue import numpy as np from tqdm import trange -from tinygrad.state import get_parameters +from tinygrad.nn.state import get_parameters from tinygrad.nn import optim from tinygrad.helpers import getenv from tinygrad.tensor import Tensor diff --git a/examples/train_resnet.py b/examples/train_resnet.py index 81b2e7cd6d..34d778cb0a 100755 --- a/examples/train_resnet.py +++ b/examples/train_resnet.py @@ -2,7 +2,7 @@ import numpy as np from PIL import Image -from tinygrad.state import get_parameters +from tinygrad.nn.state import get_parameters from tinygrad.nn import optim from tinygrad.helpers import getenv from extra.training import train, evaluate diff --git a/examples/transformer.py b/examples/transformer.py index 7bc2157e9f..2257107a5e 100755 --- a/examples/transformer.py +++ b/examples/transformer.py @@ -2,7 +2,7 @@ import numpy as np import random -from tinygrad.state import get_parameters +from tinygrad.nn.state import get_parameters from tinygrad.nn.optim import Adam from extra.training import train, evaluate from models.transformer import Transformer diff --git a/examples/vits.py b/examples/vits.py index 9b4e13ba65..7b88490c9d 100644 --- a/examples/vits.py +++ b/examples/vits.py @@ -5,7 +5,7 @@ from typing import List from extra.utils import download_file from tinygrad import nn from tinygrad.helpers import dtypes -from tinygrad.state import torch_load +from tinygrad.nn.state import torch_load from tinygrad.tensor import Tensor from unidecode import unidecode diff --git a/examples/whisper.py b/examples/whisper.py index 4107c08ded..e36824c5ed 100644 --- a/examples/whisper.py +++ b/examples/whisper.py @@ -7,7 +7,7 @@ import multiprocessing import numpy as np from typing import Optional from extra.utils import download_file -from tinygrad.state import torch_load, load_state_dict +from tinygrad.nn.state import torch_load, load_state_dict from tinygrad.helpers import getenv import tinygrad.nn as nn from tinygrad.tensor import Tensor diff --git a/examples/yolov8.py b/examples/yolov8.py index db3c421fdf..b8e232765e 100644 --- a/examples/yolov8.py +++ b/examples/yolov8.py @@ -8,7 +8,7 @@ import cv2 from collections import defaultdict import os import time, io, sys -from tinygrad.state import safe_load, load_state_dict +from tinygrad.nn.state import safe_load, load_state_dict #Model architecture from https://github.com/ultralytics/ultralytics/issues/189 diff --git a/extra/export_model.py b/extra/export_model.py index f58c60b631..085ee9caeb 100644 --- a/extra/export_model.py +++ b/extra/export_model.py @@ -2,7 +2,7 @@ from typing import Tuple, Dict, List from tinygrad.helpers import DType from tinygrad.tensor import Device, Tensor from tinygrad.jit import TinyJit -from tinygrad.state import get_state_dict +from tinygrad.nn.state import get_state_dict import json def compile_net(run:TinyJit, special_names:Dict[int,str]) -> Tuple[Dict[str,str],List[Tuple[str,List[str],List[int]]],Dict[str,Tuple[int,DType,int]],Dict[str,Tensor]]: diff --git a/models/efficientnet.py b/models/efficientnet.py index 81868392ca..017720dfbd 100644 --- a/models/efficientnet.py +++ b/models/efficientnet.py @@ -143,7 +143,7 @@ class EfficientNet: } from extra.utils import fetch_as_file - from tinygrad.state import torch_load + from tinygrad.nn.state import torch_load b0 = torch_load(fetch_as_file(model_urls[self.number])) for k,v in b0.items(): if k.endswith("num_batches_tracked"): continue diff --git a/models/mask_rcnn.py b/models/mask_rcnn.py index e91f70c66f..aea0d3e489 100644 --- a/models/mask_rcnn.py +++ b/models/mask_rcnn.py @@ -7,7 +7,7 @@ from tinygrad import nn from tinygrad.tensor import Tensor from tinygrad.helpers import dtypes from extra.utils import get_child, download_file -from tinygrad.state import torch_load +from tinygrad.nn.state import torch_load from models.resnet import ResNet from models.retinanet import nms as _box_nms diff --git a/test/external/external_test_allocator_on_models.py b/test/external/external_test_allocator_on_models.py index 36717945d3..a8e79d7c12 100644 --- a/test/external/external_test_allocator_on_models.py +++ b/test/external/external_test_allocator_on_models.py @@ -2,7 +2,7 @@ import unittest, gc import numpy as np from tinygrad.tensor import Tensor -from tinygrad.state import get_parameters, get_state_dict +from tinygrad.nn.state import get_parameters, get_state_dict from tinygrad.ops import GlobalCounters, LazyOp, LoadOps from tinygrad.runtime.lib import RawBuffer, LRUAllocator from tinygrad.helpers import dtypes, prod diff --git a/test/external/external_test_opt.py b/test/external/external_test_opt.py index ff9aa4fe55..1bf2f99895 100644 --- a/test/external/external_test_opt.py +++ b/test/external/external_test_opt.py @@ -36,7 +36,7 @@ from models.convnext import ConvNeXt from models.efficientnet import EfficientNet from models.resnet import ResNet18 from models.vit import ViT -from tinygrad.state import get_parameters +from tinygrad.nn.state import get_parameters @unittest.skipUnless(Device.DEFAULT == "GPU", "Not Implemented") class TestInferenceMinKernels(unittest.TestCase): diff --git a/test/external/external_test_speed_llama.py b/test/external/external_test_speed_llama.py index c009119e01..8b83ee3f66 100644 --- a/test/external/external_test_speed_llama.py +++ b/test/external/external_test_speed_llama.py @@ -5,7 +5,7 @@ from examples.llama import Transformer, MODEL_PARAMS from test.test_net_speed import start_profile, stop_profile from tinygrad.tensor import Tensor from tinygrad.lazy import Device -from tinygrad.state import get_state_dict +from tinygrad.nn.state import get_state_dict from tinygrad.ops import Compiled from tinygrad.helpers import dtypes, prod from tinygrad.runtime.lib import RawBuffer diff --git a/test/external/external_test_yolov8.py b/test/external/external_test_yolov8.py index cc5f6bc76c..98bf555cf6 100644 --- a/test/external/external_test_yolov8.py +++ b/test/external/external_test_yolov8.py @@ -6,29 +6,29 @@ import unittest import io, cv2, os import onnxruntime as ort import ultralytics -from tinygrad.state import safe_load, load_state_dict +from tinygrad.nn.state import safe_load, load_state_dict class TestYOLOv8(unittest.TestCase): def test_all_load_weights(self): for variant in ['n', 's', 'm', 'l', 'x']: weights_location = Path(__file__).parent.parent.parent / "weights" / f'yolov8{variant}.safetensors' download_file(f'https://gitlab.com/r3sist/yolov8_weights/-/raw/master/yolov8{variant}.safetensors', weights_location) - - depth, width, ratio = get_variant_multiples(variant) - TinyYolov8 = YOLOv8(w=width, r=ratio, d=depth, num_classes=80) + + depth, width, ratio = get_variant_multiples(variant) + TinyYolov8 = YOLOv8(w=width, r=ratio, d=depth, num_classes=80) state_dict = safe_load(weights_location) load_state_dict(TinyYolov8, state_dict) print(f'successfully loaded weights for yolov{variant}') - + def test_predictions(self): test_image_urls = ['https://raw.githubusercontent.com/ultralytics/yolov5/master/data/images/bus.jpg', 'https://www.aljazeera.com/wp-content/uploads/2022/10/2022-04-28T192650Z_1186456067_UP1EI4S1I0P14_RTRMADP_3_SOCCER-ENGLAND-MUN-CHE-REPORT.jpg'] variant = 'n' weights_location = Path(__file__).parent.parent.parent / "weights" / f'yolov8{variant}.safetensors' - depth, width, ratio = get_variant_multiples(variant) - TinyYolov8 = YOLOv8(w=width, r=ratio, d=depth, num_classes=80) + depth, width, ratio = get_variant_multiples(variant) + TinyYolov8 = YOLOv8(w=width, r=ratio, d=depth, num_classes=80) state_dict = safe_load(weights_location) load_state_dict(TinyYolov8, state_dict) - + for i in range(len(test_image_urls)): img_stream = io.BytesIO(fetch(test_image_urls[i])) img = cv2.imdecode(np.frombuffer(img_stream.read(), np.uint8), 1) @@ -37,41 +37,40 @@ class TestYOLOv8(unittest.TestCase): post_predictions = postprocess(preds=predictions, img=test_image, orig_imgs=[img]) labels = label_predictions(post_predictions) assert labels == {5: 1, 0: 4, 11: 1} if i == 0 else labels == {0: 13, 29: 1, 32: 1} - + def test_forward_pass_torch_onnx(self): variant = 'n' - weights_location_onnx = Path(__file__).parent.parent.parent / "weights" / f'yolov8{variant}.onnx' - weights_location_pt = Path(__file__).parent.parent.parent / "weights" / f'yolov8{variant}.pt' - weights_location = Path(__file__).parent.parent.parent / "weights" / f'yolov8{variant}.safetensors' + weights_location_onnx = Path(__file__).parent.parent.parent / "weights" / f'yolov8{variant}.onnx' + weights_location_pt = Path(__file__).parent.parent.parent / "weights" / f'yolov8{variant}.pt' + weights_location = Path(__file__).parent.parent.parent / "weights" / f'yolov8{variant}.safetensors' download_file(f'https://github.com/ultralytics/assets/releases/download/v0.0.0/yolov8{variant}.pt', weights_location_pt) # the ultralytics export prints a lot of unneccesary things if not os.path.isfile(weights_location_onnx): - model = ultralytics.YOLO(model=weights_location_pt, task='Detect') - model.export(format="onnx",imgsz=[640, 480]) + model = ultralytics.YOLO(model=weights_location_pt, task='Detect') + model.export(format="onnx",imgsz=[640, 480]) - depth, width, ratio = get_variant_multiples(variant) - TinyYolov8 = YOLOv8(w=width, r=ratio, d=depth, num_classes=80) + depth, width, ratio = get_variant_multiples(variant) + TinyYolov8 = YOLOv8(w=width, r=ratio, d=depth, num_classes=80) state_dict = safe_load(weights_location) load_state_dict(TinyYolov8, state_dict) - + image_location = [np.frombuffer(io.BytesIO(fetch('https://raw.githubusercontent.com/ultralytics/yolov5/master/data/images/bus.jpg')).read(), np.uint8)] orig_image = [cv2.imdecode(image_location[0], 1)] - + input_image = preprocess(orig_image) - + onnx_session = ort.InferenceSession(weights_location_onnx) onnx_input_name = onnx_session.get_inputs()[0].name onnx_output_name = onnx_session.get_outputs()[0].name onnx_output = onnx_session.run([onnx_output_name], {onnx_input_name: input_image.numpy()}) tiny_output = TinyYolov8(input_image) - - # currently rtol is 0.025 because there is a 1-2% difference in our predictions - # because of the zero padding in SPPF module (line 280) maxpooling layers rather than the -infinity in torch. - # This difference does not make a difference "visually". + + # currently rtol is 0.025 because there is a 1-2% difference in our predictions + # because of the zero padding in SPPF module (line 280) maxpooling layers rather than the -infinity in torch. + # This difference does not make a difference "visually". np.testing.assert_allclose(onnx_output[0], tiny_output.numpy(), atol=5e-4, rtol=0.025) - + if __name__ == '__main__': unittest.main() - \ No newline at end of file diff --git a/test/external/graph_batchnorm.py b/test/external/graph_batchnorm.py index 2fa98b05c8..f0813edf35 100644 --- a/test/external/graph_batchnorm.py +++ b/test/external/graph_batchnorm.py @@ -1,5 +1,5 @@ import unittest -from tinygrad.state import get_parameters +from tinygrad.nn.state import get_parameters from tinygrad.tensor import Tensor from tinygrad.nn import Conv2d, BatchNorm2d, optim @@ -33,7 +33,7 @@ class TestBatchnorm(unittest.TestCase): return self.c2(self.c(x)).relu() lm = LilModel() model_step(lm) - + def test_two_conv_bn(self): class LilModel: def __init__(self): diff --git a/test/extra/test_lr_scheduler.py b/test/extra/test_lr_scheduler.py index 283652b48a..9aa9b86341 100644 --- a/test/extra/test_lr_scheduler.py +++ b/test/extra/test_lr_scheduler.py @@ -2,7 +2,7 @@ import numpy as np import torch import unittest from tinygrad.tensor import Tensor -from tinygrad.state import get_parameters +from tinygrad.nn.state import get_parameters from tinygrad.nn.optim import Adam from extra.lr_scheduler import MultiStepLR, ReduceLROnPlateau, CosineAnnealingLR, OneCycleLR from extra.training import train, evaluate diff --git a/test/extra/test_utils.py b/test/extra/test_utils.py index 6350403ade..c73a8ca455 100644 --- a/test/extra/test_utils.py +++ b/test/extra/test_utils.py @@ -6,9 +6,9 @@ from unittest.mock import patch, MagicMock import torch import numpy as np -from tinygrad.helpers import getenv +from tinygrad.helpers import getenv from extra.utils import fetch, temp, download_file -from tinygrad.state import torch_load +from tinygrad.nn.state import torch_load from PIL import Image @unittest.skipIf(getenv("CI", "") != "", "no internet tests in CI") @@ -33,7 +33,7 @@ class TestFetchRelative(unittest.TestCase): os.chdir(self.tempdir.name) with open('test_file.txt', 'x') as f: f.write("12345") - + def tearDown(self): os.chdir(self.working_dir) self.tempdir.cleanup() @@ -41,7 +41,7 @@ class TestFetchRelative(unittest.TestCase): #test ./ def test_fetch_relative_dotslash(self): self.assertEqual(b'12345', fetch("./test_file.txt")) - + #test ../ def test_fetch_relative_dotdotslash(self): os.mkdir('test_file_path') @@ -92,7 +92,7 @@ class TestUtils(unittest.TestCase): ) if isfloat16: model = model.half() - path = temp(f"test_load_{isfloat16}.pt") + path = temp(f"test_load_{isfloat16}.pt") torch.save(model.state_dict(), path) model2 = torch_load(path) diff --git a/test/models/test_bert.py b/test/models/test_bert.py index 82cf9ff9f2..8a20616225 100644 --- a/test/models/test_bert.py +++ b/test/models/test_bert.py @@ -13,7 +13,7 @@ def get_question_samp(bsz, seq_len, vocab_size, seed): return in_ids, mask, seg_ids def set_equal_weights(mdl, torch_mdl): - from tinygrad.state import get_state_dict + from tinygrad.nn.state import get_state_dict state, torch_state = get_state_dict(mdl), torch_mdl.state_dict() assert len(state) == len(torch_state) for k, v in state.items(): diff --git a/test/models/test_end2end.py b/test/models/test_end2end.py index b206e50bc7..b8e8f53a66 100644 --- a/test/models/test_end2end.py +++ b/test/models/test_end2end.py @@ -2,7 +2,7 @@ import torch from torch import nn import unittest import numpy as np -from tinygrad.state import get_parameters, get_state_dict +from tinygrad.nn.state import get_parameters, get_state_dict from tinygrad.nn import optim, Linear, Conv2d, BatchNorm2d from tinygrad.tensor import Tensor from extra.datasets import fetch_mnist diff --git a/test/models/test_mnist.py b/test/models/test_mnist.py index fca4c85084..f3f37c3460 100644 --- a/test/models/test_mnist.py +++ b/test/models/test_mnist.py @@ -1,7 +1,7 @@ #!/usr/bin/env python import unittest import numpy as np -from tinygrad.state import get_parameters +from tinygrad.nn.state import get_parameters from tinygrad.tensor import Tensor, Device from tinygrad.nn import optim, BatchNorm2d from extra.training import train, evaluate diff --git a/test/models/test_real_world.py b/test/models/test_real_world.py index d28bac16da..d2e6f0048b 100644 --- a/test/models/test_real_world.py +++ b/test/models/test_real_world.py @@ -1,7 +1,7 @@ import unittest, time from tinygrad.tensor import Tensor from tinygrad.nn import optim -from tinygrad.state import get_parameters +from tinygrad.nn.state import get_parameters from tinygrad.jit import TinyJit, JIT_SUPPORTED_DEVICE from tinygrad.ops import GlobalCounters, LazyOp, LoadOps from tinygrad.lazy import Device diff --git a/test/models/test_train.py b/test/models/test_train.py index 3f58358564..8931fe3640 100644 --- a/test/models/test_train.py +++ b/test/models/test_train.py @@ -1,7 +1,7 @@ import unittest import time import numpy as np -from tinygrad.state import get_parameters +from tinygrad.nn.state import get_parameters from tinygrad.nn import optim from tinygrad.tensor import Device from tinygrad.helpers import getenv diff --git a/test/unit/test_disk_tensor.py b/test/unit/test_disk_tensor.py index b142b7e0a5..afb360eff2 100644 --- a/test/unit/test_disk_tensor.py +++ b/test/unit/test_disk_tensor.py @@ -2,7 +2,7 @@ import pathlib import unittest import numpy as np from tinygrad.tensor import Tensor, Device -from tinygrad.state import safe_load, safe_save, get_state_dict, torch_load +from tinygrad.nn.state import safe_load, safe_save, get_state_dict, torch_load from tinygrad.helpers import dtypes from tinygrad.runtime.ops_disk import RawDiskBuffer from tinygrad.helpers import Timing diff --git a/tinygrad/state.py b/tinygrad/nn/state.py similarity index 100% rename from tinygrad/state.py rename to tinygrad/nn/state.py