fix assertion message for supported device in export_model (#9957)

This commit is contained in:
akhuntsaria
2025-04-21 15:23:44 +02:00
committed by GitHub
parent 783a191925
commit 2d423e6737

View File

@@ -238,7 +238,7 @@ export default {model_name};
"""
def export_model(model, target:str, *inputs, model_name: Optional[str] = "model", stream_weights=False):
assert Device.DEFAULT in EXPORT_SUPPORTED_DEVICE, "only WEBGPU, CPU, CUDA, GPU, METAL are supported"
assert Device.DEFAULT in EXPORT_SUPPORTED_DEVICE, f"only {', '.join(EXPORT_SUPPORTED_DEVICE)} are supported"
with Context(JIT=2): run,special_names = jit_model(model, *inputs)
functions, statements, bufs, bufs_to_save = compile_net(run, special_names)
state = get_state_dict(model)