Compare commits

...

308 Commits

Author SHA1 Message Date
Lincoln Stein
4ffdf73412 Merge branch 'development' into main
This merge adds the following major features:

* Support for image variations.

* Security fix for webGUI (binds to localhost by default, use
--host=0.0.0.0 to allow access from external interface.

* Scalable configs/models.yaml configuration file for adding more
models as they become available.

* More tuning and exception handling for M1 hardware running MPS.

* Various documentation fixes.
2022-09-03 11:58:46 -04:00
Lincoln Stein
9130ad7e08 make results section of webgui full width 2022-09-03 11:58:05 -04:00
tildebyte
d66010410c FEAT: add notebook for Windows for from-zero install and run (#164)
* Update README.md

Those []() link pairs get me every time.

* New issue template

* Added issue templates

* feat(install+run): add notebook for Windows for from-zero install...

...and run

Tested with JupyterLab and VSCode

Signed-off-by: Ben Alkov <ben.alkov@gmail.com>

Signed-off-by: Ben Alkov <ben.alkov@gmail.com>
Co-authored-by: Lincoln Stein <lincoln.stein@gmail.com>
Co-authored-by: James Reynolds <magnusviri@users.noreply.github.com>
Co-authored-by: James Reynolds <magnsuviri@me.com>
2022-09-03 11:49:37 -04:00
Lincoln Stein
6566c2298c add scalable support for new models using a configs/models.yaml file 2022-09-03 11:45:21 -04:00
Lincoln Stein
063b4a1995 add ability to specify location of config file (models.yaml) 2022-09-03 11:36:04 -04:00
Lincoln Stein
18cdb556bd update requirements.txt to run on m1 w/pip 2022-09-03 10:56:06 -04:00
Lincoln Stein
8d16a69b80 Merge branch 'erickhun-patch-1' into development 2022-09-03 10:45:12 -04:00
Lincoln Stein
a406b588b4 Merge branch 'development' into patch-1 2022-09-03 10:43:59 -04:00
Lincoln Stein
5454a0edc2 code cleanup
* check that fixed side provided when requesting variant parameter sweep
(-v)
* move _get_noise() into outer scope to improve readability -
refactoring of big method call needed
2022-09-03 10:40:20 -04:00
Lincoln Stein
fe5cc79249 fixes dream.py mps seed 2022-09-03 10:11:46 -04:00
Ben Alkov
361cc42829 fix(readme): Add individual OS links to TOC; fix TOC changelog link 2022-09-03 09:46:06 -04:00
Lincoln Stein
91cce6b4c3 move special-casing test for precision on mps into T2I class 2022-09-03 09:43:18 -04:00
Lincoln Stein
d0df894c9f Merge branch 'cgodley-web-host-port' into development 2022-09-03 09:33:46 -04:00
Lincoln Stein
f46916d521 Add warning message about change in default host 2022-09-03 09:33:02 -04:00
Lincoln Stein
12755c6ef6 Merge branch 'web-host-port' of github.com:cgodley/stable-diffusion into cgodley-web-host-port
this allows host and port to be set on --web command line.
changes default binding from 0.0.0.0 to 127.0.0.1
2022-09-03 09:12:32 -04:00
Lincoln Stein
cc4f33bf3a Merge branch 'bakkot-variant-commas' into development
Change the image variation weighting syntax to match the prompt weighting
syntax
2022-09-03 09:08:20 -04:00
Lincoln Stein
d8c0d020eb remove space between -V and its value in generated prompt, for consistency with other switches 2022-09-03 09:08:10 -04:00
Kevin Gibbons
e918cb1a8a replace list delimiters in variations syntax 2022-09-02 23:51:22 -07:00
Eric Khun
0163310a47 Merge branch 'development' into patch-1 2022-09-03 03:27:25 +00:00
Lincoln Stein
423d25716d Fix unclosed code section 2022-09-02 18:04:15 -04:00
Lincoln Stein
1d999ba974 Fix reference to variations walkthru
Had wrong href for the VARIATIONS.md file
2022-09-02 18:01:52 -04:00
Lincoln Stein
27d4bb5624 Merge branch 'development' of github.com:lstein/stable-diffusion into development
Synchronize with earlier changes in development
2022-09-02 18:00:02 -04:00
Lincoln Stein
c78b496da6 Merge branch 'bakkot-seed-fuzz' into development
This adds support for variations and mixtures of weighted variations
2022-09-02 17:58:42 -04:00
Lincoln Stein
dd2af3f93c added walkthru, small code fixes 2022-09-02 17:54:55 -04:00
Lincoln Stein
2d65b03f05 Merge branch 'seed-fuzz' of github.com:bakkot/stable-diffusion into bakkot-seed-fuzz 2022-09-02 16:17:51 -04:00
Cragin Godley
2288412ef2 dream.py: fix indentation 2022-09-02 15:00:07 -04:00
Cragin Godley
6bff985496 dream.py: include 0.0.0.0 in --host help
Co-authored-by: Kevin Gibbons <bakkot@gmail.com>
2022-09-02 14:58:57 -04:00
Cragin Godley
918ade12ed dream.py: use localhost in url when host is 0.0.0.0
Co-authored-by: Kevin Gibbons <bakkot@gmail.com>
2022-09-02 14:56:52 -04:00
Cragin Godley
68f62c8352 web: allow custom host/port, default to 127.0.0.1 for security reasons 2022-09-02 12:27:12 -04:00
James Reynolds
33936430d0 Added issue templates 2022-09-02 12:16:09 -04:00
James Reynolds
81b3de9c65 New issue template 2022-09-02 12:16:09 -04:00
Simon Vans-Colina
ad6cf6f2f7 Update readme to make it clearer for Windows users 2022-09-02 12:12:48 -04:00
Eric Khun
ecef72ca39 Merge branch 'development' into patch-1 2022-09-02 15:22:30 +00:00
Lincoln Stein
92d1ed744a Merge branch 'magnusviri-readme-mac-update-take2' into development 2022-09-02 10:42:39 -04:00
Lincoln Stein
da4bf95fbc Merge branch 'development' into readme-mac-update-take2 2022-09-02 10:41:10 -04:00
Lincoln Stein
d43c5c01e3 Update README.md
Those []() link pairs get me every time.
2022-09-02 10:38:59 -04:00
Lincoln Stein
51278c7a10 add brief contribution instructions in lieu of full code-of-conduct and contribution guidelines 2022-09-02 10:37:35 -04:00
James Reynolds
6ef7c1ad4e Added psychedelicious' changes 2022-09-02 08:29:28 -06:00
Lincoln Stein
33cc16473f Merge branch 'main' into development
Synchronizing dev with legacy pulls to main.
2022-09-02 10:22:57 -04:00
James Reynolds
1701c2ea94 README-Mac update 2022-09-02 10:22:26 -04:00
Lincoln Stein
2e299a1daf Merge branch 'gabrielrotbart-fix_img2img_m1' into main
This may improve the black image problem when using img2img with some
samplers on M1 hardware.
2022-09-02 10:18:06 -04:00
Lincoln Stein
0b582a40d0 add developer's guidance for refactoring this change 2022-09-02 10:17:51 -04:00
James Reynolds
1306457b27 README-Mac update 2022-09-02 08:17:19 -06:00
gabrielrotbart
f4a19af04f fix scope being set to autocast even for m1 2022-09-02 14:55:24 +03:00
Eric Khun
58545ba057 Update README-Mac-MPS.md 2022-09-02 19:52:13 +08:00
Kevin Gibbons
4fe265735a support generating variations
Co-authored-by: xra <mail@xra.dev>
2022-09-01 23:48:53 -07:00
James Reynolds
2b7f32502c Merge branch 'lstein:main' into main 2022-09-01 19:41:14 -06:00
Lincoln Stein
3ee82d8a3b Merge branch 'toffaletti-dream-m1' into main
This provides support for Apple M1 hardware
2022-09-01 17:55:36 -04:00
Lincoln Stein
629ca09fda Merge branch 'dream-m1' of github.com:toffaletti/stable-diffusion into toffaletti-dream-m1
* Fix conflicts with main branch changes
* Fix logic error in choose_autocast_device() that was causing crashes
on CUDA systems.
2022-09-01 17:54:01 -04:00
Lincoln Stein
833de06299 fix InitImageResizer not found error, closes #294 2022-09-01 16:16:46 -04:00
David Wager
68eabab2af Deprecate --laion400m and --weights arguments
Removes functionality for the --laion400m and --weights arguments and notifies user to use the --model argument instead.
2022-09-01 20:46:53 +01:00
David Wager
a4f69e62d7 Set sensible default for 1.4
Use the file that already exists for the majority of users for the default value.
2022-09-01 20:21:39 +01:00
David Wager
7db51d0171 Merge branch 'main' into main 2022-09-01 19:27:38 +01:00
Lincoln Stein
1b3c7acce3 fix ambiguous naming of self.device 2022-09-01 14:18:17 -04:00
Lincoln Stein
e6b2c15fc5 Merge branch 'main' into fit-init-img
add a --fit option to limit the size of the initial image to the
maximum boundaries specified by width and height.
2022-09-01 14:09:46 -04:00
David Wager
d319b8a762 Reference model from configs/models.yaml
By supplying --model (defaulting to stable-diffusion-1.4) a user can specify which model to load.
Width/Height/Config Location/Weights Location are referenced from configs/models.yaml
2022-09-01 19:04:31 +01:00
David Wager
db580ccefd Create models.yaml
models.yaml can serve as a base for expanding our support for other versions of Latent/Stable Diffusion.
Contained are parameters for default width/height, as well as where to find the config and weights for this model.
Adding a new model is as simple as adding to this file.
2022-09-01 19:02:57 +01:00
Brent Ozar
9e99fcbc16 README.md - fixing "further reading" formatting
Fixing typo in header and hyperlinking a file.
2022-09-01 10:27:58 -04:00
Lincoln Stein
346c9b66ec Merge branch 'corajr-main' into main
This improves Mac M1 installation instructions and makes the
environment easier to install.
2022-09-01 10:25:59 -04:00
Lincoln Stein
a52870684a Merge branch 'main' of https://github.com/corajr/stable-diffusion into corajr-main 2022-09-01 10:25:43 -04:00
Lincoln Stein
2455bb38a4 Remove redundant chain of types
torch->cuda and cuda->torch, so torch.cuda.torch.cuda actually works. However it looks like (and probably is) a typo.
2022-09-01 10:23:45 -04:00
Lincoln Stein
01e05a98de this fixes the inconsistent use of self.device, sometimes a str and sometimes an obj 2022-09-01 10:16:05 -04:00
Cora Johnson-Roberson
2cac4697aa Correct some verbiage in Mac readme. 2022-09-01 10:11:14 -04:00
Lincoln Stein
c5e95adb49 closes #273, crash on M1 machines 2022-09-01 10:01:41 -04:00
Cora Johnson-Roberson
91565970c2 Move environment-mac.yaml to Python 3.9 and patch dream.py for Macs.
I'm using stable-diffusion on a 2022 Macbook M2 Air with 24 GB unified memory.
I see this taking about 2.0s/it.

I've moved many deps from pip to conda-forge, to take advantage of the
precompiled binaries. Some notes for Mac users, since I've seen a lot of
confusion about this:

One doesn't need the `apple` channel to run this on a Mac-- that's only
used by `tensorflow-deps`, required for running tensorflow-metal. For
that, I have an example environment.yml here:

https://developer.apple.com/forums/thread/711792?answerId=723276022#723276022

However, the `CONDA_ENV=osx-arm64` environment variable *is* needed to
ensure that you do not run any Intel-specific packages such as `mkl`,
which will fail with [cryptic errors](https://github.com/CompVis/stable-diffusion/issues/25#issuecomment-1226702274)
on the ARM architecture and cause the environment to break.

I've also added a comment in the env file about 3.10 not working yet.
When it becomes possible to update, those commands run on an osx-arm64
machine should work to determine the new version set.

Here's what a successful run of dream.py should look like:

```
$ python scripts/dream.py --full_precision                                                                                                           SIGABRT(6) ↵  08:42:59
* Initializing, be patient...

Loading model from models/ldm/stable-diffusion-v1/model.ckpt
LatentDiffusion: Running in eps-prediction mode
DiffusionWrapper has 859.52 M params.
making attention of type 'vanilla' with 512 in_channels
Working with z of shape (1, 4, 32, 32) = 4096 dimensions.
making attention of type 'vanilla' with 512 in_channels
Using slower but more accurate full-precision math (--full_precision)
>> Setting Sampler to k_lms
model loaded in 6.12s

* Initialization done! Awaiting your command (-h for help, 'q' to quit)
dream> "an astronaut riding a horse"
Generating:   0%|                                                                                                                                                                         | 0/1 [00:00<?, ?it/s]/Users/corajr/Documents/lstein/ldm/modules/embedding_manager.py:152: UserWarning: The operator 'aten::nonzero' is not currently supported on the MPS backend and will fall back to run on the CPU. This may have performance implications. (Triggered internally at /Users/runner/work/_temp/anaconda/conda-bld/pytorch_1662016319283/work/aten/src/ATen/mps/MPSFallback.mm:11.)
  placeholder_idx = torch.where(
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 50/50 [01:37<00:00,  1.95s/it]
Generating: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [01:38<00:00, 98.55s/it]
Usage stats:
   1 image(s) generated in 98.60s
   Max VRAM used for this generation: 0.00G
Outputs:
outputs/img-samples/000001.1525943180.png: "an astronaut riding a horse" -s50 -W512 -H512 -C7.5 -Ak_lms -F -S1525943180
```
2022-09-01 09:04:30 -04:00
Jason Toffaletti
09bd9fa47e move autocast device selection to a function 2022-08-31 22:21:14 -07:00
Lincoln Stein
dc30adfbb4 closes #273, crash on M1 machines 2022-09-01 01:09:56 -04:00
Jason Toffaletti
fa98601bfb better error reporting for load_model 2022-08-31 22:03:50 -07:00
Jason Toffaletti
66fe110148 default full_prevision to True for mps device 2022-08-31 22:03:50 -07:00
Jason Toffaletti
bf50ab9dd6 changes to get dream.py working on M1
- move all device init logic to T2I.__init__
- handle m1 specific edge case with autocast device type
- check torch.cuda.is_available before using cuda
2022-08-31 22:03:42 -07:00
James Reynolds
70119602a0 Issue 270 fix (#274)
* check if torch.backends has mps before calling it

* Fixes issue 270

Co-authored-by: James Reynolds <magnsuviri@me.com>
2022-09-01 00:59:20 -04:00
Lincoln Stein
28fe84177e optionally scale initial image to fit box defined by width x height
* This functionality is triggered by the --fit option in the CLI (default
false), and by the "fit" checkbox in the WebGUI (default True)

* In addition, this commit contains a number of whitespace changes to
make the code more readable, as well as an attempt to unify the visual
appearance of info and warning messages.
2022-09-01 00:52:43 -04:00
James Reynolds
35d3f0ed90 Merge branch 'lstein:main' into main 2022-08-31 21:42:12 -06:00
blessedcoolant
0433b3d625 Add Warning When Image Is Too Large (#271)
* Add Warning When Image Is Too Large

* fix incomprehensible formatting introduced by "blue"

Co-authored-by: Lincoln Stein <lincoln.stein@gmail.com>
2022-08-31 23:13:21 -04:00
Lincoln Stein
4b560b50c2 fix AttributeError crash when running on non-CUDA systems (#256)
* fix AttributeError crash when running on non-CUDA systems; closes issue #234 and issue #250
* although this prevents dream.py script from crashing immediately on MPS systems, MPS support still very much a work in progress.
2022-08-31 16:59:27 -04:00
Lincoln Stein
9ad79207c2 Merge branch 'main' of github.com:lstein/stable-diffusion into main 2022-08-31 14:44:18 -04:00
Lincoln Stein
0be2351c97 Merge branch 'resolution-checker' of https://github.com/blessedcoolant/stable-diffusion into main 2022-08-31 14:43:17 -04:00
David Wager
ed513397b2 Allow configuration of which SD model to use (#263)
* Allow configuration of which SD model to use

Closes https://github.com/lstein/stable-diffusion/issues/49 The syntax isn't quite the same (opting for --weights over --model), although --weights is more in-line with the existing naming convention.
This method also locks us into models in the models/ldm/stable-diffusion-v1/ directory. Personally, I'm not averse to this, although a secondary solution may be necessary if we wish to supply weights from an external directory.

* Fix typo

* Allow either filename OR filepath input for arg

This approach allows both
--weights SD13 
--weights C:/StableDiffusion/models/ldm/stable-diffusion-v1/SD13.ckpt
2022-08-31 14:20:28 -04:00
_nderscore
c52ba1b022 feat: simplify and enhance prompt weight splitting (#258)
* feat: simplify and enhance prompt weight splitting

* fix: don't shadow the prompt variable

* feat: enable backslash-escaped colons in prompts
2022-08-31 14:00:10 -04:00
Lincoln Stein
d022d0dd11 continue to display in-progress image until the post-processing is done, for better esthetics (#255) 2022-08-31 12:32:56 -04:00
Kevin Gibbons
a14fd69a5a fix progress bar in webui when using strength parameter (#254) 2022-08-31 11:28:11 -04:00
James Reynolds
0d2e6f90c8 Readme update (#253)
* check if torch.backends has mps before calling it

* Updated Mac Readme with latest debugging info

Co-authored-by: James Reynolds <magnsuviri@me.com>
2022-08-31 11:27:13 -04:00
David Ford
58e3562652 Fix merging embeddings (#226)
Fixed merging embeddings based on the changes made in textual inversion. Tested and working. Inverted their logic to prioritize Stable Diffusion implementation over alternatives, but left the option for alternatives to still be used.
2022-08-31 11:24:11 -04:00
Mikhail Tishin
b622819051 Expose img2img strength parameter in Web UI (#239)
* Expose img2img strength parameter in Web UI

* Fix strength label id

Co-authored-by: Mikhail Tishin <michail.tishin@fayrix.com>
Co-authored-by: Kevin Gibbons https://github.com/bakkot
2022-08-31 11:18:32 -04:00
James Reynolds
a547c33327 check if torch.backends has mps before calling it (#245)
Co-authored-by: James Reynolds <magnsuviri@me.com>
2022-08-31 10:56:38 -04:00
Brent Ozar
31b77dbaf8 Readme.md - fix hyperlink to Mac docs (#246)
The square brackets & curly brackets were mixed up.
2022-08-31 10:53:21 -04:00
Tom Elovi Spruce
4280788c18 Fix link to Mac instructions in README (#235) 2022-08-31 10:51:25 -04:00
Lincoln Stein
146e75a1de Merge branch 'bakkot-refactor-pngwriter-2' into main
This fixes regressions in the WebGUI and makes maintenance of pngwriter
easier.
2022-08-31 10:07:57 -04:00
Lincoln Stein
8a2b849620 fix regression in WebGUI progress bar and WebGUI crashes, closes issue #236. Closes issue #249 2022-08-31 10:07:19 -04:00
Lincoln Stein
462a1961e4 fix infinite hang during GFPGAN duration inadvertently introduced during batch_size cleanup 2022-08-31 08:21:49 -04:00
James Reynolds
84c10346fb check if torch.backends has mps before calling it 2022-08-31 03:29:37 -06:00
Jason Toffaletti
2aa8393272 set PYTORCH_ENABLE_MPS_FALLBACK in mac environment (#232)
- this enables cpu fallback for op not yet implemented for m1 gpu
2022-08-31 02:00:40 -04:00
Lincoln Stein
c83d01b369 fix hang during GFPGAN processing due to bug introduced by recent removal of batch_size arg from pngwriter 2022-08-31 01:41:15 -04:00
Lincoln Stein
5354122094 Merge branch 'main' into refactor-pngwriter-2 2022-08-31 01:24:17 -04:00
spezialspezial
64444025a9 Update simplet2i.py (#228)
Typo causing bug when preinitializing the model. Unsupported Sampler: klms, Defaulting to plms
2022-08-31 01:08:46 -04:00
Kevin Gibbons
d566ee092a move make_grid into image_utils 2022-08-30 22:03:53 -07:00
Kevin Gibbons
b983d61e93 tweak format of "result" event in web ui 2022-08-30 22:03:53 -07:00
Kevin Gibbons
153c93bdd4 refactor pngwriter 2022-08-30 22:03:51 -07:00
Lincoln Stein
3be1cee17c avoid crash due to dangling batch_size reference 2022-08-31 00:56:12 -04:00
Lincoln Stein
bdb0651eb2 add support for Apple hardware using MPS acceleration 2022-08-31 00:33:23 -04:00
blessedcoolant
1480ef84dc Add Resolution Checker 2022-08-31 14:54:16 +12:00
Kevin Gibbons
1714816fe2 remove support for batch_size from dream.py (#227)
* remove dream.py support for batch_size

* expect to get a single image
2022-08-30 22:30:12 -04:00
David Ford
b5565d2c82 Update .gitignore (#225)
Include log folders in git ignore.
2022-08-30 20:29:26 -04:00
David Ford
4fad71cd8c Training optimizations (#217)
* Optimizations to the training model

Based on the changes made in
textual_inversion I carried over the relevant changes that improve model training. These changes reduce the amount of memory used, significantly improve the speed at which training runs, and improves the quality of the results.

It also fixes the problem where the model trainer wouldn't automatically stop when it hit the set number of steps.

* Update main.py

Cleaned up whitespace
2022-08-30 15:59:32 -04:00
Lincoln Stein
d126db2413 Update README.md 2022-08-30 15:57:54 -04:00
blessedcoolant
7811d20f21 Add Badges to README.md and add CHANGELOG.md (#205)
* Update README.md - Add Badges

* Add CHANGELOG.md
2022-08-30 15:40:56 -04:00
Yosuke Shinya
d524e5797d Add regression test (#136)
* Add regression test
* fix regression test with full_precision
2022-08-30 15:39:14 -04:00
Kevin Gibbons
8ca4d6542d support progress for img2img (#215)
WebGUI shows progress bar when an initial image is provided.
2022-08-30 15:36:12 -04:00
Lincoln Stein
a51e18ea98 resize initial image to match requested width and height, preserving aspect ratio. Closes #210. Closes #207 (#214) 2022-08-30 15:26:02 -04:00
Lincoln Stein
8bf321f6ae Merge pull request #182 from bakkot/webui-cancel
webui: support cancelation
2022-08-30 12:02:05 -04:00
Kevin Gibbons
5d13207aa6 webui: support cancelation 2022-08-30 08:55:40 -07:00
Lincoln Stein
dae2b26765 remove message about GFPGAN being required, since it is no longer displayed if GFPGAN missing 2022-08-30 09:50:39 -04:00
Lincoln Stein
713b2a03dc Merge branch 'bakkot-sw-drop' into main
This adds a checkbox that shows the intermediate images fpr,omg as txt2img()
goes through its denoising steps.
2022-08-30 09:31:19 -04:00
Lincoln Stein
186d0f9d10 Merge branch 'sw-drop' of https://github.com/bakkot/stable-diffusion into bakkot-sw-drop 2022-08-30 09:17:07 -04:00
Lincoln Stein
55b448818e Update README.md
Highlighted Colab notebook addition.
2022-08-29 23:49:42 -04:00
Lincoln Stein
b4babf7680 add a screenshot to description of command-line utility 2022-08-29 23:27:44 -04:00
Lincoln Stein
85f32752fe promote most headings by one level 2022-08-29 23:16:41 -04:00
Lincoln Stein
b757384aba promote most headings by one level 2022-08-29 23:16:21 -04:00
Lincoln Stein
a5d21d7c94 Update README.md
Added a table of contents and a troubleshooting guide.
2022-08-29 23:15:49 -04:00
Lincoln Stein
8f3520e2d5 add troubleshooting guide to README 2022-08-29 23:08:04 -04:00
Lincoln Stein
19e4298cf9 Merge branch 'BlueAmulet-prompt_as_dir' into main
This adds the frequently-requested feature of naming the output
directory after the text prompt.
2022-08-29 22:34:48 -04:00
Lincoln Stein
42ffcd7204 add the recently added commands to the readline command-line-completion list; fix command-line documentation bug, closing issue #188 2022-08-29 22:34:09 -04:00
Lincoln Stein
d48299e56c Merge branch 'prompt_as_dir' of https://github.com/BlueAmulet/stable-diffusion into BlueAmulet-prompt_as_dir 2022-08-29 22:13:37 -04:00
BlueAmulet
2e22d9ecf1 Address bakkot review 2022-08-29 18:10:15 -06:00
Kevin Gibbons
18597ad1d9 fix bug in pngwriter 2022-08-29 16:33:32 -07:00
Kevin Gibbons
0173d3a8fc stream images 2022-08-29 16:33:31 -07:00
Lincoln Stein
e7658b941e Merge pull request #187 from bakkot/webui-upscalers-optional
webui: hide gfpgan part if not installed
2022-08-29 19:29:16 -04:00
Kevin Gibbons
a7a62d39d4 webui: hide gfpgan if not installed 2022-08-29 16:27:44 -07:00
Lincoln Stein
24ce56b3db Merge branch 'webui-sampler-fix' into main 2022-08-29 19:25:49 -04:00
Lincoln Stein
3220f73f0a add missing dropdown element for K_EULER_A 2022-08-29 19:24:52 -04:00
Lincoln Stein
27a1044e65 Merge pull request #199 from david-ford/gitattributes
Create .gitattributes
2022-08-29 19:15:25 -04:00
Lincoln Stein
39c56f20be Merge pull request #200 from david-ford/html-minor-changes
Minor updates to index.html
2022-08-29 19:14:51 -04:00
Lincoln Stein
f6b2ec61b2 Merge pull request #201 from blessedcoolant/placeholder-logo
Add Temp Logo To Repo
2022-08-29 19:13:54 -04:00
blessedcoolant
e57d6fd1a6 Add Temp Logo To Repo 2022-08-30 11:00:11 +12:00
David Ford
1b40a31a89 Update .gitattributes
Co-authored-by: Kevin Gibbons <bakkot@gmail.com>
2022-08-29 16:58:41 -05:00
David Ford
4fce1063c4 Minor updates to index.html
Some minor tweaks to index.html for accessibility and browsers.
2022-08-29 16:55:20 -05:00
David Ford
f9862a3d88 Create .gitattributes
Added a .gitattributes file to autoresolve line endings across different environments.
2022-08-29 16:50:28 -05:00
Lincoln Stein
81ad239197 Merge pull request #192 from david-ford/working-branch
Fix case sensitive check to be case insensitive
2022-08-29 17:48:12 -04:00
David Ford
ed38c97ed8 Removed unrelated changes and updated based on recommendation
Removed the changes to the index.html and .gitattributes for this PR. Will add them in separate PRs.

Applied recommended change for resolving the case issue.
2022-08-29 16:43:34 -05:00
BlueAmulet
4f8e7356b3 Add prompt as output directory feature
Based on previous code by czyz
2022-08-29 14:52:02 -06:00
Lincoln Stein
c363f033e8 Merge pull request #198 from bakkot/instructions-typo
fix path to dream server in readme
2022-08-29 16:41:24 -04:00
Kevin Gibbons
22c25b3615 fix path to dream server in readme 2022-08-29 13:38:52 -07:00
Lincoln Stein
7fe7cdc8c9 Merge pull request #176 from xraxra/show-tokenization
Print out tokenization data during image generation, allowing truncated prompts to be visible.
2022-08-29 15:36:10 -04:00
Lincoln Stein
e26fee78b5 Merge pull request #158 from Cubox/patch-1
Fix wrong help message
2022-08-29 15:12:24 -04:00
Lincoln Stein
63178c6a8c Merge branch 'main' into patch-1 2022-08-29 15:12:14 -04:00
Lincoln Stein
6fb2f1ed6e fixes WebUI so that the selected sampler is actually applied 2022-08-29 14:06:18 -04:00
Lincoln Stein
38701a6d7b Fix IndexError when generating grid; --grid option can now be passed on shell command line 2022-08-29 13:52:44 -04:00
David Ford
31fa92a83f Fix case sensitive check to be case insensitive
Case sensitivity between os.getcwd and os.realpath can fail due to different drive letter casing. C:\ vs c:\. This change addresses that by normalizing the strings before comparing.
2022-08-29 12:46:24 -05:00
Lincoln Stein
0abfc3cac6 Merge branch 'main' of github.com:lstein/stable-diffusion into main
This fixes issue with grid generation.
2022-08-29 13:39:59 -04:00
Lincoln Stein
d483fcb53a Merge pull request #166 from SMUsamaShah/patch-1
Bug fix in grid
2022-08-29 13:39:50 -04:00
Lincoln Stein
c7db038c96 grid is broken, needs the grid-fix PR#166 to fix 2022-08-29 13:39:20 -04:00
Lincoln Stein
132d23e55d Merge pull request #186 from lstein/reset-properly
Fix form reset logic
2022-08-29 13:03:30 -04:00
Lincoln Stein
90cbc6362c Display new features more prominently in the README 2022-08-29 12:59:34 -04:00
Lincoln Stein
f33ae1bdf4 Display new features more prominently in the README 2022-08-29 12:58:48 -04:00
Kevin Gibbons
754525be82 fix reset logic 2022-08-29 09:47:56 -07:00
Kevin Gibbons
d9eab7f383 use LF not CRLF for files, oh god 2022-08-29 09:47:45 -07:00
Kevin Gibbons
f695988915 fix whitespace in index.html 2022-08-29 09:47:20 -07:00
Lincoln Stein
5d19294810 Merge pull request #128 from artmen1516/feature/colab-notebook
Feature: Add Colab Notebook
2022-08-29 12:37:11 -04:00
Lincoln Stein
77803cf233 Merge branch 'bakkot-rebase-streaming-web' into main
This adds correct treatment of upscaling/face-fixing within the WebUI.
Also adds a basic status message so that the user knows what's happening
during the post-processing steps.
2022-08-29 12:09:44 -04:00
Lincoln Stein
4acfb76be6 correctly handle upscaling in webUI, including displaying status messages during GFPGAN/ESRGAN postprocessing 2022-08-29 12:08:18 -04:00
Lincoln Stein
fd13526454 Merge pull request #175 from bakkot/unicode
read/write plain text files in utf-8, not ascii
2022-08-29 07:03:55 -04:00
Lincoln Stein
7718af041c Merge pull request #178 from lstein/dream-web-upscaling
FEAT: Dream web upscaling
2022-08-29 06:59:21 -04:00
Lincoln Stein
30dbf0e589 Merge pull request #177 from lstein/bugfixes
Bugfixes to image generation logic
2022-08-29 06:58:42 -04:00
tesseractcat
070795a3b4 webui: stream progress events to page 2022-08-28 21:54:10 -07:00
Lincoln Stein
e351d6ffe5 set correct default values for scaling and sampler; closes issues #167 #157 2022-08-29 00:13:18 -04:00
Lincoln Stein
46464ac677 remove unused metadatastr variable 2022-08-28 23:45:50 -04:00
Lincoln Stein
03d8eb19e0 when no callback used, modify results list so that upscaled/face-fixed image replaces the old one 2022-08-28 23:40:04 -04:00
xra
fef632e0e1 tokenization logging (take 2)
This adds an option -t argument that will print out color-coded tokenization, SD has a maximum of 77 tokens, it silently discards tokens over the limit if your prompt is too long.
By using -t you can see how your prompt is being tokenized which helps prompt crafting.
2022-08-29 12:28:49 +09:00
Lincoln Stein
05061a70b3 report errors on non-cuda systems rather than failing silently 2022-08-28 23:13:23 -04:00
Lincoln Stein
617a029ae7 pass outdir from txt2img() and img2img() to prompt2img() correctly 2022-08-28 23:12:49 -04:00
Kevin Gibbons
7ae79b350e write log files in utf-8, not ascii 2022-08-28 20:00:11 -07:00
Lincoln Stein
9a8cd9684e Merge pull request #173 from warner-benjamin/bug-fixes-improvements
Fix grid image saving & sampler selection, log to outdir path, display sampler options once
2022-08-28 22:46:38 -04:00
Lincoln Stein
18899be4ae working, but there is a bug in underlying txt2png() call that is preventing upscaled images from being returned 2022-08-28 22:42:31 -04:00
Benjamin Warner
3ea505bc2d Merge branch 'lstein:main' into bug-fixes-improvements 2022-08-28 21:34:13 -05:00
Lincoln Stein
e2ae6d288d added reset to defaults button and sampler selection 2022-08-28 21:35:52 -04:00
blessedcoolant
41b26e0520 Merge pull request #171 from blessedcoolant/sampler-bug-fix
Fix sampler changer not working.
2022-08-29 13:06:30 +12:00
blessedcoolant
b6053108c1 Merge pull request #168 from blessedcoolant/bug-fixes
Fixed grid image not saving
2022-08-29 13:05:53 +12:00
Lincoln Stein
22365a3f12 begin adding fields for GFPGAN and ESRGAN adjustment; only making public because need to switch computers 2022-08-28 21:04:32 -04:00
Benjamin Warner
594c0eeb8c check in rest of sampler fix 2022-08-28 19:53:25 -05:00
Benjamin Warner
529040708b Fix grid image saving, log to outdir path, display sampler options once 2022-08-28 19:34:55 -05:00
blessedcoolant
f0e2fa781f Grid image not saving after recent changes has been fixed. 2022-08-29 11:29:45 +12:00
blessedcoolant
87b7446228 Fix unique filename bug 2022-08-29 11:28:16 +12:00
blessedcoolant
8a517fdc17 Fix sampler changer not working. 2022-08-29 11:26:19 +12:00
Lincoln Stein
373a2d9c32 Merge branch 'main' of github.com:lstein/stable-diffusion into main 2022-08-28 19:03:45 -04:00
Lincoln Stein
1f8bc9482a added support for changing sampler on prompt line 2022-08-28 19:03:38 -04:00
Lincoln Stein
b85773f332 resolved conflicts and write properly-formatted prompt string (with sampler & upscaling) into image file 2022-08-28 19:01:45 -04:00
Lincoln Stein
ddc0e9b4d8 Merge pull request #133 from bakkot/dir-traversal
prevent directory traversal in the web UI
2022-08-28 18:32:12 -04:00
Lincoln Stein
44a48d0981 Merge pull request #130 from bakkot/patch-1
make web ui default to 512x512
2022-08-28 18:30:32 -04:00
Lincoln Stein
8bbe7936bd close Issue #165 2022-08-28 18:21:20 -04:00
Kevin Gibbons
9e7865704a prevent directory traversal in the web UI 2022-08-28 14:33:30 -07:00
Lincoln Stein
ac02a775e4 moved server.py into right location 2022-08-28 17:27:43 -04:00
Lincoln Stein
7c485a1a4a adjusted -U upscaling argument so that it defaults to upscaling strength 0.75 if the second argument is not given 2022-08-28 17:26:39 -04:00
Lincoln Stein
36bc989a27 Merge branch 'blessedcoolant-gfpgan-optimization' into main
This reduces VRAM requirements when GFPGAN face fixing and Real-ESRGAN
upscaling are used. --gfpgan flag is no longer needed (or accepted)
2022-08-28 17:06:49 -04:00
Lincoln Stein
ea2ee33be8 cosmetic fixup to how the outputs are reported 2022-08-28 17:06:33 -04:00
Lincoln Stein
5d67986997 Merge branch 'dagf2101-main' into main
This consolidates dream.py with dream_web.py, such that you now invoke
the web server by executing "scripts/dream.py --web (+other options)"
2022-08-28 16:37:50 -04:00
Lincoln Stein
7dfca3dcb5 moved scripts/dream_server.py into ldm/dream/server.py 2022-08-28 16:37:27 -04:00
blessedcoolant
e0de42bd03 Update README.md 2022-08-29 08:25:55 +12:00
blessedcoolant
614974a8e8 Merge branch 'main' into gfpgan-optimization 2022-08-29 08:22:26 +12:00
blessedcoolant
6e49c070bb Optimize and Improve GFPGAN and Real-ESRGAN Pipeline 2022-08-29 08:14:29 +12:00
Lincoln Stein
08a9702b73 Merge branch 'main' of github.com:lstein/stable-diffusion into main 2022-08-28 15:54:33 -04:00
Lincoln Stein
042a9043d1 got rid of the cd and pwd commands, and just allow user to specify --outdir on the command 2022-08-28 15:54:12 -04:00
Lincoln Stein
a7ac93a899 Merge pull request #110 from sajattack/half-precision-embeddings
Support full-precision embeddings in half precision inference mode
2022-08-28 15:36:26 -04:00
Lincoln Stein
3b2569ebdd Merge branch 'yunsaki-main' into main 2022-08-28 14:20:48 -04:00
Lincoln Stein
8b9a520c5c adjusted handling of from_file 2022-08-28 14:20:34 -04:00
Lincoln Stein
ba03289c14 print current and max VRAM usage stats after each round of generation 2022-08-28 13:05:01 -04:00
blessedcoolant
d1551b1bd4 Enable users to set sampler using prompts 2022-08-29 04:27:54 +12:00
Andy Pilate
fab9e1a423 Fix wrong help message 2022-08-28 17:11:24 +02:00
Muhammad Usama
59be6c815d bug fix in grid
In case of 6 images 3rd image was also copied to 4th box missing the last image in the grid.
2022-08-28 15:39:09 +01:00
artmen1516
ff6c11406b add notebook and readme section 2022-08-27 17:20:44 -07:00
Kevin Gibbons
6f90c7daf6 make web ui default to 512x512 2022-08-26 18:26:22 -07:00
Lincoln Stein
38ed6393fa updated TODO 2022-08-26 14:59:53 -04:00
Lincoln Stein
a5a3300fc6 Merge pull request #114 from TesseractCat/main
Replace numerical size inputs with dropdowns
2022-08-26 14:42:57 -04:00
tesseractcat
0ab03a5fde Replace numerical size inputs with dropdowns 2022-08-26 13:35:27 -04:00
Lincoln Stein
800132970e Merge pull request #105 from shusso/select-device
Move torch.device selection to it's own function
2022-08-26 12:23:21 -04:00
Paul Sajna
555f13e469 Merge branch 'main' into half-precision-embeddings 2022-08-26 08:33:46 -07:00
Paul Sajna
9b5101cd8d support full-precision embeddings in half precision mode 2022-08-26 08:30:58 -07:00
yun saki
7040995ceb fixed variable name error 2022-08-26 14:25:49 +02:00
yun saki
5129f256a3 simplet2i: changed image file handling to work as stated in the [docs](https://pillow.readthedocs.io/en/stable/reference/open_files.html) 2022-08-26 14:13:16 +02:00
Lincoln Stein
b0b4ccf521 Merge pull request #101 from BaristaLabs/remove-gpfgan
Set default to none for gfpgan_strength
2022-08-26 07:55:47 -04:00
Samuel Husso
ed72ff3268 Move torch.device selection to it's own function 2022-08-26 14:43:18 +03:00
yun saki
89805a5239 fixed mistake in comment 2022-08-26 13:25:12 +02:00
yun saki
e00397f9ca refactored logfile handling; minimised time spent in context managers (with open) 2022-08-26 13:22:53 +02:00
yun saki
12f59e1daa removed log.close(); 'with open' automatically closes the file 2022-08-26 13:12:56 +02:00
yun saki
cf750f62db refactored infile handling 2022-08-26 13:10:37 +02:00
yun saki
0f28663805 remove redundant None check (if var does the same thing) 2022-08-26 12:43:13 +02:00
Sean McLellan
f3fad22cb6 Fix 2022-08-26 05:27:34 -04:00
Sean McLellan
7bf0bc5208 fix comment 2022-08-26 04:08:18 -04:00
Sean McLellan
4e5aa7e714 fix comment 2022-08-26 04:07:01 -04:00
Sean McLellan
46a223f229 Double check for null and 0, and add a comment to indicate intent 2022-08-26 04:05:09 -04:00
Sean McLellan
eb9f0be91a Set default to none for gfpgan_strength 2022-08-26 03:53:55 -04:00
Lincoln Stein
4f02b72c9c prettified all the code using "blue" at the urging of @tildebyte 2022-08-26 03:15:42 -04:00
Lincoln Stein
dd670200bb documentation tweaks for installation and running of the GFPGAN extension; now you have the ability to specify the previous image's seed with -S -1, the one before that with -S -2, and so forth 2022-08-26 02:17:14 -04:00
Lincoln Stein
8f89a2456a something is not quite right; when providing -G1 option on one prompt, and then omitting it on the next, I see a "images do not match" error from GFPGAN 2022-08-26 01:20:01 -04:00
Sean McLellan
407d70a987 Fix backwards logic 2022-08-26 00:49:12 -04:00
Sean McLellan
f1ffb5b51b Fix blend if the target image has been upscaled 2022-08-26 00:45:19 -04:00
Sean McLellan
4f1664ec4f remove params 2022-08-26 00:41:41 -04:00
Sean McLellan
fcdd95b652 Refactor so that behavior is consolidated at top level 2022-08-26 00:39:57 -04:00
Sean McLellan
470a62dbbe Merge branch 'main' of https://github.com/BaristaLabs/stable-diffusion-dream into add-gfpgan-option 2022-08-26 00:26:03 -04:00
Lincoln Stein
2c08cf7175 Merge branch 'more-refactoring' into main
This breaks up the dream utility modules in a more
sensible manner.
2022-08-25 23:59:30 -04:00
Lincoln Stein
539c15966d Update README.md
Put in a plug for Yansuki's morphing code.
2022-08-25 23:54:44 -04:00
Lincoln Stein
5f844807cb Update README.md
Removed a bit of an uncaught merge conflict warning.
2022-08-25 23:50:56 -04:00
Sean McLellan
cb86b9ae6e Remove the redundancy, better logging 2022-08-25 23:48:35 -04:00
Sean McLellan
3a30a8f2d2 Fix not being able to disable bgupscaler; update readme 2022-08-25 23:39:03 -04:00
Sean McLellan
60ed004328 Update readme, fix defaults for case-sensitive fs's 2022-08-25 23:31:08 -04:00
Sean McLellan
dbb9132f4d Merge branch 'main' of https://github.com/BaristaLabs/stable-diffusion-dream into add-gfpgan-option 2022-08-25 23:19:17 -04:00
Sean McLellan
5711b6d611 Add optional GFPGAN support 2022-08-25 22:57:30 -04:00
Lincoln Stein
f1bed52530 moved dream utilities into their own subfolder 2022-08-25 22:49:15 -04:00
Lincoln Stein
23fb4a72bb Merge branch 'bakkot-more-refactor' into main 2022-08-25 22:19:27 -04:00
Lincoln Stein
c38b6964b4 improved inline error messages slightly 2022-08-25 22:19:12 -04:00
Lincoln Stein
e202441f0c Merge branch 'more-refactor' of https://github.com/bakkot/stable-diffusion into bakkot-more-refactor 2022-08-25 21:55:08 -04:00
Lincoln Stein
d051d86df6 Merge pull request #96 from TesseractCat/main
Keep a log of requests for dream_web
2022-08-25 21:42:17 -04:00
tesseractcat
b49475a54f Keep a log of requests for dream_web 2022-08-25 21:06:17 -04:00
Kevin Gibbons
797de3257c fix batch_size 2022-08-25 17:28:52 -07:00
Kevin Gibbons
31b22e057d switch to generators 2022-08-25 17:06:06 -07:00
Kevin Gibbons
078859207d factor out loop 2022-08-25 16:51:39 -07:00
Kevin Gibbons
a10baf5808 factor out exception handler 2022-08-25 15:13:07 -07:00
Lincoln Stein
0eba55ddbc Merge pull request #91 from veprogames/update-ignorance
Properly remove src from repository, add ignored directory for initial images
2022-08-25 17:32:10 -04:00
Lincoln Stein
19fa222810 refactoring complete; please test carefully! 2022-08-25 17:30:08 -04:00
Lincoln Stein
b3e3b0e861 feature complete; looks like ready for merge 2022-08-25 17:26:48 -04:00
veprogames
dde2994d10 add inputs/ to .gitignore (a place for initial images) 2022-08-25 22:31:24 +02:00
veprogames
888ca39ce2 remove k-diffusion from repository (git rm --cached)
should fix conda environment hanging
2022-08-25 22:29:12 +02:00
Lincoln Stein
f4c95bfec0 Update README.md 2022-08-25 15:33:49 -04:00
Lincoln Stein
91d3e4605e Update README.md 2022-08-25 15:32:48 -04:00
Lincoln Stein
652c67c90e Update README.md 2022-08-25 15:29:41 -04:00
Lincoln Stein
2114c386ad moved index.js .html and .css files into static/dream_web/; changed batch to iterations again 2022-08-25 15:27:43 -04:00
Lincoln Stein
6d2b4cbda1 Merge branch 'main' of github.com:lstein/stable-diffusion into main 2022-08-25 15:15:07 -04:00
Lincoln Stein
562831fc4b Merge branch 'TesseractCat-main' into main 2022-08-25 15:14:50 -04:00
Lincoln Stein
d04518e65e resolved conflicts in use of batch vs iterations 2022-08-25 15:14:38 -04:00
Lincoln Stein
d598b6c79d Update README.md 2022-08-25 15:11:06 -04:00
Lincoln Stein
4ec21a5423 resolved conflicts 2022-08-25 15:09:55 -04:00
Lincoln Stein
b64c902354 added missing image 2022-08-25 15:06:10 -04:00
Lincoln Stein
2ada3288e7 Small cleanups.
- Quenched tokenizer warnings during model initialization.
- Changed "batch" to "iterations" for generating multiple images in
  order to conserve vram.
- Updated README.
- Moved static folder from under scripts to top level. Can store other
  static content there in future.
- Added screenshot of web server in action (to static folder).
2022-08-25 15:03:40 -04:00
tesseractcat
91966e9ffa Fix appearance on mobile 2022-08-25 15:01:08 -04:00
tesseractcat
2ad73246f9 Normalize working directory 2022-08-25 14:27:33 -04:00
tesseractcat
d3a802db69 Fix horizontal divider 2022-08-25 14:18:29 -04:00
tesseractcat
b95908daec Move style and script to individual files 2022-08-25 14:15:08 -04:00
Lincoln Stein
79add5f0b6 Merge branch 'main' of https://github.com/TesseractCat/stable-diffusion into TesseractCat-main 2022-08-25 13:52:44 -04:00
Lincoln Stein
650ae3eb13 Merge pull request #89 from BlueAmulet/remove-accelerate
Remove accelerate library
2022-08-25 13:48:48 -04:00
Lincoln Stein
0e3059728c Merge pull request #85 from JigenD/VRAMutilizationFix
fix VRAM utilization
2022-08-25 13:47:49 -04:00
BlueAmulet
b7735b3788 Fix attribution 2022-08-25 11:13:12 -06:00
BlueAmulet
39b55ae016 Remove accelerate library
This library is not required to use k-diffusion
Make k-diffusion wrapper closer to the other samplers
2022-08-25 11:04:57 -06:00
JigenD
e82c5eba18 PR revision: replace cuda call with dynamic type 2022-08-25 12:18:35 -04:00
Lincoln Stein
1c8ecacddf remove src directory, which is gumming up conda installs; addresses issue #77 2022-08-25 10:43:05 -04:00
Lincoln Stein
26dc05e0e0 document --from_file flag, closes issue #82 2022-08-25 09:47:27 -04:00
Lincoln Stein
49247b4aa4 fix performance regression; closes issue #42 2022-08-25 09:41:12 -04:00
JigenD
eb58276a2c fix VRAM utilization 2022-08-25 08:34:51 -04:00
tesseractcat
72a9d75330 404 on missing file 2022-08-25 01:25:22 -04:00
Lincoln Stein
1a7743f3c2 Merge pull request #79 from BaristaLabs/update-readme-with-variant-disc
Update readme with variant disc
2022-08-25 00:44:45 -04:00
Lincoln Stein
0b4459b707 mostly back to full functionality; just missing grid generation code 2022-08-25 00:42:37 -04:00
Sean McLellan
c521ac08ee Another update 2022-08-25 00:00:39 -04:00
Sean McLellan
29727f3e12 Another update 2022-08-24 23:59:37 -04:00
Sean McLellan
51b9a1d8d3 Update readme.md 2022-08-24 23:55:31 -04:00
tesseractcat
ab131cb55e Add img2img support, fix naming conventions 2022-08-24 23:03:02 -04:00
tesseractcat
269fcf92d9 Reapply prompt config on image click 2022-08-24 21:38:47 -04:00
Lincoln Stein
8b682ac83b Merge pull request #75 from tildebyte/docs-readme-update-109-notes
DOCS: update release features for v1.09 in README - add k_diffusion samplers note
2022-08-24 19:55:10 -04:00
Lincoln Stein
36e4130f1c Merge pull request #72 from BaristaLabs/fix-dependencies
Various fixes in requirements and variant counting.
2022-08-24 19:54:38 -04:00
Lincoln Stein
b978536385 code is reorganized and mostly functional. Grid needs to be brought back online, as well as naming of img2img variants (currently the variants get written but not logged) 2022-08-24 19:47:59 -04:00
tesseractcat
0a7fe6f2d9 Switch to ThreadingHTTPServer 2022-08-24 18:19:50 -04:00
Lincoln Stein
b12955c963 remove unneeded imports from dream.py 2022-08-24 17:57:44 -04:00
Lincoln Stein
9133087850 first draft at big refactoring; will be broken 2022-08-24 17:52:34 -04:00
Ben Alkov
25fa0ad1f2 docs(readme): update release features for v1.09 2022-08-24 17:50:29 -04:00
tesseractcat
df9f088eb4 Preserve prompt across generations 2022-08-24 17:28:59 -04:00
tesseractcat
b1600d4ca3 Update seed on click 2022-08-24 17:26:22 -04:00
tesseractcat
0efc3bf780 Add bare bones web UI 2022-08-24 17:04:30 -04:00
Sean McLellan
dd16fe16bb Fix issue where more than the expected number of variants are generated 2022-08-24 16:26:58 -04:00
Sean McLellan
4d72644db4 Housekeeping 2022-08-24 15:54:49 -04:00
Lincoln Stein
7ea168227c Update README.md
Added a few features that were missed in initial 1.09 commit.
2022-08-24 15:35:10 -04:00
Lincoln Stein
ef8ddffe46 updated README 2022-08-24 15:28:19 -04:00
81 changed files with 9769 additions and 3720 deletions

View File

@@ -0,0 +1,32 @@
import argparse
import numpy as np
from PIL import Image
def read_image_int16(image_path):
image = Image.open(image_path)
return np.array(image).astype(np.int16)
def calc_images_mean_L1(image1_path, image2_path):
image1 = read_image_int16(image1_path)
image2 = read_image_int16(image2_path)
assert image1.shape == image2.shape
mean_L1 = np.abs(image1 - image2).mean()
return mean_L1
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument('image1_path')
parser.add_argument('image2_path')
args = parser.parse_args()
return args
if __name__ == '__main__':
args = parse_args()
mean_L1 = calc_images_mean_L1(args.image1_path, args.image2_path)
print(mean_L1)

Binary file not shown.

After

Width:  |  Height:  |  Size: 416 KiB

View File

@@ -0,0 +1 @@
"a photograph of an astronaut riding a horse" -s50 -S42

View File

@@ -0,0 +1,20 @@
# generate an image
PROMPT_FILE=".dev_scripts/sample_command.txt"
OUT_DIR="outputs/img-samples/test_regression_txt2img_v1_4"
SAMPLES_DIR=${OUT_DIR}
python scripts/dream.py \
--from_file ${PROMPT_FILE} \
--outdir ${OUT_DIR} \
--sampler plms \
--full_precision
# original output by CompVis/stable-diffusion
IMAGE1=".dev_scripts/images/v1_4_astronaut_rides_horse_plms_step50_seed42.png"
# new output
IMAGE2=`ls -A ${SAMPLES_DIR}/*.png | sort | tail -n 1`
echo ""
echo "comparing the following two images"
echo "IMAGE1: ${IMAGE1}"
echo "IMAGE2: ${IMAGE2}"
python .dev_scripts/diff_images.py ${IMAGE1} ${IMAGE2}

View File

@@ -0,0 +1,23 @@
# generate an image
PROMPT="a photograph of an astronaut riding a horse"
OUT_DIR="outputs/txt2img-samples/test_regression_txt2img_v1_4"
SAMPLES_DIR="outputs/txt2img-samples/test_regression_txt2img_v1_4/samples"
python scripts/orig_scripts/txt2img.py \
--prompt "${PROMPT}" \
--outdir ${OUT_DIR} \
--plms \
--ddim_steps 50 \
--n_samples 1 \
--n_iter 1 \
--seed 42
# original output by CompVis/stable-diffusion
IMAGE1=".dev_scripts/images/v1_4_astronaut_rides_horse_plms_step50_seed42.png"
# new output
IMAGE2=`ls -A ${SAMPLES_DIR}/*.png | sort | tail -n 1`
echo ""
echo "comparing the following two images"
echo "IMAGE1: ${IMAGE1}"
echo "IMAGE2: ${IMAGE2}"
python .dev_scripts/diff_images.py ${IMAGE1} ${IMAGE2}

4
.gitattributes vendored Normal file
View File

@@ -0,0 +1,4 @@
# Auto normalizes line endings on commit so devs don't need to change local settings.
# Only affects text files and ignores other file types.
# For more info see: https://www.aleksandrhovhannisyan.com/blog/crlf-vs-lf-normalizing-line-endings-in-git/
* text=auto

36
.github/ISSUE_TEMPLATE/bug_report.md vendored Normal file
View File

@@ -0,0 +1,36 @@
---
name: Bug report
about: Create a report to help us improve
title: ''
labels: ''
assignees: ''
---
**Describe your environment**
- GPU: [cuda/amd/mps/cpu]
- VRAM: [if known]
- CPU arch: [x86/arm]
- OS: [Linux/Windows/macOS]
- Python: [Anaconda/miniconda/miniforge/pyenv/other (explain)]
- Branch: [if `git status` says anything other than "On branch main" paste it here]
- Commit: [run `git show` and paste the line that starts with "Merge" here]
**Describe the bug**
A clear and concise description of what the bug is.
**To Reproduce**
Steps to reproduce the behavior:
1. Go to '...'
2. Click on '....'
3. Scroll down to '....'
4. See error
**Expected behavior**
A clear and concise description of what you expected to happen.
**Screenshots**
If applicable, add screenshots to help explain your problem.
**Additional context**
Add any other context about the problem here.

View File

@@ -0,0 +1,20 @@
---
name: Feature request
about: Suggest an idea for this project
title: ''
labels: ''
assignees: ''
---
**Is your feature request related to a problem? Please describe.**
A clear and concise description of what the problem is. Ex. I'm always frustrated when [...]
**Describe the solution you'd like**
A clear and concise description of what you want to happen.
**Describe alternatives you've considered**
A clear and concise description of any alternative solutions or features you've considered.
**Additional context**
Add any other context or screenshots about the feature request here.

12
.gitignore vendored
View File

@@ -2,6 +2,9 @@
outputs/
models/ldm/stable-diffusion-v1/model.ckpt
# ignore a directory which serves as a place for initial images
inputs/
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
@@ -170,6 +173,13 @@ cython_debug/
#.idea/
src
logs/
**/__pycache__/
outputs
# Logs and associated folders
# created from generated embeddings.
logs
testtube
checkpoints
# If it's a Mac
.DS_Store

13
.gitmodules vendored
View File

@@ -1,13 +0,0 @@
[submodule "taming-transformers"]
path = src/taming-transformers
url = https://github.com/CompVis/taming-transformers.git
ignore = dirty
[submodule "clip"]
path = src/clip
url = https://github.com/openai/CLIP.git
ignore = dirty
[submodule "k-diffusion"]
path = src/k-diffusion
url = https://github.com/lstein/k-diffusion.git
ignore = dirty

137
CHANGELOG.md Normal file
View File

@@ -0,0 +1,137 @@
# **Changelog**
## v1.13 (in process)
- Supports a Google Colab notebook for a standalone server running on Google hardware [Arturo Mendivil](https://github.com/artmen1516)
- WebUI supports GFPGAN/ESRGAN facial reconstruction and upscaling [Kevin Gibbons](https://github.com/bakkot)
- WebUI supports incremental display of in-progress images during generation [Kevin Gibbons](https://github.com/bakkot)
- Output directory can be specified on the dream> command line.
- The grid was displaying duplicated images when not enough images to fill the final row [Muhammad Usama](https://github.com/SMUsamaShah)
- Can specify --grid on dream.py command line as the default.
- Miscellaneous internal bug and stability fixes.
---
## v1.12 (28 August 2022)
- Improved file handling, including ability to read prompts from standard input.
(kudos to [Yunsaki](https://github.com/yunsaki)
- The web server is now integrated with the dream.py script. Invoke by adding --web to
the dream.py command arguments.
- Face restoration and upscaling via GFPGAN and Real-ESGAN are now automatically
enabled if the GFPGAN directory is located as a sibling to Stable Diffusion.
VRAM requirements are modestly reduced. Thanks to both [Blessedcoolant](https://github.com/blessedcoolant) and
[Oceanswave](https://github.com/oceanswave) for their work on this.
- You can now swap samplers on the dream> command line. [Blessedcoolant](https://github.com/blessedcoolant)
---
## v1.11 (26 August 2022)
- NEW FEATURE: Support upscaling and face enhancement using the GFPGAN module. (kudos to [Oceanswave](https://github.com/Oceanswave)
- You now can specify a seed of -1 to use the previous image's seed, -2 to use the seed for the image generated before that, etc.
Seed memory only extends back to the previous command, but will work on all images generated with the -n# switch.
- Variant generation support temporarily disabled pending more general solution.
- Created a feature branch named **yunsaki-morphing-dream** which adds experimental support for
iteratively modifying the prompt and its parameters. Please see[ Pull Request #86](https://github.com/lstein/stable-diffusion/pull/86)
for a synopsis of how this works. Note that when this feature is eventually added to the main branch, it will may be modified
significantly.
---
## v1.10 (25 August 2022)
- A barebones but fully functional interactive web server for online generation of txt2img and img2img.
---
## v1.09 (24 August 2022)
- A new -v option allows you to generate multiple variants of an initial image
in img2img mode. (kudos to [Oceanswave](https://github.com/Oceanswave). [
See this discussion in the PR for examples and details on use](https://github.com/lstein/stable-diffusion/pull/71#issuecomment-1226700810))
- Added ability to personalize text to image generation (kudos to [Oceanswave](https://github.com/Oceanswave) and [nicolai256](https://github.com/nicolai256))
- Enabled all of the samplers from k_diffusion
---
## v1.08 (24 August 2022)
- Escape single quotes on the dream> command before trying to parse. This avoids
parse errors.
- Removed instruction to get Python3.8 as first step in Windows install.
Anaconda3 does it for you.
- Added bounds checks for numeric arguments that could cause crashes.
- Cleaned up the copyright and license agreement files.
---
## v1.07 (23 August 2022)
- Image filenames will now never fill gaps in the sequence, but will be assigned the
next higher name in the chosen directory. This ensures that the alphabetic and chronological
sort orders are the same.
---
## v1.06 (23 August 2022)
- Added weighted prompt support contributed by [xraxra](https://github.com/xraxra)
- Example of using weighted prompts to tweak a demonic figure contributed by [bmaltais](https://github.com/bmaltais)
---
## v1.05 (22 August 2022 - after the drop)
- Filenames now use the following formats:
000010.95183149.png -- Two files produced by the same command (e.g. -n2),
000010.26742632.png -- distinguished by a different seed.
000011.455191342.01.png -- Two files produced by the same command using
000011.455191342.02.png -- a batch size>1 (e.g. -b2). They have the same seed.
000011.4160627868.grid#1-4.png -- a grid of four images (-g); the whole grid can
be regenerated with the indicated key
- It should no longer be possible for one image to overwrite another
- You can use the "cd" and "pwd" commands at the dream> prompt to set and retrieve
the path of the output directory.
---
## v1.04 (22 August 2022 - after the drop)
- Updated README to reflect installation of the released weights.
- Suppressed very noisy and inconsequential warning when loading the frozen CLIP
tokenizer.
---
## v1.03 (22 August 2022)
- The original txt2img and img2img scripts from the CompViz repository have been moved into
a subfolder named "orig_scripts", to reduce confusion.
---
## v1.02 (21 August 2022)
- A copy of the prompt and all of its switches and options is now stored in the corresponding
image in a tEXt metadata field named "Dream". You can read the prompt using scripts/images2prompt.py,
or an image editor that allows you to explore the full metadata.
**Please run "conda env update -f environment.yaml" to load the k_lms dependencies!!**
---
## v1.01 (21 August 2022)
- added k_lms sampling.
**Please run "conda env update -f environment.yaml" to load the k_lms dependencies!!**
- use half precision arithmetic by default, resulting in faster execution and lower memory requirements
Pass argument --full_precision to dream.py to get slower but more accurate image generation
---
## Links
- **[Read Me](readme.md)**

322
README-Mac-MPS.md Normal file
View File

@@ -0,0 +1,322 @@
# macOS Instructions
Requirements
- macOS 12.3 Monterey or later
- Python
- Patience
- Apple Silicon*
*I haven't tested any of this on Intel Macs but I have read that one person got
it to work, so Apple Silicon might not be requried.
Things have moved really fast and so these instructions change often and are
often out-of-date. One of the problems is that there are so many different ways to
run this.
We are trying to build a testing setup so that when we make changes it doesn't
always break.
How to (this hasn't been 100% tested yet):
First get the weights checkpoint download started - it's big:
1. Sign up at https://huggingface.co
2. Go to the [Stable diffusion diffusion model page](https://huggingface.co/CompVis/stable-diffusion-v-1-4-original)
3. Accept the terms and click Access Repository:
4. Download [sd-v1-4.ckpt (4.27 GB)](https://huggingface.co/CompVis/stable-diffusion-v-1-4-original/blob/main/sd-v1-4.ckpt) and note where you have saved it (probably the Downloads folder)
While that is downloading, open Terminal and run the following commands one at a time.
```
# install brew (and Xcode command line tools):
/bin/bash -c "$(curl -fsSL https://raw.githubusercontent.com/Homebrew/install/HEAD/install.sh)"
# install python 3, git, cmake, protobuf:
brew install cmake protobuf rust
# install miniconda (M1 arm64 version):
curl https://repo.anaconda.com/miniconda/Miniconda3-latest-MacOSX-arm64.sh -o Miniconda3-latest-MacOSX-arm64.sh
/bin/bash Miniconda3-latest-MacOSX-arm64.sh
# clone the repo
git clone https://github.com/lstein/stable-diffusion.git
cd stable-diffusion
#
# wait until the checkpoint file has downloaded, then proceed
#
# create symlink to checkpoint
mkdir -p models/ldm/stable-diffusion-v1/
PATH_TO_CKPT="$HOME/Downloads" # or wherever you saved sd-v1-4.ckpt
ln -s "$PATH_TO_CKPT/sd-v1-4.ckpt" models/ldm/stable-diffusion-v1/model.ckpt
# install packages
PIP_EXISTS_ACTION=w CONDA_SUBDIR=osx-arm64 conda env create -f environment-mac.yaml
conda activate ldm
# only need to do this once
python scripts/preload_models.py
# run SD!
python scripts/dream.py --full_precision # half-precision requires autocast and won't work
```
The original scripts should work as well.
```
python scripts/orig_scripts/txt2img.py --prompt "a photograph of an astronaut riding a horse" --plms
```
Note, `export PIP_EXISTS_ACTION=w` is a precaution to fix `conda env create -f environment-mac.yaml`
never finishing in some situations. So it isn't required but wont hurt.
After you follow all the instructions and run dream.py you might get several
errors. Here's the errors I've seen and found solutions for.
### Is it slow?
Be sure to specify 1 sample and 1 iteration.
python ./scripts/orig_scripts/txt2img.py --prompt "ocean" --ddim_steps 5 --n_samples 1 --n_iter 1
### Doesn't work anymore?
PyTorch nightly includes support for MPS. Because of this, this setup is
inherently unstable. One morning I woke up and it no longer worked no matter
what I did until I switched to miniforge. However, I have another Mac that works
just fine with Anaconda. If you can't get it to work, please search a little
first because many of the errors will get posted and solved. If you can't find
a solution please [create an issue](https://github.com/lstein/stable-diffusion/issues).
One debugging step is to update to the latest version of PyTorch nightly.
conda install pytorch torchvision torchaudio -c pytorch-nightly
If `conda env create -f environment-mac.yaml` takes forever run this.
git clean -f
And run this.
conda clean --yes --all
Or you could reset Anaconda.
conda update --force-reinstall -y -n base -c defaults conda
### "No module named cv2", torch, 'ldm', 'transformers', 'taming', etc.
There are several causes of these errors.
First, did you remember to `conda activate ldm`? If your terminal prompt
begins with "(ldm)" then you activated it. If it begins with "(base)"
or something else you haven't.
Second, you might've run `./scripts/preload_models.py` or `./scripts/dream.py`
instead of `python ./scripts/preload_models.py` or `python ./scripts/dream.py`.
The cause of this error is long so it's below.
Third, if it says you're missing taming you need to rebuild your virtual
environment.
conda env remove -n ldm
conda env create -f environment-mac.yaml
Fourth, If you have activated the ldm virtual environment and tried rebuilding
it, maybe the problem could be that I have something installed that
you don't and you'll just need to manually install it. Make sure you
activate the virtual environment so it installs there instead of
globally.
conda activate ldm
pip install *name*
You might also need to install Rust (I mention this again below).
### How many snakes are living in your computer?
Here's the reason why you have to specify which python to use.
There are several versions of python on macOS and the computer is
picking the wrong one. More specifically, preload_models.py and dream.py says to
find the first `python3` in the path environment variable. You can see which one
it is picking with `which python3`. These are the mostly likely paths you'll see.
% which python3
/usr/bin/python3
The above path is part of the OS. However, that path is a stub that asks you if
you want to install Xcode. If you have Xcode installed already,
/usr/bin/python3 will execute /Library/Developer/CommandLineTools/usr/bin/python3 or
/Applications/Xcode.app/Contents/Developer/usr/bin/python3 (depending on which
Xcode you've selected with `xcode-select`).
% which python3
/opt/homebrew/bin/python3
If you installed python3 with Homebrew and you've modified your path to search
for Homebrew binaries before system ones, you'll see the above path.
% which python
/opt/anaconda3/bin/python
If you drop the "3" you get an entirely different python. Note: starting in
macOS 12.3, /usr/bin/python no longer exists (it was python 2 anyway).
If you have Anaconda installed, this is what you'll see. There is a
/opt/anaconda3/bin/python3 also.
(ldm) % which python
/Users/name/miniforge3/envs/ldm/bin/python
This is what you'll see if you have miniforge and you've correctly activated
the ldm environment. This is the goal.
It's all a mess and you should know [how to modify the path environment variable](https://support.apple.com/guide/terminal/use-environment-variables-apd382cc5fa-4f58-4449-b20a-41c53c006f8f/mac)
if you want to fix it. Here's a brief hint of all the ways you can modify it
(don't really have the time to explain it all here).
- ~/.zshrc
- ~/.bash_profile
- ~/.bashrc
- /etc/paths.d
- /etc/path
Which one you use will depend on what you have installed except putting a file
in /etc/paths.d is what I prefer to do.
### Debugging?
Tired of waiting for your renders to finish before you can see if it
works? Reduce the steps! The image quality will be horrible but at least you'll
get quick feedback.
python ./scripts/txt2img.py --prompt "ocean" --ddim_steps 5 --n_samples 1 --n_iter 1
### OSError: Can't load tokenizer for 'openai/clip-vit-large-patch14'...
python scripts/preload_models.py
### "The operator [name] is not current implemented for the MPS device." (sic)
Example error.
```
...
NotImplementedError: The operator 'aten::_index_put_impl_' is not current implemented for the MPS device. If you want this op to be added in priority during the prototype phase of this feature, please comment on [https://github.com/pytorch/pytorch/issues/77764](https://github.com/pytorch/pytorch/issues/77764). As a temporary fix, you can set the environment variable `PYTORCH_ENABLE_MPS_FALLBACK=1` to use the CPU as a fallback for this op. WARNING: this will be slower than running natively on MPS.
```
The lstein branch includes this fix in [environment-mac.yaml](https://github.com/lstein/stable-diffusion/blob/main/environment-mac.yaml).
### "Could not build wheels for tokenizers"
I have not seen this error because I had Rust installed on my computer before I started playing with Stable Diffusion. The fix is to install Rust.
curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh
### How come `--seed` doesn't work?
First this:
> Completely reproducible results are not guaranteed across PyTorch
releases, individual commits, or different platforms. Furthermore,
results may not be reproducible between CPU and GPU executions, even
when using identical seeds.
[PyTorch docs](https://pytorch.org/docs/stable/notes/randomness.html)
Second, we might have a fix that at least gets a consistent seed sort of. We're
still working on it.
### libiomp5.dylib error?
OMP: Error #15: Initializing libiomp5.dylib, but found libomp.dylib already initialized.
You are likely using an Intel package by mistake. Be sure to run conda with
the environment variable `CONDA_SUBDIR=osx-arm64`, like so:
`CONDA_SUBDIR=osx-arm64 conda install ...`
This error happens with Anaconda on Macs when the Intel-only `mkl` is pulled in by
a dependency. [nomkl](https://stackoverflow.com/questions/66224879/what-is-the-nomkl-python-package-used-for)
is a metapackage designed to prevent this, by making it impossible to install
`mkl`, but if your environment is already broken it may not work.
Do *not* use `os.environ['KMP_DUPLICATE_LIB_OK']='True'` or equivalents as this
masks the underlying issue of using Intel packages.
### Not enough memory.
This seems to be a common problem and is probably the underlying
problem for a lot of symptoms (listed below). The fix is to lower your
image size or to add `model.half()` right after the model is loaded. I
should probably test it out. I've read that the reason this fixes
problems is because it converts the model from 32-bit to 16-bit and
that leaves more RAM for other things. I have no idea how that would
affect the quality of the images though.
See [this issue](https://github.com/CompVis/stable-diffusion/issues/71).
### "Error: product of dimension sizes > 2**31'"
This error happens with img2img, which I haven't played with too much
yet. But I know it's because your image is too big or the resolution
isn't a multiple of 32x32. Because the stable-diffusion model was
trained on images that were 512 x 512, it's always best to use that
output size (which is the default). However, if you're using that size
and you get the above error, try 256 x 256 or 512 x 256 or something
as the source image.
BTW, 2**31-1 = [2,147,483,647](https://en.wikipedia.org/wiki/2,147,483,647#In_computing), which is also 32-bit signed [LONG_MAX](https://en.wikipedia.org/wiki/C_data_types) in C.
### I just got Rickrolled! Do I have a virus?
You don't have a virus. It's part of the project. Here's
[Rick](https://github.com/lstein/stable-diffusion/blob/main/assets/rick.jpeg)
and here's [the
code](https://github.com/lstein/stable-diffusion/blob/69ae4b35e0a0f6ee1af8bb9a5d0016ccb27e36dc/scripts/txt2img.py#L79)
that swaps him in. It's a NSFW filter, which IMO, doesn't work very
good (and we call this "computer vision", sheesh).
Actually, this could be happening because there's not enough RAM. You could try the `model.half()` suggestion or specify smaller output images.
### My images come out black
We might have this fixed, we are still testing.
There's a [similar issue](https://github.com/CompVis/stable-diffusion/issues/69)
on CUDA GPU's where the images come out green. Maybe it's the same issue?
Someone in that issue says to use "--precision full", but this fork
actually disables that flag. I don't know why, someone else provided
that code and I don't know what it does. Maybe the `model.half()`
suggestion above would fix this issue too. I should probably test it.
### "view size is not compatible with input tensor's size and stride"
```
File "/opt/anaconda3/envs/ldm/lib/python3.10/site-packages/torch/nn/functional.py", line 2511, in layer_norm
return torch.layer_norm(input, normalized_shape, weight, bias, eps, torch.backends.cudnn.enabled)
RuntimeError: view size is not compatible with input tensor's size and stride (at least one dimension spans across two contiguous subspaces). Use .reshape(...) instead.
```
Update to the latest version of lstein/stable-diffusion. We were
patching pytorch but we found a file in stable-diffusion that we could
change instead. This is a 32-bit vs 16-bit problem.
### The processor must support the Intel bla bla bla
What? Intel? On an Apple Silicon?
Intel MKL FATAL ERROR: This system does not meet the minimum requirements for use of the Intel(R) Math Kernel Library.
The processor must support the Intel(R) Supplemental Streaming SIMD Extensions 3 (Intel(R) SSSE3) instructions.
The processor must support the Intel(R) Streaming SIMD Extensions 4.2 (Intel(R) SSE4.2) instructions.
The processor must support the Intel(R) Advanced Vector Extensions (Intel(R) AVX) instructions.
This is due to the Intel `mkl` package getting picked up when you try to install
something that depends on it-- Rosetta can translate some Intel instructions but
not the specialized ones here. To avoid this, make sure to use the environment
variable `CONDA_SUBDIR=osx-arm64`, which restricts the Conda environment to only
use ARM packages, and use `nomkl` as described above.

602
README.md
View File

@@ -1,19 +1,71 @@
# Stable Diffusion Dream Script
<h1 align='center'><b>Stable Diffusion Dream Script</b></h1>
<p align='center'>
<img src="static/logo_temp.png"/>
</p>
<p align="center">
<img src="https://img.shields.io/github/last-commit/lstein/stable-diffusion?logo=Python&logoColor=green&style=for-the-badge" alt="last-commit"/>
<img src="https://img.shields.io/github/stars/lstein/stable-diffusion?logo=GitHub&style=for-the-badge" alt="stars"/>
<br>
<img src="https://img.shields.io/github/issues/lstein/stable-diffusion?logo=GitHub&style=for-the-badge" alt="issues"/>
<img src="https://img.shields.io/github/issues-pr/lstein/stable-diffusion?logo=GitHub&style=for-the-badge" alt="pull-requests"/>
</p>
This is a fork of CompVis/stable-diffusion, the wonderful open source
text-to-image generator.
text-to-image generator. This fork supports:
The original has been modified in several ways:
1. An interactive command-line interface that accepts the same prompt
and switches as the Discord bot.
2. A basic Web interface that allows you to run a local web server for
generating images in your browser.
3. Support for img2img in which you provide a seed image to guide the
image creation. (inpainting & masking coming soon)
4. A notebook for running the code on Google Colab.
5. Upscaling and face fixing using the optional ESRGAN and GFPGAN
packages.
6. Weighted subprompts for prompt tuning.
7. [Image variations](VARIATIONS.md) which allow you to systematically
generate variations of an image you like and combine two or more
images together to combine the best features of both.
8. Textual inversion for customization of the prompt language and images.
8. ...and more!
This fork is rapidly evolving, so use the Issues panel to report bugs
and make feature requests, and check back periodically for
improvements and bug fixes.
# Table of Contents
1. [Major Features](#features)
2. [Changelog](#latest-changes)
3. [Installation](#installation)
1. [Linux](#linux)
1. [Windows](#windows)
1. [MacOS](README-Mac-MPS.md)
4. [Troubleshooting](#troubleshooting)
5. [Contributing](#contributing)
6. [Support](#support)
# Features
## Interactive command-line interface similar to the Discord bot
The *dream.py* script, located in scripts/dream.py,
The _dream.py_ script, located in scripts/dream.py,
provides an interactive interface to image generation similar to
the "dream mothership" bot that Stable AI provided on its Discord
server. Unlike the txt2img.py and img2img.py scripts provided in the
original CompViz/stable-diffusion source code repository, the
time-consuming initialization of the AI model
initialization only happens once. After that image generation
initialization only happens once. After that image generation
from the command-line interface is very fast.
The script uses the readline library to allow for in-line editing,
@@ -27,17 +79,11 @@ The script is confirmed to work on Linux and Windows systems. It should
work on MacOSX as well, but this is not confirmed. Note that this script
runs from the command-line (CMD or Terminal window), and does not have a GUI.
~~~~
```
(ldm) ~/stable-diffusion$ python3 ./scripts/dream.py
* Initializing, be patient...
Loading model from models/ldm/text2img-large/model.ckpt
LatentDiffusion: Running in eps-prediction mode
DiffusionWrapper has 872.30 M params.
making attention of type 'vanilla' with 512 in_channels
Working with z of shape (1, 4, 32, 32) = 4096 dimensions.
making attention of type 'vanilla' with 512 in_channels
Loading Bert tokenizer from "models/bert"
setting sampler to plms
(...more initialization messages...)
* Initialization done! Awaiting your command...
dream> ashley judd riding a camel -n2 -s150
@@ -55,13 +101,17 @@ dream> q
00009.png: "ashley judd riding a camel" -s150 -S 416354203
00010.png: "ashley judd riding a camel" -s150 -S 1362479620
00011.png: "there's a fly in my soup" -n6 -g -S 2685670268
~~~~
```
<p align='center'>
<img src="static/dream-py-demo.png"/>
</p>
The dream> prompt's arguments are pretty much identical to those used
in the Discord bot, except you don't need to type "!dream" (it doesn't
hurt if you do). A significant change is that creation of individual
images is now the default unless --grid (-g) is given. For backward
compatibility, the -i switch is recognized. For command-line help
compatibility, the -i switch is recognized. For command-line help
type -h (or --help) at the dream> prompt.
The script itself also recognizes a series of command-line switches
@@ -73,26 +123,207 @@ image outputs and the location of the model weight files.
This script also provides an img2img feature that lets you seed your
creations with a drawing or photo. This is a really cool feature that tells
stable diffusion to build the prompt on top of the image you provide, preserving
the original's basic shape and layout. To use it, provide the --init_img
the original's basic shape and layout. To use it, provide the --init_img
option as shown here:
~~~~
```
dream> "waterfall and rainbow" --init_img=./init-images/crude_drawing.png --strength=0.5 -s100 -n4
~~~~
```
The --init_img (-I) option gives the path to the seed picture. --strength (-f) controls how much
the original will be modified, ranging from 0.0 (keep the original intact), to 1.0 (ignore the original
completely). The default is 0.75, and ranges from 0.25-0.75 give interesting results.
You may also pass a -v<count> option to generate count variants on the original image. This is done by
passing the first generated image back into img2img the requested number of times. It generates interesting
variants.
## GFPGAN and Real-ESRGAN Support
The script also provides the ability to do face restoration and
upscaling with the help of GFPGAN and Real-ESRGAN respectively.
To use the ability, clone the **[GFPGAN
repository](https://github.com/TencentARC/GFPGAN)** and follow their
installation instructions. By default, we expect GFPGAN to be
installed in a 'GFPGAN' sibling directory. Be sure that the `"ldm"`
conda environment is active as you install GFPGAN.
You can use the `--gfpgan_dir` argument with `dream.py` to set a
custom path to your GFPGAN directory. _There are other GFPGAN related
boot arguments if you wish to customize further._
You can install **Real-ESRGAN** by typing the following command.
```
pip install realesrgan
```
**Note: Internet connection needed:**
Users whose GPU machines are isolated from the Internet (e.g. on a
University cluster) should be aware that the first time you run
dream.py with GFPGAN and Real-ESRGAN turned on, it will try to
download model files from the Internet. To rectify this, you may run
`python3 scripts/preload_models.py` after you have installed GFPGAN
and all its dependencies.
**Usage**
You will now have access to two new prompt arguments.
**Upscaling**
`-U : <upscaling_factor> <upscaling_strength>`
The upscaling prompt argument takes two values. The first value is a
scaling factor and should be set to either `2` or `4` only. This will
either scale the image 2x or 4x respectively using different models.
You can set the scaling stength between `0` and `1.0` to control
intensity of the of the scaling. This is handy because AI upscalers
generally tend to smooth out texture details. If you wish to retain
some of those for natural looking results, we recommend using values
between `0.5 to 0.8`.
If you do not explicitly specify an upscaling_strength, it will
default to 0.75.
**Face Restoration**
`-G : <gfpgan_strength>`
This prompt argument controls the strength of the face restoration
that is being applied. Similar to upscaling, values between `0.5 to 0.8` are recommended.
You can use either one or both without any conflicts. In cases where
you use both, the image will be first upscaled and then the face
restoration process will be executed to ensure you get the highest
quality facial features.
`--save_orig`
When you use either `-U` or `-G`, the final result you get is upscaled
or face modified. If you want to save the original Stable Diffusion
generation, you can use the `-save_orig` prompt argument to save the
original unaffected version too.
**Example Usage**
```
dream > superman dancing with a panda bear -U 2 0.6 -G 0.4
```
This also works with img2img:
```
dream> a man wearing a pineapple hat -I path/to/your/file.png -U 2 0.5 -G 0.6
```
**Note**
GFPGAN and Real-ESRGAN are both memory intensive. In order to avoid
crashes and memory overloads during the Stable Diffusion process,
these effects are applied after Stable Diffusion has completed its
work.
In single image generations, you will see the output right away but
when you are using multiple iterations, the images will first be
generated and then upscaled and face restored after that process is
complete. While the image generation is taking place, you will still
be able to preview the base images.
If you wish to stop during the image generation but want to upscale or
face restore a particular generated image, pass it again with the same
prompt and generated seed along with the `-U` and `-G` prompt
arguments to perform those actions.
## Google Colab
Stable Diffusion AI Notebook: <a href="https://colab.research.google.com/github/lstein/stable-diffusion/blob/main/Stable_Diffusion_AI_Notebook.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a> <br>
Open and follow instructions to use an isolated environment running Dream.<br>
Output example:
![Colab Notebook](static/colab_notebook.png)
## Barebones Web Server
As of version 1.10, this distribution comes with a bare bones web
server (see screenshot). To use it, run the _dream.py_ script by
adding the **--web** option.
```
(ldm) ~/stable-diffusion$ python3 scripts/dream.py --web
```
You can then connect to the server by pointing your web browser at
http://localhost:9090, or to the network name or IP address of the server.
Kudos to [Tesseract Cat](https://github.com/TesseractCat) for
contributing this code, and to [dagf2101](https://github.com/dagf2101)
for refining it.
![Dream Web Server](static/dream_web_server.png)
## Reading Prompts from a File
You can automate dream.py by providing a text file with the prompts
you want to run, one line per prompt. The text file must be composed
with a text editor (e.g. Notepad) and not a word processor. Each line
should look like what you would type at the dream> prompt:
```
a beautiful sunny day in the park, children playing -n4 -C10
stormy weather on a mountain top, goats grazing -s100
innovative packaging for a squid's dinner -S137038382
```
Then pass this file's name to dream.py when you invoke it:
```
(ldm) ~/stable-diffusion$ python3 scripts/dream.py --from_file "path/to/prompts.txt"
```
You may read a series of prompts from standard input by providing a filename of "-":
```
(ldm) ~/stable-diffusion$ echo "a beautiful day" | python3 scripts/dream.py --from_file -
```
## Shortcut for reusing seeds from the previous command
Since it is so common to reuse seeds while refining a prompt, there is
now a shortcut as of version 1.11. Provide a **-S** (or **--seed**)
switch of -1 to use the seed of the most recent image generated. If
you produced multiple images with the **-n** switch, then you can go
back further using -2, -3, etc. up to the first image generated by the
previous command. Sorry, but you can't go back further than one
command.
Here's an example of using this to do a quick refinement. It also
illustrates using the new **-G** switch to turn on upscaling and
face enhancement (see previous section):
```
dream> a cute child playing hopscotch -G0.5
[...]
outputs/img-samples/000039.3498014304.png: "a cute child playing hopscotch" -s50 -W512 -H512 -C7.5 -mk_lms -S3498014304
# I wonder what it will look like if I bump up the steps and set facial enhancement to full strength?
dream> a cute child playing hopscotch -G1.0 -s100 -S -1
reusing previous seed 3498014304
[...]
outputs/img-samples/000040.3498014304.png: "a cute child playing hopscotch" -G1.0 -s100 -W512 -H512 -C7.5 -mk_lms -S3498014304
```
## Weighted Prompts
You may weight different sections of the prompt to tell the sampler to attach different levels of
priority to them, by adding :(number) to the end of the section you wish to up- or downweight.
For example consider this prompt:
~~~~
```
tabby cat:0.25 white duck:0.75 hybrid
~~~~
```
This will tell the sampler to invest 25% of its effort on the tabby
cat aspect of the image and 75% on the white duck aspect
@@ -107,8 +338,10 @@ and introducing a new vocabulary to the fixed model.
To train, prepare a folder that contains images sized at 512x512 and execute the following:
~~~~
# As the default backend is not available on Windows, if you're using that platform, execute SET PL_TORCH_DISTRIBUTED_BACKEND=gloo
WINDOWS: As the default backend is not available on Windows, if you're using that platform, set the environment variable `PL_TORCH_DISTRIBUTED_BACKEND=gloo`
```
(ldm) ~/stable-diffusion$ python3 ./main.py --base ./configs/stable-diffusion/v1-finetune.yaml \
-t \
--actual_resume ./models/ldm/stable-diffusion-v1/model.ckpt \
@@ -116,151 +349,118 @@ To train, prepare a folder that contains images sized at 512x512 and execute the
--gpus 0, \
--data_root D:/textual-inversion/my_cat \
--init_word 'cat'
~~~~
```
During the training process, files will be created in /logs/[project][time][project]/
where you can see the process.
conditioning* contains the training prompts
conditioning\* contains the training prompts
inputs, reconstruction the input images for the training epoch
samples, samples scaled for a sample of the prompt and one with the init word provided
samples, samples scaled for a sample of the prompt and one with the init word provided
On a RTX3090, the process for SD will take ~1h @1.6 iterations/sec.
Note: According to the associated paper, the optimal number of images
is 3-5 any more images than that and your model might not converge.
is 3-5. Your model may not converge if you use more images than that.
Training will run indefinately, but you may wish to stop it before the
heat death of the universe, when you fine a low loss epoch or around
heat death of the universe, when you find a low loss epoch or around
~5000 iterations.
Once the model is trained, specify the trained .pt file when starting
dream using
~~~~
```
(ldm) ~/stable-diffusion$ python3 ./scripts/dream.py --embedding_path /path/to/embedding.pt --full_precision
~~~~
```
Then, to utilize your subject at the dream prompt
~~~
```
dream> "a photo of *"
~~~
```
this also works with image2image
~~~~
```
dream> "waterfall and rainbow in the style of *" --init_img=./init-images/crude_drawing.png --strength=0.5 -s100 -n4
~~~~
```
It's also possible to train multiple tokens (modify the placeholder string in configs/stable-diffusion/v1-finetune.yaml) and combine LDM checkpoints using:
~~~~
```
(ldm) ~/stable-diffusion$ python3 ./scripts/merge_embeddings.py \
--manager_ckpts /path/to/first/embedding.pt /path/to/second/embedding.pt [...] \
--output_path /path/to/output/embedding.pt
~~~~
```
Credit goes to @rinongal and the repository located at
https://github.com/rinongal/textual_inversion Please see the
repository and associated paper for details and limitations.
## Changes
# Latest Changes
* v1.08 (24 August 2022)
* Escape single quotes on the dream> command before trying to parse. This avoids
parse errors.
* A new -v option allows you to generate multiple variants of an initial image
in img2img mode. (kudos to Oceanswave)
* Removed instruction to get Python3.8 as first step in Windows install.
Anaconda3 does it for you.
* Added bounds checks for numeric arguments that could cause crashes.
* Cleaned up the copyright and license agreement files.
- v1.13 (in process)
* v1.07 (23 August 2022)
* Image filenames will now never fill gaps in the sequence, but will be assigned the
next higher name in the chosen directory. This ensures that the alphabetic and chronological
sort orders are the same.
- Supports a Google Colab notebook for a standalone server running on Google hardware [Arturo Mendivil](https://github.com/artmen1516)
- WebUI supports GFPGAN/ESRGAN facial reconstruction and upscaling [Kevin Gibbons](https://github.com/bakkot)
- WebUI supports incremental display of in-progress images during generation [Kevin Gibbons](https://github.com/bakkot)
- Can specify --grid on dream.py command line as the default.
- Miscellaneous internal bug and stability fixes.
- Works on M1 Apple hardware.
- Multiple bug fixes.
* v1.06 (23 August 2022)
* Added weighted prompt support contributed by [xraxra](https://github.com/xraxra)
* Example of using weighted prompts to tweak a demonic figure contributed by [bmaltais](https://github.com/bmaltais)
For older changelogs, please visit **[CHANGELOGS](CHANGELOG.md)**.
* v1.05 (22 August 2022 - after the drop)
* Filenames now use the following formats:
000010.95183149.png -- Two files produced by the same command (e.g. -n2),
000010.26742632.png -- distinguished by a different seed.
# Installation
000011.455191342.01.png -- Two files produced by the same command using
000011.455191342.02.png -- a batch size>1 (e.g. -b2). They have the same seed.
There are separate installation walkthroughs for [Linux](#linux), [Windows](#windows) and [Macintosh](#Macintosh)
000011.4160627868.grid#1-4.png -- a grid of four images (-g); the whole grid can
be regenerated with the indicated key
* It should no longer be possible for one image to overwrite another
* You can use the "cd" and "pwd" commands at the dream> prompt to set and retrieve
the path of the output directory.
* v1.04 (22 August 2022 - after the drop)
* Updated README to reflect installation of the released weights.
* Suppressed very noisy and inconsequential warning when loading the frozen CLIP
tokenizer.
* v1.03 (22 August 2022)
* The original txt2img and img2img scripts from the CompViz repository have been moved into
a subfolder named "orig_scripts", to reduce confusion.
* v1.02 (21 August 2022)
* A copy of the prompt and all of its switches and options is now stored in the corresponding
image in a tEXt metadata field named "Dream". You can read the prompt using scripts/images2prompt.py,
or an image editor that allows you to explore the full metadata.
**Please run "conda env update -f environment.yaml" to load the k_lms dependencies!!**
* v1.01 (21 August 2022)
* added k_lms sampling.
**Please run "conda env update -f environment.yaml" to load the k_lms dependencies!!**
* use half precision arithmetic by default, resulting in faster execution and lower memory requirements
Pass argument --full_precision to dream.py to get slower but more accurate image generation
## Installation
There are separate installation walkthroughs for [Linux/Mac](#linuxmac) and [Windows](#windows).
### Linux/Mac
## Linux
1. You will need to install the following prerequisites if they are not already available. Use your
operating system's preferred installer
* Python (version 3.8.5 recommended; higher may work)
* git
operating system's preferred installer
- Python (version 3.8.5 recommended; higher may work)
- git
2. Install the Python Anaconda environment manager using pip3.
```
~$ pip3 install anaconda
```
After installing anaconda, you should log out of your system and log back in. If the installation
worked, your command prompt will be prefixed by the name of the current anaconda environment, "(base)".
3. Copy the stable-diffusion source code from GitHub:
```
(base) ~$ git clone https://github.com/lstein/stable-diffusion.git
```
This will create stable-diffusion folder where you will follow the rest of the steps.
4. Enter the newly-created stable-diffusion folder. From this step forward make sure that you are working in the stable-diffusion directory!
```
(base) ~$ cd stable-diffusion
(base) ~/stable-diffusion$
```
5. Use anaconda to copy necessary python packages, create a new python environment named "ldm",
and activate the environment.
and activate the environment.
```
(base) ~/stable-diffusion$ conda env create -f environment.yaml
(base) ~/stable-diffusion$ conda activate ldm
(ldm) ~/stable-diffusion$
```
After these steps, your command prompt will be prefixed by "(ldm)" as shown above.
6. Load a couple of small machine-learning models required by stable diffusion:
```
(ldm) ~/stable-diffusion$ python3 scripts/preload_models.py
```
@@ -280,13 +480,14 @@ to a page that prompts you to click the "download" link. Save the file somewhere
Now run the following commands from within the stable-diffusion directory. This will create a symbolic
link from the stable-diffusion model.ckpt file, to the true location of the sd-v1-4.ckpt file.
```
(ldm) ~/stable-diffusion$ mkdir -p models/ldm/stable-diffusion-v1
(ldm) ~/stable-diffusion$ ln -sf /path/to/sd-v1-4.ckpt models/ldm/stable-diffusion-v1/model.ckpt
```
8. Start generating images!
```
# for the pre-release weights use the -l or --liaon400m switch
(ldm) ~/stable-diffusion$ python3 scripts/dream.py -l
@@ -297,18 +498,45 @@ link from the stable-diffusion model.ckpt file, to the true location of the sd-v
# for additional configuration switches and arguments, use -h or --help
(ldm) ~/stable-diffusion$ python3 scripts/dream.py -h
```
9. Subsequently, to relaunch the script, be sure to run "conda activate ldm" (step 5, second command), enter the "stable-diffusion"
directory, and then launch the dream script (step 8). If you forget to activate the ldm environment, the script will fail with multiple ModuleNotFound errors.
#### Updating to newer versions of the script
9. Subsequently, to relaunch the script, be sure to run "conda activate ldm" (step 5, second command), enter the "stable-diffusion"
directory, and then launch the dream script (step 8). If you forget to activate the ldm environment, the script will fail with multiple ModuleNotFound errors.
### Updating to newer versions of the script
This distribution is changing rapidly. If you used the "git clone" method (step 5) to download the stable-diffusion directory, then to update to the latest and greatest version, launch the Anaconda window, enter "stable-diffusion", and type:
```
(ldm) ~/stable-diffusion$ git pull
```
This will bring your local copy into sync with the remote one.
### Windows
## Windows
### Notebook install (semi-automated)
We have a
[Jupyter notebook](https://github.com/lstein/stable-diffusion/blob/main/Stable-Diffusion-local-Windows.ipynb)
with cell-by-cell installation steps. It will download the code in this repo as
one of the steps, so instead of cloning this repo, simply download the notebook
from the link above and load it up in VSCode (with the
appropriate extensions installed)/Jupyter/JupyterLab and start running the cells one-by-one.
Note that you will need NVIDIA drivers, Python 3.10, and Git installed
beforehand - simplified
[step-by-step instructions](https://github.com/lstein/stable-diffusion/wiki/Easy-peasy-Windows-install)
are available in the wiki (you'll only need steps 1, 2, & 3 ).
### Manual installs
#### pip
See
[Easy-peasy Windows install](https://github.com/lstein/stable-diffusion/wiki/Easy-peasy-Windows-install)
in the wiki
#### Conda
1. Install Anaconda3 (miniconda3 version) from here: https://docs.anaconda.com/anaconda/install/windows/
@@ -317,24 +545,30 @@ This will bring your local copy into sync with the remote one.
3. Launch Anaconda from the Windows Start menu. This will bring up a command window. Type all the remaining commands in this window.
4. Run the command:
```
git clone https://github.com/lstein/stable-diffusion.git
```
This will create stable-diffusion folder where you will follow the rest of the steps.
5. Enter the newly-created stable-diffusion folder. From this step forward make sure that you are working in the stable-diffusion directory!
```
cd stable-diffusion
```
6. Run the following two commands:
```
conda env create -f environment.yaml (step 6a)
conda activate ldm (step 6b)
```
This will install all python requirements and activate the "ldm" environment which sets PATH and other environment variables properly.
7. Run the command:
```
python scripts\preload_models.py
```
@@ -347,30 +581,32 @@ downloaded just-in-time)
8. Now you need to install the weights for the big stable diffusion model.
For running with the released weights, you will first need to set up
an acount with Hugging Face (https://huggingface.co). Use your
an acount with Hugging Face (https://huggingface.co). Use your
credentials to log in, and then point your browser at
https://huggingface.co/CompVis/stable-diffusion-v-1-4-original. You
https://huggingface.co/CompVis/stable-diffusion-v-1-4-original. You
may be asked to sign a license agreement at this point.
Click on "Files and versions" near the top of the page, and then click
on the file named "sd-v1-4.ckpt". You'll be taken to a page that
prompts you to click the "download" link. Now save the file somewhere
safe on your local machine. The weight file is >4 GB in size, so
safe on your local machine. The weight file is >4 GB in size, so
downloading may take a while.
Now run the following commands from **within the stable-diffusion
directory** to copy the weights file to the right place:
```
mkdir -p models\ldm\stable-diffusion-v1
copy C:\path\to\sd-v1-4.ckpt models\ldm\stable-diffusion-v1\model.ckpt
```
Please replace "C:\path\to\sd-v1.4.ckpt" with the correct path to wherever
you stashed this file. If you prefer not to copy or move the .ckpt file,
you stashed this file. If you prefer not to copy or move the .ckpt file,
you may instead create a shortcut to it from within
"models\ldm\stable-diffusion-v1\".
9. Start generating images!
```
# for the pre-release weights
python scripts\dream.py -l
@@ -378,38 +614,57 @@ python scripts\dream.py -l
# for the post-release weights
python scripts\dream.py
```
10. Subsequently, to relaunch the script, first activate the Anaconda command window (step 3), enter the stable-diffusion directory (step 5, "cd \path\to\stable-diffusion"), run "conda activate ldm" (step 6b), and then launch the dream script (step 9).
#### Updating to newer versions of the script
10. Subsequently, to relaunch the script, first activate the Anaconda
command window (step 3), enter the stable-diffusion directory (step 5,
"cd \path\to\stable-diffusion"), run "conda activate ldm" (step 6b),
and then launch the dream script (step 9).
**Note:** Tildebyte has written an alternative ["Easy peasy Windows
install"](https://github.com/lstein/stable-diffusion/wiki/Easy-peasy-Windows-install)
which uses the Windows Powershell and pew. If you are having trouble
with Anaconda on Windows, give this a try (or try it first!)
### Updating to newer versions of the script
This distribution is changing rapidly. If you used the "git clone"
method (step 5) to download the stable-diffusion directory, then to
update to the latest and greatest version, launch the Anaconda window,
enter "stable-diffusion", and type:
This distribution is changing rapidly. If you used the "git clone" method (step 5) to download the stable-diffusion directory, then to update to the latest and greatest version, launch the Anaconda window, enter "stable-diffusion", and type:
```
git pull
```
This will bring your local copy into sync with the remote one.
## Simplified API for text to image generation
## Macintosh
See [README-Mac-MPS](README-Mac-MPS.md) for instructions.
# Simplified API for text to image generation
For programmers who wish to incorporate stable-diffusion into other
products, this repository includes a simplified API for text to image generation, which
lets you create images from a prompt in just three lines of code:
products, this repository includes a simplified API for text to image
generation, which lets you create images from a prompt in just three
lines of code:
~~~~
```
from ldm.simplet2i import T2I
model = T2I()
outputs = model.txt2img("a unicorn in manhattan")
~~~~
```
Outputs is a list of lists in the format [[filename1,seed1],[filename2,seed2]...]
Please see ldm/simplet2i.py for more information.
Please see ldm/simplet2i.py for more information. A set of example scripts is
coming RSN.
## Workaround for machines with limited internet connectivity
# Workaround for machines with limited internet connectivity
My development machine is a GPU node in a high-performance compute
cluster which has no connection to the internet. During model
initialization, stable-diffusion tries to download the Bert tokenizer
and a file needed by the kornia library. This obviously didn't work
and a file needed by the kornia library. This obviously didn't work
for me.
To work around this, I have modified ldm/modules/encoders/modules.py
@@ -420,7 +675,7 @@ prior to running the code on an isolated one. This assumes that both
machines share a common network-mounted filesystem with a common
.cache directory.
~~~~
```
(ldm) ~/stable-diffusion$ python3 ./scripts/preload_models.py
preloading bert tokenizer...
Downloading: 100%|██████████████████████████████████| 28.0/28.0 [00:00<00:00, 49.3kB/s]
@@ -432,30 +687,113 @@ preloading kornia requirements...
Downloading: "https://github.com/DagnyT/hardnet/raw/master/pretrained/train_liberty_with_aug/checkpoint_liberty_with_aug.pth" to /u/lstein/.cache/torch/hub/checkpoints/checkpoint_liberty_with_aug.pth
100%|███████████████████████████████████████████████| 5.10M/5.10M [00:00<00:00, 101MB/s]
...success
~~~~
```
If you don't need this change and want to download the files just in
time, copy over the file ldm/modules/encoders/modules.py from the
CompVis/stable-diffusion repository. Or you can run preload_models.py
on the target machine.
# Troubleshooting
## Support
Here are a few common installation problems and their solutions. Often
these are caused by incomplete installations or crashes during the
install process.
- PROBLEM: During "conda env create -f environment.yaml", conda
hangs indefinitely.
- SOLUTION: Enter the stable-diffusion directory and completely
remove the "src" directory and all its contents. The safest way
to do this is to enter the stable-diffusion directory and
give the command "git clean -f". If this still doesn't fix
the problem, try "conda clean -all" and then restart at the
"conda env create" step.
---
- PROBLEM: dream.py crashes with the complaint that it can't find
ldm.simplet2i.py. Or it complains that function is being passed
incorrect parameters.
- SOLUTION: Reinstall the stable diffusion modules. Enter the
stable-diffusion directory and give the command "pip install -e ."
---
- PROBLEM: dream.py dies, complaining of various missing modules, none
of which starts with "ldm".
- SOLUTION: From within the stable-diffusion directory, run "conda env
update -f environment.yaml" This is also frequently the solution to
complaints about an unknown function in a module.
---
- PROBLEM: There's a feature or bugfix in the Stable Diffusion GitHub
that you want to try out.
- SOLUTION: If the fix/feature is on the "main" branch, enter the stable-diffusion
directory and do a "git pull". Usually this will be sufficient, but if
you start to see errors about missing or incorrect modules, use the
command "pip install -e ." and/or "conda env update -f environment.yaml"
(These commands won't break anything.)
- If the feature/fix is on a branch (e.g. "foo-bugfix"), the recipe is similar, but
do a "git pull <name of branch>".
- If the feature/fix is in a pull request that has not yet been made
part of the main branch or a feature/bugfix branch, then from the page
for the desired pull request, look for the line at the top that reads
"xxxx wants to merge xx commits into lstein:main from YYYYYY". Copy
the URL in YYYY. It should have the format
https://github.com/<name of contributor>/stable-diffusion/tree/<name
of branch>
- Then **go to the directory above stable-diffusion**, and rename the
directory to "stable-diffusion.lstein", "stable-diffusion.old", or
whatever. You can then git clone the branch that contains the
pull request:
```
git clone https://github.com/<name of contributor>/stable-diffusion/tree/<name
of branch>
```
You will need to go through the install procedure again, but it should
be fast because all the dependencies are already loaded.
# Contributing
Anyone who wishes to contribute to this project, whether
documentation, features, bug fixes, code cleanup, testing, or code
reviews, is very much encouraged to do so. If you are unfamiliar with
how to contribute to GitHub projects, here is a [Getting Started
Guide](https://opensource.com/article/19/7/create-pull-request-github).
A full set of contribution guidelines, along with templates, are in
progress, but for now the most important thing is to **make your pull
request against the "development" branch**, and not against
"main". This will help keep public breakage to a minimum and will
allow you to propose more radical changes.
# Support
For support,
please use this repository's GitHub Issues tracking service. Feel free
to send me an email if you use and like the script.
*Original Author:* Lincoln D. Stein <lincoln.stein@gmail.com>
_Original Author:_ Lincoln D. Stein <lincoln.stein@gmail.com>
*Contributions by:*
_Contributions by:_
[Peter Kowalczyk](https://github.com/slix), [Henry Harrison](https://github.com/hwharrison),
[xraxra](https://github.com/xraxra), [bmaltais](https://github.com/bmaltais), [Sean McLellan] (https://github.com/Oceanswave],
[nicolai256](https://github.com/nicolai256], [Benjamin Warner](https://github.com/warner-benjamin),
and [tildebyte](https://github.com/tildebyte)
[xraxra](https://github.com/xraxra), [bmaltais](https://github.com/bmaltais), [Sean McLellan](https://github.com/Oceanswave),
[nicolai256](https://github.com/nicolai256), [Benjamin Warner](https://github.com/warner-benjamin),
[tildebyte](https://github.com/tildebyte),[yunsaki](https://github.com/yunsaki), [James Reynolds][https://github.com/magnusviri],
[Tesseract Cat](https://github.com/TesseractCat), and many more!
(If you have contributed and don't see your name on the list of
contributors, please let lstein know about the omission, or make a
pull request)
Original portions of the software are Copyright (c) 2020 Lincoln D. Stein (https://github.com/lstein)
#Further Reading
# Further Reading
Please see the original README for more information on this software
and underlying algorithm, located in the file README-CompViz.md.
and underlying algorithm, located in the file [README-CompViz.md](README-CompViz.md).

View File

@@ -0,0 +1,259 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Easy-peasy Windows install"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Note that you will need NVIDIA drivers, Python 3.10, and Git installed\n",
"beforehand - simplified\n",
"[step-by-step instructions](https://github.com/lstein/stable-diffusion/wiki/Easy-peasy-Windows-install)\n",
"are available in the wiki (you'll only need steps 1, 2, & 3 )"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Run each cell in turn. In VSCode, either hit SHIFT-ENTER, or click on the little ▶️ to the left of the cell. In Jupyter/JupyterLab, you **must** hit SHIFT-ENTER"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"%pip install pew"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"%%cmd\n",
"git clone https://github.com/lstein/stable-diffusion.git"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"%cd stable-diffusion"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"%%writefile requirements.txt\n",
"albumentations==0.4.3\n",
"einops==0.3.0\n",
"huggingface-hub==0.8.1\n",
"imageio-ffmpeg==0.4.2\n",
"imageio==2.9.0\n",
"kornia==0.6.0\n",
"omegaconf==2.1.1\n",
"opencv-python==4.6.0.66\n",
"pillow==9.2.0\n",
"pudb==2019.2\n",
"pytorch-lightning==1.4.2\n",
"streamlit==1.12.0\n",
"# Regular \"taming-transformers\" doesn't seem to work\n",
"taming-transformers-rom1504==0.0.6\n",
"test-tube>=0.7.5\n",
"torch-fidelity==0.3.0\n",
"torchmetrics==0.6.0\n",
"torchvision==0.12.0\n",
"transformers==4.19.2\n",
"git+https://github.com/openai/CLIP.git@main#egg=clip\n",
"git+https://github.com/lstein/k-diffusion.git@master#egg=k-diffusion\n",
"# No CUDA in PyPi builds\n",
"torch@https://download.pytorch.org/whl/cu113/torch-1.11.0%2Bcu113-cp310-cp310-win_amd64.whl\n",
"# No MKL in PyPi builds (faster, more robust than OpenBLAS)\n",
"numpy@https://download.lfd.uci.edu/pythonlibs/archived/numpy-1.22.4+mkl-cp310-cp310-win_amd64.whl\n",
"-e .\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"%%cmd\n",
"pew new --python 3.10 -r requirements.txt --dont-activate ldm"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Switch the notebook kernel to the new 'ldm' environment!\n",
"\n",
"## VSCode: restart VSCode and come back to this cell\n",
"\n",
"1. Ctrl+Shift+P\n",
"1. Type \"Select Interpreter\" and select \"Jupyter: Select Interpreter to Start Jupyter Server\"\n",
"1. VSCode will say that it needs to install packages. Click the \"Install\" button.\n",
"1. Once the install is finished, do 1 & 2 again\n",
"1. Pick 'ldm'\n",
"1. Run the following cell"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"%cd stable-diffusion"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"\n",
"## Jupyter/JupyterLab\n",
"\n",
"1. Run the cell below\n",
"1. Click on the toolbar where it says \"(ipyknel)\" ↗️. You should get a pop-up asking you to \"Select Kernel\". Pick 'ldm' from the drop-down.\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### DO NOT RUN THE FOLLOWING CELL IF YOU ARE USING VSCODE!!"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# DO NOT RUN THIS CELL IF YOU ARE USING VSCODE!!\n",
"%%cmd\n",
"pew workon ldm\n",
"pip3 install ipykernel\n",
"python -m ipykernel install --name=ldm"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### When running the next cell, Jupyter/JupyterLab users might get a warning saying \"IProgress not found\". This can be ignored."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"%run \"scripts/preload_models.py\""
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"%%cmd\n",
"mkdir \"models/ldm/stable-diffusion-v1\""
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Now copy the SD model you downloaded from Hugging Face into the above new directory, and (if necessary) rename it to 'model.ckpt'"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Now go create some magic!\n",
"\n",
"VSCode\n",
"\n",
"- The actual input box for the 'dream' prompt will appear at the very top of the VSCode window. Type in your commands and hit 'ENTER'.\n",
"- To quit, hit the 'Interrupt' button in the toolbar up there ⬆️ a couple of times, then hit ENTER (you'll probably see a terrifying traceback from Python - just ignore it).\n",
"\n",
"Jupyter/JupyterLab\n",
"\n",
"- The input box for the 'dream' prompt will appear below. Type in your commands and hit 'ENTER'.\n",
"- To quit, hit the interrupt button (⏹️) in the toolbar up there ⬆️ a couple of times, then hit ENTER (you'll probably see a terrifying traceback from Python - just ignore it)."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"%run \"scripts/dream.py\""
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Once this seems to be working well, you can try opening a terminal\n",
"\n",
"- VSCode: type ('CTRL+`')\n",
"- Jupyter/JupyterLab: File|New Terminal\n",
"- Or jump out of the notebook entirely, and open Powershell/Command Prompt\n",
"\n",
"Now:\n",
"\n",
"1. `cd` to wherever the 'stable-diffusion' directory is\n",
"1. Run `pew workon ldm`\n",
"1. Run `winpty python scripts\\dream.py`"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3.10.6 ('ldm')",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.6"
},
"vscode": {
"interpreter": {
"hash": "a05e4574567b7bc2c98f7f9aa579f9ea5b8739b54844ab610ac85881c4be2659"
}
}
},
"nbformat": 4,
"nbformat_minor": 4
}

View File

@@ -0,0 +1,256 @@
{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"colab": {
"name": "Stable_Diffusion_AI_Notebook.ipynb",
"provenance": [],
"collapsed_sections": [],
"private_outputs": true
},
"kernelspec": {
"name": "python3",
"display_name": "Python 3"
},
"language_info": {
"name": "python"
},
"accelerator": "GPU",
"gpuClass": "standard"
},
"cells": [
{
"cell_type": "markdown",
"source": [
"# Stable Diffusion AI Notebook\n",
"\n",
"<img src=\"https://user-images.githubusercontent.com/60411196/186547976-d9de378a-9de8-4201-9c25-c057a9c59bad.jpeg\" alt=\"stable-diffusion-ai\" width=\"170px\"/> <br>\n",
"#### Instructions:\n",
"1. Execute each cell in order to mount a Dream bot and create images from text. <br>\n",
"2. Once cells 1-8 were run correctly you'll be executing a terminal in cell #9, you'll to enter `pipenv run scripts/dream.py` command to run Dream bot.<br> \n",
"3. After launching dream bot, you'll see: <br> `Dream > ` in terminal. <br> Insert a command, eg. `Dream > Astronaut floating in a distant galaxy`, or type `-h` for help.\n",
"3. After completion you'll see your generated images in path `stable-diffusion/outputs/img-samples/`, you can also display images in cell #10.\n",
"4. To quit Dream bot use `q` command. <br> \n",
"---\n",
"<font color=\"red\">Note:</font> It takes some time to load, but after installing all dependencies you can use the bot all time you want while colab instance is up. <br>\n",
"<font color=\"red\">Requirements:</font> For this notebook to work you need to have [Stable-Diffusion-v-1-4](https://huggingface.co/CompVis/stable-diffusion-v-1-4-original) stored in your Google Drive, it will be needed in cell #6\n",
"##### For more details visit Github repository: [lstein/stable-diffusion](https://github.com/lstein/stable-diffusion)\n",
"---\n"
],
"metadata": {
"id": "ycYWcsEKc6w7"
}
},
{
"cell_type": "code",
"source": [
"#@title 1. Check current GPU assigned\n",
"!nvidia-smi -L\n",
"!nvidia-smi"
],
"metadata": {
"cellView": "form",
"id": "a2Z5Qu_o8VtQ"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"cellView": "form",
"id": "vbI9ZsQHzjqF"
},
"outputs": [],
"source": [
"#@title 2. Download stable-diffusion Repository\n",
"from os.path import exists\n",
"\n",
"if exists(\"/content/stable-diffusion/\")==True:\n",
" print(\"Already downloaded repo\")\n",
"else:\n",
" !git clone --quiet https://github.com/lstein/stable-diffusion.git # Original repo\n",
" %cd stable-diffusion/\n",
" !git checkout --quiet tags/release-1.09\n",
" "
]
},
{
"cell_type": "code",
"source": [
"#@title 3. Install Python 3.8 \n",
"%%capture --no-stderr\n",
"import gc\n",
"!apt-get -qq install python3.8\n",
"gc.collect()"
],
"metadata": {
"id": "daHlozvwKesj",
"cellView": "form"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"#@title 4. Install dependencies from file in a VirtualEnv\n",
"#@markdown Be patient, it takes ~ 5 - 7min <br>\n",
"%%capture --no-stderr\n",
"#Virtual environment\n",
"!pip install pipenv -q\n",
"!pip install colab-xterm\n",
"%load_ext colabxterm\n",
"!pipenv --python 3.8\n",
"!pipenv install -r requirements.txt --skip-lock\n",
"gc.collect()\n"
],
"metadata": {
"cellView": "form",
"id": "QbXcGXYEFSNB"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"#@title 5. Mount google Drive\n",
"from google.colab import drive\n",
"drive.mount('/content/drive')"
],
"metadata": {
"cellView": "form",
"id": "YEWPV-sF1RDM"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"#@title 6. Drive Path to model\n",
"#@markdown Path should start with /content/drive/path-to-your-file <br>\n",
"#@markdown <font color=\"red\">Note:</font> Model should be downloaded from https://huggingface.co <br>\n",
"#@markdown Lastest release: [Stable-Diffusion-v-1-4](https://huggingface.co/CompVis/stable-diffusion-v-1-4-original)\n",
"from os.path import exists\n",
"\n",
"model_path = \"\" #@param {type:\"string\"}\n",
"if exists(model_path)==True:\n",
" print(\"✅ Valid directory\")\n",
"else: \n",
" print(\"❌ File doesn't exist\")"
],
"metadata": {
"cellView": "form",
"id": "zRTJeZ461WGu"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"#@title 7. Symlink to model\n",
"\n",
"from os.path import exists\n",
"import os \n",
"\n",
"# Folder creation if it doesn't exist\n",
"if exists(\"/content/stable-diffusion/models/ldm/stable-diffusion-v1\")==True:\n",
" print(\"❗ Dir stable-diffusion-v1 already exists\")\n",
"else:\n",
" %mkdir /content/stable-diffusion/models/ldm/stable-diffusion-v1\n",
" print(\"✅ Dir stable-diffusion-v1 created\")\n",
"\n",
"# Symbolic link if it doesn't exist\n",
"if exists(\"/content/stable-diffusion/models/ldm/stable-diffusion-v1/model.ckpt\")==True:\n",
" print(\"❗ Symlink already created\")\n",
"else: \n",
" src = model_path\n",
" dst = '/content/stable-diffusion/models/ldm/stable-diffusion-v1/model.ckpt'\n",
" os.symlink(src, dst) \n",
" print(\"✅ Symbolic link created successfully\")"
],
"metadata": {
"id": "UY-NNz4I8_aG",
"cellView": "form"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"#@title 8. Load small ML models required\n",
"%%capture --no-stderr\n",
"!pipenv run scripts/preload_models.py\n",
"gc.collect()"
],
"metadata": {
"cellView": "form",
"id": "ChIDWxLVHGGJ"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"#@title 9. Run Terminal and Execute Dream bot\n",
"#@markdown <font color=\"blue\">Steps:</font> <br>\n",
"#@markdown 1. Execute command `pipenv run scripts/dream.py` to run dream bot.<br>\n",
"#@markdown 2. After initialized you'll see `Dream>` line.<br>\n",
"#@markdown 3. Example text: `Astronaut floating in a distant galaxy` <br>\n",
"#@markdown 4. To quit Dream bot use: `q` command.<br>\n",
"\n",
"#Run from virtual env\n",
"\n",
"%xterm\n",
"gc.collect()"
],
"metadata": {
"id": "ir4hCrMIuUpl",
"cellView": "form"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"#@title 10. Show generated images\n",
"\n",
"import glob\n",
"import matplotlib.pyplot as plt\n",
"import matplotlib.image as mpimg\n",
"%matplotlib inline\n",
"\n",
"images = []\n",
"for img_path in glob.glob('/content/stable-diffusion/outputs/img-samples/*.png'):\n",
" images.append(mpimg.imread(img_path))\n",
"\n",
"# Remove ticks and labels on x-axis and y-axis both\n",
"\n",
"plt.figure(figsize=(20,10))\n",
"\n",
"columns = 5\n",
"for i, image in enumerate(images):\n",
" ax = plt.subplot(len(images) / columns + 1, columns, i + 1)\n",
" ax.axes.xaxis.set_visible(False)\n",
" ax.axes.yaxis.set_visible(False)\n",
" ax.axis('off')\n",
" plt.imshow(image)\n",
" gc.collect()\n",
"\n"
],
"metadata": {
"cellView": "form",
"id": "qnLohSHmKoGk"
},
"execution_count": null,
"outputs": []
}
]
}

View File

@@ -1,31 +0,0 @@
Feature requests:
1. "gobig" mode - split image into strips, scale up, add detail using
img2img and reassemble with feathering. Issue #66.
2. Port basujindal low VRAM optimizations. Issue #62
3. Store images under folders named after the prompt. Issue #27.
4. Some sort of automation for generating variations. Issues #32 and #47.
5. Support for inpainting masks #68.
6. Support for loading variations of the stable-diffusion
weights #49
7. Support for klms and other non-ddim samplers in img2img() #36
8. Pass a shell command to open up an image viewer on the last
batch of images generated #29.
Code Refactorization:
1. Move the PNG file generation code out of simplet2i and into
separate module. txt2img() and img2img() should return Image
objects, and parent code is responsible for filenaming logic.
2. Refactor redundant code that is shared between txt2img() and
img2img().
3. Experiment with replacing CompViz code with HuggingFace.

113
VARIATIONS.md Normal file
View File

@@ -0,0 +1,113 @@
# Cheat Sheat for Generating Variations
Release 1.13 of SD-Dream adds support for image variations. There are two things that you can do:
1. Generate a series of systematic variations of an image, given a
prompt. The amount of variation from one image to the next can be
controlled.
2. Given two or more variations that you like, you can combine them in
a weighted fashion
This cheat sheet provides a quick guide for how this works in
practice, using variations to create the desired image of Xena,
Warrior Princess.
## Step 1 -- find a base image that you like
The prompt we will use throughout is "lucy lawless as xena, warrior
princess, character portrait, high resolution." This will be indicated
as "prompt" in the examples below.
First we let SD create a series of images in the usual way, in this case
requesting six iterations:
~~~
dream> lucy lawless as xena, warrior princess, character portrait, high resolution -n6
...
Outputs:
./outputs/Xena/000001.1579445059.png: "prompt" -s50 -W512 -H512 -C7.5 -Ak_lms -S1579445059
./outputs/Xena/000001.1880768722.png: "prompt" -s50 -W512 -H512 -C7.5 -Ak_lms -S1880768722
./outputs/Xena/000001.332057179.png: "prompt" -s50 -W512 -H512 -C7.5 -Ak_lms -S332057179
./outputs/Xena/000001.2224800325.png: "prompt" -s50 -W512 -H512 -C7.5 -Ak_lms -S2224800325
./outputs/Xena/000001.465250761.png: "prompt" -s50 -W512 -H512 -C7.5 -Ak_lms -S465250761
./outputs/Xena/000001.3357757885.png: "prompt" -s50 -W512 -H512 -C7.5 -Ak_lms -S3357757885
~~~
The one with seed 3357757885 looks nice:
<img src="static/variation_walkthru/000001.3357757885.png"/>
Let's try to generate some variations. Using the same seed, we pass
the argument -v0.1 (or --variant_amount), which generates a series of
variations each differing by a variation amount of 0.2. This number
ranges from 0 to 1.0, with higher numbers being larger amounts of
variation.
~~~
dream> "prompt" -n6 -S3357757885 -v0.2
...
Outputs:
./outputs/Xena/000002.784039624.png: "prompt" -s50 -W512 -H512 -C7.5 -Ak_lms -V 784039624:0.2 -S3357757885
./outputs/Xena/000002.3647897225.png: "prompt" -s50 -W512 -H512 -C7.5 -Ak_lms -V 3647897225:0.2 -S3357757885
./outputs/Xena/000002.917731034.png: "prompt" -s50 -W512 -H512 -C7.5 -Ak_lms -V 917731034:0.2 -S3357757885
./outputs/Xena/000002.4116285959.png: "prompt" -s50 -W512 -H512 -C7.5 -Ak_lms -V 4116285959:0.2 -S3357757885
./outputs/Xena/000002.1614299449.png: "prompt" -s50 -W512 -H512 -C7.5 -Ak_lms -V 1614299449:0.2 -S3357757885
./outputs/Xena/000002.1335553075.png: "prompt" -s50 -W512 -H512 -C7.5 -Ak_lms -V 1335553075:0.2 -S3357757885
~~~
Note that the output for each image has a -V option giving the
"variant subseed" for that image, consisting of a seed followed by the
variation amount used to generate it.
This gives us a series of closely-related variations, including the
two shown here.
<img src="static/variation_walkthru/000002.3647897225.png">
<img src="static/variation_walkthru/000002.1614299449.png">
I like the expression on Xena's face in the first one (subseed
3647897225), and the armor on her shoulder in the second one (subseed
1614299449). Can we combine them to get the best of both worlds?
We combine the two variations using -V (--with_variations). Again, we
must provide the seed for the originally-chosen image in order for
this to work.
~~~
dream> "prompt" -S3357757885 -V3647897225,0.1;1614299449,0.1
Outputs:
./outputs/Xena/000003.1614299449.png: "prompt" -s50 -W512 -H512 -C7.5 -Ak_lms -V 3647897225:0.1,1614299449:0.1 -S3357757885
~~~
Here we are providing equal weights (0.1 and 0.1) for both the
subseeds. The resulting image is close, but not exactly what I
wanted:
<img src="static/variation_walkthru/000003.1614299449.png">
We could either try combining the images with different weights, or we
can generate more variations around the almost-but-not-quite image. We
do the latter, using both the -V (combining) and -v (variation
strength) options. Note that we use -n6 to generate 6 variations:
~~~~
dream> "prompt" -S3357757885 -V3647897225,0.1;1614299449,0.1 -v0.05 -n6
Outputs:
./outputs/Xena/000004.3279757577.png: "prompt" -s50 -W512 -H512 -C7.5 -Ak_lms -V 3647897225:0.1,1614299449:0.1,3279757577:0.05 -S3357757885
./outputs/Xena/000004.2853129515.png: "prompt" -s50 -W512 -H512 -C7.5 -Ak_lms -V 3647897225:0.1,1614299449:0.1,2853129515:0.05 -S3357757885
./outputs/Xena/000004.3747154981.png: "prompt" -s50 -W512 -H512 -C7.5 -Ak_lms -V 3647897225:0.1,1614299449:0.1,3747154981:0.05 -S3357757885
./outputs/Xena/000004.2664260391.png: "prompt" -s50 -W512 -H512 -C7.5 -Ak_lms -V 3647897225:0.1,1614299449:0.1,2664260391:0.05 -S3357757885
./outputs/Xena/000004.1642517170.png: "prompt" -s50 -W512 -H512 -C7.5 -Ak_lms -V 3647897225:0.1,1614299449:0.1,1642517170:0.05 -S3357757885
./outputs/Xena/000004.2183375608.png: "prompt" -s50 -W512 -H512 -C7.5 -Ak_lms -V 3647897225:0.1,1614299449:0.1,2183375608:0.05 -S3357757885
~~~~
This produces six images, all slight variations on the combination of
the chosen two images. Here's the one I like best:
<img src="static/variation_walkthru/000004.3747154981.png">
As you can see, this is a very powerful too, which when combined with
subprompt weighting, gives you great control over the content and
quality of your generated images.

18
configs/models.yaml Normal file
View File

@@ -0,0 +1,18 @@
# This file describes the alternative machine learning models
# available to the dream script.
#
# To add a new model, follow the examples below. Each
# model requires a model config file, a weights file,
# and the width and height of the images it
# was trained on.
laion400m:
config: configs/latent-diffusion/txt2img-1p4B-eval.yaml
weights: models/ldm/text2img-large/model.ckpt
width: 256
height: 256
stable-diffusion-1.4:
config: configs/stable-diffusion/v1-inference.yaml
weights: models/ldm/stable-diffusion-v1/model.ckpt
width: 512
height: 512

View File

@@ -73,8 +73,8 @@ model:
data:
target: main.DataModuleFromConfig
params:
batch_size: 2
num_workers: 16
batch_size: 1
num_workers: 2
wrap: false
train:
target: ldm.data.personalized.PersonalizedBase
@@ -92,6 +92,9 @@ data:
repeats: 10
lightning:
modelcheckpoint:
params:
every_n_train_steps: 500
callbacks:
image_logger:
target: main.ImageLogger
@@ -102,4 +105,5 @@ lightning:
trainer:
benchmark: True
max_steps: 6100
max_steps: 4000

58
environment-mac.yaml Normal file
View File

@@ -0,0 +1,58 @@
name: ldm
channels:
- pytorch-nightly
- conda-forge
dependencies:
- python==3.9.13
- pip==22.2.2
# pytorch-nightly, left unpinned
- pytorch
- torchmetrics
- torchvision
# I suggest to keep the other deps sorted for convenience.
# If you wish to upgrade to 3.10, try to run this:
#
# ```shell
# CONDA_CMD=conda
# sed -E 's/python==3.9.13/python==3.10.5/;s/ldm/ldm-3.10/;21,99s/- ([^=]+)==.+/- \1/' environment-mac.yaml > /tmp/environment-mac-updated.yml
# CONDA_SUBDIR=osx-arm64 $CONDA_CMD env create -f /tmp/environment-mac-updated.yml && $CONDA_CMD list -n ldm-3.10 | awk ' {print " - " $1 "==" $2;} '
# ```
#
# Unfortunately, as of 2022-08-31, this fails at the pip stage.
- albumentations==1.2.1
- coloredlogs==15.0.1
- einops==0.4.1
- grpcio==1.46.4
- humanfriendly
- imageio-ffmpeg==0.4.7
- imageio==2.21.2
- imgaug==0.4.0
- kornia==0.6.7
- mpmath==1.2.1
- nomkl
- numpy==1.23.2
- omegaconf==2.1.1
- onnx==1.12.0
- onnxruntime==1.12.1
- opencv==4.6.0
- pudb==2022.1
- pytorch-lightning==1.6.5
- scipy==1.9.1
- streamlit==1.12.2
- sympy==1.10.1
- tensorboard==2.9.0
- transformers==4.21.2
- pip:
- invisible-watermark
- test-tube
- tokenizers
- torch-fidelity
- -e git+https://github.com/huggingface/diffusers.git@v0.2.4#egg=diffusers
- -e git+https://github.com/CompVis/taming-transformers.git@master#egg=taming-transformers
- -e git+https://github.com/openai/CLIP.git@main#egg=clip
- -e git+https://github.com/Birch-san/k-diffusion.git@mps#egg=k_diffusion
- -e .
variables:
PYTORCH_ENABLE_MPS_FALLBACK: 1

View File

@@ -18,14 +18,13 @@ dependencies:
- pytorch-lightning==1.4.2
- omegaconf==2.1.1
- test-tube>=0.7.5
- streamlit>=0.73.1
- pillow==9.0.1
- streamlit==1.12.0
- pillow==9.2.0
- einops==0.3.0
- torch-fidelity==0.3.0
- transformers==4.19.2
- torchmetrics==0.6.0
- kornia==0.6
- accelerate==0.12.0
- kornia==0.6.0
- -e git+https://github.com/openai/CLIP.git@main#egg=clip
- -e git+https://github.com/CompVis/taming-transformers.git@master#egg=taming-transformers
- -e git+https://github.com/lstein/k-diffusion.git@master#egg=k-diffusion

View File

@@ -1,11 +1,17 @@
from abc import abstractmethod
from torch.utils.data import Dataset, ConcatDataset, ChainDataset, IterableDataset
from torch.utils.data import (
Dataset,
ConcatDataset,
ChainDataset,
IterableDataset,
)
class Txt2ImgIterableBaseDataset(IterableDataset):
'''
"""
Define an interface to make the IterableDatasets for text2img data chainable
'''
"""
def __init__(self, num_records=0, valid_ids=None, size=256):
super().__init__()
self.num_records = num_records
@@ -13,11 +19,13 @@ class Txt2ImgIterableBaseDataset(IterableDataset):
self.sample_ids = valid_ids
self.size = size
print(f'{self.__class__.__name__} dataset contains {self.__len__()} examples.')
print(
f'{self.__class__.__name__} dataset contains {self.__len__()} examples.'
)
def __len__(self):
return self.num_records
@abstractmethod
def __iter__(self):
pass
pass

View File

@@ -11,24 +11,34 @@ from tqdm import tqdm
from torch.utils.data import Dataset, Subset
import taming.data.utils as tdu
from taming.data.imagenet import str_to_indices, give_synsets_from_indices, download, retrieve
from taming.data.imagenet import (
str_to_indices,
give_synsets_from_indices,
download,
retrieve,
)
from taming.data.imagenet import ImagePaths
from ldm.modules.image_degradation import degradation_fn_bsr, degradation_fn_bsr_light
from ldm.modules.image_degradation import (
degradation_fn_bsr,
degradation_fn_bsr_light,
)
def synset2idx(path_to_yaml="data/index_synset.yaml"):
def synset2idx(path_to_yaml='data/index_synset.yaml'):
with open(path_to_yaml) as f:
di2s = yaml.load(f)
return dict((v,k) for k,v in di2s.items())
return dict((v, k) for k, v in di2s.items())
class ImageNetBase(Dataset):
def __init__(self, config=None):
self.config = config or OmegaConf.create()
if not type(self.config)==dict:
if not type(self.config) == dict:
self.config = OmegaConf.to_container(self.config)
self.keep_orig_class_label = self.config.get("keep_orig_class_label", False)
self.keep_orig_class_label = self.config.get(
'keep_orig_class_label', False
)
self.process_images = True # if False we skip loading & processing images and self.data contains filepaths
self._prepare()
self._prepare_synset_to_human()
@@ -46,17 +56,23 @@ class ImageNetBase(Dataset):
raise NotImplementedError()
def _filter_relpaths(self, relpaths):
ignore = set([
"n06596364_9591.JPEG",
])
relpaths = [rpath for rpath in relpaths if not rpath.split("/")[-1] in ignore]
if "sub_indices" in self.config:
indices = str_to_indices(self.config["sub_indices"])
synsets = give_synsets_from_indices(indices, path_to_yaml=self.idx2syn) # returns a list of strings
ignore = set(
[
'n06596364_9591.JPEG',
]
)
relpaths = [
rpath for rpath in relpaths if not rpath.split('/')[-1] in ignore
]
if 'sub_indices' in self.config:
indices = str_to_indices(self.config['sub_indices'])
synsets = give_synsets_from_indices(
indices, path_to_yaml=self.idx2syn
) # returns a list of strings
self.synset2idx = synset2idx(path_to_yaml=self.idx2syn)
files = []
for rpath in relpaths:
syn = rpath.split("/")[0]
syn = rpath.split('/')[0]
if syn in synsets:
files.append(rpath)
return files
@@ -65,78 +81,89 @@ class ImageNetBase(Dataset):
def _prepare_synset_to_human(self):
SIZE = 2655750
URL = "https://heibox.uni-heidelberg.de/f/9f28e956cd304264bb82/?dl=1"
self.human_dict = os.path.join(self.root, "synset_human.txt")
if (not os.path.exists(self.human_dict) or
not os.path.getsize(self.human_dict)==SIZE):
URL = 'https://heibox.uni-heidelberg.de/f/9f28e956cd304264bb82/?dl=1'
self.human_dict = os.path.join(self.root, 'synset_human.txt')
if (
not os.path.exists(self.human_dict)
or not os.path.getsize(self.human_dict) == SIZE
):
download(URL, self.human_dict)
def _prepare_idx_to_synset(self):
URL = "https://heibox.uni-heidelberg.de/f/d835d5b6ceda4d3aa910/?dl=1"
self.idx2syn = os.path.join(self.root, "index_synset.yaml")
if (not os.path.exists(self.idx2syn)):
URL = 'https://heibox.uni-heidelberg.de/f/d835d5b6ceda4d3aa910/?dl=1'
self.idx2syn = os.path.join(self.root, 'index_synset.yaml')
if not os.path.exists(self.idx2syn):
download(URL, self.idx2syn)
def _prepare_human_to_integer_label(self):
URL = "https://heibox.uni-heidelberg.de/f/2362b797d5be43b883f6/?dl=1"
self.human2integer = os.path.join(self.root, "imagenet1000_clsidx_to_labels.txt")
if (not os.path.exists(self.human2integer)):
URL = 'https://heibox.uni-heidelberg.de/f/2362b797d5be43b883f6/?dl=1'
self.human2integer = os.path.join(
self.root, 'imagenet1000_clsidx_to_labels.txt'
)
if not os.path.exists(self.human2integer):
download(URL, self.human2integer)
with open(self.human2integer, "r") as f:
with open(self.human2integer, 'r') as f:
lines = f.read().splitlines()
assert len(lines) == 1000
self.human2integer_dict = dict()
for line in lines:
value, key = line.split(":")
value, key = line.split(':')
self.human2integer_dict[key] = int(value)
def _load(self):
with open(self.txt_filelist, "r") as f:
with open(self.txt_filelist, 'r') as f:
self.relpaths = f.read().splitlines()
l1 = len(self.relpaths)
self.relpaths = self._filter_relpaths(self.relpaths)
print("Removed {} files from filelist during filtering.".format(l1 - len(self.relpaths)))
print(
'Removed {} files from filelist during filtering.'.format(
l1 - len(self.relpaths)
)
)
self.synsets = [p.split("/")[0] for p in self.relpaths]
self.synsets = [p.split('/')[0] for p in self.relpaths]
self.abspaths = [os.path.join(self.datadir, p) for p in self.relpaths]
unique_synsets = np.unique(self.synsets)
class_dict = dict((synset, i) for i, synset in enumerate(unique_synsets))
class_dict = dict(
(synset, i) for i, synset in enumerate(unique_synsets)
)
if not self.keep_orig_class_label:
self.class_labels = [class_dict[s] for s in self.synsets]
else:
self.class_labels = [self.synset2idx[s] for s in self.synsets]
with open(self.human_dict, "r") as f:
with open(self.human_dict, 'r') as f:
human_dict = f.read().splitlines()
human_dict = dict(line.split(maxsplit=1) for line in human_dict)
self.human_labels = [human_dict[s] for s in self.synsets]
labels = {
"relpath": np.array(self.relpaths),
"synsets": np.array(self.synsets),
"class_label": np.array(self.class_labels),
"human_label": np.array(self.human_labels),
'relpath': np.array(self.relpaths),
'synsets': np.array(self.synsets),
'class_label': np.array(self.class_labels),
'human_label': np.array(self.human_labels),
}
if self.process_images:
self.size = retrieve(self.config, "size", default=256)
self.data = ImagePaths(self.abspaths,
labels=labels,
size=self.size,
random_crop=self.random_crop,
)
self.size = retrieve(self.config, 'size', default=256)
self.data = ImagePaths(
self.abspaths,
labels=labels,
size=self.size,
random_crop=self.random_crop,
)
else:
self.data = self.abspaths
class ImageNetTrain(ImageNetBase):
NAME = "ILSVRC2012_train"
URL = "http://www.image-net.org/challenges/LSVRC/2012/"
AT_HASH = "a306397ccf9c2ead27155983c254227c0fd938e2"
NAME = 'ILSVRC2012_train'
URL = 'http://www.image-net.org/challenges/LSVRC/2012/'
AT_HASH = 'a306397ccf9c2ead27155983c254227c0fd938e2'
FILES = [
"ILSVRC2012_img_train.tar",
'ILSVRC2012_img_train.tar',
]
SIZES = [
147897477120,
@@ -151,57 +178,64 @@ class ImageNetTrain(ImageNetBase):
if self.data_root:
self.root = os.path.join(self.data_root, self.NAME)
else:
cachedir = os.environ.get("XDG_CACHE_HOME", os.path.expanduser("~/.cache"))
self.root = os.path.join(cachedir, "autoencoders/data", self.NAME)
cachedir = os.environ.get(
'XDG_CACHE_HOME', os.path.expanduser('~/.cache')
)
self.root = os.path.join(cachedir, 'autoencoders/data', self.NAME)
self.datadir = os.path.join(self.root, "data")
self.txt_filelist = os.path.join(self.root, "filelist.txt")
self.datadir = os.path.join(self.root, 'data')
self.txt_filelist = os.path.join(self.root, 'filelist.txt')
self.expected_length = 1281167
self.random_crop = retrieve(self.config, "ImageNetTrain/random_crop",
default=True)
self.random_crop = retrieve(
self.config, 'ImageNetTrain/random_crop', default=True
)
if not tdu.is_prepared(self.root):
# prep
print("Preparing dataset {} in {}".format(self.NAME, self.root))
print('Preparing dataset {} in {}'.format(self.NAME, self.root))
datadir = self.datadir
if not os.path.exists(datadir):
path = os.path.join(self.root, self.FILES[0])
if not os.path.exists(path) or not os.path.getsize(path)==self.SIZES[0]:
if (
not os.path.exists(path)
or not os.path.getsize(path) == self.SIZES[0]
):
import academictorrents as at
atpath = at.get(self.AT_HASH, datastore=self.root)
assert atpath == path
print("Extracting {} to {}".format(path, datadir))
print('Extracting {} to {}'.format(path, datadir))
os.makedirs(datadir, exist_ok=True)
with tarfile.open(path, "r:") as tar:
with tarfile.open(path, 'r:') as tar:
tar.extractall(path=datadir)
print("Extracting sub-tars.")
subpaths = sorted(glob.glob(os.path.join(datadir, "*.tar")))
print('Extracting sub-tars.')
subpaths = sorted(glob.glob(os.path.join(datadir, '*.tar')))
for subpath in tqdm(subpaths):
subdir = subpath[:-len(".tar")]
subdir = subpath[: -len('.tar')]
os.makedirs(subdir, exist_ok=True)
with tarfile.open(subpath, "r:") as tar:
with tarfile.open(subpath, 'r:') as tar:
tar.extractall(path=subdir)
filelist = glob.glob(os.path.join(datadir, "**", "*.JPEG"))
filelist = glob.glob(os.path.join(datadir, '**', '*.JPEG'))
filelist = [os.path.relpath(p, start=datadir) for p in filelist]
filelist = sorted(filelist)
filelist = "\n".join(filelist)+"\n"
with open(self.txt_filelist, "w") as f:
filelist = '\n'.join(filelist) + '\n'
with open(self.txt_filelist, 'w') as f:
f.write(filelist)
tdu.mark_prepared(self.root)
class ImageNetValidation(ImageNetBase):
NAME = "ILSVRC2012_validation"
URL = "http://www.image-net.org/challenges/LSVRC/2012/"
AT_HASH = "5d6d0df7ed81efd49ca99ea4737e0ae5e3a5f2e5"
VS_URL = "https://heibox.uni-heidelberg.de/f/3e0f6e9c624e45f2bd73/?dl=1"
NAME = 'ILSVRC2012_validation'
URL = 'http://www.image-net.org/challenges/LSVRC/2012/'
AT_HASH = '5d6d0df7ed81efd49ca99ea4737e0ae5e3a5f2e5'
VS_URL = 'https://heibox.uni-heidelberg.de/f/3e0f6e9c624e45f2bd73/?dl=1'
FILES = [
"ILSVRC2012_img_val.tar",
"validation_synset.txt",
'ILSVRC2012_img_val.tar',
'validation_synset.txt',
]
SIZES = [
6744924160,
@@ -217,39 +251,49 @@ class ImageNetValidation(ImageNetBase):
if self.data_root:
self.root = os.path.join(self.data_root, self.NAME)
else:
cachedir = os.environ.get("XDG_CACHE_HOME", os.path.expanduser("~/.cache"))
self.root = os.path.join(cachedir, "autoencoders/data", self.NAME)
self.datadir = os.path.join(self.root, "data")
self.txt_filelist = os.path.join(self.root, "filelist.txt")
cachedir = os.environ.get(
'XDG_CACHE_HOME', os.path.expanduser('~/.cache')
)
self.root = os.path.join(cachedir, 'autoencoders/data', self.NAME)
self.datadir = os.path.join(self.root, 'data')
self.txt_filelist = os.path.join(self.root, 'filelist.txt')
self.expected_length = 50000
self.random_crop = retrieve(self.config, "ImageNetValidation/random_crop",
default=False)
self.random_crop = retrieve(
self.config, 'ImageNetValidation/random_crop', default=False
)
if not tdu.is_prepared(self.root):
# prep
print("Preparing dataset {} in {}".format(self.NAME, self.root))
print('Preparing dataset {} in {}'.format(self.NAME, self.root))
datadir = self.datadir
if not os.path.exists(datadir):
path = os.path.join(self.root, self.FILES[0])
if not os.path.exists(path) or not os.path.getsize(path)==self.SIZES[0]:
if (
not os.path.exists(path)
or not os.path.getsize(path) == self.SIZES[0]
):
import academictorrents as at
atpath = at.get(self.AT_HASH, datastore=self.root)
assert atpath == path
print("Extracting {} to {}".format(path, datadir))
print('Extracting {} to {}'.format(path, datadir))
os.makedirs(datadir, exist_ok=True)
with tarfile.open(path, "r:") as tar:
with tarfile.open(path, 'r:') as tar:
tar.extractall(path=datadir)
vspath = os.path.join(self.root, self.FILES[1])
if not os.path.exists(vspath) or not os.path.getsize(vspath)==self.SIZES[1]:
if (
not os.path.exists(vspath)
or not os.path.getsize(vspath) == self.SIZES[1]
):
download(self.VS_URL, vspath)
with open(vspath, "r") as f:
with open(vspath, 'r') as f:
synset_dict = f.read().splitlines()
synset_dict = dict(line.split() for line in synset_dict)
print("Reorganizing into synset folders")
print('Reorganizing into synset folders')
synsets = np.unique(list(synset_dict.values()))
for s in synsets:
os.makedirs(os.path.join(datadir, s), exist_ok=True)
@@ -258,21 +302,26 @@ class ImageNetValidation(ImageNetBase):
dst = os.path.join(datadir, v)
shutil.move(src, dst)
filelist = glob.glob(os.path.join(datadir, "**", "*.JPEG"))
filelist = glob.glob(os.path.join(datadir, '**', '*.JPEG'))
filelist = [os.path.relpath(p, start=datadir) for p in filelist]
filelist = sorted(filelist)
filelist = "\n".join(filelist)+"\n"
with open(self.txt_filelist, "w") as f:
filelist = '\n'.join(filelist) + '\n'
with open(self.txt_filelist, 'w') as f:
f.write(filelist)
tdu.mark_prepared(self.root)
class ImageNetSR(Dataset):
def __init__(self, size=None,
degradation=None, downscale_f=4, min_crop_f=0.5, max_crop_f=1.,
random_crop=True):
def __init__(
self,
size=None,
degradation=None,
downscale_f=4,
min_crop_f=0.5,
max_crop_f=1.0,
random_crop=True,
):
"""
Imagenet Superresolution Dataloader
Performs following ops in order:
@@ -296,67 +345,86 @@ class ImageNetSR(Dataset):
self.LR_size = int(size / downscale_f)
self.min_crop_f = min_crop_f
self.max_crop_f = max_crop_f
assert(max_crop_f <= 1.)
assert max_crop_f <= 1.0
self.center_crop = not random_crop
self.image_rescaler = albumentations.SmallestMaxSize(max_size=size, interpolation=cv2.INTER_AREA)
self.image_rescaler = albumentations.SmallestMaxSize(
max_size=size, interpolation=cv2.INTER_AREA
)
self.pil_interpolation = False # gets reset later if incase interp_op is from pillow
self.pil_interpolation = (
False # gets reset later if incase interp_op is from pillow
)
if degradation == "bsrgan":
self.degradation_process = partial(degradation_fn_bsr, sf=downscale_f)
if degradation == 'bsrgan':
self.degradation_process = partial(
degradation_fn_bsr, sf=downscale_f
)
elif degradation == "bsrgan_light":
self.degradation_process = partial(degradation_fn_bsr_light, sf=downscale_f)
elif degradation == 'bsrgan_light':
self.degradation_process = partial(
degradation_fn_bsr_light, sf=downscale_f
)
else:
interpolation_fn = {
"cv_nearest": cv2.INTER_NEAREST,
"cv_bilinear": cv2.INTER_LINEAR,
"cv_bicubic": cv2.INTER_CUBIC,
"cv_area": cv2.INTER_AREA,
"cv_lanczos": cv2.INTER_LANCZOS4,
"pil_nearest": PIL.Image.NEAREST,
"pil_bilinear": PIL.Image.BILINEAR,
"pil_bicubic": PIL.Image.BICUBIC,
"pil_box": PIL.Image.BOX,
"pil_hamming": PIL.Image.HAMMING,
"pil_lanczos": PIL.Image.LANCZOS,
'cv_nearest': cv2.INTER_NEAREST,
'cv_bilinear': cv2.INTER_LINEAR,
'cv_bicubic': cv2.INTER_CUBIC,
'cv_area': cv2.INTER_AREA,
'cv_lanczos': cv2.INTER_LANCZOS4,
'pil_nearest': PIL.Image.NEAREST,
'pil_bilinear': PIL.Image.BILINEAR,
'pil_bicubic': PIL.Image.BICUBIC,
'pil_box': PIL.Image.BOX,
'pil_hamming': PIL.Image.HAMMING,
'pil_lanczos': PIL.Image.LANCZOS,
}[degradation]
self.pil_interpolation = degradation.startswith("pil_")
self.pil_interpolation = degradation.startswith('pil_')
if self.pil_interpolation:
self.degradation_process = partial(TF.resize, size=self.LR_size, interpolation=interpolation_fn)
self.degradation_process = partial(
TF.resize,
size=self.LR_size,
interpolation=interpolation_fn,
)
else:
self.degradation_process = albumentations.SmallestMaxSize(max_size=self.LR_size,
interpolation=interpolation_fn)
self.degradation_process = albumentations.SmallestMaxSize(
max_size=self.LR_size, interpolation=interpolation_fn
)
def __len__(self):
return len(self.base)
def __getitem__(self, i):
example = self.base[i]
image = Image.open(example["file_path_"])
image = Image.open(example['file_path_'])
if not image.mode == "RGB":
image = image.convert("RGB")
if not image.mode == 'RGB':
image = image.convert('RGB')
image = np.array(image).astype(np.uint8)
min_side_len = min(image.shape[:2])
crop_side_len = min_side_len * np.random.uniform(self.min_crop_f, self.max_crop_f, size=None)
crop_side_len = min_side_len * np.random.uniform(
self.min_crop_f, self.max_crop_f, size=None
)
crop_side_len = int(crop_side_len)
if self.center_crop:
self.cropper = albumentations.CenterCrop(height=crop_side_len, width=crop_side_len)
self.cropper = albumentations.CenterCrop(
height=crop_side_len, width=crop_side_len
)
else:
self.cropper = albumentations.RandomCrop(height=crop_side_len, width=crop_side_len)
self.cropper = albumentations.RandomCrop(
height=crop_side_len, width=crop_side_len
)
image = self.cropper(image=image)["image"]
image = self.image_rescaler(image=image)["image"]
image = self.cropper(image=image)['image']
image = self.image_rescaler(image=image)['image']
if self.pil_interpolation:
image_pil = PIL.Image.fromarray(image)
@@ -364,10 +432,10 @@ class ImageNetSR(Dataset):
LR_image = np.array(LR_image).astype(np.uint8)
else:
LR_image = self.degradation_process(image=image)["image"]
LR_image = self.degradation_process(image=image)['image']
example["image"] = (image/127.5 - 1.0).astype(np.float32)
example["LR_image"] = (LR_image/127.5 - 1.0).astype(np.float32)
example['image'] = (image / 127.5 - 1.0).astype(np.float32)
example['LR_image'] = (LR_image / 127.5 - 1.0).astype(np.float32)
return example
@@ -377,9 +445,11 @@ class ImageNetSRTrain(ImageNetSR):
super().__init__(**kwargs)
def get_base(self):
with open("data/imagenet_train_hr_indices.p", "rb") as f:
with open('data/imagenet_train_hr_indices.p', 'rb') as f:
indices = pickle.load(f)
dset = ImageNetTrain(process_images=False,)
dset = ImageNetTrain(
process_images=False,
)
return Subset(dset, indices)
@@ -388,7 +458,9 @@ class ImageNetSRValidation(ImageNetSR):
super().__init__(**kwargs)
def get_base(self):
with open("data/imagenet_val_hr_indices.p", "rb") as f:
with open('data/imagenet_val_hr_indices.p', 'rb') as f:
indices = pickle.load(f)
dset = ImageNetValidation(process_images=False,)
dset = ImageNetValidation(
process_images=False,
)
return Subset(dset, indices)

View File

@@ -7,30 +7,33 @@ from torchvision import transforms
class LSUNBase(Dataset):
def __init__(self,
txt_file,
data_root,
size=None,
interpolation="bicubic",
flip_p=0.5
):
def __init__(
self,
txt_file,
data_root,
size=None,
interpolation='bicubic',
flip_p=0.5,
):
self.data_paths = txt_file
self.data_root = data_root
with open(self.data_paths, "r") as f:
with open(self.data_paths, 'r') as f:
self.image_paths = f.read().splitlines()
self._length = len(self.image_paths)
self.labels = {
"relative_file_path_": [l for l in self.image_paths],
"file_path_": [os.path.join(self.data_root, l)
for l in self.image_paths],
'relative_file_path_': [l for l in self.image_paths],
'file_path_': [
os.path.join(self.data_root, l) for l in self.image_paths
],
}
self.size = size
self.interpolation = {"linear": PIL.Image.LINEAR,
"bilinear": PIL.Image.BILINEAR,
"bicubic": PIL.Image.BICUBIC,
"lanczos": PIL.Image.LANCZOS,
}[interpolation]
self.interpolation = {
'linear': PIL.Image.LINEAR,
'bilinear': PIL.Image.BILINEAR,
'bicubic': PIL.Image.BICUBIC,
'lanczos': PIL.Image.LANCZOS,
}[interpolation]
self.flip = transforms.RandomHorizontalFlip(p=flip_p)
def __len__(self):
@@ -38,55 +41,86 @@ class LSUNBase(Dataset):
def __getitem__(self, i):
example = dict((k, self.labels[k][i]) for k in self.labels)
image = Image.open(example["file_path_"])
if not image.mode == "RGB":
image = image.convert("RGB")
image = Image.open(example['file_path_'])
if not image.mode == 'RGB':
image = image.convert('RGB')
# default to score-sde preprocessing
img = np.array(image).astype(np.uint8)
crop = min(img.shape[0], img.shape[1])
h, w, = img.shape[0], img.shape[1]
img = img[(h - crop) // 2:(h + crop) // 2,
(w - crop) // 2:(w + crop) // 2]
h, w, = (
img.shape[0],
img.shape[1],
)
img = img[
(h - crop) // 2 : (h + crop) // 2,
(w - crop) // 2 : (w + crop) // 2,
]
image = Image.fromarray(img)
if self.size is not None:
image = image.resize((self.size, self.size), resample=self.interpolation)
image = image.resize(
(self.size, self.size), resample=self.interpolation
)
image = self.flip(image)
image = np.array(image).astype(np.uint8)
example["image"] = (image / 127.5 - 1.0).astype(np.float32)
example['image'] = (image / 127.5 - 1.0).astype(np.float32)
return example
class LSUNChurchesTrain(LSUNBase):
def __init__(self, **kwargs):
super().__init__(txt_file="data/lsun/church_outdoor_train.txt", data_root="data/lsun/churches", **kwargs)
super().__init__(
txt_file='data/lsun/church_outdoor_train.txt',
data_root='data/lsun/churches',
**kwargs
)
class LSUNChurchesValidation(LSUNBase):
def __init__(self, flip_p=0., **kwargs):
super().__init__(txt_file="data/lsun/church_outdoor_val.txt", data_root="data/lsun/churches",
flip_p=flip_p, **kwargs)
def __init__(self, flip_p=0.0, **kwargs):
super().__init__(
txt_file='data/lsun/church_outdoor_val.txt',
data_root='data/lsun/churches',
flip_p=flip_p,
**kwargs
)
class LSUNBedroomsTrain(LSUNBase):
def __init__(self, **kwargs):
super().__init__(txt_file="data/lsun/bedrooms_train.txt", data_root="data/lsun/bedrooms", **kwargs)
super().__init__(
txt_file='data/lsun/bedrooms_train.txt',
data_root='data/lsun/bedrooms',
**kwargs
)
class LSUNBedroomsValidation(LSUNBase):
def __init__(self, flip_p=0.0, **kwargs):
super().__init__(txt_file="data/lsun/bedrooms_val.txt", data_root="data/lsun/bedrooms",
flip_p=flip_p, **kwargs)
super().__init__(
txt_file='data/lsun/bedrooms_val.txt',
data_root='data/lsun/bedrooms',
flip_p=flip_p,
**kwargs
)
class LSUNCatsTrain(LSUNBase):
def __init__(self, **kwargs):
super().__init__(txt_file="data/lsun/cat_train.txt", data_root="data/lsun/cats", **kwargs)
super().__init__(
txt_file='data/lsun/cat_train.txt',
data_root='data/lsun/cats',
**kwargs
)
class LSUNCatsValidation(LSUNBase):
def __init__(self, flip_p=0., **kwargs):
super().__init__(txt_file="data/lsun/cat_val.txt", data_root="data/lsun/cats",
flip_p=flip_p, **kwargs)
def __init__(self, flip_p=0.0, **kwargs):
super().__init__(
txt_file='data/lsun/cat_val.txt',
data_root='data/lsun/cats',
flip_p=flip_p,
**kwargs
)

View File

@@ -72,31 +72,57 @@ imagenet_dual_templates_small = [
]
per_img_token_list = [
'א', 'ב', 'ג', 'ד', 'ה', 'ו', 'ז', 'ח', 'ט', 'י', 'כ', 'ל', 'מ', 'נ', 'ס', 'ע', 'פ', 'צ', 'ק', 'ר', 'ש', 'ת',
'א',
'ב',
'ג',
'ד',
'ה',
'ו',
'ז',
'ח',
'ט',
'י',
'כ',
'ל',
'מ',
'נ',
'ס',
'ע',
'פ',
'צ',
'ק',
'ר',
'ש',
'ת',
]
class PersonalizedBase(Dataset):
def __init__(self,
data_root,
size=None,
repeats=100,
interpolation="bicubic",
flip_p=0.5,
set="train",
placeholder_token="*",
per_image_tokens=False,
center_crop=False,
mixing_prob=0.25,
coarse_class_text=None,
):
def __init__(
self,
data_root,
size=None,
repeats=100,
interpolation='bicubic',
flip_p=0.5,
set='train',
placeholder_token='*',
per_image_tokens=False,
center_crop=False,
mixing_prob=0.25,
coarse_class_text=None,
):
self.data_root = data_root
self.image_paths = [os.path.join(self.data_root, file_path) for file_path in os.listdir(self.data_root)]
self.image_paths = [
os.path.join(self.data_root, file_path)
for file_path in os.listdir(self.data_root)
]
# self._length = len(self.image_paths)
self.num_images = len(self.image_paths)
self._length = self.num_images
self._length = self.num_images
self.placeholder_token = placeholder_token
@@ -107,17 +133,20 @@ class PersonalizedBase(Dataset):
self.coarse_class_text = coarse_class_text
if per_image_tokens:
assert self.num_images < len(per_img_token_list), f"Can't use per-image tokens when the training set contains more than {len(per_img_token_list)} tokens. To enable larger sets, add more tokens to 'per_img_token_list'."
assert self.num_images < len(
per_img_token_list
), f"Can't use per-image tokens when the training set contains more than {len(per_img_token_list)} tokens. To enable larger sets, add more tokens to 'per_img_token_list'."
if set == "train":
if set == 'train':
self._length = self.num_images * repeats
self.size = size
self.interpolation = {"linear": PIL.Image.LINEAR,
"bilinear": PIL.Image.BILINEAR,
"bicubic": PIL.Image.BICUBIC,
"lanczos": PIL.Image.LANCZOS,
}[interpolation]
self.interpolation = {
'linear': PIL.Image.LINEAR,
'bilinear': PIL.Image.BILINEAR,
'bicubic': PIL.Image.BICUBIC,
'lanczos': PIL.Image.LANCZOS,
}[interpolation]
self.flip = transforms.RandomHorizontalFlip(p=flip_p)
def __len__(self):
@@ -127,34 +156,47 @@ class PersonalizedBase(Dataset):
example = {}
image = Image.open(self.image_paths[i % self.num_images])
if not image.mode == "RGB":
image = image.convert("RGB")
if not image.mode == 'RGB':
image = image.convert('RGB')
placeholder_string = self.placeholder_token
if self.coarse_class_text:
placeholder_string = f"{self.coarse_class_text} {placeholder_string}"
placeholder_string = (
f'{self.coarse_class_text} {placeholder_string}'
)
if self.per_image_tokens and np.random.uniform() < self.mixing_prob:
text = random.choice(imagenet_dual_templates_small).format(placeholder_string, per_img_token_list[i % self.num_images])
text = random.choice(imagenet_dual_templates_small).format(
placeholder_string, per_img_token_list[i % self.num_images]
)
else:
text = random.choice(imagenet_templates_small).format(placeholder_string)
example["caption"] = text
text = random.choice(imagenet_templates_small).format(
placeholder_string
)
example['caption'] = text
# default to score-sde preprocessing
img = np.array(image).astype(np.uint8)
if self.center_crop:
crop = min(img.shape[0], img.shape[1])
h, w, = img.shape[0], img.shape[1]
img = img[(h - crop) // 2:(h + crop) // 2,
(w - crop) // 2:(w + crop) // 2]
h, w, = (
img.shape[0],
img.shape[1],
)
img = img[
(h - crop) // 2 : (h + crop) // 2,
(w - crop) // 2 : (w + crop) // 2,
]
image = Image.fromarray(img)
if self.size is not None:
image = image.resize((self.size, self.size), resample=self.interpolation)
image = image.resize(
(self.size, self.size), resample=self.interpolation
)
image = self.flip(image)
image = np.array(image).astype(np.uint8)
example["image"] = (image / 127.5 - 1.0).astype(np.float32)
return example
example['image'] = (image / 127.5 - 1.0).astype(np.float32)
return example

View File

@@ -50,29 +50,55 @@ imagenet_dual_templates_small = [
]
per_img_token_list = [
'א', 'ב', 'ג', 'ד', 'ה', 'ו', 'ז', 'ח', 'ט', 'י', 'כ', 'ל', 'מ', 'נ', 'ס', 'ע', 'פ', 'צ', 'ק', 'ר', 'ש', 'ת',
'א',
'ב',
'ג',
'ד',
'ה',
'ו',
'ז',
'ח',
'ט',
'י',
'כ',
'ל',
'מ',
'נ',
'ס',
'ע',
'פ',
'צ',
'ק',
'ר',
'ש',
'ת',
]
class PersonalizedBase(Dataset):
def __init__(self,
data_root,
size=None,
repeats=100,
interpolation="bicubic",
flip_p=0.5,
set="train",
placeholder_token="*",
per_image_tokens=False,
center_crop=False,
):
def __init__(
self,
data_root,
size=None,
repeats=100,
interpolation='bicubic',
flip_p=0.5,
set='train',
placeholder_token='*',
per_image_tokens=False,
center_crop=False,
):
self.data_root = data_root
self.image_paths = [os.path.join(self.data_root, file_path) for file_path in os.listdir(self.data_root)]
self.image_paths = [
os.path.join(self.data_root, file_path)
for file_path in os.listdir(self.data_root)
]
# self._length = len(self.image_paths)
self.num_images = len(self.image_paths)
self._length = self.num_images
self._length = self.num_images
self.placeholder_token = placeholder_token
@@ -80,17 +106,20 @@ class PersonalizedBase(Dataset):
self.center_crop = center_crop
if per_image_tokens:
assert self.num_images < len(per_img_token_list), f"Can't use per-image tokens when the training set contains more than {len(per_img_token_list)} tokens. To enable larger sets, add more tokens to 'per_img_token_list'."
assert self.num_images < len(
per_img_token_list
), f"Can't use per-image tokens when the training set contains more than {len(per_img_token_list)} tokens. To enable larger sets, add more tokens to 'per_img_token_list'."
if set == "train":
if set == 'train':
self._length = self.num_images * repeats
self.size = size
self.interpolation = {"linear": PIL.Image.LINEAR,
"bilinear": PIL.Image.BILINEAR,
"bicubic": PIL.Image.BICUBIC,
"lanczos": PIL.Image.LANCZOS,
}[interpolation]
self.interpolation = {
'linear': PIL.Image.LINEAR,
'bilinear': PIL.Image.BILINEAR,
'bicubic': PIL.Image.BICUBIC,
'lanczos': PIL.Image.LANCZOS,
}[interpolation]
self.flip = transforms.RandomHorizontalFlip(p=flip_p)
def __len__(self):
@@ -100,30 +129,41 @@ class PersonalizedBase(Dataset):
example = {}
image = Image.open(self.image_paths[i % self.num_images])
if not image.mode == "RGB":
image = image.convert("RGB")
if not image.mode == 'RGB':
image = image.convert('RGB')
if self.per_image_tokens and np.random.uniform() < 0.25:
text = random.choice(imagenet_dual_templates_small).format(self.placeholder_token, per_img_token_list[i % self.num_images])
text = random.choice(imagenet_dual_templates_small).format(
self.placeholder_token, per_img_token_list[i % self.num_images]
)
else:
text = random.choice(imagenet_templates_small).format(self.placeholder_token)
example["caption"] = text
text = random.choice(imagenet_templates_small).format(
self.placeholder_token
)
example['caption'] = text
# default to score-sde preprocessing
img = np.array(image).astype(np.uint8)
if self.center_crop:
crop = min(img.shape[0], img.shape[1])
h, w, = img.shape[0], img.shape[1]
img = img[(h - crop) // 2:(h + crop) // 2,
(w - crop) // 2:(w + crop) // 2]
h, w, = (
img.shape[0],
img.shape[1],
)
img = img[
(h - crop) // 2 : (h + crop) // 2,
(w - crop) // 2 : (w + crop) // 2,
]
image = Image.fromarray(img)
if self.size is not None:
image = image.resize((self.size, self.size), resample=self.interpolation)
image = image.resize(
(self.size, self.size), resample=self.interpolation
)
image = self.flip(image)
image = np.array(image).astype(np.uint8)
example["image"] = (image / 127.5 - 1.0).astype(np.float32)
return example
example['image'] = (image / 127.5 - 1.0).astype(np.float32)
return example

17
ldm/dream/devices.py Normal file
View File

@@ -0,0 +1,17 @@
import torch
def choose_torch_device() -> str:
'''Convenience routine for guessing which GPU device to run model on'''
if torch.cuda.is_available():
return 'cuda'
if hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():
return 'mps'
return 'cpu'
def choose_autocast_device(device) -> str:
'''Returns an autocast compatible device from a torch device'''
device_type = device.type # this returns 'mps' on M1
# autocast only supports cuda or cpu
if device_type not in ('cuda','cpu'):
return 'cpu'
return device_type

70
ldm/dream/image_util.py Normal file
View File

@@ -0,0 +1,70 @@
from math import sqrt, floor, ceil
from PIL import Image
class InitImageResizer():
"""Simple class to create resized copies of an Image while preserving the aspect ratio."""
def __init__(self,Image):
self.image = Image
def resize(self,width=None,height=None) -> Image:
"""
Return a copy of the image resized to fit within
a box width x height. The aspect ratio is
maintained. If neither width nor height are provided,
then returns a copy of the original image. If one or the other is
provided, then the other will be calculated from the
aspect ratio.
Everything is floored to the nearest multiple of 64 so
that it can be passed to img2img()
"""
im = self.image
ar = im.width/float(im.height)
# Infer missing values from aspect ratio
if not(width or height): # both missing
width = im.width
height = im.height
elif not height: # height missing
height = int(width/ar)
elif not width: # width missing
width = int(height*ar)
# rw and rh are the resizing width and height for the image
# they maintain the aspect ratio, but may not completelyl fill up
# the requested destination size
(rw,rh) = (width,int(width/ar)) if im.width>=im.height else (int(height*ar),height)
#round everything to multiples of 64
width,height,rw,rh = map(
lambda x: x-x%64, (width,height,rw,rh)
)
# no resize necessary, but return a copy
if im.width == width and im.height == height:
return im.copy()
# otherwise resize the original image so that it fits inside the bounding box
resized_image = self.image.resize((rw,rh),resample=Image.Resampling.LANCZOS)
return resized_image
def make_grid(image_list, rows=None, cols=None):
image_cnt = len(image_list)
if None in (rows, cols):
rows = floor(sqrt(image_cnt)) # try to make it square
cols = ceil(image_cnt / rows)
width = image_list[0].width
height = image_list[0].height
grid_img = Image.new('RGB', (width * cols, height * rows))
i = 0
for r in range(0, rows):
for c in range(0, cols):
if i >= len(image_list):
break
grid_img.paste(image_list[i], (c * width, r * height))
i = i + 1
return grid_img

79
ldm/dream/pngwriter.py Normal file
View File

@@ -0,0 +1,79 @@
"""
Two helper classes for dealing with PNG images and their path names.
PngWriter -- Converts Images generated by T2I into PNGs, finds
appropriate names for them, and writes prompt metadata
into the PNG.
PromptFormatter -- Utility for converting a Namespace of prompt parameters
back into a formatted prompt string with command-line switches.
"""
import os
import re
from PIL import PngImagePlugin
# -------------------image generation utils-----
class PngWriter:
def __init__(self, outdir):
self.outdir = outdir
os.makedirs(outdir, exist_ok=True)
# gives the next unique prefix in outdir
def unique_prefix(self):
# sort reverse alphabetically until we find max+1
dirlist = sorted(os.listdir(self.outdir), reverse=True)
# find the first filename that matches our pattern or return 000000.0.png
existing_name = next(
(f for f in dirlist if re.match('^(\d+)\..*\.png', f)),
'0000000.0.png',
)
basecount = int(existing_name.split('.', 1)[0]) + 1
return f'{basecount:06}'
# saves image named _image_ to outdir/name, writing metadata from prompt
# returns full path of output
def save_image_and_prompt_to_png(self, image, prompt, name):
path = os.path.join(self.outdir, name)
info = PngImagePlugin.PngInfo()
info.add_text('Dream', prompt)
image.save(path, 'PNG', pnginfo=info)
return path
class PromptFormatter:
def __init__(self, t2i, opt):
self.t2i = t2i
self.opt = opt
# note: the t2i object should provide all these values.
# there should be no need to or against opt values
def normalize_prompt(self):
"""Normalize the prompt and switches"""
t2i = self.t2i
opt = self.opt
switches = list()
switches.append(f'"{opt.prompt}"')
switches.append(f'-s{opt.steps or t2i.steps}')
switches.append(f'-W{opt.width or t2i.width}')
switches.append(f'-H{opt.height or t2i.height}')
switches.append(f'-C{opt.cfg_scale or t2i.cfg_scale}')
switches.append(f'-A{opt.sampler_name or t2i.sampler_name}')
if opt.init_img:
switches.append(f'-I{opt.init_img}')
if opt.fit:
switches.append(f'--fit')
if opt.strength and opt.init_img is not None:
switches.append(f'-f{opt.strength or t2i.strength}')
if opt.gfpgan_strength:
switches.append(f'-G{opt.gfpgan_strength}')
if opt.upscale:
switches.append(f'-U {" ".join([str(u) for u in opt.upscale])}')
if opt.variation_amount > 0:
switches.append(f'-v{opt.variation_amount}')
if opt.with_variations:
formatted_variations = ','.join(f'{seed}:{weight}' for seed, weight in opt.with_variations)
switches.append(f'-V{formatted_variations}')
if t2i.full_precision:
switches.append('-F')
return ' '.join(switches)

121
ldm/dream/readline.py Normal file
View File

@@ -0,0 +1,121 @@
"""
Readline helper functions for dream.py (linux and mac only).
"""
import os
import re
import atexit
# ---------------readline utilities---------------------
try:
import readline
readline_available = True
except:
readline_available = False
class Completer:
def __init__(self, options):
self.options = sorted(options)
return
def complete(self, text, state):
buffer = readline.get_line_buffer()
if text.startswith(('-I', '--init_img')):
return self._path_completions(text, state, ('.png','.jpg','.jpeg'))
if buffer.strip().endswith('cd') or text.startswith(('.', '/')):
return self._path_completions(text, state, ())
response = None
if state == 0:
# This is the first time for this text, so build a match list.
if text:
self.matches = [
s for s in self.options if s and s.startswith(text)
]
else:
self.matches = self.options[:]
# Return the state'th item from the match list,
# if we have that many.
try:
response = self.matches[state]
except IndexError:
response = None
return response
def _path_completions(self, text, state, extensions):
# get the path so far
if text.startswith('-I'):
path = text.replace('-I', '', 1).lstrip()
elif text.startswith('--init_img='):
path = text.replace('--init_img=', '', 1).lstrip()
else:
path = text
matches = list()
path = os.path.expanduser(path)
if len(path) == 0:
matches.append(text + './')
else:
dir = os.path.dirname(path)
dir_list = os.listdir(dir)
for n in dir_list:
if n.startswith('.') and len(n) > 1:
continue
full_path = os.path.join(dir, n)
if full_path.startswith(path):
if os.path.isdir(full_path):
matches.append(
os.path.join(os.path.dirname(text), n) + '/'
)
elif n.endswith(extensions):
matches.append(os.path.join(os.path.dirname(text), n))
try:
response = matches[state]
except IndexError:
response = None
return response
if readline_available:
readline.set_completer(
Completer(
[
'--steps','-s',
'--seed','-S',
'--iterations','-n',
'--width','-W','--height','-H',
'--cfg_scale','-C',
'--grid','-g',
'--individual','-i',
'--init_img','-I',
'--strength','-f',
'--variants','-v',
'--outdir','-o',
'--sampler','-A','-m',
'--embedding_path',
'--device',
'--grid','-g',
'--gfpgan_strength','-G',
'--upscale','-U',
'-save_orig','--save_original',
'--skip_normalize','-x',
'--log_tokenization','t',
]
).complete
)
readline.set_completer_delims(' ')
readline.parse_and_bind('tab: complete')
histfile = os.path.join(os.path.expanduser('~'), '.dream_history')
try:
readline.read_history_file(histfile)
readline.set_history_length(1000)
except FileNotFoundError:
pass
atexit.register(readline.write_history_file, histfile)

202
ldm/dream/server.py Normal file
View File

@@ -0,0 +1,202 @@
import json
import base64
import mimetypes
import os
from http.server import BaseHTTPRequestHandler, ThreadingHTTPServer
from ldm.dream.pngwriter import PngWriter
from threading import Event
class CanceledException(Exception):
pass
class DreamServer(BaseHTTPRequestHandler):
model = None
canceled = Event()
def do_GET(self):
if self.path == "/":
self.send_response(200)
self.send_header("Content-type", "text/html")
self.end_headers()
with open("./static/dream_web/index.html", "rb") as content:
self.wfile.write(content.read())
elif self.path == "/config.js":
# unfortunately this import can't be at the top level, since that would cause a circular import
from ldm.gfpgan.gfpgan_tools import gfpgan_model_exists
self.send_response(200)
self.send_header("Content-type", "application/javascript")
self.end_headers()
config = {
'gfpgan_model_exists': gfpgan_model_exists
}
self.wfile.write(bytes("let config = " + json.dumps(config) + ";\n", "utf-8"))
elif self.path == "/cancel":
self.canceled.set()
self.send_response(200)
self.send_header("Content-type", "application/json")
self.end_headers()
self.wfile.write(bytes('{}', 'utf8'))
else:
path = "." + self.path
cwd = os.path.realpath(os.getcwd())
is_in_cwd = os.path.commonprefix((os.path.realpath(path), cwd)) == cwd
if not (is_in_cwd and os.path.exists(path)):
self.send_response(404)
return
mime_type = mimetypes.guess_type(path)[0]
if mime_type is not None:
self.send_response(200)
self.send_header("Content-type", mime_type)
self.end_headers()
with open("." + self.path, "rb") as content:
self.wfile.write(content.read())
else:
self.send_response(404)
def do_POST(self):
self.send_response(200)
self.send_header("Content-type", "application/json")
self.end_headers()
# unfortunately this import can't be at the top level, since that would cause a circular import
from ldm.gfpgan.gfpgan_tools import gfpgan_model_exists
content_length = int(self.headers['Content-Length'])
post_data = json.loads(self.rfile.read(content_length))
prompt = post_data['prompt']
initimg = post_data['initimg']
strength = float(post_data['strength'])
iterations = int(post_data['iterations'])
steps = int(post_data['steps'])
width = int(post_data['width'])
height = int(post_data['height'])
fit = 'fit' in post_data
cfgscale = float(post_data['cfgscale'])
sampler_name = post_data['sampler']
gfpgan_strength = float(post_data['gfpgan_strength']) if gfpgan_model_exists else 0
upscale_level = post_data['upscale_level']
upscale_strength = post_data['upscale_strength']
upscale = [int(upscale_level),float(upscale_strength)] if upscale_level != '' else None
progress_images = 'progress_images' in post_data
seed = self.model.seed if int(post_data['seed']) == -1 else int(post_data['seed'])
self.canceled.clear()
print(f">> Request to generate with prompt: {prompt}")
# In order to handle upscaled images, the PngWriter needs to maintain state
# across images generated by each call to prompt2img(), so we define it in
# the outer scope of image_done()
config = post_data.copy() # Shallow copy
config['initimg'] = ''
images_generated = 0 # helps keep track of when upscaling is started
images_upscaled = 0 # helps keep track of when upscaling is completed
pngwriter = PngWriter("./outputs/img-samples/")
prefix = pngwriter.unique_prefix()
# if upscaling is requested, then this will be called twice, once when
# the images are first generated, and then again when after upscaling
# is complete. The upscaling replaces the original file, so the second
# entry should not be inserted into the image list.
def image_done(image, seed, upscaled=False):
name = f'{prefix}.{seed}.png'
path = pngwriter.save_image_and_prompt_to_png(image, f'{prompt} -S{seed}', name)
# Append post_data to log, but only once!
if not upscaled:
with open("./outputs/img-samples/dream_web_log.txt", "a") as log:
log.write(f"{path}: {json.dumps(config)}\n")
self.wfile.write(bytes(json.dumps(
{'event': 'result', 'url': path, 'seed': seed, 'config': config}
) + '\n',"utf-8"))
# control state of the "postprocessing..." message
upscaling_requested = upscale or gfpgan_strength>0
nonlocal images_generated # NB: Is this bad python style? It is typical usage in a perl closure.
nonlocal images_upscaled # NB: Is this bad python style? It is typical usage in a perl closure.
if upscaled:
images_upscaled += 1
else:
images_generated +=1
if upscaling_requested:
action = None
if images_generated >= iterations:
if images_upscaled < iterations:
action = 'upscaling-started'
else:
action = 'upscaling-done'
if action:
x = images_upscaled+1
self.wfile.write(bytes(json.dumps(
{'event':action,'processed_file_cnt':f'{x}/{iterations}'}
) + '\n',"utf-8"))
step_writer = PngWriter('./outputs/intermediates/')
step_index = 1
def image_progress(sample, step):
if self.canceled.is_set():
self.wfile.write(bytes(json.dumps({'event':'canceled'}) + '\n', 'utf-8'))
raise CanceledException
path = None
# since rendering images is moderately expensive, only render every 5th image
# and don't bother with the last one, since it'll render anyway
nonlocal step_index
if progress_images and step % 5 == 0 and step < steps - 1:
image = self.model._sample_to_image(sample)
name = f'{prefix}.{seed}.{step_index}.png'
metadata = f'{prompt} -S{seed} [intermediate]'
path = step_writer.save_image_and_prompt_to_png(image, metadata, name)
step_index += 1
self.wfile.write(bytes(json.dumps(
{'event': 'step', 'step': step + 1, 'url': path}
) + '\n',"utf-8"))
try:
if initimg is None:
# Run txt2img
self.model.prompt2image(prompt,
iterations=iterations,
cfg_scale = cfgscale,
width = width,
height = height,
seed = seed,
steps = steps,
gfpgan_strength = gfpgan_strength,
upscale = upscale,
sampler_name = sampler_name,
step_callback=image_progress,
image_callback=image_done)
else:
# Decode initimg as base64 to temp file
with open("./img2img-tmp.png", "wb") as f:
initimg = initimg.split(",")[1] # Ignore mime type
f.write(base64.b64decode(initimg))
try:
# Run img2img
self.model.prompt2image(prompt,
init_img = "./img2img-tmp.png",
strength = strength,
iterations = iterations,
cfg_scale = cfgscale,
seed = seed,
steps = steps,
sampler_name = sampler_name,
width = width,
height = height,
fit = fit,
gfpgan_strength=gfpgan_strength,
upscale = upscale,
step_callback=image_progress,
image_callback=image_done)
finally:
# Remove the temp file
os.remove("./img2img-tmp.png")
except CanceledException:
print(f"Canceled.")
return
class ThreadingDreamServer(ThreadingHTTPServer):
def __init__(self, server_address):
super(ThreadingDreamServer, self).__init__(server_address, DreamServer)

167
ldm/gfpgan/gfpgan_tools.py Normal file
View File

@@ -0,0 +1,167 @@
import torch
import warnings
import os
import sys
import numpy as np
from PIL import Image
from scripts.dream import create_argv_parser
arg_parser = create_argv_parser()
opt = arg_parser.parse_args()
model_path = os.path.join(opt.gfpgan_dir, opt.gfpgan_model_path)
gfpgan_model_exists = os.path.isfile(model_path)
def _run_gfpgan(image, strength, prompt, seed, upsampler_scale=4):
print(f'>> GFPGAN - Restoring Faces: {prompt} : seed:{seed}')
gfpgan = None
with warnings.catch_warnings():
warnings.filterwarnings('ignore', category=DeprecationWarning)
warnings.filterwarnings('ignore', category=UserWarning)
try:
if not gfpgan_model_exists:
raise Exception('GFPGAN model not found at path ' + model_path)
sys.path.append(os.path.abspath(opt.gfpgan_dir))
from gfpgan import GFPGANer
bg_upsampler = _load_gfpgan_bg_upsampler(
opt.gfpgan_bg_upsampler, upsampler_scale, opt.gfpgan_bg_tile
)
gfpgan = GFPGANer(
model_path=model_path,
upscale=upsampler_scale,
arch='clean',
channel_multiplier=2,
bg_upsampler=bg_upsampler,
)
except Exception:
import traceback
print('>> Error loading GFPGAN:', file=sys.stderr)
print(traceback.format_exc(), file=sys.stderr)
if gfpgan is None:
print(
f'>> GFPGAN not initialized, it must be loaded via the --gfpgan argument'
)
return image
image = image.convert('RGB')
cropped_faces, restored_faces, restored_img = gfpgan.enhance(
np.array(image, dtype=np.uint8),
has_aligned=False,
only_center_face=False,
paste_back=True,
)
res = Image.fromarray(restored_img)
if strength < 1.0:
# Resize the image to the new image if the sizes have changed
if restored_img.size != image.size:
image = image.resize(res.size)
res = Image.blend(image, res, strength)
if torch.cuda.is_available():
torch.cuda.empty_cache()
gfpgan = None
return res
def _load_gfpgan_bg_upsampler(bg_upsampler, upsampler_scale, bg_tile=400):
if bg_upsampler == 'realesrgan':
if not torch.cuda.is_available(): # CPU
warnings.warn(
'The unoptimized RealESRGAN is slow on CPU. We do not use it. '
'If you really want to use it, please modify the corresponding codes.'
)
bg_upsampler = None
else:
model_path = {
2: 'https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x2plus.pth',
4: 'https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth',
}
if upsampler_scale not in model_path:
return None
from basicsr.archs.rrdbnet_arch import RRDBNet
from realesrgan import RealESRGANer
if upsampler_scale == 4:
model = RRDBNet(
num_in_ch=3,
num_out_ch=3,
num_feat=64,
num_block=23,
num_grow_ch=32,
scale=4,
)
if upsampler_scale == 2:
model = RRDBNet(
num_in_ch=3,
num_out_ch=3,
num_feat=64,
num_block=23,
num_grow_ch=32,
scale=2,
)
bg_upsampler = RealESRGANer(
scale=upsampler_scale,
model_path=model_path[upsampler_scale],
model=model,
tile=bg_tile,
tile_pad=10,
pre_pad=0,
half=True,
) # need to set False in CPU mode
else:
bg_upsampler = None
return bg_upsampler
def real_esrgan_upscale(image, strength, upsampler_scale, prompt, seed):
print(
f'>> Real-ESRGAN Upscaling: {prompt} : seed:{seed} : scale:{upsampler_scale}x'
)
with warnings.catch_warnings():
warnings.filterwarnings('ignore', category=DeprecationWarning)
warnings.filterwarnings('ignore', category=UserWarning)
try:
upsampler = _load_gfpgan_bg_upsampler(
opt.gfpgan_bg_upsampler, upsampler_scale, opt.gfpgan_bg_tile
)
except Exception:
import traceback
print('>> Error loading Real-ESRGAN:', file=sys.stderr)
print(traceback.format_exc(), file=sys.stderr)
output, img_mode = upsampler.enhance(
np.array(image, dtype=np.uint8),
outscale=upsampler_scale,
alpha_upsampler=opt.gfpgan_bg_upsampler,
)
res = Image.fromarray(output)
if strength < 1.0:
# Resize the image to the new image if the sizes have changed
if output.size != image.size:
image = image.resize(res.size)
res = Image.blend(image, res, strength)
if torch.cuda.is_available():
torch.cuda.empty_cache()
upsampler = None
return res

View File

@@ -5,32 +5,49 @@ class LambdaWarmUpCosineScheduler:
"""
note: use with a base_lr of 1.0
"""
def __init__(self, warm_up_steps, lr_min, lr_max, lr_start, max_decay_steps, verbosity_interval=0):
def __init__(
self,
warm_up_steps,
lr_min,
lr_max,
lr_start,
max_decay_steps,
verbosity_interval=0,
):
self.lr_warm_up_steps = warm_up_steps
self.lr_start = lr_start
self.lr_min = lr_min
self.lr_max = lr_max
self.lr_max_decay_steps = max_decay_steps
self.last_lr = 0.
self.last_lr = 0.0
self.verbosity_interval = verbosity_interval
def schedule(self, n, **kwargs):
if self.verbosity_interval > 0:
if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_lr}")
if n % self.verbosity_interval == 0:
print(
f'current step: {n}, recent lr-multiplier: {self.last_lr}'
)
if n < self.lr_warm_up_steps:
lr = (self.lr_max - self.lr_start) / self.lr_warm_up_steps * n + self.lr_start
lr = (
self.lr_max - self.lr_start
) / self.lr_warm_up_steps * n + self.lr_start
self.last_lr = lr
return lr
else:
t = (n - self.lr_warm_up_steps) / (self.lr_max_decay_steps - self.lr_warm_up_steps)
t = (n - self.lr_warm_up_steps) / (
self.lr_max_decay_steps - self.lr_warm_up_steps
)
t = min(t, 1.0)
lr = self.lr_min + 0.5 * (self.lr_max - self.lr_min) * (
1 + np.cos(t * np.pi))
1 + np.cos(t * np.pi)
)
self.last_lr = lr
return lr
def __call__(self, n, **kwargs):
return self.schedule(n,**kwargs)
return self.schedule(n, **kwargs)
class LambdaWarmUpCosineScheduler2:
@@ -38,15 +55,30 @@ class LambdaWarmUpCosineScheduler2:
supports repeated iterations, configurable via lists
note: use with a base_lr of 1.0.
"""
def __init__(self, warm_up_steps, f_min, f_max, f_start, cycle_lengths, verbosity_interval=0):
assert len(warm_up_steps) == len(f_min) == len(f_max) == len(f_start) == len(cycle_lengths)
def __init__(
self,
warm_up_steps,
f_min,
f_max,
f_start,
cycle_lengths,
verbosity_interval=0,
):
assert (
len(warm_up_steps)
== len(f_min)
== len(f_max)
== len(f_start)
== len(cycle_lengths)
)
self.lr_warm_up_steps = warm_up_steps
self.f_start = f_start
self.f_min = f_min
self.f_max = f_max
self.cycle_lengths = cycle_lengths
self.cum_cycles = np.cumsum([0] + list(self.cycle_lengths))
self.last_f = 0.
self.last_f = 0.0
self.verbosity_interval = verbosity_interval
def find_in_interval(self, n):
@@ -60,17 +92,25 @@ class LambdaWarmUpCosineScheduler2:
cycle = self.find_in_interval(n)
n = n - self.cum_cycles[cycle]
if self.verbosity_interval > 0:
if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_f}, "
f"current cycle {cycle}")
if n % self.verbosity_interval == 0:
print(
f'current step: {n}, recent lr-multiplier: {self.last_f}, '
f'current cycle {cycle}'
)
if n < self.lr_warm_up_steps[cycle]:
f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[cycle] * n + self.f_start[cycle]
f = (
self.f_max[cycle] - self.f_start[cycle]
) / self.lr_warm_up_steps[cycle] * n + self.f_start[cycle]
self.last_f = f
return f
else:
t = (n - self.lr_warm_up_steps[cycle]) / (self.cycle_lengths[cycle] - self.lr_warm_up_steps[cycle])
t = (n - self.lr_warm_up_steps[cycle]) / (
self.cycle_lengths[cycle] - self.lr_warm_up_steps[cycle]
)
t = min(t, 1.0)
f = self.f_min[cycle] + 0.5 * (self.f_max[cycle] - self.f_min[cycle]) * (
1 + np.cos(t * np.pi))
f = self.f_min[cycle] + 0.5 * (
self.f_max[cycle] - self.f_min[cycle]
) * (1 + np.cos(t * np.pi))
self.last_f = f
return f
@@ -79,20 +119,25 @@ class LambdaWarmUpCosineScheduler2:
class LambdaLinearScheduler(LambdaWarmUpCosineScheduler2):
def schedule(self, n, **kwargs):
cycle = self.find_in_interval(n)
n = n - self.cum_cycles[cycle]
if self.verbosity_interval > 0:
if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_f}, "
f"current cycle {cycle}")
if n % self.verbosity_interval == 0:
print(
f'current step: {n}, recent lr-multiplier: {self.last_f}, '
f'current cycle {cycle}'
)
if n < self.lr_warm_up_steps[cycle]:
f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[cycle] * n + self.f_start[cycle]
f = (
self.f_max[cycle] - self.f_start[cycle]
) / self.lr_warm_up_steps[cycle] * n + self.f_start[cycle]
self.last_f = f
return f
else:
f = self.f_min[cycle] + (self.f_max[cycle] - self.f_min[cycle]) * (self.cycle_lengths[cycle] - n) / (self.cycle_lengths[cycle])
f = self.f_min[cycle] + (self.f_max[cycle] - self.f_min[cycle]) * (
self.cycle_lengths[cycle] - n
) / (self.cycle_lengths[cycle])
self.last_f = f
return f

View File

@@ -6,29 +6,32 @@ from contextlib import contextmanager
from taming.modules.vqvae.quantize import VectorQuantizer2 as VectorQuantizer
from ldm.modules.diffusionmodules.model import Encoder, Decoder
from ldm.modules.distributions.distributions import DiagonalGaussianDistribution
from ldm.modules.distributions.distributions import (
DiagonalGaussianDistribution,
)
from ldm.util import instantiate_from_config
class VQModel(pl.LightningModule):
def __init__(self,
ddconfig,
lossconfig,
n_embed,
embed_dim,
ckpt_path=None,
ignore_keys=[],
image_key="image",
colorize_nlabels=None,
monitor=None,
batch_resize_range=None,
scheduler_config=None,
lr_g_factor=1.0,
remap=None,
sane_index_shape=False, # tell vector quantizer to return indices as bhw
use_ema=False
):
def __init__(
self,
ddconfig,
lossconfig,
n_embed,
embed_dim,
ckpt_path=None,
ignore_keys=[],
image_key='image',
colorize_nlabels=None,
monitor=None,
batch_resize_range=None,
scheduler_config=None,
lr_g_factor=1.0,
remap=None,
sane_index_shape=False, # tell vector quantizer to return indices as bhw
use_ema=False,
):
super().__init__()
self.embed_dim = embed_dim
self.n_embed = n_embed
@@ -36,24 +39,34 @@ class VQModel(pl.LightningModule):
self.encoder = Encoder(**ddconfig)
self.decoder = Decoder(**ddconfig)
self.loss = instantiate_from_config(lossconfig)
self.quantize = VectorQuantizer(n_embed, embed_dim, beta=0.25,
remap=remap,
sane_index_shape=sane_index_shape)
self.quant_conv = torch.nn.Conv2d(ddconfig["z_channels"], embed_dim, 1)
self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1)
self.quantize = VectorQuantizer(
n_embed,
embed_dim,
beta=0.25,
remap=remap,
sane_index_shape=sane_index_shape,
)
self.quant_conv = torch.nn.Conv2d(ddconfig['z_channels'], embed_dim, 1)
self.post_quant_conv = torch.nn.Conv2d(
embed_dim, ddconfig['z_channels'], 1
)
if colorize_nlabels is not None:
assert type(colorize_nlabels)==int
self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1))
assert type(colorize_nlabels) == int
self.register_buffer(
'colorize', torch.randn(3, colorize_nlabels, 1, 1)
)
if monitor is not None:
self.monitor = monitor
self.batch_resize_range = batch_resize_range
if self.batch_resize_range is not None:
print(f"{self.__class__.__name__}: Using per-batch resizing in range {batch_resize_range}.")
print(
f'{self.__class__.__name__}: Using per-batch resizing in range {batch_resize_range}.'
)
self.use_ema = use_ema
if self.use_ema:
self.model_ema = LitEma(self)
print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")
print(f'Keeping EMAs of {len(list(self.model_ema.buffers()))}.')
if ckpt_path is not None:
self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
@@ -66,28 +79,30 @@ class VQModel(pl.LightningModule):
self.model_ema.store(self.parameters())
self.model_ema.copy_to(self)
if context is not None:
print(f"{context}: Switched to EMA weights")
print(f'{context}: Switched to EMA weights')
try:
yield None
finally:
if self.use_ema:
self.model_ema.restore(self.parameters())
if context is not None:
print(f"{context}: Restored training weights")
print(f'{context}: Restored training weights')
def init_from_ckpt(self, path, ignore_keys=list()):
sd = torch.load(path, map_location="cpu")["state_dict"]
sd = torch.load(path, map_location='cpu')['state_dict']
keys = list(sd.keys())
for k in keys:
for ik in ignore_keys:
if k.startswith(ik):
print("Deleting key {} from state_dict.".format(k))
print('Deleting key {} from state_dict.'.format(k))
del sd[k]
missing, unexpected = self.load_state_dict(sd, strict=False)
print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys")
print(
f'Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys'
)
if len(missing) > 0:
print(f"Missing Keys: {missing}")
print(f"Unexpected Keys: {unexpected}")
print(f'Missing Keys: {missing}')
print(f'Unexpected Keys: {unexpected}')
def on_train_batch_end(self, *args, **kwargs):
if self.use_ema:
@@ -115,7 +130,7 @@ class VQModel(pl.LightningModule):
return dec
def forward(self, input, return_pred_indices=False):
quant, diff, (_,_,ind) = self.encode(input)
quant, diff, (_, _, ind) = self.encode(input)
dec = self.decode(quant)
if return_pred_indices:
return dec, diff, ind
@@ -125,7 +140,11 @@ class VQModel(pl.LightningModule):
x = batch[k]
if len(x.shape) == 3:
x = x[..., None]
x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format).float()
x = (
x.permute(0, 3, 1, 2)
.to(memory_format=torch.contiguous_format)
.float()
)
if self.batch_resize_range is not None:
lower_size = self.batch_resize_range[0]
upper_size = self.batch_resize_range[1]
@@ -133,9 +152,11 @@ class VQModel(pl.LightningModule):
# do the first few batches with max size to avoid later oom
new_resize = upper_size
else:
new_resize = np.random.choice(np.arange(lower_size, upper_size+16, 16))
new_resize = np.random.choice(
np.arange(lower_size, upper_size + 16, 16)
)
if new_resize != x.shape[2]:
x = F.interpolate(x, size=new_resize, mode="bicubic")
x = F.interpolate(x, size=new_resize, mode='bicubic')
x = x.detach()
return x
@@ -147,81 +168,139 @@ class VQModel(pl.LightningModule):
if optimizer_idx == 0:
# autoencode
aeloss, log_dict_ae = self.loss(qloss, x, xrec, optimizer_idx, self.global_step,
last_layer=self.get_last_layer(), split="train",
predicted_indices=ind)
aeloss, log_dict_ae = self.loss(
qloss,
x,
xrec,
optimizer_idx,
self.global_step,
last_layer=self.get_last_layer(),
split='train',
predicted_indices=ind,
)
self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=True)
self.log_dict(
log_dict_ae,
prog_bar=False,
logger=True,
on_step=True,
on_epoch=True,
)
return aeloss
if optimizer_idx == 1:
# discriminator
discloss, log_dict_disc = self.loss(qloss, x, xrec, optimizer_idx, self.global_step,
last_layer=self.get_last_layer(), split="train")
self.log_dict(log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=True)
discloss, log_dict_disc = self.loss(
qloss,
x,
xrec,
optimizer_idx,
self.global_step,
last_layer=self.get_last_layer(),
split='train',
)
self.log_dict(
log_dict_disc,
prog_bar=False,
logger=True,
on_step=True,
on_epoch=True,
)
return discloss
def validation_step(self, batch, batch_idx):
log_dict = self._validation_step(batch, batch_idx)
with self.ema_scope():
log_dict_ema = self._validation_step(batch, batch_idx, suffix="_ema")
log_dict_ema = self._validation_step(
batch, batch_idx, suffix='_ema'
)
return log_dict
def _validation_step(self, batch, batch_idx, suffix=""):
def _validation_step(self, batch, batch_idx, suffix=''):
x = self.get_input(batch, self.image_key)
xrec, qloss, ind = self(x, return_pred_indices=True)
aeloss, log_dict_ae = self.loss(qloss, x, xrec, 0,
self.global_step,
last_layer=self.get_last_layer(),
split="val"+suffix,
predicted_indices=ind
)
aeloss, log_dict_ae = self.loss(
qloss,
x,
xrec,
0,
self.global_step,
last_layer=self.get_last_layer(),
split='val' + suffix,
predicted_indices=ind,
)
discloss, log_dict_disc = self.loss(qloss, x, xrec, 1,
self.global_step,
last_layer=self.get_last_layer(),
split="val"+suffix,
predicted_indices=ind
)
rec_loss = log_dict_ae[f"val{suffix}/rec_loss"]
self.log(f"val{suffix}/rec_loss", rec_loss,
prog_bar=True, logger=True, on_step=False, on_epoch=True, sync_dist=True)
self.log(f"val{suffix}/aeloss", aeloss,
prog_bar=True, logger=True, on_step=False, on_epoch=True, sync_dist=True)
discloss, log_dict_disc = self.loss(
qloss,
x,
xrec,
1,
self.global_step,
last_layer=self.get_last_layer(),
split='val' + suffix,
predicted_indices=ind,
)
rec_loss = log_dict_ae[f'val{suffix}/rec_loss']
self.log(
f'val{suffix}/rec_loss',
rec_loss,
prog_bar=True,
logger=True,
on_step=False,
on_epoch=True,
sync_dist=True,
)
self.log(
f'val{suffix}/aeloss',
aeloss,
prog_bar=True,
logger=True,
on_step=False,
on_epoch=True,
sync_dist=True,
)
if version.parse(pl.__version__) >= version.parse('1.4.0'):
del log_dict_ae[f"val{suffix}/rec_loss"]
del log_dict_ae[f'val{suffix}/rec_loss']
self.log_dict(log_dict_ae)
self.log_dict(log_dict_disc)
return self.log_dict
def configure_optimizers(self):
lr_d = self.learning_rate
lr_g = self.lr_g_factor*self.learning_rate
print("lr_d", lr_d)
print("lr_g", lr_g)
opt_ae = torch.optim.Adam(list(self.encoder.parameters())+
list(self.decoder.parameters())+
list(self.quantize.parameters())+
list(self.quant_conv.parameters())+
list(self.post_quant_conv.parameters()),
lr=lr_g, betas=(0.5, 0.9))
opt_disc = torch.optim.Adam(self.loss.discriminator.parameters(),
lr=lr_d, betas=(0.5, 0.9))
lr_g = self.lr_g_factor * self.learning_rate
print('lr_d', lr_d)
print('lr_g', lr_g)
opt_ae = torch.optim.Adam(
list(self.encoder.parameters())
+ list(self.decoder.parameters())
+ list(self.quantize.parameters())
+ list(self.quant_conv.parameters())
+ list(self.post_quant_conv.parameters()),
lr=lr_g,
betas=(0.5, 0.9),
)
opt_disc = torch.optim.Adam(
self.loss.discriminator.parameters(), lr=lr_d, betas=(0.5, 0.9)
)
if self.scheduler_config is not None:
scheduler = instantiate_from_config(self.scheduler_config)
print("Setting up LambdaLR scheduler...")
print('Setting up LambdaLR scheduler...')
scheduler = [
{
'scheduler': LambdaLR(opt_ae, lr_lambda=scheduler.schedule),
'scheduler': LambdaLR(
opt_ae, lr_lambda=scheduler.schedule
),
'interval': 'step',
'frequency': 1
'frequency': 1,
},
{
'scheduler': LambdaLR(opt_disc, lr_lambda=scheduler.schedule),
'scheduler': LambdaLR(
opt_disc, lr_lambda=scheduler.schedule
),
'interval': 'step',
'frequency': 1
'frequency': 1,
},
]
return [opt_ae, opt_disc], scheduler
@@ -235,7 +314,7 @@ class VQModel(pl.LightningModule):
x = self.get_input(batch, self.image_key)
x = x.to(self.device)
if only_inputs:
log["inputs"] = x
log['inputs'] = x
return log
xrec, _ = self(x)
if x.shape[1] > 3:
@@ -243,21 +322,24 @@ class VQModel(pl.LightningModule):
assert xrec.shape[1] > 3
x = self.to_rgb(x)
xrec = self.to_rgb(xrec)
log["inputs"] = x
log["reconstructions"] = xrec
log['inputs'] = x
log['reconstructions'] = xrec
if plot_ema:
with self.ema_scope():
xrec_ema, _ = self(x)
if x.shape[1] > 3: xrec_ema = self.to_rgb(xrec_ema)
log["reconstructions_ema"] = xrec_ema
if x.shape[1] > 3:
xrec_ema = self.to_rgb(xrec_ema)
log['reconstructions_ema'] = xrec_ema
return log
def to_rgb(self, x):
assert self.image_key == "segmentation"
if not hasattr(self, "colorize"):
self.register_buffer("colorize", torch.randn(3, x.shape[1], 1, 1).to(x))
assert self.image_key == 'segmentation'
if not hasattr(self, 'colorize'):
self.register_buffer(
'colorize', torch.randn(3, x.shape[1], 1, 1).to(x)
)
x = F.conv2d(x, weight=self.colorize)
x = 2.*(x-x.min())/(x.max()-x.min()) - 1.
x = 2.0 * (x - x.min()) / (x.max() - x.min()) - 1.0
return x
@@ -283,43 +365,50 @@ class VQModelInterface(VQModel):
class AutoencoderKL(pl.LightningModule):
def __init__(self,
ddconfig,
lossconfig,
embed_dim,
ckpt_path=None,
ignore_keys=[],
image_key="image",
colorize_nlabels=None,
monitor=None,
):
def __init__(
self,
ddconfig,
lossconfig,
embed_dim,
ckpt_path=None,
ignore_keys=[],
image_key='image',
colorize_nlabels=None,
monitor=None,
):
super().__init__()
self.image_key = image_key
self.encoder = Encoder(**ddconfig)
self.decoder = Decoder(**ddconfig)
self.loss = instantiate_from_config(lossconfig)
assert ddconfig["double_z"]
self.quant_conv = torch.nn.Conv2d(2*ddconfig["z_channels"], 2*embed_dim, 1)
self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1)
assert ddconfig['double_z']
self.quant_conv = torch.nn.Conv2d(
2 * ddconfig['z_channels'], 2 * embed_dim, 1
)
self.post_quant_conv = torch.nn.Conv2d(
embed_dim, ddconfig['z_channels'], 1
)
self.embed_dim = embed_dim
if colorize_nlabels is not None:
assert type(colorize_nlabels)==int
self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1))
assert type(colorize_nlabels) == int
self.register_buffer(
'colorize', torch.randn(3, colorize_nlabels, 1, 1)
)
if monitor is not None:
self.monitor = monitor
if ckpt_path is not None:
self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
def init_from_ckpt(self, path, ignore_keys=list()):
sd = torch.load(path, map_location="cpu")["state_dict"]
sd = torch.load(path, map_location='cpu')['state_dict']
keys = list(sd.keys())
for k in keys:
for ik in ignore_keys:
if k.startswith(ik):
print("Deleting key {} from state_dict.".format(k))
print('Deleting key {} from state_dict.'.format(k))
del sd[k]
self.load_state_dict(sd, strict=False)
print(f"Restored from {path}")
print(f'Restored from {path}')
def encode(self, x):
h = self.encoder(x)
@@ -345,7 +434,11 @@ class AutoencoderKL(pl.LightningModule):
x = batch[k]
if len(x.shape) == 3:
x = x[..., None]
x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format).float()
x = (
x.permute(0, 3, 1, 2)
.to(memory_format=torch.contiguous_format)
.float()
)
return x
def training_step(self, batch, batch_idx, optimizer_idx):
@@ -354,44 +447,102 @@ class AutoencoderKL(pl.LightningModule):
if optimizer_idx == 0:
# train encoder+decoder+logvar
aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, optimizer_idx, self.global_step,
last_layer=self.get_last_layer(), split="train")
self.log("aeloss", aeloss, prog_bar=True, logger=True, on_step=True, on_epoch=True)
self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=False)
aeloss, log_dict_ae = self.loss(
inputs,
reconstructions,
posterior,
optimizer_idx,
self.global_step,
last_layer=self.get_last_layer(),
split='train',
)
self.log(
'aeloss',
aeloss,
prog_bar=True,
logger=True,
on_step=True,
on_epoch=True,
)
self.log_dict(
log_dict_ae,
prog_bar=False,
logger=True,
on_step=True,
on_epoch=False,
)
return aeloss
if optimizer_idx == 1:
# train the discriminator
discloss, log_dict_disc = self.loss(inputs, reconstructions, posterior, optimizer_idx, self.global_step,
last_layer=self.get_last_layer(), split="train")
discloss, log_dict_disc = self.loss(
inputs,
reconstructions,
posterior,
optimizer_idx,
self.global_step,
last_layer=self.get_last_layer(),
split='train',
)
self.log("discloss", discloss, prog_bar=True, logger=True, on_step=True, on_epoch=True)
self.log_dict(log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=False)
self.log(
'discloss',
discloss,
prog_bar=True,
logger=True,
on_step=True,
on_epoch=True,
)
self.log_dict(
log_dict_disc,
prog_bar=False,
logger=True,
on_step=True,
on_epoch=False,
)
return discloss
def validation_step(self, batch, batch_idx):
inputs = self.get_input(batch, self.image_key)
reconstructions, posterior = self(inputs)
aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, 0, self.global_step,
last_layer=self.get_last_layer(), split="val")
aeloss, log_dict_ae = self.loss(
inputs,
reconstructions,
posterior,
0,
self.global_step,
last_layer=self.get_last_layer(),
split='val',
)
discloss, log_dict_disc = self.loss(inputs, reconstructions, posterior, 1, self.global_step,
last_layer=self.get_last_layer(), split="val")
discloss, log_dict_disc = self.loss(
inputs,
reconstructions,
posterior,
1,
self.global_step,
last_layer=self.get_last_layer(),
split='val',
)
self.log("val/rec_loss", log_dict_ae["val/rec_loss"])
self.log('val/rec_loss', log_dict_ae['val/rec_loss'])
self.log_dict(log_dict_ae)
self.log_dict(log_dict_disc)
return self.log_dict
def configure_optimizers(self):
lr = self.learning_rate
opt_ae = torch.optim.Adam(list(self.encoder.parameters())+
list(self.decoder.parameters())+
list(self.quant_conv.parameters())+
list(self.post_quant_conv.parameters()),
lr=lr, betas=(0.5, 0.9))
opt_disc = torch.optim.Adam(self.loss.discriminator.parameters(),
lr=lr, betas=(0.5, 0.9))
opt_ae = torch.optim.Adam(
list(self.encoder.parameters())
+ list(self.decoder.parameters())
+ list(self.quant_conv.parameters())
+ list(self.post_quant_conv.parameters()),
lr=lr,
betas=(0.5, 0.9),
)
opt_disc = torch.optim.Adam(
self.loss.discriminator.parameters(), lr=lr, betas=(0.5, 0.9)
)
return [opt_ae, opt_disc], []
def get_last_layer(self):
@@ -409,17 +560,19 @@ class AutoencoderKL(pl.LightningModule):
assert xrec.shape[1] > 3
x = self.to_rgb(x)
xrec = self.to_rgb(xrec)
log["samples"] = self.decode(torch.randn_like(posterior.sample()))
log["reconstructions"] = xrec
log["inputs"] = x
log['samples'] = self.decode(torch.randn_like(posterior.sample()))
log['reconstructions'] = xrec
log['inputs'] = x
return log
def to_rgb(self, x):
assert self.image_key == "segmentation"
if not hasattr(self, "colorize"):
self.register_buffer("colorize", torch.randn(3, x.shape[1], 1, 1).to(x))
assert self.image_key == 'segmentation'
if not hasattr(self, 'colorize'):
self.register_buffer(
'colorize', torch.randn(3, x.shape[1], 1, 1).to(x)
)
x = F.conv2d(x, weight=self.colorize)
x = 2.*(x-x.min())/(x.max()-x.min()) - 1.
x = 2.0 * (x - x.min()) / (x.max() - x.min()) - 1.0
return x

View File

@@ -10,13 +10,13 @@ from einops import rearrange
from glob import glob
from natsort import natsorted
from ldm.modules.diffusionmodules.openaimodel import EncoderUNetModel, UNetModel
from ldm.modules.diffusionmodules.openaimodel import (
EncoderUNetModel,
UNetModel,
)
from ldm.util import log_txt_as_img, default, ismap, instantiate_from_config
__models__ = {
'class_label': EncoderUNetModel,
'segmentation': UNetModel
}
__models__ = {'class_label': EncoderUNetModel, 'segmentation': UNetModel}
def disabled_train(self, mode=True):
@@ -26,37 +26,49 @@ def disabled_train(self, mode=True):
class NoisyLatentImageClassifier(pl.LightningModule):
def __init__(self,
diffusion_path,
num_classes,
ckpt_path=None,
pool='attention',
label_key=None,
diffusion_ckpt_path=None,
scheduler_config=None,
weight_decay=1.e-2,
log_steps=10,
monitor='val/loss',
*args,
**kwargs):
def __init__(
self,
diffusion_path,
num_classes,
ckpt_path=None,
pool='attention',
label_key=None,
diffusion_ckpt_path=None,
scheduler_config=None,
weight_decay=1.0e-2,
log_steps=10,
monitor='val/loss',
*args,
**kwargs,
):
super().__init__(*args, **kwargs)
self.num_classes = num_classes
# get latest config of diffusion model
diffusion_config = natsorted(glob(os.path.join(diffusion_path, 'configs', '*-project.yaml')))[-1]
diffusion_config = natsorted(
glob(os.path.join(diffusion_path, 'configs', '*-project.yaml'))
)[-1]
self.diffusion_config = OmegaConf.load(diffusion_config).model
self.diffusion_config.params.ckpt_path = diffusion_ckpt_path
self.load_diffusion()
self.monitor = monitor
self.numd = self.diffusion_model.first_stage_model.encoder.num_resolutions - 1
self.log_time_interval = self.diffusion_model.num_timesteps // log_steps
self.numd = (
self.diffusion_model.first_stage_model.encoder.num_resolutions - 1
)
self.log_time_interval = (
self.diffusion_model.num_timesteps // log_steps
)
self.log_steps = log_steps
self.label_key = label_key if not hasattr(self.diffusion_model, 'cond_stage_key') \
self.label_key = (
label_key
if not hasattr(self.diffusion_model, 'cond_stage_key')
else self.diffusion_model.cond_stage_key
)
assert self.label_key is not None, 'label_key neither in diffusion model nor in model.params'
assert (
self.label_key is not None
), 'label_key neither in diffusion model nor in model.params'
if self.label_key not in __models__:
raise NotImplementedError()
@@ -68,22 +80,27 @@ class NoisyLatentImageClassifier(pl.LightningModule):
self.weight_decay = weight_decay
def init_from_ckpt(self, path, ignore_keys=list(), only_model=False):
sd = torch.load(path, map_location="cpu")
if "state_dict" in list(sd.keys()):
sd = sd["state_dict"]
sd = torch.load(path, map_location='cpu')
if 'state_dict' in list(sd.keys()):
sd = sd['state_dict']
keys = list(sd.keys())
for k in keys:
for ik in ignore_keys:
if k.startswith(ik):
print("Deleting key {} from state_dict.".format(k))
print('Deleting key {} from state_dict.'.format(k))
del sd[k]
missing, unexpected = self.load_state_dict(sd, strict=False) if not only_model else self.model.load_state_dict(
sd, strict=False)
print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys")
missing, unexpected = (
self.load_state_dict(sd, strict=False)
if not only_model
else self.model.load_state_dict(sd, strict=False)
)
print(
f'Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys'
)
if len(missing) > 0:
print(f"Missing Keys: {missing}")
print(f'Missing Keys: {missing}')
if len(unexpected) > 0:
print(f"Unexpected Keys: {unexpected}")
print(f'Unexpected Keys: {unexpected}')
def load_diffusion(self):
model = instantiate_from_config(self.diffusion_config)
@@ -93,17 +110,25 @@ class NoisyLatentImageClassifier(pl.LightningModule):
param.requires_grad = False
def load_classifier(self, ckpt_path, pool):
model_config = deepcopy(self.diffusion_config.params.unet_config.params)
model_config.in_channels = self.diffusion_config.params.unet_config.params.out_channels
model_config = deepcopy(
self.diffusion_config.params.unet_config.params
)
model_config.in_channels = (
self.diffusion_config.params.unet_config.params.out_channels
)
model_config.out_channels = self.num_classes
if self.label_key == 'class_label':
model_config.pool = pool
self.model = __models__[self.label_key](**model_config)
if ckpt_path is not None:
print('#####################################################################')
print(
'#####################################################################'
)
print(f'load from ckpt "{ckpt_path}"')
print('#####################################################################')
print(
'#####################################################################'
)
self.init_from_ckpt(ckpt_path)
@torch.no_grad()
@@ -111,11 +136,19 @@ class NoisyLatentImageClassifier(pl.LightningModule):
noise = default(noise, lambda: torch.randn_like(x))
continuous_sqrt_alpha_cumprod = None
if self.diffusion_model.use_continuous_noise:
continuous_sqrt_alpha_cumprod = self.diffusion_model.sample_continuous_noise_level(x.shape[0], t + 1)
continuous_sqrt_alpha_cumprod = (
self.diffusion_model.sample_continuous_noise_level(
x.shape[0], t + 1
)
)
# todo: make sure t+1 is correct here
return self.diffusion_model.q_sample(x_start=x, t=t, noise=noise,
continuous_sqrt_alpha_cumprod=continuous_sqrt_alpha_cumprod)
return self.diffusion_model.q_sample(
x_start=x,
t=t,
noise=noise,
continuous_sqrt_alpha_cumprod=continuous_sqrt_alpha_cumprod,
)
def forward(self, x_noisy, t, *args, **kwargs):
return self.model(x_noisy, t)
@@ -141,17 +174,21 @@ class NoisyLatentImageClassifier(pl.LightningModule):
targets = rearrange(targets, 'b h w c -> b c h w')
for down in range(self.numd):
h, w = targets.shape[-2:]
targets = F.interpolate(targets, size=(h // 2, w // 2), mode='nearest')
targets = F.interpolate(
targets, size=(h // 2, w // 2), mode='nearest'
)
# targets = rearrange(targets,'b c h w -> b h w c')
return targets
def compute_top_k(self, logits, labels, k, reduction="mean"):
def compute_top_k(self, logits, labels, k, reduction='mean'):
_, top_ks = torch.topk(logits, k, dim=1)
if reduction == "mean":
return (top_ks == labels[:, None]).float().sum(dim=-1).mean().item()
elif reduction == "none":
if reduction == 'mean':
return (
(top_ks == labels[:, None]).float().sum(dim=-1).mean().item()
)
elif reduction == 'none':
return (top_ks == labels[:, None]).float().sum(dim=-1)
def on_train_epoch_start(self):
@@ -162,29 +199,59 @@ class NoisyLatentImageClassifier(pl.LightningModule):
def write_logs(self, loss, logits, targets):
log_prefix = 'train' if self.training else 'val'
log = {}
log[f"{log_prefix}/loss"] = loss.mean()
log[f"{log_prefix}/acc@1"] = self.compute_top_k(
logits, targets, k=1, reduction="mean"
log[f'{log_prefix}/loss'] = loss.mean()
log[f'{log_prefix}/acc@1'] = self.compute_top_k(
logits, targets, k=1, reduction='mean'
)
log[f"{log_prefix}/acc@5"] = self.compute_top_k(
logits, targets, k=5, reduction="mean"
log[f'{log_prefix}/acc@5'] = self.compute_top_k(
logits, targets, k=5, reduction='mean'
)
self.log_dict(log, prog_bar=False, logger=True, on_step=self.training, on_epoch=True)
self.log('loss', log[f"{log_prefix}/loss"], prog_bar=True, logger=False)
self.log('global_step', self.global_step, logger=False, on_epoch=False, prog_bar=True)
self.log_dict(
log,
prog_bar=False,
logger=True,
on_step=self.training,
on_epoch=True,
)
self.log(
'loss', log[f'{log_prefix}/loss'], prog_bar=True, logger=False
)
self.log(
'global_step',
self.global_step,
logger=False,
on_epoch=False,
prog_bar=True,
)
lr = self.optimizers().param_groups[0]['lr']
self.log('lr_abs', lr, on_step=True, logger=True, on_epoch=False, prog_bar=True)
self.log(
'lr_abs',
lr,
on_step=True,
logger=True,
on_epoch=False,
prog_bar=True,
)
def shared_step(self, batch, t=None):
x, *_ = self.diffusion_model.get_input(batch, k=self.diffusion_model.first_stage_key)
x, *_ = self.diffusion_model.get_input(
batch, k=self.diffusion_model.first_stage_key
)
targets = self.get_conditioning(batch)
if targets.dim() == 4:
targets = targets.argmax(dim=1)
if t is None:
t = torch.randint(0, self.diffusion_model.num_timesteps, (x.shape[0],), device=self.device).long()
t = torch.randint(
0,
self.diffusion_model.num_timesteps,
(x.shape[0],),
device=self.device,
).long()
else:
t = torch.full(size=(x.shape[0],), fill_value=t, device=self.device).long()
t = torch.full(
size=(x.shape[0],), fill_value=t, device=self.device
).long()
x_noisy = self.get_x_noisy(x, t)
logits = self(x_noisy, t)
@@ -200,8 +267,14 @@ class NoisyLatentImageClassifier(pl.LightningModule):
return loss
def reset_noise_accs(self):
self.noisy_acc = {t: {'acc@1': [], 'acc@5': []} for t in
range(0, self.diffusion_model.num_timesteps, self.diffusion_model.log_every_t)}
self.noisy_acc = {
t: {'acc@1': [], 'acc@5': []}
for t in range(
0,
self.diffusion_model.num_timesteps,
self.diffusion_model.log_every_t,
)
}
def on_validation_start(self):
self.reset_noise_accs()
@@ -212,24 +285,35 @@ class NoisyLatentImageClassifier(pl.LightningModule):
for t in self.noisy_acc:
_, logits, _, targets = self.shared_step(batch, t)
self.noisy_acc[t]['acc@1'].append(self.compute_top_k(logits, targets, k=1, reduction='mean'))
self.noisy_acc[t]['acc@5'].append(self.compute_top_k(logits, targets, k=5, reduction='mean'))
self.noisy_acc[t]['acc@1'].append(
self.compute_top_k(logits, targets, k=1, reduction='mean')
)
self.noisy_acc[t]['acc@5'].append(
self.compute_top_k(logits, targets, k=5, reduction='mean')
)
return loss
def configure_optimizers(self):
optimizer = AdamW(self.model.parameters(), lr=self.learning_rate, weight_decay=self.weight_decay)
optimizer = AdamW(
self.model.parameters(),
lr=self.learning_rate,
weight_decay=self.weight_decay,
)
if self.use_scheduler:
scheduler = instantiate_from_config(self.scheduler_config)
print("Setting up LambdaLR scheduler...")
print('Setting up LambdaLR scheduler...')
scheduler = [
{
'scheduler': LambdaLR(optimizer, lr_lambda=scheduler.schedule),
'scheduler': LambdaLR(
optimizer, lr_lambda=scheduler.schedule
),
'interval': 'step',
'frequency': 1
}]
'frequency': 1,
}
]
return [optimizer], scheduler
return optimizer
@@ -243,7 +327,7 @@ class NoisyLatentImageClassifier(pl.LightningModule):
y = self.get_conditioning(batch)
if self.label_key == 'class_label':
y = log_txt_as_img((x.shape[2], x.shape[3]), batch["human_label"])
y = log_txt_as_img((x.shape[2], x.shape[3]), batch['human_label'])
log['labels'] = y
if ismap(y):
@@ -256,10 +340,14 @@ class NoisyLatentImageClassifier(pl.LightningModule):
log[f'inputs@t{current_time}'] = x_noisy
pred = F.one_hot(logits.argmax(dim=1), num_classes=self.num_classes)
pred = F.one_hot(
logits.argmax(dim=1), num_classes=self.num_classes
)
pred = rearrange(pred, 'b h w c -> b c h w')
log[f'pred@t{current_time}'] = self.diffusion_model.to_rgb(pred)
log[f'pred@t{current_time}'] = self.diffusion_model.to_rgb(
pred
)
for key in log:
log[key] = log[key][:N]

View File

@@ -4,89 +4,146 @@ import torch
import numpy as np
from tqdm import tqdm
from functools import partial
from ldm.dream.devices import choose_torch_device
from ldm.modules.diffusionmodules.util import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like, \
extract_into_tensor
from ldm.modules.diffusionmodules.util import (
make_ddim_sampling_parameters,
make_ddim_timesteps,
noise_like,
extract_into_tensor,
)
class DDIMSampler(object):
def __init__(self, model, schedule="linear", device="cuda", **kwargs):
def __init__(self, model, schedule='linear', device=None, **kwargs):
super().__init__()
self.model = model
self.ddpm_num_timesteps = model.num_timesteps
self.schedule = schedule
self.device = device
self.device = device or choose_torch_device()
def register_buffer(self, name, attr):
if type(attr) == torch.Tensor:
if attr.device != torch.device(self.device):
attr = attr.to(torch.device(self.device))
attr = attr.to(dtype=torch.float32, device=self.device)
setattr(self, name, attr)
def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True):
self.ddim_timesteps = make_ddim_timesteps(ddim_discr_method=ddim_discretize, num_ddim_timesteps=ddim_num_steps,
num_ddpm_timesteps=self.ddpm_num_timesteps,verbose=verbose)
def make_schedule(
self,
ddim_num_steps,
ddim_discretize='uniform',
ddim_eta=0.0,
verbose=True,
):
self.ddim_timesteps = make_ddim_timesteps(
ddim_discr_method=ddim_discretize,
num_ddim_timesteps=ddim_num_steps,
num_ddpm_timesteps=self.ddpm_num_timesteps,
verbose=verbose,
)
alphas_cumprod = self.model.alphas_cumprod
assert alphas_cumprod.shape[0] == self.ddpm_num_timesteps, 'alphas have to be defined for each timestep'
to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model.device)
assert (
alphas_cumprod.shape[0] == self.ddpm_num_timesteps
), 'alphas have to be defined for each timestep'
to_torch = (
lambda x: x.clone()
.detach()
.to(torch.float32)
.to(self.model.device)
)
self.register_buffer('betas', to_torch(self.model.betas))
self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod))
self.register_buffer('alphas_cumprod_prev', to_torch(self.model.alphas_cumprod_prev))
self.register_buffer(
'alphas_cumprod_prev', to_torch(self.model.alphas_cumprod_prev)
)
# calculations for diffusion q(x_t | x_{t-1}) and others
self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod.cpu())))
self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod.cpu())))
self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod.cpu())))
self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu())))
self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu() - 1)))
self.register_buffer(
'sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod.cpu()))
)
self.register_buffer(
'sqrt_one_minus_alphas_cumprod',
to_torch(np.sqrt(1.0 - alphas_cumprod.cpu())),
)
self.register_buffer(
'log_one_minus_alphas_cumprod',
to_torch(np.log(1.0 - alphas_cumprod.cpu())),
)
self.register_buffer(
'sqrt_recip_alphas_cumprod',
to_torch(np.sqrt(1.0 / alphas_cumprod.cpu())),
)
self.register_buffer(
'sqrt_recipm1_alphas_cumprod',
to_torch(np.sqrt(1.0 / alphas_cumprod.cpu() - 1)),
)
# ddim sampling parameters
ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(alphacums=alphas_cumprod.cpu(),
ddim_timesteps=self.ddim_timesteps,
eta=ddim_eta,verbose=verbose)
(
ddim_sigmas,
ddim_alphas,
ddim_alphas_prev,
) = make_ddim_sampling_parameters(
alphacums=alphas_cumprod.cpu(),
ddim_timesteps=self.ddim_timesteps,
eta=ddim_eta,
verbose=verbose,
)
self.register_buffer('ddim_sigmas', ddim_sigmas)
self.register_buffer('ddim_alphas', ddim_alphas)
self.register_buffer('ddim_alphas_prev', ddim_alphas_prev)
self.register_buffer('ddim_sqrt_one_minus_alphas', np.sqrt(1. - ddim_alphas))
self.register_buffer(
'ddim_sqrt_one_minus_alphas', np.sqrt(1.0 - ddim_alphas)
)
sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt(
(1 - self.alphas_cumprod_prev) / (1 - self.alphas_cumprod) * (
1 - self.alphas_cumprod / self.alphas_cumprod_prev))
self.register_buffer('ddim_sigmas_for_original_num_steps', sigmas_for_original_sampling_steps)
(1 - self.alphas_cumprod_prev)
/ (1 - self.alphas_cumprod)
* (1 - self.alphas_cumprod / self.alphas_cumprod_prev)
)
self.register_buffer(
'ddim_sigmas_for_original_num_steps',
sigmas_for_original_sampling_steps,
)
@torch.no_grad()
def sample(self,
S,
batch_size,
shape,
conditioning=None,
callback=None,
normals_sequence=None,
img_callback=None,
quantize_x0=False,
eta=0.,
mask=None,
x0=None,
temperature=1.,
noise_dropout=0.,
score_corrector=None,
corrector_kwargs=None,
verbose=True,
x_T=None,
log_every_t=100,
unconditional_guidance_scale=1.,
unconditional_conditioning=None,
# this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
**kwargs
):
def sample(
self,
S,
batch_size,
shape,
conditioning=None,
callback=None,
normals_sequence=None,
img_callback=None,
quantize_x0=False,
eta=0.0,
mask=None,
x0=None,
temperature=1.0,
noise_dropout=0.0,
score_corrector=None,
corrector_kwargs=None,
verbose=True,
x_T=None,
log_every_t=100,
unconditional_guidance_scale=1.0,
unconditional_conditioning=None,
# this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
**kwargs,
):
if conditioning is not None:
if isinstance(conditioning, dict):
cbs = conditioning[list(conditioning.keys())[0]].shape[0]
if cbs != batch_size:
print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")
print(
f'Warning: Got {cbs} conditionings but batch-size is {batch_size}'
)
else:
if conditioning.shape[0] != batch_size:
print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}")
print(
f'Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}'
)
self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose)
# sampling
@@ -94,30 +151,47 @@ class DDIMSampler(object):
size = (batch_size, C, H, W)
print(f'Data shape for DDIM sampling is {size}, eta {eta}')
samples, intermediates = self.ddim_sampling(conditioning, size,
callback=callback,
img_callback=img_callback,
quantize_denoised=quantize_x0,
mask=mask, x0=x0,
ddim_use_original_steps=False,
noise_dropout=noise_dropout,
temperature=temperature,
score_corrector=score_corrector,
corrector_kwargs=corrector_kwargs,
x_T=x_T,
log_every_t=log_every_t,
unconditional_guidance_scale=unconditional_guidance_scale,
unconditional_conditioning=unconditional_conditioning,
)
samples, intermediates = self.ddim_sampling(
conditioning,
size,
callback=callback,
img_callback=img_callback,
quantize_denoised=quantize_x0,
mask=mask,
x0=x0,
ddim_use_original_steps=False,
noise_dropout=noise_dropout,
temperature=temperature,
score_corrector=score_corrector,
corrector_kwargs=corrector_kwargs,
x_T=x_T,
log_every_t=log_every_t,
unconditional_guidance_scale=unconditional_guidance_scale,
unconditional_conditioning=unconditional_conditioning,
)
return samples, intermediates
@torch.no_grad()
def ddim_sampling(self, cond, shape,
x_T=None, ddim_use_original_steps=False,
callback=None, timesteps=None, quantize_denoised=False,
mask=None, x0=None, img_callback=None, log_every_t=100,
temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
unconditional_guidance_scale=1., unconditional_conditioning=None,):
def ddim_sampling(
self,
cond,
shape,
x_T=None,
ddim_use_original_steps=False,
callback=None,
timesteps=None,
quantize_denoised=False,
mask=None,
x0=None,
img_callback=None,
log_every_t=100,
temperature=1.0,
noise_dropout=0.0,
score_corrector=None,
corrector_kwargs=None,
unconditional_guidance_scale=1.0,
unconditional_conditioning=None,
):
device = self.model.betas.device
b = shape[0]
if x_T is None:
@@ -126,17 +200,38 @@ class DDIMSampler(object):
img = x_T
if timesteps is None:
timesteps = self.ddpm_num_timesteps if ddim_use_original_steps else self.ddim_timesteps
timesteps = (
self.ddpm_num_timesteps
if ddim_use_original_steps
else self.ddim_timesteps
)
elif timesteps is not None and not ddim_use_original_steps:
subset_end = int(min(timesteps / self.ddim_timesteps.shape[0], 1) * self.ddim_timesteps.shape[0]) - 1
subset_end = (
int(
min(timesteps / self.ddim_timesteps.shape[0], 1)
* self.ddim_timesteps.shape[0]
)
- 1
)
timesteps = self.ddim_timesteps[:subset_end]
intermediates = {'x_inter': [img], 'pred_x0': [img]}
time_range = reversed(range(0,timesteps)) if ddim_use_original_steps else np.flip(timesteps)
total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0]
print(f"Running DDIM Sampling with {total_steps} timesteps")
time_range = (
reversed(range(0, timesteps))
if ddim_use_original_steps
else np.flip(timesteps)
)
total_steps = (
timesteps if ddim_use_original_steps else timesteps.shape[0]
)
print(f'Running DDIM Sampling with {total_steps} timesteps')
iterator = tqdm(time_range, desc='DDIM Sampler', total=total_steps, dynamic_ncols=True)
iterator = tqdm(
time_range,
desc='DDIM Sampler',
total=total_steps,
dynamic_ncols=True,
)
for i, step in enumerate(iterator):
index = total_steps - i - 1
@@ -144,18 +239,30 @@ class DDIMSampler(object):
if mask is not None:
assert x0 is not None
img_orig = self.model.q_sample(x0, ts) # TODO: deterministic forward pass?
img = img_orig * mask + (1. - mask) * img
img_orig = self.model.q_sample(
x0, ts
) # TODO: deterministic forward pass?
img = img_orig * mask + (1.0 - mask) * img
outs = self.p_sample_ddim(img, cond, ts, index=index, use_original_steps=ddim_use_original_steps,
quantize_denoised=quantize_denoised, temperature=temperature,
noise_dropout=noise_dropout, score_corrector=score_corrector,
corrector_kwargs=corrector_kwargs,
unconditional_guidance_scale=unconditional_guidance_scale,
unconditional_conditioning=unconditional_conditioning)
outs = self.p_sample_ddim(
img,
cond,
ts,
index=index,
use_original_steps=ddim_use_original_steps,
quantize_denoised=quantize_denoised,
temperature=temperature,
noise_dropout=noise_dropout,
score_corrector=score_corrector,
corrector_kwargs=corrector_kwargs,
unconditional_guidance_scale=unconditional_guidance_scale,
unconditional_conditioning=unconditional_conditioning,
)
img, pred_x0 = outs
if callback: callback(i)
if img_callback: img_callback(pred_x0, i)
if callback:
callback(i)
if img_callback:
img_callback(pred_x0, i)
if index % log_every_t == 0 or index == total_steps - 1:
intermediates['x_inter'].append(img)
@@ -164,42 +271,82 @@ class DDIMSampler(object):
return img, intermediates
@torch.no_grad()
def p_sample_ddim(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False,
temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
unconditional_guidance_scale=1., unconditional_conditioning=None):
def p_sample_ddim(
self,
x,
c,
t,
index,
repeat_noise=False,
use_original_steps=False,
quantize_denoised=False,
temperature=1.0,
noise_dropout=0.0,
score_corrector=None,
corrector_kwargs=None,
unconditional_guidance_scale=1.0,
unconditional_conditioning=None,
):
b, *_, device = *x.shape, x.device
if unconditional_conditioning is None or unconditional_guidance_scale == 1.:
if (
unconditional_conditioning is None
or unconditional_guidance_scale == 1.0
):
e_t = self.model.apply_model(x, t, c)
else:
x_in = torch.cat([x] * 2)
t_in = torch.cat([t] * 2)
c_in = torch.cat([unconditional_conditioning, c])
e_t_uncond, e_t = self.model.apply_model(x_in, t_in, c_in).chunk(2)
e_t = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond)
e_t = e_t_uncond + unconditional_guidance_scale * (
e_t - e_t_uncond
)
if score_corrector is not None:
assert self.model.parameterization == "eps"
e_t = score_corrector.modify_score(self.model, e_t, x, t, c, **corrector_kwargs)
assert self.model.parameterization == 'eps'
e_t = score_corrector.modify_score(
self.model, e_t, x, t, c, **corrector_kwargs
)
alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas
alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev
sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas
sigmas = self.model.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas
alphas = (
self.model.alphas_cumprod
if use_original_steps
else self.ddim_alphas
)
alphas_prev = (
self.model.alphas_cumprod_prev
if use_original_steps
else self.ddim_alphas_prev
)
sqrt_one_minus_alphas = (
self.model.sqrt_one_minus_alphas_cumprod
if use_original_steps
else self.ddim_sqrt_one_minus_alphas
)
sigmas = (
self.model.ddim_sigmas_for_original_num_steps
if use_original_steps
else self.ddim_sigmas
)
# select parameters corresponding to the currently considered timestep
a_t = torch.full((b, 1, 1, 1), alphas[index], device=device)
a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device)
sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device)
sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index],device=device)
sqrt_one_minus_at = torch.full(
(b, 1, 1, 1), sqrt_one_minus_alphas[index], device=device
)
# current prediction for x_0
pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
if quantize_denoised:
pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0)
# direction pointing to x_t
dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t
noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature
if noise_dropout > 0.:
dir_xt = (1.0 - a_prev - sigma_t**2).sqrt() * e_t
noise = (
sigma_t * noise_like(x.shape, device, repeat_noise) * temperature
)
if noise_dropout > 0.0:
noise = torch.nn.functional.dropout(noise, p=noise_dropout)
x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise
return x_prev, pred_x0
@@ -217,26 +364,55 @@ class DDIMSampler(object):
if noise is None:
noise = torch.randn_like(x0)
return (extract_into_tensor(sqrt_alphas_cumprod, t, x0.shape) * x0 +
extract_into_tensor(sqrt_one_minus_alphas_cumprod, t, x0.shape) * noise)
return (
extract_into_tensor(sqrt_alphas_cumprod, t, x0.shape) * x0
+ extract_into_tensor(sqrt_one_minus_alphas_cumprod, t, x0.shape)
* noise
)
@torch.no_grad()
def decode(self, x_latent, cond, t_start, unconditional_guidance_scale=1.0, unconditional_conditioning=None,
use_original_steps=False):
def decode(
self,
x_latent,
cond,
t_start,
img_callback=None,
unconditional_guidance_scale=1.0,
unconditional_conditioning=None,
use_original_steps=False,
):
timesteps = np.arange(self.ddpm_num_timesteps) if use_original_steps else self.ddim_timesteps
timesteps = (
np.arange(self.ddpm_num_timesteps)
if use_original_steps
else self.ddim_timesteps
)
timesteps = timesteps[:t_start]
time_range = np.flip(timesteps)
total_steps = timesteps.shape[0]
print(f"Running DDIM Sampling with {total_steps} timesteps")
print(f'Running DDIM Sampling with {total_steps} timesteps')
iterator = tqdm(time_range, desc='Decoding image', total=total_steps)
x_dec = x_latent
for i, step in enumerate(iterator):
index = total_steps - i - 1
ts = torch.full((x_latent.shape[0],), step, device=x_latent.device, dtype=torch.long)
x_dec, _ = self.p_sample_ddim(x_dec, cond, ts, index=index, use_original_steps=use_original_steps,
unconditional_guidance_scale=unconditional_guidance_scale,
unconditional_conditioning=unconditional_conditioning)
ts = torch.full(
(x_latent.shape[0],),
step,
device=x_latent.device,
dtype=torch.long,
)
x_dec, _ = self.p_sample_ddim(
x_dec,
cond,
ts,
index=index,
use_original_steps=use_original_steps,
unconditional_guidance_scale=unconditional_guidance_scale,
unconditional_conditioning=unconditional_conditioning,
)
if img_callback:
img_callback(x_dec, i)
return x_dec

File diff suppressed because it is too large Load Diff

View File

@@ -1,8 +1,8 @@
'''wrapper around part of Karen Crownson's k-duffsion library, making it call compatible with other Samplers'''
"""wrapper around part of Katherine Crowson's k-diffusion library, making it call compatible with other Samplers"""
import k_diffusion as K
import torch
import torch.nn as nn
import accelerate
from ldm.dream.devices import choose_torch_device
class CFGDenoiser(nn.Module):
def __init__(self, model):
@@ -16,59 +16,73 @@ class CFGDenoiser(nn.Module):
uncond, cond = self.inner_model(x_in, sigma_in, cond=cond_in).chunk(2)
return uncond + (cond - uncond) * cond_scale
class KSampler(object):
def __init__(self,model,schedule="lms", **kwargs):
def __init__(self, model, schedule='lms', device=None, **kwargs):
super().__init__()
self.model = K.external.CompVisDenoiser(model)
self.accelerator = accelerate.Accelerator()
self.device = self.accelerator.device
self.model = K.external.CompVisDenoiser(model)
self.schedule = schedule
self.device = device or choose_torch_device()
def forward(self, x, sigma, uncond, cond, cond_scale):
x_in = torch.cat([x] * 2)
sigma_in = torch.cat([sigma] * 2)
cond_in = torch.cat([uncond, cond])
uncond, cond = self.inner_model(x_in, sigma_in, cond=cond_in).chunk(2)
uncond, cond = self.inner_model(
x_in, sigma_in, cond=cond_in
).chunk(2)
return uncond + (cond - uncond) * cond_scale
# most of these arguments are ignored and are only present for compatibility with
# other samples
@torch.no_grad()
def sample(self,
S,
batch_size,
shape,
conditioning=None,
callback=None,
normals_sequence=None,
img_callback=None,
quantize_x0=False,
eta=0.,
mask=None,
x0=None,
temperature=1.,
noise_dropout=0.,
score_corrector=None,
corrector_kwargs=None,
verbose=True,
x_T=None,
log_every_t=100,
unconditional_guidance_scale=1.,
unconditional_conditioning=None,
# this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
**kwargs
):
def sample(
self,
S,
batch_size,
shape,
conditioning=None,
callback=None,
normals_sequence=None,
img_callback=None,
quantize_x0=False,
eta=0.0,
mask=None,
x0=None,
temperature=1.0,
noise_dropout=0.0,
score_corrector=None,
corrector_kwargs=None,
verbose=True,
x_T=None,
log_every_t=100,
unconditional_guidance_scale=1.0,
unconditional_conditioning=None,
# this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
**kwargs,
):
def route_callback(k_callback_values):
if img_callback is not None:
img_callback(k_callback_values['x'], k_callback_values['i'])
sigmas = self.model.get_sigmas(S)
if x_T:
x = x_T
if x_T is not None:
x = x_T * sigmas[0]
else:
x = torch.randn([batch_size, *shape], device=self.device) * sigmas[0] # for GPU draw
x = (
torch.randn([batch_size, *shape], device=self.device)
* sigmas[0]
) # for GPU draw
model_wrap_cfg = CFGDenoiser(self.model)
extra_args = {'cond': conditioning, 'uncond': unconditional_conditioning, 'cond_scale': unconditional_guidance_scale}
return (K.sampling.__dict__[f'sample_{self.schedule}'](model_wrap_cfg, x, sigmas, extra_args=extra_args, disable=not self.accelerator.is_main_process),
None)
def gather(samples_ddim):
return self.accelerator.gather(samples_ddim)
extra_args = {
'cond': conditioning,
'uncond': unconditional_conditioning,
'cond_scale': unconditional_guidance_scale,
}
return (
K.sampling.__dict__[f'sample_{self.schedule}'](
model_wrap_cfg, x, sigmas, extra_args=extra_args,
callback=route_callback
),
None,
)

View File

@@ -4,122 +4,195 @@ import torch
import numpy as np
from tqdm import tqdm
from functools import partial
from ldm.dream.devices import choose_torch_device
from ldm.modules.diffusionmodules.util import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like
from ldm.modules.diffusionmodules.util import (
make_ddim_sampling_parameters,
make_ddim_timesteps,
noise_like,
)
class PLMSSampler(object):
def __init__(self, model, schedule="linear", device="cuda", **kwargs):
def __init__(self, model, schedule='linear', device=None, **kwargs):
super().__init__()
self.model = model
self.ddpm_num_timesteps = model.num_timesteps
self.schedule = schedule
self.device = device
self.device = device if device else choose_torch_device()
def register_buffer(self, name, attr):
if type(attr) == torch.Tensor:
if attr.device != torch.device(self.device):
attr = attr.to(torch.device(self.device))
attr = attr.to(torch.float32).to(torch.device(self.device))
setattr(self, name, attr)
def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True):
def make_schedule(
self,
ddim_num_steps,
ddim_discretize='uniform',
ddim_eta=0.0,
verbose=True,
):
if ddim_eta != 0:
raise ValueError('ddim_eta must be 0 for PLMS')
self.ddim_timesteps = make_ddim_timesteps(ddim_discr_method=ddim_discretize, num_ddim_timesteps=ddim_num_steps,
num_ddpm_timesteps=self.ddpm_num_timesteps,verbose=verbose)
self.ddim_timesteps = make_ddim_timesteps(
ddim_discr_method=ddim_discretize,
num_ddim_timesteps=ddim_num_steps,
num_ddpm_timesteps=self.ddpm_num_timesteps,
verbose=verbose,
)
alphas_cumprod = self.model.alphas_cumprod
assert alphas_cumprod.shape[0] == self.ddpm_num_timesteps, 'alphas have to be defined for each timestep'
to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model.device)
assert (
alphas_cumprod.shape[0] == self.ddpm_num_timesteps
), 'alphas have to be defined for each timestep'
to_torch = (
lambda x: x.clone()
.detach()
.to(torch.float32)
.to(self.model.device)
)
self.register_buffer('betas', to_torch(self.model.betas))
self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod))
self.register_buffer('alphas_cumprod_prev', to_torch(self.model.alphas_cumprod_prev))
self.register_buffer(
'alphas_cumprod_prev', to_torch(self.model.alphas_cumprod_prev)
)
# calculations for diffusion q(x_t | x_{t-1}) and others
self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod.cpu())))
self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod.cpu())))
self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod.cpu())))
self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu())))
self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu() - 1)))
self.register_buffer(
'sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod.cpu()))
)
self.register_buffer(
'sqrt_one_minus_alphas_cumprod',
to_torch(np.sqrt(1.0 - alphas_cumprod.cpu())),
)
self.register_buffer(
'log_one_minus_alphas_cumprod',
to_torch(np.log(1.0 - alphas_cumprod.cpu())),
)
self.register_buffer(
'sqrt_recip_alphas_cumprod',
to_torch(np.sqrt(1.0 / alphas_cumprod.cpu())),
)
self.register_buffer(
'sqrt_recipm1_alphas_cumprod',
to_torch(np.sqrt(1.0 / alphas_cumprod.cpu() - 1)),
)
# ddim sampling parameters
ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(alphacums=alphas_cumprod.cpu(),
ddim_timesteps=self.ddim_timesteps,
eta=ddim_eta,verbose=verbose)
(
ddim_sigmas,
ddim_alphas,
ddim_alphas_prev,
) = make_ddim_sampling_parameters(
alphacums=alphas_cumprod.cpu(),
ddim_timesteps=self.ddim_timesteps,
eta=ddim_eta,
verbose=verbose,
)
self.register_buffer('ddim_sigmas', ddim_sigmas)
self.register_buffer('ddim_alphas', ddim_alphas)
self.register_buffer('ddim_alphas_prev', ddim_alphas_prev)
self.register_buffer('ddim_sqrt_one_minus_alphas', np.sqrt(1. - ddim_alphas))
self.register_buffer(
'ddim_sqrt_one_minus_alphas', np.sqrt(1.0 - ddim_alphas)
)
sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt(
(1 - self.alphas_cumprod_prev) / (1 - self.alphas_cumprod) * (
1 - self.alphas_cumprod / self.alphas_cumprod_prev))
self.register_buffer('ddim_sigmas_for_original_num_steps', sigmas_for_original_sampling_steps)
(1 - self.alphas_cumprod_prev)
/ (1 - self.alphas_cumprod)
* (1 - self.alphas_cumprod / self.alphas_cumprod_prev)
)
self.register_buffer(
'ddim_sigmas_for_original_num_steps',
sigmas_for_original_sampling_steps,
)
@torch.no_grad()
def sample(self,
S,
batch_size,
shape,
conditioning=None,
callback=None,
normals_sequence=None,
img_callback=None,
quantize_x0=False,
eta=0.,
mask=None,
x0=None,
temperature=1.,
noise_dropout=0.,
score_corrector=None,
corrector_kwargs=None,
verbose=True,
x_T=None,
log_every_t=100,
unconditional_guidance_scale=1.,
unconditional_conditioning=None,
# this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
**kwargs
):
def sample(
self,
S,
batch_size,
shape,
conditioning=None,
callback=None,
normals_sequence=None,
img_callback=None,
quantize_x0=False,
eta=0.0,
mask=None,
x0=None,
temperature=1.0,
noise_dropout=0.0,
score_corrector=None,
corrector_kwargs=None,
verbose=True,
x_T=None,
log_every_t=100,
unconditional_guidance_scale=1.0,
unconditional_conditioning=None,
# this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
**kwargs,
):
if conditioning is not None:
if isinstance(conditioning, dict):
cbs = conditioning[list(conditioning.keys())[0]].shape[0]
if cbs != batch_size:
print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")
print(
f'Warning: Got {cbs} conditionings but batch-size is {batch_size}'
)
else:
if conditioning.shape[0] != batch_size:
print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}")
print(
f'Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}'
)
self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose)
# sampling
C, H, W = shape
size = (batch_size, C, H, W)
# print(f'Data shape for PLMS sampling is {size}')
# print(f'Data shape for PLMS sampling is {size}')
samples, intermediates = self.plms_sampling(conditioning, size,
callback=callback,
img_callback=img_callback,
quantize_denoised=quantize_x0,
mask=mask, x0=x0,
ddim_use_original_steps=False,
noise_dropout=noise_dropout,
temperature=temperature,
score_corrector=score_corrector,
corrector_kwargs=corrector_kwargs,
x_T=x_T,
log_every_t=log_every_t,
unconditional_guidance_scale=unconditional_guidance_scale,
unconditional_conditioning=unconditional_conditioning,
)
samples, intermediates = self.plms_sampling(
conditioning,
size,
callback=callback,
img_callback=img_callback,
quantize_denoised=quantize_x0,
mask=mask,
x0=x0,
ddim_use_original_steps=False,
noise_dropout=noise_dropout,
temperature=temperature,
score_corrector=score_corrector,
corrector_kwargs=corrector_kwargs,
x_T=x_T,
log_every_t=log_every_t,
unconditional_guidance_scale=unconditional_guidance_scale,
unconditional_conditioning=unconditional_conditioning,
)
return samples, intermediates
@torch.no_grad()
def plms_sampling(self, cond, shape,
x_T=None, ddim_use_original_steps=False,
callback=None, timesteps=None, quantize_denoised=False,
mask=None, x0=None, img_callback=None, log_every_t=100,
temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
unconditional_guidance_scale=1., unconditional_conditioning=None,):
def plms_sampling(
self,
cond,
shape,
x_T=None,
ddim_use_original_steps=False,
callback=None,
timesteps=None,
quantize_denoised=False,
mask=None,
x0=None,
img_callback=None,
log_every_t=100,
temperature=1.0,
noise_dropout=0.0,
score_corrector=None,
corrector_kwargs=None,
unconditional_guidance_scale=1.0,
unconditional_conditioning=None,
):
device = self.model.betas.device
b = shape[0]
if x_T is None:
@@ -128,42 +201,81 @@ class PLMSSampler(object):
img = x_T
if timesteps is None:
timesteps = self.ddpm_num_timesteps if ddim_use_original_steps else self.ddim_timesteps
timesteps = (
self.ddpm_num_timesteps
if ddim_use_original_steps
else self.ddim_timesteps
)
elif timesteps is not None and not ddim_use_original_steps:
subset_end = int(min(timesteps / self.ddim_timesteps.shape[0], 1) * self.ddim_timesteps.shape[0]) - 1
subset_end = (
int(
min(timesteps / self.ddim_timesteps.shape[0], 1)
* self.ddim_timesteps.shape[0]
)
- 1
)
timesteps = self.ddim_timesteps[:subset_end]
intermediates = {'x_inter': [img], 'pred_x0': [img]}
time_range = list(reversed(range(0,timesteps))) if ddim_use_original_steps else np.flip(timesteps)
total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0]
# print(f"Running PLMS Sampling with {total_steps} timesteps")
time_range = (
list(reversed(range(0, timesteps)))
if ddim_use_original_steps
else np.flip(timesteps)
)
total_steps = (
timesteps if ddim_use_original_steps else timesteps.shape[0]
)
# print(f"Running PLMS Sampling with {total_steps} timesteps")
iterator = tqdm(time_range, desc='PLMS Sampler', total=total_steps, dynamic_ncols=True)
iterator = tqdm(
time_range,
desc='PLMS Sampler',
total=total_steps,
dynamic_ncols=True,
)
old_eps = []
for i, step in enumerate(iterator):
index = total_steps - i - 1
ts = torch.full((b,), step, device=device, dtype=torch.long)
ts_next = torch.full((b,), time_range[min(i + 1, len(time_range) - 1)], device=device, dtype=torch.long)
ts_next = torch.full(
(b,),
time_range[min(i + 1, len(time_range) - 1)],
device=device,
dtype=torch.long,
)
if mask is not None:
assert x0 is not None
img_orig = self.model.q_sample(x0, ts) # TODO: deterministic forward pass?
img = img_orig * mask + (1. - mask) * img
img_orig = self.model.q_sample(
x0, ts
) # TODO: deterministic forward pass?
img = img_orig * mask + (1.0 - mask) * img
outs = self.p_sample_plms(img, cond, ts, index=index, use_original_steps=ddim_use_original_steps,
quantize_denoised=quantize_denoised, temperature=temperature,
noise_dropout=noise_dropout, score_corrector=score_corrector,
corrector_kwargs=corrector_kwargs,
unconditional_guidance_scale=unconditional_guidance_scale,
unconditional_conditioning=unconditional_conditioning,
old_eps=old_eps, t_next=ts_next)
outs = self.p_sample_plms(
img,
cond,
ts,
index=index,
use_original_steps=ddim_use_original_steps,
quantize_denoised=quantize_denoised,
temperature=temperature,
noise_dropout=noise_dropout,
score_corrector=score_corrector,
corrector_kwargs=corrector_kwargs,
unconditional_guidance_scale=unconditional_guidance_scale,
unconditional_conditioning=unconditional_conditioning,
old_eps=old_eps,
t_next=ts_next,
)
img, pred_x0, e_t = outs
old_eps.append(e_t)
if len(old_eps) >= 4:
old_eps.pop(0)
if callback: callback(i)
if img_callback: img_callback(pred_x0, i)
if callback:
callback(i)
if img_callback:
img_callback(pred_x0, i)
if index % log_every_t == 0 or index == total_steps - 1:
intermediates['x_inter'].append(img)
@@ -172,47 +284,95 @@ class PLMSSampler(object):
return img, intermediates
@torch.no_grad()
def p_sample_plms(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False,
temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
unconditional_guidance_scale=1., unconditional_conditioning=None, old_eps=None, t_next=None):
def p_sample_plms(
self,
x,
c,
t,
index,
repeat_noise=False,
use_original_steps=False,
quantize_denoised=False,
temperature=1.0,
noise_dropout=0.0,
score_corrector=None,
corrector_kwargs=None,
unconditional_guidance_scale=1.0,
unconditional_conditioning=None,
old_eps=None,
t_next=None,
):
b, *_, device = *x.shape, x.device
def get_model_output(x, t):
if unconditional_conditioning is None or unconditional_guidance_scale == 1.:
if (
unconditional_conditioning is None
or unconditional_guidance_scale == 1.0
):
e_t = self.model.apply_model(x, t, c)
else:
x_in = torch.cat([x] * 2)
t_in = torch.cat([t] * 2)
c_in = torch.cat([unconditional_conditioning, c])
e_t_uncond, e_t = self.model.apply_model(x_in, t_in, c_in).chunk(2)
e_t = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond)
e_t_uncond, e_t = self.model.apply_model(
x_in, t_in, c_in
).chunk(2)
e_t = e_t_uncond + unconditional_guidance_scale * (
e_t - e_t_uncond
)
if score_corrector is not None:
assert self.model.parameterization == "eps"
e_t = score_corrector.modify_score(self.model, e_t, x, t, c, **corrector_kwargs)
assert self.model.parameterization == 'eps'
e_t = score_corrector.modify_score(
self.model, e_t, x, t, c, **corrector_kwargs
)
return e_t
alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas
alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev
sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas
sigmas = self.model.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas
alphas = (
self.model.alphas_cumprod
if use_original_steps
else self.ddim_alphas
)
alphas_prev = (
self.model.alphas_cumprod_prev
if use_original_steps
else self.ddim_alphas_prev
)
sqrt_one_minus_alphas = (
self.model.sqrt_one_minus_alphas_cumprod
if use_original_steps
else self.ddim_sqrt_one_minus_alphas
)
sigmas = (
self.model.ddim_sigmas_for_original_num_steps
if use_original_steps
else self.ddim_sigmas
)
def get_x_prev_and_pred_x0(e_t, index):
# select parameters corresponding to the currently considered timestep
a_t = torch.full((b, 1, 1, 1), alphas[index], device=device)
a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device)
a_prev = torch.full(
(b, 1, 1, 1), alphas_prev[index], device=device
)
sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device)
sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index],device=device)
sqrt_one_minus_at = torch.full(
(b, 1, 1, 1), sqrt_one_minus_alphas[index], device=device
)
# current prediction for x_0
pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
if quantize_denoised:
pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0)
# direction pointing to x_t
dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t
noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature
if noise_dropout > 0.:
dir_xt = (1.0 - a_prev - sigma_t**2).sqrt() * e_t
noise = (
sigma_t
* noise_like(x.shape, device, repeat_noise)
* temperature
)
if noise_dropout > 0.0:
noise = torch.nn.functional.dropout(noise, p=noise_dropout)
x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise
return x_prev, pred_x0
@@ -231,7 +391,12 @@ class PLMSSampler(object):
e_t_prime = (23 * e_t - 16 * old_eps[-1] + 5 * old_eps[-2]) / 12
elif len(old_eps) >= 3:
# 4nd order Pseudo Linear Multistep (Adams-Bashforth)
e_t_prime = (55 * e_t - 59 * old_eps[-1] + 37 * old_eps[-2] - 9 * old_eps[-3]) / 24
e_t_prime = (
55 * e_t
- 59 * old_eps[-1]
+ 37 * old_eps[-2]
- 9 * old_eps[-3]
) / 24
x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t_prime, index)

View File

@@ -13,7 +13,7 @@ def exists(val):
def uniq(arr):
return{el: True for el in arr}.keys()
return {el: True for el in arr}.keys()
def default(val, d):
@@ -45,19 +45,18 @@ class GEGLU(nn.Module):
class FeedForward(nn.Module):
def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.):
def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.0):
super().__init__()
inner_dim = int(dim * mult)
dim_out = default(dim_out, dim)
project_in = nn.Sequential(
nn.Linear(dim, inner_dim),
nn.GELU()
) if not glu else GEGLU(dim, inner_dim)
project_in = (
nn.Sequential(nn.Linear(dim, inner_dim), nn.GELU())
if not glu
else GEGLU(dim, inner_dim)
)
self.net = nn.Sequential(
project_in,
nn.Dropout(dropout),
nn.Linear(inner_dim, dim_out)
project_in, nn.Dropout(dropout), nn.Linear(inner_dim, dim_out)
)
def forward(self, x):
@@ -74,7 +73,9 @@ def zero_module(module):
def Normalize(in_channels):
return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
return torch.nn.GroupNorm(
num_groups=32, num_channels=in_channels, eps=1e-6, affine=True
)
class LinearAttention(nn.Module):
@@ -82,17 +83,28 @@ class LinearAttention(nn.Module):
super().__init__()
self.heads = heads
hidden_dim = dim_head * heads
self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias = False)
self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias=False)
self.to_out = nn.Conv2d(hidden_dim, dim, 1)
def forward(self, x):
b, c, h, w = x.shape
qkv = self.to_qkv(x)
q, k, v = rearrange(qkv, 'b (qkv heads c) h w -> qkv b heads c (h w)', heads = self.heads, qkv=3)
k = k.softmax(dim=-1)
q, k, v = rearrange(
qkv,
'b (qkv heads c) h w -> qkv b heads c (h w)',
heads=self.heads,
qkv=3,
)
k = k.softmax(dim=-1)
context = torch.einsum('bhdn,bhen->bhde', k, v)
out = torch.einsum('bhde,bhdn->bhen', context, q)
out = rearrange(out, 'b heads c (h w) -> b (heads c) h w', heads=self.heads, h=h, w=w)
out = rearrange(
out,
'b heads c (h w) -> b (heads c) h w',
heads=self.heads,
h=h,
w=w,
)
return self.to_out(out)
@@ -102,26 +114,18 @@ class SpatialSelfAttention(nn.Module):
self.in_channels = in_channels
self.norm = Normalize(in_channels)
self.q = torch.nn.Conv2d(in_channels,
in_channels,
kernel_size=1,
stride=1,
padding=0)
self.k = torch.nn.Conv2d(in_channels,
in_channels,
kernel_size=1,
stride=1,
padding=0)
self.v = torch.nn.Conv2d(in_channels,
in_channels,
kernel_size=1,
stride=1,
padding=0)
self.proj_out = torch.nn.Conv2d(in_channels,
in_channels,
kernel_size=1,
stride=1,
padding=0)
self.q = torch.nn.Conv2d(
in_channels, in_channels, kernel_size=1, stride=1, padding=0
)
self.k = torch.nn.Conv2d(
in_channels, in_channels, kernel_size=1, stride=1, padding=0
)
self.v = torch.nn.Conv2d(
in_channels, in_channels, kernel_size=1, stride=1, padding=0
)
self.proj_out = torch.nn.Conv2d(
in_channels, in_channels, kernel_size=1, stride=1, padding=0
)
def forward(self, x):
h_ = x
@@ -131,12 +135,12 @@ class SpatialSelfAttention(nn.Module):
v = self.v(h_)
# compute attention
b,c,h,w = q.shape
b, c, h, w = q.shape
q = rearrange(q, 'b c h w -> b (h w) c')
k = rearrange(k, 'b c h w -> b c (h w)')
w_ = torch.einsum('bij,bjk->bik', q, k)
w_ = w_ * (int(c)**(-0.5))
w_ = w_ * (int(c) ** (-0.5))
w_ = torch.nn.functional.softmax(w_, dim=2)
# attend to values
@@ -146,16 +150,18 @@ class SpatialSelfAttention(nn.Module):
h_ = rearrange(h_, 'b c (h w) -> b c h w', h=h)
h_ = self.proj_out(h_)
return x+h_
return x + h_
class CrossAttention(nn.Module):
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.):
def __init__(
self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.0
):
super().__init__()
inner_dim = dim_head * heads
context_dim = default(context_dim, query_dim)
self.scale = dim_head ** -0.5
self.scale = dim_head**-0.5
self.heads = heads
self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
@@ -163,8 +169,7 @@ class CrossAttention(nn.Module):
self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
self.to_out = nn.Sequential(
nn.Linear(inner_dim, query_dim),
nn.Dropout(dropout)
nn.Linear(inner_dim, query_dim), nn.Dropout(dropout)
)
def forward(self, x, context=None, mask=None):
@@ -175,7 +180,9 @@ class CrossAttention(nn.Module):
k = self.to_k(context)
v = self.to_v(context)
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
q, k, v = map(
lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v)
)
sim = einsum('b i d, b j d -> b i j', q, k) * self.scale
@@ -194,21 +201,40 @@ class CrossAttention(nn.Module):
class BasicTransformerBlock(nn.Module):
def __init__(self, dim, n_heads, d_head, dropout=0., context_dim=None, gated_ff=True, checkpoint=True):
def __init__(
self,
dim,
n_heads,
d_head,
dropout=0.0,
context_dim=None,
gated_ff=True,
checkpoint=True,
):
super().__init__()
self.attn1 = CrossAttention(query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout) # is a self-attention
self.attn1 = CrossAttention(
query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout
) # is a self-attention
self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff)
self.attn2 = CrossAttention(query_dim=dim, context_dim=context_dim,
heads=n_heads, dim_head=d_head, dropout=dropout) # is self-attn if context is none
self.attn2 = CrossAttention(
query_dim=dim,
context_dim=context_dim,
heads=n_heads,
dim_head=d_head,
dropout=dropout,
) # is self-attn if context is none
self.norm1 = nn.LayerNorm(dim)
self.norm2 = nn.LayerNorm(dim)
self.norm3 = nn.LayerNorm(dim)
self.checkpoint = checkpoint
def forward(self, x, context=None):
return checkpoint(self._forward, (x, context), self.parameters(), self.checkpoint)
return checkpoint(
self._forward, (x, context), self.parameters(), self.checkpoint
)
def _forward(self, x, context=None):
x = x.contiguous() if x.device.type == 'mps' else x
x = self.attn1(self.norm1(x)) + x
x = self.attn2(self.norm2(x), context=context) + x
x = self.ff(self.norm3(x)) + x
@@ -223,29 +249,43 @@ class SpatialTransformer(nn.Module):
Then apply standard transformer action.
Finally, reshape to image
"""
def __init__(self, in_channels, n_heads, d_head,
depth=1, dropout=0., context_dim=None):
def __init__(
self,
in_channels,
n_heads,
d_head,
depth=1,
dropout=0.0,
context_dim=None,
):
super().__init__()
self.in_channels = in_channels
inner_dim = n_heads * d_head
self.norm = Normalize(in_channels)
self.proj_in = nn.Conv2d(in_channels,
inner_dim,
kernel_size=1,
stride=1,
padding=0)
self.transformer_blocks = nn.ModuleList(
[BasicTransformerBlock(inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim)
for d in range(depth)]
self.proj_in = nn.Conv2d(
in_channels, inner_dim, kernel_size=1, stride=1, padding=0
)
self.proj_out = zero_module(nn.Conv2d(inner_dim,
in_channels,
kernel_size=1,
stride=1,
padding=0))
self.transformer_blocks = nn.ModuleList(
[
BasicTransformerBlock(
inner_dim,
n_heads,
d_head,
dropout=dropout,
context_dim=context_dim,
)
for d in range(depth)
]
)
self.proj_out = zero_module(
nn.Conv2d(
inner_dim, in_channels, kernel_size=1, stride=1, padding=0
)
)
def forward(self, x, context=None):
# note: if no context is given, cross-attention defaults to self-attention
@@ -258,4 +298,4 @@ class SpatialTransformer(nn.Module):
x = block(x, context=context)
x = rearrange(x, 'b (h w) c -> b c h w', h=h, w=w)
x = self.proj_out(x)
return x + x_in
return x + x_in

File diff suppressed because it is too large Load Diff

View File

@@ -24,6 +24,7 @@ from ldm.modules.attention import SpatialTransformer
def convert_module_to_f16(x):
pass
def convert_module_to_f32(x):
pass
@@ -42,7 +43,9 @@ class AttentionPool2d(nn.Module):
output_dim: int = None,
):
super().__init__()
self.positional_embedding = nn.Parameter(th.randn(embed_dim, spacial_dim ** 2 + 1) / embed_dim ** 0.5)
self.positional_embedding = nn.Parameter(
th.randn(embed_dim, spacial_dim**2 + 1) / embed_dim**0.5
)
self.qkv_proj = conv_nd(1, embed_dim, 3 * embed_dim, 1)
self.c_proj = conv_nd(1, embed_dim, output_dim or embed_dim, 1)
self.num_heads = embed_dim // num_heads_channels
@@ -97,37 +100,45 @@ class Upsample(nn.Module):
upsampling occurs in the inner-two dimensions.
"""
def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1):
def __init__(
self, channels, use_conv, dims=2, out_channels=None, padding=1
):
super().__init__()
self.channels = channels
self.out_channels = out_channels or channels
self.use_conv = use_conv
self.dims = dims
if use_conv:
self.conv = conv_nd(dims, self.channels, self.out_channels, 3, padding=padding)
self.conv = conv_nd(
dims, self.channels, self.out_channels, 3, padding=padding
)
def forward(self, x):
assert x.shape[1] == self.channels
if self.dims == 3:
x = F.interpolate(
x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2), mode="nearest"
x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2), mode='nearest'
)
else:
x = F.interpolate(x, scale_factor=2, mode="nearest")
x = F.interpolate(x, scale_factor=2, mode='nearest')
if self.use_conv:
x = self.conv(x)
return x
class TransposedUpsample(nn.Module):
'Learned 2x upsampling without padding'
"""Learned 2x upsampling without padding"""
def __init__(self, channels, out_channels=None, ks=5):
super().__init__()
self.channels = channels
self.out_channels = out_channels or channels
self.up = nn.ConvTranspose2d(self.channels,self.out_channels,kernel_size=ks,stride=2)
self.up = nn.ConvTranspose2d(
self.channels, self.out_channels, kernel_size=ks, stride=2
)
def forward(self,x):
def forward(self, x):
return self.up(x)
@@ -140,7 +151,9 @@ class Downsample(nn.Module):
downsampling occurs in the inner-two dimensions.
"""
def __init__(self, channels, use_conv, dims=2, out_channels=None,padding=1):
def __init__(
self, channels, use_conv, dims=2, out_channels=None, padding=1
):
super().__init__()
self.channels = channels
self.out_channels = out_channels or channels
@@ -149,7 +162,12 @@ class Downsample(nn.Module):
stride = 2 if dims != 3 else (1, 2, 2)
if use_conv:
self.op = conv_nd(
dims, self.channels, self.out_channels, 3, stride=stride, padding=padding
dims,
self.channels,
self.out_channels,
3,
stride=stride,
padding=padding,
)
else:
assert self.channels == self.out_channels
@@ -219,7 +237,9 @@ class ResBlock(TimestepBlock):
nn.SiLU(),
linear(
emb_channels,
2 * self.out_channels if use_scale_shift_norm else self.out_channels,
2 * self.out_channels
if use_scale_shift_norm
else self.out_channels,
),
)
self.out_layers = nn.Sequential(
@@ -227,7 +247,9 @@ class ResBlock(TimestepBlock):
nn.SiLU(),
nn.Dropout(p=dropout),
zero_module(
conv_nd(dims, self.out_channels, self.out_channels, 3, padding=1)
conv_nd(
dims, self.out_channels, self.out_channels, 3, padding=1
)
),
)
@@ -238,7 +260,9 @@ class ResBlock(TimestepBlock):
dims, channels, self.out_channels, 3, padding=1
)
else:
self.skip_connection = conv_nd(dims, channels, self.out_channels, 1)
self.skip_connection = conv_nd(
dims, channels, self.out_channels, 1
)
def forward(self, x, emb):
"""
@@ -251,7 +275,6 @@ class ResBlock(TimestepBlock):
self._forward, (x, emb), self.parameters(), self.use_checkpoint
)
def _forward(self, x, emb):
if self.updown:
in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1]
@@ -297,7 +320,7 @@ class AttentionBlock(nn.Module):
else:
assert (
channels % num_head_channels == 0
), f"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}"
), f'q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}'
self.num_heads = channels // num_head_channels
self.use_checkpoint = use_checkpoint
self.norm = normalization(channels)
@@ -312,8 +335,10 @@ class AttentionBlock(nn.Module):
self.proj_out = zero_module(conv_nd(1, channels, channels, 1))
def forward(self, x):
return checkpoint(self._forward, (x,), self.parameters(), True) # TODO: check checkpoint usage, is True # TODO: fix the .half call!!!
#return pt_checkpoint(self._forward, x) # pytorch
return checkpoint(
self._forward, (x,), self.parameters(), True
) # TODO: check checkpoint usage, is True # TODO: fix the .half call!!!
# return pt_checkpoint(self._forward, x) # pytorch
def _forward(self, x):
b, c, *spatial = x.shape
@@ -340,7 +365,7 @@ def count_flops_attn(model, _x, y):
# We perform two matmuls with the same number of ops.
# The first computes the weight matrix, the second computes
# the combination of the value vectors.
matmul_ops = 2 * b * (num_spatial ** 2) * c
matmul_ops = 2 * b * (num_spatial**2) * c
model.total_ops += th.DoubleTensor([matmul_ops])
@@ -362,13 +387,15 @@ class QKVAttentionLegacy(nn.Module):
bs, width, length = qkv.shape
assert width % (3 * self.n_heads) == 0
ch = width // (3 * self.n_heads)
q, k, v = qkv.reshape(bs * self.n_heads, ch * 3, length).split(ch, dim=1)
q, k, v = qkv.reshape(bs * self.n_heads, ch * 3, length).split(
ch, dim=1
)
scale = 1 / math.sqrt(math.sqrt(ch))
weight = th.einsum(
"bct,bcs->bts", q * scale, k * scale
'bct,bcs->bts', q * scale, k * scale
) # More stable with f16 than dividing afterwards
weight = th.softmax(weight.float(), dim=-1).type(weight.dtype)
a = th.einsum("bts,bcs->bct", weight, v)
a = th.einsum('bts,bcs->bct', weight, v)
return a.reshape(bs, -1, length)
@staticmethod
@@ -397,12 +424,14 @@ class QKVAttention(nn.Module):
q, k, v = qkv.chunk(3, dim=1)
scale = 1 / math.sqrt(math.sqrt(ch))
weight = th.einsum(
"bct,bcs->bts",
'bct,bcs->bts',
(q * scale).view(bs * self.n_heads, ch, length),
(k * scale).view(bs * self.n_heads, ch, length),
) # More stable with f16 than dividing afterwards
weight = th.softmax(weight.float(), dim=-1).type(weight.dtype)
a = th.einsum("bts,bcs->bct", weight, v.reshape(bs * self.n_heads, ch, length))
a = th.einsum(
'bts,bcs->bct', weight, v.reshape(bs * self.n_heads, ch, length)
)
return a.reshape(bs, -1, length)
@staticmethod
@@ -461,19 +490,24 @@ class UNetModel(nn.Module):
use_scale_shift_norm=False,
resblock_updown=False,
use_new_attention_order=False,
use_spatial_transformer=False, # custom transformer support
transformer_depth=1, # custom transformer support
context_dim=None, # custom transformer support
n_embed=None, # custom support for prediction of discrete ids into codebook of first stage vq model
use_spatial_transformer=False, # custom transformer support
transformer_depth=1, # custom transformer support
context_dim=None, # custom transformer support
n_embed=None, # custom support for prediction of discrete ids into codebook of first stage vq model
legacy=True,
):
super().__init__()
if use_spatial_transformer:
assert context_dim is not None, 'Fool!! You forgot to include the dimension of your cross-attention conditioning...'
assert (
context_dim is not None
), 'Fool!! You forgot to include the dimension of your cross-attention conditioning...'
if context_dim is not None:
assert use_spatial_transformer, 'Fool!! You forgot to use the spatial transformer for your cross-attention conditioning...'
assert (
use_spatial_transformer
), 'Fool!! You forgot to use the spatial transformer for your cross-attention conditioning...'
from omegaconf.listconfig import ListConfig
if type(context_dim) == ListConfig:
context_dim = list(context_dim)
@@ -481,10 +515,14 @@ class UNetModel(nn.Module):
num_heads_upsample = num_heads
if num_heads == -1:
assert num_head_channels != -1, 'Either num_heads or num_head_channels has to be set'
assert (
num_head_channels != -1
), 'Either num_heads or num_head_channels has to be set'
if num_head_channels == -1:
assert num_heads != -1, 'Either num_heads or num_head_channels has to be set'
assert (
num_heads != -1
), 'Either num_heads or num_head_channels has to be set'
self.image_size = image_size
self.in_channels = in_channels
@@ -545,8 +583,12 @@ class UNetModel(nn.Module):
num_heads = ch // num_head_channels
dim_head = num_head_channels
if legacy:
#num_heads = 1
dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
# num_heads = 1
dim_head = (
ch // num_heads
if use_spatial_transformer
else num_head_channels
)
layers.append(
AttentionBlock(
ch,
@@ -554,8 +596,14 @@ class UNetModel(nn.Module):
num_heads=num_heads,
num_head_channels=dim_head,
use_new_attention_order=use_new_attention_order,
) if not use_spatial_transformer else SpatialTransformer(
ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim
)
if not use_spatial_transformer
else SpatialTransformer(
ch,
num_heads,
dim_head,
depth=transformer_depth,
context_dim=context_dim,
)
)
self.input_blocks.append(TimestepEmbedSequential(*layers))
@@ -592,8 +640,12 @@ class UNetModel(nn.Module):
num_heads = ch // num_head_channels
dim_head = num_head_channels
if legacy:
#num_heads = 1
dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
# num_heads = 1
dim_head = (
ch // num_heads
if use_spatial_transformer
else num_head_channels
)
self.middle_block = TimestepEmbedSequential(
ResBlock(
ch,
@@ -609,9 +661,15 @@ class UNetModel(nn.Module):
num_heads=num_heads,
num_head_channels=dim_head,
use_new_attention_order=use_new_attention_order,
) if not use_spatial_transformer else SpatialTransformer(
ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim
),
)
if not use_spatial_transformer
else SpatialTransformer(
ch,
num_heads,
dim_head,
depth=transformer_depth,
context_dim=context_dim,
),
ResBlock(
ch,
time_embed_dim,
@@ -646,8 +704,12 @@ class UNetModel(nn.Module):
num_heads = ch // num_head_channels
dim_head = num_head_channels
if legacy:
#num_heads = 1
dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
# num_heads = 1
dim_head = (
ch // num_heads
if use_spatial_transformer
else num_head_channels
)
layers.append(
AttentionBlock(
ch,
@@ -655,8 +717,14 @@ class UNetModel(nn.Module):
num_heads=num_heads_upsample,
num_head_channels=dim_head,
use_new_attention_order=use_new_attention_order,
) if not use_spatial_transformer else SpatialTransformer(
ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim
)
if not use_spatial_transformer
else SpatialTransformer(
ch,
num_heads,
dim_head,
depth=transformer_depth,
context_dim=context_dim,
)
)
if level and i == num_res_blocks:
@@ -673,7 +741,9 @@ class UNetModel(nn.Module):
up=True,
)
if resblock_updown
else Upsample(ch, conv_resample, dims=dims, out_channels=out_ch)
else Upsample(
ch, conv_resample, dims=dims, out_channels=out_ch
)
)
ds //= 2
self.output_blocks.append(TimestepEmbedSequential(*layers))
@@ -682,14 +752,16 @@ class UNetModel(nn.Module):
self.out = nn.Sequential(
normalization(ch),
nn.SiLU(),
zero_module(conv_nd(dims, model_channels, out_channels, 3, padding=1)),
zero_module(
conv_nd(dims, model_channels, out_channels, 3, padding=1)
),
)
if self.predict_codebook_ids:
self.id_predictor = nn.Sequential(
normalization(ch),
conv_nd(dims, model_channels, n_embed, 1),
#nn.LogSoftmax(dim=1) # change to cross_entropy and produce non-normalized logits
)
normalization(ch),
conv_nd(dims, model_channels, n_embed, 1),
# nn.LogSoftmax(dim=1) # change to cross_entropy and produce non-normalized logits
)
def convert_to_fp16(self):
"""
@@ -707,7 +779,7 @@ class UNetModel(nn.Module):
self.middle_block.apply(convert_module_to_f32)
self.output_blocks.apply(convert_module_to_f32)
def forward(self, x, timesteps=None, context=None, y=None,**kwargs):
def forward(self, x, timesteps=None, context=None, y=None, **kwargs):
"""
Apply the model to an input batch.
:param x: an [N x C x ...] Tensor of inputs.
@@ -718,9 +790,11 @@ class UNetModel(nn.Module):
"""
assert (y is not None) == (
self.num_classes is not None
), "must specify y if and only if the model is class-conditional"
), 'must specify y if and only if the model is class-conditional'
hs = []
t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False)
t_emb = timestep_embedding(
timesteps, self.model_channels, repeat_only=False
)
emb = self.time_embed(t_emb)
if self.num_classes is not None:
@@ -768,9 +842,9 @@ class EncoderUNetModel(nn.Module):
use_scale_shift_norm=False,
resblock_updown=False,
use_new_attention_order=False,
pool="adaptive",
pool='adaptive',
*args,
**kwargs
**kwargs,
):
super().__init__()
@@ -888,7 +962,7 @@ class EncoderUNetModel(nn.Module):
)
self._feature_size += ch
self.pool = pool
if pool == "adaptive":
if pool == 'adaptive':
self.out = nn.Sequential(
normalization(ch),
nn.SiLU(),
@@ -896,7 +970,7 @@ class EncoderUNetModel(nn.Module):
zero_module(conv_nd(dims, ch, out_channels, 1)),
nn.Flatten(),
)
elif pool == "attention":
elif pool == 'attention':
assert num_head_channels != -1
self.out = nn.Sequential(
normalization(ch),
@@ -905,13 +979,13 @@ class EncoderUNetModel(nn.Module):
(image_size // ds), ch, num_head_channels, out_channels
),
)
elif pool == "spatial":
elif pool == 'spatial':
self.out = nn.Sequential(
nn.Linear(self._feature_size, 2048),
nn.ReLU(),
nn.Linear(2048, self.out_channels),
)
elif pool == "spatial_v2":
elif pool == 'spatial_v2':
self.out = nn.Sequential(
nn.Linear(self._feature_size, 2048),
normalization(2048),
@@ -919,7 +993,7 @@ class EncoderUNetModel(nn.Module):
nn.Linear(2048, self.out_channels),
)
else:
raise NotImplementedError(f"Unexpected {pool} pooling")
raise NotImplementedError(f'Unexpected {pool} pooling')
def convert_to_fp16(self):
"""
@@ -942,20 +1016,21 @@ class EncoderUNetModel(nn.Module):
:param timesteps: a 1-D batch of timesteps.
:return: an [N x K] Tensor of outputs.
"""
emb = self.time_embed(timestep_embedding(timesteps, self.model_channels))
emb = self.time_embed(
timestep_embedding(timesteps, self.model_channels)
)
results = []
h = x.type(self.dtype)
for module in self.input_blocks:
h = module(h, emb)
if self.pool.startswith("spatial"):
if self.pool.startswith('spatial'):
results.append(h.type(x.dtype).mean(dim=(2, 3)))
h = self.middle_block(h, emb)
if self.pool.startswith("spatial"):
if self.pool.startswith('spatial'):
results.append(h.type(x.dtype).mean(dim=(2, 3)))
h = th.cat(results, axis=-1)
return self.out(h)
else:
h = h.type(x.dtype)
return self.out(h)

View File

@@ -18,15 +18,24 @@ from einops import repeat
from ldm.util import instantiate_from_config
def make_beta_schedule(schedule, n_timestep, linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3):
if schedule == "linear":
def make_beta_schedule(
schedule, n_timestep, linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3
):
if schedule == 'linear':
betas = (
torch.linspace(linear_start ** 0.5, linear_end ** 0.5, n_timestep, dtype=torch.float64) ** 2
torch.linspace(
linear_start**0.5,
linear_end**0.5,
n_timestep,
dtype=torch.float64,
)
** 2
)
elif schedule == "cosine":
elif schedule == 'cosine':
timesteps = (
torch.arange(n_timestep + 1, dtype=torch.float64) / n_timestep + cosine_s
torch.arange(n_timestep + 1, dtype=torch.float64) / n_timestep
+ cosine_s
)
alphas = timesteps / (1 + cosine_s) * np.pi / 2
alphas = torch.cos(alphas).pow(2)
@@ -34,23 +43,41 @@ def make_beta_schedule(schedule, n_timestep, linear_start=1e-4, linear_end=2e-2,
betas = 1 - alphas[1:] / alphas[:-1]
betas = np.clip(betas, a_min=0, a_max=0.999)
elif schedule == "sqrt_linear":
betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64)
elif schedule == "sqrt":
betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64) ** 0.5
elif schedule == 'sqrt_linear':
betas = torch.linspace(
linear_start, linear_end, n_timestep, dtype=torch.float64
)
elif schedule == 'sqrt':
betas = (
torch.linspace(
linear_start, linear_end, n_timestep, dtype=torch.float64
)
** 0.5
)
else:
raise ValueError(f"schedule '{schedule}' unknown.")
return betas.numpy()
def make_ddim_timesteps(ddim_discr_method, num_ddim_timesteps, num_ddpm_timesteps, verbose=True):
def make_ddim_timesteps(
ddim_discr_method, num_ddim_timesteps, num_ddpm_timesteps, verbose=True
):
if ddim_discr_method == 'uniform':
c = num_ddpm_timesteps // num_ddim_timesteps
ddim_timesteps = np.asarray(list(range(0, num_ddpm_timesteps, c)))
elif ddim_discr_method == 'quad':
ddim_timesteps = ((np.linspace(0, np.sqrt(num_ddpm_timesteps * .8), num_ddim_timesteps)) ** 2).astype(int)
ddim_timesteps = (
(
np.linspace(
0, np.sqrt(num_ddpm_timesteps * 0.8), num_ddim_timesteps
)
)
** 2
).astype(int)
else:
raise NotImplementedError(f'There is no ddim discretization method called "{ddim_discr_method}"')
raise NotImplementedError(
f'There is no ddim discretization method called "{ddim_discr_method}"'
)
# assert ddim_timesteps.shape[0] == num_ddim_timesteps
# add one to get the final alpha values right (the ones from first scale to data during sampling)
@@ -60,17 +87,27 @@ def make_ddim_timesteps(ddim_discr_method, num_ddim_timesteps, num_ddpm_timestep
return steps_out
def make_ddim_sampling_parameters(alphacums, ddim_timesteps, eta, verbose=True):
def make_ddim_sampling_parameters(
alphacums, ddim_timesteps, eta, verbose=True
):
# select alphas for computing the variance schedule
alphas = alphacums[ddim_timesteps]
alphas_prev = np.asarray([alphacums[0]] + alphacums[ddim_timesteps[:-1]].tolist())
alphas_prev = np.asarray(
[alphacums[0]] + alphacums[ddim_timesteps[:-1]].tolist()
)
# according the the formula provided in https://arxiv.org/abs/2010.02502
sigmas = eta * np.sqrt((1 - alphas_prev) / (1 - alphas) * (1 - alphas / alphas_prev))
sigmas = eta * np.sqrt(
(1 - alphas_prev) / (1 - alphas) * (1 - alphas / alphas_prev)
)
if verbose:
print(f'Selected alphas for ddim sampler: a_t: {alphas}; a_(t-1): {alphas_prev}')
print(f'For the chosen value of eta, which is {eta}, '
f'this results in the following sigma_t schedule for ddim sampler {sigmas}')
print(
f'Selected alphas for ddim sampler: a_t: {alphas}; a_(t-1): {alphas_prev}'
)
print(
f'For the chosen value of eta, which is {eta}, '
f'this results in the following sigma_t schedule for ddim sampler {sigmas}'
)
return sigmas, alphas, alphas_prev
@@ -109,7 +146,9 @@ def checkpoint(func, inputs, params, flag):
explicitly take as arguments.
:param flag: if False, disable gradient checkpointing.
"""
if False: # disabled checkpointing to allow requires_grad = False for main model
if (
False
): # disabled checkpointing to allow requires_grad = False for main model
args = tuple(inputs) + tuple(params)
return CheckpointFunction.apply(func, len(inputs), *args)
else:
@@ -129,7 +168,9 @@ class CheckpointFunction(torch.autograd.Function):
@staticmethod
def backward(ctx, *output_grads):
ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors]
ctx.input_tensors = [
x.detach().requires_grad_(True) for x in ctx.input_tensors
]
with torch.enable_grad():
# Fixes a bug where the first op in run_function modifies the
# Tensor storage in place, which is not allowed for detach()'d
@@ -160,12 +201,16 @@ def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False):
if not repeat_only:
half = dim // 2
freqs = torch.exp(
-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
-math.log(max_period)
* torch.arange(start=0, end=half, dtype=torch.float32)
/ half
).to(device=timesteps.device)
args = timesteps[:, None].float() * freqs[None]
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
if dim % 2:
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
embedding = torch.cat(
[embedding, torch.zeros_like(embedding[:, :1])], dim=-1
)
else:
embedding = repeat(timesteps, 'b -> b d', d=dim)
return embedding
@@ -215,6 +260,7 @@ class GroupNorm32(nn.GroupNorm):
def forward(self, x):
return super().forward(x.float()).type(x.dtype)
def conv_nd(dims, *args, **kwargs):
"""
Create a 1D, 2D, or 3D convolution module.
@@ -225,7 +271,7 @@ def conv_nd(dims, *args, **kwargs):
return nn.Conv2d(*args, **kwargs)
elif dims == 3:
return nn.Conv3d(*args, **kwargs)
raise ValueError(f"unsupported dimensions: {dims}")
raise ValueError(f'unsupported dimensions: {dims}')
def linear(*args, **kwargs):
@@ -245,15 +291,16 @@ def avg_pool_nd(dims, *args, **kwargs):
return nn.AvgPool2d(*args, **kwargs)
elif dims == 3:
return nn.AvgPool3d(*args, **kwargs)
raise ValueError(f"unsupported dimensions: {dims}")
raise ValueError(f'unsupported dimensions: {dims}')
class HybridConditioner(nn.Module):
def __init__(self, c_concat_config, c_crossattn_config):
super().__init__()
self.concat_conditioner = instantiate_from_config(c_concat_config)
self.crossattn_conditioner = instantiate_from_config(c_crossattn_config)
self.crossattn_conditioner = instantiate_from_config(
c_crossattn_config
)
def forward(self, c_concat, c_crossattn):
c_concat = self.concat_conditioner(c_concat)
@@ -262,6 +309,8 @@ class HybridConditioner(nn.Module):
def noise_like(shape, device, repeat=False):
repeat_noise = lambda: torch.randn((1, *shape[1:]), device=device).repeat(shape[0], *((1,) * (len(shape) - 1)))
repeat_noise = lambda: torch.randn((1, *shape[1:]), device=device).repeat(
shape[0], *((1,) * (len(shape) - 1))
)
noise = lambda: torch.randn(shape, device=device)
return repeat_noise() if repeat else noise()
return repeat_noise() if repeat else noise()

View File

@@ -30,33 +30,45 @@ class DiagonalGaussianDistribution(object):
self.std = torch.exp(0.5 * self.logvar)
self.var = torch.exp(self.logvar)
if self.deterministic:
self.var = self.std = torch.zeros_like(self.mean).to(device=self.parameters.device)
self.var = self.std = torch.zeros_like(self.mean).to(
device=self.parameters.device
)
def sample(self):
x = self.mean + self.std * torch.randn(self.mean.shape).to(device=self.parameters.device)
x = self.mean + self.std * torch.randn(self.mean.shape).to(
device=self.parameters.device
)
return x
def kl(self, other=None):
if self.deterministic:
return torch.Tensor([0.])
return torch.Tensor([0.0])
else:
if other is None:
return 0.5 * torch.sum(torch.pow(self.mean, 2)
+ self.var - 1.0 - self.logvar,
dim=[1, 2, 3])
return 0.5 * torch.sum(
torch.pow(self.mean, 2) + self.var - 1.0 - self.logvar,
dim=[1, 2, 3],
)
else:
return 0.5 * torch.sum(
torch.pow(self.mean - other.mean, 2) / other.var
+ self.var / other.var - 1.0 - self.logvar + other.logvar,
dim=[1, 2, 3])
+ self.var / other.var
- 1.0
- self.logvar
+ other.logvar,
dim=[1, 2, 3],
)
def nll(self, sample, dims=[1,2,3]):
def nll(self, sample, dims=[1, 2, 3]):
if self.deterministic:
return torch.Tensor([0.])
return torch.Tensor([0.0])
logtwopi = np.log(2.0 * np.pi)
return 0.5 * torch.sum(
logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var,
dim=dims)
logtwopi
+ self.logvar
+ torch.pow(sample - self.mean, 2) / self.var,
dim=dims,
)
def mode(self):
return self.mean
@@ -74,7 +86,7 @@ def normal_kl(mean1, logvar1, mean2, logvar2):
if isinstance(obj, torch.Tensor):
tensor = obj
break
assert tensor is not None, "at least one argument must be a Tensor"
assert tensor is not None, 'at least one argument must be a Tensor'
# Force variances to be Tensors. Broadcasting helps convert scalars to
# Tensors, but it does not work for torch.exp().

View File

@@ -10,24 +10,30 @@ class LitEma(nn.Module):
self.m_name2s_name = {}
self.register_buffer('decay', torch.tensor(decay, dtype=torch.float32))
self.register_buffer('num_updates', torch.tensor(0,dtype=torch.int) if use_num_upates
else torch.tensor(-1,dtype=torch.int))
self.register_buffer(
'num_updates',
torch.tensor(0, dtype=torch.int)
if use_num_upates
else torch.tensor(-1, dtype=torch.int),
)
for name, p in model.named_parameters():
if p.requires_grad:
#remove as '.'-character is not allowed in buffers
s_name = name.replace('.','')
self.m_name2s_name.update({name:s_name})
self.register_buffer(s_name,p.clone().detach().data)
# remove as '.'-character is not allowed in buffers
s_name = name.replace('.', '')
self.m_name2s_name.update({name: s_name})
self.register_buffer(s_name, p.clone().detach().data)
self.collected_params = []
def forward(self,model):
def forward(self, model):
decay = self.decay
if self.num_updates >= 0:
self.num_updates += 1
decay = min(self.decay,(1 + self.num_updates) / (10 + self.num_updates))
decay = min(
self.decay, (1 + self.num_updates) / (10 + self.num_updates)
)
one_minus_decay = 1.0 - decay
@@ -38,8 +44,12 @@ class LitEma(nn.Module):
for key in m_param:
if m_param[key].requires_grad:
sname = self.m_name2s_name[key]
shadow_params[sname] = shadow_params[sname].type_as(m_param[key])
shadow_params[sname].sub_(one_minus_decay * (shadow_params[sname] - m_param[key]))
shadow_params[sname] = shadow_params[sname].type_as(
m_param[key]
)
shadow_params[sname].sub_(
one_minus_decay * (shadow_params[sname] - m_param[key])
)
else:
assert not key in self.m_name2s_name
@@ -48,7 +58,9 @@ class LitEma(nn.Module):
shadow_params = dict(self.named_buffers())
for key in m_param:
if m_param[key].requires_grad:
m_param[key].data.copy_(shadow_params[self.m_name2s_name[key]].data)
m_param[key].data.copy_(
shadow_params[self.m_name2s_name[key]].data
)
else:
assert not key in self.m_name2s_name

View File

@@ -8,18 +8,29 @@ from ldm.data.personalized import per_img_token_list
from transformers import CLIPTokenizer
from functools import partial
DEFAULT_PLACEHOLDER_TOKEN = ["*"]
DEFAULT_PLACEHOLDER_TOKEN = ['*']
PROGRESSIVE_SCALE = 2000
def get_clip_token_for_string(tokenizer, string):
batch_encoding = tokenizer(string, truncation=True, max_length=77, return_length=True,
return_overflowing_tokens=False, padding="max_length", return_tensors="pt")
tokens = batch_encoding["input_ids"]
assert torch.count_nonzero(tokens - 49407) == 2, f"String '{string}' maps to more than a single token. Please use another string"
batch_encoding = tokenizer(
string,
truncation=True,
max_length=77,
return_length=True,
return_overflowing_tokens=False,
padding='max_length',
return_tensors='pt',
)
tokens = batch_encoding['input_ids']
assert (
torch.count_nonzero(tokens - 49407) == 2
), f"String '{string}' maps to more than a single token. Please use another string"
return tokens[0, 1]
def get_bert_token_for_string(tokenizer, string):
token = tokenizer(string)
# assert torch.count_nonzero(token) == 3, f"String '{string}' maps to more than a single token. Please use another string"
@@ -28,42 +39,54 @@ def get_bert_token_for_string(tokenizer, string):
return token
def get_embedding_for_clip_token(embedder, token):
return embedder(token.unsqueeze(0))[0, 0]
class EmbeddingManager(nn.Module):
def __init__(
self,
embedder,
placeholder_strings=None,
initializer_words=None,
per_image_tokens=False,
num_vectors_per_token=1,
progressive_words=False,
**kwargs
self,
embedder,
placeholder_strings=None,
initializer_words=None,
per_image_tokens=False,
num_vectors_per_token=1,
progressive_words=False,
**kwargs,
):
super().__init__()
self.string_to_token_dict = {}
self.string_to_param_dict = nn.ParameterDict()
self.initial_embeddings = nn.ParameterDict() # These should not be optimized
self.initial_embeddings = (
nn.ParameterDict()
) # These should not be optimized
self.progressive_words = progressive_words
self.progressive_counter = 0
self.max_vectors_per_token = num_vectors_per_token
if hasattr(embedder, 'tokenizer'): # using Stable Diffusion's CLIP encoder
if hasattr(
embedder, 'tokenizer'
): # using Stable Diffusion's CLIP encoder
self.is_clip = True
get_token_for_string = partial(get_clip_token_for_string, embedder.tokenizer)
get_embedding_for_tkn = partial(get_embedding_for_clip_token, embedder.transformer.text_model.embeddings)
get_token_for_string = partial(
get_clip_token_for_string, embedder.tokenizer
)
get_embedding_for_tkn = partial(
get_embedding_for_clip_token,
embedder.transformer.text_model.embeddings,
)
token_dim = 1280
else: # using LDM's BERT encoder
else: # using LDM's BERT encoder
self.is_clip = False
get_token_for_string = partial(get_bert_token_for_string, embedder.tknz_fn)
get_token_for_string = partial(
get_bert_token_for_string, embedder.tknz_fn
)
get_embedding_for_tkn = embedder.transformer.token_emb
token_dim = 1280
@@ -71,79 +94,142 @@ class EmbeddingManager(nn.Module):
placeholder_strings.extend(per_img_token_list)
for idx, placeholder_string in enumerate(placeholder_strings):
token = get_token_for_string(placeholder_string)
if initializer_words and idx < len(initializer_words):
init_word_token = get_token_for_string(initializer_words[idx])
with torch.no_grad():
init_word_embedding = get_embedding_for_tkn(init_word_token.cpu())
init_word_embedding = get_embedding_for_tkn(
init_word_token.cpu()
)
token_params = torch.nn.Parameter(init_word_embedding.unsqueeze(0).repeat(num_vectors_per_token, 1), requires_grad=True)
self.initial_embeddings[placeholder_string] = torch.nn.Parameter(init_word_embedding.unsqueeze(0).repeat(num_vectors_per_token, 1), requires_grad=False)
token_params = torch.nn.Parameter(
init_word_embedding.unsqueeze(0).repeat(
num_vectors_per_token, 1
),
requires_grad=True,
)
self.initial_embeddings[
placeholder_string
] = torch.nn.Parameter(
init_word_embedding.unsqueeze(0).repeat(
num_vectors_per_token, 1
),
requires_grad=False,
)
else:
token_params = torch.nn.Parameter(torch.rand(size=(num_vectors_per_token, token_dim), requires_grad=True))
token_params = torch.nn.Parameter(
torch.rand(
size=(num_vectors_per_token, token_dim),
requires_grad=True,
)
)
self.string_to_token_dict[placeholder_string] = token
self.string_to_param_dict[placeholder_string] = token_params
def forward(
self,
tokenized_text,
embedded_text,
self,
tokenized_text,
embedded_text,
):
b, n, device = *tokenized_text.shape, tokenized_text.device
for placeholder_string, placeholder_token in self.string_to_token_dict.items():
for (
placeholder_string,
placeholder_token,
) in self.string_to_token_dict.items():
placeholder_embedding = self.string_to_param_dict[placeholder_string].to(device)
placeholder_embedding = self.string_to_param_dict[
placeholder_string
].to(device)
if self.max_vectors_per_token == 1: # If there's only one vector per token, we can do a simple replacement
placeholder_idx = torch.where(tokenized_text == placeholder_token.to(device))
if (
self.max_vectors_per_token == 1
): # If there's only one vector per token, we can do a simple replacement
placeholder_idx = torch.where(
tokenized_text == placeholder_token.to(device)
)
embedded_text[placeholder_idx] = placeholder_embedding
else: # otherwise, need to insert and keep track of changing indices
else: # otherwise, need to insert and keep track of changing indices
if self.progressive_words:
self.progressive_counter += 1
max_step_tokens = 1 + self.progressive_counter // PROGRESSIVE_SCALE
max_step_tokens = (
1 + self.progressive_counter // PROGRESSIVE_SCALE
)
else:
max_step_tokens = self.max_vectors_per_token
num_vectors_for_token = min(placeholder_embedding.shape[0], max_step_tokens)
num_vectors_for_token = min(
placeholder_embedding.shape[0], max_step_tokens
)
placeholder_rows, placeholder_cols = torch.where(tokenized_text == placeholder_token.to(device))
placeholder_rows, placeholder_cols = torch.where(
tokenized_text == placeholder_token.to(device)
)
if placeholder_rows.nelement() == 0:
continue
sorted_cols, sort_idx = torch.sort(placeholder_cols, descending=True)
sorted_cols, sort_idx = torch.sort(
placeholder_cols, descending=True
)
sorted_rows = placeholder_rows[sort_idx]
for idx in range(len(sorted_rows)):
row = sorted_rows[idx]
col = sorted_cols[idx]
new_token_row = torch.cat([tokenized_text[row][:col], placeholder_token.repeat(num_vectors_for_token).to(device), tokenized_text[row][col + 1:]], axis=0)[:n]
new_embed_row = torch.cat([embedded_text[row][:col], placeholder_embedding[:num_vectors_for_token], embedded_text[row][col + 1:]], axis=0)[:n]
new_token_row = torch.cat(
[
tokenized_text[row][:col],
placeholder_token.repeat(num_vectors_for_token).to(
device
),
tokenized_text[row][col + 1 :],
],
axis=0,
)[:n]
new_embed_row = torch.cat(
[
embedded_text[row][:col],
placeholder_embedding[:num_vectors_for_token],
embedded_text[row][col + 1 :],
],
axis=0,
)[:n]
embedded_text[row] = new_embed_row
embedded_text[row] = new_embed_row
tokenized_text[row] = new_token_row
return embedded_text
def save(self, ckpt_path):
torch.save({"string_to_token": self.string_to_token_dict,
"string_to_param": self.string_to_param_dict}, ckpt_path)
torch.save(
{
'string_to_token': self.string_to_token_dict,
'string_to_param': self.string_to_param_dict,
},
ckpt_path,
)
def load(self, ckpt_path):
def load(self, ckpt_path, full=True):
ckpt = torch.load(ckpt_path, map_location='cpu')
self.string_to_token_dict = ckpt["string_to_token"]
self.string_to_param_dict = ckpt["string_to_param"]
if not full:
for key, value in self.string_to_param_dict.items():
self.string_to_param_dict[key] = torch.nn.Parameter(value.half())
def get_embedding_norms_squared(self):
all_params = torch.cat(list(self.string_to_param_dict.values()), axis=0) # num_placeholders x embedding_dim
param_norm_squared = (all_params * all_params).sum(axis=-1) # num_placeholders
all_params = torch.cat(
list(self.string_to_param_dict.values()), axis=0
) # num_placeholders x embedding_dim
param_norm_squared = (all_params * all_params).sum(
axis=-1
) # num_placeholders
return param_norm_squared
@@ -151,14 +237,19 @@ class EmbeddingManager(nn.Module):
return self.string_to_param_dict.parameters()
def embedding_to_coarse_loss(self):
loss = 0.
loss = 0.0
num_embeddings = len(self.initial_embeddings)
for key in self.initial_embeddings:
optimized = self.string_to_param_dict[key]
coarse = self.initial_embeddings[key].clone().to(optimized.device)
loss = loss + (optimized - coarse) @ (optimized - coarse).T / num_embeddings
loss = (
loss
+ (optimized - coarse)
@ (optimized - coarse).T
/ num_embeddings
)
return loss
return loss

View File

@@ -5,30 +5,41 @@ import clip
from einops import rearrange, repeat
from transformers import CLIPTokenizer, CLIPTextModel
import kornia
from ldm.dream.devices import choose_torch_device
from ldm.modules.x_transformer import Encoder, TransformerWrapper # TODO: can we directly rely on lucidrains code and simply add this as a reuirement? --> test
from ldm.modules.x_transformer import (
Encoder,
TransformerWrapper,
) # TODO: can we directly rely on lucidrains code and simply add this as a reuirement? --> test
def _expand_mask(mask, dtype, tgt_len = None):
def _expand_mask(mask, dtype, tgt_len=None):
"""
Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
"""
bsz, src_len = mask.size()
tgt_len = tgt_len if tgt_len is not None else src_len
expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype)
expanded_mask = (
mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype)
)
inverted_mask = 1.0 - expanded_mask
return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)
return inverted_mask.masked_fill(
inverted_mask.to(torch.bool), torch.finfo(dtype).min
)
def _build_causal_attention_mask(bsz, seq_len, dtype):
# lazily create causal attention mask, with full attention between the vision tokens
# pytorch uses additive attention mask; fill with -inf
mask = torch.empty(bsz, seq_len, seq_len, dtype=dtype)
mask.fill_(torch.tensor(torch.finfo(dtype).min))
mask.triu_(1) # zero out the lower diagonal
mask = mask.unsqueeze(1) # expand mask
return mask
# lazily create causal attention mask, with full attention between the vision tokens
# pytorch uses additive attention mask; fill with -inf
mask = torch.empty(bsz, seq_len, seq_len, dtype=dtype)
mask.fill_(torch.tensor(torch.finfo(dtype).min))
mask.triu_(1) # zero out the lower diagonal
mask = mask.unsqueeze(1) # expand mask
return mask
class AbstractEncoder(nn.Module):
def __init__(self):
@@ -38,7 +49,6 @@ class AbstractEncoder(nn.Module):
raise NotImplementedError
class ClassEmbedder(nn.Module):
def __init__(self, embed_dim, n_classes=1000, key='class'):
super().__init__()
@@ -56,11 +66,22 @@ class ClassEmbedder(nn.Module):
class TransformerEmbedder(AbstractEncoder):
"""Some transformer encoder layers"""
def __init__(self, n_embed, n_layer, vocab_size, max_seq_len=77, device="cuda"):
def __init__(
self,
n_embed,
n_layer,
vocab_size,
max_seq_len=77,
device=choose_torch_device(),
):
super().__init__()
self.device = device
self.transformer = TransformerWrapper(num_tokens=vocab_size, max_seq_len=max_seq_len,
attn_layers=Encoder(dim=n_embed, depth=n_layer))
self.transformer = TransformerWrapper(
num_tokens=vocab_size,
max_seq_len=max_seq_len,
attn_layers=Encoder(dim=n_embed, depth=n_layer),
)
def forward(self, tokens):
tokens = tokens.to(self.device) # meh
@@ -72,27 +93,44 @@ class TransformerEmbedder(AbstractEncoder):
class BERTTokenizer(AbstractEncoder):
""" Uses a pretrained BERT tokenizer by huggingface. Vocab size: 30522 (?)"""
def __init__(self, device="cuda", vq_interface=True, max_length=77):
"""Uses a pretrained BERT tokenizer by huggingface. Vocab size: 30522 (?)"""
def __init__(
self, device=choose_torch_device(), vq_interface=True, max_length=77
):
super().__init__()
from transformers import BertTokenizerFast # TODO: add to reuquirements
from transformers import (
BertTokenizerFast,
) # TODO: add to reuquirements
# Modified to allow to run on non-internet connected compute nodes.
# Model needs to be loaded into cache from an internet-connected machine
# by running:
# from transformers import BertTokenizerFast
# BertTokenizerFast.from_pretrained("bert-base-uncased")
try:
self.tokenizer = BertTokenizerFast.from_pretrained("bert-base-uncased",local_files_only=True)
self.tokenizer = BertTokenizerFast.from_pretrained(
'bert-base-uncased', local_files_only=True
)
except OSError:
raise SystemExit("* Couldn't load Bert tokenizer files. Try running scripts/preload_models.py from an internet-conected machine.")
raise SystemExit(
"* Couldn't load Bert tokenizer files. Try running scripts/preload_models.py from an internet-conected machine."
)
self.device = device
self.vq_interface = vq_interface
self.max_length = max_length
def forward(self, text):
batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True,
return_overflowing_tokens=False, padding="max_length", return_tensors="pt")
tokens = batch_encoding["input_ids"].to(self.device)
batch_encoding = self.tokenizer(
text,
truncation=True,
max_length=self.max_length,
return_length=True,
return_overflowing_tokens=False,
padding='max_length',
return_tensors='pt',
)
tokens = batch_encoding['input_ids'].to(self.device)
return tokens
@torch.no_grad()
@@ -108,53 +146,84 @@ class BERTTokenizer(AbstractEncoder):
class BERTEmbedder(AbstractEncoder):
"""Uses the BERT tokenizr model and add some transformer encoder layers"""
def __init__(self, n_embed, n_layer, vocab_size=30522, max_seq_len=77,
device="cuda",use_tokenizer=True, embedding_dropout=0.0):
def __init__(
self,
n_embed,
n_layer,
vocab_size=30522,
max_seq_len=77,
device=choose_torch_device(),
use_tokenizer=True,
embedding_dropout=0.0,
):
super().__init__()
self.use_tknz_fn = use_tokenizer
if self.use_tknz_fn:
self.tknz_fn = BERTTokenizer(vq_interface=False, max_length=max_seq_len)
self.tknz_fn = BERTTokenizer(
vq_interface=False, max_length=max_seq_len
)
self.device = device
self.transformer = TransformerWrapper(num_tokens=vocab_size, max_seq_len=max_seq_len,
attn_layers=Encoder(dim=n_embed, depth=n_layer),
emb_dropout=embedding_dropout)
self.transformer = TransformerWrapper(
num_tokens=vocab_size,
max_seq_len=max_seq_len,
attn_layers=Encoder(dim=n_embed, depth=n_layer),
emb_dropout=embedding_dropout,
)
def forward(self, text, embedding_manager=None):
if self.use_tknz_fn:
tokens = self.tknz_fn(text)#.to(self.device)
tokens = self.tknz_fn(text) # .to(self.device)
else:
tokens = text
z = self.transformer(tokens, return_embeddings=True, embedding_manager=embedding_manager)
z = self.transformer(
tokens, return_embeddings=True, embedding_manager=embedding_manager
)
return z
def encode(self, text, **kwargs):
# output of length 77
return self(text, **kwargs)
class SpatialRescaler(nn.Module):
def __init__(self,
n_stages=1,
method='bilinear',
multiplier=0.5,
in_channels=3,
out_channels=None,
bias=False):
def __init__(
self,
n_stages=1,
method='bilinear',
multiplier=0.5,
in_channels=3,
out_channels=None,
bias=False,
):
super().__init__()
self.n_stages = n_stages
assert self.n_stages >= 0
assert method in ['nearest','linear','bilinear','trilinear','bicubic','area']
assert method in [
'nearest',
'linear',
'bilinear',
'trilinear',
'bicubic',
'area',
]
self.multiplier = multiplier
self.interpolator = partial(torch.nn.functional.interpolate, mode=method)
self.interpolator = partial(
torch.nn.functional.interpolate, mode=method
)
self.remap_output = out_channels is not None
if self.remap_output:
print(f'Spatial Rescaler mapping from {in_channels} to {out_channels} channels after resizing.')
self.channel_mapper = nn.Conv2d(in_channels,out_channels,1,bias=bias)
print(
f'Spatial Rescaler mapping from {in_channels} to {out_channels} channels after resizing.'
)
self.channel_mapper = nn.Conv2d(
in_channels, out_channels, 1, bias=bias
)
def forward(self,x):
def forward(self, x):
for stage in range(self.n_stages):
x = self.interpolator(x, scale_factor=self.multiplier)
if self.remap_output:
x = self.channel_mapper(x)
return x
@@ -162,57 +231,83 @@ class SpatialRescaler(nn.Module):
def encode(self, x):
return self(x)
class FrozenCLIPEmbedder(AbstractEncoder):
"""Uses the CLIP transformer encoder for text (from Hugging Face)"""
def __init__(self, version="openai/clip-vit-large-patch14", device="cuda", max_length=77):
def __init__(
self,
version='openai/clip-vit-large-patch14',
device=choose_torch_device(),
max_length=77,
):
super().__init__()
self.tokenizer = CLIPTokenizer.from_pretrained(version,local_files_only=True)
self.transformer = CLIPTextModel.from_pretrained(version,local_files_only=True)
self.tokenizer = CLIPTokenizer.from_pretrained(
version, local_files_only=True
)
self.transformer = CLIPTextModel.from_pretrained(
version, local_files_only=True
)
self.device = device
self.max_length = max_length
self.freeze()
def embedding_forward(
self,
input_ids = None,
position_ids = None,
inputs_embeds = None,
embedding_manager = None,
) -> torch.Tensor:
self,
input_ids=None,
position_ids=None,
inputs_embeds=None,
embedding_manager=None,
) -> torch.Tensor:
seq_length = input_ids.shape[-1] if input_ids is not None else inputs_embeds.shape[-2]
seq_length = (
input_ids.shape[-1]
if input_ids is not None
else inputs_embeds.shape[-2]
)
if position_ids is None:
position_ids = self.position_ids[:, :seq_length]
if position_ids is None:
position_ids = self.position_ids[:, :seq_length]
if inputs_embeds is None:
inputs_embeds = self.token_embedding(input_ids)
if inputs_embeds is None:
inputs_embeds = self.token_embedding(input_ids)
if embedding_manager is not None:
inputs_embeds = embedding_manager(input_ids, inputs_embeds)
if embedding_manager is not None:
inputs_embeds = embedding_manager(input_ids, inputs_embeds)
position_embeddings = self.position_embedding(position_ids)
embeddings = inputs_embeds + position_embeddings
position_embeddings = self.position_embedding(position_ids)
embeddings = inputs_embeds + position_embeddings
return embeddings
return embeddings
self.transformer.text_model.embeddings.forward = embedding_forward.__get__(self.transformer.text_model.embeddings)
self.transformer.text_model.embeddings.forward = (
embedding_forward.__get__(self.transformer.text_model.embeddings)
)
def encoder_forward(
self,
inputs_embeds,
attention_mask = None,
causal_attention_mask = None,
output_attentions = None,
output_hidden_states = None,
return_dict = None,
attention_mask=None,
causal_attention_mask=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
):
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
output_attentions = (
output_attentions
if output_attentions is not None
else self.config.output_attentions
)
output_hidden_states = (
output_hidden_states
if output_hidden_states is not None
else self.config.output_hidden_states
)
return_dict = (
return_dict
if return_dict is not None
else self.config.use_return_dict
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
encoder_states = () if output_hidden_states else None
all_attentions = () if output_attentions else None
@@ -239,44 +334,61 @@ class FrozenCLIPEmbedder(AbstractEncoder):
return hidden_states
self.transformer.text_model.encoder.forward = encoder_forward.__get__(self.transformer.text_model.encoder)
self.transformer.text_model.encoder.forward = encoder_forward.__get__(
self.transformer.text_model.encoder
)
def text_encoder_forward(
self,
input_ids = None,
attention_mask = None,
position_ids = None,
output_attentions = None,
output_hidden_states = None,
return_dict = None,
embedding_manager = None,
input_ids=None,
attention_mask=None,
position_ids=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
embedding_manager=None,
):
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
output_attentions = (
output_attentions
if output_attentions is not None
else self.config.output_attentions
)
output_hidden_states = (
output_hidden_states
if output_hidden_states is not None
else self.config.output_hidden_states
)
return_dict = (
return_dict
if return_dict is not None
else self.config.use_return_dict
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if input_ids is None:
raise ValueError("You have to specify either input_ids")
raise ValueError('You have to specify either input_ids')
input_shape = input_ids.size()
input_ids = input_ids.view(-1, input_shape[-1])
hidden_states = self.embeddings(input_ids=input_ids, position_ids=position_ids, embedding_manager=embedding_manager)
hidden_states = self.embeddings(
input_ids=input_ids,
position_ids=position_ids,
embedding_manager=embedding_manager,
)
bsz, seq_len = input_shape
# CLIP's text model uses causal mask, prepare it here.
# https://github.com/openai/CLIP/blob/cfcffb90e69f37bf2ff1e988237a0fbe41f33c04/clip/model.py#L324
causal_attention_mask = _build_causal_attention_mask(bsz, seq_len, hidden_states.dtype).to(
hidden_states.device
)
causal_attention_mask = _build_causal_attention_mask(
bsz, seq_len, hidden_states.dtype
).to(hidden_states.device)
# expand attention_mask
if attention_mask is not None:
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
attention_mask = _expand_mask(attention_mask, hidden_states.dtype)
attention_mask = _expand_mask(
attention_mask, hidden_states.dtype
)
last_hidden_state = self.encoder(
inputs_embeds=hidden_states,
@@ -291,17 +403,19 @@ class FrozenCLIPEmbedder(AbstractEncoder):
return last_hidden_state
self.transformer.text_model.forward = text_encoder_forward.__get__(self.transformer.text_model)
self.transformer.text_model.forward = text_encoder_forward.__get__(
self.transformer.text_model
)
def transformer_forward(
self,
input_ids = None,
attention_mask = None,
position_ids = None,
output_attentions = None,
output_hidden_states = None,
return_dict = None,
embedding_manager = None,
input_ids=None,
attention_mask=None,
position_ids=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
embedding_manager=None,
):
return self.text_model(
input_ids=input_ids,
@@ -310,11 +424,12 @@ class FrozenCLIPEmbedder(AbstractEncoder):
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
embedding_manager = embedding_manager
embedding_manager=embedding_manager,
)
self.transformer.forward = transformer_forward.__get__(self.transformer)
self.transformer.forward = transformer_forward.__get__(
self.transformer
)
def freeze(self):
self.transformer = self.transformer.eval()
@@ -322,9 +437,16 @@ class FrozenCLIPEmbedder(AbstractEncoder):
param.requires_grad = False
def forward(self, text, **kwargs):
batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True,
return_overflowing_tokens=False, padding="max_length", return_tensors="pt")
tokens = batch_encoding["input_ids"].to(self.device)
batch_encoding = self.tokenizer(
text,
truncation=True,
max_length=self.max_length,
return_length=True,
return_overflowing_tokens=False,
padding='max_length',
return_tensors='pt',
)
tokens = batch_encoding['input_ids'].to(self.device)
z = self.transformer(input_ids=tokens, **kwargs)
return z
@@ -337,9 +459,17 @@ class FrozenCLIPTextEmbedder(nn.Module):
"""
Uses the CLIP transformer encoder for text.
"""
def __init__(self, version='ViT-L/14', device="cuda", max_length=77, n_repeat=1, normalize=True):
def __init__(
self,
version='ViT-L/14',
device=choose_torch_device(),
max_length=77,
n_repeat=1,
normalize=True,
):
super().__init__()
self.model, _ = clip.load(version, jit=False, device="cpu")
self.model, _ = clip.load(version, jit=False, device=device)
self.device = device
self.max_length = max_length
self.n_repeat = n_repeat
@@ -359,7 +489,7 @@ class FrozenCLIPTextEmbedder(nn.Module):
def encode(self, text):
z = self(text)
if z.ndim==2:
if z.ndim == 2:
z = z[:, None, :]
z = repeat(z, 'b 1 d -> b k d', k=self.n_repeat)
return z
@@ -367,29 +497,42 @@ class FrozenCLIPTextEmbedder(nn.Module):
class FrozenClipImageEmbedder(nn.Module):
"""
Uses the CLIP image encoder.
"""
Uses the CLIP image encoder.
"""
def __init__(
self,
model,
jit=False,
device='cuda' if torch.cuda.is_available() else 'cpu',
antialias=False,
):
self,
model,
jit=False,
device=choose_torch_device(),
antialias=False,
):
super().__init__()
self.model, _ = clip.load(name=model, device=device, jit=jit)
self.antialias = antialias
self.register_buffer('mean', torch.Tensor([0.48145466, 0.4578275, 0.40821073]), persistent=False)
self.register_buffer('std', torch.Tensor([0.26862954, 0.26130258, 0.27577711]), persistent=False)
self.register_buffer(
'mean',
torch.Tensor([0.48145466, 0.4578275, 0.40821073]),
persistent=False,
)
self.register_buffer(
'std',
torch.Tensor([0.26862954, 0.26130258, 0.27577711]),
persistent=False,
)
def preprocess(self, x):
# normalize to [0,1]
x = kornia.geometry.resize(x, (224, 224),
interpolation='bicubic',align_corners=True,
antialias=self.antialias)
x = (x + 1.) / 2.
x = kornia.geometry.resize(
x,
(224, 224),
interpolation='bicubic',
align_corners=True,
antialias=self.antialias,
)
x = (x + 1.0) / 2.0
# renormalize according to clip
x = kornia.enhance.normalize(x, self.mean, self.std)
return x
@@ -399,7 +542,8 @@ class FrozenClipImageEmbedder(nn.Module):
return self.model.encode_image(self.preprocess(x))
if __name__ == "__main__":
if __name__ == '__main__':
from ldm.util import count_params
model = FrozenCLIPEmbedder()
count_params(model, verbose=True)

View File

@@ -1,2 +1,6 @@
from ldm.modules.image_degradation.bsrgan import degradation_bsrgan_variant as degradation_fn_bsr
from ldm.modules.image_degradation.bsrgan_light import degradation_bsrgan_variant as degradation_fn_bsr_light
from ldm.modules.image_degradation.bsrgan import (
degradation_bsrgan_variant as degradation_fn_bsr,
)
from ldm.modules.image_degradation.bsrgan_light import (
degradation_bsrgan_variant as degradation_fn_bsr_light,
)

View File

@@ -27,16 +27,16 @@ import ldm.modules.image_degradation.utils_image as util
def modcrop_np(img, sf):
'''
"""
Args:
img: numpy image, WxH or WxHxC
sf: scale factor
Return:
cropped image
'''
"""
w, h = img.shape[:2]
im = np.copy(img)
return im[:w - w % sf, :h - h % sf, ...]
return im[: w - w % sf, : h - h % sf, ...]
"""
@@ -54,7 +54,9 @@ def analytic_kernel(k):
# Loop over the small kernel to fill the big one
for r in range(k_size):
for c in range(k_size):
big_k[2 * r:2 * r + k_size, 2 * c:2 * c + k_size] += k[r, c] * k
big_k[2 * r : 2 * r + k_size, 2 * c : 2 * c + k_size] += (
k[r, c] * k
)
# Crop the edges of the big kernel to ignore very small values and increase run time of SR
crop = k_size // 2
cropped_big_k = big_k[crop:-crop, crop:-crop]
@@ -63,7 +65,7 @@ def analytic_kernel(k):
def anisotropic_Gaussian(ksize=15, theta=np.pi, l1=6, l2=6):
""" generate an anisotropic Gaussian kernel
"""generate an anisotropic Gaussian kernel
Args:
ksize : e.g., 15, kernel size
theta : [0, pi], rotation angle range
@@ -74,7 +76,12 @@ def anisotropic_Gaussian(ksize=15, theta=np.pi, l1=6, l2=6):
k : kernel
"""
v = np.dot(np.array([[np.cos(theta), -np.sin(theta)], [np.sin(theta), np.cos(theta)]]), np.array([1., 0.]))
v = np.dot(
np.array(
[[np.cos(theta), -np.sin(theta)], [np.sin(theta), np.cos(theta)]]
),
np.array([1.0, 0.0]),
)
V = np.array([[v[0], v[1]], [v[1], -v[0]]])
D = np.array([[l1, 0], [0, l2]])
Sigma = np.dot(np.dot(V, D), np.linalg.inv(V))
@@ -126,24 +133,32 @@ def shift_pixel(x, sf, upper_left=True):
def blur(x, k):
'''
"""
x: image, NxcxHxW
k: kernel, Nx1xhxw
'''
"""
n, c = x.shape[:2]
p1, p2 = (k.shape[-2] - 1) // 2, (k.shape[-1] - 1) // 2
x = torch.nn.functional.pad(x, pad=(p1, p2, p1, p2), mode='replicate')
k = k.repeat(1, c, 1, 1)
k = k.view(-1, 1, k.shape[2], k.shape[3])
x = x.view(1, -1, x.shape[2], x.shape[3])
x = torch.nn.functional.conv2d(x, k, bias=None, stride=1, padding=0, groups=n * c)
x = torch.nn.functional.conv2d(
x, k, bias=None, stride=1, padding=0, groups=n * c
)
x = x.view(n, c, x.shape[2], x.shape[3])
return x
def gen_kernel(k_size=np.array([15, 15]), scale_factor=np.array([4, 4]), min_var=0.6, max_var=10., noise_level=0):
""""
def gen_kernel(
k_size=np.array([15, 15]),
scale_factor=np.array([4, 4]),
min_var=0.6,
max_var=10.0,
noise_level=0,
):
""" "
# modified version of https://github.com/assafshocher/BlindSR_dataset_generator
# Kai Zhang
# min_var = 0.175 * sf # variance of the gaussian kernel will be sampled between min_var and max_var
@@ -157,13 +172,16 @@ def gen_kernel(k_size=np.array([15, 15]), scale_factor=np.array([4, 4]), min_var
# Set COV matrix using Lambdas and Theta
LAMBDA = np.diag([lambda_1, lambda_2])
Q = np.array([[np.cos(theta), -np.sin(theta)],
[np.sin(theta), np.cos(theta)]])
Q = np.array(
[[np.cos(theta), -np.sin(theta)], [np.sin(theta), np.cos(theta)]]
)
SIGMA = Q @ LAMBDA @ Q.T
INV_SIGMA = np.linalg.inv(SIGMA)[None, None, :, :]
# Set expectation position (shifting kernel for aligned image)
MU = k_size // 2 - 0.5 * (scale_factor - 1) # - 0.5 * (scale_factor - k_size % 2)
MU = k_size // 2 - 0.5 * (
scale_factor - 1
) # - 0.5 * (scale_factor - k_size % 2)
MU = MU[None, None, :, None]
# Create meshgrid for Gaussian
@@ -188,7 +206,9 @@ def fspecial_gaussian(hsize, sigma):
hsize = [hsize, hsize]
siz = [(hsize[0] - 1.0) / 2.0, (hsize[1] - 1.0) / 2.0]
std = sigma
[x, y] = np.meshgrid(np.arange(-siz[1], siz[1] + 1), np.arange(-siz[0], siz[0] + 1))
[x, y] = np.meshgrid(
np.arange(-siz[1], siz[1] + 1), np.arange(-siz[0], siz[0] + 1)
)
arg = -(x * x + y * y) / (2 * std * std)
h = np.exp(arg)
h[h < scipy.finfo(float).eps * h.max()] = 0
@@ -208,10 +228,10 @@ def fspecial_laplacian(alpha):
def fspecial(filter_type, *args, **kwargs):
'''
"""
python code from:
https://github.com/ronaldosena/imagens-medicas-2/blob/40171a6c259edec7827a6693a93955de2bd39e76/Aulas/aula_2_-_uniform_filter/matlab_fspecial.py
'''
"""
if filter_type == 'gaussian':
return fspecial_gaussian(*args, **kwargs)
if filter_type == 'laplacian':
@@ -226,19 +246,19 @@ def fspecial(filter_type, *args, **kwargs):
def bicubic_degradation(x, sf=3):
'''
"""
Args:
x: HxWxC image, [0, 1]
sf: down-scale factor
Return:
bicubicly downsampled LR image
'''
"""
x = util.imresize_np(x, scale=1 / sf)
return x
def srmd_degradation(x, k, sf=3):
''' blur + bicubic downsampling
"""blur + bicubic downsampling
Args:
x: HxWxC image, [0, 1]
k: hxw, double
@@ -253,14 +273,16 @@ def srmd_degradation(x, k, sf=3):
pages={3262--3271},
year={2018}
}
'''
x = ndimage.filters.convolve(x, np.expand_dims(k, axis=2), mode='wrap') # 'nearest' | 'mirror'
"""
x = ndimage.filters.convolve(
x, np.expand_dims(k, axis=2), mode='wrap'
) # 'nearest' | 'mirror'
x = bicubic_degradation(x, sf=sf)
return x
def dpsr_degradation(x, k, sf=3):
''' bicubic downsampling + blur
"""bicubic downsampling + blur
Args:
x: HxWxC image, [0, 1]
k: hxw, double
@@ -275,21 +297,21 @@ def dpsr_degradation(x, k, sf=3):
pages={1671--1681},
year={2019}
}
'''
"""
x = bicubic_degradation(x, sf=sf)
x = ndimage.filters.convolve(x, np.expand_dims(k, axis=2), mode='wrap')
return x
def classical_degradation(x, k, sf=3):
''' blur + downsampling
"""blur + downsampling
Args:
x: HxWxC image, [0, 1]/[0, 255]
k: hxw, double
sf: down-scale factor
Return:
downsampled LR image
'''
"""
x = ndimage.filters.convolve(x, np.expand_dims(k, axis=2), mode='wrap')
# x = filters.correlate(x, np.expand_dims(np.flip(k), axis=2))
st = 0
@@ -328,10 +350,19 @@ def add_blur(img, sf=4):
if random.random() < 0.5:
l1 = wd2 * random.random()
l2 = wd2 * random.random()
k = anisotropic_Gaussian(ksize=2 * random.randint(2, 11) + 3, theta=random.random() * np.pi, l1=l1, l2=l2)
k = anisotropic_Gaussian(
ksize=2 * random.randint(2, 11) + 3,
theta=random.random() * np.pi,
l1=l1,
l2=l2,
)
else:
k = fspecial('gaussian', 2 * random.randint(2, 11) + 3, wd * random.random())
img = ndimage.filters.convolve(img, np.expand_dims(k, axis=2), mode='mirror')
k = fspecial(
'gaussian', 2 * random.randint(2, 11) + 3, wd * random.random()
)
img = ndimage.filters.convolve(
img, np.expand_dims(k, axis=2), mode='mirror'
)
return img
@@ -344,7 +375,11 @@ def add_resize(img, sf=4):
sf1 = random.uniform(0.5 / sf, 1)
else:
sf1 = 1.0
img = cv2.resize(img, (int(sf1 * img.shape[1]), int(sf1 * img.shape[0])), interpolation=random.choice([1, 2, 3]))
img = cv2.resize(
img,
(int(sf1 * img.shape[1]), int(sf1 * img.shape[0])),
interpolation=random.choice([1, 2, 3]),
)
img = np.clip(img, 0.0, 1.0)
return img
@@ -366,19 +401,26 @@ def add_resize(img, sf=4):
# img = np.clip(img, 0.0, 1.0)
# return img
def add_Gaussian_noise(img, noise_level1=2, noise_level2=25):
noise_level = random.randint(noise_level1, noise_level2)
rnum = np.random.rand()
if rnum > 0.6: # add color Gaussian noise
img = img + np.random.normal(0, noise_level / 255.0, img.shape).astype(np.float32)
img = img + np.random.normal(0, noise_level / 255.0, img.shape).astype(
np.float32
)
elif rnum < 0.4: # add grayscale Gaussian noise
img = img + np.random.normal(0, noise_level / 255.0, (*img.shape[:2], 1)).astype(np.float32)
img = img + np.random.normal(
0, noise_level / 255.0, (*img.shape[:2], 1)
).astype(np.float32)
else: # add noise
L = noise_level2 / 255.
L = noise_level2 / 255.0
D = np.diag(np.random.rand(3))
U = orth(np.random.rand(3, 3))
conv = np.dot(np.dot(np.transpose(U), D), U)
img = img + np.random.multivariate_normal([0, 0, 0], np.abs(L ** 2 * conv), img.shape[:2]).astype(np.float32)
img = img + np.random.multivariate_normal(
[0, 0, 0], np.abs(L**2 * conv), img.shape[:2]
).astype(np.float32)
img = np.clip(img, 0.0, 1.0)
return img
@@ -388,28 +430,37 @@ def add_speckle_noise(img, noise_level1=2, noise_level2=25):
img = np.clip(img, 0.0, 1.0)
rnum = random.random()
if rnum > 0.6:
img += img * np.random.normal(0, noise_level / 255.0, img.shape).astype(np.float32)
img += img * np.random.normal(
0, noise_level / 255.0, img.shape
).astype(np.float32)
elif rnum < 0.4:
img += img * np.random.normal(0, noise_level / 255.0, (*img.shape[:2], 1)).astype(np.float32)
img += img * np.random.normal(
0, noise_level / 255.0, (*img.shape[:2], 1)
).astype(np.float32)
else:
L = noise_level2 / 255.
L = noise_level2 / 255.0
D = np.diag(np.random.rand(3))
U = orth(np.random.rand(3, 3))
conv = np.dot(np.dot(np.transpose(U), D), U)
img += img * np.random.multivariate_normal([0, 0, 0], np.abs(L ** 2 * conv), img.shape[:2]).astype(np.float32)
img += img * np.random.multivariate_normal(
[0, 0, 0], np.abs(L**2 * conv), img.shape[:2]
).astype(np.float32)
img = np.clip(img, 0.0, 1.0)
return img
def add_Poisson_noise(img):
img = np.clip((img * 255.0).round(), 0, 255) / 255.
img = np.clip((img * 255.0).round(), 0, 255) / 255.0
vals = 10 ** (2 * random.random() + 2.0) # [2, 4]
if random.random() < 0.5:
img = np.random.poisson(img * vals).astype(np.float32) / vals
else:
img_gray = np.dot(img[..., :3], [0.299, 0.587, 0.114])
img_gray = np.clip((img_gray * 255.0).round(), 0, 255) / 255.
noise_gray = np.random.poisson(img_gray * vals).astype(np.float32) / vals - img_gray
img_gray = np.clip((img_gray * 255.0).round(), 0, 255) / 255.0
noise_gray = (
np.random.poisson(img_gray * vals).astype(np.float32) / vals
- img_gray
)
img += noise_gray[:, :, np.newaxis]
img = np.clip(img, 0.0, 1.0)
return img
@@ -418,7 +469,9 @@ def add_Poisson_noise(img):
def add_JPEG_noise(img):
quality_factor = random.randint(30, 95)
img = cv2.cvtColor(util.single2uint(img), cv2.COLOR_RGB2BGR)
result, encimg = cv2.imencode('.jpg', img, [int(cv2.IMWRITE_JPEG_QUALITY), quality_factor])
result, encimg = cv2.imencode(
'.jpg', img, [int(cv2.IMWRITE_JPEG_QUALITY), quality_factor]
)
img = cv2.imdecode(encimg, 1)
img = cv2.cvtColor(util.uint2single(img), cv2.COLOR_BGR2RGB)
return img
@@ -428,10 +481,14 @@ def random_crop(lq, hq, sf=4, lq_patchsize=64):
h, w = lq.shape[:2]
rnd_h = random.randint(0, h - lq_patchsize)
rnd_w = random.randint(0, w - lq_patchsize)
lq = lq[rnd_h:rnd_h + lq_patchsize, rnd_w:rnd_w + lq_patchsize, :]
lq = lq[rnd_h : rnd_h + lq_patchsize, rnd_w : rnd_w + lq_patchsize, :]
rnd_h_H, rnd_w_H = int(rnd_h * sf), int(rnd_w * sf)
hq = hq[rnd_h_H:rnd_h_H + lq_patchsize * sf, rnd_w_H:rnd_w_H + lq_patchsize * sf, :]
hq = hq[
rnd_h_H : rnd_h_H + lq_patchsize * sf,
rnd_w_H : rnd_w_H + lq_patchsize * sf,
:,
]
return lq, hq
@@ -452,7 +509,7 @@ def degradation_bsrgan(img, sf=4, lq_patchsize=72, isp_model=None):
sf_ori = sf
h1, w1 = img.shape[:2]
img = img.copy()[:w1 - w1 % sf, :h1 - h1 % sf, ...] # mod crop
img = img.copy()[: w1 - w1 % sf, : h1 - h1 % sf, ...] # mod crop
h, w = img.shape[:2]
if h < lq_patchsize * sf or w < lq_patchsize * sf:
@@ -462,8 +519,11 @@ def degradation_bsrgan(img, sf=4, lq_patchsize=72, isp_model=None):
if sf == 4 and random.random() < scale2_prob: # downsample1
if np.random.rand() < 0.5:
img = cv2.resize(img, (int(1 / 2 * img.shape[1]), int(1 / 2 * img.shape[0])),
interpolation=random.choice([1, 2, 3]))
img = cv2.resize(
img,
(int(1 / 2 * img.shape[1]), int(1 / 2 * img.shape[0])),
interpolation=random.choice([1, 2, 3]),
)
else:
img = util.imresize_np(img, 1 / 2, True)
img = np.clip(img, 0.0, 1.0)
@@ -472,7 +532,10 @@ def degradation_bsrgan(img, sf=4, lq_patchsize=72, isp_model=None):
shuffle_order = random.sample(range(7), 7)
idx1, idx2 = shuffle_order.index(2), shuffle_order.index(3)
if idx1 > idx2: # keep downsample3 last
shuffle_order[idx1], shuffle_order[idx2] = shuffle_order[idx2], shuffle_order[idx1]
shuffle_order[idx1], shuffle_order[idx2] = (
shuffle_order[idx2],
shuffle_order[idx1],
)
for i in shuffle_order:
@@ -487,19 +550,30 @@ def degradation_bsrgan(img, sf=4, lq_patchsize=72, isp_model=None):
# downsample2
if random.random() < 0.75:
sf1 = random.uniform(1, 2 * sf)
img = cv2.resize(img, (int(1 / sf1 * img.shape[1]), int(1 / sf1 * img.shape[0])),
interpolation=random.choice([1, 2, 3]))
img = cv2.resize(
img,
(int(1 / sf1 * img.shape[1]), int(1 / sf1 * img.shape[0])),
interpolation=random.choice([1, 2, 3]),
)
else:
k = fspecial('gaussian', 25, random.uniform(0.1, 0.6 * sf))
k_shifted = shift_pixel(k, sf)
k_shifted = k_shifted / k_shifted.sum() # blur with shifted kernel
img = ndimage.filters.convolve(img, np.expand_dims(k_shifted, axis=2), mode='mirror')
k_shifted = (
k_shifted / k_shifted.sum()
) # blur with shifted kernel
img = ndimage.filters.convolve(
img, np.expand_dims(k_shifted, axis=2), mode='mirror'
)
img = img[0::sf, 0::sf, ...] # nearest downsampling
img = np.clip(img, 0.0, 1.0)
elif i == 3:
# downsample3
img = cv2.resize(img, (int(1 / sf * a), int(1 / sf * b)), interpolation=random.choice([1, 2, 3]))
img = cv2.resize(
img,
(int(1 / sf * a), int(1 / sf * b)),
interpolation=random.choice([1, 2, 3]),
)
img = np.clip(img, 0.0, 1.0)
elif i == 4:
@@ -544,15 +618,18 @@ def degradation_bsrgan_variant(image, sf=4, isp_model=None):
sf_ori = sf
h1, w1 = image.shape[:2]
image = image.copy()[:w1 - w1 % sf, :h1 - h1 % sf, ...] # mod crop
image = image.copy()[: w1 - w1 % sf, : h1 - h1 % sf, ...] # mod crop
h, w = image.shape[:2]
hq = image.copy()
if sf == 4 and random.random() < scale2_prob: # downsample1
if np.random.rand() < 0.5:
image = cv2.resize(image, (int(1 / 2 * image.shape[1]), int(1 / 2 * image.shape[0])),
interpolation=random.choice([1, 2, 3]))
image = cv2.resize(
image,
(int(1 / 2 * image.shape[1]), int(1 / 2 * image.shape[0])),
interpolation=random.choice([1, 2, 3]),
)
else:
image = util.imresize_np(image, 1 / 2, True)
image = np.clip(image, 0.0, 1.0)
@@ -561,7 +638,10 @@ def degradation_bsrgan_variant(image, sf=4, isp_model=None):
shuffle_order = random.sample(range(7), 7)
idx1, idx2 = shuffle_order.index(2), shuffle_order.index(3)
if idx1 > idx2: # keep downsample3 last
shuffle_order[idx1], shuffle_order[idx2] = shuffle_order[idx2], shuffle_order[idx1]
shuffle_order[idx1], shuffle_order[idx2] = (
shuffle_order[idx2],
shuffle_order[idx1],
)
for i in shuffle_order:
@@ -576,19 +656,33 @@ def degradation_bsrgan_variant(image, sf=4, isp_model=None):
# downsample2
if random.random() < 0.75:
sf1 = random.uniform(1, 2 * sf)
image = cv2.resize(image, (int(1 / sf1 * image.shape[1]), int(1 / sf1 * image.shape[0])),
interpolation=random.choice([1, 2, 3]))
image = cv2.resize(
image,
(
int(1 / sf1 * image.shape[1]),
int(1 / sf1 * image.shape[0]),
),
interpolation=random.choice([1, 2, 3]),
)
else:
k = fspecial('gaussian', 25, random.uniform(0.1, 0.6 * sf))
k_shifted = shift_pixel(k, sf)
k_shifted = k_shifted / k_shifted.sum() # blur with shifted kernel
image = ndimage.filters.convolve(image, np.expand_dims(k_shifted, axis=2), mode='mirror')
k_shifted = (
k_shifted / k_shifted.sum()
) # blur with shifted kernel
image = ndimage.filters.convolve(
image, np.expand_dims(k_shifted, axis=2), mode='mirror'
)
image = image[0::sf, 0::sf, ...] # nearest downsampling
image = np.clip(image, 0.0, 1.0)
elif i == 3:
# downsample3
image = cv2.resize(image, (int(1 / sf * a), int(1 / sf * b)), interpolation=random.choice([1, 2, 3]))
image = cv2.resize(
image,
(int(1 / sf * a), int(1 / sf * b)),
interpolation=random.choice([1, 2, 3]),
)
image = np.clip(image, 0.0, 1.0)
elif i == 4:
@@ -609,12 +703,19 @@ def degradation_bsrgan_variant(image, sf=4, isp_model=None):
# add final JPEG compression noise
image = add_JPEG_noise(image)
image = util.single2uint(image)
example = {"image":image}
example = {'image': image}
return example
# TODO incase there is a pickle error one needs to replace a += x with a = a + x in add_speckle_noise etc...
def degradation_bsrgan_plus(img, sf=4, shuffle_prob=0.5, use_sharp=True, lq_patchsize=64, isp_model=None):
def degradation_bsrgan_plus(
img,
sf=4,
shuffle_prob=0.5,
use_sharp=True,
lq_patchsize=64,
isp_model=None,
):
"""
This is an extended degradation model by combining
the degradation models of BSRGAN and Real-ESRGAN
@@ -630,7 +731,7 @@ def degradation_bsrgan_plus(img, sf=4, shuffle_prob=0.5, use_sharp=True, lq_patc
"""
h1, w1 = img.shape[:2]
img = img.copy()[:w1 - w1 % sf, :h1 - h1 % sf, ...] # mod crop
img = img.copy()[: w1 - w1 % sf, : h1 - h1 % sf, ...] # mod crop
h, w = img.shape[:2]
if h < lq_patchsize * sf or w < lq_patchsize * sf:
@@ -645,8 +746,12 @@ def degradation_bsrgan_plus(img, sf=4, shuffle_prob=0.5, use_sharp=True, lq_patc
else:
shuffle_order = list(range(13))
# local shuffle for noise, JPEG is always the last one
shuffle_order[2:6] = random.sample(shuffle_order[2:6], len(range(2, 6)))
shuffle_order[9:13] = random.sample(shuffle_order[9:13], len(range(9, 13)))
shuffle_order[2:6] = random.sample(
shuffle_order[2:6], len(range(2, 6))
)
shuffle_order[9:13] = random.sample(
shuffle_order[9:13], len(range(9, 13))
)
poisson_prob, speckle_prob, isp_prob = 0.1, 0.1, 0.1
@@ -689,8 +794,11 @@ def degradation_bsrgan_plus(img, sf=4, shuffle_prob=0.5, use_sharp=True, lq_patc
print('check the shuffle!')
# resize to desired size
img = cv2.resize(img, (int(1 / sf * hq.shape[1]), int(1 / sf * hq.shape[0])),
interpolation=random.choice([1, 2, 3]))
img = cv2.resize(
img,
(int(1 / sf * hq.shape[1]), int(1 / sf * hq.shape[0])),
interpolation=random.choice([1, 2, 3]),
)
# add final JPEG compression noise
img = add_JPEG_noise(img)
@@ -702,29 +810,37 @@ def degradation_bsrgan_plus(img, sf=4, shuffle_prob=0.5, use_sharp=True, lq_patc
if __name__ == '__main__':
print("hey")
img = util.imread_uint('utils/test.png', 3)
print(img)
img = util.uint2single(img)
print(img)
img = img[:448, :448]
h = img.shape[0] // 4
print("resizing to", h)
sf = 4
deg_fn = partial(degradation_bsrgan_variant, sf=sf)
for i in range(20):
print(i)
img_lq = deg_fn(img)
print(img_lq)
img_lq_bicubic = albumentations.SmallestMaxSize(max_size=h, interpolation=cv2.INTER_CUBIC)(image=img)["image"]
print(img_lq.shape)
print("bicubic", img_lq_bicubic.shape)
print(img_hq.shape)
lq_nearest = cv2.resize(util.single2uint(img_lq), (int(sf * img_lq.shape[1]), int(sf * img_lq.shape[0])),
interpolation=0)
lq_bicubic_nearest = cv2.resize(util.single2uint(img_lq_bicubic), (int(sf * img_lq.shape[1]), int(sf * img_lq.shape[0])),
interpolation=0)
img_concat = np.concatenate([lq_bicubic_nearest, lq_nearest, util.single2uint(img_hq)], axis=1)
util.imsave(img_concat, str(i) + '.png')
print('hey')
img = util.imread_uint('utils/test.png', 3)
print(img)
img = util.uint2single(img)
print(img)
img = img[:448, :448]
h = img.shape[0] // 4
print('resizing to', h)
sf = 4
deg_fn = partial(degradation_bsrgan_variant, sf=sf)
for i in range(20):
print(i)
img_lq = deg_fn(img)
print(img_lq)
img_lq_bicubic = albumentations.SmallestMaxSize(
max_size=h, interpolation=cv2.INTER_CUBIC
)(image=img)['image']
print(img_lq.shape)
print('bicubic', img_lq_bicubic.shape)
print(img_hq.shape)
lq_nearest = cv2.resize(
util.single2uint(img_lq),
(int(sf * img_lq.shape[1]), int(sf * img_lq.shape[0])),
interpolation=0,
)
lq_bicubic_nearest = cv2.resize(
util.single2uint(img_lq_bicubic),
(int(sf * img_lq.shape[1]), int(sf * img_lq.shape[0])),
interpolation=0,
)
img_concat = np.concatenate(
[lq_bicubic_nearest, lq_nearest, util.single2uint(img_hq)], axis=1
)
util.imsave(img_concat, str(i) + '.png')

View File

@@ -27,16 +27,16 @@ import ldm.modules.image_degradation.utils_image as util
def modcrop_np(img, sf):
'''
"""
Args:
img: numpy image, WxH or WxHxC
sf: scale factor
Return:
cropped image
'''
"""
w, h = img.shape[:2]
im = np.copy(img)
return im[:w - w % sf, :h - h % sf, ...]
return im[: w - w % sf, : h - h % sf, ...]
"""
@@ -54,7 +54,9 @@ def analytic_kernel(k):
# Loop over the small kernel to fill the big one
for r in range(k_size):
for c in range(k_size):
big_k[2 * r:2 * r + k_size, 2 * c:2 * c + k_size] += k[r, c] * k
big_k[2 * r : 2 * r + k_size, 2 * c : 2 * c + k_size] += (
k[r, c] * k
)
# Crop the edges of the big kernel to ignore very small values and increase run time of SR
crop = k_size // 2
cropped_big_k = big_k[crop:-crop, crop:-crop]
@@ -63,7 +65,7 @@ def analytic_kernel(k):
def anisotropic_Gaussian(ksize=15, theta=np.pi, l1=6, l2=6):
""" generate an anisotropic Gaussian kernel
"""generate an anisotropic Gaussian kernel
Args:
ksize : e.g., 15, kernel size
theta : [0, pi], rotation angle range
@@ -74,7 +76,12 @@ def anisotropic_Gaussian(ksize=15, theta=np.pi, l1=6, l2=6):
k : kernel
"""
v = np.dot(np.array([[np.cos(theta), -np.sin(theta)], [np.sin(theta), np.cos(theta)]]), np.array([1., 0.]))
v = np.dot(
np.array(
[[np.cos(theta), -np.sin(theta)], [np.sin(theta), np.cos(theta)]]
),
np.array([1.0, 0.0]),
)
V = np.array([[v[0], v[1]], [v[1], -v[0]]])
D = np.array([[l1, 0], [0, l2]])
Sigma = np.dot(np.dot(V, D), np.linalg.inv(V))
@@ -126,24 +133,32 @@ def shift_pixel(x, sf, upper_left=True):
def blur(x, k):
'''
"""
x: image, NxcxHxW
k: kernel, Nx1xhxw
'''
"""
n, c = x.shape[:2]
p1, p2 = (k.shape[-2] - 1) // 2, (k.shape[-1] - 1) // 2
x = torch.nn.functional.pad(x, pad=(p1, p2, p1, p2), mode='replicate')
k = k.repeat(1, c, 1, 1)
k = k.view(-1, 1, k.shape[2], k.shape[3])
x = x.view(1, -1, x.shape[2], x.shape[3])
x = torch.nn.functional.conv2d(x, k, bias=None, stride=1, padding=0, groups=n * c)
x = torch.nn.functional.conv2d(
x, k, bias=None, stride=1, padding=0, groups=n * c
)
x = x.view(n, c, x.shape[2], x.shape[3])
return x
def gen_kernel(k_size=np.array([15, 15]), scale_factor=np.array([4, 4]), min_var=0.6, max_var=10., noise_level=0):
""""
def gen_kernel(
k_size=np.array([15, 15]),
scale_factor=np.array([4, 4]),
min_var=0.6,
max_var=10.0,
noise_level=0,
):
""" "
# modified version of https://github.com/assafshocher/BlindSR_dataset_generator
# Kai Zhang
# min_var = 0.175 * sf # variance of the gaussian kernel will be sampled between min_var and max_var
@@ -157,13 +172,16 @@ def gen_kernel(k_size=np.array([15, 15]), scale_factor=np.array([4, 4]), min_var
# Set COV matrix using Lambdas and Theta
LAMBDA = np.diag([lambda_1, lambda_2])
Q = np.array([[np.cos(theta), -np.sin(theta)],
[np.sin(theta), np.cos(theta)]])
Q = np.array(
[[np.cos(theta), -np.sin(theta)], [np.sin(theta), np.cos(theta)]]
)
SIGMA = Q @ LAMBDA @ Q.T
INV_SIGMA = np.linalg.inv(SIGMA)[None, None, :, :]
# Set expectation position (shifting kernel for aligned image)
MU = k_size // 2 - 0.5 * (scale_factor - 1) # - 0.5 * (scale_factor - k_size % 2)
MU = k_size // 2 - 0.5 * (
scale_factor - 1
) # - 0.5 * (scale_factor - k_size % 2)
MU = MU[None, None, :, None]
# Create meshgrid for Gaussian
@@ -188,7 +206,9 @@ def fspecial_gaussian(hsize, sigma):
hsize = [hsize, hsize]
siz = [(hsize[0] - 1.0) / 2.0, (hsize[1] - 1.0) / 2.0]
std = sigma
[x, y] = np.meshgrid(np.arange(-siz[1], siz[1] + 1), np.arange(-siz[0], siz[0] + 1))
[x, y] = np.meshgrid(
np.arange(-siz[1], siz[1] + 1), np.arange(-siz[0], siz[0] + 1)
)
arg = -(x * x + y * y) / (2 * std * std)
h = np.exp(arg)
h[h < scipy.finfo(float).eps * h.max()] = 0
@@ -208,10 +228,10 @@ def fspecial_laplacian(alpha):
def fspecial(filter_type, *args, **kwargs):
'''
"""
python code from:
https://github.com/ronaldosena/imagens-medicas-2/blob/40171a6c259edec7827a6693a93955de2bd39e76/Aulas/aula_2_-_uniform_filter/matlab_fspecial.py
'''
"""
if filter_type == 'gaussian':
return fspecial_gaussian(*args, **kwargs)
if filter_type == 'laplacian':
@@ -226,19 +246,19 @@ def fspecial(filter_type, *args, **kwargs):
def bicubic_degradation(x, sf=3):
'''
"""
Args:
x: HxWxC image, [0, 1]
sf: down-scale factor
Return:
bicubicly downsampled LR image
'''
"""
x = util.imresize_np(x, scale=1 / sf)
return x
def srmd_degradation(x, k, sf=3):
''' blur + bicubic downsampling
"""blur + bicubic downsampling
Args:
x: HxWxC image, [0, 1]
k: hxw, double
@@ -253,14 +273,16 @@ def srmd_degradation(x, k, sf=3):
pages={3262--3271},
year={2018}
}
'''
x = ndimage.filters.convolve(x, np.expand_dims(k, axis=2), mode='wrap') # 'nearest' | 'mirror'
"""
x = ndimage.filters.convolve(
x, np.expand_dims(k, axis=2), mode='wrap'
) # 'nearest' | 'mirror'
x = bicubic_degradation(x, sf=sf)
return x
def dpsr_degradation(x, k, sf=3):
''' bicubic downsampling + blur
"""bicubic downsampling + blur
Args:
x: HxWxC image, [0, 1]
k: hxw, double
@@ -275,21 +297,21 @@ def dpsr_degradation(x, k, sf=3):
pages={1671--1681},
year={2019}
}
'''
"""
x = bicubic_degradation(x, sf=sf)
x = ndimage.filters.convolve(x, np.expand_dims(k, axis=2), mode='wrap')
return x
def classical_degradation(x, k, sf=3):
''' blur + downsampling
"""blur + downsampling
Args:
x: HxWxC image, [0, 1]/[0, 255]
k: hxw, double
sf: down-scale factor
Return:
downsampled LR image
'''
"""
x = ndimage.filters.convolve(x, np.expand_dims(k, axis=2), mode='wrap')
# x = filters.correlate(x, np.expand_dims(np.flip(k), axis=2))
st = 0
@@ -326,16 +348,25 @@ def add_blur(img, sf=4):
wd2 = 4.0 + sf
wd = 2.0 + 0.2 * sf
wd2 = wd2/4
wd = wd/4
wd2 = wd2 / 4
wd = wd / 4
if random.random() < 0.5:
l1 = wd2 * random.random()
l2 = wd2 * random.random()
k = anisotropic_Gaussian(ksize=random.randint(2, 11) + 3, theta=random.random() * np.pi, l1=l1, l2=l2)
k = anisotropic_Gaussian(
ksize=random.randint(2, 11) + 3,
theta=random.random() * np.pi,
l1=l1,
l2=l2,
)
else:
k = fspecial('gaussian', random.randint(2, 4) + 3, wd * random.random())
img = ndimage.filters.convolve(img, np.expand_dims(k, axis=2), mode='mirror')
k = fspecial(
'gaussian', random.randint(2, 4) + 3, wd * random.random()
)
img = ndimage.filters.convolve(
img, np.expand_dims(k, axis=2), mode='mirror'
)
return img
@@ -348,7 +379,11 @@ def add_resize(img, sf=4):
sf1 = random.uniform(0.5 / sf, 1)
else:
sf1 = 1.0
img = cv2.resize(img, (int(sf1 * img.shape[1]), int(sf1 * img.shape[0])), interpolation=random.choice([1, 2, 3]))
img = cv2.resize(
img,
(int(sf1 * img.shape[1]), int(sf1 * img.shape[0])),
interpolation=random.choice([1, 2, 3]),
)
img = np.clip(img, 0.0, 1.0)
return img
@@ -370,19 +405,26 @@ def add_resize(img, sf=4):
# img = np.clip(img, 0.0, 1.0)
# return img
def add_Gaussian_noise(img, noise_level1=2, noise_level2=25):
noise_level = random.randint(noise_level1, noise_level2)
rnum = np.random.rand()
if rnum > 0.6: # add color Gaussian noise
img = img + np.random.normal(0, noise_level / 255.0, img.shape).astype(np.float32)
img = img + np.random.normal(0, noise_level / 255.0, img.shape).astype(
np.float32
)
elif rnum < 0.4: # add grayscale Gaussian noise
img = img + np.random.normal(0, noise_level / 255.0, (*img.shape[:2], 1)).astype(np.float32)
img = img + np.random.normal(
0, noise_level / 255.0, (*img.shape[:2], 1)
).astype(np.float32)
else: # add noise
L = noise_level2 / 255.
L = noise_level2 / 255.0
D = np.diag(np.random.rand(3))
U = orth(np.random.rand(3, 3))
conv = np.dot(np.dot(np.transpose(U), D), U)
img = img + np.random.multivariate_normal([0, 0, 0], np.abs(L ** 2 * conv), img.shape[:2]).astype(np.float32)
img = img + np.random.multivariate_normal(
[0, 0, 0], np.abs(L**2 * conv), img.shape[:2]
).astype(np.float32)
img = np.clip(img, 0.0, 1.0)
return img
@@ -392,28 +434,37 @@ def add_speckle_noise(img, noise_level1=2, noise_level2=25):
img = np.clip(img, 0.0, 1.0)
rnum = random.random()
if rnum > 0.6:
img += img * np.random.normal(0, noise_level / 255.0, img.shape).astype(np.float32)
img += img * np.random.normal(
0, noise_level / 255.0, img.shape
).astype(np.float32)
elif rnum < 0.4:
img += img * np.random.normal(0, noise_level / 255.0, (*img.shape[:2], 1)).astype(np.float32)
img += img * np.random.normal(
0, noise_level / 255.0, (*img.shape[:2], 1)
).astype(np.float32)
else:
L = noise_level2 / 255.
L = noise_level2 / 255.0
D = np.diag(np.random.rand(3))
U = orth(np.random.rand(3, 3))
conv = np.dot(np.dot(np.transpose(U), D), U)
img += img * np.random.multivariate_normal([0, 0, 0], np.abs(L ** 2 * conv), img.shape[:2]).astype(np.float32)
img += img * np.random.multivariate_normal(
[0, 0, 0], np.abs(L**2 * conv), img.shape[:2]
).astype(np.float32)
img = np.clip(img, 0.0, 1.0)
return img
def add_Poisson_noise(img):
img = np.clip((img * 255.0).round(), 0, 255) / 255.
img = np.clip((img * 255.0).round(), 0, 255) / 255.0
vals = 10 ** (2 * random.random() + 2.0) # [2, 4]
if random.random() < 0.5:
img = np.random.poisson(img * vals).astype(np.float32) / vals
else:
img_gray = np.dot(img[..., :3], [0.299, 0.587, 0.114])
img_gray = np.clip((img_gray * 255.0).round(), 0, 255) / 255.
noise_gray = np.random.poisson(img_gray * vals).astype(np.float32) / vals - img_gray
img_gray = np.clip((img_gray * 255.0).round(), 0, 255) / 255.0
noise_gray = (
np.random.poisson(img_gray * vals).astype(np.float32) / vals
- img_gray
)
img += noise_gray[:, :, np.newaxis]
img = np.clip(img, 0.0, 1.0)
return img
@@ -422,7 +473,9 @@ def add_Poisson_noise(img):
def add_JPEG_noise(img):
quality_factor = random.randint(80, 95)
img = cv2.cvtColor(util.single2uint(img), cv2.COLOR_RGB2BGR)
result, encimg = cv2.imencode('.jpg', img, [int(cv2.IMWRITE_JPEG_QUALITY), quality_factor])
result, encimg = cv2.imencode(
'.jpg', img, [int(cv2.IMWRITE_JPEG_QUALITY), quality_factor]
)
img = cv2.imdecode(encimg, 1)
img = cv2.cvtColor(util.uint2single(img), cv2.COLOR_BGR2RGB)
return img
@@ -432,10 +485,14 @@ def random_crop(lq, hq, sf=4, lq_patchsize=64):
h, w = lq.shape[:2]
rnd_h = random.randint(0, h - lq_patchsize)
rnd_w = random.randint(0, w - lq_patchsize)
lq = lq[rnd_h:rnd_h + lq_patchsize, rnd_w:rnd_w + lq_patchsize, :]
lq = lq[rnd_h : rnd_h + lq_patchsize, rnd_w : rnd_w + lq_patchsize, :]
rnd_h_H, rnd_w_H = int(rnd_h * sf), int(rnd_w * sf)
hq = hq[rnd_h_H:rnd_h_H + lq_patchsize * sf, rnd_w_H:rnd_w_H + lq_patchsize * sf, :]
hq = hq[
rnd_h_H : rnd_h_H + lq_patchsize * sf,
rnd_w_H : rnd_w_H + lq_patchsize * sf,
:,
]
return lq, hq
@@ -456,7 +513,7 @@ def degradation_bsrgan(img, sf=4, lq_patchsize=72, isp_model=None):
sf_ori = sf
h1, w1 = img.shape[:2]
img = img.copy()[:w1 - w1 % sf, :h1 - h1 % sf, ...] # mod crop
img = img.copy()[: w1 - w1 % sf, : h1 - h1 % sf, ...] # mod crop
h, w = img.shape[:2]
if h < lq_patchsize * sf or w < lq_patchsize * sf:
@@ -466,8 +523,11 @@ def degradation_bsrgan(img, sf=4, lq_patchsize=72, isp_model=None):
if sf == 4 and random.random() < scale2_prob: # downsample1
if np.random.rand() < 0.5:
img = cv2.resize(img, (int(1 / 2 * img.shape[1]), int(1 / 2 * img.shape[0])),
interpolation=random.choice([1, 2, 3]))
img = cv2.resize(
img,
(int(1 / 2 * img.shape[1]), int(1 / 2 * img.shape[0])),
interpolation=random.choice([1, 2, 3]),
)
else:
img = util.imresize_np(img, 1 / 2, True)
img = np.clip(img, 0.0, 1.0)
@@ -476,7 +536,10 @@ def degradation_bsrgan(img, sf=4, lq_patchsize=72, isp_model=None):
shuffle_order = random.sample(range(7), 7)
idx1, idx2 = shuffle_order.index(2), shuffle_order.index(3)
if idx1 > idx2: # keep downsample3 last
shuffle_order[idx1], shuffle_order[idx2] = shuffle_order[idx2], shuffle_order[idx1]
shuffle_order[idx1], shuffle_order[idx2] = (
shuffle_order[idx2],
shuffle_order[idx1],
)
for i in shuffle_order:
@@ -491,19 +554,30 @@ def degradation_bsrgan(img, sf=4, lq_patchsize=72, isp_model=None):
# downsample2
if random.random() < 0.75:
sf1 = random.uniform(1, 2 * sf)
img = cv2.resize(img, (int(1 / sf1 * img.shape[1]), int(1 / sf1 * img.shape[0])),
interpolation=random.choice([1, 2, 3]))
img = cv2.resize(
img,
(int(1 / sf1 * img.shape[1]), int(1 / sf1 * img.shape[0])),
interpolation=random.choice([1, 2, 3]),
)
else:
k = fspecial('gaussian', 25, random.uniform(0.1, 0.6 * sf))
k_shifted = shift_pixel(k, sf)
k_shifted = k_shifted / k_shifted.sum() # blur with shifted kernel
img = ndimage.filters.convolve(img, np.expand_dims(k_shifted, axis=2), mode='mirror')
k_shifted = (
k_shifted / k_shifted.sum()
) # blur with shifted kernel
img = ndimage.filters.convolve(
img, np.expand_dims(k_shifted, axis=2), mode='mirror'
)
img = img[0::sf, 0::sf, ...] # nearest downsampling
img = np.clip(img, 0.0, 1.0)
elif i == 3:
# downsample3
img = cv2.resize(img, (int(1 / sf * a), int(1 / sf * b)), interpolation=random.choice([1, 2, 3]))
img = cv2.resize(
img,
(int(1 / sf * a), int(1 / sf * b)),
interpolation=random.choice([1, 2, 3]),
)
img = np.clip(img, 0.0, 1.0)
elif i == 4:
@@ -548,15 +622,18 @@ def degradation_bsrgan_variant(image, sf=4, isp_model=None):
sf_ori = sf
h1, w1 = image.shape[:2]
image = image.copy()[:w1 - w1 % sf, :h1 - h1 % sf, ...] # mod crop
image = image.copy()[: w1 - w1 % sf, : h1 - h1 % sf, ...] # mod crop
h, w = image.shape[:2]
hq = image.copy()
if sf == 4 and random.random() < scale2_prob: # downsample1
if np.random.rand() < 0.5:
image = cv2.resize(image, (int(1 / 2 * image.shape[1]), int(1 / 2 * image.shape[0])),
interpolation=random.choice([1, 2, 3]))
image = cv2.resize(
image,
(int(1 / 2 * image.shape[1]), int(1 / 2 * image.shape[0])),
interpolation=random.choice([1, 2, 3]),
)
else:
image = util.imresize_np(image, 1 / 2, True)
image = np.clip(image, 0.0, 1.0)
@@ -565,7 +642,10 @@ def degradation_bsrgan_variant(image, sf=4, isp_model=None):
shuffle_order = random.sample(range(7), 7)
idx1, idx2 = shuffle_order.index(2), shuffle_order.index(3)
if idx1 > idx2: # keep downsample3 last
shuffle_order[idx1], shuffle_order[idx2] = shuffle_order[idx2], shuffle_order[idx1]
shuffle_order[idx1], shuffle_order[idx2] = (
shuffle_order[idx2],
shuffle_order[idx1],
)
for i in shuffle_order:
@@ -583,20 +663,34 @@ def degradation_bsrgan_variant(image, sf=4, isp_model=None):
# downsample2
if random.random() < 0.8:
sf1 = random.uniform(1, 2 * sf)
image = cv2.resize(image, (int(1 / sf1 * image.shape[1]), int(1 / sf1 * image.shape[0])),
interpolation=random.choice([1, 2, 3]))
image = cv2.resize(
image,
(
int(1 / sf1 * image.shape[1]),
int(1 / sf1 * image.shape[0]),
),
interpolation=random.choice([1, 2, 3]),
)
else:
k = fspecial('gaussian', 25, random.uniform(0.1, 0.6 * sf))
k_shifted = shift_pixel(k, sf)
k_shifted = k_shifted / k_shifted.sum() # blur with shifted kernel
image = ndimage.filters.convolve(image, np.expand_dims(k_shifted, axis=2), mode='mirror')
k_shifted = (
k_shifted / k_shifted.sum()
) # blur with shifted kernel
image = ndimage.filters.convolve(
image, np.expand_dims(k_shifted, axis=2), mode='mirror'
)
image = image[0::sf, 0::sf, ...] # nearest downsampling
image = np.clip(image, 0.0, 1.0)
elif i == 3:
# downsample3
image = cv2.resize(image, (int(1 / sf * a), int(1 / sf * b)), interpolation=random.choice([1, 2, 3]))
image = cv2.resize(
image,
(int(1 / sf * a), int(1 / sf * b)),
interpolation=random.choice([1, 2, 3]),
)
image = np.clip(image, 0.0, 1.0)
elif i == 4:
@@ -617,34 +711,41 @@ def degradation_bsrgan_variant(image, sf=4, isp_model=None):
# add final JPEG compression noise
image = add_JPEG_noise(image)
image = util.single2uint(image)
example = {"image": image}
example = {'image': image}
return example
if __name__ == '__main__':
print("hey")
print('hey')
img = util.imread_uint('utils/test.png', 3)
img = img[:448, :448]
h = img.shape[0] // 4
print("resizing to", h)
print('resizing to', h)
sf = 4
deg_fn = partial(degradation_bsrgan_variant, sf=sf)
for i in range(20):
print(i)
img_hq = img
img_lq = deg_fn(img)["image"]
img_lq = deg_fn(img)['image']
img_hq, img_lq = util.uint2single(img_hq), util.uint2single(img_lq)
print(img_lq)
img_lq_bicubic = albumentations.SmallestMaxSize(max_size=h, interpolation=cv2.INTER_CUBIC)(image=img_hq)["image"]
img_lq_bicubic = albumentations.SmallestMaxSize(
max_size=h, interpolation=cv2.INTER_CUBIC
)(image=img_hq)['image']
print(img_lq.shape)
print("bicubic", img_lq_bicubic.shape)
print('bicubic', img_lq_bicubic.shape)
print(img_hq.shape)
lq_nearest = cv2.resize(util.single2uint(img_lq), (int(sf * img_lq.shape[1]), int(sf * img_lq.shape[0])),
interpolation=0)
lq_bicubic_nearest = cv2.resize(util.single2uint(img_lq_bicubic),
(int(sf * img_lq.shape[1]), int(sf * img_lq.shape[0])),
interpolation=0)
img_concat = np.concatenate([lq_bicubic_nearest, lq_nearest, util.single2uint(img_hq)], axis=1)
lq_nearest = cv2.resize(
util.single2uint(img_lq),
(int(sf * img_lq.shape[1]), int(sf * img_lq.shape[0])),
interpolation=0,
)
lq_bicubic_nearest = cv2.resize(
util.single2uint(img_lq_bicubic),
(int(sf * img_lq.shape[1]), int(sf * img_lq.shape[0])),
interpolation=0,
)
img_concat = np.concatenate(
[lq_bicubic_nearest, lq_nearest, util.single2uint(img_hq)], axis=1
)
util.imsave(img_concat, str(i) + '.png')

View File

@@ -6,13 +6,14 @@ import torch
import cv2
from torchvision.utils import make_grid
from datetime import datetime
#import matplotlib.pyplot as plt # TODO: check with Dominik, also bsrgan.py vs bsrgan_light.py
# import matplotlib.pyplot as plt # TODO: check with Dominik, also bsrgan.py vs bsrgan_light.py
os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE"
os.environ['KMP_DUPLICATE_LIB_OK'] = 'TRUE'
'''
"""
# --------------------------------------------
# Kai Zhang (github: https://github.com/cszn)
# 03/Mar/2019
@@ -20,10 +21,22 @@ os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE"
# https://github.com/twhui/SRGAN-pyTorch
# https://github.com/xinntao/BasicSR
# --------------------------------------------
'''
"""
IMG_EXTENSIONS = ['.jpg', '.JPG', '.jpeg', '.JPEG', '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP', '.tif']
IMG_EXTENSIONS = [
'.jpg',
'.JPG',
'.jpeg',
'.JPEG',
'.png',
'.PNG',
'.ppm',
'.PPM',
'.bmp',
'.BMP',
'.tif',
]
def is_image_file(filename):
@@ -49,19 +62,19 @@ def surf(Z, cmap='rainbow', figsize=None):
ax3 = plt.axes(projection='3d')
w, h = Z.shape[:2]
xx = np.arange(0,w,1)
yy = np.arange(0,h,1)
xx = np.arange(0, w, 1)
yy = np.arange(0, h, 1)
X, Y = np.meshgrid(xx, yy)
ax3.plot_surface(X,Y,Z,cmap=cmap)
#ax3.contour(X,Y,Z, zdim='z',offset=-2cmap=cmap)
ax3.plot_surface(X, Y, Z, cmap=cmap)
# ax3.contour(X,Y,Z, zdim='z',offset=-2cmap=cmap)
plt.show()
'''
"""
# --------------------------------------------
# get image pathes
# --------------------------------------------
'''
"""
def get_image_paths(dataroot):
@@ -83,26 +96,26 @@ def _get_paths_from_images(path):
return images
'''
"""
# --------------------------------------------
# split large images into small images
# --------------------------------------------
'''
"""
def patches_from_image(img, p_size=512, p_overlap=64, p_max=800):
w, h = img.shape[:2]
patches = []
if w > p_max and h > p_max:
w1 = list(np.arange(0, w-p_size, p_size-p_overlap, dtype=np.int))
h1 = list(np.arange(0, h-p_size, p_size-p_overlap, dtype=np.int))
w1.append(w-p_size)
h1.append(h-p_size)
# print(w1)
# print(h1)
w1 = list(np.arange(0, w - p_size, p_size - p_overlap, dtype=np.int))
h1 = list(np.arange(0, h - p_size, p_size - p_overlap, dtype=np.int))
w1.append(w - p_size)
h1.append(h - p_size)
# print(w1)
# print(h1)
for i in w1:
for j in h1:
patches.append(img[i:i+p_size, j:j+p_size,:])
patches.append(img[i : i + p_size, j : j + p_size, :])
else:
patches.append(img)
@@ -118,11 +131,21 @@ def imssave(imgs, img_path):
for i, img in enumerate(imgs):
if img.ndim == 3:
img = img[:, :, [2, 1, 0]]
new_path = os.path.join(os.path.dirname(img_path), img_name+str('_s{:04d}'.format(i))+'.png')
new_path = os.path.join(
os.path.dirname(img_path),
img_name + str('_s{:04d}'.format(i)) + '.png',
)
cv2.imwrite(new_path, img)
def split_imageset(original_dataroot, taget_dataroot, n_channels=3, p_size=800, p_overlap=96, p_max=1000):
def split_imageset(
original_dataroot,
taget_dataroot,
n_channels=3,
p_size=800,
p_overlap=96,
p_max=1000,
):
"""
split the large images from original_dataroot into small overlapped images with size (p_size)x(p_size),
and save them into taget_dataroot; only the images with larger size than (p_max)x(p_max)
@@ -139,15 +162,18 @@ def split_imageset(original_dataroot, taget_dataroot, n_channels=3, p_size=800,
# img_name, ext = os.path.splitext(os.path.basename(img_path))
img = imread_uint(img_path, n_channels=n_channels)
patches = patches_from_image(img, p_size, p_overlap, p_max)
imssave(patches, os.path.join(taget_dataroot,os.path.basename(img_path)))
#if original_dataroot == taget_dataroot:
#del img_path
imssave(
patches, os.path.join(taget_dataroot, os.path.basename(img_path))
)
# if original_dataroot == taget_dataroot:
# del img_path
'''
"""
# --------------------------------------------
# makedir
# --------------------------------------------
'''
"""
def mkdir(path):
@@ -171,12 +197,12 @@ def mkdir_and_rename(path):
os.makedirs(path)
'''
"""
# --------------------------------------------
# read image from path
# opencv is fast, but read BGR numpy image
# --------------------------------------------
'''
"""
# --------------------------------------------
@@ -206,6 +232,7 @@ def imsave(img, img_path):
img = img[:, :, [2, 1, 0]]
cv2.imwrite(img_path, img)
def imwrite(img, img_path):
img = np.squeeze(img)
if img.ndim == 3:
@@ -213,7 +240,6 @@ def imwrite(img, img_path):
cv2.imwrite(img_path, img)
# --------------------------------------------
# get single image of size HxWxn_channles (BGR)
# --------------------------------------------
@@ -221,7 +247,7 @@ def read_img(path):
# read image by cv2
# return: Numpy float32, HWC, BGR, [0,1]
img = cv2.imread(path, cv2.IMREAD_UNCHANGED) # cv2.IMREAD_GRAYSCALE
img = img.astype(np.float32) / 255.
img = img.astype(np.float32) / 255.0
if img.ndim == 2:
img = np.expand_dims(img, axis=2)
# some images have 4 channels
@@ -230,7 +256,7 @@ def read_img(path):
return img
'''
"""
# --------------------------------------------
# image format conversion
# --------------------------------------------
@@ -238,7 +264,7 @@ def read_img(path):
# numpy(single) <---> tensor
# numpy(unit) <---> tensor
# --------------------------------------------
'''
"""
# --------------------------------------------
@@ -248,22 +274,22 @@ def read_img(path):
def uint2single(img):
return np.float32(img/255.)
return np.float32(img / 255.0)
def single2uint(img):
return np.uint8((img.clip(0, 1)*255.).round())
return np.uint8((img.clip(0, 1) * 255.0).round())
def uint162single(img):
return np.float32(img/65535.)
return np.float32(img / 65535.0)
def single2uint16(img):
return np.uint16((img.clip(0, 1)*65535.).round())
return np.uint16((img.clip(0, 1) * 65535.0).round())
# --------------------------------------------
@@ -275,14 +301,25 @@ def single2uint16(img):
def uint2tensor4(img):
if img.ndim == 2:
img = np.expand_dims(img, axis=2)
return torch.from_numpy(np.ascontiguousarray(img)).permute(2, 0, 1).float().div(255.).unsqueeze(0)
return (
torch.from_numpy(np.ascontiguousarray(img))
.permute(2, 0, 1)
.float()
.div(255.0)
.unsqueeze(0)
)
# convert uint to 3-dimensional torch tensor
def uint2tensor3(img):
if img.ndim == 2:
img = np.expand_dims(img, axis=2)
return torch.from_numpy(np.ascontiguousarray(img)).permute(2, 0, 1).float().div(255.)
return (
torch.from_numpy(np.ascontiguousarray(img))
.permute(2, 0, 1)
.float()
.div(255.0)
)
# convert 2/3/4-dimensional torch tensor to uint
@@ -290,7 +327,7 @@ def tensor2uint(img):
img = img.data.squeeze().float().clamp_(0, 1).cpu().numpy()
if img.ndim == 3:
img = np.transpose(img, (1, 2, 0))
return np.uint8((img*255.0).round())
return np.uint8((img * 255.0).round())
# --------------------------------------------
@@ -305,7 +342,12 @@ def single2tensor3(img):
# convert single (HxWxC) to 4-dimensional torch tensor
def single2tensor4(img):
return torch.from_numpy(np.ascontiguousarray(img)).permute(2, 0, 1).float().unsqueeze(0)
return (
torch.from_numpy(np.ascontiguousarray(img))
.permute(2, 0, 1)
.float()
.unsqueeze(0)
)
# convert torch tensor to single
@@ -316,6 +358,7 @@ def tensor2single(img):
return img
# convert torch tensor to single
def tensor2single3(img):
img = img.data.squeeze().float().cpu().numpy()
@@ -327,30 +370,48 @@ def tensor2single3(img):
def single2tensor5(img):
return torch.from_numpy(np.ascontiguousarray(img)).permute(2, 0, 1, 3).float().unsqueeze(0)
return (
torch.from_numpy(np.ascontiguousarray(img))
.permute(2, 0, 1, 3)
.float()
.unsqueeze(0)
)
def single32tensor5(img):
return torch.from_numpy(np.ascontiguousarray(img)).float().unsqueeze(0).unsqueeze(0)
return (
torch.from_numpy(np.ascontiguousarray(img))
.float()
.unsqueeze(0)
.unsqueeze(0)
)
def single42tensor4(img):
return torch.from_numpy(np.ascontiguousarray(img)).permute(2, 0, 1, 3).float()
return (
torch.from_numpy(np.ascontiguousarray(img)).permute(2, 0, 1, 3).float()
)
# from skimage.io import imread, imsave
def tensor2img(tensor, out_type=np.uint8, min_max=(0, 1)):
'''
"""
Converts a torch Tensor into an image Numpy array of BGR channel order
Input: 4D(B,(3/1),H,W), 3D(C,H,W), or 2D(H,W), any range, RGB channel order
Output: 3D(H,W,C) or 2D(H,W), [0,255], np.uint8 (default)
'''
tensor = tensor.squeeze().float().cpu().clamp_(*min_max) # squeeze first, then clamp
tensor = (tensor - min_max[0]) / (min_max[1] - min_max[0]) # to range [0,1]
"""
tensor = (
tensor.squeeze().float().cpu().clamp_(*min_max)
) # squeeze first, then clamp
tensor = (tensor - min_max[0]) / (
min_max[1] - min_max[0]
) # to range [0,1]
n_dim = tensor.dim()
if n_dim == 4:
n_img = len(tensor)
img_np = make_grid(tensor, nrow=int(math.sqrt(n_img)), normalize=False).numpy()
img_np = make_grid(
tensor, nrow=int(math.sqrt(n_img)), normalize=False
).numpy()
img_np = np.transpose(img_np[[2, 1, 0], :, :], (1, 2, 0)) # HWC, BGR
elif n_dim == 3:
img_np = tensor.numpy()
@@ -359,14 +420,17 @@ def tensor2img(tensor, out_type=np.uint8, min_max=(0, 1)):
img_np = tensor.numpy()
else:
raise TypeError(
'Only support 4D, 3D and 2D tensor. But received with dimension: {:d}'.format(n_dim))
'Only support 4D, 3D and 2D tensor. But received with dimension: {:d}'.format(
n_dim
)
)
if out_type == np.uint8:
img_np = (img_np * 255.0).round()
# Important. Unlike matlab, numpy.unit8() WILL NOT round by default.
return img_np.astype(out_type)
'''
"""
# --------------------------------------------
# Augmentation, flipe and/or rotate
# --------------------------------------------
@@ -374,12 +438,11 @@ def tensor2img(tensor, out_type=np.uint8, min_max=(0, 1)):
# (1) augmet_img: numpy image of WxHxC or WxH
# (2) augment_img_tensor4: tensor image 1xCxWxH
# --------------------------------------------
'''
"""
def augment_img(img, mode=0):
'''Kai Zhang (github: https://github.com/cszn)
'''
"""Kai Zhang (github: https://github.com/cszn)"""
if mode == 0:
return img
elif mode == 1:
@@ -399,8 +462,7 @@ def augment_img(img, mode=0):
def augment_img_tensor4(img, mode=0):
'''Kai Zhang (github: https://github.com/cszn)
'''
"""Kai Zhang (github: https://github.com/cszn)"""
if mode == 0:
return img
elif mode == 1:
@@ -420,8 +482,7 @@ def augment_img_tensor4(img, mode=0):
def augment_img_tensor(img, mode=0):
'''Kai Zhang (github: https://github.com/cszn)
'''
"""Kai Zhang (github: https://github.com/cszn)"""
img_size = img.size()
img_np = img.data.cpu().numpy()
if len(img_size) == 3:
@@ -484,11 +545,11 @@ def augment_imgs(img_list, hflip=True, rot=True):
return [_augment(img) for img in img_list]
'''
"""
# --------------------------------------------
# modcrop and shave
# --------------------------------------------
'''
"""
def modcrop(img_in, scale):
@@ -497,11 +558,11 @@ def modcrop(img_in, scale):
if img.ndim == 2:
H, W = img.shape
H_r, W_r = H % scale, W % scale
img = img[:H - H_r, :W - W_r]
img = img[: H - H_r, : W - W_r]
elif img.ndim == 3:
H, W, C = img.shape
H_r, W_r = H % scale, W % scale
img = img[:H - H_r, :W - W_r, :]
img = img[: H - H_r, : W - W_r, :]
else:
raise ValueError('Wrong img ndim: [{:d}].'.format(img.ndim))
return img
@@ -511,11 +572,11 @@ def shave(img_in, border=0):
# img_in: Numpy, HWC or HW
img = np.copy(img_in)
h, w = img.shape[:2]
img = img[border:h-border, border:w-border]
img = img[border : h - border, border : w - border]
return img
'''
"""
# --------------------------------------------
# image processing process on numpy image
# channel_convert(in_c, tar_type, img_list):
@@ -523,74 +584,92 @@ def shave(img_in, border=0):
# bgr2ycbcr(img, only_y=True):
# ycbcr2rgb(img):
# --------------------------------------------
'''
"""
def rgb2ycbcr(img, only_y=True):
'''same as matlab rgb2ycbcr
"""same as matlab rgb2ycbcr
only_y: only return Y channel
Input:
uint8, [0, 255]
float, [0, 1]
'''
"""
in_img_type = img.dtype
img.astype(np.float32)
if in_img_type != np.uint8:
img *= 255.
img *= 255.0
# convert
if only_y:
rlt = np.dot(img, [65.481, 128.553, 24.966]) / 255.0 + 16.0
else:
rlt = np.matmul(img, [[65.481, -37.797, 112.0], [128.553, -74.203, -93.786],
[24.966, 112.0, -18.214]]) / 255.0 + [16, 128, 128]
rlt = np.matmul(
img,
[
[65.481, -37.797, 112.0],
[128.553, -74.203, -93.786],
[24.966, 112.0, -18.214],
],
) / 255.0 + [16, 128, 128]
if in_img_type == np.uint8:
rlt = rlt.round()
else:
rlt /= 255.
rlt /= 255.0
return rlt.astype(in_img_type)
def ycbcr2rgb(img):
'''same as matlab ycbcr2rgb
"""same as matlab ycbcr2rgb
Input:
uint8, [0, 255]
float, [0, 1]
'''
"""
in_img_type = img.dtype
img.astype(np.float32)
if in_img_type != np.uint8:
img *= 255.
img *= 255.0
# convert
rlt = np.matmul(img, [[0.00456621, 0.00456621, 0.00456621], [0, -0.00153632, 0.00791071],
[0.00625893, -0.00318811, 0]]) * 255.0 + [-222.921, 135.576, -276.836]
rlt = np.matmul(
img,
[
[0.00456621, 0.00456621, 0.00456621],
[0, -0.00153632, 0.00791071],
[0.00625893, -0.00318811, 0],
],
) * 255.0 + [-222.921, 135.576, -276.836]
if in_img_type == np.uint8:
rlt = rlt.round()
else:
rlt /= 255.
rlt /= 255.0
return rlt.astype(in_img_type)
def bgr2ycbcr(img, only_y=True):
'''bgr version of rgb2ycbcr
"""bgr version of rgb2ycbcr
only_y: only return Y channel
Input:
uint8, [0, 255]
float, [0, 1]
'''
"""
in_img_type = img.dtype
img.astype(np.float32)
if in_img_type != np.uint8:
img *= 255.
img *= 255.0
# convert
if only_y:
rlt = np.dot(img, [24.966, 128.553, 65.481]) / 255.0 + 16.0
else:
rlt = np.matmul(img, [[24.966, 112.0, -18.214], [128.553, -74.203, -93.786],
[65.481, -37.797, 112.0]]) / 255.0 + [16, 128, 128]
rlt = np.matmul(
img,
[
[24.966, 112.0, -18.214],
[128.553, -74.203, -93.786],
[65.481, -37.797, 112.0],
],
) / 255.0 + [16, 128, 128]
if in_img_type == np.uint8:
rlt = rlt.round()
else:
rlt /= 255.
rlt /= 255.0
return rlt.astype(in_img_type)
@@ -608,11 +687,11 @@ def channel_convert(in_c, tar_type, img_list):
return img_list
'''
"""
# --------------------------------------------
# metric, PSNR and SSIM
# --------------------------------------------
'''
"""
# --------------------------------------------
@@ -620,17 +699,17 @@ def channel_convert(in_c, tar_type, img_list):
# --------------------------------------------
def calculate_psnr(img1, img2, border=0):
# img1 and img2 have range [0, 255]
#img1 = img1.squeeze()
#img2 = img2.squeeze()
# img1 = img1.squeeze()
# img2 = img2.squeeze()
if not img1.shape == img2.shape:
raise ValueError('Input images must have the same dimensions.')
h, w = img1.shape[:2]
img1 = img1[border:h-border, border:w-border]
img2 = img2[border:h-border, border:w-border]
img1 = img1[border : h - border, border : w - border]
img2 = img2[border : h - border, border : w - border]
img1 = img1.astype(np.float64)
img2 = img2.astype(np.float64)
mse = np.mean((img1 - img2)**2)
mse = np.mean((img1 - img2) ** 2)
if mse == 0:
return float('inf')
return 20 * math.log10(255.0 / math.sqrt(mse))
@@ -640,17 +719,17 @@ def calculate_psnr(img1, img2, border=0):
# SSIM
# --------------------------------------------
def calculate_ssim(img1, img2, border=0):
'''calculate SSIM
"""calculate SSIM
the same outputs as MATLAB's
img1, img2: [0, 255]
'''
#img1 = img1.squeeze()
#img2 = img2.squeeze()
"""
# img1 = img1.squeeze()
# img2 = img2.squeeze()
if not img1.shape == img2.shape:
raise ValueError('Input images must have the same dimensions.')
h, w = img1.shape[:2]
img1 = img1[border:h-border, border:w-border]
img2 = img2[border:h-border, border:w-border]
img1 = img1[border : h - border, border : w - border]
img2 = img2[border : h - border, border : w - border]
if img1.ndim == 2:
return ssim(img1, img2)
@@ -658,7 +737,7 @@ def calculate_ssim(img1, img2, border=0):
if img1.shape[2] == 3:
ssims = []
for i in range(3):
ssims.append(ssim(img1[:,:,i], img2[:,:,i]))
ssims.append(ssim(img1[:, :, i], img2[:, :, i]))
return np.array(ssims).mean()
elif img1.shape[2] == 1:
return ssim(np.squeeze(img1), np.squeeze(img2))
@@ -667,8 +746,8 @@ def calculate_ssim(img1, img2, border=0):
def ssim(img1, img2):
C1 = (0.01 * 255)**2
C2 = (0.03 * 255)**2
C1 = (0.01 * 255) ** 2
C2 = (0.03 * 255) ** 2
img1 = img1.astype(np.float64)
img2 = img2.astype(np.float64)
@@ -684,16 +763,17 @@ def ssim(img1, img2):
sigma2_sq = cv2.filter2D(img2**2, -1, window)[5:-5, 5:-5] - mu2_sq
sigma12 = cv2.filter2D(img1 * img2, -1, window)[5:-5, 5:-5] - mu1_mu2
ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) *
(sigma1_sq + sigma2_sq + C2))
ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / (
(mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2)
)
return ssim_map.mean()
'''
"""
# --------------------------------------------
# matlab's bicubic imresize (numpy and torch) [0, 1]
# --------------------------------------------
'''
"""
# matlab 'imresize' function, now only support 'bicubic'
@@ -701,11 +781,14 @@ def cubic(x):
absx = torch.abs(x)
absx2 = absx**2
absx3 = absx**3
return (1.5*absx3 - 2.5*absx2 + 1) * ((absx <= 1).type_as(absx)) + \
(-0.5*absx3 + 2.5*absx2 - 4*absx + 2) * (((absx > 1)*(absx <= 2)).type_as(absx))
return (1.5 * absx3 - 2.5 * absx2 + 1) * ((absx <= 1).type_as(absx)) + (
-0.5 * absx3 + 2.5 * absx2 - 4 * absx + 2
) * (((absx > 1) * (absx <= 2)).type_as(absx))
def calculate_weights_indices(in_length, out_length, scale, kernel, kernel_width, antialiasing):
def calculate_weights_indices(
in_length, out_length, scale, kernel, kernel_width, antialiasing
):
if (scale < 1) and (antialiasing):
# Use a modified kernel to simultaneously interpolate and antialias- larger kernel width
kernel_width = kernel_width / scale
@@ -729,8 +812,9 @@ def calculate_weights_indices(in_length, out_length, scale, kernel, kernel_width
# The indices of the input pixels involved in computing the k-th output
# pixel are in row k of the indices matrix.
indices = left.view(out_length, 1).expand(out_length, P) + torch.linspace(0, P - 1, P).view(
1, P).expand(out_length, P)
indices = left.view(out_length, 1).expand(out_length, P) + torch.linspace(
0, P - 1, P
).view(1, P).expand(out_length, P)
# The weights used to compute the k-th output pixel are in row k of the
# weights matrix.
@@ -771,7 +855,11 @@ def imresize(img, scale, antialiasing=True):
if need_squeeze:
img.unsqueeze_(0)
in_C, in_H, in_W = img.size()
out_C, out_H, out_W = in_C, math.ceil(in_H * scale), math.ceil(in_W * scale)
out_C, out_H, out_W = (
in_C,
math.ceil(in_H * scale),
math.ceil(in_W * scale),
)
kernel_width = 4
kernel = 'cubic'
@@ -782,9 +870,11 @@ def imresize(img, scale, antialiasing=True):
# get weights and indices
weights_H, indices_H, sym_len_Hs, sym_len_He = calculate_weights_indices(
in_H, out_H, scale, kernel, kernel_width, antialiasing)
in_H, out_H, scale, kernel, kernel_width, antialiasing
)
weights_W, indices_W, sym_len_Ws, sym_len_We = calculate_weights_indices(
in_W, out_W, scale, kernel, kernel_width, antialiasing)
in_W, out_W, scale, kernel, kernel_width, antialiasing
)
# process H dimension
# symmetric copying
img_aug = torch.FloatTensor(in_C, in_H + sym_len_Hs + sym_len_He, in_W)
@@ -805,7 +895,11 @@ def imresize(img, scale, antialiasing=True):
for i in range(out_H):
idx = int(indices_H[i][0])
for j in range(out_C):
out_1[j, i, :] = img_aug[j, idx:idx + kernel_width, :].transpose(0, 1).mv(weights_H[i])
out_1[j, i, :] = (
img_aug[j, idx : idx + kernel_width, :]
.transpose(0, 1)
.mv(weights_H[i])
)
# process W dimension
# symmetric copying
@@ -827,7 +921,9 @@ def imresize(img, scale, antialiasing=True):
for i in range(out_W):
idx = int(indices_W[i][0])
for j in range(out_C):
out_2[j, :, i] = out_1_aug[j, :, idx:idx + kernel_width].mv(weights_W[i])
out_2[j, :, i] = out_1_aug[j, :, idx : idx + kernel_width].mv(
weights_W[i]
)
if need_squeeze:
out_2.squeeze_()
return out_2
@@ -846,7 +942,11 @@ def imresize_np(img, scale, antialiasing=True):
img.unsqueeze_(2)
in_H, in_W, in_C = img.size()
out_C, out_H, out_W = in_C, math.ceil(in_H * scale), math.ceil(in_W * scale)
out_C, out_H, out_W = (
in_C,
math.ceil(in_H * scale),
math.ceil(in_W * scale),
)
kernel_width = 4
kernel = 'cubic'
@@ -857,9 +957,11 @@ def imresize_np(img, scale, antialiasing=True):
# get weights and indices
weights_H, indices_H, sym_len_Hs, sym_len_He = calculate_weights_indices(
in_H, out_H, scale, kernel, kernel_width, antialiasing)
in_H, out_H, scale, kernel, kernel_width, antialiasing
)
weights_W, indices_W, sym_len_Ws, sym_len_We = calculate_weights_indices(
in_W, out_W, scale, kernel, kernel_width, antialiasing)
in_W, out_W, scale, kernel, kernel_width, antialiasing
)
# process H dimension
# symmetric copying
img_aug = torch.FloatTensor(in_H + sym_len_Hs + sym_len_He, in_W, in_C)
@@ -880,7 +982,11 @@ def imresize_np(img, scale, antialiasing=True):
for i in range(out_H):
idx = int(indices_H[i][0])
for j in range(out_C):
out_1[i, :, j] = img_aug[idx:idx + kernel_width, :, j].transpose(0, 1).mv(weights_H[i])
out_1[i, :, j] = (
img_aug[idx : idx + kernel_width, :, j]
.transpose(0, 1)
.mv(weights_H[i])
)
# process W dimension
# symmetric copying
@@ -902,7 +1008,9 @@ def imresize_np(img, scale, antialiasing=True):
for i in range(out_W):
idx = int(indices_W[i][0])
for j in range(out_C):
out_2[:, i, j] = out_1_aug[:, idx:idx + kernel_width, j].mv(weights_W[i])
out_2[:, i, j] = out_1_aug[:, idx : idx + kernel_width, j].mv(
weights_W[i]
)
if need_squeeze:
out_2.squeeze_()
@@ -913,4 +1021,4 @@ if __name__ == '__main__':
print('---')
# img = imread_uint('test.bmp', 3)
# img = uint2single(img)
# img_bicubic = imresize_np(img, 1/4)
# img_bicubic = imresize_np(img, 1/4)

View File

@@ -1 +1 @@
from ldm.modules.losses.contperceptual import LPIPSWithDiscriminator
from ldm.modules.losses.contperceptual import LPIPSWithDiscriminator

View File

@@ -5,13 +5,24 @@ from taming.modules.losses.vqperceptual import * # TODO: taming dependency yes/
class LPIPSWithDiscriminator(nn.Module):
def __init__(self, disc_start, logvar_init=0.0, kl_weight=1.0, pixelloss_weight=1.0,
disc_num_layers=3, disc_in_channels=3, disc_factor=1.0, disc_weight=1.0,
perceptual_weight=1.0, use_actnorm=False, disc_conditional=False,
disc_loss="hinge"):
def __init__(
self,
disc_start,
logvar_init=0.0,
kl_weight=1.0,
pixelloss_weight=1.0,
disc_num_layers=3,
disc_in_channels=3,
disc_factor=1.0,
disc_weight=1.0,
perceptual_weight=1.0,
use_actnorm=False,
disc_conditional=False,
disc_loss='hinge',
):
super().__init__()
assert disc_loss in ["hinge", "vanilla"]
assert disc_loss in ['hinge', 'vanilla']
self.kl_weight = kl_weight
self.pixel_weight = pixelloss_weight
self.perceptual_loss = LPIPS().eval()
@@ -19,42 +30,68 @@ class LPIPSWithDiscriminator(nn.Module):
# output log variance
self.logvar = nn.Parameter(torch.ones(size=()) * logvar_init)
self.discriminator = NLayerDiscriminator(input_nc=disc_in_channels,
n_layers=disc_num_layers,
use_actnorm=use_actnorm
).apply(weights_init)
self.discriminator = NLayerDiscriminator(
input_nc=disc_in_channels,
n_layers=disc_num_layers,
use_actnorm=use_actnorm,
).apply(weights_init)
self.discriminator_iter_start = disc_start
self.disc_loss = hinge_d_loss if disc_loss == "hinge" else vanilla_d_loss
self.disc_loss = (
hinge_d_loss if disc_loss == 'hinge' else vanilla_d_loss
)
self.disc_factor = disc_factor
self.discriminator_weight = disc_weight
self.disc_conditional = disc_conditional
def calculate_adaptive_weight(self, nll_loss, g_loss, last_layer=None):
if last_layer is not None:
nll_grads = torch.autograd.grad(nll_loss, last_layer, retain_graph=True)[0]
g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0]
nll_grads = torch.autograd.grad(
nll_loss, last_layer, retain_graph=True
)[0]
g_grads = torch.autograd.grad(
g_loss, last_layer, retain_graph=True
)[0]
else:
nll_grads = torch.autograd.grad(nll_loss, self.last_layer[0], retain_graph=True)[0]
g_grads = torch.autograd.grad(g_loss, self.last_layer[0], retain_graph=True)[0]
nll_grads = torch.autograd.grad(
nll_loss, self.last_layer[0], retain_graph=True
)[0]
g_grads = torch.autograd.grad(
g_loss, self.last_layer[0], retain_graph=True
)[0]
d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4)
d_weight = torch.clamp(d_weight, 0.0, 1e4).detach()
d_weight = d_weight * self.discriminator_weight
return d_weight
def forward(self, inputs, reconstructions, posteriors, optimizer_idx,
global_step, last_layer=None, cond=None, split="train",
weights=None):
rec_loss = torch.abs(inputs.contiguous() - reconstructions.contiguous())
def forward(
self,
inputs,
reconstructions,
posteriors,
optimizer_idx,
global_step,
last_layer=None,
cond=None,
split='train',
weights=None,
):
rec_loss = torch.abs(
inputs.contiguous() - reconstructions.contiguous()
)
if self.perceptual_weight > 0:
p_loss = self.perceptual_loss(inputs.contiguous(), reconstructions.contiguous())
p_loss = self.perceptual_loss(
inputs.contiguous(), reconstructions.contiguous()
)
rec_loss = rec_loss + self.perceptual_weight * p_loss
nll_loss = rec_loss / torch.exp(self.logvar) + self.logvar
weighted_nll_loss = nll_loss
if weights is not None:
weighted_nll_loss = weights*nll_loss
weighted_nll_loss = torch.sum(weighted_nll_loss) / weighted_nll_loss.shape[0]
weighted_nll_loss = weights * nll_loss
weighted_nll_loss = (
torch.sum(weighted_nll_loss) / weighted_nll_loss.shape[0]
)
nll_loss = torch.sum(nll_loss) / nll_loss.shape[0]
kl_loss = posteriors.kl()
kl_loss = torch.sum(kl_loss) / kl_loss.shape[0]
@@ -67,45 +104,72 @@ class LPIPSWithDiscriminator(nn.Module):
logits_fake = self.discriminator(reconstructions.contiguous())
else:
assert self.disc_conditional
logits_fake = self.discriminator(torch.cat((reconstructions.contiguous(), cond), dim=1))
logits_fake = self.discriminator(
torch.cat((reconstructions.contiguous(), cond), dim=1)
)
g_loss = -torch.mean(logits_fake)
if self.disc_factor > 0.0:
try:
d_weight = self.calculate_adaptive_weight(nll_loss, g_loss, last_layer=last_layer)
d_weight = self.calculate_adaptive_weight(
nll_loss, g_loss, last_layer=last_layer
)
except RuntimeError:
assert not self.training
d_weight = torch.tensor(0.0)
else:
d_weight = torch.tensor(0.0)
disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start)
loss = weighted_nll_loss + self.kl_weight * kl_loss + d_weight * disc_factor * g_loss
disc_factor = adopt_weight(
self.disc_factor,
global_step,
threshold=self.discriminator_iter_start,
)
loss = (
weighted_nll_loss
+ self.kl_weight * kl_loss
+ d_weight * disc_factor * g_loss
)
log = {"{}/total_loss".format(split): loss.clone().detach().mean(), "{}/logvar".format(split): self.logvar.detach(),
"{}/kl_loss".format(split): kl_loss.detach().mean(), "{}/nll_loss".format(split): nll_loss.detach().mean(),
"{}/rec_loss".format(split): rec_loss.detach().mean(),
"{}/d_weight".format(split): d_weight.detach(),
"{}/disc_factor".format(split): torch.tensor(disc_factor),
"{}/g_loss".format(split): g_loss.detach().mean(),
}
log = {
'{}/total_loss'.format(split): loss.clone().detach().mean(),
'{}/logvar'.format(split): self.logvar.detach(),
'{}/kl_loss'.format(split): kl_loss.detach().mean(),
'{}/nll_loss'.format(split): nll_loss.detach().mean(),
'{}/rec_loss'.format(split): rec_loss.detach().mean(),
'{}/d_weight'.format(split): d_weight.detach(),
'{}/disc_factor'.format(split): torch.tensor(disc_factor),
'{}/g_loss'.format(split): g_loss.detach().mean(),
}
return loss, log
if optimizer_idx == 1:
# second pass for discriminator update
if cond is None:
logits_real = self.discriminator(inputs.contiguous().detach())
logits_fake = self.discriminator(reconstructions.contiguous().detach())
logits_fake = self.discriminator(
reconstructions.contiguous().detach()
)
else:
logits_real = self.discriminator(torch.cat((inputs.contiguous().detach(), cond), dim=1))
logits_fake = self.discriminator(torch.cat((reconstructions.contiguous().detach(), cond), dim=1))
logits_real = self.discriminator(
torch.cat((inputs.contiguous().detach(), cond), dim=1)
)
logits_fake = self.discriminator(
torch.cat(
(reconstructions.contiguous().detach(), cond), dim=1
)
)
disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start)
disc_factor = adopt_weight(
self.disc_factor,
global_step,
threshold=self.discriminator_iter_start,
)
d_loss = disc_factor * self.disc_loss(logits_real, logits_fake)
log = {"{}/disc_loss".format(split): d_loss.clone().detach().mean(),
"{}/logits_real".format(split): logits_real.detach().mean(),
"{}/logits_fake".format(split): logits_fake.detach().mean()
}
log = {
'{}/disc_loss'.format(split): d_loss.clone().detach().mean(),
'{}/logits_real'.format(split): logits_real.detach().mean(),
'{}/logits_fake'.format(split): logits_fake.detach().mean(),
}
return d_loss, log

View File

@@ -3,21 +3,25 @@ from torch import nn
import torch.nn.functional as F
from einops import repeat
from taming.modules.discriminator.model import NLayerDiscriminator, weights_init
from taming.modules.discriminator.model import (
NLayerDiscriminator,
weights_init,
)
from taming.modules.losses.lpips import LPIPS
from taming.modules.losses.vqperceptual import hinge_d_loss, vanilla_d_loss
def hinge_d_loss_with_exemplar_weights(logits_real, logits_fake, weights):
assert weights.shape[0] == logits_real.shape[0] == logits_fake.shape[0]
loss_real = torch.mean(F.relu(1. - logits_real), dim=[1,2,3])
loss_fake = torch.mean(F.relu(1. + logits_fake), dim=[1,2,3])
loss_real = torch.mean(F.relu(1.0 - logits_real), dim=[1, 2, 3])
loss_fake = torch.mean(F.relu(1.0 + logits_fake), dim=[1, 2, 3])
loss_real = (weights * loss_real).sum() / weights.sum()
loss_fake = (weights * loss_fake).sum() / weights.sum()
d_loss = 0.5 * (loss_real + loss_fake)
return d_loss
def adopt_weight(weight, global_step, threshold=0, value=0.):
def adopt_weight(weight, global_step, threshold=0, value=0.0):
if global_step < threshold:
weight = value
return weight
@@ -26,57 +30,76 @@ def adopt_weight(weight, global_step, threshold=0, value=0.):
def measure_perplexity(predicted_indices, n_embed):
# src: https://github.com/karpathy/deep-vector-quantization/blob/main/model.py
# eval cluster perplexity. when perplexity == num_embeddings then all clusters are used exactly equally
encodings = F.one_hot(predicted_indices, n_embed).float().reshape(-1, n_embed)
encodings = (
F.one_hot(predicted_indices, n_embed).float().reshape(-1, n_embed)
)
avg_probs = encodings.mean(0)
perplexity = (-(avg_probs * torch.log(avg_probs + 1e-10)).sum()).exp()
cluster_use = torch.sum(avg_probs > 0)
return perplexity, cluster_use
def l1(x, y):
return torch.abs(x-y)
return torch.abs(x - y)
def l2(x, y):
return torch.pow((x-y), 2)
return torch.pow((x - y), 2)
class VQLPIPSWithDiscriminator(nn.Module):
def __init__(self, disc_start, codebook_weight=1.0, pixelloss_weight=1.0,
disc_num_layers=3, disc_in_channels=3, disc_factor=1.0, disc_weight=1.0,
perceptual_weight=1.0, use_actnorm=False, disc_conditional=False,
disc_ndf=64, disc_loss="hinge", n_classes=None, perceptual_loss="lpips",
pixel_loss="l1"):
def __init__(
self,
disc_start,
codebook_weight=1.0,
pixelloss_weight=1.0,
disc_num_layers=3,
disc_in_channels=3,
disc_factor=1.0,
disc_weight=1.0,
perceptual_weight=1.0,
use_actnorm=False,
disc_conditional=False,
disc_ndf=64,
disc_loss='hinge',
n_classes=None,
perceptual_loss='lpips',
pixel_loss='l1',
):
super().__init__()
assert disc_loss in ["hinge", "vanilla"]
assert perceptual_loss in ["lpips", "clips", "dists"]
assert pixel_loss in ["l1", "l2"]
assert disc_loss in ['hinge', 'vanilla']
assert perceptual_loss in ['lpips', 'clips', 'dists']
assert pixel_loss in ['l1', 'l2']
self.codebook_weight = codebook_weight
self.pixel_weight = pixelloss_weight
if perceptual_loss == "lpips":
print(f"{self.__class__.__name__}: Running with LPIPS.")
if perceptual_loss == 'lpips':
print(f'{self.__class__.__name__}: Running with LPIPS.')
self.perceptual_loss = LPIPS().eval()
else:
raise ValueError(f"Unknown perceptual loss: >> {perceptual_loss} <<")
raise ValueError(
f'Unknown perceptual loss: >> {perceptual_loss} <<'
)
self.perceptual_weight = perceptual_weight
if pixel_loss == "l1":
if pixel_loss == 'l1':
self.pixel_loss = l1
else:
self.pixel_loss = l2
self.discriminator = NLayerDiscriminator(input_nc=disc_in_channels,
n_layers=disc_num_layers,
use_actnorm=use_actnorm,
ndf=disc_ndf
).apply(weights_init)
self.discriminator = NLayerDiscriminator(
input_nc=disc_in_channels,
n_layers=disc_num_layers,
use_actnorm=use_actnorm,
ndf=disc_ndf,
).apply(weights_init)
self.discriminator_iter_start = disc_start
if disc_loss == "hinge":
if disc_loss == 'hinge':
self.disc_loss = hinge_d_loss
elif disc_loss == "vanilla":
elif disc_loss == 'vanilla':
self.disc_loss = vanilla_d_loss
else:
raise ValueError(f"Unknown GAN loss '{disc_loss}'.")
print(f"VQLPIPSWithDiscriminator running with {disc_loss} loss.")
print(f'VQLPIPSWithDiscriminator running with {disc_loss} loss.')
self.disc_factor = disc_factor
self.discriminator_weight = disc_weight
self.disc_conditional = disc_conditional
@@ -84,31 +107,53 @@ class VQLPIPSWithDiscriminator(nn.Module):
def calculate_adaptive_weight(self, nll_loss, g_loss, last_layer=None):
if last_layer is not None:
nll_grads = torch.autograd.grad(nll_loss, last_layer, retain_graph=True)[0]
g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0]
nll_grads = torch.autograd.grad(
nll_loss, last_layer, retain_graph=True
)[0]
g_grads = torch.autograd.grad(
g_loss, last_layer, retain_graph=True
)[0]
else:
nll_grads = torch.autograd.grad(nll_loss, self.last_layer[0], retain_graph=True)[0]
g_grads = torch.autograd.grad(g_loss, self.last_layer[0], retain_graph=True)[0]
nll_grads = torch.autograd.grad(
nll_loss, self.last_layer[0], retain_graph=True
)[0]
g_grads = torch.autograd.grad(
g_loss, self.last_layer[0], retain_graph=True
)[0]
d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4)
d_weight = torch.clamp(d_weight, 0.0, 1e4).detach()
d_weight = d_weight * self.discriminator_weight
return d_weight
def forward(self, codebook_loss, inputs, reconstructions, optimizer_idx,
global_step, last_layer=None, cond=None, split="train", predicted_indices=None):
def forward(
self,
codebook_loss,
inputs,
reconstructions,
optimizer_idx,
global_step,
last_layer=None,
cond=None,
split='train',
predicted_indices=None,
):
if not exists(codebook_loss):
codebook_loss = torch.tensor([0.]).to(inputs.device)
#rec_loss = torch.abs(inputs.contiguous() - reconstructions.contiguous())
rec_loss = self.pixel_loss(inputs.contiguous(), reconstructions.contiguous())
codebook_loss = torch.tensor([0.0]).to(inputs.device)
# rec_loss = torch.abs(inputs.contiguous() - reconstructions.contiguous())
rec_loss = self.pixel_loss(
inputs.contiguous(), reconstructions.contiguous()
)
if self.perceptual_weight > 0:
p_loss = self.perceptual_loss(inputs.contiguous(), reconstructions.contiguous())
p_loss = self.perceptual_loss(
inputs.contiguous(), reconstructions.contiguous()
)
rec_loss = rec_loss + self.perceptual_weight * p_loss
else:
p_loss = torch.tensor([0.0])
nll_loss = rec_loss
#nll_loss = torch.sum(nll_loss) / nll_loss.shape[0]
# nll_loss = torch.sum(nll_loss) / nll_loss.shape[0]
nll_loss = torch.mean(nll_loss)
# now the GAN part
@@ -119,49 +164,77 @@ class VQLPIPSWithDiscriminator(nn.Module):
logits_fake = self.discriminator(reconstructions.contiguous())
else:
assert self.disc_conditional
logits_fake = self.discriminator(torch.cat((reconstructions.contiguous(), cond), dim=1))
logits_fake = self.discriminator(
torch.cat((reconstructions.contiguous(), cond), dim=1)
)
g_loss = -torch.mean(logits_fake)
try:
d_weight = self.calculate_adaptive_weight(nll_loss, g_loss, last_layer=last_layer)
d_weight = self.calculate_adaptive_weight(
nll_loss, g_loss, last_layer=last_layer
)
except RuntimeError:
assert not self.training
d_weight = torch.tensor(0.0)
disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start)
loss = nll_loss + d_weight * disc_factor * g_loss + self.codebook_weight * codebook_loss.mean()
disc_factor = adopt_weight(
self.disc_factor,
global_step,
threshold=self.discriminator_iter_start,
)
loss = (
nll_loss
+ d_weight * disc_factor * g_loss
+ self.codebook_weight * codebook_loss.mean()
)
log = {"{}/total_loss".format(split): loss.clone().detach().mean(),
"{}/quant_loss".format(split): codebook_loss.detach().mean(),
"{}/nll_loss".format(split): nll_loss.detach().mean(),
"{}/rec_loss".format(split): rec_loss.detach().mean(),
"{}/p_loss".format(split): p_loss.detach().mean(),
"{}/d_weight".format(split): d_weight.detach(),
"{}/disc_factor".format(split): torch.tensor(disc_factor),
"{}/g_loss".format(split): g_loss.detach().mean(),
}
log = {
'{}/total_loss'.format(split): loss.clone().detach().mean(),
'{}/quant_loss'.format(split): codebook_loss.detach().mean(),
'{}/nll_loss'.format(split): nll_loss.detach().mean(),
'{}/rec_loss'.format(split): rec_loss.detach().mean(),
'{}/p_loss'.format(split): p_loss.detach().mean(),
'{}/d_weight'.format(split): d_weight.detach(),
'{}/disc_factor'.format(split): torch.tensor(disc_factor),
'{}/g_loss'.format(split): g_loss.detach().mean(),
}
if predicted_indices is not None:
assert self.n_classes is not None
with torch.no_grad():
perplexity, cluster_usage = measure_perplexity(predicted_indices, self.n_classes)
log[f"{split}/perplexity"] = perplexity
log[f"{split}/cluster_usage"] = cluster_usage
perplexity, cluster_usage = measure_perplexity(
predicted_indices, self.n_classes
)
log[f'{split}/perplexity'] = perplexity
log[f'{split}/cluster_usage'] = cluster_usage
return loss, log
if optimizer_idx == 1:
# second pass for discriminator update
if cond is None:
logits_real = self.discriminator(inputs.contiguous().detach())
logits_fake = self.discriminator(reconstructions.contiguous().detach())
logits_fake = self.discriminator(
reconstructions.contiguous().detach()
)
else:
logits_real = self.discriminator(torch.cat((inputs.contiguous().detach(), cond), dim=1))
logits_fake = self.discriminator(torch.cat((reconstructions.contiguous().detach(), cond), dim=1))
logits_real = self.discriminator(
torch.cat((inputs.contiguous().detach(), cond), dim=1)
)
logits_fake = self.discriminator(
torch.cat(
(reconstructions.contiguous().detach(), cond), dim=1
)
)
disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start)
disc_factor = adopt_weight(
self.disc_factor,
global_step,
threshold=self.discriminator_iter_start,
)
d_loss = disc_factor * self.disc_loss(logits_real, logits_fake)
log = {"{}/disc_loss".format(split): d_loss.clone().detach().mean(),
"{}/logits_real".format(split): logits_real.detach().mean(),
"{}/logits_fake".format(split): logits_fake.detach().mean()
}
log = {
'{}/disc_loss'.format(split): d_loss.clone().detach().mean(),
'{}/logits_real'.format(split): logits_real.detach().mean(),
'{}/logits_fake'.format(split): logits_fake.detach().mean(),
}
return d_loss, log

View File

@@ -11,15 +11,13 @@ from einops import rearrange, repeat, reduce
DEFAULT_DIM_HEAD = 64
Intermediates = namedtuple('Intermediates', [
'pre_softmax_attn',
'post_softmax_attn'
])
Intermediates = namedtuple(
'Intermediates', ['pre_softmax_attn', 'post_softmax_attn']
)
LayerIntermediates = namedtuple('Intermediates', [
'hiddens',
'attn_intermediates'
])
LayerIntermediates = namedtuple(
'Intermediates', ['hiddens', 'attn_intermediates']
)
class AbsolutePositionalEmbedding(nn.Module):
@@ -39,11 +37,16 @@ class AbsolutePositionalEmbedding(nn.Module):
class FixedPositionalEmbedding(nn.Module):
def __init__(self, dim):
super().__init__()
inv_freq = 1. / (10000 ** (torch.arange(0, dim, 2).float() / dim))
inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim))
self.register_buffer('inv_freq', inv_freq)
def forward(self, x, seq_dim=1, offset=0):
t = torch.arange(x.shape[seq_dim], device=x.device).type_as(self.inv_freq) + offset
t = (
torch.arange(x.shape[seq_dim], device=x.device).type_as(
self.inv_freq
)
+ offset
)
sinusoid_inp = torch.einsum('i , j -> i j', t, self.inv_freq)
emb = torch.cat((sinusoid_inp.sin(), sinusoid_inp.cos()), dim=-1)
return emb[None, :, :]
@@ -51,6 +54,7 @@ class FixedPositionalEmbedding(nn.Module):
# helpers
def exists(val):
return val is not None
@@ -64,18 +68,21 @@ def default(val, d):
def always(val):
def inner(*args, **kwargs):
return val
return inner
def not_equals(val):
def inner(x):
return x != val
return inner
def equals(val):
def inner(x):
return x == val
return inner
@@ -85,6 +92,7 @@ def max_neg_value(tensor):
# keyword argument helpers
def pick_and_pop(keys, d):
values = list(map(lambda key: d.pop(key), keys))
return dict(zip(keys, values))
@@ -108,8 +116,15 @@ def group_by_key_prefix(prefix, d):
def groupby_prefix_and_trim(prefix, d):
kwargs_with_prefix, kwargs = group_dict_by_key(partial(string_begins_with, prefix), d)
kwargs_without_prefix = dict(map(lambda x: (x[0][len(prefix):], x[1]), tuple(kwargs_with_prefix.items())))
kwargs_with_prefix, kwargs = group_dict_by_key(
partial(string_begins_with, prefix), d
)
kwargs_without_prefix = dict(
map(
lambda x: (x[0][len(prefix) :], x[1]),
tuple(kwargs_with_prefix.items()),
)
)
return kwargs_without_prefix, kwargs
@@ -139,7 +154,7 @@ class Rezero(nn.Module):
class ScaleNorm(nn.Module):
def __init__(self, dim, eps=1e-5):
super().__init__()
self.scale = dim ** -0.5
self.scale = dim**-0.5
self.eps = eps
self.g = nn.Parameter(torch.ones(1))
@@ -151,7 +166,7 @@ class ScaleNorm(nn.Module):
class RMSNorm(nn.Module):
def __init__(self, dim, eps=1e-8):
super().__init__()
self.scale = dim ** -0.5
self.scale = dim**-0.5
self.eps = eps
self.g = nn.Parameter(torch.ones(dim))
@@ -173,7 +188,7 @@ class GRUGating(nn.Module):
def forward(self, x, residual):
gated_output = self.gru(
rearrange(x, 'b n d -> (b n) d'),
rearrange(residual, 'b n d -> (b n) d')
rearrange(residual, 'b n d -> (b n) d'),
)
return gated_output.reshape_as(x)
@@ -181,6 +196,7 @@ class GRUGating(nn.Module):
# feedforward
class GEGLU(nn.Module):
def __init__(self, dim_in, dim_out):
super().__init__()
@@ -192,19 +208,18 @@ class GEGLU(nn.Module):
class FeedForward(nn.Module):
def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.):
def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.0):
super().__init__()
inner_dim = int(dim * mult)
dim_out = default(dim_out, dim)
project_in = nn.Sequential(
nn.Linear(dim, inner_dim),
nn.GELU()
) if not glu else GEGLU(dim, inner_dim)
project_in = (
nn.Sequential(nn.Linear(dim, inner_dim), nn.GELU())
if not glu
else GEGLU(dim, inner_dim)
)
self.net = nn.Sequential(
project_in,
nn.Dropout(dropout),
nn.Linear(inner_dim, dim_out)
project_in, nn.Dropout(dropout), nn.Linear(inner_dim, dim_out)
)
def forward(self, x):
@@ -214,23 +229,25 @@ class FeedForward(nn.Module):
# attention.
class Attention(nn.Module):
def __init__(
self,
dim,
dim_head=DEFAULT_DIM_HEAD,
heads=8,
causal=False,
mask=None,
talking_heads=False,
sparse_topk=None,
use_entmax15=False,
num_mem_kv=0,
dropout=0.,
on_attn=False
self,
dim,
dim_head=DEFAULT_DIM_HEAD,
heads=8,
causal=False,
mask=None,
talking_heads=False,
sparse_topk=None,
use_entmax15=False,
num_mem_kv=0,
dropout=0.0,
on_attn=False,
):
super().__init__()
if use_entmax15:
raise NotImplementedError("Check out entmax activation instead of softmax activation!")
self.scale = dim_head ** -0.5
raise NotImplementedError(
'Check out entmax activation instead of softmax activation!'
)
self.scale = dim_head**-0.5
self.heads = heads
self.causal = causal
self.mask = mask
@@ -252,7 +269,7 @@ class Attention(nn.Module):
self.sparse_topk = sparse_topk
# entmax
#self.attn_fn = entmax15 if use_entmax15 else F.softmax
# self.attn_fn = entmax15 if use_entmax15 else F.softmax
self.attn_fn = F.softmax
# add memory key / values
@@ -263,20 +280,29 @@ class Attention(nn.Module):
# attention on attention
self.attn_on_attn = on_attn
self.to_out = nn.Sequential(nn.Linear(inner_dim, dim * 2), nn.GLU()) if on_attn else nn.Linear(inner_dim, dim)
self.to_out = (
nn.Sequential(nn.Linear(inner_dim, dim * 2), nn.GLU())
if on_attn
else nn.Linear(inner_dim, dim)
)
def forward(
self,
x,
context=None,
mask=None,
context_mask=None,
rel_pos=None,
sinusoidal_emb=None,
prev_attn=None,
mem=None
self,
x,
context=None,
mask=None,
context_mask=None,
rel_pos=None,
sinusoidal_emb=None,
prev_attn=None,
mem=None,
):
b, n, _, h, talking_heads, device = *x.shape, self.heads, self.talking_heads, x.device
b, n, _, h, talking_heads, device = (
*x.shape,
self.heads,
self.talking_heads,
x.device,
)
kv_input = default(context, x)
q_input = x
@@ -297,23 +323,35 @@ class Attention(nn.Module):
k = self.to_k(k_input)
v = self.to_v(v_input)
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h=h), (q, k, v))
q, k, v = map(
lambda t: rearrange(t, 'b n (h d) -> b h n d', h=h), (q, k, v)
)
input_mask = None
if any(map(exists, (mask, context_mask))):
q_mask = default(mask, lambda: torch.ones((b, n), device=device).bool())
q_mask = default(
mask, lambda: torch.ones((b, n), device=device).bool()
)
k_mask = q_mask if not exists(context) else context_mask
k_mask = default(k_mask, lambda: torch.ones((b, k.shape[-2]), device=device).bool())
k_mask = default(
k_mask,
lambda: torch.ones((b, k.shape[-2]), device=device).bool(),
)
q_mask = rearrange(q_mask, 'b i -> b () i ()')
k_mask = rearrange(k_mask, 'b j -> b () () j')
input_mask = q_mask * k_mask
if self.num_mem_kv > 0:
mem_k, mem_v = map(lambda t: repeat(t, 'h n d -> b h n d', b=b), (self.mem_k, self.mem_v))
mem_k, mem_v = map(
lambda t: repeat(t, 'h n d -> b h n d', b=b),
(self.mem_k, self.mem_v),
)
k = torch.cat((mem_k, k), dim=-2)
v = torch.cat((mem_v, v), dim=-2)
if exists(input_mask):
input_mask = F.pad(input_mask, (self.num_mem_kv, 0), value=True)
input_mask = F.pad(
input_mask, (self.num_mem_kv, 0), value=True
)
dots = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale
mask_value = max_neg_value(dots)
@@ -324,7 +362,9 @@ class Attention(nn.Module):
pre_softmax_attn = dots
if talking_heads:
dots = einsum('b h i j, h k -> b k i j', dots, self.pre_softmax_proj).contiguous()
dots = einsum(
'b h i j, h k -> b k i j', dots, self.pre_softmax_proj
).contiguous()
if exists(rel_pos):
dots = rel_pos(dots)
@@ -336,7 +376,9 @@ class Attention(nn.Module):
if self.causal:
i, j = dots.shape[-2:]
r = torch.arange(i, device=device)
mask = rearrange(r, 'i -> () () i ()') < rearrange(r, 'j -> () () () j')
mask = rearrange(r, 'i -> () () i ()') < rearrange(
r, 'j -> () () () j'
)
mask = F.pad(mask, (j - i, 0), value=False)
dots.masked_fill_(mask, mask_value)
del mask
@@ -354,14 +396,16 @@ class Attention(nn.Module):
attn = self.dropout(attn)
if talking_heads:
attn = einsum('b h i j, h k -> b k i j', attn, self.post_softmax_proj).contiguous()
attn = einsum(
'b h i j, h k -> b k i j', attn, self.post_softmax_proj
).contiguous()
out = einsum('b h i j, b h j d -> b h i d', attn, v)
out = rearrange(out, 'b h n d -> b n (h d)')
intermediates = Intermediates(
pre_softmax_attn=pre_softmax_attn,
post_softmax_attn=post_softmax_attn
post_softmax_attn=post_softmax_attn,
)
return self.to_out(out), intermediates
@@ -369,28 +413,28 @@ class Attention(nn.Module):
class AttentionLayers(nn.Module):
def __init__(
self,
dim,
depth,
heads=8,
causal=False,
cross_attend=False,
only_cross=False,
use_scalenorm=False,
use_rmsnorm=False,
use_rezero=False,
rel_pos_num_buckets=32,
rel_pos_max_distance=128,
position_infused_attn=False,
custom_layers=None,
sandwich_coef=None,
par_ratio=None,
residual_attn=False,
cross_residual_attn=False,
macaron=False,
pre_norm=True,
gate_residual=False,
**kwargs
self,
dim,
depth,
heads=8,
causal=False,
cross_attend=False,
only_cross=False,
use_scalenorm=False,
use_rmsnorm=False,
use_rezero=False,
rel_pos_num_buckets=32,
rel_pos_max_distance=128,
position_infused_attn=False,
custom_layers=None,
sandwich_coef=None,
par_ratio=None,
residual_attn=False,
cross_residual_attn=False,
macaron=False,
pre_norm=True,
gate_residual=False,
**kwargs,
):
super().__init__()
ff_kwargs, kwargs = groupby_prefix_and_trim('ff_', kwargs)
@@ -403,10 +447,14 @@ class AttentionLayers(nn.Module):
self.layers = nn.ModuleList([])
self.has_pos_emb = position_infused_attn
self.pia_pos_emb = FixedPositionalEmbedding(dim) if position_infused_attn else None
self.pia_pos_emb = (
FixedPositionalEmbedding(dim) if position_infused_attn else None
)
self.rotary_pos_emb = always(None)
assert rel_pos_num_buckets <= rel_pos_max_distance, 'number of relative position buckets must be less than the relative position max distance'
assert (
rel_pos_num_buckets <= rel_pos_max_distance
), 'number of relative position buckets must be less than the relative position max distance'
self.rel_pos = None
self.pre_norm = pre_norm
@@ -438,15 +486,27 @@ class AttentionLayers(nn.Module):
assert 1 < par_ratio <= par_depth, 'par ratio out of range'
default_block = tuple(filter(not_equals('f'), default_block))
par_attn = par_depth // par_ratio
depth_cut = par_depth * 2 // 3 # 2 / 3 attention layer cutoff suggested by PAR paper
depth_cut = (
par_depth * 2 // 3
) # 2 / 3 attention layer cutoff suggested by PAR paper
par_width = (depth_cut + depth_cut // par_attn) // par_attn
assert len(default_block) <= par_width, 'default block is too large for par_ratio'
par_block = default_block + ('f',) * (par_width - len(default_block))
assert (
len(default_block) <= par_width
), 'default block is too large for par_ratio'
par_block = default_block + ('f',) * (
par_width - len(default_block)
)
par_head = par_block * par_attn
layer_types = par_head + ('f',) * (par_depth - len(par_head))
elif exists(sandwich_coef):
assert sandwich_coef > 0 and sandwich_coef <= depth, 'sandwich coefficient should be less than the depth'
layer_types = ('a',) * sandwich_coef + default_block * (depth - sandwich_coef) + ('f',) * sandwich_coef
assert (
sandwich_coef > 0 and sandwich_coef <= depth
), 'sandwich coefficient should be less than the depth'
layer_types = (
('a',) * sandwich_coef
+ default_block * (depth - sandwich_coef)
+ ('f',) * sandwich_coef
)
else:
layer_types = default_block * depth
@@ -455,7 +515,9 @@ class AttentionLayers(nn.Module):
for layer_type in self.layer_types:
if layer_type == 'a':
layer = Attention(dim, heads=heads, causal=causal, **attn_kwargs)
layer = Attention(
dim, heads=heads, causal=causal, **attn_kwargs
)
elif layer_type == 'c':
layer = Attention(dim, heads=heads, **attn_kwargs)
elif layer_type == 'f':
@@ -472,21 +534,17 @@ class AttentionLayers(nn.Module):
else:
residual_fn = Residual()
self.layers.append(nn.ModuleList([
norm_fn(),
layer,
residual_fn
]))
self.layers.append(nn.ModuleList([norm_fn(), layer, residual_fn]))
def forward(
self,
x,
context=None,
mask=None,
context_mask=None,
mems=None,
return_hiddens=False,
**kwargs
self,
x,
context=None,
mask=None,
context_mask=None,
mems=None,
return_hiddens=False,
**kwargs,
):
hiddens = []
intermediates = []
@@ -495,7 +553,9 @@ class AttentionLayers(nn.Module):
mems = mems.copy() if exists(mems) else [None] * self.num_attn_layers
for ind, (layer_type, (norm, block, residual_fn)) in enumerate(zip(self.layer_types, self.layers)):
for ind, (layer_type, (norm, block, residual_fn)) in enumerate(
zip(self.layer_types, self.layers)
):
is_last = ind == (len(self.layers) - 1)
if layer_type == 'a':
@@ -508,10 +568,22 @@ class AttentionLayers(nn.Module):
x = norm(x)
if layer_type == 'a':
out, inter = block(x, mask=mask, sinusoidal_emb=self.pia_pos_emb, rel_pos=self.rel_pos,
prev_attn=prev_attn, mem=layer_mem)
out, inter = block(
x,
mask=mask,
sinusoidal_emb=self.pia_pos_emb,
rel_pos=self.rel_pos,
prev_attn=prev_attn,
mem=layer_mem,
)
elif layer_type == 'c':
out, inter = block(x, context=context, mask=mask, context_mask=context_mask, prev_attn=prev_cross_attn)
out, inter = block(
x,
context=context,
mask=mask,
context_mask=context_mask,
prev_attn=prev_cross_attn,
)
elif layer_type == 'f':
out = block(x)
@@ -530,8 +602,7 @@ class AttentionLayers(nn.Module):
if return_hiddens:
intermediates = LayerIntermediates(
hiddens=hiddens,
attn_intermediates=intermediates
hiddens=hiddens, attn_intermediates=intermediates
)
return x, intermediates
@@ -545,23 +616,24 @@ class Encoder(AttentionLayers):
super().__init__(causal=False, **kwargs)
class TransformerWrapper(nn.Module):
def __init__(
self,
*,
num_tokens,
max_seq_len,
attn_layers,
emb_dim=None,
max_mem_len=0.,
emb_dropout=0.,
num_memory_tokens=None,
tie_embedding=False,
use_pos_emb=True
self,
*,
num_tokens,
max_seq_len,
attn_layers,
emb_dim=None,
max_mem_len=0.0,
emb_dropout=0.0,
num_memory_tokens=None,
tie_embedding=False,
use_pos_emb=True,
):
super().__init__()
assert isinstance(attn_layers, AttentionLayers), 'attention layers must be one of Encoder or Decoder'
assert isinstance(
attn_layers, AttentionLayers
), 'attention layers must be one of Encoder or Decoder'
dim = attn_layers.dim
emb_dim = default(emb_dim, dim)
@@ -571,23 +643,34 @@ class TransformerWrapper(nn.Module):
self.num_tokens = num_tokens
self.token_emb = nn.Embedding(num_tokens, emb_dim)
self.pos_emb = AbsolutePositionalEmbedding(emb_dim, max_seq_len) if (
use_pos_emb and not attn_layers.has_pos_emb) else always(0)
self.pos_emb = (
AbsolutePositionalEmbedding(emb_dim, max_seq_len)
if (use_pos_emb and not attn_layers.has_pos_emb)
else always(0)
)
self.emb_dropout = nn.Dropout(emb_dropout)
self.project_emb = nn.Linear(emb_dim, dim) if emb_dim != dim else nn.Identity()
self.project_emb = (
nn.Linear(emb_dim, dim) if emb_dim != dim else nn.Identity()
)
self.attn_layers = attn_layers
self.norm = nn.LayerNorm(dim)
self.init_()
self.to_logits = nn.Linear(dim, num_tokens) if not tie_embedding else lambda t: t @ self.token_emb.weight.t()
self.to_logits = (
nn.Linear(dim, num_tokens)
if not tie_embedding
else lambda t: t @ self.token_emb.weight.t()
)
# memory tokens (like [cls]) from Memory Transformers paper
num_memory_tokens = default(num_memory_tokens, 0)
self.num_memory_tokens = num_memory_tokens
if num_memory_tokens > 0:
self.memory_tokens = nn.Parameter(torch.randn(num_memory_tokens, dim))
self.memory_tokens = nn.Parameter(
torch.randn(num_memory_tokens, dim)
)
# let funnel encoder know number of memory tokens, if specified
if hasattr(attn_layers, 'num_memory_tokens'):
@@ -597,20 +680,20 @@ class TransformerWrapper(nn.Module):
nn.init.normal_(self.token_emb.weight, std=0.02)
def forward(
self,
x,
return_embeddings=False,
mask=None,
return_mems=False,
return_attn=False,
mems=None,
embedding_manager=None,
**kwargs
self,
x,
return_embeddings=False,
mask=None,
return_mems=False,
return_attn=False,
mems=None,
embedding_manager=None,
**kwargs,
):
b, n, device, num_mem = *x.shape, x.device, self.num_memory_tokens
embedded_x = self.token_emb(x)
if embedding_manager:
x = embedding_manager(x, embedded_x)
else:
@@ -629,7 +712,9 @@ class TransformerWrapper(nn.Module):
if exists(mask):
mask = F.pad(mask, (num_mem, 0), value=True)
x, intermediates = self.attn_layers(x, mask=mask, mems=mems, return_hiddens=True, **kwargs)
x, intermediates = self.attn_layers(
x, mask=mask, mems=mems, return_hiddens=True, **kwargs
)
x = self.norm(x)
mem, x = x[:, :num_mem], x[:, num_mem:]
@@ -638,13 +723,30 @@ class TransformerWrapper(nn.Module):
if return_mems:
hiddens = intermediates.hiddens
new_mems = list(map(lambda pair: torch.cat(pair, dim=-2), zip(mems, hiddens))) if exists(mems) else hiddens
new_mems = list(map(lambda t: t[..., -self.max_mem_len:, :].detach(), new_mems))
new_mems = (
list(
map(
lambda pair: torch.cat(pair, dim=-2),
zip(mems, hiddens),
)
)
if exists(mems)
else hiddens
)
new_mems = list(
map(
lambda t: t[..., -self.max_mem_len :, :].detach(), new_mems
)
)
return out, new_mems
if return_attn:
attn_maps = list(map(lambda t: t.post_softmax_attn, intermediates.attn_intermediates))
attn_maps = list(
map(
lambda t: t.post_softmax_attn,
intermediates.attn_intermediates,
)
)
return out, attn_maps
return out

File diff suppressed because it is too large Load Diff

View File

@@ -13,22 +13,25 @@ from queue import Queue
from inspect import isfunction
from PIL import Image, ImageDraw, ImageFont
def log_txt_as_img(wh, xc, size=10):
# wh a tuple of (width, height)
# xc a list of captions to plot
b = len(xc)
txts = list()
for bi in range(b):
txt = Image.new("RGB", wh, color="white")
txt = Image.new('RGB', wh, color='white')
draw = ImageDraw.Draw(txt)
font = ImageFont.load_default()
nc = int(40 * (wh[0] / 256))
lines = "\n".join(xc[bi][start:start + nc] for start in range(0, len(xc[bi]), nc))
lines = '\n'.join(
xc[bi][start : start + nc] for start in range(0, len(xc[bi]), nc)
)
try:
draw.text((0, 0), lines, fill="black", font=font)
draw.text((0, 0), lines, fill='black', font=font)
except UnicodeEncodeError:
print("Cant encode string for logging. Skipping.")
print('Cant encode string for logging. Skipping.')
txt = np.array(txt).transpose(2, 0, 1) / 127.5 - 1.0
txts.append(txt)
@@ -70,22 +73,26 @@ def mean_flat(tensor):
def count_params(model, verbose=False):
total_params = sum(p.numel() for p in model.parameters())
if verbose:
print(f"{model.__class__.__name__} has {total_params * 1.e-6:.2f} M params.")
print(
f'{model.__class__.__name__} has {total_params * 1.e-6:.2f} M params.'
)
return total_params
def instantiate_from_config(config, **kwargs):
if not "target" in config:
if not 'target' in config:
if config == '__is_first_stage__':
return None
elif config == "__is_unconditional__":
elif config == '__is_unconditional__':
return None
raise KeyError("Expected key `target` to instantiate.")
return get_obj_from_str(config["target"])(**config.get("params", dict()), **kwargs)
raise KeyError('Expected key `target` to instantiate.')
return get_obj_from_str(config['target'])(
**config.get('params', dict()), **kwargs
)
def get_obj_from_str(string, reload=False):
module, cls = string.rsplit(".", 1)
module, cls = string.rsplit('.', 1)
if reload:
module_imp = importlib.import_module(module)
importlib.reload(module_imp)
@@ -101,31 +108,36 @@ def _do_parallel_data_prefetch(func, Q, data, idx, idx_to_fn=False):
else:
res = func(data)
Q.put([idx, res])
Q.put("Done")
Q.put('Done')
def parallel_data_prefetch(
func: callable, data, n_proc, target_data_type="ndarray", cpu_intensive=True, use_worker_id=False
func: callable,
data,
n_proc,
target_data_type='ndarray',
cpu_intensive=True,
use_worker_id=False,
):
# if target_data_type not in ["ndarray", "list"]:
# raise ValueError(
# "Data, which is passed to parallel_data_prefetch has to be either of type list or ndarray."
# )
if isinstance(data, np.ndarray) and target_data_type == "list":
raise ValueError("list expected but function got ndarray.")
if isinstance(data, np.ndarray) and target_data_type == 'list':
raise ValueError('list expected but function got ndarray.')
elif isinstance(data, abc.Iterable):
if isinstance(data, dict):
print(
f'WARNING:"data" argument passed to parallel_data_prefetch is a dict: Using only its values and disregarding keys.'
)
data = list(data.values())
if target_data_type == "ndarray":
if target_data_type == 'ndarray':
data = np.asarray(data)
else:
data = list(data)
else:
raise TypeError(
f"The data, that shall be processed parallel has to be either an np.ndarray or an Iterable, but is actually {type(data)}."
f'The data, that shall be processed parallel has to be either an np.ndarray or an Iterable, but is actually {type(data)}.'
)
if cpu_intensive:
@@ -135,7 +147,7 @@ def parallel_data_prefetch(
Q = Queue(1000)
proc = Thread
# spawn processes
if target_data_type == "ndarray":
if target_data_type == 'ndarray':
arguments = [
[func, Q, part, i, use_worker_id]
for i, part in enumerate(np.array_split(data, n_proc))
@@ -149,7 +161,7 @@ def parallel_data_prefetch(
arguments = [
[func, Q, part, i, use_worker_id]
for i, part in enumerate(
[data[i: i + step] for i in range(0, len(data), step)]
[data[i : i + step] for i in range(0, len(data), step)]
)
]
processes = []
@@ -158,7 +170,7 @@ def parallel_data_prefetch(
processes += [p]
# start processes
print(f"Start prefetching...")
print(f'Start prefetching...')
import time
start = time.time()
@@ -171,13 +183,13 @@ def parallel_data_prefetch(
while k < n_proc:
# get result
res = Q.get()
if res == "Done":
if res == 'Done':
k += 1
else:
gather_res[res[0]] = res[1]
except Exception as e:
print("Exception: ", e)
print('Exception: ', e)
for p in processes:
p.terminate()
@@ -185,7 +197,7 @@ def parallel_data_prefetch(
finally:
for p in processes:
p.join()
print(f"Prefetching complete. [{time.time() - start} sec.]")
print(f'Prefetching complete. [{time.time() - start} sec.]')
if target_data_type == 'ndarray':
if not isinstance(gather_res[0], np.ndarray):

701
main.py

File diff suppressed because it is too large Load Diff

View File

@@ -14,7 +14,7 @@ from ldm.models.diffusion.ddim import DDIMSampler
from ldm.util import ismap
import time
from omegaconf import OmegaConf
from ldm.dream.devices import choose_torch_device
def download_models(mode):
@@ -117,7 +117,8 @@ def get_cond(mode, selected_path):
c = rearrange(c, '1 c h w -> 1 h w c')
c = 2. * c - 1.
c = c.to(torch.device("cuda"))
device = choose_torch_device()
c = c.to(device)
example["LR_image"] = c
example["image"] = c_up
@@ -267,4 +268,4 @@ def make_convolutional_sample(batch, model, mode="vanilla", custom_steps=None, e
log["sample"] = x_sample
log["time"] = t1 - t0
return log
return log

View File

@@ -1,23 +1,23 @@
accelerate==0.12.0
albumentations==0.4.3
einops==0.3.0
huggingface-hub==0.8.1
imageio==2.9.0
imageio-ffmpeg==0.4.2
kornia==0.6.0
numpy==1.19.2
numpy==1.23.1
--pre torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/nightly/cpu
omegaconf==2.1.1
opencv-python==4.1.2.30
pillow==9.0.1
opencv-python==4.6.0.66
pillow==9.2.0
pudb==2019.2
pytorch
torch==1.12.1
torchvision==0.12.0
pytorch-lightning==1.4.2
streamlit==1.12.0
test-tube>=0.7.5
torch-fidelity==0.3.0
torchmetrics==0.6.0
torchvision
transformers==4.19.2
-e git+https://github.com/openai/CLIP.git@main#egg=clip
-e git+https://github.com/CompVis/taming-transformers.git@master#egg=taming-transformers
-e git+https://github.com/lstein/k-diffusion.git@master#egg=k-diffusion
-e git+https://github.com/Birch-san/k-diffusion.git@mps#egg=k-diffusion

View File

@@ -3,426 +3,631 @@
import argparse
import shlex
import atexit
import os
import re
import sys
import copy
from PIL import Image,PngImagePlugin
# readline unavailable on windows systems
try:
import readline
readline_available = True
except:
readline_available = False
debugging = False
import warnings
import time
import ldm.dream.readline
from ldm.dream.pngwriter import PngWriter, PromptFormatter
from ldm.dream.server import DreamServer, ThreadingDreamServer
from ldm.dream.image_util import make_grid
from omegaconf import OmegaConf
def main():
''' Initialize command-line parsers and the diffusion model '''
"""Initialize command-line parsers and the diffusion model"""
arg_parser = create_argv_parser()
opt = arg_parser.parse_args()
opt = arg_parser.parse_args()
if opt.laion400m:
# defaults suitable to the older latent diffusion weights
width = 256
height = 256
config = "configs/latent-diffusion/txt2img-1p4B-eval.yaml"
weights = "models/ldm/text2img-large/model.ckpt"
else:
# some defaults suitable for stable diffusion weights
width = 512
height = 512
config = "configs/stable-diffusion/v1-inference.yaml"
weights = "models/ldm/stable-diffusion-v1/model.ckpt"
print('--laion400m flag has been deprecated. Please use --model laion400m instead.')
sys.exit(-1)
if opt.weights != 'model':
print('--weights argument has been deprecated. Please configure ./configs/models.yaml, and call it using --model instead.')
sys.exit(-1)
try:
models = OmegaConf.load(opt.config)
width = models[opt.model].width
height = models[opt.model].height
config = models[opt.model].config
weights = models[opt.model].weights
except (FileNotFoundError, IOError, KeyError) as e:
print(f'{e}. Aborting.')
sys.exit(-1)
# command line history will be stored in a file called "~/.dream_history"
if readline_available:
setup_readline()
print("* Initializing, be patient...\n")
print('* Initializing, be patient...\n')
sys.path.append('.')
from pytorch_lightning import logging
from ldm.simplet2i import T2I
# these two lines prevent a horrible warning message from appearing
# when the frozen CLIP tokenizer is imported
import transformers
transformers.logging.set_verbosity_error()
# creating a simple text2image object with a handful of
# defaults passed on the command line.
# additional parameters will be added (or overriden) during
# the user input loop
t2i = T2I(width=width,
height=height,
batch_size=opt.batch_size,
outdir=opt.outdir,
sampler_name=opt.sampler_name,
weights=weights,
full_precision=opt.full_precision,
config=config,
latent_diffusion_weights=opt.laion400m, # this is solely for recreating the prompt
embedding_path=opt.embedding_path,
device=opt.device
t2i = T2I(
width=width,
height=height,
sampler_name=opt.sampler_name,
weights=weights,
full_precision=opt.full_precision,
config=config,
grid = opt.grid,
# this is solely for recreating the prompt
latent_diffusion_weights=opt.laion400m,
embedding_path=opt.embedding_path,
device_type=opt.device
)
# make sure the output directory exists
if not os.path.exists(opt.outdir):
os.makedirs(opt.outdir)
# gets rid of annoying messages about random seed
logging.getLogger("pytorch_lightning").setLevel(logging.ERROR)
# gets rid of annoying messages about random seed
logging.getLogger('pytorch_lightning').setLevel(logging.ERROR)
# load the infile as a list of lines
infile = None
try:
if opt.infile is not None:
infile = open(opt.infile,'r')
except FileNotFoundError as e:
print(e)
exit(-1)
if opt.infile:
try:
if os.path.isfile(opt.infile):
infile = open(opt.infile, 'r', encoding='utf-8')
elif opt.infile == '-': # stdin
infile = sys.stdin
else:
raise FileNotFoundError(f'{opt.infile} not found.')
except (FileNotFoundError, IOError) as e:
print(f'{e}. Aborting.')
sys.exit(-1)
# preload the model
tic = time.time()
t2i.load_model()
print("\n* Initialization done! Awaiting your command (-h for help, 'q' to quit, 'cd' to change output dir, 'pwd' to print output dir)...")
print(
f'>> model loaded in', '%4.2fs' % (time.time() - tic)
)
log_path = os.path.join(opt.outdir,'dream_log.txt')
with open(log_path,'a') as log:
cmd_parser = create_cmd_parser()
main_loop(t2i,cmd_parser,log,infile)
log.close()
if infile:
infile.close()
if not infile:
print(
"\n* Initialization done! Awaiting your command (-h for help, 'q' to quit)"
)
cmd_parser = create_cmd_parser()
if opt.web:
dream_server_loop(t2i, opt.host, opt.port)
else:
main_loop(t2i, opt.outdir, opt.prompt_as_dir, cmd_parser, infile)
def main_loop(t2i,parser,log,infile):
''' prompt/read/execute loop '''
def main_loop(t2i, outdir, prompt_as_dir, parser, infile):
"""prompt/read/execute loop"""
done = False
last_seeds = []
path_filter = re.compile(r'[<>:"/\\|?*]')
# os.pathconf is not available on Windows
if hasattr(os, 'pathconf'):
path_max = os.pathconf(outdir, 'PC_PATH_MAX')
name_max = os.pathconf(outdir, 'PC_NAME_MAX')
else:
path_max = 260
name_max = 255
while not done:
try:
command = infile.readline() if infile else input("dream> ")
command = get_next_command(infile)
except EOFError:
done = True
break
if infile and len(command)==0:
done = True
break
# skip empty lines
if not command.strip():
continue
if command.startswith(('#','//')):
if command.startswith(('#', '//')):
continue
# before splitting, escape single quotes so as not to mess
# up the parser
command = command.replace("'","\\'")
command = command.replace("'", "\\'")
try:
elements = shlex.split(command)
except ValueError as e:
print(str(e))
continue
if len(elements)==0:
continue
if elements[0]=='q':
if elements[0] == 'q':
done = True
break
if elements[0]=='cd' and len(elements)>1:
if os.path.exists(elements[1]):
print(f"setting image output directory to {elements[1]}")
t2i.outdir=elements[1]
else:
print(f"directory {elements[1]} does not exist")
continue
if elements[0]=='pwd':
print(f"current output directory is {t2i.outdir}")
continue
if elements[0].startswith('!dream'): # in case a stored prompt still contains the !dream command
if elements[0].startswith(
'!dream'
): # in case a stored prompt still contains the !dream command
elements.pop(0)
# rearrange the arguments to mimic how it works in the Dream bot.
switches = ['']
switches_started = False
for el in elements:
if el[0]=='-' and not switches_started:
if el[0] == '-' and not switches_started:
switches_started = True
if switches_started:
switches.append(el)
else:
switches[0] += el
switches[0] += ' '
switches[0] = switches[0][:len(switches[0])-1]
switches[0] = switches[0][: len(switches[0]) - 1]
try:
opt = parser.parse_args(switches)
opt = parser.parse_args(switches)
except SystemExit:
parser.print_help()
continue
if len(opt.prompt)==0:
print("Try again with a prompt!")
if len(opt.prompt) == 0:
print('Try again with a prompt!')
continue
if opt.seed is not None and opt.seed < 0: # retrieve previous value!
try:
opt.seed = last_seeds[opt.seed]
print(f'reusing previous seed {opt.seed}')
except IndexError:
print(f'No previous seed at position {opt.seed} found')
opt.seed = None
try:
if opt.init_img is None:
results = t2i.txt2img(**vars(opt))
do_grid = opt.grid or t2i.grid
if opt.with_variations is not None:
# shotgun parsing, woo
parts = []
broken = False # python doesn't have labeled loops...
for part in opt.with_variations.split(','):
seed_and_weight = part.split(':')
if len(seed_and_weight) != 2:
print(f'could not parse with_variation part "{part}"')
broken = True
break
try:
seed = int(seed_and_weight[0])
weight = float(seed_and_weight[1])
except ValueError:
print(f'could not parse with_variation part "{part}"')
broken = True
break
parts.append([seed, weight])
if broken:
continue
if len(parts) > 0:
opt.with_variations = parts
else:
assert os.path.exists(opt.init_img),f"No file found at {opt.init_img}. On Linux systems, pressing <tab> after -I will autocomplete a list of possible image files."
if None not in (opt.width,opt.height):
print('Warning: width and height options are ignored when modifying an init image')
results = t2i.img2img(**vars(opt))
opt.with_variations = None
if opt.outdir:
if not os.path.exists(opt.outdir):
os.makedirs(opt.outdir)
current_outdir = opt.outdir
elif prompt_as_dir:
# sanitize the prompt to a valid folder name
subdir = path_filter.sub('_', opt.prompt)[:name_max].rstrip(' .')
# truncate path to maximum allowed length
# 27 is the length of '######.##########.##.png', plus two separators and a NUL
subdir = subdir[:(path_max - 27 - len(os.path.abspath(outdir)))]
current_outdir = os.path.join(outdir, subdir)
print ('Writing files to directory: "' + current_outdir + '"')
# make sure the output directory exists
if not os.path.exists(current_outdir):
os.makedirs(current_outdir)
else:
current_outdir = outdir
# Here is where the images are actually generated!
try:
file_writer = PngWriter(current_outdir)
prefix = file_writer.unique_prefix()
seeds = set()
results = [] # list of filename, prompt pairs
grid_images = dict() # seed -> Image, only used if `do_grid`
def image_writer(image, seed, upscaled=False):
if do_grid:
grid_images[seed] = image
else:
if upscaled and opt.save_original:
filename = f'{prefix}.{seed}.postprocessed.png'
else:
filename = f'{prefix}.{seed}.png'
if opt.variation_amount > 0:
iter_opt = argparse.Namespace(**vars(opt)) # copy
this_variation = [[seed, opt.variation_amount]]
if opt.with_variations is None:
iter_opt.with_variations = this_variation
else:
iter_opt.with_variations = opt.with_variations + this_variation
iter_opt.variation_amount = 0
normalized_prompt = PromptFormatter(t2i, iter_opt).normalize_prompt()
metadata_prompt = f'{normalized_prompt} -S{iter_opt.seed}'
elif opt.with_variations is not None:
normalized_prompt = PromptFormatter(t2i, opt).normalize_prompt()
metadata_prompt = f'{normalized_prompt} -S{opt.seed}' # use the original seed - the per-iteration value is the last variation-seed
else:
normalized_prompt = PromptFormatter(t2i, opt).normalize_prompt()
metadata_prompt = f'{normalized_prompt} -S{seed}'
path = file_writer.save_image_and_prompt_to_png(image, metadata_prompt, filename)
if (not upscaled) or opt.save_original:
# only append to results if we didn't overwrite an earlier output
results.append([path, metadata_prompt])
seeds.add(seed)
t2i.prompt2image(image_callback=image_writer, **vars(opt))
if do_grid and len(grid_images) > 0:
grid_img = make_grid(list(grid_images.values()))
first_seed = next(iter(seeds))
filename = f'{prefix}.{first_seed}.png'
# TODO better metadata for grid images
normalized_prompt = PromptFormatter(t2i, opt).normalize_prompt()
metadata_prompt = f'{normalized_prompt} -S{first_seed} --grid -N{len(grid_images)}'
path = file_writer.save_image_and_prompt_to_png(
grid_img, metadata_prompt, filename
)
results = [[path, metadata_prompt]]
last_seeds = list(seeds)
except AssertionError as e:
print(e)
continue
except OSError as e:
print(e)
continue
allVariantResults = []
if opt.variants is not None:
print(f"Generating {opt.variants} variant(s)...")
newopt = copy.deepcopy(opt)
newopt.variants = None
for r in results:
newopt.init_img = r[0]
print(f"\t generating variant for {newopt.init_img}")
for j in range(0, opt.variants):
try:
variantResults = t2i.img2img(**vars(newopt))
allVariantResults.append([newopt,variantResults])
except AssertionError as e:
print(e)
continue
print(f"{opt.variants} Variants generated!")
print('Outputs:')
log_path = os.path.join(current_outdir, 'dream_log.txt')
write_log_message(results, log_path)
print("Outputs:")
write_log_message(t2i,opt,results,log)
if allVariantResults:
print("Variant outputs:")
for vr in allVariantResults:
write_log_message(t2i,vr[0],vr[1],log)
print("goodbye!")
print('goodbye!')
def write_log_message(t2i,opt,results,logfile):
''' logs the name of the output image, its prompt and seed to the terminal, log file, and a Dream text chunk in the PNG metadata '''
switches = _reconstruct_switches(t2i,opt)
prompt_str = ' '.join(switches)
# when multiple images are produced in batch, then we keep track of where each starts
last_seed = None
img_num = 1
batch_size = opt.batch_size or t2i.batch_size
seenit = {}
seeds = [a[1] for a in results]
if batch_size > 1:
seeds = f"(seeds for each batch row: {seeds})"
def get_next_command(infile=None) -> str: #command string
if infile is None:
command = input('dream> ')
else:
seeds = f"(seeds for individual images: {seeds})"
command = infile.readline()
if not command:
raise EOFError
else:
command = command.strip()
print(f'#{command}')
return command
for r in results:
seed = r[1]
log_message = (f'{r[0]}: {prompt_str} -S{seed}')
def dream_server_loop(t2i, host, port):
print('\n* --web was specified, starting web server...')
# Change working directory to the stable-diffusion directory
os.chdir(
os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))
)
if batch_size > 1:
if seed != last_seed:
img_num = 1
log_message += f' # (batch image {img_num} of {batch_size})'
else:
img_num += 1
log_message += f' # (batch image {img_num} of {batch_size})'
last_seed = seed
print(log_message)
logfile.write(log_message+"\n")
logfile.flush()
if r[0] not in seenit:
seenit[r[0]] = True
try:
if opt.grid:
_write_prompt_to_png(r[0],f'{prompt_str} -g -S{seed} {seeds}')
else:
_write_prompt_to_png(r[0],f'{prompt_str} -S{seed}')
except FileNotFoundError:
print(f"Could not open file '{r[0]}' for reading")
# Start server
DreamServer.model = t2i
dream_server = ThreadingDreamServer((host, port))
print(">> Started Stable Diffusion dream server!")
if host == '0.0.0.0':
print(f"Point your browser at http://localhost:{port} or use the host's DNS name or IP address.")
else:
print(">> Default host address now 127.0.0.1 (localhost). Use --host 0.0.0.0 to bind any address.")
print(f">> Point your browser at http://{host}:{port}.")
def _reconstruct_switches(t2i,opt):
'''Normalize the prompt and switches'''
switches = list()
switches.append(f'"{opt.prompt}"')
switches.append(f'-s{opt.steps or t2i.steps}')
switches.append(f'-b{opt.batch_size or t2i.batch_size}')
switches.append(f'-W{opt.width or t2i.width}')
switches.append(f'-H{opt.height or t2i.height}')
switches.append(f'-C{opt.cfg_scale or t2i.cfg_scale}')
switches.append(f'-m{t2i.sampler_name}')
if opt.variants:
switches.append(f'-v{opt.variants}')
if opt.init_img:
switches.append(f'-I{opt.init_img}')
if opt.strength and opt.init_img is not None:
switches.append(f'-f{opt.strength or t2i.strength}')
if t2i.full_precision:
switches.append('-F')
return switches
try:
dream_server.serve_forever()
except KeyboardInterrupt:
pass
dream_server.server_close()
def write_log_message(results, log_path):
"""logs the name of the output image, prompt, and prompt args to the terminal and log file"""
log_lines = [f'{path}: {prompt}\n' for path, prompt in results]
print(*log_lines, sep='')
with open(log_path, 'a', encoding='utf-8') as file:
file.writelines(log_lines)
SAMPLER_CHOICES=[
'ddim',
'k_dpm_2_a',
'k_dpm_2',
'k_euler_a',
'k_euler',
'k_heun',
'k_lms',
'plms',
]
def _write_prompt_to_png(path,prompt):
info = PngImagePlugin.PngInfo()
info.add_text("Dream",prompt)
im = Image.open(path)
im.save(path,"PNG",pnginfo=info)
def create_argv_parser():
parser = argparse.ArgumentParser(description="Parse script's command line args")
parser.add_argument("--laion400m",
"--latent_diffusion",
"-l",
dest='laion400m',
action='store_true',
help="fallback to the latent diffusion (laion400m) weights and config")
parser.add_argument("--from_file",
dest='infile',
type=str,
help="if specified, load prompts from this file")
parser.add_argument('-n','--iterations',
type=int,
default=1,
help="number of images to generate")
parser.add_argument('-F','--full_precision',
dest='full_precision',
action='store_true',
help="use slower full precision math for calculations")
parser.add_argument('-b','--batch_size',
type=int,
default=1,
help="number of images to produce per iteration (faster, but doesn't generate individual seeds")
parser.add_argument('--sampler','-m',
dest="sampler_name",
choices=['ddim', 'k_dpm_2_a', 'k_dpm_2', 'k_euler_a', 'k_euler', 'k_heun', 'k_lms', 'plms'],
default='k_lms',
help="which sampler to use (k_lms) - can only be set on command line")
parser.add_argument('--outdir',
'-o',
type=str,
default="outputs/img-samples",
help="directory in which to place generated images and a log of prompts and seeds")
parser.add_argument('--embedding_path',
type=str,
help="Path to a pre-trained embedding manager checkpoint - can only be set on command line")
parser.add_argument('--device',
'-d',
type=str,
default="cuda",
help="device to run stable diffusion on. defaults to cuda `torch.cuda.current_device()` if avalible")
parser = argparse.ArgumentParser(
description="""Generate images using Stable Diffusion.
Use --web to launch the web interface.
Use --from_file to load prompts from a file path or standard input ("-").
Otherwise you will be dropped into an interactive command prompt (type -h for help.)
Other command-line arguments are defaults that can usually be overridden
prompt the command prompt.
"""
)
parser.add_argument(
'--laion400m',
'--latent_diffusion',
'-l',
dest='laion400m',
action='store_true',
help='Fallback to the latent diffusion (laion400m) weights and config',
)
parser.add_argument(
'--from_file',
dest='infile',
type=str,
help='If specified, load prompts from this file',
)
parser.add_argument(
'-n',
'--iterations',
type=int,
default=1,
help='Number of images to generate',
)
parser.add_argument(
'-F',
'--full_precision',
dest='full_precision',
action='store_true',
help='Use more memory-intensive full precision math for calculations',
)
parser.add_argument(
'-g',
'--grid',
action='store_true',
help='Generate a grid instead of individual images',
)
parser.add_argument(
'-A',
'-m',
'--sampler',
dest='sampler_name',
choices=SAMPLER_CHOICES,
metavar='SAMPLER_NAME',
default='k_lms',
help=f'Set the initial sampler. Default: k_lms. Supported samplers: {", ".join(SAMPLER_CHOICES)}',
)
parser.add_argument(
'--outdir',
'-o',
type=str,
default='outputs/img-samples',
help='Directory to save generated images and a log of prompts and seeds. Default: outputs/img-samples',
)
parser.add_argument(
'--embedding_path',
type=str,
help='Path to a pre-trained embedding manager checkpoint - can only be set on command line',
)
parser.add_argument(
'--prompt_as_dir',
'-p',
action='store_true',
help='Place images in subdirectories named after the prompt.',
)
# GFPGAN related args
parser.add_argument(
'--gfpgan_bg_upsampler',
type=str,
default='realesrgan',
help='Background upsampler. Default: realesrgan. Options: realesrgan, none. Only used if --gfpgan is specified',
)
parser.add_argument(
'--gfpgan_bg_tile',
type=int,
default=400,
help='Tile size for background sampler, 0 for no tile during testing. Default: 400.',
)
parser.add_argument(
'--gfpgan_model_path',
type=str,
default='experiments/pretrained_models/GFPGANv1.3.pth',
help='Indicates the path to the GFPGAN model, relative to --gfpgan_dir.',
)
parser.add_argument(
'--gfpgan_dir',
type=str,
default='../GFPGAN',
help='Indicates the directory containing the GFPGAN code.',
)
parser.add_argument(
'--web',
dest='web',
action='store_true',
help='Start in web server mode.',
)
parser.add_argument(
'--host',
type=str,
default='127.0.0.1',
help='Web server: Host or IP to listen on. Set to 0.0.0.0 to accept traffic from other devices on your network.'
)
parser.add_argument(
'--port',
type=int,
default='9090',
help='Web server: Port to listen on'
)
parser.add_argument(
'--weights',
default='model',
help='Indicates the Stable Diffusion model to use.',
)
parser.add_argument(
'--device',
'-d',
type=str,
default='cuda',
help="device to run stable diffusion on. defaults to cuda `torch.cuda.current_device()` if available"
)
parser.add_argument(
'--model',
default='stable-diffusion-1.4',
help='Indicates which diffusion model to load. (currently "stable-diffusion-1.4" (default) or "laion400m")',
)
parser.add_argument(
'--config',
default ='configs/models.yaml',
help ='Path to configuration file for alternate models.',
)
return parser
def create_cmd_parser():
parser = argparse.ArgumentParser(description='Example: dream> a fantastic alien landscape -W1024 -H960 -s100 -n12')
parser = argparse.ArgumentParser(
description='Example: dream> a fantastic alien landscape -W1024 -H960 -s100 -n12'
)
parser.add_argument('prompt')
parser.add_argument('-s','--steps',type=int,help="number of steps")
parser.add_argument('-S','--seed',type=int,help="image seed")
parser.add_argument('-n','--iterations',type=int,default=1,help="number of samplings to perform (slower, but will provide seeds for individual images)")
parser.add_argument('-b','--batch_size',type=int,default=1,help="number of images to produce per sampling (will not provide seeds for individual images!)")
parser.add_argument('-W','--width',type=int,help="image width, multiple of 64")
parser.add_argument('-H','--height',type=int,help="image height, multiple of 64")
parser.add_argument('-C','--cfg_scale',default=7.5,type=float,help="prompt configuration scale")
parser.add_argument('-g','--grid',action='store_true',help="generate a grid")
parser.add_argument('-i','--individual',action='store_true',help="generate individual files (default)")
parser.add_argument('-I','--init_img',type=str,help="path to input image for img2img mode (supersedes width and height)")
parser.add_argument('-f','--strength',default=0.75,type=float,help="strength for noising/unnoising. 0.0 preserves image exactly, 1.0 replaces it completely")
parser.add_argument('-v','--variants',type=int,help="in img2img mode, the first generated image will get passed back to img2img to generate the requested number of variants")
parser.add_argument('-x','--skip_normalize',action='store_true',help="skip subprompt weight normalization")
parser.add_argument('-s', '--steps', type=int, help='Number of steps')
parser.add_argument(
'-S',
'--seed',
type=int,
help='Image seed; a +ve integer, or use -1 for the previous seed, -2 for the one before that, etc',
)
parser.add_argument(
'-n',
'--iterations',
type=int,
default=1,
help='Number of samplings to perform (slower, but will provide seeds for individual images)',
)
parser.add_argument(
'-W', '--width', type=int, help='Image width, multiple of 64'
)
parser.add_argument(
'-H', '--height', type=int, help='Image height, multiple of 64'
)
parser.add_argument(
'-C',
'--cfg_scale',
default=7.5,
type=float,
help='Classifier free guidance (CFG) scale - higher numbers cause generator to "try" harder.',
)
parser.add_argument(
'-g', '--grid', action='store_true', help='generate a grid'
)
parser.add_argument(
'--outdir',
'-o',
type=str,
default=None,
help='Directory to save generated images and a log of prompts and seeds',
)
parser.add_argument(
'-i',
'--individual',
action='store_true',
help='Generate individual files (default)',
)
parser.add_argument(
'-I',
'--init_img',
type=str,
help='Path to input image for img2img mode (supersedes width and height)',
)
parser.add_argument(
'-T',
'-fit',
'--fit',
action='store_true',
help='If specified, will resize the input image to fit within the dimensions of width x height (512x512 default)',
)
parser.add_argument(
'-f',
'--strength',
default=0.75,
type=float,
help='Strength for noising/unnoising. 0.0 preserves image exactly, 1.0 replaces it completely',
)
parser.add_argument(
'-G',
'--gfpgan_strength',
default=0,
type=float,
help='The strength at which to apply the GFPGAN model to the result, in order to improve faces.',
)
parser.add_argument(
'-U',
'--upscale',
nargs='+',
default=None,
type=float,
help='Scale factor (2, 4) for upscaling followed by upscaling strength (0-1.0). If strength not specified, defaults to 0.75'
)
parser.add_argument(
'-save_orig',
'--save_original',
action='store_true',
help='Save original. Use it when upscaling to save both versions.',
)
# variants is going to be superseded by a generalized "prompt-morph" function
# parser.add_argument('-v','--variants',type=int,help="in img2img mode, the first generated image will get passed back to img2img to generate the requested number of variants")
parser.add_argument(
'-x',
'--skip_normalize',
action='store_true',
help='Skip subprompt weight normalization',
)
parser.add_argument(
'-A',
'-m',
'--sampler',
dest='sampler_name',
default=None,
type=str,
choices=SAMPLER_CHOICES,
metavar='SAMPLER_NAME',
help=f'Switch to a different sampler. Supported samplers: {", ".join(SAMPLER_CHOICES)}',
)
parser.add_argument(
'-t',
'--log_tokenization',
action='store_true',
help='shows how the prompt is split into tokens'
)
parser.add_argument(
'-v',
'--variation_amount',
default=0.0,
type=float,
help='If > 0, generates variations on the initial seed instead of random seeds per iteration. Must be between 0 and 1. Higher values will be more different.'
)
parser.add_argument(
'-V',
'--with_variations',
default=None,
type=str,
help='list of variations to apply, in the format `seed:weight,seed:weight,...'
)
return parser
if readline_available:
def setup_readline():
readline.set_completer(Completer(['cd','pwd',
'--steps','-s','--seed','-S','--iterations','-n','--batch_size','-b',
'--width','-W','--height','-H','--cfg_scale','-C','--grid','-g',
'--individual','-i','--init_img','-I','--strength','-f','-v','--variants']).complete)
readline.set_completer_delims(" ")
readline.parse_and_bind('tab: complete')
load_history()
def load_history():
histfile = os.path.join(os.path.expanduser('~'),".dream_history")
try:
readline.read_history_file(histfile)
readline.set_history_length(1000)
except FileNotFoundError:
pass
atexit.register(readline.write_history_file,histfile)
class Completer():
def __init__(self,options):
self.options = sorted(options)
return
def complete(self,text,state):
buffer = readline.get_line_buffer()
if text.startswith(('-I','--init_img')):
return self._path_completions(text,state,('.png'))
if buffer.strip().endswith('cd') or text.startswith(('.','/')):
return self._path_completions(text,state,())
response = None
if state == 0:
# This is the first time for this text, so build a match list.
if text:
self.matches = [s
for s in self.options
if s and s.startswith(text)]
else:
self.matches = self.options[:]
# Return the state'th item from the match list,
# if we have that many.
try:
response = self.matches[state]
except IndexError:
response = None
return response
def _path_completions(self,text,state,extensions):
# get the path so far
if text.startswith('-I'):
path = text.replace('-I','',1).lstrip()
elif text.startswith('--init_img='):
path = text.replace('--init_img=','',1).lstrip()
else:
path = text
matches = list()
path = os.path.expanduser(path)
if len(path)==0:
matches.append(text+'./')
else:
dir = os.path.dirname(path)
dir_list = os.listdir(dir)
for n in dir_list:
if n.startswith('.') and len(n)>1:
continue
full_path = os.path.join(dir,n)
if full_path.startswith(path):
if os.path.isdir(full_path):
matches.append(os.path.join(os.path.dirname(text),n)+'/')
elif n.endswith(extensions):
matches.append(os.path.join(os.path.dirname(text),n))
try:
response = matches[state]
except IndexError:
response = None
return response
if __name__ == "__main__":
if __name__ == '__main__':
main()

View File

@@ -6,7 +6,7 @@ import numpy as np
import torch
from main import instantiate_from_config
from ldm.models.diffusion.ddim import DDIMSampler
from ldm.dream.devices import choose_torch_device
def make_batch(image, mask, device):
image = np.array(Image.open(image).convert("RGB"))
@@ -61,8 +61,8 @@ if __name__ == "__main__":
model.load_state_dict(torch.load("models/ldm/inpainting_big/last.ckpt")["state_dict"],
strict=False)
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
model = model.to(device)
device = choose_torch_device()
model = model.to(device)
sampler = DDIMSampler(model)
os.makedirs(opt.outdir, exist_ok=True)

View File

@@ -1,4 +1,4 @@
from ldm.modules.encoders.modules import BERTTokenizer
from ldm.modules.encoders.modules import FrozenCLIPEmbedder, BERTEmbedder
from ldm.modules.embedding_manager import EmbeddingManager
import argparse, os
@@ -6,7 +6,7 @@ from functools import partial
import torch
def get_placeholder_loop(placeholder_string, tokenizer):
def get_placeholder_loop(placeholder_string, embedder, use_bert):
new_placeholder = None
@@ -16,10 +16,36 @@ def get_placeholder_loop(placeholder_string, tokenizer):
else:
new_placeholder = input(f"Placeholder string '{new_placeholder}' maps to more than a single token. Please enter another string: ")
token = tokenizer(new_placeholder)
token = get_bert_token_for_string(embedder.tknz_fn, new_placeholder) if use_bert else get_clip_token_for_string(embedder.tokenizer, new_placeholder)
if token is not None:
return new_placeholder, token
def get_clip_token_for_string(tokenizer, string):
batch_encoding = tokenizer(
string,
truncation=True,
max_length=77,
return_length=True,
return_overflowing_tokens=False,
padding="max_length",
return_tensors="pt"
)
tokens = batch_encoding["input_ids"]
if torch.count_nonzero(tokens - 49407) == 2:
return tokens[0, 1]
return None
def get_bert_token_for_string(tokenizer, string):
token = tokenizer(string)
if torch.count_nonzero(token) == 3:
return token[0, 1]
return None
if torch.count_nonzero(token) == 3:
return new_placeholder, token[0, 1]
if __name__ == "__main__":
@@ -40,10 +66,20 @@ if __name__ == "__main__":
help="Output path for the merged manager",
)
parser.add_argument(
"-sd", "--use_bert",
action="store_true",
help="Flag to denote that we are not merging stable diffusion embeddings"
)
args = parser.parse_args()
tokenizer = BERTTokenizer(vq_interface=False, max_length=77)
EmbeddingManager = partial(EmbeddingManager, tokenizer, ["*"])
if args.use_bert:
embedder = BERTEmbedder(n_embed=1280, n_layer=32).cuda()
else:
embedder = FrozenCLIPEmbedder().cuda()
EmbeddingManager = partial(EmbeddingManager, embedder, ["*"])
string_to_token_dict = {}
string_to_param_dict = torch.nn.ParameterDict()
@@ -63,7 +99,7 @@ if __name__ == "__main__":
placeholder_to_src[placeholder_string] = manager_ckpt
else:
new_placeholder, new_token = get_placeholder_loop(placeholder_string, tokenizer)
new_placeholder, new_token = get_placeholder_loop(placeholder_string, embedder, use_bert=args.use_bert)
string_to_token_dict[new_placeholder] = new_token
string_to_param_dict[new_placeholder] = manager.string_to_param_dict[placeholder_string]
@@ -77,7 +113,3 @@ if __name__ == "__main__":
print("Managers merged. Final list of placeholders: ")
print(placeholder_to_src)

View File

@@ -18,6 +18,7 @@ from pytorch_lightning import seed_everything
from ldm.util import instantiate_from_config
from ldm.models.diffusion.ddim import DDIMSampler
from ldm.models.diffusion.plms import PLMSSampler
from ldm.dream.devices import choose_torch_device
def chunk(it, size):
@@ -40,7 +41,7 @@ def load_model_from_config(config, ckpt, verbose=False):
print("unexpected keys:")
print(u)
model.cuda()
model.to(choose_torch_device())
model.eval()
return model
@@ -199,7 +200,7 @@ def main():
config = OmegaConf.load(f"{opt.config}")
model = load_model_from_config(config, f"{opt.ckpt}")
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
device = torch.device(choose_torch_device())
model = model.to(device)
if opt.plms:
@@ -241,8 +242,10 @@ def main():
print(f"target t_enc is {t_enc} steps")
precision_scope = autocast if opt.precision == "autocast" else nullcontext
if device.type in ['mps', 'cpu']:
precision_scope = nullcontext # have to use f32 on mps
with torch.no_grad():
with precision_scope("cuda"):
with precision_scope(device.type):
with model.ema_scope():
tic = time.time()
all_samples = list()

View File

@@ -12,14 +12,13 @@ from pytorch_lightning import seed_everything
from torch import autocast
from contextlib import contextmanager, nullcontext
import accelerate
import k_diffusion as K
import torch.nn as nn
from ldm.util import instantiate_from_config
from ldm.util import instantiate_from_config
from ldm.models.diffusion.ddim import DDIMSampler
from ldm.models.diffusion.plms import PLMSSampler
from ldm.dream.devices import choose_torch_device
def chunk(it, size):
it = iter(it)
@@ -41,7 +40,7 @@ def load_model_from_config(config, ckpt, verbose=False):
print("unexpected keys:")
print(u)
model.cuda()
model.to(choose_torch_device())
model.eval()
return model
@@ -191,18 +190,17 @@ def main():
opt.ckpt = "models/ldm/text2img-large/model.ckpt"
opt.outdir = "outputs/txt2img-samples-laion400m"
seed_everything(opt.seed)
config = OmegaConf.load(f"{opt.config}")
model = load_model_from_config(config, f"{opt.ckpt}")
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
model = model.to(device)
seed_everything(opt.seed)
device = torch.device(choose_torch_device())
model = model.to(device)
#for klms
model_wrap = K.external.CompVisDenoiser(model)
accelerator = accelerate.Accelerator()
device = accelerator.device
class CFGDenoiser(nn.Module):
def __init__(self, model):
super().__init__()
@@ -243,16 +241,22 @@ def main():
start_code = None
if opt.fixed_code:
start_code = torch.randn([opt.n_samples, opt.C, opt.H // opt.f, opt.W // opt.f], device=device)
shape = [opt.n_samples, opt.C, opt.H // opt.f, opt.W // opt.f]
if device.type == 'mps':
start_code = torch.randn(shape, device='cpu').to(device)
else:
torch.randn(shape, device=device)
precision_scope = autocast if opt.precision=="autocast" else nullcontext
if device.type in ['mps', 'cpu']:
precision_scope = nullcontext # have to use f32 on mps
with torch.no_grad():
with precision_scope("cuda"):
with precision_scope(device.type):
with model.ema_scope():
tic = time.time()
all_samples = list()
for n in trange(opt.n_iter, desc="Sampling", disable =not accelerator.is_main_process):
for prompts in tqdm(data, desc="data", disable =not accelerator.is_main_process):
for n in trange(opt.n_iter, desc="Sampling"):
for prompts in tqdm(data, desc="data"):
uc = None
if opt.scale != 1.0:
uc = model.get_learned_conditioning(batch_size * [""])
@@ -279,13 +283,10 @@ def main():
x = torch.randn([opt.n_samples, *shape], device=device) * sigmas[0] # for GPU draw
model_wrap_cfg = CFGDenoiser(model_wrap)
extra_args = {'cond': c, 'uncond': uc, 'cond_scale': opt.scale}
samples_ddim = K.sampling.sample_lms(model_wrap_cfg, x, sigmas, extra_args=extra_args, disable=not accelerator.is_main_process)
samples_ddim = K.sampling.sample_lms(model_wrap_cfg, x, sigmas, extra_args=extra_args)
x_samples_ddim = model.decode_first_stage(samples_ddim)
x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
if opt.klms:
x_sample = accelerator.gather(x_samples_ddim)
if not opt.skip_save:
for x_sample in x_samples_ddim:

View File

@@ -3,32 +3,84 @@
# Before running stable-diffusion on an internet-isolated machine,
# run this script from one with internet connectivity. The
# two machines must share a common .cache directory.
from transformers import CLIPTokenizer, CLIPTextModel
import clip
from transformers import BertTokenizerFast
import sys
import transformers
import os
import warnings
transformers.logging.set_verbosity_error()
# this will preload the Bert tokenizer fles
print("preloading bert tokenizer...")
from transformers import BertTokenizerFast
tokenizer = BertTokenizerFast.from_pretrained("bert-base-uncased")
print("...success")
print('preloading bert tokenizer...')
tokenizer = BertTokenizerFast.from_pretrained('bert-base-uncased')
print('...success')
# this will download requirements for Kornia
print("preloading Kornia requirements (ignore the warnings)...")
import kornia
print("...success")
print('preloading Kornia requirements (ignore the deprecation warnings)...')
with warnings.catch_warnings():
warnings.filterwarnings('ignore', category=DeprecationWarning)
import kornia
print('...success')
# doesn't work - probably wrong logger
# logging.getLogger('transformers.tokenization_utils').setLevel(logging.ERROR)
version='openai/clip-vit-large-patch14'
version = 'openai/clip-vit-large-patch14'
print('preloading CLIP model (Ignore the warnings)...')
print('preloading CLIP model (Ignore the deprecation warnings)...')
sys.stdout.flush()
import clip
from transformers import CLIPTokenizer, CLIPTextModel
tokenizer =CLIPTokenizer.from_pretrained(version)
transformer=CLIPTextModel.from_pretrained(version)
tokenizer = CLIPTokenizer.from_pretrained(version)
transformer = CLIPTextModel.from_pretrained(version)
print('\n\n...success')
# In the event that the user has installed GFPGAN and also elected to use
# RealESRGAN, this will attempt to download the model needed by RealESRGANer
gfpgan = False
try:
from realesrgan import RealESRGANer
gfpgan = True
except ModuleNotFoundError:
pass
if gfpgan:
print('Loading models from RealESRGAN and facexlib')
try:
from basicsr.archs.rrdbnet_arch import RRDBNet
from facexlib.utils.face_restoration_helper import FaceRestoreHelper
RealESRGANer(
scale=2,
model_path='https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x2plus.pth',
model=RRDBNet(
num_in_ch=3,
num_out_ch=3,
num_feat=64,
num_block=23,
num_grow_ch=32,
scale=2,
),
)
RealESRGANer(
scale=4,
model_path='https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth',
model=RRDBNet(
num_in_ch=3,
num_out_ch=3,
num_feat=64,
num_block=23,
num_grow_ch=32,
scale=4,
),
)
FaceRestoreHelper(1, det_model='retinaface_resnet50')
print('...success')
except Exception:
import traceback
print('Error loading GFPGAN:')
print(traceback.format_exc())

Submodule src/clip deleted from d50d76daa6

Submodule src/k-diffusion deleted from db57990687

BIN
static/colab_notebook.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 799 KiB

BIN
static/dream-py-demo.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 499 KiB

View File

@@ -0,0 +1,97 @@
* {
font-family: 'Arial';
}
#header {
text-decoration: dotted underline;
}
#search {
margin-top: 20vh;
margin-left: auto;
margin-right: auto;
max-width: 1024px;
text-align: center;
}
fieldset {
border: none;
}
div {
padding: 10px 10px 10px 10px;
}
#fieldset-search {
display: flex;
}
#scaling-inprocess-message{
font-weight: bold;
font-style: italic;
display: none;
}
#prompt {
flex-grow: 1;
border-radius: 20px 0px 0px 20px;
padding: 5px 10px 5px 10px;
border: 1px solid black;
border-right: none;
outline: none;
}
#submit {
border-radius: 0px 20px 20px 0px;
padding: 5px 10px 5px 10px;
border: 1px solid black;
}
#reset-all {
background-color: pink;
}
#results {
text-align: center;
// max-width: 1024px;
margin: auto;
padding-top: 10px;
}
#results img {
cursor: pointer;
height: 30vh;
border-radius: 5px;
margin: 10px;
}
#fieldset-config {
line-height:2em;
}
input[type="number"] {
width: 60px;
}
#seed {
width: 150px;
}
hr {
// width: 200px;
}
label {
white-space: nowrap;
}
#progress-section {
display: none;
}
#progress-image {
width: 30vh;
height: 30vh;
}
#cancel-button {
cursor: pointer;
color: red;
}
#txt2img {
background-color: #DCDCDC;
}
#img2img {
background-color: #F5F5F5;
}
#gfpgan {
background-color: #DCDCDC;
}
#progress-section {
background-color: #F5F5F5;
}
#about {
background-color: #DCDCDC;
}

111
static/dream_web/index.html Normal file
View File

@@ -0,0 +1,111 @@
<html lang="en">
<head>
<title>Stable Diffusion Dream Server</title>
<meta charset="utf-8">
<link rel="icon" href="data:,">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<link rel="stylesheet" href="static/dream_web/index.css">
<script src="config.js"></script>
<script src="static/dream_web/index.js"></script>
</head>
<body>
<div id="search">
<h2 id="header">Stable Diffusion Dream Server</h2>
<form id="generate-form" method="post" action="#">
<div id="txt2img">
<fieldset id="fieldset-search">
<input type="text" id="prompt" name="prompt">
<input type="submit" id="submit" value="Generate">
</fieldset>
<fieldset id="fieldset-config">
<label for="iterations">Images to generate:</label>
<input value="1" type="number" id="iterations" name="iterations" size="4">
<label for="steps">Steps:</label>
<input value="50" type="number" id="steps" name="steps">
<label for="cfgscale">Cfg Scale:</label>
<input value="7.5" type="number" id="cfgscale" name="cfgscale" step="any">
<label for="sampler">Sampler:</label>
<select id="sampler" name="sampler" value="k_lms">
<option value="ddim">DDIM</option>
<option value="plms">PLMS</option>
<option value="k_lms" selected>KLMS</option>
<option value="k_dpm_2">KDPM_2</option>
<option value="k_dpm_2_a">KDPM_2A</option>
<option value="k_euler">KEULER</option>
<option value="k_euler_a">KEULER_A</option>
<option value="k_heun">KHEUN</option>
</select>
<br>
<label title="Set to multiple of 64" for="width">Width:</label>
<select id="width" name="width" value="512">
<option value="64">64</option> <option value="128">128</option>
<option value="192">192</option> <option value="256">256</option>
<option value="320">320</option> <option value="384">384</option>
<option value="448">448</option> <option value="512" selected>512</option>
<option value="576">576</option> <option value="640">640</option>
<option value="704">704</option> <option value="768">768</option>
<option value="832">832</option> <option value="896">896</option>
<option value="960">960</option> <option value="1024">1024</option>
</select>
<label title="Set to multiple of 64" for="height">Height:</label>
<select id="height" name="height" value="512">
<option value="64">64</option> <option value="128">128</option>
<option value="192">192</option> <option value="256">256</option>
<option value="320">320</option> <option value="384">384</option>
<option value="448">448</option> <option value="512" selected>512</option>
<option value="576">576</option> <option value="640">640</option>
<option value="704">704</option> <option value="768">768</option>
<option value="832">832</option> <option value="896">896</option>
<option value="960">960</option> <option value="1024">1024</option>
</select>
<label title="Set to -1 for random seed" for="seed">Seed:</label>
<input value="-1" type="number" id="seed" name="seed">
<button type="button" id="reset-seed">&olarr;</button>
<input type="checkbox" name="progress_images" id="progress_images">
<label for="progress_images">Display in-progress images (slows down generation):</label>
<button type="button" id="reset-all">Reset to Defaults</button>
</div>
<div id="img2img">
<label title="Upload an image to use img2img" for="initimg">Initial image:</label>
<input type="file" id="initimg" name="initimg" accept=".jpg, .jpeg, .png">
<br>
<label for="strength">Img2Img Strength:</label>
<input value="0.75" type="number" id="strength" name="strength" step="0.01" min="0" max="1">
<input type="checkbox" id="fit" name="fit" checked>
<label title="Rescale image to fit within requested width and height" for="fit">Fit to width/height:</label>
</div>
<div id="gfpgan">
<label title="Strength of the gfpgan (face fixing) algorithm." for="gfpgan_strength">GPFGAN Strength (0 to disable):</label>
<input value="0.8" min="0" max="1" type="number" id="gfpgan_strength" name="gfpgan_strength" step="0.05">
<label title="Upscaling to perform using ESRGAN." for="upscale_level">Upscaling Level</label>
<select id="upscale_level" name="upscale_level" value="">
<option value="" selected>None</option>
<option value="2">2x</option>
<option value="4">4x</option>
</select>
<label title="Strength of the esrgan (upscaling) algorithm." for="upscale_strength">Upscale Strength:</label>
<input value="0.75" min="0" max="1" type="number" id="upscale_strength" name="upscale_strength" step="0.05">
</div>
</fieldset>
</form>
<div id="about">For news and support for this web service, visit our <a href="http://github.com/lstein/stable-diffusion">GitHub site</a></div>
<br>
<div id="progress-section">
<progress id="progress-bar" value="0" max="1"></progress>
<span id="cancel-button" title="Cancel">&#10006;</span>
<br>
<img id="progress-image" src='data:image/svg+xml,<svg xmlns="http://www.w3.org/2000/svg"/>'></img>
<div id="scaling-inprocess-message">
<i><span>Postprocessing...</span><span id="processing_cnt">1/3</span></i>
</div>
</div>
</div>
<div id="results">
<div id="no-results-message">
<i><p>No results...</p></i>
</div>
</div>
</body>
</html>

161
static/dream_web/index.js Normal file
View File

@@ -0,0 +1,161 @@
function toBase64(file) {
return new Promise((resolve, reject) => {
const r = new FileReader();
r.readAsDataURL(file);
r.onload = () => resolve(r.result);
r.onerror = (error) => reject(error);
});
}
function appendOutput(src, seed, config) {
let outputNode = document.createElement("img");
outputNode.src = src;
let altText = seed.toString() + " | " + config.prompt;
outputNode.alt = altText;
outputNode.title = altText;
// Reload image config
outputNode.addEventListener('click', () => {
let form = document.querySelector("#generate-form");
for (const [k, v] of new FormData(form)) {
form.querySelector(`*[name=${k}]`).value = config[k];
}
document.querySelector("#seed").value = seed;
saveFields(document.querySelector("#generate-form"));
});
document.querySelector("#results").prepend(outputNode);
}
function saveFields(form) {
for (const [k, v] of new FormData(form)) {
if (typeof v !== 'object') { // Don't save 'file' type
localStorage.setItem(k, v);
}
}
}
function loadFields(form) {
for (const [k, v] of new FormData(form)) {
const item = localStorage.getItem(k);
if (item != null) {
form.querySelector(`*[name=${k}]`).value = item;
}
}
}
function clearFields(form) {
localStorage.clear();
let prompt = form.prompt.value;
form.reset();
form.prompt.value = prompt;
}
const BLANK_IMAGE_URL = 'data:image/svg+xml,<svg xmlns="http://www.w3.org/2000/svg"/>';
async function generateSubmit(form) {
const prompt = document.querySelector("#prompt").value;
// Convert file data to base64
let formData = Object.fromEntries(new FormData(form));
formData.initimg = formData.initimg.name !== '' ? await toBase64(formData.initimg) : null;
let strength = formData.strength;
let totalSteps = formData.initimg ? Math.floor(strength * formData.steps) : formData.steps;
let progressSectionEle = document.querySelector('#progress-section');
progressSectionEle.style.display = 'initial';
let progressEle = document.querySelector('#progress-bar');
progressEle.setAttribute('max', totalSteps);
let progressImageEle = document.querySelector('#progress-image');
progressImageEle.src = BLANK_IMAGE_URL;
progressImageEle.style.display = {}.hasOwnProperty.call(formData, 'progress_images') ? 'initial': 'none';
// Post as JSON, using Fetch streaming to get results
fetch(form.action, {
method: form.method,
body: JSON.stringify(formData),
}).then(async (response) => {
const reader = response.body.getReader();
let noOutputs = true;
while (true) {
let {value, done} = await reader.read();
value = new TextDecoder().decode(value);
if (done) {
progressSectionEle.style.display = 'none';
break;
}
for (let event of value.split('\n').filter(e => e !== '')) {
const data = JSON.parse(event);
if (data.event === 'result') {
noOutputs = false;
document.querySelector("#no-results-message")?.remove();
appendOutput(data.url, data.seed, data.config);
progressEle.setAttribute('value', 0);
progressEle.setAttribute('max', totalSteps);
} else if (data.event === 'upscaling-started') {
document.getElementById("processing_cnt").textContent=data.processed_file_cnt;
document.getElementById("scaling-inprocess-message").style.display = "block";
} else if (data.event === 'upscaling-done') {
document.getElementById("scaling-inprocess-message").style.display = "none";
} else if (data.event === 'step') {
progressEle.setAttribute('value', data.step);
if (data.url) {
progressImageEle.src = data.url;
}
} else if (data.event === 'canceled') {
// avoid alerting as if this were an error case
noOutputs = false;
}
}
}
// Re-enable form, remove no-results-message
form.querySelector('fieldset').removeAttribute('disabled');
document.querySelector("#prompt").value = prompt;
document.querySelector('progress').setAttribute('value', '0');
if (noOutputs) {
alert("Error occurred while generating.");
}
});
// Disable form while generating
form.querySelector('fieldset').setAttribute('disabled','');
document.querySelector("#prompt").value = `Generating: "${prompt}"`;
}
window.onload = () => {
document.querySelector("#generate-form").addEventListener('submit', (e) => {
e.preventDefault();
const form = e.target;
generateSubmit(form);
});
document.querySelector("#generate-form").addEventListener('change', (e) => {
saveFields(e.target.form);
});
document.querySelector("#reset-seed").addEventListener('click', (e) => {
document.querySelector("#seed").value = -1;
saveFields(e.target.form);
});
document.querySelector("#reset-all").addEventListener('click', (e) => {
clearFields(e.target.form);
});
loadFields(document.querySelector("#generate-form"));
document.querySelector('#cancel-button').addEventListener('click', () => {
fetch('/cancel').catch(e => {
console.error(e);
});
});
if (!config.gfpgan_model_exists) {
document.querySelector("#gfpgan").style.display = 'none';
}
};

BIN
static/dream_web_server.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 536 KiB

BIN
static/logo_temp.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 34 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 429 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 445 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 426 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 427 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 424 KiB