encapsulate the exported webgpu model (#8203)

This commit is contained in:
Ahmed Harmouche
2024-12-13 10:55:37 +01:00
committed by GitHub
parent 5864627abe
commit 651f72442c
4 changed files with 32 additions and 28 deletions

View File

@@ -17,8 +17,11 @@ canvas { display: none; }
* { text-align: center; font-family: monospace; }
</style>
<title>tinygrad has WebGPU</title>
<script src="../../net.js"></script>
<link rel="icon" type="image/x-icon" href="https://raw.githubusercontent.com/tinygrad/tinygrad/master/docs/logo.png">
<script type="module">
import model from "../../net.js";
window.model = model;
</script>
</head>
<body>
<h1>WebGPU <a href="https://github.com/geohot/tinygrad">tinygrad</a> EfficientNet!</h1>
@@ -61,8 +64,6 @@ canvas { display: none; }
const getLabels = async () => (await fetch("https://raw.githubusercontent.com/anishathalye/imagenet-simple-labels/master/imagenet-simple-labels.json")).json();
const getSavetensorBuffer = async () => new Uint8Array(await (await fetch("../../net.safetensors")).arrayBuffer());
const reorderChannelsAndRemoveAlpha = (data) => {
const out = [];
let i = 0;
@@ -97,9 +98,8 @@ canvas { display: none; }
try {
resultText.innerHTML = "loading..."
labels = await getLabels();
const safetensor = await getSavetensorBuffer();
const device = await getDevice();
net = await timer(() => setupNet(device, safetensor), "(compilation)");
net = await timer(() => model.load(device, '../../net.safetensors'), "(compilation)");
resultText.innerHTML = "ready"
} catch (e) {
error(e)

View File

@@ -12,7 +12,7 @@ if __name__ == "__main__":
yolo_infer = YOLOv8(w=0.25, r=2.0, d=0.33, num_classes=80)
state_dict = safe_load(get_weights_location(yolo_variant))
load_state_dict(yolo_infer, state_dict)
prg, inp_sizes, out_sizes, state = export_model(yolo_infer, Device.DEFAULT.lower(), Tensor.randn(1,3,416,416))
prg, inp_sizes, out_sizes, state = export_model(yolo_infer, Device.DEFAULT.lower(), Tensor.randn(1,3,416,416), model_name="yolov8")
dirname = Path(__file__).parent
safe_save(state, (dirname / "net.safetensors").as_posix())
with open(dirname / f"net.js", "w") as text_file:

View File

@@ -4,7 +4,10 @@
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<title>YOLOv8 tinygrad WebGPU</title>
<script src="./net.js"></script>
<script type="module">
import yolov8 from "./net.js"
window.yolov8 = yolov8;
</script>
<style>
body {
text-align: center;
@@ -213,7 +216,7 @@
wgpuError.style.display = "block";
loadingContainer.style.display = "none";
}
net = await loadNet(device);
net = await yolov8.load(device, "./net.safetensors");
loadingContainer.style.display = "none";
}
let start = performance.now();