mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 23:18:04 -05:00
Enable Multi-Output Export (#2179)
* Enable Multi-Output Export * Add test * Update examples and lint * fix padding * test ops * dummy commit to rerun test * revert cuda lint * Enforce tuple/list of tensors * subscripted generics * put back webgpu test * Re-enable WebGPU Efficientnet test
This commit is contained in:
4
.github/workflows/test.yml
vendored
4
.github/workflows/test.yml
vendored
@@ -243,6 +243,10 @@ jobs:
|
||||
run: WEBGPU=1 WGPU_BACKEND_TYPE=Metal python -m pytest -n=auto test/test_dtype.py
|
||||
- name: Build WEBGPU Efficientnet
|
||||
run: WEBGPU=1 WGPU_BACKEND_TYPE=Metal python -m examples.compile_efficientnet
|
||||
- name: Install Puppeteer
|
||||
run: npm install puppeteer
|
||||
- name: Run WEBGPU Efficientnet
|
||||
run: node test/test_webgpu.js
|
||||
|
||||
tests:
|
||||
strategy:
|
||||
|
||||
@@ -11,7 +11,7 @@ if __name__ == "__main__":
|
||||
model = EfficientNet(0)
|
||||
model.load_from_pretrained()
|
||||
mode = "clang" if getenv("CLANG", "") != "" else "webgpu" if getenv("WEBGPU", "") != "" else ""
|
||||
prg, inp_sizes, out_size, state = export_model(model, mode, Tensor.randn(1,3,224,224))
|
||||
prg, inp_sizes, out_sizes, state = export_model(model, mode, Tensor.randn(1,3,224,224))
|
||||
dirname = Path(__file__).parent
|
||||
if getenv("CLANG", "") == "":
|
||||
safe_save(state, (dirname / "net.safetensors").as_posix())
|
||||
@@ -28,9 +28,10 @@ if __name__ == "__main__":
|
||||
lbls = ast.literal_eval(lbls.decode('utf-8'))
|
||||
lbls = ['"'+lbls[i]+'"' for i in range(1000)]
|
||||
inputs = "\n".join([f"float {inp}[{inp_size}];" for inp,inp_size in inp_sizes.items()])
|
||||
outputs = "\n".join([f"float {out}[{out_size}];" for out,out_size in out_sizes.items()])
|
||||
cprog.append(f"char *lbls[] = {{{','.join(lbls)}}};")
|
||||
cprog.append(inputs)
|
||||
cprog.append(f"float outputs[{out_size}];")
|
||||
cprog.append(outputs)
|
||||
|
||||
# buffers (empty + weights)
|
||||
cprog.append("""
|
||||
@@ -52,12 +53,12 @@ if __name__ == "__main__":
|
||||
}
|
||||
}
|
||||
}
|
||||
net(input0, outputs);
|
||||
net(input0, output0);
|
||||
float best = -INFINITY;
|
||||
int best_idx = -1;
|
||||
for (int i = 0; i < 1000; i++) {
|
||||
if (outputs[i] > best) {
|
||||
best = outputs[i];
|
||||
if (output0[i] > best) {
|
||||
best = output0[i];
|
||||
best_idx = i;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -18,6 +18,7 @@ canvas { display: none; }
|
||||
</style>
|
||||
<title>tinygrad has WebGPU</title>
|
||||
<script src="./net.js"></script>
|
||||
<link rel="icon" type="image/x-icon" href="https://raw.githubusercontent.com/tinygrad/tinygrad/master/docs/logo.png">
|
||||
</head>
|
||||
<body>
|
||||
<h1>WebGPU <a href="https://github.com/geohot/tinygrad">tinygrad</a> EfficientNet!</h1>
|
||||
@@ -111,7 +112,7 @@ canvas { display: none; }
|
||||
const input = reorderChannelsAndRemoveAlpha(Array.from(data).map((pix) => (pix / 255.0) * 0.45 - 0.225));
|
||||
const out = await timer(() => net(new Float32Array(input)));
|
||||
|
||||
const arr = Array.from(new Float32Array(out));
|
||||
const arr = Array.from(new Float32Array(out[0]));
|
||||
const index = arr.indexOf(Math.max(...arr));
|
||||
|
||||
resultText.textContent = labels[index];
|
||||
|
||||
@@ -30,7 +30,11 @@ def compile_net(run:TinyJit, special_names:Dict[int,str]) -> Tuple[Dict[str,str]
|
||||
def jit_model(model, *args) -> Tuple[TinyJit,Dict[int,str]]:
|
||||
assert hasattr(model, "forward") or callable(model), "model needs a forward function"
|
||||
@TinyJit
|
||||
def run(*x): return (model.forward(*x) if hasattr(model, "forward") else model(*x)).realize()
|
||||
def run(*x):
|
||||
out = model.forward(*x) if hasattr(model, "forward") else model(*x)
|
||||
assert isinstance(out, tuple) or isinstance(out, list) or isinstance(out, Tensor), "model output must be a Tensor, tuple, or a list of Tensors for export"
|
||||
out = [out] if isinstance(out, Tensor) else out
|
||||
return [o.realize() for o in out]
|
||||
|
||||
# twice to run the JIT
|
||||
for _ in range(2): the_output = run(*args)
|
||||
@@ -43,10 +47,11 @@ def jit_model(model, *args) -> Tuple[TinyJit,Dict[int,str]]:
|
||||
special_names[id(realized_input)] = f'input{idx[0]}'
|
||||
|
||||
# TODO: fetch this from the jit in self.input_replace and self.ret (hint: use get_parameters on self.ret)
|
||||
special_names[id(the_output.lazydata.realized)] = "outputs"
|
||||
for i, output in enumerate(the_output):
|
||||
special_names[id(output.lazydata.realized)] = f'output{i}'
|
||||
return run, special_names
|
||||
|
||||
def export_model_clang(functions:Dict[str,str], statements:Dict[str,Tuple[str,int,int]], bufs:Dict[str,Tuple[str,int,int]], bufs_to_save:Dict[str,Tensor], input_names:List[str]) -> str:
|
||||
def export_model_clang(functions:Dict[str,str], statements:Dict[str,Tuple[str,int,int]], bufs:Dict[str,Tuple[str,int,int]], bufs_to_save:Dict[str,Tensor], input_names:List[str], output_names:List[str]) -> str:
|
||||
from tinygrad.runtime.ops_clang import CLANG_PROGRAM_HEADER
|
||||
cprog = [CLANG_PROGRAM_HEADER]
|
||||
|
||||
@@ -55,18 +60,23 @@ def export_model_clang(functions:Dict[str,str], statements:Dict[str,Tuple[str,in
|
||||
cprog.append(f"unsigned char {name}_data[] = \"{weight}\";")
|
||||
|
||||
inputs = ", ".join([f'float* {input}' for input in input_names])
|
||||
outputs = ", ".join([f'float* {output}' for output in output_names])
|
||||
cprog += [f"float {name}[{len}];" if name not in bufs_to_save else f"float *{name} = (float *){name}_data;" for name,(len,dtype,_key) in bufs.items() if name not in ['input', 'outputs']]
|
||||
cprog += list(functions.values())
|
||||
cprog += [f"void net({inputs}, float* outputs) {{"] + [f"{name}({', '.join(args)});" for (name, args, _global_size, _local_size) in statements] + ["}"]
|
||||
cprog += [f"void net({inputs}, {outputs}) {{"] + [f"{name}({', '.join(args)});" for (name, args, _global_size, _local_size) in statements] + ["}"]
|
||||
return '\n'.join(cprog)
|
||||
|
||||
def export_model_webgpu(functions, statements, bufs, bufs_to_save, weight_names, input_names) -> Tuple[str,int,int]:
|
||||
def export_model_webgpu(functions, statements, bufs, bufs_to_save, weight_names, input_names, output_names) -> Tuple[str,int,int]:
|
||||
kernel_code = '\n\n'.join([f"const {key} = `{code.replace(key, 'main')}`;" for key, code in functions.items()])
|
||||
kernel_names = ', '.join([name for (name, _args, _global_size, _local_size) in statements])
|
||||
kernel_calls = '\n '.join([f"addComputePass(device, commandEncoder, piplines[{i}], [{', '.join(args)}], {global_size});" for i, (_name, args, global_size, _local_size) in enumerate(statements) ])
|
||||
_bufs = '\n '.join([f"const {name} = " + (f"createEmptyBuf(device, {size});" if _key not in weight_names else f"createWeightBuf(device, {size}, getTensorBuffer(safetensor, metadata['{weight_names[_key]}']))") + ";" for name,(size,dtype,_key) in bufs.items()])
|
||||
gpu_write_bufs = '\n '.join([f"const gpuWriteBuffer{i} = device.createBuffer({{size:{input_name}.size, usage: GPUBufferUsage.COPY_SRC | GPUBufferUsage.MAP_WRITE }});" for i,input_name in enumerate(input_names)])
|
||||
input_writers = '\n '.join([f"await gpuWriteBuffer{i}.mapAsync(GPUMapMode.WRITE);\n new Float32Array(gpuWriteBuffer{i}.getMappedRange()).set(" + f'_{inp_name});' + f"\n gpuWriteBuffer{i}.unmap();\n commandEncoder.copyBufferToBuffer(gpuWriteBuffer{i}, 0, {inp_name}, 0, gpuWriteBuffer{i}.size);" for i,inp_name in enumerate(input_names)])
|
||||
gpu_read_bufs = '\n '.join([f"const gpuReadBuffer{i} = device.createBuffer({{size:{output_name}.size, usage: GPUBufferUsage.COPY_DST | GPUBufferUsage.MAP_READ }});" for i,output_name in enumerate(output_names)])
|
||||
outbuf_copies = '\n '.join([f"commandEncoder.copyBufferToBuffer({output_name}, 0, gpuReadBuffer{i}, 0, output{i}.size);" for i,output_name in enumerate(output_names)])
|
||||
output_readers = '\n '.join([f"await gpuReadBuffer{i}.mapAsync(GPUMapMode.READ);\n const resultBuffer{i} = new Float32Array(gpuReadBuffer{i}.size);\n resultBuffer{i}.set(new Float32Array(gpuReadBuffer{i}.getMappedRange()));\n gpuReadBuffer{i}.unmap();" for i in range(len(output_names))])
|
||||
output_return = '[{}]'.format(",".join([f'resultBuffer{i}' for i in range(len(output_names))]))
|
||||
return f"""
|
||||
const getTensorMetadata = (safetensorBuffer) => {{
|
||||
const metadataLength = Number(new DataView(safetensorBuffer.buffer).getBigUint64(0, true));
|
||||
@@ -107,7 +117,7 @@ const setupNet = async (device, safetensor) => {{
|
||||
|
||||
{gpu_write_bufs}
|
||||
|
||||
const gpuReadBuffer = device.createBuffer({{ size: outputs.size, usage: GPUBufferUsage.COPY_DST | GPUBufferUsage.MAP_READ }});
|
||||
{gpu_read_bufs}
|
||||
|
||||
const kernels = [{kernel_names}];
|
||||
const piplines = await Promise.all(kernels.map(name => device.createComputePipelineAsync({{layout: "auto", compute: {{ module: device.createShaderModule({{ code: name }}), entryPoint: "main" }}}})));
|
||||
@@ -116,15 +126,12 @@ const setupNet = async (device, safetensor) => {{
|
||||
const commandEncoder = device.createCommandEncoder();
|
||||
{input_writers}
|
||||
{kernel_calls}
|
||||
commandEncoder.copyBufferToBuffer(outputs, 0, gpuReadBuffer, 0, outputs.size);
|
||||
{outbuf_copies}
|
||||
const gpuCommands = commandEncoder.finish();
|
||||
device.queue.submit([gpuCommands]);
|
||||
|
||||
await gpuReadBuffer.mapAsync(GPUMapMode.READ);
|
||||
const resultBuffer = new Float32Array(gpuReadBuffer.size/4);
|
||||
resultBuffer.set(new Float32Array(gpuReadBuffer.getMappedRange()));
|
||||
gpuReadBuffer.unmap();
|
||||
return resultBuffer;
|
||||
{output_readers}
|
||||
return {output_return};
|
||||
}}
|
||||
}}
|
||||
""" + f"\n\nconst loadNet = async (device) => {{ return await fetch('net.safetensors').then(x => x.arrayBuffer()).then(x => setupNet(device, new Uint8Array(x))); }}"
|
||||
@@ -136,11 +143,12 @@ def export_model(model, target:str, *inputs):
|
||||
state = get_state_dict(model)
|
||||
weight_names = {id(x.lazydata.realized): name for name, x in state.items()}
|
||||
input_names = [name for _,name in special_names.items() if "input" in name]
|
||||
output_names = [name for _,name in special_names.items() if "output" in name]
|
||||
prg = ""
|
||||
if target == "clang":
|
||||
prg = export_model_clang(functions, statements, bufs, bufs_to_save, input_names)
|
||||
prg = export_model_clang(functions, statements, bufs, bufs_to_save, input_names, output_names)
|
||||
elif target == "webgpu":
|
||||
prg = export_model_webgpu(functions, statements, bufs, bufs_to_save, weight_names, input_names)
|
||||
prg = export_model_webgpu(functions, statements, bufs, bufs_to_save, weight_names, input_names, output_names)
|
||||
else:
|
||||
prg = json.dumps({
|
||||
"backend": Device.DEFAULT,
|
||||
@@ -148,10 +156,10 @@ def export_model(model, target:str, *inputs):
|
||||
"size": bufs[name][0],
|
||||
"dtype": bufs[name][1].name
|
||||
} for name in input_names],
|
||||
"output": {
|
||||
"size": bufs["outputs"][0],
|
||||
"dtype": bufs["outputs"][1].name
|
||||
},
|
||||
"outputs": [{
|
||||
"size": bufs[name][0],
|
||||
"dtype": bufs[name][1].name
|
||||
} for name in output_names],
|
||||
"functions": functions,
|
||||
"statements": [{
|
||||
"kernel": kernel,
|
||||
@@ -168,4 +176,4 @@ def export_model(model, target:str, *inputs):
|
||||
}
|
||||
})
|
||||
|
||||
return prg, {input:bufs[input][0] for input in input_names}, bufs['outputs'][0], state
|
||||
return prg, {input:bufs[input][0] for input in input_names}, {output:bufs[output][0] for output in output_names}, state
|
||||
|
||||
@@ -7,6 +7,10 @@ class MockMultiInputModel:
|
||||
def forward(self, x1, x2, x3):
|
||||
return x1 + x2 + x3
|
||||
|
||||
class MockMultiOutputModel:
|
||||
def __call__(self, x1):
|
||||
return x1 + 2.0, x1.pad(((0, 0), (0, 1))) + 1.0
|
||||
|
||||
# TODO: move compile_efficientnet tests here
|
||||
@unittest.skipUnless(Device.DEFAULT in EXPORT_SUPPORTED_DEVICE, f"Model export is not supported on {Device.DEFAULT}")
|
||||
class TextModelExport(unittest.TestCase):
|
||||
@@ -18,12 +22,29 @@ class TextModelExport(unittest.TestCase):
|
||||
|
||||
assert len(inputs) == len(prg["inputs"]) == len(inp_sizes), f"Model and exported inputs don't match: mdl={len(inputs)}, prg={len(prg['inputs'])}, inp_sizes={len(inp_sizes)}"
|
||||
|
||||
for i in range(len(inputs)):
|
||||
for i in range(len(inputs)):
|
||||
assert f"input{i}" in inp_sizes, f"input{i} not captured in inp_sizes"
|
||||
assert f"input{i}" in prg["buffers"], f"input{i} not captured in exported buffers"
|
||||
|
||||
for i, exported_input in enumerate(prg["inputs"]):
|
||||
for i, exported_input in enumerate(prg["inputs"]):
|
||||
assert inputs[i].dtype.name == exported_input["dtype"], f"Model and exported input dtype don't match: mdl={inputs[i].dtype.name}, prg={exported_input['dtype']}"
|
||||
|
||||
def test_multi_output_model_export(self):
|
||||
model = MockMultiOutputModel()
|
||||
input = Tensor.rand(2,2)
|
||||
outputs = model(input)
|
||||
prg, _, out_sizes, _ = export_model(model, "", input)
|
||||
prg = json.loads(prg)
|
||||
|
||||
assert len(outputs) == len(prg["outputs"]) == len(out_sizes), f"Model and exported outputs don't match: mdl={len(outputs)}, prg={len(prg['outputs'])}, inp_sizes={len(out_sizes)}"
|
||||
|
||||
for i in range(len(outputs)):
|
||||
assert f"output{i}" in out_sizes, f"output{i} not captured in out_sizes"
|
||||
assert f"output{i}" in prg["buffers"], f"output{i} not captured in exported buffers"
|
||||
|
||||
for i, exported_output in enumerate(prg["outputs"]):
|
||||
assert outputs[i].dtype.name == exported_output["dtype"], f"Model and exported output dtype don't match: mdl={outputs[i].dtype.name}, prg={exported_output['dtype']}"
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
||||
@@ -35,7 +35,7 @@ puppeteer.launch({ headless: false, args: ["--enable-unsafe-webgpu"]}).then(asyn
|
||||
page.on("console", message => console.log(`message from console ${message.text()}`))
|
||||
.on("pageerror", ({ message }) => console.log(`error from page ${message}`))
|
||||
|
||||
const res = await page.goto("http://localhost:8000/examples/webgpu/index.html");
|
||||
const res = await page.goto("http://localhost:8000/examples/index.html");
|
||||
if(res.status() != 200) throw new Error("Failed to load page");
|
||||
const textSelector = await page.waitForSelector("#result");
|
||||
const buttonSelector = await page.waitForSelector("input[type=button]");
|
||||
|
||||
Reference in New Issue
Block a user