diff --git a/extra/optimization/rl.py b/extra/optimization/rl.py index 9b2dd2d8d5..a088ae337f 100644 --- a/extra/optimization/rl.py +++ b/extra/optimization/rl.py @@ -5,7 +5,7 @@ from tinygrad.tensor import Tensor from tinygrad.nn.state import get_parameters, get_state_dict, safe_save, safe_load, load_state_dict from tinygrad.features.search import actions, bufs_from_lin, time_linearizer, get_linearizer_actions from tinygrad.nn.optim import Adam -from extra.optimization.pretrain_policynet import PolicyNet +from extra.optimization.extract_policynet import PolicyNet from extra.optimization.helpers import load_worlds, ast_str_to_lin, lin_to_feats if __name__ == "__main__": diff --git a/extra/optimization/test_net.py b/extra/optimization/test_net.py index 9dd40af613..851c2a85bd 100644 --- a/extra/optimization/test_net.py +++ b/extra/optimization/test_net.py @@ -8,7 +8,7 @@ from tinygrad.tensor import Tensor from tinygrad.nn.state import get_parameters, get_state_dict, safe_save, safe_load, load_state_dict from tinygrad.features.search import bufs_from_lin, time_linearizer, actions, get_linearizer_actions from extra.optimization.helpers import load_worlds, ast_str_to_lin, lin_to_feats -from extra.optimization.pretrain_policynet import PolicyNet +from extra.optimization.extract_policynet import PolicyNet from extra.optimization.pretrain_valuenet import ValueNet VALUE = getenv("VALUE")