mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 15:08:02 -05:00
Allow multi-input model export (#1995)
* Allow multi-input model export * Add model export unit test * Fix efficientnet compilation * Only run model export test on JIT supported devices * Skip export model test if not EXPORT_SUPPORTED_DEVICE
This commit is contained in:
@@ -11,7 +11,7 @@ if __name__ == "__main__":
|
||||
model = EfficientNet(0)
|
||||
model.load_from_pretrained()
|
||||
mode = "clang" if getenv("CLANG", "") != "" else "webgpu" if getenv("WEBGPU", "") != "" else ""
|
||||
prg, inp_size, out_size, state = export_model(model, Tensor.randn(1,3,224,224), mode)
|
||||
prg, inp_sizes, out_size, state = export_model(model, mode, Tensor.randn(1,3,224,224))
|
||||
dirname = Path(__file__).parent
|
||||
if getenv("CLANG", "") == "":
|
||||
safe_save(state, (dirname / "net.safetensors").as_posix())
|
||||
@@ -27,8 +27,9 @@ if __name__ == "__main__":
|
||||
lbls = fetch("https://gist.githubusercontent.com/yrevar/942d3a0ac09ec9e5eb3a/raw/238f720ff059c1f82f368259d1ca4ffa5dd8f9f5/imagenet1000_clsidx_to_labels.txt")
|
||||
lbls = ast.literal_eval(lbls.decode('utf-8'))
|
||||
lbls = ['"'+lbls[i]+'"' for i in range(1000)]
|
||||
inputs = "\n".join([f"float {inp}[{inp_size}];" for inp,inp_size in inp_sizes.items()])
|
||||
cprog.append(f"char *lbls[] = {{{','.join(lbls)}}};")
|
||||
cprog.append(f"float input[{inp_size}];")
|
||||
cprog.append(inputs)
|
||||
cprog.append(f"float outputs[{out_size}];")
|
||||
|
||||
# buffers (empty + weights)
|
||||
@@ -47,11 +48,11 @@ if __name__ == "__main__":
|
||||
int tx = (x/224.)*X;
|
||||
int ty = (y/224.)*Y;
|
||||
for (int c = 0; c < 3; c++) {
|
||||
input[c*224*224 + y*224 + x] = (image[ty*X*chan + tx*chan + c] / 255.0 - 0.45) / 0.225;
|
||||
input0[c*224*224 + y*224 + x] = (image[ty*X*chan + tx*chan + c] / 255.0 - 0.45) / 0.225;
|
||||
}
|
||||
}
|
||||
}
|
||||
net(input, outputs);
|
||||
net(input0, outputs);
|
||||
float best = -INFINITY;
|
||||
int best_idx = -1;
|
||||
for (int i = 0; i < 1000; i++) {
|
||||
|
||||
@@ -5,6 +5,8 @@ from tinygrad.jit import TinyJit
|
||||
from tinygrad.nn.state import get_state_dict
|
||||
import json
|
||||
|
||||
EXPORT_SUPPORTED_DEVICE = ["WEBGPU", "CLANG", "CUDA", "GPU", "METAL"]
|
||||
|
||||
def compile_net(run:TinyJit, special_names:Dict[int,str]) -> Tuple[Dict[str,str],List[Tuple[str,List[str],List[int]]],Dict[str,Tuple[int,DType,int]],Dict[str,Tensor]]:
|
||||
functions, bufs, bufs_to_save, statements, bufnum = {}, {}, {}, [], 0
|
||||
for fxn,args,var_vals in run.jit_cache:
|
||||
@@ -25,24 +27,26 @@ def compile_net(run:TinyJit, special_names:Dict[int,str]) -> Tuple[Dict[str,str]
|
||||
|
||||
return functions, statements, {name:(size, dtype, key) for (name,size,dtype,key) in bufs.values()}, bufs_to_save
|
||||
|
||||
def jit_model(model, the_input:Tensor) -> Tuple[TinyJit,Dict[int,str]]:
|
||||
def jit_model(model, *args) -> Tuple[TinyJit,Dict[int,str]]:
|
||||
assert hasattr(model, "forward") or callable(model), "model needs a forward function"
|
||||
@TinyJit
|
||||
def run(x): return (model.forward(x) if hasattr(model, "forward") else model(x)).realize()
|
||||
def run(*x): return (model.forward(*x) if hasattr(model, "forward") else model(*x)).realize()
|
||||
|
||||
# twice to run the JIT
|
||||
for _ in range(2): the_output = run(the_input)
|
||||
for _ in range(2): the_output = run(*args)
|
||||
special_names = {}
|
||||
|
||||
# hack to put the inputs back
|
||||
assert len(run.input_replace) == 1, f"didn't get one input to replace {run.input_replace}"
|
||||
for (j,i),idx in run.input_replace.items():
|
||||
run.jit_cache[j][1][i] = the_input.lazydata.realized
|
||||
realized_input = args[idx[0]].lazydata.realized
|
||||
run.jit_cache[j][1][i] = realized_input
|
||||
special_names[id(realized_input)] = f'input{idx[0]}'
|
||||
|
||||
# TODO: fetch this from the jit in self.input_replace and self.ret (hint: use get_parameters on self.ret)
|
||||
special_names = {id(the_input.lazydata.realized): "input", id(the_output.lazydata.realized): "outputs"}
|
||||
special_names[id(the_output.lazydata.realized)] = "outputs"
|
||||
return run, special_names
|
||||
|
||||
def export_model_clang(functions:Dict[str,str], statements:Dict[str,Tuple[str,int,int]], bufs:Dict[str,Tuple[str,int,int]], bufs_to_save:Dict[str,Tensor]) -> str:
|
||||
def export_model_clang(functions:Dict[str,str], statements:Dict[str,Tuple[str,int,int]], bufs:Dict[str,Tuple[str,int,int]], bufs_to_save:Dict[str,Tensor], input_names:List[str]) -> str:
|
||||
from tinygrad.runtime.ops_clang import CLANG_PROGRAM_HEADER
|
||||
cprog = [CLANG_PROGRAM_HEADER]
|
||||
|
||||
@@ -50,16 +54,19 @@ def export_model_clang(functions:Dict[str,str], statements:Dict[str,Tuple[str,in
|
||||
weight = ''.join(["\\x%02X"%x for x in bytes(cl._buf)])
|
||||
cprog.append(f"unsigned char {name}_data[] = \"{weight}\";")
|
||||
|
||||
inputs = ", ".join([f'float* {input}' for input in input_names])
|
||||
cprog += [f"float {name}[{len}];" if name not in bufs_to_save else f"float *{name} = (float *){name}_data;" for name,(len,dtype,_key) in bufs.items() if name not in ['input', 'outputs']]
|
||||
cprog += list(functions.values())
|
||||
cprog += ["void net(float* input, float* outputs) {"] + [f"{name}({', '.join(args)});" for (name, args, _global_size, _local_size) in statements] + ["}"]
|
||||
cprog += [f"void net({inputs}, float* outputs) {{"] + [f"{name}({', '.join(args)});" for (name, args, _global_size, _local_size) in statements] + ["}"]
|
||||
return '\n'.join(cprog)
|
||||
|
||||
def export_model_webgpu(functions, statements, bufs, bufs_to_save, weight_names) -> Tuple[str,int,int]:
|
||||
def export_model_webgpu(functions, statements, bufs, bufs_to_save, weight_names, input_names) -> Tuple[str,int,int]:
|
||||
kernel_code = '\n\n'.join([f"const {key} = `{code.replace(key, 'main')}`;" for key, code in functions.items()])
|
||||
kernel_names = ', '.join([name for (name, _args, _global_size, _local_size) in statements])
|
||||
kernel_calls = '\n '.join([f"addComputePass(device, commandEncoder, piplines[{i}], [{', '.join(args)}], {global_size});" for i, (_name, args, global_size, _local_size) in enumerate(statements) ])
|
||||
kernel_calls = '\n '.join([f"addComputePass(device, commandEncoder, piplines[{i}], [{', '.join(args)}], {global_size});" for i, (_name, args, global_size, _local_size) in enumerate(statements) ])
|
||||
_bufs = '\n '.join([f"const {name} = " + (f"createEmptyBuf(device, {size});" if _key not in weight_names else f"createWeightBuf(device, {size}, getTensorBuffer(safetensor, metadata['{weight_names[_key]}']))") + ";" for name,(size,dtype,_key) in bufs.items()])
|
||||
gpu_write_bufs = '\n '.join([f"const gpuWriteBuffer{i} = device.createBuffer({{size:{input_name}.size, usage: GPUBufferUsage.COPY_SRC | GPUBufferUsage.MAP_WRITE }});" for i,input_name in enumerate(input_names)])
|
||||
input_writers = '\n '.join([f"await gpuWriteBuffer{i}.mapAsync(GPUMapMode.WRITE);\n new Float32Array(gpuWriteBuffer{i}.getMappedRange()).set(" + f'_{inp_name});' + f"\n gpuWriteBuffer{i}.unmap();\n commandEncoder.copyBufferToBuffer(gpuWriteBuffer{i}, 0, {inp_name}, 0, gpuWriteBuffer{i}.size);" for i,inp_name in enumerate(input_names)])
|
||||
return f"""
|
||||
const getTensorMetadata = (safetensorBuffer) => {{
|
||||
const metadataLength = Number(new DataView(safetensorBuffer.buffer).getBigUint64(0, true));
|
||||
@@ -98,19 +105,16 @@ const setupNet = async (device, safetensor) => {{
|
||||
|
||||
{_bufs}
|
||||
|
||||
const gpuWriteBuffer = device.createBuffer({{size:input.size, usage: GPUBufferUsage.COPY_SRC | GPUBufferUsage.MAP_WRITE }});
|
||||
{gpu_write_bufs}
|
||||
|
||||
const gpuReadBuffer = device.createBuffer({{ size: outputs.size, usage: GPUBufferUsage.COPY_DST | GPUBufferUsage.MAP_READ }});
|
||||
|
||||
const kernels = [{kernel_names}];
|
||||
const piplines = await Promise.all(kernels.map(name => device.createComputePipelineAsync({{layout: "auto", compute: {{ module: device.createShaderModule({{ code: name }}), entryPoint: "main" }}}})));
|
||||
|
||||
return async (data) => {{
|
||||
await gpuWriteBuffer.mapAsync(GPUMapMode.WRITE);
|
||||
new Float32Array(gpuWriteBuffer.getMappedRange()).set(data);
|
||||
gpuWriteBuffer.unmap();
|
||||
|
||||
return async ({",".join([f"_{input_name}" for input_name in input_names])}) => {{
|
||||
const commandEncoder = device.createCommandEncoder();
|
||||
commandEncoder.copyBufferToBuffer(gpuWriteBuffer, 0, input, 0, gpuWriteBuffer.size);
|
||||
{input_writers}
|
||||
{kernel_calls}
|
||||
commandEncoder.copyBufferToBuffer(outputs, 0, gpuReadBuffer, 0, outputs.size);
|
||||
const gpuCommands = commandEncoder.finish();
|
||||
@@ -125,24 +129,25 @@ const setupNet = async (device, safetensor) => {{
|
||||
}}
|
||||
""" + f"\n\nconst loadNet = async (device) => {{ return await fetch('net.safetensors').then(x => x.arrayBuffer()).then(x => setupNet(device, new Uint8Array(x))); }}"
|
||||
|
||||
def export_model(model, input:Tensor, target:str):
|
||||
assert Device.DEFAULT in ["WEBGPU", "CLANG", "CUDA", "GPU", "METAL"], "only WEBGPU, CLANG, CUDA, GPU, METAL are supported"
|
||||
run,special_names = jit_model(model, input)
|
||||
def export_model(model, target:str, *inputs):
|
||||
assert Device.DEFAULT in EXPORT_SUPPORTED_DEVICE, "only WEBGPU, CLANG, CUDA, GPU, METAL are supported"
|
||||
run,special_names = jit_model(model, *inputs)
|
||||
functions, statements, bufs, bufs_to_save = compile_net(run, special_names)
|
||||
state = get_state_dict(model)
|
||||
weight_names = {id(x.lazydata.realized): name for name, x in state.items()}
|
||||
input_names = [name for _,name in special_names.items() if "input" in name]
|
||||
prg = ""
|
||||
if target == "clang":
|
||||
prg = export_model_clang(functions, statements, bufs, bufs_to_save)
|
||||
prg = export_model_clang(functions, statements, bufs, bufs_to_save, input_names)
|
||||
elif target == "webgpu":
|
||||
prg = export_model_webgpu(functions, statements, bufs, bufs_to_save, weight_names)
|
||||
prg = export_model_webgpu(functions, statements, bufs, bufs_to_save, weight_names, input_names)
|
||||
else:
|
||||
prg = json.dumps({
|
||||
"backend": Device.DEFAULT,
|
||||
"input": {
|
||||
"size": bufs['input'][0],
|
||||
"dtype": bufs['input'][1].name
|
||||
},
|
||||
"inputs": [{
|
||||
"size": bufs[name][0],
|
||||
"dtype": bufs[name][1].name
|
||||
} for name in input_names],
|
||||
"output": {
|
||||
"size": bufs["outputs"][0],
|
||||
"dtype": bufs["outputs"][1].name
|
||||
@@ -163,4 +168,4 @@ def export_model(model, input:Tensor, target:str):
|
||||
}
|
||||
})
|
||||
|
||||
return prg, bufs['input'][0], bufs['outputs'][0], state
|
||||
return prg, {input:bufs[input][0] for input in input_names}, bufs['outputs'][0], state
|
||||
|
||||
29
test/extra/test_export_model.py
Normal file
29
test/extra/test_export_model.py
Normal file
@@ -0,0 +1,29 @@
|
||||
import unittest
|
||||
from extra.export_model import export_model, EXPORT_SUPPORTED_DEVICE
|
||||
from tinygrad.tensor import Tensor, Device
|
||||
import json
|
||||
|
||||
class MockMultiInputModel:
|
||||
def forward(self, x1, x2, x3):
|
||||
return x1 + x2 + x3
|
||||
|
||||
# TODO: move compile_efficientnet tests here
|
||||
@unittest.skipUnless(Device.DEFAULT in EXPORT_SUPPORTED_DEVICE, f"Model export is not supported on {Device.DEFAULT}")
|
||||
class TextModelExport(unittest.TestCase):
|
||||
def test_multi_input_model_export(self):
|
||||
model = MockMultiInputModel()
|
||||
inputs = [Tensor.rand(2,2), Tensor.rand(2,2), Tensor.rand(2,2)]
|
||||
prg, inp_sizes, _, _ = export_model(model, "", *inputs)
|
||||
prg = json.loads(prg)
|
||||
|
||||
assert len(inputs) == len(prg["inputs"]) == len(inp_sizes), f"Model and exported inputs don't match: mdl={len(inputs)}, prg={len(prg['inputs'])}, inp_sizes={len(inp_sizes)}"
|
||||
|
||||
for i in range(len(inputs)):
|
||||
assert f"input{i}" in inp_sizes, f"input{i} not captured in inp_sizes"
|
||||
assert f"input{i}" in prg["buffers"], f"input{i} not captured in exported buffers"
|
||||
|
||||
for i, exported_input in enumerate(prg["inputs"]):
|
||||
assert inputs[i].dtype.name == exported_input["dtype"], f"Model and exported input dtype don't match: mdl={inputs[i].dtype.name}, prg={exported_input['dtype']}"
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
Reference in New Issue
Block a user