mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-10 07:28:15 -05:00
fix assertion message for supported device in export_model (#9957)
This commit is contained in:
@@ -238,7 +238,7 @@ export default {model_name};
|
||||
"""
|
||||
|
||||
def export_model(model, target:str, *inputs, model_name: Optional[str] = "model", stream_weights=False):
|
||||
assert Device.DEFAULT in EXPORT_SUPPORTED_DEVICE, "only WEBGPU, CPU, CUDA, GPU, METAL are supported"
|
||||
assert Device.DEFAULT in EXPORT_SUPPORTED_DEVICE, f"only {', '.join(EXPORT_SUPPORTED_DEVICE)} are supported"
|
||||
with Context(JIT=2): run,special_names = jit_model(model, *inputs)
|
||||
functions, statements, bufs, bufs_to_save = compile_net(run, special_names)
|
||||
state = get_state_dict(model)
|
||||
|
||||
Reference in New Issue
Block a user