mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 15:08:02 -05:00
simple exporting models (#1344)
* unified exporting * json exporting * ignore more * simplified buffer export * added dtypes * added assert * swift example * fix tests * linter * remove whitespace * fixed tests * remove swift example * remove unintended changes * allow callable models to be used * whitespace * more readable json export * name change * whitespace * whitespace
This commit is contained in:
2
.github/workflows/test.yml
vendored
2
.github/workflows/test.yml
vendored
@@ -184,7 +184,7 @@ jobs:
|
||||
- name: Run webgpu pytest
|
||||
run: WEBGPU=1 WGPU_BACKEND_TYPE=Metal python -m pytest -n=auto -m 'webgpu'
|
||||
- name: Build WEBGPU Efficientnet
|
||||
run: WEBGPU=1 WGPU_BACKEND_TYPE=Metal python -m examples.webgpu.compile_webgpu
|
||||
run: WEBGPU=1 WGPU_BACKEND_TYPE=Metal python -m examples.compile_efficientnet
|
||||
|
||||
testdocker:
|
||||
name: Docker Test
|
||||
|
||||
3
.gitignore
vendored
3
.gitignore
vendored
@@ -31,8 +31,7 @@ extra/datasets/kits/
|
||||
extra/datasets/COCO/
|
||||
extra/datasets/audio*
|
||||
venv
|
||||
examples/webgpu/net.js
|
||||
examples/webgpu/net.safetensors
|
||||
examples/net.*[js,json,safetensors]
|
||||
node_modules
|
||||
package.json
|
||||
package-lock.json
|
||||
|
||||
@@ -1,109 +1,67 @@
|
||||
from models.efficientnet import EfficientNet
|
||||
from tinygrad.tensor import Tensor
|
||||
from tinygrad.jit import TinyJit
|
||||
from tinygrad.state import safe_save
|
||||
from extra.utils import fetch
|
||||
import ast
|
||||
|
||||
def compile_net(run, special_names):
|
||||
functions, bufs, bufs_to_save, statements, bufnum = {}, {}, {}, [], 0
|
||||
for fxn,args in run.jit_cache:
|
||||
functions[fxn.name] = fxn.prg # NOTE: this assumes all with the same name are the same
|
||||
cargs = []
|
||||
for i,arg in enumerate(args):
|
||||
key = id(arg)
|
||||
if key not in bufs:
|
||||
if key in special_names:
|
||||
bufs[key] = (special_names[key], arg._memsz, key)
|
||||
else:
|
||||
bufs[key] = (f"buf_{bufnum}", arg._memsz, key)
|
||||
bufnum += 1
|
||||
if i > 0: bufs_to_save[bufs[key][0]] = arg # if first usage of a buffer is not an output, and it's not a special name
|
||||
cargs.append(bufs[key][0])
|
||||
statements.append((fxn.name, cargs, fxn.global_size))
|
||||
|
||||
return functions, statements, bufs, bufs_to_save
|
||||
|
||||
def jit_model(model, the_input):
|
||||
@TinyJit
|
||||
def run(x): return model.forward(x).realize()
|
||||
|
||||
# twice to run the JIT
|
||||
the_output = run(the_input)
|
||||
the_output = run(the_input)
|
||||
|
||||
# hack to put the inputs back
|
||||
assert len(run.input_replace) == 1, f"didn't get one input to replace {run.input_replace}"
|
||||
for (j,i),idx in run.input_replace.items():
|
||||
run.jit_cache[j][1][i] = the_input.lazydata.realized
|
||||
|
||||
# TODO: fetch this from the jit in self.input_replace and self.ret (hint: use get_parameters on self.ret)
|
||||
special_names = {id(the_input.lazydata.realized): "input", id(the_output.lazydata.realized): "outputs"}
|
||||
return run, special_names
|
||||
from extra.export_model import export_model
|
||||
from tinygrad.helpers import getenv
|
||||
import ast, os
|
||||
|
||||
if __name__ == "__main__":
|
||||
model = EfficientNet(0)
|
||||
model.load_from_pretrained()
|
||||
run, special_names = jit_model(model, Tensor.randn(1,3,224,224))
|
||||
functions, statements, bufs, bufs_to_save = compile_net(run, special_names)
|
||||
mode = "clang" if getenv("CLANG", "") != "" else "webgpu" if getenv("WEBGPU", "") != "" else ""
|
||||
prg, inp_size, out_size, state = export_model(model, Tensor.randn(1,3,224,224), mode)
|
||||
if getenv("CLANG", "") == "":
|
||||
safe_save(state, os.path.join(os.path.dirname(__file__), "net.safetensors"))
|
||||
ext = "js" if getenv("WEBGPU", "") != "" else "json"
|
||||
with open(os.path.join(os.path.dirname(__file__), f"net.{ext}"), "w") as text_file:
|
||||
text_file.write(prg)
|
||||
else:
|
||||
cprog = [prg]
|
||||
# image library!
|
||||
cprog += ["#define STB_IMAGE_IMPLEMENTATION", fetch("https://raw.githubusercontent.com/nothings/stb/master/stb_image.h").decode('utf-8').replace("half", "_half")]
|
||||
|
||||
# c header
|
||||
cprog = ["#include <stdio.h>", "#include <math.h>", "#define max(x,y) ((x>y)?x:y)"]
|
||||
# imagenet labels, move to datasets?
|
||||
lbls = fetch("https://gist.githubusercontent.com/yrevar/942d3a0ac09ec9e5eb3a/raw/238f720ff059c1f82f368259d1ca4ffa5dd8f9f5/imagenet1000_clsidx_to_labels.txt")
|
||||
lbls = ast.literal_eval(lbls.decode('utf-8'))
|
||||
lbls = ['"'+lbls[i]+'"' for i in range(1000)]
|
||||
cprog.append(f"char *lbls[] = {{{','.join(lbls)}}};")
|
||||
cprog.append(f"float input[{inp_size}];")
|
||||
cprog.append(f"float outputs[{out_size}];")
|
||||
|
||||
# save the weights
|
||||
for name,cl in bufs_to_save.items():
|
||||
weight = ''.join(["\\x%02X"%x for x in bytes(cl._buf)])
|
||||
cprog.append(f"unsigned char {name}_data[] = \"{weight}\";")
|
||||
|
||||
# image library!
|
||||
cprog += ["#define STB_IMAGE_IMPLEMENTATION", fetch("https://raw.githubusercontent.com/nothings/stb/master/stb_image.h").decode('utf-8')]
|
||||
|
||||
# imagenet labels, move to datasets?
|
||||
lbls = fetch("https://gist.githubusercontent.com/yrevar/942d3a0ac09ec9e5eb3a/raw/238f720ff059c1f82f368259d1ca4ffa5dd8f9f5/imagenet1000_clsidx_to_labels.txt")
|
||||
lbls = ast.literal_eval(lbls.decode('utf-8'))
|
||||
lbls = ['"'+lbls[i]+'"' for i in range(1000)]
|
||||
cprog.append(f"char *lbls[] = {{{','.join(lbls)}}};")
|
||||
|
||||
# buffers (empty + weights)
|
||||
cprog += [f"float {name}[{len}];" if name not in bufs_to_save else f"float *{name} = (float *){name}_data;" for name,len,_key in bufs.values()]
|
||||
|
||||
# the functions
|
||||
cprog += list(functions.values())
|
||||
|
||||
# the net
|
||||
cprog += ["void net() {"] + [f"{name}({', '.join(args)});" for (name, args, _global_size) in statements] + ["}"]
|
||||
|
||||
cprog += ["""
|
||||
int main(int argc, char* argv[]) {
|
||||
int DEBUG = getenv("DEBUG") != NULL ? atoi(getenv("DEBUG")) : 0;
|
||||
int X=0, Y=0, chan=0;
|
||||
stbi_uc *image = (argc > 1) ? stbi_load(argv[1], &X, &Y, &chan, 3) : stbi_load_from_file(stdin, &X, &Y, &chan, 3);
|
||||
assert(image != NULL);
|
||||
if (DEBUG) printf("loaded image %dx%d channels %d\\n", X, Y, chan);
|
||||
assert(chan == 3);
|
||||
// resize to input[1,3,224,224] and rescale
|
||||
for (int y = 0; y < 224; y++) {
|
||||
for (int x = 0; x < 224; x++) {
|
||||
// get sample position
|
||||
int tx = (x/224.)*X;
|
||||
int ty = (y/224.)*Y;
|
||||
for (int c = 0; c < 3; c++) {
|
||||
input[c*224*224 + y*224 + x] = (image[ty*X*chan + tx*chan + c] / 255.0 - 0.45) / 0.225;
|
||||
# buffers (empty + weights)
|
||||
cprog.append("""
|
||||
int main(int argc, char* argv[]) {
|
||||
int DEBUG = getenv("DEBUG") != NULL ? atoi(getenv("DEBUG")) : 0;
|
||||
int X=0, Y=0, chan=0;
|
||||
stbi_uc *image = (argc > 1) ? stbi_load(argv[1], &X, &Y, &chan, 3) : stbi_load_from_file(stdin, &X, &Y, &chan, 3);
|
||||
assert(image != NULL);
|
||||
if (DEBUG) printf("loaded image %dx%d channels %d\\n", X, Y, chan);
|
||||
assert(chan == 3);
|
||||
// resize to input[1,3,224,224] and rescale
|
||||
for (int y = 0; y < 224; y++) {
|
||||
for (int x = 0; x < 224; x++) {
|
||||
// get sample position
|
||||
int tx = (x/224.)*X;
|
||||
int ty = (y/224.)*Y;
|
||||
for (int c = 0; c < 3; c++) {
|
||||
input[c*224*224 + y*224 + x] = (image[ty*X*chan + tx*chan + c] / 255.0 - 0.45) / 0.225;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
net();
|
||||
float best = -INFINITY;
|
||||
int best_idx = -1;
|
||||
for (int i = 0; i < 1000; i++) {
|
||||
if (outputs[i] > best) {
|
||||
best = outputs[i];
|
||||
best_idx = i;
|
||||
net(input, outputs);
|
||||
float best = -INFINITY;
|
||||
int best_idx = -1;
|
||||
for (int i = 0; i < 1000; i++) {
|
||||
if (outputs[i] > best) {
|
||||
best = outputs[i];
|
||||
best_idx = i;
|
||||
}
|
||||
}
|
||||
}
|
||||
if (DEBUG) printf("category : %d (%s) with %f\\n", best_idx, lbls[best_idx], best);
|
||||
else printf("%s\\n", lbls[best_idx]);
|
||||
}"""]
|
||||
if (DEBUG) printf("category : %d (%s) with %f\\n", best_idx, lbls[best_idx], best);
|
||||
else printf("%s\\n", lbls[best_idx]);
|
||||
}""")
|
||||
|
||||
# CLANG=1 python3 examples/compile_efficientnet.py | clang -O2 -lm -x c - -o recognize && DEBUG=1 time ./recognize docs/showcase/stable_diffusion_by_tinygrad.jpg
|
||||
# category : 281 (tabby, tabby cat) with 9.452788
|
||||
print('\n'.join(cprog))
|
||||
# CLANG=1 python3 examples/compile_efficientnet.py | clang -O2 -lm -x c - -o recognize && DEBUG=1 time ./recognize docs/showcase/stable_diffusion_by_tinygrad.jpg
|
||||
# category : 281 (tabby, tabby cat) with 9.452788
|
||||
print('\n'.join(cprog))
|
||||
|
||||
@@ -1,87 +0,0 @@
|
||||
from os import path
|
||||
from examples.compile_efficientnet import compile_net, jit_model
|
||||
from models.efficientnet import EfficientNet
|
||||
from tinygrad.state import get_state_dict, safe_save
|
||||
from tinygrad.tensor import Tensor
|
||||
|
||||
if __name__ == "__main__":
|
||||
model = EfficientNet(0)
|
||||
model.load_from_pretrained()
|
||||
run, special_names = jit_model(model, Tensor.randn(1,3,224,224))
|
||||
functions, statements, bufs, _bufs_to_save = compile_net(run, special_names)
|
||||
|
||||
state = get_state_dict(model)
|
||||
weights = {id(x.lazydata.realized): name for name, x in state.items()}
|
||||
safe_save(state, path.join(path.dirname(__file__), "net.safetensors"))
|
||||
|
||||
kernel_code = '\n\n'.join([f"const {key} = `{code.replace(key, 'main')}`;" for key, code in functions.items()])
|
||||
kernel_names = ', '.join([name for (name, _args, _global_size) in statements])
|
||||
kernel_calls = '\n '.join([f"addComputePass(device, commandEncoder, piplines[{i}], [{', '.join(args)}], {global_size});" for i, (_name, args, global_size) in enumerate(statements) ])
|
||||
bufs = '\n '.join([f"const {buf[0]} = " + (f"createEmptyBuf(device, {buf[1]});" if buf[2] not in weights else f"createWeightBuf(device, {buf[1]}, getTensorBuffer(safetensor, metadata['{weights[buf[2]]}']))") + ";" for buf in bufs.values()])
|
||||
|
||||
prg = f"""const getTensorMetadata = (safetensorBuffer) => {{
|
||||
const metadataLength = Number(new DataView(safetensorBuffer.buffer).getBigUint64(0, true));
|
||||
const metadata = JSON.parse(new TextDecoder("utf8").decode(safetensorBuffer.subarray(8, 8 + metadataLength)));
|
||||
return Object.fromEntries(Object.entries(metadata).filter(([k, v]) => k !== "__metadata__").map(([k, v]) => [k, {{...v, data_offsets: v.data_offsets.map(x => 8 + metadataLength + x)}}]));
|
||||
}};
|
||||
|
||||
const getTensorBuffer = (safetensorBuffer, tensorMetadata) => {{
|
||||
return safetensorBuffer.subarray(...tensorMetadata.data_offsets);
|
||||
}}
|
||||
|
||||
const createEmptyBuf = (device, size) => {{
|
||||
return device.createBuffer({{size, usage: GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_SRC | GPUBufferUsage.COPY_DST }});
|
||||
}};
|
||||
|
||||
const createWeightBuf = (device, size, data) => {{
|
||||
const buf = device.createBuffer({{ mappedAtCreation: true, size, usage: GPUBufferUsage.STORAGE }});
|
||||
new Uint8Array(buf.getMappedRange()).set(data);
|
||||
buf.unmap();
|
||||
return buf;
|
||||
}};
|
||||
|
||||
const addComputePass = (device, commandEncoder, pipeline, bufs, workgroup) => {{
|
||||
const bindGroup = device.createBindGroup({{layout: pipeline.getBindGroupLayout(0), entries: bufs.map((buffer, index) => ({{ binding: index, resource: {{ buffer }} }}))}});
|
||||
const passEncoder = commandEncoder.beginComputePass();
|
||||
passEncoder.setPipeline(pipeline);
|
||||
passEncoder.setBindGroup(0, bindGroup);
|
||||
passEncoder.dispatchWorkgroups(...workgroup);
|
||||
passEncoder.end();
|
||||
}};
|
||||
|
||||
{kernel_code}
|
||||
|
||||
const setupNet = async (device, safetensor) => {{
|
||||
const metadata = getTensorMetadata(safetensor);
|
||||
|
||||
{bufs}
|
||||
|
||||
const gpuWriteBuffer = device.createBuffer({{size:input.size, usage: GPUBufferUsage.COPY_SRC | GPUBufferUsage.MAP_WRITE }});
|
||||
const gpuReadBuffer = device.createBuffer({{ size: outputs.size, usage: GPUBufferUsage.COPY_DST | GPUBufferUsage.MAP_READ }});
|
||||
|
||||
const kernels = [{kernel_names}];
|
||||
const piplines = await Promise.all(kernels.map(name => device.createComputePipelineAsync({{layout: "auto", compute: {{ module: device.createShaderModule({{ code: name }}), entryPoint: "main" }}}})));
|
||||
|
||||
return async (data) => {{
|
||||
await gpuWriteBuffer.mapAsync(GPUMapMode.WRITE);
|
||||
new Float32Array(gpuWriteBuffer.getMappedRange()).set(data);
|
||||
gpuWriteBuffer.unmap();
|
||||
|
||||
const commandEncoder = device.createCommandEncoder();
|
||||
commandEncoder.copyBufferToBuffer(gpuWriteBuffer, 0, input, 0, gpuWriteBuffer.size);
|
||||
{kernel_calls}
|
||||
commandEncoder.copyBufferToBuffer(outputs, 0, gpuReadBuffer, 0, outputs.size);
|
||||
const gpuCommands = commandEncoder.finish();
|
||||
device.queue.submit([gpuCommands]);
|
||||
|
||||
await gpuReadBuffer.mapAsync(GPUMapMode.READ);
|
||||
const resultBuffer = new Float32Array(gpuReadBuffer.size);
|
||||
resultBuffer.set(new Float32Array(gpuReadBuffer.getMappedRange()));
|
||||
gpuReadBuffer.unmap();
|
||||
return resultBuffer;
|
||||
}}
|
||||
}}
|
||||
"""
|
||||
|
||||
with open(path.join(path.dirname(__file__), "net.js"), "w") as text_file:
|
||||
text_file.write(prg)
|
||||
165
extra/export_model.py
Normal file
165
extra/export_model.py
Normal file
@@ -0,0 +1,165 @@
|
||||
from typing import Tuple, Dict, List
|
||||
from tinygrad.helpers import DType
|
||||
from tinygrad.tensor import Device, Tensor
|
||||
from tinygrad.jit import TinyJit
|
||||
from tinygrad.state import get_state_dict
|
||||
import json
|
||||
|
||||
def compile_net(run:TinyJit, special_names:Dict[int,str]) -> Tuple[Dict[str,str],List[Tuple[str,List[str],List[int]]],Dict[str,Tuple[int,DType,int]],Dict[str,Tensor]]:
|
||||
functions, bufs, bufs_to_save, statements, bufnum = {}, {}, {}, [], 0
|
||||
for fxn,args in run.jit_cache:
|
||||
functions[fxn.name] = fxn.prg # NOTE: this assumes all with the same name are the same
|
||||
cargs = []
|
||||
for i,arg in enumerate(args):
|
||||
key = id(arg)
|
||||
if key not in bufs:
|
||||
if key in special_names:
|
||||
bufs[key] = (special_names[key], arg._memsz, arg.dtype, key)
|
||||
else:
|
||||
bufs[key] = (f"buf_{bufnum}", arg._memsz, arg.dtype, key)
|
||||
bufnum += 1
|
||||
if i > 0: bufs_to_save[bufs[key][0]] = arg # if first usage of a buffer is not an output, and it's not a special name
|
||||
cargs.append(bufs[key][0])
|
||||
statements.append((fxn.name, cargs, fxn.global_size, fxn.local_size))
|
||||
|
||||
return functions, statements, {name:(size, dtype, key) for (name,size,dtype,key) in bufs.values()}, bufs_to_save
|
||||
|
||||
def jit_model(model, the_input:Tensor) -> Tuple[TinyJit,Dict[int,str]]:
|
||||
assert hasattr(model, "forward") or callable(model), "model needs a forward function"
|
||||
@TinyJit
|
||||
def run(x): return (model.forward(x) if hasattr(model, "forward") else model(x)).realize()
|
||||
|
||||
# twice to run the JIT
|
||||
for _ in range(2): the_output = run(the_input)
|
||||
|
||||
# hack to put the inputs back
|
||||
assert len(run.input_replace) == 1, f"didn't get one input to replace {run.input_replace}"
|
||||
for (j,i),idx in run.input_replace.items():
|
||||
run.jit_cache[j][1][i] = the_input.lazydata.realized
|
||||
|
||||
# TODO: fetch this from the jit in self.input_replace and self.ret (hint: use get_parameters on self.ret)
|
||||
special_names = {id(the_input.lazydata.realized): "input", id(the_output.lazydata.realized): "outputs"}
|
||||
return run, special_names
|
||||
|
||||
def export_model_clang(functions:Dict[str,str], statements:Dict[str,Tuple[str,int,int]], bufs:Dict[str,Tuple[str,int,int]], bufs_to_save:Dict[str,Tensor]) -> str:
|
||||
from tinygrad.runtime.ops_clang import CLANG_PROGRAM_HEADER
|
||||
cprog = [CLANG_PROGRAM_HEADER]
|
||||
|
||||
for name,cl in bufs_to_save.items():
|
||||
weight = ''.join(["\\x%02X"%x for x in bytes(cl._buf)])
|
||||
cprog.append(f"unsigned char {name}_data[] = \"{weight}\";")
|
||||
|
||||
cprog += [f"float {name}[{len}];" if name not in bufs_to_save else f"float *{name} = (float *){name}_data;" for name,(len,dtype,_key) in bufs.items() if name not in ['input', 'outputs']]
|
||||
cprog += list(functions.values())
|
||||
cprog += ["void net(float* input, float* outputs) {"] + [f"{name}({', '.join(args)});" for (name, args, _global_size, _local_size) in statements] + ["}"]
|
||||
return '\n'.join(cprog)
|
||||
|
||||
def export_model_webgpu(functions, statements, bufs, bufs_to_save, weight_names) -> Tuple[str,int,int]:
|
||||
kernel_code = '\n\n'.join([f"const {key} = `{code.replace(key, 'main')}`;" for key, code in functions.items()])
|
||||
kernel_names = ', '.join([name for (name, _args, _global_size, _local_size) in statements])
|
||||
kernel_calls = '\n '.join([f"addComputePass(device, commandEncoder, piplines[{i}], [{', '.join(args)}], {global_size});" for i, (_name, args, global_size, _local_size) in enumerate(statements) ])
|
||||
_bufs = '\n '.join([f"const {name} = " + (f"createEmptyBuf(device, {size});" if _key not in weight_names else f"createWeightBuf(device, {size}, getTensorBuffer(safetensor, metadata['{weight_names[_key]}']))") + ";" for name,(size,dtype,_key) in bufs.items()])
|
||||
return f"""
|
||||
const getTensorMetadata = (safetensorBuffer) => {{
|
||||
const metadataLength = Number(new DataView(safetensorBuffer.buffer).getBigUint64(0, true));
|
||||
const metadata = JSON.parse(new TextDecoder("utf8").decode(safetensorBuffer.subarray(8, 8 + metadataLength)));
|
||||
return Object.fromEntries(Object.entries(metadata).filter(([k, v]) => k !== "__metadata__").map(([k, v]) => [k, {{...v, data_offsets: v.data_offsets.map(x => 8 + metadataLength + x)}}]));
|
||||
}};
|
||||
|
||||
const getTensorBuffer = (safetensorBuffer, tensorMetadata) => {{
|
||||
return safetensorBuffer.subarray(...tensorMetadata.data_offsets);
|
||||
}}
|
||||
|
||||
const createEmptyBuf = (device, size) => {{
|
||||
return device.createBuffer({{size, usage: GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_SRC | GPUBufferUsage.COPY_DST }});
|
||||
}};
|
||||
|
||||
const createWeightBuf = (device, size, data) => {{
|
||||
const buf = device.createBuffer({{ mappedAtCreation: true, size, usage: GPUBufferUsage.STORAGE }});
|
||||
new Uint8Array(buf.getMappedRange()).set(data);
|
||||
buf.unmap();
|
||||
return buf;
|
||||
}};
|
||||
|
||||
const addComputePass = (device, commandEncoder, pipeline, bufs, workgroup) => {{
|
||||
const bindGroup = device.createBindGroup({{layout: pipeline.getBindGroupLayout(0), entries: bufs.map((buffer, index) => ({{ binding: index, resource: {{ buffer }} }}))}});
|
||||
const passEncoder = commandEncoder.beginComputePass();
|
||||
passEncoder.setPipeline(pipeline);
|
||||
passEncoder.setBindGroup(0, bindGroup);
|
||||
passEncoder.dispatchWorkgroups(...workgroup);
|
||||
passEncoder.end();
|
||||
}};
|
||||
|
||||
{kernel_code}
|
||||
|
||||
const setupNet = async (device, safetensor) => {{
|
||||
const metadata = getTensorMetadata(safetensor);
|
||||
|
||||
{_bufs}
|
||||
|
||||
const gpuWriteBuffer = device.createBuffer({{size:input.size, usage: GPUBufferUsage.COPY_SRC | GPUBufferUsage.MAP_WRITE }});
|
||||
const gpuReadBuffer = device.createBuffer({{ size: outputs.size, usage: GPUBufferUsage.COPY_DST | GPUBufferUsage.MAP_READ }});
|
||||
|
||||
const kernels = [{kernel_names}];
|
||||
const piplines = await Promise.all(kernels.map(name => device.createComputePipelineAsync({{layout: "auto", compute: {{ module: device.createShaderModule({{ code: name }}), entryPoint: "main" }}}})));
|
||||
|
||||
return async (data) => {{
|
||||
await gpuWriteBuffer.mapAsync(GPUMapMode.WRITE);
|
||||
new Float32Array(gpuWriteBuffer.getMappedRange()).set(data);
|
||||
gpuWriteBuffer.unmap();
|
||||
|
||||
const commandEncoder = device.createCommandEncoder();
|
||||
commandEncoder.copyBufferToBuffer(gpuWriteBuffer, 0, input, 0, gpuWriteBuffer.size);
|
||||
{kernel_calls}
|
||||
commandEncoder.copyBufferToBuffer(outputs, 0, gpuReadBuffer, 0, outputs.size);
|
||||
const gpuCommands = commandEncoder.finish();
|
||||
device.queue.submit([gpuCommands]);
|
||||
|
||||
await gpuReadBuffer.mapAsync(GPUMapMode.READ);
|
||||
const resultBuffer = new Float32Array(gpuReadBuffer.size);
|
||||
resultBuffer.set(new Float32Array(gpuReadBuffer.getMappedRange()));
|
||||
gpuReadBuffer.unmap();
|
||||
return resultBuffer;
|
||||
}}
|
||||
}}
|
||||
""" + f"\n\nconst loadNet = async (device) => {{ return await fetch('net.safetensors').then(x => x.arrayBuffer()).then(x => setupNet(device, new Uint8Array(x))); }}"
|
||||
|
||||
def export_model(model, input:Tensor, target:str):
|
||||
assert Device.DEFAULT in ["WEBGPU", "CLANG", "CUDA", "GPU", "METAL"], "only WEBGPU, CLANG, CUDA, GPU, METAL are supported"
|
||||
run,special_names = jit_model(model, input)
|
||||
functions, statements, bufs, bufs_to_save = compile_net(run, special_names)
|
||||
state = get_state_dict(model)
|
||||
weight_names = {id(x.lazydata.realized): name for name, x in state.items()}
|
||||
prg = ""
|
||||
if target == "clang":
|
||||
prg = export_model_clang(functions, statements, bufs, bufs_to_save)
|
||||
elif target == "webgpu":
|
||||
prg = export_model_webgpu(functions, statements, bufs, bufs_to_save, weight_names)
|
||||
else:
|
||||
prg = json.dumps({
|
||||
"backend": Device.DEFAULT,
|
||||
"input": {
|
||||
"size": bufs['input'][0],
|
||||
"dtype": bufs['input'][1].name
|
||||
},
|
||||
"output": {
|
||||
"size": bufs["outputs"][0],
|
||||
"dtype": bufs["outputs"][1].name
|
||||
},
|
||||
"functions": functions,
|
||||
"statements": [{
|
||||
"kernel": kernel,
|
||||
"args": args,
|
||||
"global_size": global_size,
|
||||
"local_size": local_size
|
||||
} for (kernel, args, global_size, local_size) in statements],
|
||||
"buffers": {
|
||||
name: {
|
||||
"size": size,
|
||||
"dtype": dtype.name,
|
||||
"id": weight_names[_key] if _key in weight_names else ""
|
||||
} for name, (size,dtype,_key) in bufs.items() if name not in ["input", "outputs"]
|
||||
}
|
||||
})
|
||||
|
||||
return prg, bufs['input'][0], bufs['outputs'][0], state
|
||||
@@ -188,7 +188,6 @@ class TestIndexExpressions2d(unittest.TestCase):
|
||||
st.expand((base_shape[0], base_shape[1], base_shape[1]))
|
||||
self.node_exprs.append(lambda idx, base_shape=base_shape, offset=offset: idx//(base_shape[1]*base_shape[1])%base_shape[0]*base_shape[1] + idx%base_shape[1] + offset)
|
||||
self.idxs_exprs.append(lambda idxs, base_shape=base_shape, offset=offset: idxs[0]*base_shape[1] + idxs[2] + offset)
|
||||
|
||||
def test_permute_reshape_1(self): # This tests multiple views
|
||||
for st, base_shape, offset in zip(self.sts, self.shapes, self.offset):
|
||||
st.permute((1, 0))
|
||||
|
||||
@@ -8,10 +8,10 @@ args = {
|
||||
'Linux': {'cflags':'-lm -fPIC --rtlib=compiler-rt ', 'ext':'so', 'exp':''},
|
||||
'Darwin': {'cflags':'-lm -fPIC --rtlib=compiler-rt ', 'ext':'dylib', 'exp':''}
|
||||
}[platform.system()]
|
||||
|
||||
CLANG_PROGRAM_HEADER = '#include <math.h>\n#define max(x,y) ((x>y)?x:y)\n#define int64 long\n#define half __fp16\n#define uchar unsigned char\n#define bool uchar\n'
|
||||
class ClangProgram:
|
||||
def __init__(self, name:str, prg:str):
|
||||
prg = '#include <math.h>\n#define max(x,y) ((x>y)?x:y)\n#define int64 long\n#define half __fp16\n#define uchar unsigned char\n#define bool uchar\n' + prg
|
||||
prg = CLANG_PROGRAM_HEADER + prg
|
||||
# TODO: is there a way to not write this to disk?
|
||||
fn = f"{tempfile.gettempdir()}/clang_{hashlib.md5(prg.encode('utf-8')).hexdigest()}.{args['ext']}"
|
||||
if not os.path.exists(fn):
|
||||
|
||||
Reference in New Issue
Block a user