Clean up the v-diffusion install pipeline (#327)

This commit is contained in:
Quinn Dawkins
2022-09-16 11:47:07 -04:00
committed by GitHub
parent c43448a826
commit 9bd951b083
5 changed files with 39 additions and 4 deletions

View File

@@ -27,7 +27,25 @@ Run the script setup_v_diffusion_pytorch.sh
./v-diffusion-pytorch/cfg_sample.py "New York City, oil on canvas":5 -n 5 -bs 5
```
The runtime device can be specified with `--runtime_device=<device string>`
### Run the v-diffusion model via torch-mlir
```shell
./cfg_sample.py "New York City, oil on canvas":5 -n 1 -bs 1 --steps 2
```
### Run the model stored in the tank
```shell
./cfg_sample_from_mlir.py "New York City, oil on canvas":5 -n 1 -bs 1 --steps 2
```
Note that the current model in the tank requires batch size 1 statically.
### Run the model with preprocessing elements taken out
To run the model without preprocessing copy `cc12m_1.py` to replace the version in `v-diffusion-pytorch`
```shell
cp cc12m_1.py v-diffusion-pytorch/diffusion/models
```
Then run
```shell
./cfg_sample_preprocess.py "New York City, oil on canvas":5 -n 1 -bs 1 --steps 2
```

View File

@@ -67,6 +67,12 @@ p.add_argument(
)
p.add_argument("--checkpoint", type=str, help="the checkpoint to use")
p.add_argument("--device", type=str, help="the device to use")
p.add_argument(
"--runtime_device",
type=str,
help="the device to use with SHARK",
default="cpu",
)
p.add_argument(
"--eta",
type=float,
@@ -235,7 +241,7 @@ mlir_model = module
func_name = "forward"
shark_module = SharkInference(
mlir_model, func_name, device="cuda", mlir_dialect="linalg"
mlir_model, func_name, device=args.runtime_device, mlir_dialect="linalg"
)
shark_module.compile()

View File

@@ -69,6 +69,12 @@ p.add_argument(
)
p.add_argument("--checkpoint", type=str, help="the checkpoint to use")
p.add_argument("--device", type=str, help="the device to use")
p.add_argument(
"--runtime_device",
type=str,
help="the device to use with SHARK",
default="cpu",
)
p.add_argument(
"--eta",
type=float,
@@ -188,7 +194,7 @@ t_in = t[0] * ts
mlir_model, func_name, inputs, golden_out = download_torch_model("v_diffusion")
shark_module = SharkInference(
mlir_model, func_name, device="cpu", mlir_dialect="linalg"
mlir_model, func_name, device=args.runtime_device, mlir_dialect="linalg"
)
shark_module.compile()

View File

@@ -69,6 +69,12 @@ p.add_argument(
)
p.add_argument("--checkpoint", type=str, help="the checkpoint to use")
p.add_argument("--device", type=str, help="the device to use")
p.add_argument(
"--runtime_device",
type=str,
help="the device to use with SHARK",
default="intel-gpu",
)
p.add_argument(
"--eta",
type=float,
@@ -260,7 +266,7 @@ mlir_model = module
func_name = "forward"
shark_module = SharkInference(
mlir_model, func_name, device="intel-gpu", mlir_dialect="linalg"
mlir_model, func_name, device=args.runtime_device, mlir_dialect="linalg"
)
shark_module.compile()

View File

@@ -24,4 +24,3 @@ mkdir checkpoints
wget https://the-eye.eu/public/AI/models/v-diffusion/cc12m_1_cfg.pth -P checkpoints/
cp -r checkpoints/ v-diffusion-pytorch/
cp cc12m_1.py v-diffusion-pytorch/diffusion/models/.