diff --git a/extra/datasets/sops.gz b/extra/datasets/sops.gz new file mode 100644 index 0000000000..20f279f8e5 Binary files /dev/null and b/extra/datasets/sops.gz differ diff --git a/extra/optimization/helpers.py b/extra/optimization/helpers.py index 369f5d3b06..a5acb371df 100644 --- a/extra/optimization/helpers.py +++ b/extra/optimization/helpers.py @@ -10,11 +10,14 @@ inf, nan = float('inf'), float('nan') from tinygrad.codegen.linearizer import Linearizer def ast_str_to_lin(ast_str): return Linearizer(eval(ast_str)) -# load worlds +# load worlds, a dataset of about 12k kernels +import gzip +from pathlib import Path import random from tinygrad.helpers import dedup def load_worlds(filter_reduce=True, filter_noimage=True, filter_novariable=True): - ast_strs = dedup(open("/tmp/sops").read().strip().split("\n")) + fn = Path(__file__).parent.parent / "datasets/sops.gz" + ast_strs = dedup(gzip.open(fn).read().decode('utf-8').strip().split("\n")) if filter_reduce: ast_strs = [x for x in ast_strs if "ReduceOps" in x] if filter_noimage: ast_strs = [x for x in ast_strs if "dtypes.image" not in x] if filter_novariable: ast_strs = [x for x in ast_strs if "Variable" not in x] diff --git a/extra/optimization/rl.py b/extra/optimization/rl.py index 0bfeb88e91..4eab7d8228 100644 --- a/extra/optimization/rl.py +++ b/extra/optimization/rl.py @@ -1,3 +1,4 @@ +import os import numpy as np import math, random from tinygrad.tensor import Tensor @@ -9,7 +10,7 @@ from extra.optimization.helpers import load_worlds, ast_str_to_lin, lin_to_feats if __name__ == "__main__": net = PolicyNet() - load_state_dict(net, safe_load("/tmp/policynet.safetensors")) + if os.path.isfile("/tmp/policynet.safetensors"): load_state_dict(net, safe_load("/tmp/policynet.safetensors")) optim = Adam(get_parameters(net)) ast_strs = load_worlds() @@ -40,7 +41,7 @@ if __name__ == "__main__": rews.append(((last_tm-tm)/base_tm)) last_tm = tm except Exception: - rews.append(-1.0) + rews.append(-0.5) break #print(f"{tm*1e6:10.2f}", lin.colored_shape()) @@ -49,7 +50,7 @@ if __name__ == "__main__": print(f"***** EPISODE {len(rews)} steps, {sum(rews):5.2f} reward, {base_tm*1e6:12.2f} -> {tm*1e6:12.2f} : {lin.colored_shape()}") all_feats += feats all_acts += acts - all_rews += rews + all_rews += np.cumsum(rews).tolist() BS = 32 if len(all_feats) >= BS: