mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 15:08:02 -05:00
CLANG -> CPU (#9189)
This commit is contained in:
@@ -15,9 +15,9 @@ if __name__ == "__main__":
|
||||
if getenv("WEBGPU"):
|
||||
safe_save(get_state_dict(model), (dirname / "net.safetensors").as_posix())
|
||||
load_state_dict(model, safe_load(str(dirname / "net.safetensors")))
|
||||
mode = "clang" if getenv("CLANG", "") != "" else "webgpu" if getenv("WEBGPU", "") != "" else ""
|
||||
mode = "clang" if getenv("CPU", "") != "" else "webgpu" if getenv("WEBGPU", "") != "" else ""
|
||||
prg, inp_sizes, out_sizes, state = export_model(model, mode, Tensor.randn(1,3,224,224))
|
||||
if getenv("CLANG", "") == "":
|
||||
if getenv("CPU", "") == "":
|
||||
ext = "js" if getenv("WEBGPU", "") != "" else "json"
|
||||
with open(dirname / f"net.{ext}", "w") as text_file:
|
||||
text_file.write(prg)
|
||||
@@ -68,6 +68,6 @@ if __name__ == "__main__":
|
||||
else printf("%s\\n", lbls[best_idx]);
|
||||
}""")
|
||||
|
||||
# CLANG=1 python3 examples/compile_efficientnet.py | clang -O2 -lm -x c - -o recognize && DEBUG=1 time ./recognize docs/showcase/stable_diffusion_by_tinygrad.jpg
|
||||
# CPU=1 python3 examples/compile_efficientnet.py | clang -O2 -lm -x c - -o recognize && DEBUG=1 time ./recognize docs/showcase/stable_diffusion_by_tinygrad.jpg
|
||||
# category : 281 (tabby, tabby cat) with 9.452788
|
||||
print('\n'.join(cprog))
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
# An example to compile a small Tensorflow model to extremely portable C code
|
||||
|
||||
import os, sys
|
||||
os.environ["CLANG"] = '1'
|
||||
os.environ["CPU"] = '1'
|
||||
os.environ["JIT"] = '2'
|
||||
|
||||
import numpy as np
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
import os
|
||||
if "NOOPT" not in os.environ: os.environ["NOOPT"] = "1"
|
||||
from tinygrad import Device, nn, Tensor, dtypes, Variable
|
||||
Device.DEFAULT = "CLANG"
|
||||
Device.DEFAULT = "CPU"
|
||||
from train_gpt2 import GPT, GPTConfig
|
||||
from tinygrad.helpers import dedup, to_function_name, flatten, getenv, GlobalCounters, ansilen, to_function_name
|
||||
from tinygrad.engine.realize import get_kernel, run_schedule
|
||||
@@ -43,9 +43,9 @@ if __name__ == "__main__":
|
||||
ast_dedup = dedup([si.ast for si in sched if si.ast.op is Ops.SINK])
|
||||
srcs = {}
|
||||
for ast in ast_dedup:
|
||||
k = get_kernel(Device["CLANG"].renderer, ast)
|
||||
k = get_kernel(Device["CPU"].renderer, ast)
|
||||
k.linearize()
|
||||
src = Device["CLANG"].renderer.render(to_function_name(k.name), k.uops)
|
||||
src = Device["CPU"].renderer.render(to_function_name(k.name), k.uops)
|
||||
srcs[ast] = (k.name, src)
|
||||
print("functions:", len(srcs))
|
||||
used_buffers = dedup(flatten([si.bufs for si in sched]))
|
||||
|
||||
@@ -170,13 +170,13 @@ def batch_load_resnet(batch_size=64, val=False, shuffle=True, seed=None, pad_fir
|
||||
|
||||
def process_batch_bert(data: List[dict]) -> dict[str, Tensor]:
|
||||
return {
|
||||
"input_ids": Tensor(np.concatenate([s["input_ids"] for s in data], axis=0), dtype=dtypes.int32, device="CLANG"),
|
||||
"input_mask": Tensor(np.concatenate([s["input_mask"] for s in data], axis=0), dtype=dtypes.int32, device="CLANG"),
|
||||
"segment_ids": Tensor(np.concatenate([s["segment_ids"] for s in data], axis=0), dtype=dtypes.int32, device="CLANG"),
|
||||
"masked_lm_positions": Tensor(np.concatenate([s["masked_lm_positions"] for s in data], axis=0), dtype=dtypes.int32, device="CLANG"),
|
||||
"masked_lm_ids": Tensor(np.concatenate([s["masked_lm_ids"] for s in data], axis=0), dtype=dtypes.int32, device="CLANG"),
|
||||
"masked_lm_weights": Tensor(np.concatenate([s["masked_lm_weights"] for s in data], axis=0), dtype=dtypes.float32, device="CLANG"),
|
||||
"next_sentence_labels": Tensor(np.concatenate([s["next_sentence_labels"] for s in data], axis=0), dtype=dtypes.int32, device="CLANG"),
|
||||
"input_ids": Tensor(np.concatenate([s["input_ids"] for s in data], axis=0), dtype=dtypes.int32, device="CPU"),
|
||||
"input_mask": Tensor(np.concatenate([s["input_mask"] for s in data], axis=0), dtype=dtypes.int32, device="CPU"),
|
||||
"segment_ids": Tensor(np.concatenate([s["segment_ids"] for s in data], axis=0), dtype=dtypes.int32, device="CPU"),
|
||||
"masked_lm_positions": Tensor(np.concatenate([s["masked_lm_positions"] for s in data], axis=0), dtype=dtypes.int32, device="CPU"),
|
||||
"masked_lm_ids": Tensor(np.concatenate([s["masked_lm_ids"] for s in data], axis=0), dtype=dtypes.int32, device="CPU"),
|
||||
"masked_lm_weights": Tensor(np.concatenate([s["masked_lm_weights"] for s in data], axis=0), dtype=dtypes.float32, device="CPU"),
|
||||
"next_sentence_labels": Tensor(np.concatenate([s["next_sentence_labels"] for s in data], axis=0), dtype=dtypes.int32, device="CPU"),
|
||||
}
|
||||
|
||||
def load_file(file: str):
|
||||
|
||||
@@ -222,11 +222,11 @@ def get_mlperf_bert_model():
|
||||
|
||||
def get_fake_data_bert(BS:int):
|
||||
return {
|
||||
"input_ids": Tensor.empty((BS, 512), dtype=dtypes.int32, device="CLANG"),
|
||||
"input_mask": Tensor.empty((BS, 512), dtype=dtypes.int32, device="CLANG"),
|
||||
"segment_ids": Tensor.empty((BS, 512), dtype=dtypes.int32, device="CLANG"),
|
||||
"masked_lm_positions": Tensor.empty((BS, 76), dtype=dtypes.int32, device="CLANG"),
|
||||
"masked_lm_ids": Tensor.empty((BS, 76), dtype=dtypes.int32, device="CLANG"),
|
||||
"masked_lm_weights": Tensor.empty((BS, 76), dtype=dtypes.float32, device="CLANG"),
|
||||
"next_sentence_labels": Tensor.empty((BS, 1), dtype=dtypes.int32, device="CLANG"),
|
||||
"input_ids": Tensor.empty((BS, 512), dtype=dtypes.int32, device="CPU"),
|
||||
"input_mask": Tensor.empty((BS, 512), dtype=dtypes.int32, device="CPU"),
|
||||
"segment_ids": Tensor.empty((BS, 512), dtype=dtypes.int32, device="CPU"),
|
||||
"masked_lm_positions": Tensor.empty((BS, 76), dtype=dtypes.int32, device="CPU"),
|
||||
"masked_lm_ids": Tensor.empty((BS, 76), dtype=dtypes.int32, device="CPU"),
|
||||
"masked_lm_weights": Tensor.empty((BS, 76), dtype=dtypes.float32, device="CPU"),
|
||||
"next_sentence_labels": Tensor.empty((BS, 1), dtype=dtypes.int32, device="CPU"),
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user