Compare commits

..

777 Commits

Author SHA1 Message Date
psychedelicious
434afdd21d chore(ui): lint 2024-08-26 22:57:56 +10:00
psychedelicious
ea267c0e09 chore: release v4.2.9.dev4
Canvas dev build.
2024-08-26 22:46:07 +10:00
psychedelicious
35b483b83a feat(ui): rough out undo/redo on canvas UNDO ME? 2024-08-26 22:45:11 +10:00
psychedelicious
217b2759c3 fix(ui): handle error from internal konva method
We are dipping into konva's private API for preview images and it appears to be unsafe (got an error once). Wrapped in a try/catch.
2024-08-26 22:44:47 +10:00
psychedelicious
7d528dc03d feat(ui): split out loras state from canvas rendering state 2024-08-26 22:29:28 +10:00
psychedelicious
bc40c1a99f feat(ui): split out session state from canvas rendering state 2024-08-26 22:20:04 +10:00
psychedelicious
a2b508016e feat(ui): split out settings state from canvas rendering state 2024-08-26 22:02:56 +10:00
psychedelicious
3ce8294379 feat(ui): split out tool state from canvas rendering state 2024-08-26 21:52:43 +10:00
psychedelicious
9c3da8de8e feat(ui): split out params/compositing state from canvas rendering state
First step to restoring undo/redo - the undoable state must be in its own slice. So params and settings must be isolated.
2024-08-26 21:41:47 +10:00
psychedelicious
ccba597e58 feat(ui): add CanvasModuleBase class to standardize canvas APIs
I did this ages ago but undid it for some reason, not sure why. Caught a few issues related to subscriptions.
2024-08-26 21:12:31 +10:00
psychedelicious
8028943cdd feat(ui): move selected tool and tool buffer out of redux
This ephemeral state can live in the canvas classes.
2024-08-26 19:59:06 +10:00
psychedelicious
cdd8b60fd0 feat(ui): move ephemeral state into canvas classes
Things like `$lastCursorPos` are now created within the canvas drawing classes. Consumers in react access them via `useCanvasManager`.

For example:
```tsx
const canvasManager = useCanvasManager();
const lastCursorPos = useStore(canvasManager.stateApi.$lastCursorPos);
```
2024-08-26 19:14:56 +10:00
psychedelicious
c624754404 feat(ui): normalize all actions to accept an entityIdentifier
Previously, canvas actions specific to an entity type only needed the id of that entity type. This allowed you to pass in the id of an entity of the wrong type.

All actions for a specific entity now take a full entity identifier, and the entity identifier type can be narrowed.

`selectEntity` and `selectEntityOrThrow` now need a full entity identifier, and narrow their return values to a specific entity type _if_ the entity identifier is narrowed.

The types for canvas entities are updated with optional type parameters for this purpose.

All reducers, actions and components have been updated.
2024-08-26 18:52:28 +10:00
psychedelicious
0505634296 feat(ui): move events into modules who care about them 2024-08-26 10:34:59 +10:00
psychedelicious
e06ea5d595 fix(ui): color picker resets brush opacity 2024-08-26 08:58:16 +10:00
psychedelicious
838a0574a5 fix(ui): scaled bbox loses sync 2024-08-26 08:55:33 +10:00
psychedelicious
ed3581c70c feat(ui): add context menu to entity list 2024-08-24 19:50:36 +10:00
psychedelicious
522cae6a42 chore(ui): bump @invoke-ai/ui-library 2024-08-24 19:49:43 +10:00
psychedelicious
a9032a34f2 fix(ui): missing vae precision in graph builders 2024-08-24 18:35:30 +10:00
psychedelicious
6af5a22d3a chore: release v4.2.9.dev3
Instead of using dates, just going to increment.
2024-08-24 15:03:43 +10:00
psychedelicious
8cf4321010 feat(ui): use new Result utils for enqueueing 2024-08-24 14:49:17 +10:00
psychedelicious
aa7f2b096a fix(ui): graph building issue w/ controlnet 2024-08-24 14:48:18 +10:00
psychedelicious
bf0824b56d feat(ui): add Result type & helpers
Wrappers to capture errors and turn into results:
- `withResult` wraps a sync function
- `withResultAsync` wraps an async function

Comments, tests.
2024-08-24 14:46:58 +10:00
psychedelicious
056c56d322 chore: release v4.2.9.dev20240824 2024-08-24 12:36:25 +10:00
psychedelicious
afc6f83d72 fix(ui): lint & fix issues with adding regional ip adapters 2024-08-24 12:32:38 +10:00
psychedelicious
c776ac3af2 feat(ui): add knipignore tag
I'm not ready to delete some things but still want to build the app.
2024-08-24 12:32:00 +10:00
psychedelicious
b7b3683bef feat(ui): duplicate entity 2024-08-24 12:20:35 +10:00
psychedelicious
fb26b6824a feat(ui): autocomplete on getPrefixeId 2024-08-24 12:20:26 +10:00
psychedelicious
63d8ad912f feat(ui): paste canvas gens back on source in generate mode 2024-08-24 11:56:24 +10:00
psychedelicious
bbd7d7fc17 chore(ui): typegen 2024-08-24 11:55:50 +10:00
psychedelicious
6507a78182 feat(nodes): CanvasV2MaskAndCropInvocation can paste generated image back on source
This is needed for `Generate` mode.
2024-08-24 11:55:43 +10:00
psychedelicious
22f46517f4 fix(ui): extraneous entity preview updates 2024-08-24 11:28:05 +10:00
psychedelicious
45596e1f94 fix(ui): newly-added entities are selected 2024-08-24 11:14:58 +10:00
psychedelicious
6de0dbe854 feat(ui): add crosshair to color picker 2024-08-24 10:51:34 +10:00
psychedelicious
011827fa29 fix(ui): color picker ignores alpha 2024-08-24 10:16:27 +10:00
psychedelicious
fc6d244071 fix(ui): calculate renderable entities correctly in tool module 2024-08-24 10:10:21 +10:00
psychedelicious
cd3da886d6 feat(ui): better color picker 2024-08-24 10:10:04 +10:00
psychedelicious
c013c55d92 feat(ui): colored mask preview image 2024-08-24 08:54:20 +10:00
psychedelicious
cd3dd7db0d fix(ui): new rectangles don't trigger rerender 2024-08-23 23:24:16 +10:00
psychedelicious
1fdcce9429 chore: bump version v4.2.9.dev20240823 2024-08-23 20:52:16 +10:00
psychedelicious
181e40926d feat(ui): disable most interaction while filtering 2024-08-23 20:32:49 +10:00
psychedelicious
c62ede5878 fix(ui): filter preview offset 2024-08-23 20:24:40 +10:00
psychedelicious
a2ad5f1a9a feat(ui): tweak layout of staging area toolbar 2024-08-23 19:55:02 +10:00
psychedelicious
ff74a5356f chore(ui): typegen 2024-08-23 19:52:37 +10:00
psychedelicious
92dc30dace tidy(app): clean up app changes for canvas v2 2024-08-23 19:52:04 +10:00
psychedelicious
3af577b210 feat(ui): use singleton for clear q confirm dialog 2024-08-23 19:47:51 +10:00
psychedelicious
d0464330f7 fix(ui): rip out broken recall logic, NO TS ERRORS 2024-08-23 19:47:51 +10:00
psychedelicious
dd3ef4a80f chore(ui): lint 2024-08-23 19:47:51 +10:00
psychedelicious
0ced891944 fix(ui): staging area interaction scopes 2024-08-23 19:47:51 +10:00
psychedelicious
10a5452df9 fix(ui): staging area actions 2024-08-23 19:47:51 +10:00
psychedelicious
cb97969bbc tidy(ui): more cleanup 2024-08-23 19:47:51 +10:00
psychedelicious
71e742e238 fix(ui): upscale tab graph 2024-08-23 19:47:51 +10:00
psychedelicious
fadd20fb8e fix(ui): sdxl graph builder 2024-08-23 19:47:51 +10:00
psychedelicious
01b9ca78e4 fix(ui): select next entity in the list when deleting 2024-08-23 19:47:51 +10:00
psychedelicious
2baf825f34 feat(ui): fix delete layer hotkey 2024-08-23 19:47:51 +10:00
psychedelicious
1fa8048509 tidy(ui): "eye dropper" -> "color picker" 2024-08-23 19:47:51 +10:00
psychedelicious
a000ad75f6 tidy(ui): regional guidance buttons 2024-08-23 19:47:51 +10:00
psychedelicious
aefb2339bb feat(ui): update entity list menu 2024-08-23 19:47:51 +10:00
psychedelicious
a4f8671f86 feat(ui): add log debug button 2024-08-23 19:47:51 +10:00
psychedelicious
73530ba54f chore(ui): lint 2024-08-23 19:47:51 +10:00
psychedelicious
685eb9927d chore(ui): prettier 2024-08-23 19:47:51 +10:00
psychedelicious
ee57302fc3 chore(ui): eslint 2024-08-23 19:47:51 +10:00
psychedelicious
c1fb9cdb93 tidy(ui): remove unused stuff 4 2024-08-23 19:47:36 +10:00
psychedelicious
aa6d441552 tidy(ui): remove unused stuff 3 2024-08-23 19:47:01 +10:00
psychedelicious
25d8d4c2e9 tidy(ui): remove unused pkg @chakra-ui/react-use-size 2024-08-23 19:47:01 +10:00
psychedelicious
427ea6da5c feat(ui): revise graph building for control layers, fix issues w/ invocation complete events 2024-08-23 19:47:01 +10:00
psychedelicious
d9f4266630 feat(ui): use unique id for metadata in Graph class 2024-08-23 19:47:01 +10:00
psychedelicious
96f6e9e683 tidy(ui): remove unused stuff 2 2024-08-23 19:47:01 +10:00
psychedelicious
f10248e3f5 tidy(ui): remove unused stuff 2024-08-23 19:47:01 +10:00
psychedelicious
6a21f5fde1 tidy(ui): reduce use of parseify util 2024-08-23 19:47:01 +10:00
psychedelicious
ff20dd509a feat(ui): refine canvas entity list items & menus 2024-08-23 19:47:01 +10:00
psychedelicious
78a59b5b78 feat(ui): canvas layer preview, revised reactivity for adapters 2024-08-23 19:47:01 +10:00
psychedelicious
46bfbbbc87 feat(ui): add SyncableMap
Can be used with useSyncExternal store to make a `Map` reactive.
2024-08-23 19:47:01 +10:00
psychedelicious
a6d73d0773 tidy(ui): removed unused transform methods from canvasmanager 2024-08-23 19:47:01 +10:00
psychedelicious
6578e8bef8 feat(ui): transform tool ux 2024-08-23 19:47:01 +10:00
psychedelicious
0596d25e07 feat(ui): rough out canvas mode 2024-08-23 19:47:01 +10:00
psychedelicious
86e8ce9139 feat(ui): add canvas autosave checkbox 2024-08-23 19:47:01 +10:00
psychedelicious
5aa2957da4 fix(ui): memory leak when getting image DTO
must unsubscribe!
2024-08-23 19:47:01 +10:00
psychedelicious
82f0cb2c8c feat(ui): rework settings menu 2024-08-23 19:47:01 +10:00
psychedelicious
fa48145cbc feat(ui): no entities fallback buttons 2024-08-23 19:47:01 +10:00
psychedelicious
6d1edc330d perf(ui): optimize gallery image delete button rendering 2024-08-23 19:47:01 +10:00
psychedelicious
97c0d3f6be feat(ui): remove "solid" background option 2024-08-23 19:47:01 +10:00
psychedelicious
a79a25ad63 tidy(ui): organise files and classes 2024-08-23 19:47:01 +10:00
psychedelicious
6a8ceef404 tidy(ui): abstract compositing logic to module 2024-08-23 19:47:01 +10:00
psychedelicious
3539670d93 fix(ui): fix canvas cache property access 2024-08-23 19:47:01 +10:00
psychedelicious
c54bc32ef6 tidy(ui): clean up CanvasFilter class 2024-08-23 19:47:01 +10:00
psychedelicious
fee293e289 tidy(ui): clean up a few bits and bobs 2024-08-23 19:47:01 +10:00
psychedelicious
747eef9ccc tidy(ui): abstract canvas rendering logic to module 2024-08-23 19:47:01 +10:00
psychedelicious
7d2df399ed tidy(ui): abstract caching logic to module 2024-08-23 19:47:01 +10:00
psychedelicious
68fad5cdcc tidy(ui): abstract worker logic to module 2024-08-23 19:47:01 +10:00
psychedelicious
b4d656c203 tidy(ui): abstract stage logic into module 2024-08-23 19:47:01 +10:00
psychedelicious
3136d89d52 feat(ui): add entity group hiding 2024-08-23 19:47:01 +10:00
psychedelicious
27e829b955 feat(ui): move all caching out of redux
While we lose the benefit of the caches persisting across reloads, this is a much simpler way to handle things. If we need a persistent cache, we can explore it in the future.
2024-08-23 19:47:01 +10:00
psychedelicious
e03e870d5b feat(ui): revised rasterization caching
- use `stable-hash` to generate stable, non-crypto hashes for cache entries, instead of using deep object comparisons
- use an object to store image name caches
2024-08-23 19:47:01 +10:00
psychedelicious
9465ff450b feat(ui): revise filter implementation 2024-08-23 19:47:01 +10:00
psychedelicious
92906a9575 fix(ui): add button to delete inpaint mask 2024-08-23 19:47:01 +10:00
psychedelicious
77f206abe4 feat(ui): add contexts/hooks to access entity adapters directly 2024-08-23 19:47:01 +10:00
psychedelicious
44a3f61580 feat(ui): add CanvasManagerProviderGate
This context waits to render its children its until the canvas manager is available. Then its children have access to the manager directly via hook.
2024-08-23 19:47:01 +10:00
psychedelicious
0a2afed08b feat(ui) do not set $canvasManager until ready 2024-08-23 19:47:01 +10:00
psychedelicious
9b3b961105 fix(ui): inpaint mask naming 2024-08-23 19:47:01 +10:00
psychedelicious
9b1828e1aa feat(ui): efficient canvas compositing
Also solves issue of exporting layers at different opacities than what is visible
2024-08-23 19:47:01 +10:00
psychedelicious
5101873f49 feat(ui): allow multiple inpaint masks
This is easier than making it a nullable singleton
2024-08-23 19:47:01 +10:00
psychedelicious
c612f18114 fix(ui): missing rasterization cache invalidations 2024-08-23 19:47:01 +10:00
psychedelicious
7e400d876f feat(ui): iterate on filter UI, flow 2024-08-23 19:47:01 +10:00
psychedelicious
677dddcfc9 fix(ui): rehydration data loss 2024-08-23 19:47:01 +10:00
psychedelicious
0792b9175e feat(ui): sort log namespaces 2024-08-23 19:47:01 +10:00
psychedelicious
e4829f80af fix(ui): do not merge arrays by index during rehydration 2024-08-23 19:47:01 +10:00
psychedelicious
bb760f3eb4 fix(ui): clone parsed data during state rehydration
Without this, the objects and arrays in `parsed` could be mutated, and the log statment would show the mutated data.
2024-08-23 19:47:01 +10:00
psychedelicious
388c65287b fix(ui): fix logger filter
was accidetnally replacing the filter instead of appending to it.
2024-08-23 19:47:01 +10:00
psychedelicious
12cd41e05c fix(ui): race condition queue status
Sequence of events causing the race condition:
- Enqueue batch
- Invalidate `SessionQueueStatus` tag
- Request updated queue status via HTTP - batch still processing at this point
- Batch completes
- Event emitted saying so
- Optimistically update the queue status cache, it is correct
- HTTP request makes it back and overwrites the optimistic update, indicating the batch is still in progress

FIxed by not invalidating the cache.
2024-08-23 19:47:01 +10:00
psychedelicious
7765c03949 fix(ui): handle opacity for masks 2024-08-23 19:47:01 +10:00
psychedelicious
3daa80c57f feat(ui): default background to checkerboard 2024-08-23 19:47:01 +10:00
psychedelicious
5dbbef4ebd feat(ui): clean up logging namespaces, allow skipping namespaces 2024-08-23 19:47:01 +10:00
psychedelicious
db33b3f7b5 chore(ui): bump ui library 2024-08-23 19:47:01 +10:00
psychedelicious
8ffcf2a6be fix(ui): do not allow drawing if layer disabled 2024-08-23 19:47:01 +10:00
psychedelicious
2e7ae6a07e fix(ui): stale state causing race conditions & extraneous renders 2024-08-23 19:47:01 +10:00
psychedelicious
fea1711f0c fix(ui): do not clear buffer when rendering "real" objects 2024-08-23 19:47:01 +10:00
psychedelicious
2a3546db97 tidy(ui): remove "filter" from CanvasImageState 2024-08-23 19:47:01 +10:00
psychedelicious
285c266612 feat(ui): better editable title 2024-08-23 19:47:01 +10:00
psychedelicious
426ad54c53 fix(ui): stroke eraserline 2024-08-23 19:47:01 +10:00
psychedelicious
fc75f7919f feat(ui): restore transparency effect for control layers 2024-08-23 19:47:01 +10:00
psychedelicious
6c6b1aaff6 feat(ui): use text cursor for entity title 2024-08-23 19:47:01 +10:00
psychedelicious
c319d653ac tidy(ui): remove extraneous logging in CanvasStateApi 2024-08-23 19:47:01 +10:00
psychedelicious
d887e474e7 feat(ui): better buffer commit logic 2024-08-23 19:47:01 +10:00
psychedelicious
da7b52d6ba feat(ui): render buffer separately from "real" objects 2024-08-23 19:47:01 +10:00
psychedelicious
b5aa308593 fix(ui): pixelRect should always be integer 2024-08-23 19:47:01 +10:00
psychedelicious
0b7ceb3bb6 fix(ui): only update stage attrs when stage itself is dragged 2024-08-23 19:47:01 +10:00
psychedelicious
3a70cefda2 feat(ui): add line simplification
This fixes some awkward issues where line segments stack up.
2024-08-23 19:47:01 +10:00
psychedelicious
4b609251e1 fix(ui): various things listening when they need not listen 2024-08-23 19:47:01 +10:00
psychedelicious
0839eac0f7 feat(ui): layer opacity via caching 2024-08-23 19:47:01 +10:00
psychedelicious
5f2a7feeee feat(ui): reset view fits all visible objects 2024-08-23 19:47:01 +10:00
psychedelicious
982535eb92 fix(ui): rerenders when changing canvas scale 2024-08-23 19:47:01 +10:00
psychedelicious
0c2b8edc8d fix(ui): do not render rasterized layer unless renderObjects=true 2024-08-23 19:47:01 +10:00
psychedelicious
f78f4ca25f feat(ui): revise app layout strategy, add interaction scopes for hotkeys 2024-08-23 19:47:01 +10:00
psychedelicious
d6b3e6c07d feat(ui): tweak mask patterns 2024-08-23 19:47:01 +10:00
psychedelicious
071ff8e74a fix(ui): dynamic prompts recalcs when presets are loaded 2024-08-23 19:47:01 +10:00
psychedelicious
1ea8aafca1 fix(ui): use style preset prompts correctly 2024-08-23 19:46:05 +10:00
psychedelicious
533dd221f8 fix(ui): discard selected staging image not all other images 2024-08-23 19:46:05 +10:00
psychedelicious
2b325c6683 fix(ui): respect image size in staging preview 2024-08-23 19:46:05 +10:00
psychedelicious
3845b1b3e6 tidy(ui): cleanup after events change 2024-08-23 19:46:05 +10:00
psychedelicious
cea7890a67 feat(ui): move socket event handling out of redux
Download events and invocation status events (including progress images) are very frequent. There's no real need for these to pass through redux. Handling them outside redux is a significant performance win - far fewer store subscription calls, far fewer trips through middleware.

All event handling is moved outside middleware. Cleanup of unused actions and listeners to follow.
2024-08-23 19:46:05 +10:00
psychedelicious
c38fe8025d fix(ui): rebase conflicts 2024-08-23 19:46:05 +10:00
psychedelicious
f1de95349c fix(ui): update compositing rect when fill changes 2024-08-23 19:46:05 +10:00
psychedelicious
2950775fa7 feat(ui): add canvas background style 2024-08-23 19:46:05 +10:00
psychedelicious
cb293fd7ac feat(ui): mask layers choose own opacity 2024-08-23 19:46:05 +10:00
psychedelicious
43b3fab6be feat(ui): mask fill patterns 2024-08-23 19:46:05 +10:00
psychedelicious
d4b0dbce49 build(ui): add vite types to tsconfig 2024-08-23 19:46:05 +10:00
psychedelicious
137b810669 fix(ui): do not smooth pixel data when using eyeDropper 2024-08-23 19:46:05 +10:00
psychedelicious
c172657324 tidy(ui): tool components & translations 2024-08-23 19:46:05 +10:00
psychedelicious
97c966b04f feat(ui): rough out eyedropper tool
It's a bit slow bc we are converting the stage to canvas on every mouse move. Also need to improve the visual but it works.
2024-08-23 19:46:05 +10:00
psychedelicious
7178fc6253 fix(ui): ip adapters work 2024-08-23 19:46:05 +10:00
psychedelicious
4adb2eabf5 feat(ui): rename layers 2024-08-23 19:46:05 +10:00
psychedelicious
9f2c815e13 feat(ui): revise entity menus 2024-08-23 19:46:05 +10:00
psychedelicious
1435557d1d feat(ui): split control layers from raster layers for UI and internal state, same rendering as raster layers 2024-08-23 19:46:05 +10:00
psychedelicious
96abf687f6 feat(ui): implement cache for image rasterization, rip out some old controladapters code 2024-08-23 19:46:05 +10:00
psychedelicious
636d9a7209 feat(ui, app): use layer as control (wip) 2024-08-23 19:46:05 +10:00
psychedelicious
3b36eb0223 feat(ui): add contextmenu for canvas entities 2024-08-23 19:46:05 +10:00
psychedelicious
388c97bff0 feat(ui): more better logging & naming 2024-08-23 19:46:05 +10:00
psychedelicious
b1cb018695 feat(ui): better logging w/ path 2024-08-23 19:46:05 +10:00
psychedelicious
df78dd7953 feat(ui): always show marks on canvas scale slider 2024-08-23 19:46:05 +10:00
psychedelicious
0dc344a22e fix(ui): do not import button from chakra 2024-08-23 19:46:05 +10:00
psychedelicious
350d7f6f14 fix(ui): scaled bbox preview 2024-08-23 19:46:05 +10:00
psychedelicious
11059ee2d4 feat(ui): tidy up atoms 2024-08-23 19:46:05 +10:00
psychedelicious
c90d3f3bb9 feat(ui): convert all my pubsubs to atoms
its the same but better
2024-08-23 19:46:05 +10:00
psychedelicious
7f6d439fd1 feat(ui): add trnalsation 2024-08-23 19:46:05 +10:00
psychedelicious
783a78f069 fix(ui): give up on thumbnail loading, causes flash during transformer 2024-08-23 19:46:05 +10:00
psychedelicious
0ff031950d fix(ui): depth anything v2 2024-08-23 19:46:05 +10:00
psychedelicious
d7e8f3d756 tidy(ui): remove unused code, comments 2024-08-23 19:46:05 +10:00
psychedelicious
4668ea449b fix(ui): staging area works 2024-08-23 19:46:05 +10:00
psychedelicious
30d318d021 feat(nodes): temp disable canvas output crop 2024-08-23 19:46:05 +10:00
psychedelicious
de96f97e5f fix(ui): max scale 1 when reset view 2024-08-23 19:46:05 +10:00
psychedelicious
57c0a2dfb1 feat(ui): better scale changer component, reset view functionality 2024-08-23 19:46:05 +10:00
psychedelicious
cd4e464bde fix(ui): img2img 2024-08-23 19:46:05 +10:00
psychedelicious
49e48c3eb7 feat(ui): add manual scale controls 2024-08-23 19:46:05 +10:00
psychedelicious
edd3b3bce9 fix(ui): do not await clearBuffer 2024-08-23 19:46:04 +10:00
psychedelicious
f8bfb66108 feat(ui): dnd image into layer 2024-08-23 19:46:04 +10:00
psychedelicious
3b6a76cbf3 fix(ui): do not await commitBuffer 2024-08-23 19:46:04 +10:00
psychedelicious
e0b60e4320 fix(ui): properly destroy entities in manager cleanup 2024-08-23 19:46:04 +10:00
psychedelicious
2159319035 tidy(ui): clearer component names for regional guidance 2024-08-23 19:46:04 +10:00
psychedelicious
b170fc232e tidy(ui): clearer component names for ip adapter 2024-08-23 19:46:04 +10:00
psychedelicious
594da60f2f tidy(ui): clearer component names for inpaint mask 2024-08-23 19:46:04 +10:00
psychedelicious
6a432f6518 tidy(ui): clearer component names for control adapters 2024-08-23 19:46:04 +10:00
psychedelicious
eb8eacfec6 feat(ui): simplify canvas list item headers 2024-08-23 19:46:04 +10:00
psychedelicious
c8d04d42e2 fix(ui): ip adapter list item 2024-08-23 19:46:04 +10:00
psychedelicious
d39c9de81e tidy(ui): clean up unused logic 2024-08-23 19:46:04 +10:00
psychedelicious
a27d39b9ff feat(ui): clean up state, add mutex for image loading, add thumbnail loading 2024-08-23 19:46:04 +10:00
psychedelicious
6b385614f0 chore(ui): add async-mutex dep 2024-08-23 19:46:04 +10:00
psychedelicious
3ae7250ef7 feat(ui): txt2img, img2img, inpaint & outpaint working 2024-08-23 19:46:04 +10:00
psychedelicious
a42d0ce1d2 feat(ui): no padding on transformer outlines 2024-08-23 19:46:04 +10:00
psychedelicious
d9131f7563 feat(ui): restore object count to layer titles 2024-08-23 19:46:04 +10:00
psychedelicious
bdce958f29 tidy(ui): "useIsEntitySelected" -> "useEntityIsSelected" 2024-08-23 19:46:04 +10:00
psychedelicious
3c86f1e979 tidy(ui): move transformer statics into class 2024-08-23 19:46:04 +10:00
psychedelicious
894b8a29b9 tidy(ui): massive cleanup
- create a context for entity identifiers, massively simplifying UI for each entity int he list
- consolidate common redux actions
- remove now-unused code
2024-08-23 19:46:04 +10:00
psychedelicious
8436a44973 perf(ui): do not add duplicate points to lines 2024-08-23 19:46:04 +10:00
psychedelicious
f9f9ec3688 feat(ui): up line tension to 0.3 2024-08-23 19:46:04 +10:00
psychedelicious
5a98d7a1f6 perf(ui): disable stroke, perfect draw on compositing rect 2024-08-23 19:46:04 +10:00
psychedelicious
f9bc96e497 tidy(ui): remove unused code, initial image 2024-08-23 19:46:04 +10:00
psychedelicious
56350ff5dc tidy(ui): remove unused state & actions 2024-08-23 19:46:04 +10:00
psychedelicious
6c1139340c feat(ui): region mask rendering 2024-08-23 19:46:04 +10:00
psychedelicious
641b1a7e6f feat(ui): esc cancels drawing buffer
maybe this is not wanted? we'll see
2024-08-23 19:46:04 +10:00
psychedelicious
674a3f462f fix(ui): render transformer over objects, fix issue w/ inpaint rect color 2024-08-23 19:46:04 +10:00
psychedelicious
2283186d3a fix(ui): brush preview fill for inpaint/region 2024-08-23 19:46:04 +10:00
psychedelicious
340af1fe50 fix(ui): no objects rendered until vis toggled 2024-08-23 19:46:04 +10:00
psychedelicious
9378656d78 feat(ui): inpaint mask transform 2024-08-23 19:46:04 +10:00
psychedelicious
5f0413e222 fix(ui): layer accidental early set isFirstRender=false 2024-08-23 19:46:04 +10:00
psychedelicious
c3e47515b1 fix(ui): inpaint mask rendering 2024-08-23 19:46:04 +10:00
psychedelicious
5dcef6fa0d feat(ui): wip inpaint mask uses new API 2024-08-23 19:46:04 +10:00
psychedelicious
31ac02cd93 feat(ui): move updatePosition to transformer 2024-08-23 19:46:04 +10:00
psychedelicious
ab16976084 feat(ui): move resetScale to transformer 2024-08-23 19:46:04 +10:00
psychedelicious
8e2b7845e1 tidy(ui): more imperative naming 2024-08-23 19:46:04 +10:00
psychedelicious
3973bce342 tidy(ui): use imperative names for setters in stateapi 2024-08-23 19:46:04 +10:00
psychedelicious
f63847a504 fix(ui): commit drawing buffer on tool change, fixing bbox not calculating 2024-08-23 19:46:04 +10:00
psychedelicious
07e3529948 fix(ui): sync transformer when requesting bbox calc 2024-08-23 19:46:04 +10:00
psychedelicious
03e1c60694 tidy(ui): rename union CanvasEntity -> CanvasEntityState 2024-08-23 19:46:04 +10:00
psychedelicious
d766ed71fc fix(ui): request rect calc immediately on transform, hiding rect 2024-08-23 19:46:04 +10:00
psychedelicious
ae68ef142a feat(ui): move bbox calculation to transformer 2024-08-23 19:46:04 +10:00
psychedelicious
20f55768c4 feat(ui): use set for transformer subscriptions 2024-08-23 19:46:04 +10:00
psychedelicious
c4fad4456e tidy(ui): clean up worker tasks when complete 2024-08-23 19:46:04 +10:00
psychedelicious
78f5ec44ad tidy(ui): remove unused code in CanvasTool 2024-08-23 19:46:04 +10:00
psychedelicious
e14ba86942 feat(ui): use pubsub for isTransforming on manager 2024-08-23 19:46:04 +10:00
psychedelicious
d4e7720f6b docs(ui): update transformer docstrings 2024-08-23 19:46:04 +10:00
psychedelicious
a3f0e7e1cb feat(ui): revised event pubsub, transformer logic split out 2024-08-23 19:46:04 +10:00
psychedelicious
30a696c476 feat(ui): add simple pubsub 2024-08-23 19:46:04 +10:00
psychedelicious
66d6c64e16 feat(ui): document & clean up object renderer 2024-08-23 19:46:04 +10:00
psychedelicious
d15be9b57c feat(ui): split out object renderer 2024-08-23 19:46:04 +10:00
psychedelicious
e5da902fd0 fix(ui): unable to hold shit while transforming to retain ratio 2024-08-23 19:46:04 +10:00
psychedelicious
fc558094c2 tidy(ui): rename canvas stuff 2024-08-23 19:46:04 +10:00
psychedelicious
ad9312e989 tidy(ui): consolidate getLoggingContext builders 2024-08-23 19:46:04 +10:00
psychedelicious
8e1a70b008 fix(ui): align all tools to 1px grid
- Offset brush tool by 0.5px when width is odd, ensuring each stroke edge is exactly on a pixel boundary
- Round the rect tool also
2024-08-23 19:46:04 +10:00
psychedelicious
17f88cd5ad feat(ui): disable image smoothing on layers 2024-08-23 19:46:04 +10:00
psychedelicious
298f1919fa fix(ui): round position when rasterizing layer 2024-08-23 19:46:04 +10:00
psychedelicious
4d20cc11d4 feat(ui): continue modularizing transform 2024-08-23 19:46:04 +10:00
psychedelicious
14f249a2f0 feat(ui): fix a few things that didn't unsubscribe correctly, add helper to manage subscriptions 2024-08-23 19:46:04 +10:00
psychedelicious
b9746a6c2c feat(ui): merge bbox outline into transformer 2024-08-23 19:46:04 +10:00
psychedelicious
94f298a6f4 fix(ui): update parent's pos not transformers 2024-08-23 19:46:04 +10:00
psychedelicious
8d3a8178da feat(ui): merge interaction rect into transformer class 2024-08-23 19:46:04 +10:00
psychedelicious
cad4212fe8 feat(ui): prepare staging area 2024-08-23 19:46:04 +10:00
psychedelicious
cff28dfaa9 feat(ui): typing for logging context 2024-08-23 19:46:04 +10:00
psychedelicious
70d7509fcc feat(ui): remove inheritance of CanvasObject
JS is terrible
2024-08-23 19:46:04 +10:00
psychedelicious
cf83af7a27 feat(ui): split & document transformer logic, iterate on class structures 2024-08-23 19:46:04 +10:00
psychedelicious
5c5a405c0f feat(ui): rotation snap to nearest 45deg when holding shift 2024-08-23 19:46:04 +10:00
psychedelicious
0208e4b232 feat(ui): expose subscribe method for nanostores 2024-08-23 19:46:04 +10:00
psychedelicious
e940754795 tidy(ui): remove layer scaling reducers 2024-08-23 19:46:04 +10:00
psychedelicious
dc9fa1a735 fix(ui): pixel-perfect transforms 2024-08-23 19:46:04 +10:00
psychedelicious
08591fbf6d fix(ui): layer visibility toggle 2024-08-23 19:46:04 +10:00
psychedelicious
74db71bb5d fix(nodes): fix canvas mask erode
it wasn't eroding enough and caused incorrect transparency in result images
2024-08-23 19:46:04 +10:00
psychedelicious
60dbe798a5 fix(ui): do not reset layer on first render 2024-08-23 19:46:04 +10:00
psychedelicious
0e676605fe feat(ui): revised logging and naming setup, fix staging area 2024-08-23 19:46:04 +10:00
psychedelicious
3f781016f6 feat(ui): add repr methods to layer and object classes 2024-08-23 19:46:04 +10:00
psychedelicious
17cd2f6b02 feat(ui): use nanoid(10) instead of uuidv4 for canvas
Shorter ids makes it much more readable
2024-08-23 19:46:04 +10:00
psychedelicious
99102a1b34 build(ui): add nanoid as explicit dep 2024-08-23 19:46:04 +10:00
psychedelicious
8d72e7d9e8 fix(ui): move CanvasImage's konva image to correct object 2024-08-23 19:46:04 +10:00
psychedelicious
0b6b6f97ad fix(ui): prevent flash when applying transform 2024-08-23 19:46:04 +10:00
psychedelicious
fb2f6382b1 build(ui): add eslint rules for async stuff 2024-08-23 19:46:04 +10:00
psychedelicious
1ddea87c35 feat(ui): trying to fix flicker after transform 2024-08-23 19:46:04 +10:00
psychedelicious
ea02323095 feat(ui): transform cleanup 2024-08-23 19:46:04 +10:00
psychedelicious
49733091c7 feat(ui): fix transform when rotated 2024-08-23 19:46:04 +10:00
psychedelicious
cf833fd6e2 fix(ui): use pixel bbox when image is in layer 2024-08-23 19:46:04 +10:00
psychedelicious
ba5cf07ab8 fix(ui): transforming when axes flipped 2024-08-23 19:46:04 +10:00
psychedelicious
d15321a373 feat(ui): hallelujah (???) 2024-08-23 19:46:04 +10:00
psychedelicious
de597a5eb4 feat(ui): add debug button 2024-08-23 19:46:04 +10:00
psychedelicious
e5f5cbdf5c fix(ui): transformer padding 2024-08-23 19:46:04 +10:00
psychedelicious
7d4342bbff feat(ui): wip transform mode 2 2024-08-23 19:46:04 +10:00
psychedelicious
7f8a1d8d20 feat(ui): wip transform mode 2024-08-23 19:46:04 +10:00
psychedelicious
65353ac1e1 feat(ui): wip transform mode 2024-08-23 19:46:04 +10:00
psychedelicious
7f9a31ca4a fix(ui): dnd to canvas broke 2024-08-23 19:46:04 +10:00
psychedelicious
592eb2886c fix(ui): conflicts after rebasing 2024-08-23 19:46:04 +10:00
psychedelicious
c220dd8987 fix(ui): imageDropped listener 2024-08-23 19:46:04 +10:00
psychedelicious
a263beb0d5 wip 2024-08-23 19:46:04 +10:00
psychedelicious
46b7c510eb fix(ui): transform tool seems to be working 2024-08-23 19:46:04 +10:00
psychedelicious
f405e472ea fix(ui): move tool fixes, add transform tool 2024-08-23 19:46:04 +10:00
psychedelicious
7bdfd3ef5f feat(ui): move tool now only moves 2024-08-23 19:46:04 +10:00
psychedelicious
778ee2c679 feat(ui): layer bbox calc in worker 2024-08-23 19:46:04 +10:00
psychedelicious
e70339ff3e feat(ui): tweaked entity & group selection styles 2024-08-23 19:46:04 +10:00
psychedelicious
88c57a9750 feat(ui): canvas entity list headers 2024-08-23 19:46:04 +10:00
psychedelicious
137252128b tidy(ui): CanvasRegion 2024-08-23 19:46:04 +10:00
psychedelicious
d4297b1345 tidy(ui): CanvasRect 2024-08-23 19:46:04 +10:00
psychedelicious
6059bc7b47 tidy(ui): CanvasLayer 2024-08-23 19:46:04 +10:00
psychedelicious
c3ff3eb51f tidy(ui): CanvasInpaintMask 2024-08-23 19:46:04 +10:00
psychedelicious
0b7751c413 tidy(ui): CanvasInitialImage 2024-08-23 19:46:04 +10:00
psychedelicious
d7f1c30624 tidy(ui): CanvasImage 2024-08-23 19:46:04 +10:00
psychedelicious
3f4d7dbeea tidy(ui): CanvasEraserLine 2024-08-23 19:46:04 +10:00
psychedelicious
19b6ae2907 tidy(ui): CanvasControlAdapter 2024-08-23 19:46:04 +10:00
psychedelicious
769f96ff9f tidy(ui): CanvasBrushLine 2024-08-23 19:46:04 +10:00
psychedelicious
fdaf75faa4 tidy(ui): CanvasBbox 2024-08-23 19:46:04 +10:00
psychedelicious
1380bb7ae6 tidy(ui): CanvasBackground 2024-08-23 19:46:04 +10:00
psychedelicious
9483c8cc29 tidy(ui): update canvas classes, organise location of konva nodes 2024-08-23 19:46:04 +10:00
psychedelicious
2ef8a8cf5a feat(ui): add names to all konva objects
Makes troubleshooting much simpler
2024-08-23 19:46:04 +10:00
psychedelicious
d296ec1932 fix(ui): do not await creating new canvas image
If you await this, it causes a race condition where multiple images are created.
2024-08-23 19:46:04 +10:00
psychedelicious
444ad3dae1 feat(ui): use position and dimensions instead of separate x,y,width,height attrs 2024-08-23 19:46:04 +10:00
psychedelicious
8cdcc71378 fix(ui): remove weird rtkq hook wrapper
I do not understand why I did that initially but it doesn't work with TS.
2024-08-23 19:46:04 +10:00
psychedelicious
e8bc06cfd3 feat(ui): rename types size and position to dimensions and coordinate 2024-08-23 19:46:04 +10:00
psychedelicious
67a0a024e9 tidy(ui): hide layer settings by default 2024-08-23 19:46:04 +10:00
psychedelicious
bd2c46c267 fix(ui): layer rendering when starting as disabled 2024-08-23 19:46:04 +10:00
psychedelicious
5acb27f350 feat(invocation): reduce canvas v2 mask & crop mask dilation 2024-08-23 19:46:04 +10:00
psychedelicious
7271b12d0f feat(ui): de-jank staging area and progress images 2024-08-23 19:46:04 +10:00
psychedelicious
4a79467a33 feat(ui): update staging handling to work w/ cropped mask 2024-08-23 19:46:04 +10:00
psychedelicious
5501bb87a3 chore(ui): typegen 2024-08-23 19:46:04 +10:00
psychedelicious
561610e296 feat(app): update CanvasV2MaskAndCropInvocation 2024-08-23 19:46:04 +10:00
psychedelicious
b76609ef18 feat(ui): use new canvas output node 2024-08-23 19:46:04 +10:00
psychedelicious
070b78501b chore(ui): typegen 2024-08-23 19:46:04 +10:00
psychedelicious
50df4f4ab6 feat(app): add CanvasV2MaskAndCropInvocation & CanvasV2MaskAndCropOutput
This handles some masking and cropping that the canvas needs.
2024-08-23 19:46:04 +10:00
psychedelicious
9bbf430125 fix(ui): restore nodes output tracking 2024-08-23 19:46:04 +10:00
psychedelicious
384a90958a feat(ui): rip out document size
barely knew ye
2024-08-23 19:46:04 +10:00
psychedelicious
0e4a25b029 feat(ui): convert initial image to layer when starting canvas session 2024-08-23 19:46:04 +10:00
psychedelicious
4a44e171fd fix(ui): fix layer transparency calculation 2024-08-23 19:46:04 +10:00
psychedelicious
9bc57a6f59 fix(ui): reset initial image when resetting canvas 2024-08-23 19:46:03 +10:00
psychedelicious
4341ed7ab4 fix(ui): reset node executions states when loading workflow 2024-08-23 19:46:03 +10:00
psychedelicious
97ce72c542 fix(ui): entity display list 2024-08-23 19:46:03 +10:00
psychedelicious
a2c78a57a7 feat(ui): img2img working 2024-08-23 19:46:03 +10:00
psychedelicious
044a713dc9 feat(ui): rough out img2img on canvas 2024-08-23 19:46:03 +10:00
psychedelicious
b8479c5fe2 UNDO ME WIP 2024-08-23 19:46:03 +10:00
psychedelicious
4e5d056824 feat(ui): log invocation source id on socket event 2024-08-23 19:46:03 +10:00
psychedelicious
118278b372 feat(ui): restore document size overlay renderer 2024-08-23 19:46:03 +10:00
psychedelicious
8e8c255f3f feat(ui): make documnet size a rect 2024-08-23 19:46:03 +10:00
psychedelicious
1575bee401 refactor(ui): remove modular imagesize components
This is no longer necessary with canvas v2 and added a ton of extraneous redux actions when changing the image size. Also renamed to document size
2024-08-23 19:46:03 +10:00
psychedelicious
249bbfc883 feat(ui): initialState is for generation mode 2024-08-23 19:46:03 +10:00
psychedelicious
3993ae410f feat(ui): split out canvas entity list component 2024-08-23 19:46:03 +10:00
psychedelicious
edf040e3d2 feat(ui): hide bbox button when no canvas session active 2024-08-23 19:46:03 +10:00
psychedelicious
66fd077ee7 tidy(ui): remove unused naming objects/utils
The canvas manager means we don't need to worry about konva node names as we never directly select konva nodes.
2024-08-23 19:46:03 +10:00
psychedelicious
b93462ebb6 feat(ui): split up tool chooser buttons
Prep for distinct toolbars for generation vs canvas modes
2024-08-23 19:46:03 +10:00
psychedelicious
aae60d0cdc feat(ui): add useAssertSingleton util hook
This simple hook asserts that it is only ever called once. Particularly useful for things like hotkeys hooks.
2024-08-23 19:46:03 +10:00
psychedelicious
d4da00e607 feat(ui): "stagingArea" -> "session" 2024-08-23 19:46:03 +10:00
psychedelicious
0c539ff00b feat(ui): add reset button to canvas 2024-08-23 19:46:03 +10:00
psychedelicious
5983cbf26c feat(ui): add snapToRect util 2024-08-23 19:46:03 +10:00
psychedelicious
c513d6e3af fix(ui): fiddle with control adapter filters
some jank still
2024-08-23 19:46:03 +10:00
psychedelicious
9d57c0e631 feat(ui): temp disable doc size overlay 2024-08-23 19:46:03 +10:00
psychedelicious
a1923a8966 feat(ui): no animation on layer selection
Felt sluggish
2024-08-23 19:46:03 +10:00
psychedelicious
d988e18731 feat(ui): use canvas as source for control images (wip) 2024-08-23 19:46:03 +10:00
psychedelicious
51008da2dd fix(ui): control adapter translate & scale 2024-08-23 19:46:03 +10:00
psychedelicious
6ccc1f5672 tidy(ui): removed unused state related to non-buffered drawing 2024-08-23 19:46:03 +10:00
psychedelicious
4a556f84e0 feat(ui): control adapter image rendering 2024-08-23 19:46:03 +10:00
psychedelicious
2f21a2220d fix(ui): do not floor bbox calc, it cuts off the last pixels 2024-08-23 19:46:03 +10:00
psychedelicious
91a420b13e feat(ui): fix issue where creating line needs 2 points 2024-08-23 19:46:03 +10:00
psychedelicious
c27da3581b fix(ui): edge cases when holding shift and drawing lines 2024-08-23 19:46:03 +10:00
psychedelicious
961dfbce93 fix(ui): set buffered rect color to full alpha 2024-08-23 19:46:03 +10:00
psychedelicious
df39c825ae fix(ui): handle mouseup correctly 2024-08-23 19:46:03 +10:00
psychedelicious
3f6ee1b7a4 feat(ui): buffered rect drawing 2024-08-23 19:46:03 +10:00
psychedelicious
908e504a6f fix(ui): buffered drawing edge cases 2024-08-23 19:46:03 +10:00
psychedelicious
f2fa41afc5 perf(ui): do not use stage.find 2024-08-23 19:46:03 +10:00
psychedelicious
440ff40ad5 perf(ui): object groups do not listen 2024-08-23 19:46:03 +10:00
psychedelicious
5c15458e15 perf(ui): buffered drawing (wip) 2024-08-23 19:46:03 +10:00
psychedelicious
be5b474f1e tidy(ui): organise files 2024-08-23 19:46:03 +10:00
psychedelicious
cee178c2b6 tidy(ui): organise files 2024-08-23 19:46:03 +10:00
psychedelicious
27657f8b7a tidy(ui): organise files 2024-08-23 19:46:03 +10:00
psychedelicious
e0cde3815a fix(ui): background rendering 2024-08-23 19:46:03 +10:00
psychedelicious
09d0421de4 pkg(ui): remove unused deps react-konva & use-image 2024-08-23 19:46:03 +10:00
psychedelicious
47b94d563c feat(ui): organize konva state and files 2024-08-23 19:46:03 +10:00
psychedelicious
0b5d20c9f0 fix(ui): merge conflicts in image deletion listener 2024-08-23 19:46:03 +10:00
psychedelicious
80e7e1293a fix(ui): region rendering 2024-08-23 19:46:03 +10:00
psychedelicious
3a82b0cbc1 fix(ui): inpaint mask rendering 2024-08-23 19:46:03 +10:00
psychedelicious
a27cbc13b6 fix(ui): staging area rendering 2024-08-23 19:46:03 +10:00
psychedelicious
a8f962eb3f fix(ui): stale selected entity 2024-08-23 19:46:03 +10:00
psychedelicious
7f40d23f19 fix(ui): staging area image offset 2024-08-23 19:46:03 +10:00
psychedelicious
918354cd9d feat(ui): tweak layer ui component 2024-08-23 19:46:03 +10:00
psychedelicious
eef9278ee6 fix(ui): resetting layer resets position 2024-08-23 19:46:03 +10:00
psychedelicious
2c32e2e5c1 feat(ui): updated layer list component styling 2024-08-23 19:46:03 +10:00
psychedelicious
6f05654db5 feat(ui): transformable layers 2024-08-23 19:46:03 +10:00
psychedelicious
1d31b6902f feat(ui): move tool icon is pointer like in other apps 2024-08-23 19:46:03 +10:00
psychedelicious
5a7d615e64 feat(ui): do not floor cursor position 2024-08-23 19:46:03 +10:00
psychedelicious
1dbf9e4ed4 feat(ui): disable gallery hotkeys while staging 2024-08-23 19:46:03 +10:00
psychedelicious
5dcc6ee203 feat(ui): revised canvas progress & staging image handling 2024-08-23 19:46:03 +10:00
psychedelicious
84a4e6ae3f feat(ui): show queue item origin in queue list 2024-08-23 19:46:03 +10:00
psychedelicious
f283bfd68f chore(ui): typegen 2024-08-23 19:46:03 +10:00
psychedelicious
6e5ff7b79c feat(app): add origin to session queue
The origin is an optional field indicating the queue item's origin. For example, "canvas" when the queue item originated from the canvas or "workflows" when the queue item originated from the workflows tab. If omitted, we assume the queue item originated from the API directly.

- Add migration to add the nullable column to the `session_queue` table.
- Update relevant event payloads with the new field.
- Add `cancel_by_origin` method to `session_queue` service and corresponding route. This is required for the canvas to bail out early when staging images.
- Add `origin` to both `SessionQueueItem` and `Batch` - it needs to be provided initially via the batch and then passed onto the queue item.
-
2024-08-23 19:46:03 +10:00
psychedelicious
7c3800d03f fix(ui): denoise start on outpainting 2024-08-23 19:46:03 +10:00
psychedelicious
941db90518 feat(ui): add redux events for queue cleared & batch enqueued socket events 2024-08-23 19:46:03 +10:00
psychedelicious
0d9ecf0f90 feat(ui): canvas staging area works 2024-08-23 19:46:03 +10:00
psychedelicious
9c77023a11 feat(ui): switch to view tool when staging 2024-08-23 19:46:03 +10:00
psychedelicious
b55378c63c tidy(ui): disable preview images on every enqueue 2024-08-23 19:46:03 +10:00
psychedelicious
946c2a49ab feat(ui): rough out save staging image 2024-08-23 19:46:03 +10:00
psychedelicious
b823c31ec6 feat(ui): staging area image visibility toggle 2024-08-23 19:46:03 +10:00
psychedelicious
ec6361e5cb fix(ui): batch building after removing canvas files 2024-08-23 19:46:03 +10:00
psychedelicious
0c26d28278 feat(ui): make Graph class's getMetadataNode public 2024-08-23 19:46:03 +10:00
psychedelicious
c5172d4c5a tidy(ui): remove old canvas graphs 2024-08-23 19:46:03 +10:00
psychedelicious
89de04775e fix(ui): do not select already-selected entity 2024-08-23 19:46:03 +10:00
psychedelicious
b4c3c940b5 tidy(ui): naming things 2024-08-23 19:46:03 +10:00
psychedelicious
aee2aad959 tidy(ui): file organisation 2024-08-23 19:46:03 +10:00
psychedelicious
5ca48a8a5f fix(ui): reset cursor pos when fitting document 2024-08-23 19:46:03 +10:00
psychedelicious
1806aa187b feat(ui): staging area works more better 2024-08-23 19:46:03 +10:00
psychedelicious
7824cb7a1a feat(ui): staging area barely works 2024-08-23 19:46:03 +10:00
psychedelicious
9807a896f4 feat(ui): consolidate konva API 2024-08-23 19:46:03 +10:00
psychedelicious
19866f057d feat(ui): consolidate konva API 2024-08-23 19:46:03 +10:00
psychedelicious
ec4eae3c9c feat(ui): staging area (rendering wip) 2024-08-23 19:46:03 +10:00
psychedelicious
bea0cba038 tidy(ui): type "Dimensions" -> "Size" 2024-08-23 19:46:03 +10:00
psychedelicious
48ee75af9c feat(ui): add updateNode to Graph 2024-08-23 19:46:03 +10:00
psychedelicious
929c593d2f feat(ui): sdxl graphs 2024-08-23 19:46:03 +10:00
psychedelicious
221f32eca7 feat(ui): sd1 outpaint graph 2024-08-23 19:46:03 +10:00
psychedelicious
c3acc15e8b tests(ui): add missing tests for Graph class 2024-08-23 19:46:03 +10:00
psychedelicious
1b653278fc feat(ui): add Graph.getid() util 2024-08-23 19:46:03 +10:00
psychedelicious
cc9062ee46 feat(ui): outpaint graph, organize builder a bit 2024-08-23 19:46:03 +10:00
psychedelicious
91c0feb0ad feat(ui): inpaint sd1 graph 2024-08-23 19:46:03 +10:00
psychedelicious
ae60292ac8 feat(ui): temp disable image caching while testing 2024-08-23 19:46:03 +10:00
psychedelicious
a6ca17b19d feat(ui): txt2img & img2img graphs 2024-08-23 19:46:03 +10:00
psychedelicious
6a4a5ece74 feat(ui): minor change to canvas bbox state type 2024-08-23 19:46:03 +10:00
psychedelicious
9b81860307 feat(ui): simplified konva node to blob/imagedata utils 2024-08-23 19:46:03 +10:00
psychedelicious
5f4a3928d2 feat(ui): node manager getter/setter 2024-08-23 19:46:03 +10:00
psychedelicious
b703884763 feat(ui): generation mode calculation, fudged graphs 2024-08-23 19:46:03 +10:00
psychedelicious
32da98ab8f feat(ui): add utils for getting images from canvas 2024-08-23 19:46:03 +10:00
psychedelicious
bd5a85bf70 feat(ui): even more simplified API - lean on the konva node manager to abstract imperative state API & rendering 2024-08-23 19:46:03 +10:00
psychedelicious
d045f24014 feat(ui): revised docstrings for renderers & simplified api 2024-08-23 19:46:03 +10:00
psychedelicious
2aad3f89c3 feat(ui): inpaint mask UI components 2024-08-23 19:46:03 +10:00
psychedelicious
dd54d19f00 feat(ui): inpaint mask rendering (wip) 2024-08-23 19:46:03 +10:00
psychedelicious
0ed6591d8c fix(ui): models loaded handler 2024-08-23 19:46:03 +10:00
psychedelicious
712e090134 feat(ui): internal state for inpaint mask 2024-08-23 19:46:03 +10:00
psychedelicious
8fc2a1d1cf refactor(ui): divvy up canvas state a bit 2024-08-23 19:46:03 +10:00
psychedelicious
cc15c1593e feat(ui): get region and base layer canvas to blob logic working 2024-08-23 19:46:03 +10:00
psychedelicious
9997d3abda refactor(ui): node manager handles more tedious annoying stuff 2024-08-23 19:46:03 +10:00
psychedelicious
031471e785 feat(ui): use node manager for addRegions 2024-08-23 19:46:03 +10:00
psychedelicious
2e860c6791 feat(ui): persist bbox 2024-08-23 19:46:03 +10:00
psychedelicious
d071a9e17d fix(ui): fix generation graphs 2024-08-23 19:46:03 +10:00
psychedelicious
ed53d33321 feat(ui): add toggle for clipToBbox 2024-08-23 19:46:03 +10:00
psychedelicious
382bc6d978 feat(ui): rename konva node manager 2024-08-23 19:46:03 +10:00
psychedelicious
dab42e258c refactor(ui): create classes to abstract mgmt of konva nodes 2024-08-23 19:46:03 +10:00
psychedelicious
81556410bb tidy(ui): organise renderers 2024-08-23 19:46:03 +10:00
psychedelicious
1f2dfd473c refactor(ui): create entity to konva node map abstraction (wip)
Instead of chaining konva `find` and `findOne` methods, all konva nodes are added to a mapping object. Finding and manipulating them is much simpler.

Done for regions and layers, wip for control adapters.
2024-08-23 19:46:03 +10:00
psychedelicious
8f0f51be2c perf(ui): fix lag w/ region rendering
Needed to memoize these selectors
2024-08-23 19:46:03 +10:00
psychedelicious
7179e250ed feat(ui): move canvas fill color picker to right 2024-08-23 19:46:03 +10:00
psychedelicious
5bec091fd6 refactor(ui): remove unused ellipse & polygon objects 2024-08-23 19:46:03 +10:00
psychedelicious
2c5896cb0c fix(ui): incorrect rect/brush/eraser positions 2024-08-23 19:46:03 +10:00
psychedelicious
93ff252dc0 refactor(ui): enable global debugging flag 2024-08-23 19:46:03 +10:00
psychedelicious
ac52224455 refactor(ui): disable the preview renderer for now 2024-08-23 19:46:03 +10:00
psychedelicious
4087cad23d tweak(ui): canvas editor layout 2024-08-23 19:46:03 +10:00
psychedelicious
e936b1ff8f perf(ui): memoize layeractionsmenu valid actions 2024-08-23 19:46:03 +10:00
psychedelicious
b7f9c5e221 refactor(ui): decouple konva renderer from react
Subscribe to redux store directly, skipping all the react overhead.

With react in dev mode, a typical frame while using the brush tool on almost-empty canvas is reduced from ~7.5ms to ~3.5ms. All things considered, this still feels slow, but it's a massive improvement.
2024-08-23 19:46:03 +10:00
psychedelicious
fc5467150e feat(ui): clip lines to bbox 2024-08-23 19:46:03 +10:00
psychedelicious
4dcab357a0 fix(ui): document fit positioning 2024-08-23 19:46:03 +10:00
psychedelicious
695e464255 feat(ui): document bounds overlay 2024-08-23 19:46:03 +10:00
psychedelicious
9999b60c3b tidy(ui): background layer 2024-08-23 19:46:03 +10:00
psychedelicious
e7df53e260 refactor(ui): use "entity" instead of "data" for canvas 2024-08-23 19:46:03 +10:00
psychedelicious
844590a571 feat(ui): brush size border radius = 1 2024-08-23 19:46:03 +10:00
psychedelicious
9622beaa0d fix(ui): canvas HUD doesn't interrupt tool 2024-08-23 19:46:03 +10:00
psychedelicious
007e2553a8 refactor(ui): split up canvas entity renderers, temp disable preview 2024-08-23 19:46:03 +10:00
psychedelicious
15ad4e3f5e fix(ui): delete all layers button 2024-08-23 19:46:03 +10:00
psychedelicious
be5094fcb4 fix(ui): ignore keyboard shortcuts in input/textarea elements 2024-08-23 19:46:03 +10:00
psychedelicious
a20a861680 fix(ui): canvas entity ids getting clobbered 2024-08-23 19:46:03 +10:00
psychedelicious
396d0a4bc0 fix(ui): move lora followup fixes 2024-08-23 19:46:03 +10:00
psychedelicious
ca9314e077 chore(ui): lint 2024-08-23 19:46:03 +10:00
psychedelicious
4b848798e7 refactor(ui): move loras to canvas slice 2024-08-23 19:46:03 +10:00
psychedelicious
083bcbc77d fix(ui): layer is selected when added 2024-08-23 19:46:03 +10:00
psychedelicious
e8cdc9ae62 feat(ui): r to center & fit stage on document 2024-08-23 19:46:03 +10:00
psychedelicious
8abfa759a4 feat(ui): better HUD 2024-08-23 19:46:03 +10:00
psychedelicious
f6a324b633 fix(ui): always use current brush width when making straight lines 2024-08-23 19:46:03 +10:00
psychedelicious
f083be9391 feat(ui): hold shift w/ brush to draw straight line 2024-08-23 19:46:03 +10:00
psychedelicious
091e2fb751 fix(ui): update bg on canvas resize 2024-08-23 19:46:03 +10:00
psychedelicious
d8539daf1f refactor(ui): better hud 2024-08-23 19:46:03 +10:00
psychedelicious
7ec059f5fa refactor(ui): scaled tool preview border 2024-08-23 19:46:03 +10:00
psychedelicious
4f2ecdefd2 refactor(ui): port remaining canvasV1 rendering logic to V2, remove old code 2024-08-23 19:46:03 +10:00
psychedelicious
e8891a1988 refactor(ui): fix more types 2024-08-23 19:46:03 +10:00
psychedelicious
37d2607f34 refactor(ui): metadata recall (wip)
just enough let the app run
2024-08-23 19:46:03 +10:00
psychedelicious
0e7b10d3d9 refactor(ui): undo/redo button temp fix 2024-08-23 19:46:03 +10:00
psychedelicious
1f85888638 refactor(ui): fix renderer stuff 2024-08-23 19:46:03 +10:00
psychedelicious
c1f9a129fa refactor(ui): fix misc types 2024-08-23 19:46:03 +10:00
psychedelicious
7ccc5ba398 refactor(ui): fix gallery stuff 2024-08-23 19:46:03 +10:00
psychedelicious
5e1a6ae334 refactor(ui): fix delete image stuff 2024-08-23 19:46:03 +10:00
psychedelicious
3f6cf638f9 refactor(ui): fix useIsReadyToEnqueue for new adapterType field 2024-08-23 19:46:03 +10:00
psychedelicious
46e062a828 refactor(ui): update generation tab graphs 2024-08-23 19:46:02 +10:00
psychedelicious
cc3a0b5d6c refactor(ui): add adapterType to ControlAdapterData 2024-08-23 19:46:02 +10:00
psychedelicious
775479ab7b refactor(ui): update components & logic to use new unified slice (again) 2024-08-23 19:46:02 +10:00
psychedelicious
6b9e0e6d63 refactor(ui): update components & logic to use new unified slice 2024-08-23 19:46:02 +10:00
psychedelicious
83a5c87f5e refactor(ui): merge compositing, params into canvasV2 slice 2024-08-23 19:46:02 +10:00
psychedelicious
84fde74331 refactor(ui): add scaled bbox state 2024-08-23 19:46:02 +10:00
psychedelicious
a517e29b39 refactor(ui): update dnd/image upload 2024-08-23 19:46:02 +10:00
psychedelicious
ccceba7565 refactor(ui): update size/prompts state 2024-08-23 19:46:02 +10:00
psychedelicious
5fc7a03669 refactor(ui): rip out old control adapter implementation 2024-08-23 19:46:02 +10:00
psychedelicious
8864ad1b50 refactor(ui): canvas v2 (wip)
fix entity count select
2024-08-23 19:46:02 +10:00
psychedelicious
f2989885fb refactor(ui): canvas v2 (wip)
delete unused file
2024-08-23 19:46:02 +10:00
psychedelicious
19c66e5c76 refactor(ui): canvas v2 (wip)
merge all canvas state reducers into one big slice (but with the logic split across files so it's not hell)
2024-08-23 19:46:02 +10:00
psychedelicious
8a6690a57c refactor(ui): canvas v2 (wip)
Fix a few more components
2024-08-23 19:46:02 +10:00
psychedelicious
2cc60f253a refactor(ui): canvas v2 (wip)
missed a spot
2024-08-23 19:46:02 +10:00
psychedelicious
cb69872dd3 refactor(ui): canvas v2 (wip)
Redo all UI components for different canvas entity types
2024-08-23 19:46:02 +10:00
psychedelicious
ba66d7c9a6 refactor(ui): canvas v2 (wip) 2024-08-23 19:46:02 +10:00
psychedelicious
9fe727c9f8 refactor(ui): canvas v2 (wip) 2024-08-23 19:46:02 +10:00
psychedelicious
58c656224f refactor(ui): canvas v2 (wip) 2024-08-23 19:46:02 +10:00
psychedelicious
c51253f5f6 refactor(ui): canvas v2 (wip) 2024-08-23 19:46:02 +10:00
psychedelicious
6c1d1588fc feat(ui): bbox tool 2024-08-23 19:46:02 +10:00
psychedelicious
95d6183a6c fix(ui): rect tool preview 2024-08-23 19:46:02 +10:00
psychedelicious
f4c9facdaf fix(ui): multiple stages 2024-08-23 19:46:02 +10:00
psychedelicious
a274e6f165 feat(ui): decouple konva logic from nanostores 2024-08-23 19:46:02 +10:00
psychedelicious
154e3e6f64 feat(ui): store all stage attrs in nanostores 2024-08-23 19:46:02 +10:00
psychedelicious
2c3ac972e5 feat(ui): round stage scale 2024-08-23 19:46:02 +10:00
psychedelicious
2e2e072b0b chore(ui): bump konva 2024-08-23 19:46:02 +10:00
psychedelicious
9d8dd2bf66 feat(ui): generation bbox transformation working
whew
2024-08-23 19:46:02 +10:00
psychedelicious
9047f6db30 feat(ui): wip generation bbox 2024-08-23 19:46:02 +10:00
psychedelicious
5ab345ee63 feat(ui): wip generation bbox 2024-08-23 19:46:02 +10:00
psychedelicious
d8a83acd3a feat(ui): CL zoom and pan, some rendering optimizations 2024-08-23 19:46:02 +10:00
psychedelicious
1f58e5756b Revert "feat(ui): add x,y,scaleX,scaleY,rotation to objects"
This reverts commit 53318b396c967c72326a7e4dea09667b2ab20bdd.
2024-08-23 19:46:02 +10:00
psychedelicious
744acb8f07 feat(ui): layers manage their own bbox 2024-08-23 19:46:02 +10:00
psychedelicious
ae7228d821 docs(ui): konva image object docstrings 2024-08-23 19:46:02 +10:00
psychedelicious
99d8b3a7bf feat(ui): add x,y,scaleX,scaleY,rotation to objects 2024-08-23 19:46:02 +10:00
psychedelicious
3fbe65bbcf fix(ui): show color picker when using rect tool 2024-08-23 19:46:02 +10:00
psychedelicious
f5d879d8e7 feat(ui): image loading fallback for raster layers 2024-08-23 19:46:02 +10:00
psychedelicious
cbc5a4f8e6 feat(ui): bbox calc for raster layers 2024-08-23 19:46:02 +10:00
psychedelicious
37ac7d8ed5 feat(ui): do not fill brush preview when drawing 2024-08-23 19:46:02 +10:00
psychedelicious
bee3fa339d fix(ui): brush spacing handling 2024-08-23 19:46:02 +10:00
psychedelicious
c171fe2b96 fix(ui): jank when starting a shape when not already focused on stage 2024-08-23 19:46:02 +10:00
psychedelicious
1fa8032fdb feat(ui): wip raster layers
I meant to split this up into smaller commits and undo some of it, but I committed afterwards and it's tedious to undo.
2024-08-23 19:46:02 +10:00
psychedelicious
5c0676bcc2 feat(ui): support image objects on raster layers
Just the UI and internal state, not rendering yet.
2024-08-23 19:46:02 +10:00
psychedelicious
cefd9a027c tidy(ui): clean up event handlers
Separate logic for each tool in preparation for ellipse and polygon tools.
2024-08-23 19:46:02 +10:00
psychedelicious
1bce156de1 feat(ui): raster layer reset, object group util 2024-08-23 19:46:02 +10:00
psychedelicious
cd4f63f2fd feat(ui): rect shape preview now has fill 2024-08-23 19:46:02 +10:00
psychedelicious
3c7140cbf3 feat(ui): cancel shape drawing on esc 2024-08-23 19:46:02 +10:00
psychedelicious
b71ba63b5a feat(ui): temp disable history on CL 2024-08-23 19:46:02 +10:00
psychedelicious
d540e2c0d3 feat(ui): raster layer logic
- Deduplicate shared logic
- Split up giant renderers file into separate cohesive files
- Tons of cleanup
- Progress on raster layer functionality
2024-08-23 19:46:02 +10:00
psychedelicious
d79fafc5f5 feat(ui): add raster layer rendering and interaction (WIP) 2024-08-23 19:46:02 +10:00
psychedelicious
9e93fa2092 feat(ui): scaffold out raster layers
Raster layers may have images, lines and shapes. These will replace initial image layers and provide sketching functionality like we have on canvas.
2024-08-23 19:46:02 +10:00
psychedelicious
392e9b4882 refactor(ui): revise types for line and rect objects
- Create separate object types for brush and eraser lines, instead of a single type that has a `tool` field.
- Create new object type for rect shapes.
- Add logic to schemas to migrate old object types to new.
- Update renderers & reducers.
2024-08-23 19:46:02 +10:00
psychedelicious
231e5ec94a chore: bump version v4.2.8post1 2024-08-23 06:55:30 +10:00
Mary Hipp
e5bb6f9693 lint fix 2024-08-23 06:46:19 +10:00
Mary Hipp
da7dee44c6 fix(ui): use empty string fallback if unable to parse prompts when creating style preset from existing image 2024-08-23 06:46:19 +10:00
Eugene Brodsky
83144f4fe3 fix(docs): follow-up docker readme fixes 2024-08-22 11:19:07 -04:00
psychedelicious
c451f52ea3 chore(ui): lint 2024-08-22 21:00:09 +10:00
psychedelicious
8a2c78f2e1 fix(ui): dynamic prompts not recalculating when deleting or updating a style preset
The root cause was the active style preset not being reset when it was deleted, or no longer present in the list of style presets.

- Add extra reducer to `stylePresetSlice` to reset the active preset if it is deleted or otherwise unavailable
- Update the dynamic prompts listener to trigger on delete/update/list of style presets
2024-08-22 21:00:09 +10:00
psychedelicious
bcc78bde9b chore: bump version to v4.2.8 2024-08-22 21:00:09 +10:00
Васянатор
054bb6fe0a translationBot(ui): update translation (Russian)
Currently translated at 100.0% (1367 of 1367 strings)

Co-authored-by: Васянатор <ilabulanov339@gmail.com>
Translate-URL: https://hosted.weblate.org/projects/invokeai/web-ui/ru/
Translation: InvokeAI/Web UI
2024-08-22 13:09:56 +10:00
Riccardo Giovanetti
4f4aa6d92e translationBot(ui): update translation (Italian)
Currently translated at 98.4% (1346 of 1367 strings)

translationBot(ui): update translation (Italian)

Currently translated at 98.4% (1346 of 1367 strings)

Co-authored-by: Riccardo Giovanetti <riccardo.giovanetti@gmail.com>
Translate-URL: https://hosted.weblate.org/projects/invokeai/web-ui/it/
Translation: InvokeAI/Web UI
2024-08-22 13:09:56 +10:00
Hosted Weblate
eac51ac6f5 translationBot(ui): update translation files
Updated by "Cleanup translation files" hook in Weblate.

Co-authored-by: Hosted Weblate <hosted@weblate.org>
Translate-URL: https://hosted.weblate.org/projects/invokeai/web-ui/
Translation: InvokeAI/Web UI
2024-08-22 13:09:56 +10:00
psychedelicious
9f349a7c0a fix(ui): do not constrain width of hide/show boards button
lets translations display fully
2024-08-22 11:36:07 +10:00
psychedelicious
918afa5b15 fix(ui): show more of current board name 2024-08-22 11:36:07 +10:00
psychedelicious
eb1113f95c feat(ui): add translation string for "Upscale" 2024-08-22 11:36:07 +10:00
psychedelicious
4f4ba7b462 tidy(ui): clean up ActiveStylePreset markup 2024-08-21 09:06:41 +10:00
Mary Hipp
2298be0e6b fix(ui): error handling if unable to convert image URL to blob 2024-08-21 09:06:41 +10:00
Mary Hipp
63494dfca7 remove extra slash in exports path 2024-08-21 09:06:41 +10:00
Mary Hipp
36a1d39454 fix(ui): handle badge styling when template name is long 2024-08-21 09:06:41 +10:00
Mary Hipp
a6f6d5c400 fix(ui): add loading state to button when creating or updating a style preset 2024-08-21 09:06:41 +10:00
Mary Hipp
e85f221aca fix(ui): clear prompt template when prompts are recalled 2024-08-21 09:04:35 +10:00
Mary Hipp
d4797e37dc fix(ui): properly unwrap delete style preset API request so that error is caught 2024-08-19 16:12:39 -04:00
Mary Hipp
3e7923d072 fix(api): allow updating of type for style preset 2024-08-19 16:12:39 -04:00
psychedelicious
a85d69ce3d tidy(ui): getViewModeChunks.tsx -> .ts 2024-08-19 08:25:39 +10:00
psychedelicious
96db006c99 fix(ui): edge case with getViewModeChunks 2024-08-19 08:25:39 +10:00
psychedelicious
8ca57d03d8 tests(ui): add tests for getViewModeChunks 2024-08-19 08:25:39 +10:00
psychedelicious
6c404ce5f8 fix(ui): prompt template preset preview out of order 2024-08-19 08:25:39 +10:00
psychedelicious
584e07182b fix(ui): use translations for style preset strings 2024-08-17 21:27:53 +10:00
psychedelicious
f787e9acf6 chore: bump version v4.2.8rc2 2024-08-16 21:47:06 +10:00
psychedelicious
5a24b89e54 fix(app): include style preset defaults in build 2024-08-16 21:47:06 +10:00
psychedelicious
9b482e2a4f chore: bump version to v4.2.8rc1 2024-08-16 10:53:19 +10:00
Max
df4dbe2d57 Fix invoke.sh not detecting symlinks
When invoke.sh is executed using a symlink with a working directory outside of InvokeAI's root directory, it will fail.

invoke.sh attempts to cd into the correct directory at the start of the script, but will cd into the directory of the symlink instead. This commit fixes that.
2024-08-16 10:40:59 +10:00
psychedelicious
713bd11177 feat(ui, api): prompt template export (#6745)
## Summary

Adds option to download all prompt templates to a CSV

## Related Issues / Discussions

<!--WHEN APPLICABLE: List any related issues or discussions on github or
discord. If this PR closes an issue, please use the "Closes #1234"
format, so that the issue will be automatically closed when the PR
merges.-->

## QA Instructions

<!--WHEN APPLICABLE: Describe how you have tested the changes in this
PR. Provide enough detail that a reviewer can reproduce your tests.-->

## Merge Plan

<!--WHEN APPLICABLE: Large PRs, or PRs that touch sensitive things like
DB schemas, may need some care when merging. For example, a careful
rebase by the change author, timing to not interfere with a pending
release, or a message to contributors on discord after merging.-->

## Checklist

- [ ] _The PR has a short but descriptive title, suitable for a
changelog_
- [ ] _Tests added / updated (if applicable)_
- [ ] _Documentation added / updated (if applicable)_
2024-08-16 10:38:50 +10:00
psychedelicious
182571df4b Merge branch 'main' into maryhipp/export-presets 2024-08-16 10:17:07 +10:00
psychedelicious
29bfe492b6 ui: translations update from weblate (#6746)
Translations update from [Hosted Weblate](https://hosted.weblate.org)
for [InvokeAI/Web
UI](https://hosted.weblate.org/projects/invokeai/web-ui/).



Current translation status:

![Weblate translation
status](https://hosted.weblate.org/widget/invokeai/web-ui/horizontal-auto.svg)
2024-08-16 10:16:51 +10:00
psychedelicious
3fb4e3050c feat(ui): focus in textarea after inserting placeholder 2024-08-16 10:14:25 +10:00
psychedelicious
39c7ec3cd9 feat(ui): per type fallbacks for templates 2024-08-16 10:11:43 +10:00
psychedelicious
26bfbdec7f feat(ui): use buttons instead of menu for preset import/export 2024-08-16 09:58:19 +10:00
psychedelicious
7a3eaa8da9 feat(api): save file as prompt_templates.csv 2024-08-16 09:51:46 +10:00
Mary Hipp
599db7296f export only user style presets 2024-08-15 16:07:32 -04:00
Riccardo Giovanetti
042aab4295 translationBot(ui): update translation (Italian)
Currently translated at 98.6% (1340 of 1359 strings)

Co-authored-by: Riccardo Giovanetti <riccardo.giovanetti@gmail.com>
Translate-URL: https://hosted.weblate.org/projects/invokeai/web-ui/it/
Translation: InvokeAI/Web UI
2024-08-15 20:44:02 +02:00
Mary Hipp
24f298283f clean up, add context menu to import/download templates 2024-08-15 12:39:55 -04:00
Mary Hipp
68dac6349d Merge remote-tracking branch 'origin/main' into maryhipp/export-presets 2024-08-15 11:21:56 -04:00
chainchompa
b675fc19e8 feat: add base prop for selectedWorkflow to allow loading a workflow on launch (#6742)
## Summary
added a base prop for selectedWorkflow to allow loading a workflow on
launch

<!--A description of the changes in this PR. Include the kind of change
(fix, feature, docs, etc), the "why" and the "how". Screenshots or
videos are useful for frontend changes.-->

## Related Issues / Discussions

<!--WHEN APPLICABLE: List any related issues or discussions on github or
discord. If this PR closes an issue, please use the "Closes #1234"
format, so that the issue will be automatically closed when the PR
merges.-->

## QA Instructions
can test by loading InvokeAIUI with a selectedWorkflow prop of the
workflow ID
<!--WHEN APPLICABLE: Describe how you have tested the changes in this
PR. Provide enough detail that a reviewer can reproduce your tests.-->

## Merge Plan

<!--WHEN APPLICABLE: Large PRs, or PRs that touch sensitive things like
DB schemas, may need some care when merging. For example, a careful
rebase by the change author, timing to not interfere with a pending
release, or a message to contributors on discord after merging.-->

## Checklist

- [ ] _The PR has a short but descriptive title, suitable for a
changelog_
- [ ] _Tests added / updated (if applicable)_
- [ ] _Documentation added / updated (if applicable)_
2024-08-15 10:52:23 -04:00
chainchompa
659019cfd6 Merge branch 'main' into chainchompa/preselect-workflows 2024-08-15 10:40:44 -04:00
Mary Hipp
dcd61e1f82 pin ruff version in python check gha 2024-08-15 09:47:49 -04:00
Mary Hipp
f5c99b1488 exclude jupyter notebooks from ruff 2024-08-15 09:47:49 -04:00
Mary Hipp
810be3e1d4 update import directions to include JSON 2024-08-15 09:47:49 -04:00
psychedelicious
60d754d1df feat(api): tidy style presets import logic
- Extract parsing into utility function
- Log import errors
- Forbid extra properties on the imported data
2024-08-15 09:47:49 -04:00
psychedelicious
bd07c86db9 feat(ui): make style preset menu trigger look like button 2024-08-15 09:47:49 -04:00
psychedelicious
bcbf8b6bd8 feat(ui): revert to using {prompt} for prompt template placeholder 2024-08-15 09:47:49 -04:00
psychedelicious
356661459b feat(api): support JSON for preset imports
This allows us to support Fooocus format presets.
2024-08-15 09:47:49 -04:00
psychedelicious
deb917825e feat(api): use pydantic validation during style preset import
- Enforce name is present and not an empty string
- Provide empty string as default for positive and negative prompt
- Add `positive_prompt` as validation alias for `prompt` field
- Strip whitespace automatically
- Create `TypeAdapter` to validate the whole list in one go
2024-08-15 09:47:49 -04:00
psychedelicious
15415c6d85 feat(ui): use dropzone for style preset upload
Easier to accept multiple file types and supper drag and drop in the future.
2024-08-15 09:47:49 -04:00
Mary Hipp
76b0380b5f feat(ui): create component to upload CSV of style presets to import 2024-08-15 09:47:49 -04:00
Mary Hipp
2d58754789 feat(api): add endpoint to take a CSV, parse it, validate it, and create many style preset entries 2024-08-15 09:47:49 -04:00
chainchompa
9cdf1f599c Merge branch 'main' into chainchompa/preselect-workflows 2024-08-15 09:25:19 -04:00
chainchompa
268be97ba0 remove ref, make options optional for useGetLoadWorkflow 2024-08-15 09:18:41 -04:00
Mary Hipp
a9014673a0 wip export 2024-08-15 09:00:11 -04:00
psychedelicious
d36c43a10f ui: translations update from weblate (#6727)
Translations update from [Hosted Weblate](https://hosted.weblate.org)
for [InvokeAI/Web
UI](https://hosted.weblate.org/projects/invokeai/web-ui/).



Current translation status:

![Weblate translation
status](https://hosted.weblate.org/widget/invokeai/web-ui/horizontal-auto.svg)
2024-08-15 08:48:03 +10:00
Phrixus2023
54a5c4e482 translationBot(ui): update translation (Chinese (Simplified))
Currently translated at 98.1% (1296 of 1320 strings)

Co-authored-by: Phrixus2023 <920414016@qq.com>
Translate-URL: https://hosted.weblate.org/projects/invokeai/web-ui/zh_Hans/
Translation: InvokeAI/Web UI
2024-08-15 00:46:01 +02:00
Riccardo Giovanetti
5e09a244e3 translationBot(ui): update translation (Italian)
Currently translated at 98.5% (1336 of 1355 strings)

translationBot(ui): update translation (Italian)

Currently translated at 98.5% (1302 of 1321 strings)

translationBot(ui): update translation (Italian)

Currently translated at 98.6% (1302 of 1320 strings)

Co-authored-by: Riccardo Giovanetti <riccardo.giovanetti@gmail.com>
Translate-URL: https://hosted.weblate.org/projects/invokeai/web-ui/it/
Translation: InvokeAI/Web UI
2024-08-15 00:46:01 +02:00
chainchompa
88648dca1a change selectedWorkflow to selectedWorkflowId 2024-08-14 11:22:37 -04:00
chainchompa
8840df2b00 Merge branch 'main' into chainchompa/preselect-workflows 2024-08-14 09:02:12 -04:00
chainchompa
af159acbdf cleanup 2024-08-14 08:58:38 -04:00
chainchompa
471719bbbe add base prop for selectedWorkflow to allow loading a workflow on launch 2024-08-14 08:47:02 -04:00
psychedelicious
b126f2ffd5 feat(ui, api): prompt templates (#6729)
## Summary

Adds prompt templates to the UI. Demo video is attached.
* added default prompt templates to seed database on startup (these
cannot be edited or deleted by users via the UI)
* can create fresh prompt template, create from an image in gallery that
has prompt metadata, or copy an existing prompt template and modify
* if a template is active, can view what your prompt will be invoked as
by switching to "view mode"



https://github.com/user-attachments/assets/32d84e0c-b04c-48da-bae5-aa6eb685d209



## Related Issues / Discussions

<!--WHEN APPLICABLE: List any related issues or discussions on github or
discord. If this PR closes an issue, please use the "Closes #1234"
format, so that the issue will be automatically closed when the PR
merges.-->

## QA Instructions

<!--WHEN APPLICABLE: Describe how you have tested the changes in this
PR. Provide enough detail that a reviewer can reproduce your tests.-->

## Merge Plan

<!--WHEN APPLICABLE: Large PRs, or PRs that touch sensitive things like
DB schemas, may need some care when merging. For example, a careful
rebase by the change author, timing to not interfere with a pending
release, or a message to contributors on discord after merging.-->

## Checklist

- [ ] _The PR has a short but descriptive title, suitable for a
changelog_
- [ ] _Tests added / updated (if applicable)_
- [ ] _Documentation added / updated (if applicable)_
2024-08-14 12:49:31 +10:00
psychedelicious
9938f12ef0 Merge branch 'main' into maryhipp/style-presets 2024-08-14 12:33:30 +10:00
psychedelicious
982c266073 tidy: remove extra characters in prompt templates 2024-08-14 12:31:57 +10:00
psychedelicious
5c37391883 fix(ui): do not show [prompt] in preset preview 2024-08-14 12:29:05 +10:00
psychedelicious
ddeafc6833 fix(ui): minimize layout shift when overlaying preset prompt preview 2024-08-14 12:24:57 +10:00
psychedelicious
41b2d5d013 fix(ui): prompt preview not working preset starts with [prompt] 2024-08-14 12:21:38 +10:00
psychedelicious
29d6f48901 fix(ui): prompt shows thru prompt label text 2024-08-14 12:01:49 +10:00
psychedelicious
d5c9f4e47f chore(ui): revert framer-motion upgrade
`framer-motion` 11 breaks a lot of stuff in profoundly unintuitive ways, holy crap. UI lib rolled back its dep, pulling in latest version of that
2024-08-14 06:12:00 +10:00
psychedelicious
24d73387d8 build(ui): fix chakra deps
We had multiple versions of @emotion/react, stemming from an extraneous dependency on @chakra-ui/react. Removed the extraneosu dep
2024-08-14 06:12:00 +10:00
Mary Hipp
e0d3927265 feat: add flag for allowPrivateStylePresets that shows a type field when creating a style preset 2024-08-13 14:08:54 -04:00
Mary Hipp
e5f7c2a9b7 add type safety / validation to form data payloads and allow type to be passed through api 2024-08-13 13:00:31 -04:00
Mary Hipp
b0760710d5 add the rest of default style presets, update image service to return default images correctly by name, add tooltip popover to images in UI 2024-08-13 11:33:15 -04:00
Mary Hipp
764accc921 update config docstring 2024-08-12 15:17:40 -04:00
Mary Hipp
6a01fce9c1 fix payloads for stringified data 2024-08-12 15:16:22 -04:00
Mary Hipp
9c732ac3b1 Merge remote-tracking branch 'origin/main' into maryhipp/style-presets 2024-08-12 14:53:45 -04:00
Mary Hipp
b70891c661 update descriptoin of placeholder in modal 2024-08-12 13:37:04 -04:00
Mary Hipp
4dbf851741 ui: add labels to prompt boxes 2024-08-12 13:33:39 -04:00
Mary Hipp
6c927a9fd4 move mdoal state into nanostore 2024-08-12 12:46:02 -04:00
Mary Hipp
096f001634 ui: add ability to copy template 2024-08-12 12:32:31 -04:00
Mary Hipp
4837e578b2 api: update dir path for style preset images, update payload for create/update formdata 2024-08-12 12:00:14 -04:00
Mary Hipp
1e547ef912 UI more pr feedback 2024-08-12 11:59:25 -04:00
psychedelicious
f6b8970bd1 fix(app): create reference to events task to prevent accidental GC
This wasn't a problem, but it's advised in the official docs so I've done it.
2024-08-12 07:49:58 +10:00
psychedelicious
29325a7214 fix(app): use asyncio queue and existing event loop for events
Around the time we (I) implemented pydantic events, I noticed a short pause between progress images every 4 or 5 steps when generating with SDXL. It didn't happen with SD1.5, but I did notice that with SD1.5, we'd get 4 or 5 progress events simultaneously. I'd expect one event every ~25ms, matching my it/s with SD1.5. Mysterious!

Digging in, I found an issue is related to our use of a synchronous queue for events. When the event queue is empty, we must call `asyncio.sleep` before checking again. We were sleeping for 100ms.

Said another way, every time we clear the event queue, we have to wait 100ms before another event can be dispatched, even if it is put on the queue immediately after we start waiting. In practice, this means our events get buffered into batches, dispatched once every 100ms.

This explains why I was getting batches of 4 or 5 SD1.5 progress events at once, but not the intermittent SDXL delay.

But this 100ms wait has another effect when the events are put on the queue in intervals that don't perfectly line up with the 100ms wait. This is most noticeable when the time between events is >100ms, and can add up to 100ms delay before the event is dispatched.

For example, say the queue is empty and we start a 100ms wait. Then, immediately after - like 0.01ms later - we push an event on to the queue. We still need to wait another 99.9ms before that event will be dispatched. That's the SDXL delay.

The easy fix is to reduce the sleep to something like 0.01 seconds, but this feels kinda dirty. Can't we just wait on the queue and dispatch every event immediately? Not with the normal synchronous queue - but we can with `asyncio.Queue`.

I switched the events queue to use `asyncio.Queue` (as seen in this commit), which lets us asynchronous wait on the queue in a loop.

Unfortunately, I ran into another issue - events now felt like their timing was inconsistent, but in a different way than with the 100ms sleep. The time between pushing events on the queue and dispatching them was not consistently ~0ms as I'd expect - it was highly variable from ~0ms up to ~100ms.

This is resolved by passing the asyncio loop directly into the events service and using its methods to create the task and interact with the queue. I don't fully understand why this resolved the issue, because either way we are interacting with the same event loop (as shown by `asyncio.get_running_loop()`). I suppose there's some scheduling magic happening.
2024-08-12 07:49:58 +10:00
psychedelicious
8ecf72838d fix(api): image downloads with correct filename
Closes #6730
2024-08-10 09:53:56 -04:00
psychedelicious
c3ab8a6aa8 chore(ui): bump rest of deps 2024-08-10 07:45:23 -04:00
psychedelicious
1931aa3e70 chore(ui): typegen 2024-08-10 07:45:23 -04:00
psychedelicious
d3d8055055 feat(ui): update typegen script 2024-08-10 07:45:23 -04:00
psychedelicious
476b0a0403 chore(ui): bump openapi-typescript 2024-08-10 07:45:23 -04:00
psychedelicious
f66584713c fix(api): sort OpenAPI schema properties for InvocationOutputMap
This makes the schema output deterministic!
2024-08-10 07:45:23 -04:00
psychedelicious
33624fc2fa fix(api): duplicate operation id for get_image_full
There's a FastAPI bug that results in the OpenAPI spec outputting the same operation id for each operation when specifying multiple HTTP methods.

- Discussion: https://github.com/tiangolo/fastapi/discussions/8449
- Pending PR to fix: https://github.com/tiangolo/fastapi/pull/10694

In our case, we have a `get_image_full` endpoint that handles GET and HEAD.

This results in an invalid OpenAPI schema. A workaround is to use two route decorators for the operation handler. This works as expected - HEAD requests get the header, and GET requests get the resource. And the OpenAPI schema is valid.
2024-08-10 07:45:23 -04:00
Mary Hipp
41c3e73a3c fix tests 2024-08-09 16:31:42 -04:00
Mary Hipp
97553a7de2 API/DB updates per PR feedback 2024-08-09 16:27:37 -04:00
Mary Hipp
12ba15bfa9 UI updates per PR feedback 2024-08-09 16:00:13 -04:00
Mary Hipp
09d1e190e7 show warning for maxUpscaleDimension if model tab is disabled 2024-08-09 14:07:55 -04:00
Mary Hipp
8eb5d08499 missed translation 2024-08-08 16:01:16 -04:00
Mary Hipp
9be6acde7d require name to submit style preset 2024-08-08 15:53:21 -04:00
Mary Hipp
5f83bb0069 update config docstring 2024-08-08 15:20:43 -04:00
Mary Hipp
b138882abc fix tests? 2024-08-08 15:18:32 -04:00
Mary Hipp
0cd7cdb52e remove send2trash 2024-08-08 15:13:36 -04:00
Mary Hipp
1d8b7e2bcf ruff 2024-08-08 15:08:45 -04:00
Mary Hipp
6461f4758d lint fix 2024-08-08 15:07:58 -04:00
Mary Hipp
3189ab6863 get dynamic prompts working 2024-08-08 15:07:23 -04:00
Mary Hipp
3f9a674d4b seed default presets and handle them in UI 2024-08-08 15:02:41 -04:00
Mary Hipp
587f59b25b focus on prompt textarea when exiting view mode by clicking 2024-08-08 14:38:50 -04:00
Mary Hipp
4952eada87 ruff format 2024-08-08 14:22:40 -04:00
Mary Hipp
581029ebaa ruff 2024-08-08 14:21:37 -04:00
Mary Hipp
42d68780de lint 2024-08-08 14:19:33 -04:00
Mary Hipp
28032a2f80 more cleanup 2024-08-08 14:18:05 -04:00
Mary Hipp
e381e021e9 knip lint 2024-08-08 14:00:17 -04:00
Mary Hipp
641af64f93 regnerate schema 2024-08-08 13:58:25 -04:00
Mary Hipp
a7b83c8b5b Merge remote-tracking branch 'origin/main' into maryhipp/style-presets 2024-08-08 13:56:59 -04:00
Mary Hipp
4cc41e0188 translations and lint fix 2024-08-08 13:56:37 -04:00
Mary Hipp
442fc02429 resize images to 100x100 for style preset images 2024-08-08 12:56:55 -04:00
Mary Hipp
9a4d075074 fix path for style_preset_images, fix png type when converting blobs to files, built view mode components 2024-08-08 12:31:20 -04:00
Sergey Borisov
17ff8196cb Remove tmp code 2024-08-07 22:06:05 -04:00
Sergey Borisov
68f993998a Add support for norm layer 2024-08-07 22:06:05 -04:00
Sergey Borisov
7da6120b39 Fix LoKR refactor bug 2024-08-07 22:06:05 -04:00
blessedcoolant
6cd40965c4 Depth Anything V2 (#6674)
- Updated the previous DepthAnything manual implementation to use the
`transformers` implementation instead. So we can get upstream features.
- Plugged in the DepthAnything models to be handled by Invoke's Model
Manager.
- `small_v2` model will use DepthAnythingV2. This has been added as a
new model option and is now also the default in the Linear UI.


![opera_TxRhmbFole](https://github.com/user-attachments/assets/2a25abe3-ba0b-4f97-b75a-2ce5fd6246e6)


# Merge

Review and merge.
2024-08-07 20:26:58 +05:30
Kent Keirsey
408a1d6dbb Merge branch 'main' into depth_anything_v2 2024-08-07 10:45:56 -04:00
Mary Hipp
0b0abfbe8f clean up image implementation 2024-08-07 10:36:38 -04:00
Mary Hipp
cc96dcf0ed style preset images 2024-08-07 09:58:27 -04:00
Mary Hipp
2604fd9fde a whole bunch of stuff 2024-08-06 15:31:13 -04:00
Hosted Weblate
140670d00e translationBot(ui): update translation files
Updated by "Cleanup translation files" hook in Weblate.

translationBot(ui): update translation files

Updated by "Cleanup translation files" hook in Weblate.

Co-authored-by: Hosted Weblate <hosted@weblate.org>
Translate-URL: https://hosted.weblate.org/projects/invokeai/web-ui/
Translation: InvokeAI/Web UI
2024-08-06 17:54:47 +10:00
Phrixus2023
70233fae5d translationBot(ui): update translation (Chinese (Simplified))
Currently translated at 98.1% (1296 of 1321 strings)

Co-authored-by: Phrixus2023 <920414016@qq.com>
Translate-URL: https://hosted.weblate.org/projects/invokeai/web-ui/zh_Hans/
Translation: InvokeAI/Web UI
2024-08-06 17:54:47 +10:00
Alexander Eichhorn
6f457a6c4c translationBot(ui): update translation (German)
Currently translated at 65.1% (860 of 1321 strings)

Co-authored-by: Alexander Eichhorn <pfannkuchensack@einfach-doof.de>
Translate-URL: https://hosted.weblate.org/projects/invokeai/web-ui/de/
Translation: InvokeAI/Web UI
2024-08-06 17:54:47 +10:00
B N
5c319f5356 translationBot(ui): update translation (German)
Currently translated at 64.8% (857 of 1321 strings)

Co-authored-by: B N <berndnieschalk@gmail.com>
Translate-URL: https://hosted.weblate.org/projects/invokeai/web-ui/de/
Translation: InvokeAI/Web UI
2024-08-06 17:54:47 +10:00
Riccardo Giovanetti
991a04f090 translationBot(ui): update translation (Italian)
Currently translated at 98.6% (1303 of 1321 strings)

translationBot(ui): update translation (Italian)

Currently translated at 98.6% (1302 of 1320 strings)

translationBot(ui): update translation (Italian)

Currently translated at 98.6% (1294 of 1312 strings)

Co-authored-by: Riccardo Giovanetti <riccardo.giovanetti@gmail.com>
Translate-URL: https://hosted.weblate.org/projects/invokeai/web-ui/it/
Translation: InvokeAI/Web UI
2024-08-06 17:54:47 +10:00
psychedelicious
c39fa75113 docs(ui): add comment in useIsTooLargeToUpscale 2024-08-06 11:49:35 +10:00
psychedelicious
f7863e17ce docs(ui): add docstring for maxUpscaleDimension 2024-08-06 11:49:35 +10:00
psychedelicious
7c526390ed fix(ui): compare upscaledPixels vs square of max dimension 2024-08-06 11:49:35 +10:00
Mary Hipp
2cff20f87a update translations, change config value to be dimension instead of total pixels 2024-08-06 11:49:35 +10:00
Mary Hipp
90ec757802 lint 2024-08-06 11:49:35 +10:00
Mary Hipp
4b85dfcefe (ui): restore optioanl limit on upcsale output resolution 2024-08-06 11:49:35 +10:00
Mary Hipp
21deefdc41 (ui): add image resolution badge to initial upscale image 2024-08-06 11:49:35 +10:00
Mary Hipp
857d74bbfe wip apply and calculate prompt with interpolation 2024-08-05 19:11:48 -04:00
Mary Hipp
fd7a635777 (ui) the most basic crud ui: view list of presets, create a new preset, edit/delete existing presets 2024-08-05 15:48:23 -04:00
Mary Hipp
af9110e964 fix prompt concat logic 2024-08-05 13:42:28 -04:00
Mary Hipp
a61209206b remove custom SDXL prompts component 2024-08-05 13:40:46 -04:00
Mary Hipp
e05cc62e5f add style presets API layer to UI 2024-08-05 13:37:07 -04:00
psychedelicious
4d4f921a4e build: exclude matplotlib 3.9.1
There was a problem w/ this release on windows and the builds were pulled from pypi. When installing invoke on windows, pip attempts to build from source, but most (all?) systems won't have the prerequisites for this and installs fail.

This also affects GH actions.

The simple fix is to exclude version 3.9.1 from our deps.

For more information, see https://github.com/matplotlib/matplotlib/issues/28551
2024-08-05 08:38:44 +10:00
psychedelicious
98db8f395b feat(app): clean up DiskImageStorage types 2024-08-04 09:43:20 +10:00
psychedelicious
f465a956a3 feat(ui): remove "images can be restored" messages 2024-08-04 09:43:20 +10:00
psychedelicious
9edb02d7ef build: remove send2trash dependency 2024-08-04 09:43:20 +10:00
psychedelicious
6c4cf58a31 feat(app): delete model_images instead of using send2trash 2024-08-04 09:43:20 +10:00
psychedelicious
08993c0d29 feat(app): delete images instead of using send2trash
Closes #6709
2024-08-04 09:43:20 +10:00
blessedcoolant
4f8a4b0f22 Merge branch 'main' into depth_anything_v2 2024-08-03 00:38:57 +05:30
blessedcoolant
a743f3c9b5 fix: implement model to func for depth anything 2024-08-03 00:37:17 +05:30
Mary Hipp
217fe40d99 feat(api): add style_presets router, make sure all CRUD is working, add is_default 2024-08-02 12:29:54 -04:00
Mary Hipp
b76bf50b93 feat(db,api): create new table for style presets, build out record storage service for style presets 2024-08-01 22:20:11 -04:00
Mary Hipp
571ba87e13 fix(ui): include upscale metadata for SDXL multidiffusion 2024-08-01 21:30:42 -04:00
Ryan Dick
f27b6e2b44 Add Grounded SAM support (text prompt image segmentation) (#6701)
## Summary

This PR enables Grounded SAM workflows
(https://arxiv.org/pdf/2401.14159) via the following:
- `GroundingDinoInvocation` for running a Grounding DINO model.
- `SegmentAnythingModelInvocation` for running a SAM model.
- `MaskTensorToImageInvocation` for convenient visualization.

Other notes:
- Uses the transformers implementation of Grounding DINO and SAM.
- The new models are treated as 'utility models' meaning that they are
not visible in the Models tab, and are downloaded automatically the
first time that they are used.

<img width="874" alt="image"
src="https://github.com/user-attachments/assets/1cbaa97d-0e27-4943-86b1-dc7327ba8675">

## Example

Input image

![be10ec0c-20a8-4ac7-840e-d1a05fffdb6a](https://github.com/user-attachments/assets/bf21572c-635d-4703-b4ab-7aba658a9671)

Prompt: "wheels", all other configs default
Result:

![2221c44e-64e6-4b18-b4cb-610514b7a554](https://github.com/user-attachments/assets/344b91f4-7f4a-4b70-8e2e-3b4a0e55176d)

## Related Issues / Discussions

Thanks to @blessedcoolant for the initial draft here:
https://github.com/invoke-ai/InvokeAI/pull/6678

## QA Instructions

Manual tests:
- [ ] Test that default settings work well.
- [ ] Test with / without apply_polygon_refinement
- [ ] Test mask_filter options
- [ ] Test detection_threshold values
- [ ] Test RGB input image
- [ ] Test RGBA input image
- [ ] Test grayscale input image
- [ ] Smoke test that an empty mask is returned when 0 objects are
detected
- [ ] Test on CPU
- [ ] Test on MPS (Works on Mac OS, but had to force both models to run
on CPU instead of MPS)

Performance:
- Peak GPU memory utilization with both Grounding DINO and SAM models
loaded is ~4.5GB. (The models do not need to be loaded at the same time,
so could be offloaded by the MM if needed.)
- On an RTX4090, with the models already cached, node execution takes
~0.6 secs.
- On my CPU, with the models cached, node execution takes ~10secs.

## Merge Plan

No special instructions.

## Checklist

- [x] _The PR has a short but descriptive title, suitable for a
changelog_
- [ ] _Tests added / updated (if applicable)_
- [x] _Documentation added / updated (if applicable)_
2024-08-01 20:40:18 +02:00
Ryan Dick
981475a624 Merge branch 'main' into ryan/grounded-sam 2024-08-01 20:30:35 +02:00
Ryan Dick
27ac61a4fb Expose all model options in the GroundingDinoInvocation and the SegmentAnythingInvocation. 2024-08-01 14:23:32 -04:00
Ryan Dick
675ffc2757 Remove BoundingBoxInvocation field name overrides. 2024-08-01 14:05:44 -04:00
Ryan Dick
44b21f10f1 Add a pydantic model_validator to BoundingBoxField to check the validity of the coords. 2024-08-01 14:00:57 -04:00
Ryan Dick
c6d49e8b1f Shorten SegmentAnythingInvocation and GroundingDinoInvocatino docstrings, since they are used as the invocation descriptions in the UI. 2024-08-01 10:17:42 -04:00
Ryan Dick
e6a512aa86 (minor) Tweak order of mask operations. 2024-08-01 10:12:24 -04:00
Ryan Dick
c3a6a6fb22 Rename SegmentAnythingModelInvocation -> SegmentAnythingInvocation. 2024-08-01 10:00:36 -04:00
Ryan Dick
b9dc3460ba Rename SegmentAnythingModel -> SegmentAnythingPipeline. 2024-08-01 09:57:47 -04:00
Ryan Dick
63581ec980 (minor) Add None check to fix static type checking error. 2024-08-01 09:51:53 -04:00
chainchompa
08b1feeed7 add base prop for destination to direct users to different tabs on initial load (#6706)
## Summary
- we want a way to load the studio while being directed to a specific
tab, introduced a destination prop to achieve that
<!--A description of the changes in this PR. Include the kind of change
(fix, feature, docs, etc), the "why" and the "how". Screenshots or
videos are useful for frontend changes.-->

## Related Issues / Discussions

<!--WHEN APPLICABLE: List any related issues or discussions on github or
discord. If this PR closes an issue, please use the "Closes #1234"
format, so that the issue will be automatically closed when the PR
merges.-->

## QA Instructions

<!--WHEN APPLICABLE: Describe how you have tested the changes in this
PR. Provide enough detail that a reviewer can reproduce your tests.-->

## Merge Plan

<!--WHEN APPLICABLE: Large PRs, or PRs that touch sensitive things like
DB schemas, may need some care when merging. For example, a careful
rebase by the change author, timing to not interfere with a pending
release, or a message to contributors on discord after merging.-->

## Checklist

- [ ] _The PR has a short but descriptive title, suitable for a
changelog_
- [ ] _Tests added / updated (if applicable)_
- [ ] _Documentation added / updated (if applicable)_
2024-07-31 19:25:36 -04:00
blessedcoolant
f5cfdcf32d feat: Add BoundingBox Primitive Node 2024-08-01 04:09:08 +05:30
chainchompa
e78fb428f0 simplify destination prop handling 2024-07-31 18:06:22 -04:00
chainchompa
31e270e32c add base prop for destination to direct users to different tabs 2024-07-31 17:20:51 -04:00
Ryan Dick
b5832768dc Return a MaskOutput from SegmentAnythingModelInvocation. And add a MaskTensorToImageInvocation. 2024-07-31 17:16:14 -04:00
Ryan Dick
4ce64b69cb Modular backend - LoRA/LyCORIS (#6667)
## Summary

Code for lora patching from #6577.
Additionally made it the way, that lora can patch not only `weight`, but
also `bias`, because saw some loras which doing it.

## Related Issues / Discussions

#6606 

https://invokeai.notion.site/Modular-Stable-Diffusion-Backend-Design-Document-e8952daab5d5472faecdc4a72d377b0d

## QA Instructions

Run with and without set `USE_MODULAR_DENOISE` environment.

## Merge Plan

Replace old lora patcher with new after review done.
If you think that there should be some kind of tests - feel free to add.

## Checklist

- [x] _The PR has a short but descriptive title, suitable for a
changelog_
- [ ] _Tests added / updated (if applicable)_
- [ ] _Documentation added / updated (if applicable)_
2024-07-31 21:31:31 +02:00
Ryan Dick
5a9173f766 Merge branch 'main' into stalker-modular_lora 2024-07-31 15:13:22 -04:00
Ryan Dick
0bb7ed44f6 Add some docs to OriginalWeightsStorage and fix type hints. 2024-07-31 15:08:24 -04:00
blessedcoolant
332bc9da5b fix: Update depth anything node default to v2 2024-07-31 23:52:29 +05:30
blessedcoolant
08def3da95 fix: Update canvas depth anything processor default to v2 2024-07-31 23:50:13 +05:30
blessedcoolant
daf899f9c4 fix: Move the manual image resizing out of the depth anything pipeline 2024-07-31 23:38:12 +05:30
blessedcoolant
13fb2d1f49 fix: Add Depth Anything V2 as a new option
It is also now the default in the UI replacing Depth Anything V1 small
2024-07-31 23:29:43 +05:30
blessedcoolant
95dde802ea fix: assert the return depth map to be a PIL image 2024-07-31 23:22:01 +05:30
Ryan Dick
fca119773b Split invokeai/backend/image_util/segment_anything/ dir into grounding_dino/ and segment_anything/ 2024-07-31 12:28:47 -04:00
Ryan Dick
0193267a53 Split GroundedSamInvocation into GroundingDinoInvocation and SegmentAnythingModelInvocation. 2024-07-31 12:20:23 -04:00
blessedcoolant
b4cf78a95d fix: make DA Pipeline a subclass of RawModel 2024-07-31 21:14:49 +05:30
Ryan Dick
73386826d6 Make GroundingDinoPipeline and SegmentAnythingModel subclasses of RawModel for type checking purposes. 2024-07-31 10:25:34 -04:00
Ryan Dick
9f448fecb7 Move invokeai/backend/grounded_sam -> invokeai/backend/image_util/grounded_sam 2024-07-31 10:00:30 -04:00
Ryan Dick
bcd1483a14 Re-order GroundedSAMInvocation._to_numpy_masks(...) to do slightly more work on the GPU. 2024-07-31 09:51:14 -04:00
Ryan Dick
e206890e25 Use staticmethods rather than inner functions for the Grounding DINO and SAM model loaders. 2024-07-31 09:28:52 -04:00
Ryan Dick
0a7048f650 (minor) Simplify GroundedSAMInvocation._merge_masks(...). 2024-07-31 08:58:51 -04:00
Ryan Dick
e8ecf5e155 (minor) Move apply_polygon_refinement condition up a layer. 2024-07-31 08:50:56 -04:00
Ryan Dick
33e8604b57 Make Grounding DINO DetectionResult a Pydantic model. 2024-07-31 08:47:00 -04:00
Ryan Dick
cec7399366 (minor) Use a new variable name to satisfy type checks. 2024-07-31 08:27:01 -04:00
Ryan Dick
bdae81e429 (minor) Simplify GroundedSAMInvocation._filter_detections() 2024-07-31 08:25:19 -04:00
Ryan Dick
67c32f3d6c Fix typo: zip(..., strict=True) 2024-07-31 08:15:28 -04:00
blessedcoolant
94d64b8a78 Fix gradient mask values range (#6688)
## Summary

Gradient mask node outputs mask tensor with values in range [-1, 1],
which unexpected range for mask.
It handled in denoise node the way it translates to [0, 2] mask, which
looks even more wrongly)
From discussion with @dunkeroni I understand him as he thought that
negative values will be treated same as 0, so clamping values not change
intended node logic.

## Related Issues / Discussions

#6643 

## QA Instructions

\-

## Merge Plan

\-

## Checklist

- [x] _The PR has a short but descriptive title, suitable for a
changelog_
- [ ] _Tests added / updated (if applicable)_
- [ ] _Documentation added / updated (if applicable)_
2024-07-31 06:37:32 +05:30
blessedcoolant
fa3c0c81b3 Merge branch 'main' into stalker7779/fix_gradient_mask 2024-07-31 06:30:44 +05:30
blessedcoolant
66547b99c1 Add more karras schedulers (#6695)
## Summary

Add karras variants of `deis`, `unipc`, `kdpm2` and `kdpm_2_a`
schedulers.
Also added `dpmpp_3` schedulers, but `dpmpp_3s` currently bugged, so
added only 3m:
https://github.com/huggingface/diffusers/issues/9007

## Related Issues / Discussions

\-

## QA Instructions

\-

## Merge Plan

~@psychedelicious We need to decide what to do with schedulers order, as
it looks a bit broken:~

![image](https://github.com/user-attachments/assets/e41674af-d87c-4432-8014-c90bd86965a6)

## Checklist

- [x] _The PR has a short but descriptive title, suitable for a
changelog_
- [ ] _Tests added / updated (if applicable)_
- [ ] _Documentation added / updated (if applicable)_
2024-07-31 06:09:26 +05:30
blessedcoolant
328e58be4c Merge branch 'main' into stalker7779/new_karras_schedulers 2024-07-31 05:56:13 +05:30
blessedcoolant
18f89ed5ed fix: Make DepthAnything work with Invoke's Model Management 2024-07-31 03:57:54 +05:30
Ryan Dick
5701c79fab Prevent Grounding DINO and Segment Anything from being moved to MPS - they don't work on MPS devices. 2024-07-30 23:04:15 +02:00
Ryan Dick
2da9f913f3 Add detection_result.py - was forgotten in a prior commit 2024-07-30 16:04:29 -04:00
Ryan Dick
6b10b59abe Make GroundedSAMInvocation work with any input image mode (RGB, RGBA, grayscale). 2024-07-30 15:55:57 -04:00
Ryan Dick
918f77bce0 Move some logic from GroundedSAMInvocation to the backend classes. 2024-07-30 15:34:33 -04:00
blessedcoolant
f170697ebe Merge branch 'main' into depth_anything_v2 2024-07-31 00:53:32 +05:30
blessedcoolant
556c6a1d84 fix: Update DepthAnything to use the transformers implementation 2024-07-31 00:51:55 +05:30
Ryan Dick
aca2a2fa13 Add mask_filter and detection_threshold options to the GroundedSAMInvocation. 2024-07-30 14:22:40 -04:00
Ryan Dick
ff6398f7d8 Add a GroundedSamInvocation for image segmentation from a text prompt (Grounding DINO + Segment Anything Model). 2024-07-30 11:12:26 -04:00
Sergey Borisov
cf996472b9 Suggested changes
Co-Authored-By: psychedelicious <4822129+psychedelicious@users.noreply.github.com>
2024-07-30 04:50:56 +03:00
Sergey Borisov
156d14c349 Run api regen 2024-07-30 04:05:21 +03:00
Sergey Borisov
86f705bf48 Optimize weights handling 2024-07-30 03:39:01 +03:00
Sergey Borisov
1fd9631f2d Comments fix
Co-Authored-By: Ryan Dick <14897797+RyanJDick@users.noreply.github.com>
2024-07-30 00:39:50 +03:00
Sergey Borisov
2227a2357f Suggested changes + simplify weights logic in patching
Co-Authored-By: Ryan Dick <14897797+RyanJDick@users.noreply.github.com>
2024-07-30 00:34:37 +03:00
Sergey Borisov
58e7ab157d Ruff format 2024-07-29 22:59:17 +03:00
Sergey Borisov
8d16fa6a49 Remove dpmpp_3s schedulers as it bugged now 2024-07-29 22:55:45 +03:00
Sergey Borisov
55e810efa3 Add dpmpp_3 schedulers 2024-07-29 22:52:15 +03:00
chainchompa
2755316021 update delete board modal to be more descriptive (#6690)
## Summary

<!--A description of the changes in this PR. Include the kind of change
(fix, feature, docs, etc), the "why" and the "how". Screenshots or
videos are useful for frontend changes.-->

## Related Issues / Discussions

<!--WHEN APPLICABLE: List any related issues or discussions on github or
discord. If this PR closes an issue, please use the "Closes #1234"
format, so that the issue will be automatically closed when the PR
merges.-->

## QA Instructions

<!--WHEN APPLICABLE: Describe how you have tested the changes in this
PR. Provide enough detail that a reviewer can reproduce your tests.-->

## Merge Plan

<!--WHEN APPLICABLE: Large PRs, or PRs that touch sensitive things like
DB schemas, may need some care when merging. For example, a careful
rebase by the change author, timing to not interfere with a pending
release, or a message to contributors on discord after merging.-->

## Checklist

- [ ] _The PR has a short but descriptive title, suitable for a
changelog_
- [ ] _Tests added / updated (if applicable)_
- [ ] _Documentation added / updated (if applicable)_
2024-07-29 13:43:17 -04:00
chainchompa
6525f18610 Merge branch 'main' into chainchompa/board-delete-info 2024-07-29 12:52:36 -04:00
Ryan Dick
2ad13ac7eb Modular backend - inpaint (#6643)
## Summary

Code for inpainting and inpaint models handling from
https://github.com/invoke-ai/InvokeAI/pull/6577.
Separated in 2 extensions as discussed briefly before, so wait for
discussion about such implementation.

## Related Issues / Discussions

#6606

https://invokeai.notion.site/Modular-Stable-Diffusion-Backend-Design-Document-e8952daab5d5472faecdc4a72d377b0d

## QA Instructions

Run with and without set `USE_MODULAR_DENOISE` environment.
Try and compare outputs between backends in cases:
- Normal generation on inpaint model
- Inpainting on inpaint model
- Inpainting on normal model

## Merge Plan

Nope.
If you think that there should be some kind of tests - feel free to add.

## Checklist

- [x] _The PR has a short but descriptive title, suitable for a
changelog_
- [ ] _Tests added / updated (if applicable)_
- [ ] _Documentation added / updated (if applicable)_
2024-07-29 10:27:25 -04:00
Ryan Dick
693a3eaff5 Merge branch 'main' into stalker-modular_inpaint-2 2024-07-29 10:14:45 -04:00
chainchompa
ffca792d5b edited copy for deleted boards message 2024-07-29 09:46:08 -04:00
Sergey Borisov
86a92bb6b5 Add more karras schedulers 2024-07-29 15:14:34 +03:00
psychedelicious
171a4e6d80 fix(ui): race condition when deleting a board and resetting selected/auto-add
We were checking the selected and auto-add board ids against the query cache to see if they still exist. If not, we reset.

This only works if the query cache is updated by the time we do the check - race condition!

We already have the board id from the query args, so there's no need to check the query cache - just compare the deleted board ID directly.

Previously this file's several listeners were all in a single one and I had adapted/split its logic up a bit wonkily, introducing these problems.
2024-07-29 11:36:03 +10:00
psychedelicious
e3a75a8adf fix(ui): fix logic to reset selected/auto-add boards when toggling show archived boards
The logic was incorrect in two ways:
1. We only ran the logic if we _enable_ showing archived boards. It should be run we we _disable_ showing archived boards.
2. If we couldn't find the selected board in the query cache, we didn't do the reset. This is wrong - if the board isn't in the query cache, we _should_ do the reset. This inverted logic makes more sense before the fix for issue 1.
2024-07-29 11:36:03 +10:00
Ryan Dick
ee7503ce13 Modular backend - T2I Adapter (#6662)
## Summary

T2I Adapter code from #6577.

## Related Issues / Discussions

#6606 

https://invokeai.notion.site/Modular-Stable-Diffusion-Backend-Design-Document-e8952daab5d5472faecdc4a72d377b0d

## QA Instructions

Run with and without set `USE_MODULAR_DENOISE` environment.

## Merge Plan

Nope.
If you think that there should be some kind of tests - feel free to add.

## Checklist

- [x] _The PR has a short but descriptive title, suitable for a
changelog_
- [ ] _Tests added / updated (if applicable)_
- [ ] _Documentation added / updated (if applicable)_
2024-07-28 15:52:04 -04:00
Sergey Borisov
8500bac3ca Use logger for warning 2024-07-28 22:51:52 +03:00
Ryan Dick
310719eb4c Merge branch 'main' into stalker-modular_t2i_adapter 2024-07-28 15:30:00 -04:00
Ryan Dick
e8e24822ec Modular backend - Seamless (#6651)
## Summary

Seamless code from #6577.

## Related Issues / Discussions

#6606 

https://invokeai.notion.site/Modular-Stable-Diffusion-Backend-Design-Document-e8952daab5d5472faecdc4a72d377b0d

## QA Instructions

Run with and without set `USE_MODULAR_DENOISE` environment.

## Merge Plan

Nope.
If you think that there should be some kind of tests - feel free to add.

## Checklist

- [x] _The PR has a short but descriptive title, suitable for a
changelog_
- [ ] _Tests added / updated (if applicable)_
- [ ] _Documentation added / updated (if applicable)_
2024-07-28 13:57:38 -04:00
Ryan Dick
c57a7afb87 Merge branch 'main' into stalker7779/modular_seamless 2024-07-28 13:49:43 -04:00
Sergey Borisov
84d028898c Revert wrong comment copy 2024-07-27 13:20:58 +03:00
Sergey Borisov
ed0174fbc6 Suggested changes
Co-Authored-By: Ryan Dick <14897797+RyanJDick@users.noreply.github.com>
2024-07-27 13:18:28 +03:00
Sergey Borisov
9e582563eb Suggested changes
Co-Authored-By: Ryan Dick <14897797+RyanJDick@users.noreply.github.com>
2024-07-27 04:25:15 +03:00
Sergey Borisov
faa88f72bf Make lora as separate extensions 2024-07-27 02:39:53 +03:00
chainchompa
0d69a31df0 Merge branch 'main' into chainchompa/board-delete-info 2024-07-26 14:03:18 -04:00
brandonrising
daa5a88eb2 Update docker image to use pnpm version 8 2024-07-26 13:57:33 -04:00
Sergey Borisov
5b84e117b2 Suggested changes
Co-Authored-By: Ryan Dick <14897797+RyanJDick@users.noreply.github.com>
2024-07-26 20:51:12 +03:00
chainchompa
eb257d2d28 update delete board modal to be more descriptive 2024-07-26 13:34:25 -04:00
Sergey Borisov
5810cee6c9 Suggested changes
Co-Authored-By: Ryan Dick <14897797+RyanJDick@users.noreply.github.com>
2024-07-26 19:47:28 +03:00
Sergey Borisov
eef88d1f83 Update gradient mask node version 2024-07-26 19:33:41 +03:00
Sergey Borisov
78f6850fc0 Fix gradient mask values range 2024-07-26 19:28:00 +03:00
Sergey Borisov
bd8890be11 Revert "Fix create gradient mask node output"
This reverts commit 9d1fcba415.
2024-07-26 19:24:46 +03:00
Sergey Borisov
adf1a977ea Suggested changes
Co-Authored-By: Ryan Dick <14897797+RyanJDick@users.noreply.github.com>
2024-07-26 19:22:26 +03:00
Mary Hipp
e1509bcb45 bump version to 4.2.7 2024-07-26 09:11:17 -07:00
psychedelicious
edcaf8287d feat(app): remove beta from multidiffusion workflows 2024-07-26 13:47:51 +10:00
psychedelicious
39bd30f2a0 feat(app): update default workflows
- Update `MultiDiffusion SDXL (Beta)`
- Add `MultiDiffusion SD1.5 (Beta)`
2024-07-26 13:47:51 +10:00
psychedelicious
102b47190f feat(ui): update qr code cnet starter model
- For SD1.5, use the new V2 version
- Add the SDXL version
2024-07-26 13:34:32 +10:00
Mary Hipp
269fe2e3bb track accordions in tabs separately so open/close state isnt shared 2024-07-26 08:20:24 +10:00
Mary Hipp
b32aa1c77f fix missing quote in translation 2024-07-26 08:20:24 +10:00
Mary Hipp Rogers
6656544ed5 tooltip copy updates
Co-authored-by: psychedelicious <4822129+psychedelicious@users.noreply.github.com>
2024-07-26 08:20:24 +10:00
Mary Hipp
4c75b93410 feat(ui): add informational popovers for upscale params 2024-07-26 08:20:24 +10:00
Mary Hipp
5be0de967d feat(ui): close generation and advanced accordions when switching to upscale tab 2024-07-26 08:20:24 +10:00
psychedelicious
f8e27b837b fix(ui): memoize model manager components 2024-07-26 07:52:10 +10:00
psychedelicious
47414be1e6 fix(ui): dropped model config cache breaking model edit UI
The model edit UI's composition allows for the model edit form to be instantiated before the model's config has been received. This results in the form having no values - all the fields are blank instead of populated by the model config.

Part of the fix is to pass the model config around directly instead of relying on _all_ components to fetch the model directly.

I also fixed a crapload of performance issues related to improper use of redux selectors.
2024-07-26 07:52:10 +10:00
psychedelicious
74cef38bcf fix(backend): add refiner to single-file load_classes
Fixes single-file refiner loading.
2024-07-26 05:08:01 +10:00
psychedelicious
bb876b8d4e fix(ui): copied edges must have new ids set
Problems this was causing:
- Deleting an edge was a copy of another edge deletes both edges
- Deleting a node that was a copy-with-edges of another node deletes its edges and it's original edges, leaving what I will call "ghost noodles" behind
2024-07-26 04:54:33 +10:00
blessedcoolant
e5d9ca013e fix: use v1 models for large and base versions 2024-07-25 17:24:12 +05:30
blessedcoolant
4166c756ce wip: depth_anything_v2 init lint fixes 2024-07-25 14:41:22 +05:30
blessedcoolant
4f0dfbd34d wip: depth_anything_v2 initial implementation 2024-07-25 13:53:06 +05:30
Sergey Borisov
46c632e7cc Change layer detection keys according to LyCORIS repository 2024-07-25 02:10:47 +03:00
Sergey Borisov
653f63ae71 Add layer keys check 2024-07-25 02:03:08 +03:00
Sergey Borisov
8a9e2f57a4 Handle bias in full/diff lora layer 2024-07-25 02:02:37 +03:00
Sergey Borisov
31949ed2f2 Refactor code a bit 2024-07-25 02:00:30 +03:00
Sergey Borisov
0ccb304b8b Ruff format 2024-07-24 16:01:29 +03:00
Sergey Borisov
ab0bfa709a Handle loras in modular denoise 2024-07-24 05:07:29 +03:00
Sergey Borisov
6af659b1da Handle t2i adapter in modular denoise 2024-07-24 02:55:33 +03:00
Sergey Borisov
416d29fb83 Ruff format 2024-07-24 01:17:28 +03:00
Sergey Borisov
19c00241c6 Use non-inverted mask generally(except inpaint model handling) 2024-07-24 00:59:13 +03:00
Sergey Borisov
c323a760a5 Suggested changes
Co-Authored-By: Ryan Dick <14897797+RyanJDick@users.noreply.github.com>
2024-07-23 23:34:28 +03:00
Sergey Borisov
9d1fcba415 Fix create gradient mask node output 2024-07-23 23:29:28 +03:00
Sergey Borisov
ca21996a97 Remove old seamless class 2024-07-23 18:04:33 +03:00
Sergey Borisov
62aa064e56 Handle seamless in modular denoise 2024-07-23 18:03:59 +03:00
Sergey Borisov
87eb018380 Revert debug change 2024-07-22 23:49:20 +03:00
Sergey Borisov
5003e5d763 Same changes as in other PRs, add check for running inpainting on inpaint model without source image
Co-Authored-By: Ryan Dick <14897797+RyanJDick@users.noreply.github.com>
2024-07-22 23:47:39 +03:00
Sergey Borisov
58f3072b91 Handle inpainting on normal models 2024-07-21 22:17:29 +03:00
Sergey Borisov
9e7b470189 Handle inpaint models 2024-07-21 20:45:55 +03:00
916 changed files with 48532 additions and 61195 deletions

View File

@@ -62,7 +62,7 @@ jobs:
- name: install ruff
if: ${{ steps.changed-files.outputs.python_any_changed == 'true' || inputs.always_run == true }}
run: pip install ruff
run: pip install ruff==0.6.0
shell: bash
- name: ruff check

View File

@@ -1,158 +0,0 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"id": "aeb428d0-0817-462c-b5d8-455a0615d305",
"metadata": {},
"outputs": [],
"source": [
"import torch\n",
"from PIL import Image\n",
"import numpy as np\n",
"import cv2\n",
"\n",
"from invokeai.backend.vto_workflow.overlay_pattern import generate_dress_mask, multiply_images\n",
"from invokeai.backend.vto_workflow.extract_channel import extract_channel, ImageChannel\n",
"from invokeai.backend.vto_workflow.seamless_mapping import map_seamless_tiles\n",
"\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "6140d4b7-8238-431c-848e-6f6ae27652f5",
"metadata": {},
"outputs": [],
"source": [
" # Load the model image.\n",
"model_image = Image.open(\"/home/ryan/src/InvokeAI/invokeai/backend/vto_workflow/dress.jpeg\")\n",
"\n",
"# Load the pattern image.\n",
"pattern_image = Image.open(\"/home/ryan/src/InvokeAI/invokeai/backend/vto_workflow/pattern1.jpg\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "fb7186ba-dc0c-4520-ac30-49073a65601a",
"metadata": {},
"outputs": [],
"source": [
"mask = generate_dress_mask(model_image)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "9b935de4-94c5-4be5-bf8e-a5a6e445c811",
"metadata": {},
"outputs": [],
"source": [
"# Visualize mask\n",
"model_image_np = np.array(model_image)\n",
"masked_model_image = (model_image_np * np.expand_dims(mask, -1).astype(np.float32)).astype(np.uint8)\n",
"mask_image = Image.fromarray(masked_model_image)\n",
"mask_image"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "e51bb545",
"metadata": {},
"outputs": [],
"source": [
"shadows = extract_channel(np.array(model_image), ImageChannel.LAB_L)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "ec43de4a",
"metadata": {},
"outputs": [],
"source": [
"# Visualize masked shadows\n",
"masked_shadows = (shadows * mask).astype(np.uint8)\n",
"masked_shadows_image = Image.fromarray(masked_shadows)\n",
"masked_shadows_image"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "dbb53794",
"metadata": {},
"outputs": [],
"source": [
"# Tile the pattern.\n",
"expanded_pattern = map_seamless_tiles(seamless_tile=pattern_image, target_hw=(model_image.height, model_image.width), num_repeats_h=10.0)\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "f4f22d02",
"metadata": {},
"outputs": [],
"source": [
"# Multiply the pattern by the shadows.\n",
"pattern_with_shadows = multiply_images(expanded_pattern, shadows)\n",
"pattern_with_shadows"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "97db42b0",
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"id": "de32f7e3",
"metadata": {},
"outputs": [],
"source": [
"# Merge the pattern with the model image.\n",
"pattern_with_shadows_np = np.array(pattern_with_shadows)\n",
"merged_image = np.where(mask[:, :, None], pattern_with_shadows_np,model_image_np)\n",
"merged_image = Image.fromarray(merged_image)\n",
"merged_image"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "ff1d4044",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.12"
}
},
"nbformat": 4,
"nbformat_minor": 5
}

View File

@@ -55,6 +55,7 @@ RUN --mount=type=cache,target=/root/.cache/pip \
FROM node:20-slim AS web-builder
ENV PNPM_HOME="/pnpm"
ENV PATH="$PNPM_HOME:$PATH"
RUN corepack use pnpm@8.x
RUN corepack enable
WORKDIR /build

View File

@@ -1,20 +1,22 @@
# Invoke in Docker
- Ensure that Docker can use the GPU on your system
- This documentation assumes Linux, but should work similarly under Windows with WSL2
First things first:
- Ensure that Docker can use your [NVIDIA][nvidia docker docs] or [AMD][amd docker docs] GPU.
- This document assumes a Linux system, but should work similarly under Windows with WSL2.
- We don't recommend running Invoke in Docker on macOS at this time. It works, but very slowly.
## Quickstart :lightning:
## Quickstart
No `docker compose`, no persistence, just a simple one-liner using the official images:
No `docker compose`, no persistence, single command, using the official images:
**CUDA:**
**CUDA (NVIDIA GPU):**
```bash
docker run --runtime=nvidia --gpus=all --publish 9090:9090 ghcr.io/invoke-ai/invokeai
```
**ROCm:**
**ROCm (AMD GPU):**
```bash
docker run --device /dev/kfd --device /dev/dri --publish 9090:9090 ghcr.io/invoke-ai/invokeai:main-rocm
@@ -22,12 +24,20 @@ docker run --device /dev/kfd --device /dev/dri --publish 9090:9090 ghcr.io/invok
Open `http://localhost:9090` in your browser once the container finishes booting, install some models, and generate away!
> [!TIP]
> To persist your data (including downloaded models) outside of the container, add a `--volume/-v` flag to the above command, e.g.: `docker run --volume /some/local/path:/invokeai <...the rest of the command>`
### Data persistence
To persist your generated images and downloaded models outside of the container, add a `--volume/-v` flag to the above command, e.g.:
```bash
docker run --volume /some/local/path:/invokeai {...etc...}
```
`/some/local/path/invokeai` will contain all your data.
It can *usually* be reused between different installs of Invoke. Tread with caution and read the release notes!
## Customize the container
We ship the `run.sh` script, which is a convenient wrapper around `docker compose` for cases where custom image build args are needed. Alternatively, the familiar `docker compose` commands work just as well.
The included `run.sh` script is a convenience wrapper around `docker compose`. It can be helpful for passing additional build arguments to `docker compose`. Alternatively, the familiar `docker compose` commands work just as well.
```bash
cd docker
@@ -38,11 +48,14 @@ cp .env.sample .env
It will take a few minutes to build the image the first time. Once the application starts up, open `http://localhost:9090` in your browser to invoke!
>[!TIP]
>When using the `run.sh` script, the container will continue running after Ctrl+C. To shut it down, use the `docker compose down` command.
## Docker setup in detail
#### Linux
1. Ensure builkit is enabled in the Docker daemon settings (`/etc/docker/daemon.json`)
1. Ensure buildkit is enabled in the Docker daemon settings (`/etc/docker/daemon.json`)
2. Install the `docker compose` plugin using your package manager, or follow a [tutorial](https://docs.docker.com/compose/install/linux/#install-using-the-repository).
- The deprecated `docker-compose` (hyphenated) CLI probably won't work. Update to a recent version.
3. Ensure docker daemon is able to access the GPU.
@@ -98,25 +111,7 @@ GPU_DRIVER=cuda
Any environment variables supported by InvokeAI can be set here. See the [Configuration docs](https://invoke-ai.github.io/InvokeAI/features/CONFIGURATION/) for further detail.
## Even More Customizing!
---
See the `docker-compose.yml` file. The `command` instruction can be uncommented and used to run arbitrary startup commands. Some examples below.
### Reconfigure the runtime directory
Can be used to download additional models from the supported model list
In conjunction with `INVOKEAI_ROOT` can be also used to initialize a runtime directory
```yaml
command:
- invokeai-configure
- --yes
```
Or install models:
```yaml
command:
- invokeai-model-install
```
[nvidia docker docs]: https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/latest/install-guide.html
[amd docker docs]: https://rocm.docs.amd.com/projects/install-on-linux/en/latest/how-to/docker.html

View File

@@ -17,7 +17,7 @@
set -eu
# Ensure we're in the correct folder in case user's CWD is somewhere else
scriptdir=$(dirname "$0")
scriptdir=$(dirname $(readlink -f "$0"))
cd "$scriptdir"
. .venv/bin/activate

View File

@@ -1,5 +1,6 @@
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
import asyncio
from logging import Logger
import torch
@@ -31,6 +32,8 @@ from invokeai.app.services.session_processor.session_processor_default import (
)
from invokeai.app.services.session_queue.session_queue_sqlite import SqliteSessionQueue
from invokeai.app.services.shared.sqlite.sqlite_util import init_db
from invokeai.app.services.style_preset_images.style_preset_images_disk import StylePresetImageFileStorageDisk
from invokeai.app.services.style_preset_records.style_preset_records_sqlite import SqliteStylePresetRecordsStorage
from invokeai.app.services.urls.urls_default import LocalUrlService
from invokeai.app.services.workflow_records.workflow_records_sqlite import SqliteWorkflowRecordsStorage
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningFieldData
@@ -63,7 +66,12 @@ class ApiDependencies:
invoker: Invoker
@staticmethod
def initialize(config: InvokeAIAppConfig, event_handler_id: int, logger: Logger = logger) -> None:
def initialize(
config: InvokeAIAppConfig,
event_handler_id: int,
loop: asyncio.AbstractEventLoop,
logger: Logger = logger,
) -> None:
logger.info(f"InvokeAI version {__version__}")
logger.info(f"Root directory = {str(config.root_path)}")
@@ -74,6 +82,7 @@ class ApiDependencies:
image_files = DiskImageFileStorage(f"{output_folder}/images")
model_images_folder = config.models_path
style_presets_folder = config.style_presets_path
db = init_db(config=config, logger=logger, image_files=image_files)
@@ -84,7 +93,7 @@ class ApiDependencies:
board_images = BoardImagesService()
board_records = SqliteBoardRecordStorage(db=db)
boards = BoardService()
events = FastAPIEventService(event_handler_id)
events = FastAPIEventService(event_handler_id, loop=loop)
bulk_download = BulkDownloadService()
image_records = SqliteImageRecordStorage(db=db)
images = ImageService()
@@ -109,6 +118,8 @@ class ApiDependencies:
session_queue = SqliteSessionQueue(db=db)
urls = LocalUrlService()
workflow_records = SqliteWorkflowRecordsStorage(db=db)
style_preset_records = SqliteStylePresetRecordsStorage(db=db)
style_preset_image_files = StylePresetImageFileStorageDisk(style_presets_folder / "images")
services = InvocationServices(
board_image_records=board_image_records,
@@ -134,6 +145,8 @@ class ApiDependencies:
workflow_records=workflow_records,
tensors=tensors,
conditioning=conditioning,
style_preset_records=style_preset_records,
style_preset_image_files=style_preset_image_files,
)
ApiDependencies.invoker = Invoker(services)

View File

@@ -218,9 +218,8 @@ async def get_image_workflow(
raise HTTPException(status_code=404)
@images_router.api_route(
@images_router.get(
"/i/{image_name}/full",
methods=["GET", "HEAD"],
operation_id="get_image_full",
response_class=Response,
responses={
@@ -231,6 +230,18 @@ async def get_image_workflow(
404: {"description": "Image not found"},
},
)
@images_router.head(
"/i/{image_name}/full",
operation_id="get_image_full_head",
response_class=Response,
responses={
200: {
"description": "Return the full-resolution image",
"content": {"image/png": {}},
},
404: {"description": "Image not found"},
},
)
async def get_image_full(
image_name: str = Path(description="The name of full-resolution image file to get"),
) -> Response:
@@ -242,6 +253,7 @@ async def get_image_full(
content = f.read()
response = Response(content, media_type="image/png")
response.headers["Cache-Control"] = f"max-age={IMAGE_MAX_AGE}"
response.headers["Content-Disposition"] = f'inline; filename="{image_name}"'
return response
except Exception:
raise HTTPException(status_code=404)

View File

@@ -11,6 +11,7 @@ from invokeai.app.services.session_queue.session_queue_common import (
Batch,
BatchStatus,
CancelByBatchIDsResult,
CancelByOriginResult,
ClearResult,
EnqueueBatchResult,
PruneResult,
@@ -105,6 +106,19 @@ async def cancel_by_batch_ids(
return ApiDependencies.invoker.services.session_queue.cancel_by_batch_ids(queue_id=queue_id, batch_ids=batch_ids)
@session_queue_router.put(
"/{queue_id}/cancel_by_origin",
operation_id="cancel_by_origin",
responses={200: {"model": CancelByBatchIDsResult}},
)
async def cancel_by_origin(
queue_id: str = Path(description="The queue id to perform this operation on"),
origin: str = Query(description="The origin to cancel all queue items for"),
) -> CancelByOriginResult:
"""Immediately cancels all queue items with the given origin"""
return ApiDependencies.invoker.services.session_queue.cancel_by_origin(queue_id=queue_id, origin=origin)
@session_queue_router.put(
"/{queue_id}/clear",
operation_id="clear",

View File

@@ -0,0 +1,274 @@
import csv
import io
import json
import traceback
from typing import Optional
import pydantic
from fastapi import APIRouter, File, Form, HTTPException, Path, Response, UploadFile
from fastapi.responses import FileResponse
from PIL import Image
from pydantic import BaseModel, Field
from invokeai.app.api.dependencies import ApiDependencies
from invokeai.app.api.routers.model_manager import IMAGE_MAX_AGE
from invokeai.app.services.style_preset_images.style_preset_images_common import StylePresetImageFileNotFoundException
from invokeai.app.services.style_preset_records.style_preset_records_common import (
InvalidPresetImportDataError,
PresetData,
PresetType,
StylePresetChanges,
StylePresetNotFoundError,
StylePresetRecordWithImage,
StylePresetWithoutId,
UnsupportedFileTypeError,
parse_presets_from_file,
)
class StylePresetFormData(BaseModel):
name: str = Field(description="Preset name")
positive_prompt: str = Field(description="Positive prompt")
negative_prompt: str = Field(description="Negative prompt")
type: PresetType = Field(description="Preset type")
style_presets_router = APIRouter(prefix="/v1/style_presets", tags=["style_presets"])
@style_presets_router.get(
"/i/{style_preset_id}",
operation_id="get_style_preset",
responses={
200: {"model": StylePresetRecordWithImage},
},
)
async def get_style_preset(
style_preset_id: str = Path(description="The style preset to get"),
) -> StylePresetRecordWithImage:
"""Gets a style preset"""
try:
image = ApiDependencies.invoker.services.style_preset_image_files.get_url(style_preset_id)
style_preset = ApiDependencies.invoker.services.style_preset_records.get(style_preset_id)
return StylePresetRecordWithImage(image=image, **style_preset.model_dump())
except StylePresetNotFoundError:
raise HTTPException(status_code=404, detail="Style preset not found")
@style_presets_router.patch(
"/i/{style_preset_id}",
operation_id="update_style_preset",
responses={
200: {"model": StylePresetRecordWithImage},
},
)
async def update_style_preset(
image: Optional[UploadFile] = File(description="The image file to upload", default=None),
style_preset_id: str = Path(description="The id of the style preset to update"),
data: str = Form(description="The data of the style preset to update"),
) -> StylePresetRecordWithImage:
"""Updates a style preset"""
if image is not None:
if not image.content_type or not image.content_type.startswith("image"):
raise HTTPException(status_code=415, detail="Not an image")
contents = await image.read()
try:
pil_image = Image.open(io.BytesIO(contents))
except Exception:
ApiDependencies.invoker.services.logger.error(traceback.format_exc())
raise HTTPException(status_code=415, detail="Failed to read image")
try:
ApiDependencies.invoker.services.style_preset_image_files.save(style_preset_id, pil_image)
except ValueError as e:
raise HTTPException(status_code=409, detail=str(e))
else:
try:
ApiDependencies.invoker.services.style_preset_image_files.delete(style_preset_id)
except StylePresetImageFileNotFoundException:
pass
try:
parsed_data = json.loads(data)
validated_data = StylePresetFormData(**parsed_data)
name = validated_data.name
type = validated_data.type
positive_prompt = validated_data.positive_prompt
negative_prompt = validated_data.negative_prompt
except pydantic.ValidationError:
raise HTTPException(status_code=400, detail="Invalid preset data")
preset_data = PresetData(positive_prompt=positive_prompt, negative_prompt=negative_prompt)
changes = StylePresetChanges(name=name, preset_data=preset_data, type=type)
style_preset_image = ApiDependencies.invoker.services.style_preset_image_files.get_url(style_preset_id)
style_preset = ApiDependencies.invoker.services.style_preset_records.update(
style_preset_id=style_preset_id, changes=changes
)
return StylePresetRecordWithImage(image=style_preset_image, **style_preset.model_dump())
@style_presets_router.delete(
"/i/{style_preset_id}",
operation_id="delete_style_preset",
)
async def delete_style_preset(
style_preset_id: str = Path(description="The style preset to delete"),
) -> None:
"""Deletes a style preset"""
try:
ApiDependencies.invoker.services.style_preset_image_files.delete(style_preset_id)
except StylePresetImageFileNotFoundException:
pass
ApiDependencies.invoker.services.style_preset_records.delete(style_preset_id)
@style_presets_router.post(
"/",
operation_id="create_style_preset",
responses={
200: {"model": StylePresetRecordWithImage},
},
)
async def create_style_preset(
image: Optional[UploadFile] = File(description="The image file to upload", default=None),
data: str = Form(description="The data of the style preset to create"),
) -> StylePresetRecordWithImage:
"""Creates a style preset"""
try:
parsed_data = json.loads(data)
validated_data = StylePresetFormData(**parsed_data)
name = validated_data.name
type = validated_data.type
positive_prompt = validated_data.positive_prompt
negative_prompt = validated_data.negative_prompt
except pydantic.ValidationError:
raise HTTPException(status_code=400, detail="Invalid preset data")
preset_data = PresetData(positive_prompt=positive_prompt, negative_prompt=negative_prompt)
style_preset = StylePresetWithoutId(name=name, preset_data=preset_data, type=type)
new_style_preset = ApiDependencies.invoker.services.style_preset_records.create(style_preset=style_preset)
if image is not None:
if not image.content_type or not image.content_type.startswith("image"):
raise HTTPException(status_code=415, detail="Not an image")
contents = await image.read()
try:
pil_image = Image.open(io.BytesIO(contents))
except Exception:
ApiDependencies.invoker.services.logger.error(traceback.format_exc())
raise HTTPException(status_code=415, detail="Failed to read image")
try:
ApiDependencies.invoker.services.style_preset_image_files.save(new_style_preset.id, pil_image)
except ValueError as e:
raise HTTPException(status_code=409, detail=str(e))
preset_image = ApiDependencies.invoker.services.style_preset_image_files.get_url(new_style_preset.id)
return StylePresetRecordWithImage(image=preset_image, **new_style_preset.model_dump())
@style_presets_router.get(
"/",
operation_id="list_style_presets",
responses={
200: {"model": list[StylePresetRecordWithImage]},
},
)
async def list_style_presets() -> list[StylePresetRecordWithImage]:
"""Gets a page of style presets"""
style_presets_with_image: list[StylePresetRecordWithImage] = []
style_presets = ApiDependencies.invoker.services.style_preset_records.get_many()
for preset in style_presets:
image = ApiDependencies.invoker.services.style_preset_image_files.get_url(preset.id)
style_preset_with_image = StylePresetRecordWithImage(image=image, **preset.model_dump())
style_presets_with_image.append(style_preset_with_image)
return style_presets_with_image
@style_presets_router.get(
"/i/{style_preset_id}/image",
operation_id="get_style_preset_image",
responses={
200: {
"description": "The style preset image was fetched successfully",
},
400: {"description": "Bad request"},
404: {"description": "The style preset image could not be found"},
},
status_code=200,
)
async def get_style_preset_image(
style_preset_id: str = Path(description="The id of the style preset image to get"),
) -> FileResponse:
"""Gets an image file that previews the model"""
try:
path = ApiDependencies.invoker.services.style_preset_image_files.get_path(style_preset_id)
response = FileResponse(
path,
media_type="image/png",
filename=style_preset_id + ".png",
content_disposition_type="inline",
)
response.headers["Cache-Control"] = f"max-age={IMAGE_MAX_AGE}"
return response
except Exception:
raise HTTPException(status_code=404)
@style_presets_router.get(
"/export",
operation_id="export_style_presets",
responses={200: {"content": {"text/csv": {}}, "description": "A CSV file with the requested data."}},
status_code=200,
)
async def export_style_presets():
# Create an in-memory stream to store the CSV data
output = io.StringIO()
writer = csv.writer(output)
# Write the header
writer.writerow(["name", "prompt", "negative_prompt"])
style_presets = ApiDependencies.invoker.services.style_preset_records.get_many(type=PresetType.User)
for preset in style_presets:
writer.writerow([preset.name, preset.preset_data.positive_prompt, preset.preset_data.negative_prompt])
csv_data = output.getvalue()
output.close()
return Response(
content=csv_data,
media_type="text/csv",
headers={"Content-Disposition": "attachment; filename=prompt_templates.csv"},
)
@style_presets_router.post(
"/import",
operation_id="import_style_presets",
)
async def import_style_presets(file: UploadFile = File(description="The file to import")):
try:
style_presets = await parse_presets_from_file(file)
ApiDependencies.invoker.services.style_preset_records.create_many(style_presets)
except InvalidPresetImportDataError as e:
ApiDependencies.invoker.services.logger.error(traceback.format_exc())
raise HTTPException(status_code=400, detail=str(e))
except UnsupportedFileTypeError as e:
ApiDependencies.invoker.services.logger.error(traceback.format_exc())
raise HTTPException(status_code=415, detail=str(e))

View File

@@ -30,6 +30,7 @@ from invokeai.app.api.routers import (
images,
model_manager,
session_queue,
style_presets,
utilities,
workflows,
)
@@ -55,11 +56,13 @@ mimetypes.add_type("text/css", ".css")
torch_device_name = TorchDevice.get_torch_device_name()
logger.info(f"Using torch device: {torch_device_name}")
loop = asyncio.new_event_loop()
@asynccontextmanager
async def lifespan(app: FastAPI):
# Add startup event to load dependencies
ApiDependencies.initialize(config=app_config, event_handler_id=event_handler_id, logger=logger)
ApiDependencies.initialize(config=app_config, event_handler_id=event_handler_id, loop=loop, logger=logger)
yield
# Shut down threads
ApiDependencies.shutdown()
@@ -106,6 +109,7 @@ app.include_router(board_images.board_images_router, prefix="/api")
app.include_router(app_info.app_router, prefix="/api")
app.include_router(session_queue.session_queue_router, prefix="/api")
app.include_router(workflows.workflows_router, prefix="/api")
app.include_router(style_presets.style_presets_router, prefix="/api")
app.openapi = get_openapi_func(app)
@@ -184,8 +188,6 @@ def invoke_api() -> None:
check_cudnn(logger)
# Start our own event loop for eventing usage
loop = asyncio.new_event_loop()
config = uvicorn.Config(
app=app,
host=app_config.host,

View File

@@ -80,12 +80,12 @@ class CompelInvocation(BaseInvocation):
with (
# apply all patches while the model is on the target device
text_encoder_info.model_on_device() as (model_state_dict, text_encoder),
text_encoder_info.model_on_device() as (cached_weights, text_encoder),
tokenizer_info as tokenizer,
ModelPatcher.apply_lora_text_encoder(
text_encoder,
loras=_lora_loader(),
model_state_dict=model_state_dict,
cached_weights=cached_weights,
),
# Apply CLIP Skip after LoRA to prevent LoRA application from failing on skipped layers.
ModelPatcher.apply_clip_skip(text_encoder, self.clip.skipped_layers),
@@ -175,13 +175,13 @@ class SDXLPromptInvocationBase:
with (
# apply all patches while the model is on the target device
text_encoder_info.model_on_device() as (state_dict, text_encoder),
text_encoder_info.model_on_device() as (cached_weights, text_encoder),
tokenizer_info as tokenizer,
ModelPatcher.apply_lora(
text_encoder,
loras=_lora_loader(),
prefix=lora_prefix,
model_state_dict=state_dict,
cached_weights=cached_weights,
),
# Apply CLIP Skip after LoRA to prevent LoRA application from failing on skipped layers.
ModelPatcher.apply_clip_skip(text_encoder, clip_field.skipped_layers),

View File

@@ -21,6 +21,8 @@ from controlnet_aux import (
from controlnet_aux.util import HWC3, ade_palette
from PIL import Image
from pydantic import BaseModel, Field, field_validator, model_validator
from transformers import pipeline
from transformers.pipelines import DepthEstimationPipeline
from invokeai.app.invocations.baseinvocation import (
BaseInvocation,
@@ -44,13 +46,12 @@ from invokeai.app.invocations.util import validate_begin_end_step, validate_weig
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.app.util.controlnet_utils import CONTROLNET_MODE_VALUES, CONTROLNET_RESIZE_VALUES, heuristic_resize
from invokeai.backend.image_util.canny import get_canny_edges
from invokeai.backend.image_util.depth_anything import DEPTH_ANYTHING_MODELS, DepthAnythingDetector
from invokeai.backend.image_util.depth_anything.depth_anything_pipeline import DepthAnythingPipeline
from invokeai.backend.image_util.dw_openpose import DWPOSE_MODELS, DWOpenposeDetector
from invokeai.backend.image_util.hed import HEDProcessor
from invokeai.backend.image_util.lineart import LineartProcessor
from invokeai.backend.image_util.lineart_anime import LineartAnimeProcessor
from invokeai.backend.image_util.util import np_to_pil, pil_to_np
from invokeai.backend.util.devices import TorchDevice
class ControlField(BaseModel):
@@ -592,7 +593,14 @@ class ColorMapImageProcessorInvocation(ImageProcessorInvocation):
return color_map
DEPTH_ANYTHING_MODEL_SIZES = Literal["large", "base", "small"]
DEPTH_ANYTHING_MODEL_SIZES = Literal["large", "base", "small", "small_v2"]
# DepthAnything V2 Small model is licensed under Apache 2.0 but not the base and large models.
DEPTH_ANYTHING_MODELS = {
"large": "LiheYoung/depth-anything-large-hf",
"base": "LiheYoung/depth-anything-base-hf",
"small": "LiheYoung/depth-anything-small-hf",
"small_v2": "depth-anything/Depth-Anything-V2-Small-hf",
}
@invocation(
@@ -600,28 +608,33 @@ DEPTH_ANYTHING_MODEL_SIZES = Literal["large", "base", "small"]
title="Depth Anything Processor",
tags=["controlnet", "depth", "depth anything"],
category="controlnet",
version="1.1.2",
version="1.1.3",
)
class DepthAnythingImageProcessorInvocation(ImageProcessorInvocation):
"""Generates a depth map based on the Depth Anything algorithm"""
model_size: DEPTH_ANYTHING_MODEL_SIZES = InputField(
default="small", description="The size of the depth model to use"
default="small_v2", description="The size of the depth model to use"
)
resolution: int = InputField(default=512, ge=1, description=FieldDescriptions.image_res)
def run_processor(self, image: Image.Image) -> Image.Image:
def loader(model_path: Path):
return DepthAnythingDetector.load_model(
model_path, model_size=self.model_size, device=TorchDevice.choose_torch_device()
)
def load_depth_anything(model_path: Path):
depth_anything_pipeline = pipeline(model=str(model_path), task="depth-estimation", local_files_only=True)
assert isinstance(depth_anything_pipeline, DepthEstimationPipeline)
return DepthAnythingPipeline(depth_anything_pipeline)
with self._context.models.load_remote_model(
source=DEPTH_ANYTHING_MODELS[self.model_size], loader=loader
) as model:
depth_anything_detector = DepthAnythingDetector(model, TorchDevice.choose_torch_device())
processed_image = depth_anything_detector(image=image, resolution=self.resolution)
return processed_image
source=DEPTH_ANYTHING_MODELS[self.model_size], loader=load_depth_anything
) as depth_anything_detector:
assert isinstance(depth_anything_detector, DepthAnythingPipeline)
depth_map = depth_anything_detector.generate_depth(image)
# Resizing to user target specified size
new_height = int(image.size[1] * (self.resolution / image.size[0]))
depth_map = depth_map.resize((self.resolution, new_height))
return depth_map
@invocation(

View File

@@ -39,7 +39,7 @@ class GradientMaskOutput(BaseInvocationOutput):
title="Create Gradient Mask",
tags=["mask", "denoise"],
category="latents",
version="1.1.0",
version="1.2.0",
)
class CreateGradientMaskInvocation(BaseInvocation):
"""Creates mask for denoising model run."""
@@ -93,6 +93,7 @@ class CreateGradientMaskInvocation(BaseInvocation):
# redistribute blur so that the original edges are 0 and blur outwards to 1
blur_tensor = (blur_tensor - 0.5) * 2
blur_tensor[blur_tensor < 0] = 0.0
threshold = 1 - self.minimum_denoise

View File

@@ -37,9 +37,9 @@ from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.app.util.controlnet_utils import prepare_control_image
from invokeai.backend.ip_adapter.ip_adapter import IPAdapter
from invokeai.backend.lora import LoRAModelRaw
from invokeai.backend.model_manager import BaseModelType
from invokeai.backend.model_manager import BaseModelType, ModelVariantType
from invokeai.backend.model_patcher import ModelPatcher
from invokeai.backend.stable_diffusion import PipelineIntermediateState, set_seamless
from invokeai.backend.stable_diffusion import PipelineIntermediateState
from invokeai.backend.stable_diffusion.denoise_context import DenoiseContext, DenoiseInputs
from invokeai.backend.stable_diffusion.diffusers_pipeline import (
ControlNetData,
@@ -60,8 +60,13 @@ from invokeai.backend.stable_diffusion.diffusion_backend import StableDiffusionB
from invokeai.backend.stable_diffusion.extension_callback_type import ExtensionCallbackType
from invokeai.backend.stable_diffusion.extensions.controlnet import ControlNetExt
from invokeai.backend.stable_diffusion.extensions.freeu import FreeUExt
from invokeai.backend.stable_diffusion.extensions.inpaint import InpaintExt
from invokeai.backend.stable_diffusion.extensions.inpaint_model import InpaintModelExt
from invokeai.backend.stable_diffusion.extensions.lora import LoRAExt
from invokeai.backend.stable_diffusion.extensions.preview import PreviewExt
from invokeai.backend.stable_diffusion.extensions.rescale_cfg import RescaleCFGExt
from invokeai.backend.stable_diffusion.extensions.seamless import SeamlessExt
from invokeai.backend.stable_diffusion.extensions.t2i_adapter import T2IAdapterExt
from invokeai.backend.stable_diffusion.extensions_manager import ExtensionsManager
from invokeai.backend.stable_diffusion.schedulers import SCHEDULER_MAP
from invokeai.backend.stable_diffusion.schedulers.schedulers import SCHEDULER_NAME_VALUES
@@ -498,6 +503,33 @@ class DenoiseLatentsInvocation(BaseInvocation):
)
)
@staticmethod
def parse_t2i_adapter_field(
exit_stack: ExitStack,
context: InvocationContext,
t2i_adapters: Optional[Union[T2IAdapterField, list[T2IAdapterField]]],
ext_manager: ExtensionsManager,
) -> None:
if t2i_adapters is None:
return
# Handle the possibility that t2i_adapters could be a list or a single T2IAdapterField.
if isinstance(t2i_adapters, T2IAdapterField):
t2i_adapters = [t2i_adapters]
for t2i_adapter_field in t2i_adapters:
ext_manager.add_extension(
T2IAdapterExt(
node_context=context,
model_id=t2i_adapter_field.t2i_adapter_model,
image=context.images.get_pil(t2i_adapter_field.image.image_name),
weight=t2i_adapter_field.weight,
begin_step_percent=t2i_adapter_field.begin_step_percent,
end_step_percent=t2i_adapter_field.end_step_percent,
resize_mode=t2i_adapter_field.resize_mode,
)
)
def prep_ip_adapter_image_prompts(
self,
context: InvocationContext,
@@ -707,7 +739,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
else:
masked_latents = torch.where(mask < 0.5, 0.0, latents)
return 1 - mask, masked_latents, self.denoise_mask.gradient
return mask, masked_latents, self.denoise_mask.gradient
@staticmethod
def prepare_noise_and_latents(
@@ -765,10 +797,6 @@ class DenoiseLatentsInvocation(BaseInvocation):
dtype = TorchDevice.choose_torch_dtype()
seed, noise, latents = self.prepare_noise_and_latents(context, self.noise, self.latents)
latents = latents.to(device=device, dtype=dtype)
if noise is not None:
noise = noise.to(device=device, dtype=dtype)
_, _, latent_height, latent_width = latents.shape
conditioning_data = self.get_conditioning_data(
@@ -801,21 +829,6 @@ class DenoiseLatentsInvocation(BaseInvocation):
denoising_end=self.denoising_end,
)
denoise_ctx = DenoiseContext(
inputs=DenoiseInputs(
orig_latents=latents,
timesteps=timesteps,
init_timestep=init_timestep,
noise=noise,
seed=seed,
scheduler_step_kwargs=scheduler_step_kwargs,
conditioning_data=conditioning_data,
attention_processor_cls=CustomAttnProcessor2_0,
),
unet=None,
scheduler=scheduler,
)
# get the unet's config so that we can pass the base to sd_step_callback()
unet_config = context.models.get_config(self.unet.unet.key)
@@ -833,6 +846,50 @@ class DenoiseLatentsInvocation(BaseInvocation):
if self.unet.freeu_config:
ext_manager.add_extension(FreeUExt(self.unet.freeu_config))
### lora
if self.unet.loras:
for lora_field in self.unet.loras:
ext_manager.add_extension(
LoRAExt(
node_context=context,
model_id=lora_field.lora,
weight=lora_field.weight,
)
)
### seamless
if self.unet.seamless_axes:
ext_manager.add_extension(SeamlessExt(self.unet.seamless_axes))
### inpaint
mask, masked_latents, is_gradient_mask = self.prep_inpaint_mask(context, latents)
# NOTE: We used to identify inpainting models by inpecting the shape of the loaded UNet model weights. Now we
# use the ModelVariantType config. During testing, there was a report of a user with models that had an
# incorrect ModelVariantType value. Re-installing the model fixed the issue. If this issue turns out to be
# prevalent, we will have to revisit how we initialize the inpainting extensions.
if unet_config.variant == ModelVariantType.Inpaint:
ext_manager.add_extension(InpaintModelExt(mask, masked_latents, is_gradient_mask))
elif mask is not None:
ext_manager.add_extension(InpaintExt(mask, is_gradient_mask))
# Initialize context for modular denoise
latents = latents.to(device=device, dtype=dtype)
if noise is not None:
noise = noise.to(device=device, dtype=dtype)
denoise_ctx = DenoiseContext(
inputs=DenoiseInputs(
orig_latents=latents,
timesteps=timesteps,
init_timestep=init_timestep,
noise=noise,
seed=seed,
scheduler_step_kwargs=scheduler_step_kwargs,
conditioning_data=conditioning_data,
attention_processor_cls=CustomAttnProcessor2_0,
),
unet=None,
scheduler=scheduler,
)
# context for loading additional models
with ExitStack() as exit_stack:
# later should be smth like:
@@ -840,6 +897,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
# ext = extension_field.to_extension(exit_stack, context, ext_manager)
# ext_manager.add_extension(ext)
self.parse_controlnet_field(exit_stack, context, self.control, ext_manager)
self.parse_t2i_adapter_field(exit_stack, context, self.t2i_adapter, ext_manager)
# ext: t2i/ip adapter
ext_manager.run_callback(ExtensionCallbackType.SETUP, denoise_ctx)
@@ -871,6 +929,10 @@ class DenoiseLatentsInvocation(BaseInvocation):
seed, noise, latents = self.prepare_noise_and_latents(context, self.noise, self.latents)
mask, masked_latents, gradient_mask = self.prep_inpaint_mask(context, latents)
# At this point, the mask ranges from 0 (leave unchanged) to 1 (inpaint).
# We invert the mask here for compatibility with the old backend implementation.
if mask is not None:
mask = 1 - mask
# TODO(ryand): I have hard-coded `do_classifier_free_guidance=True` to mirror the behaviour of ControlNets,
# below. Investigate whether this is appropriate.
@@ -913,14 +975,14 @@ class DenoiseLatentsInvocation(BaseInvocation):
assert isinstance(unet_info.model, UNet2DConditionModel)
with (
ExitStack() as exit_stack,
unet_info.model_on_device() as (model_state_dict, unet),
unet_info.model_on_device() as (cached_weights, unet),
ModelPatcher.apply_freeu(unet, self.unet.freeu_config),
set_seamless(unet, self.unet.seamless_axes), # FIXME
SeamlessExt.static_patch_model(unet, self.unet.seamless_axes), # FIXME
# Apply the LoRA after unet has been moved to its target device for faster patching.
ModelPatcher.apply_lora_unet(
unet,
loras=_lora_loader(),
model_state_dict=model_state_dict,
cached_weights=cached_weights,
),
):
assert isinstance(unet, UNet2DConditionModel)

View File

@@ -1,7 +1,7 @@
from enum import Enum
from typing import Any, Callable, Optional, Tuple
from pydantic import BaseModel, ConfigDict, Field, RootModel, TypeAdapter
from pydantic import BaseModel, ConfigDict, Field, RootModel, TypeAdapter, model_validator
from pydantic.fields import _Unset
from pydantic_core import PydanticUndefined
@@ -242,6 +242,31 @@ class ConditioningField(BaseModel):
)
class BoundingBoxField(BaseModel):
"""A bounding box primitive value."""
x_min: int = Field(ge=0, description="The minimum x-coordinate of the bounding box (inclusive).")
x_max: int = Field(ge=0, description="The maximum x-coordinate of the bounding box (exclusive).")
y_min: int = Field(ge=0, description="The minimum y-coordinate of the bounding box (inclusive).")
y_max: int = Field(ge=0, description="The maximum y-coordinate of the bounding box (exclusive).")
score: Optional[float] = Field(
default=None,
ge=0.0,
le=1.0,
description="The score associated with the bounding box. In the range [0, 1]. This value is typically set "
"when the bounding box was produced by a detector and has an associated confidence score.",
)
@model_validator(mode="after")
def check_coords(self):
if self.x_min > self.x_max:
raise ValueError(f"x_min ({self.x_min}) is greater than x_max ({self.x_max}).")
if self.y_min > self.y_max:
raise ValueError(f"y_min ({self.y_min}) is greater than y_max ({self.y_max}).")
return self
class MetadataField(RootModel[dict[str, Any]]):
"""
Pydantic model for metadata with custom root of type dict[str, Any].

View File

@@ -0,0 +1,100 @@
from pathlib import Path
from typing import Literal
import torch
from PIL import Image
from transformers import pipeline
from transformers.pipelines import ZeroShotObjectDetectionPipeline
from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation
from invokeai.app.invocations.fields import BoundingBoxField, ImageField, InputField
from invokeai.app.invocations.primitives import BoundingBoxCollectionOutput
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.image_util.grounding_dino.detection_result import DetectionResult
from invokeai.backend.image_util.grounding_dino.grounding_dino_pipeline import GroundingDinoPipeline
GroundingDinoModelKey = Literal["grounding-dino-tiny", "grounding-dino-base"]
GROUNDING_DINO_MODEL_IDS: dict[GroundingDinoModelKey, str] = {
"grounding-dino-tiny": "IDEA-Research/grounding-dino-tiny",
"grounding-dino-base": "IDEA-Research/grounding-dino-base",
}
@invocation(
"grounding_dino",
title="Grounding DINO (Text Prompt Object Detection)",
tags=["prompt", "object detection"],
category="image",
version="1.0.0",
)
class GroundingDinoInvocation(BaseInvocation):
"""Runs a Grounding DINO model. Performs zero-shot bounding-box object detection from a text prompt."""
# Reference:
# - https://arxiv.org/pdf/2303.05499
# - https://huggingface.co/docs/transformers/v4.43.3/en/model_doc/grounding-dino#grounded-sam
# - https://github.com/NielsRogge/Transformers-Tutorials/blob/a39f33ac1557b02ebfb191ea7753e332b5ca933f/Grounding%20DINO/GroundingDINO_with_Segment_Anything.ipynb
model: GroundingDinoModelKey = InputField(description="The Grounding DINO model to use.")
prompt: str = InputField(description="The prompt describing the object to segment.")
image: ImageField = InputField(description="The image to segment.")
detection_threshold: float = InputField(
description="The detection threshold for the Grounding DINO model. All detected bounding boxes with scores above this threshold will be returned.",
ge=0.0,
le=1.0,
default=0.3,
)
@torch.no_grad()
def invoke(self, context: InvocationContext) -> BoundingBoxCollectionOutput:
# The model expects a 3-channel RGB image.
image_pil = context.images.get_pil(self.image.image_name, mode="RGB")
detections = self._detect(
context=context, image=image_pil, labels=[self.prompt], threshold=self.detection_threshold
)
# Convert detections to BoundingBoxCollectionOutput.
bounding_boxes: list[BoundingBoxField] = []
for detection in detections:
bounding_boxes.append(
BoundingBoxField(
x_min=detection.box.xmin,
x_max=detection.box.xmax,
y_min=detection.box.ymin,
y_max=detection.box.ymax,
score=detection.score,
)
)
return BoundingBoxCollectionOutput(collection=bounding_boxes)
@staticmethod
def _load_grounding_dino(model_path: Path):
grounding_dino_pipeline = pipeline(
model=str(model_path),
task="zero-shot-object-detection",
local_files_only=True,
# TODO(ryand): Setting the torch_dtype here doesn't work. Investigate whether fp16 is supported by the
# model, and figure out how to make it work in the pipeline.
# torch_dtype=TorchDevice.choose_torch_dtype(),
)
assert isinstance(grounding_dino_pipeline, ZeroShotObjectDetectionPipeline)
return GroundingDinoPipeline(grounding_dino_pipeline)
def _detect(
self,
context: InvocationContext,
image: Image.Image,
labels: list[str],
threshold: float = 0.3,
) -> list[DetectionResult]:
"""Use Grounding DINO to detect bounding boxes for a set of labels in an image."""
# TODO(ryand): I copied this "."-handling logic from the transformers example code. Test it and see if it
# actually makes a difference.
labels = [label if label.endswith(".") else label + "." for label in labels]
with context.models.load_remote_model(
source=GROUNDING_DINO_MODEL_IDS[self.model], loader=GroundingDinoInvocation._load_grounding_dino
) as detector:
assert isinstance(detector, GroundingDinoPipeline)
return detector.detect(image=image, candidate_labels=labels, threshold=threshold)

View File

@@ -6,13 +6,19 @@ import cv2
import numpy
from PIL import Image, ImageChops, ImageFilter, ImageOps
from invokeai.app.invocations.baseinvocation import BaseInvocation, Classification, invocation
from invokeai.app.invocations.baseinvocation import (
BaseInvocation,
Classification,
invocation,
invocation_output,
)
from invokeai.app.invocations.constants import IMAGE_MODES
from invokeai.app.invocations.fields import (
ColorField,
FieldDescriptions,
ImageField,
InputField,
OutputField,
WithBoard,
WithMetadata,
)
@@ -1007,3 +1013,62 @@ class MaskFromIDInvocation(BaseInvocation, WithMetadata, WithBoard):
image_dto = context.images.save(image=mask, image_category=ImageCategory.MASK)
return ImageOutput.build(image_dto)
@invocation_output("canvas_v2_mask_and_crop_output")
class CanvasV2MaskAndCropOutput(ImageOutput):
offset_x: int = OutputField(description="The x offset of the image, after cropping")
offset_y: int = OutputField(description="The y offset of the image, after cropping")
@invocation(
"canvas_v2_mask_and_crop",
title="Canvas V2 Mask and Crop",
tags=["image", "mask", "id"],
category="image",
version="1.0.0",
classification=Classification.Prototype,
)
class CanvasV2MaskAndCropInvocation(BaseInvocation, WithMetadata, WithBoard):
"""Handles Canvas V2 image output masking and cropping"""
source_image: ImageField | None = InputField(
default=None,
description="The source image onto which the masked generated image is pasted. If omitted, the masked generated image is returned with transparency.",
)
generated_image: ImageField = InputField(description="The image to apply the mask to")
mask: ImageField = InputField(description="The mask to apply")
mask_blur: int = InputField(default=0, ge=0, description="The amount to blur the mask by")
def _prepare_mask(self, mask: Image.Image) -> Image.Image:
mask_array = numpy.array(mask)
kernel = numpy.ones((self.mask_blur, self.mask_blur), numpy.uint8)
dilated_mask_array = cv2.erode(mask_array, kernel, iterations=3)
dilated_mask = Image.fromarray(dilated_mask_array)
if self.mask_blur > 0:
mask = dilated_mask.filter(ImageFilter.GaussianBlur(self.mask_blur))
return ImageOps.invert(mask.convert("L"))
def invoke(self, context: InvocationContext) -> CanvasV2MaskAndCropOutput:
mask = self._prepare_mask(context.images.get_pil(self.mask.image_name))
if self.source_image:
generated_image = context.images.get_pil(self.generated_image.image_name)
source_image = context.images.get_pil(self.source_image.image_name)
source_image.paste(generated_image, (0, 0), mask)
image_dto = context.images.save(image=source_image)
else:
generated_image = context.images.get_pil(self.generated_image.image_name)
generated_image.putalpha(mask)
image_dto = context.images.save(image=generated_image)
# bbox = image.getbbox()
# image = image.crop(bbox)
return CanvasV2MaskAndCropOutput(
image=ImageField(image_name=image_dto.image_name),
offset_x=0,
offset_y=0,
width=image_dto.width,
height=image_dto.height,
)

View File

@@ -24,7 +24,7 @@ from invokeai.app.invocations.fields import (
from invokeai.app.invocations.model import VAEField
from invokeai.app.invocations.primitives import ImageOutput
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.stable_diffusion import set_seamless
from invokeai.backend.stable_diffusion.extensions.seamless import SeamlessExt
from invokeai.backend.stable_diffusion.vae_tiling import patch_vae_tiling_params
from invokeai.backend.util.devices import TorchDevice
@@ -59,7 +59,7 @@ class LatentsToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
vae_info = context.models.load(self.vae.vae)
assert isinstance(vae_info.model, (AutoencoderKL, AutoencoderTiny))
with set_seamless(vae_info.model, self.vae.seamless_axes), vae_info as vae:
with SeamlessExt.static_patch_model(vae_info.model, self.vae.seamless_axes), vae_info as vae:
assert isinstance(vae, (AutoencoderKL, AutoencoderTiny))
latents = latents.to(vae.device)
if self.fp32:

View File

@@ -1,9 +1,10 @@
import numpy as np
import torch
from PIL import Image
from invokeai.app.invocations.baseinvocation import BaseInvocation, Classification, InvocationContext, invocation
from invokeai.app.invocations.fields import ImageField, InputField, TensorField, WithMetadata
from invokeai.app.invocations.primitives import MaskOutput
from invokeai.app.invocations.fields import ImageField, InputField, TensorField, WithBoard, WithMetadata
from invokeai.app.invocations.primitives import ImageOutput, MaskOutput
@invocation(
@@ -118,3 +119,27 @@ class ImageMaskToTensorInvocation(BaseInvocation, WithMetadata):
height=mask.shape[1],
width=mask.shape[2],
)
@invocation(
"tensor_mask_to_image",
title="Tensor Mask to Image",
tags=["mask"],
category="mask",
version="1.0.0",
)
class MaskTensorToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
"""Convert a mask tensor to an image."""
mask: TensorField = InputField(description="The mask tensor to convert.")
def invoke(self, context: InvocationContext) -> ImageOutput:
mask = context.tensors.load(self.mask.tensor_name)
# Ensure that the mask is binary.
if mask.dtype != torch.bool:
mask = mask > 0.5
mask_np = (mask.float() * 255).byte().cpu().numpy()
mask_pil = Image.fromarray(mask_np, mode="L")
image_dto = context.images.save(image=mask_pil)
return ImageOutput.build(image_dto)

View File

@@ -7,6 +7,7 @@ import torch
from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput, invocation, invocation_output
from invokeai.app.invocations.constants import LATENT_SCALE_FACTOR
from invokeai.app.invocations.fields import (
BoundingBoxField,
ColorField,
ConditioningField,
DenoiseMaskField,
@@ -469,3 +470,42 @@ class ConditioningCollectionInvocation(BaseInvocation):
# endregion
# region BoundingBox
@invocation_output("bounding_box_output")
class BoundingBoxOutput(BaseInvocationOutput):
"""Base class for nodes that output a single bounding box"""
bounding_box: BoundingBoxField = OutputField(description="The output bounding box.")
@invocation_output("bounding_box_collection_output")
class BoundingBoxCollectionOutput(BaseInvocationOutput):
"""Base class for nodes that output a collection of bounding boxes"""
collection: list[BoundingBoxField] = OutputField(description="The output bounding boxes.", title="Bounding Boxes")
@invocation(
"bounding_box",
title="Bounding Box",
tags=["primitives", "segmentation", "collection", "bounding box"],
category="primitives",
version="1.0.0",
)
class BoundingBoxInvocation(BaseInvocation):
"""Create a bounding box manually by supplying box coordinates"""
x_min: int = InputField(default=0, description="x-coordinate of the bounding box's top left vertex")
y_min: int = InputField(default=0, description="y-coordinate of the bounding box's top left vertex")
x_max: int = InputField(default=0, description="x-coordinate of the bounding box's bottom right vertex")
y_max: int = InputField(default=0, description="y-coordinate of the bounding box's bottom right vertex")
def invoke(self, context: InvocationContext) -> BoundingBoxOutput:
bounding_box = BoundingBoxField(x_min=self.x_min, y_min=self.y_min, x_max=self.x_max, y_max=self.y_max)
return BoundingBoxOutput(bounding_box=bounding_box)
# endregion

View File

@@ -1,76 +1,161 @@
from typing import Dict, cast
from pathlib import Path
from typing import Literal
import numpy as np
import torch
from PIL import Image
from transformers import AutoModelForMaskGeneration, AutoProcessor
from transformers.models.sam import SamModel
from transformers.models.sam.processing_sam import SamProcessor
from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation
from invokeai.app.invocations.fields import ImageField, InputField
from invokeai.app.invocations.primitives import ImageOutput
from invokeai.app.invocations.fields import BoundingBoxField, ImageField, InputField, TensorField
from invokeai.app.invocations.primitives import MaskOutput
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.image_util.grounding_segment_anything.gsa import GroundingSegmentAnythingDetector
from invokeai.backend.util.devices import TorchDevice
from invokeai.backend.image_util.segment_anything.mask_refinement import mask_to_polygon, polygon_to_mask
from invokeai.backend.image_util.segment_anything.segment_anything_pipeline import SegmentAnythingPipeline
GROUNDING_SEGMENT_ANYTHING_MODELS = {
"groundingdino_swint_ogc": "https://github.com/IDEA-Research/GroundingDINO/releases/download/v0.1.0-alpha/groundingdino_swint_ogc.pth",
"segment_anything_vit_h": "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth",
SegmentAnythingModelKey = Literal["segment-anything-base", "segment-anything-large", "segment-anything-huge"]
SEGMENT_ANYTHING_MODEL_IDS: dict[SegmentAnythingModelKey, str] = {
"segment-anything-base": "facebook/sam-vit-base",
"segment-anything-large": "facebook/sam-vit-large",
"segment-anything-huge": "facebook/sam-vit-huge",
}
@invocation(
"segment_anything",
title="Segment Anything",
tags=["grounding_dino", "segment", "anything"],
category="image",
tags=["prompt", "segmentation"],
category="segmentation",
version="1.0.0",
)
class SegmentAnythingInvocation(BaseInvocation):
"""Automatically generate masks from an image using GroundingDINO & Segment Anything"""
"""Runs a Segment Anything Model."""
image: ImageField = InputField(description="The image to process")
prompt: str = InputField(default="", description="Keywords to segment", title="Prompt")
box_threshold: float = InputField(
default=0.5, ge=0, le=1, description="Threshold of box detection", title="Box Threshold"
# Reference:
# - https://arxiv.org/pdf/2304.02643
# - https://huggingface.co/docs/transformers/v4.43.3/en/model_doc/grounding-dino#grounded-sam
# - https://github.com/NielsRogge/Transformers-Tutorials/blob/a39f33ac1557b02ebfb191ea7753e332b5ca933f/Grounding%20DINO/GroundingDINO_with_Segment_Anything.ipynb
model: SegmentAnythingModelKey = InputField(description="The Segment Anything model to use.")
image: ImageField = InputField(description="The image to segment.")
bounding_boxes: list[BoundingBoxField] = InputField(description="The bounding boxes to prompt the SAM model with.")
apply_polygon_refinement: bool = InputField(
description="Whether to apply polygon refinement to the masks. This will smooth the edges of the masks slightly and ensure that each mask consists of a single closed polygon (before merging).",
default=True,
)
text_threshold: float = InputField(
default=0.5, ge=0, le=1, description="Threshold of text detection", title="Text Threshold"
)
nms_threshold: float = InputField(
default=0.8, ge=0, le=1, description="Threshold of nms detection", title="NMS Threshold"
mask_filter: Literal["all", "largest", "highest_box_score"] = InputField(
description="The filtering to apply to the detected masks before merging them into a final output.",
default="all",
)
def invoke(self, context: InvocationContext) -> ImageOutput:
input_image = context.images.get_pil(self.image.image_name)
@torch.no_grad()
def invoke(self, context: InvocationContext) -> MaskOutput:
# The models expect a 3-channel RGB image.
image_pil = context.images.get_pil(self.image.image_name, mode="RGB")
grounding_dino_model = context.models.load_remote_model(
GROUNDING_SEGMENT_ANYTHING_MODELS["groundingdino_swint_ogc"]
)
segment_anything_model = context.models.load_remote_model(
GROUNDING_SEGMENT_ANYTHING_MODELS["segment_anything_vit_h"]
if len(self.bounding_boxes) == 0:
combined_mask = torch.zeros(image_pil.size[::-1], dtype=torch.bool)
else:
masks = self._segment(context=context, image=image_pil)
masks = self._filter_masks(masks=masks, bounding_boxes=self.bounding_boxes)
# masks contains bool values, so we merge them via max-reduce.
combined_mask, _ = torch.stack(masks).max(dim=0)
mask_tensor_name = context.tensors.save(combined_mask)
height, width = combined_mask.shape
return MaskOutput(mask=TensorField(tensor_name=mask_tensor_name), width=width, height=height)
@staticmethod
def _load_sam_model(model_path: Path):
sam_model = AutoModelForMaskGeneration.from_pretrained(
model_path,
local_files_only=True,
# TODO(ryand): Setting the torch_dtype here doesn't work. Investigate whether fp16 is supported by the
# model, and figure out how to make it work in the pipeline.
# torch_dtype=TorchDevice.choose_torch_dtype(),
)
assert isinstance(sam_model, SamModel)
sam_processor = AutoProcessor.from_pretrained(model_path, local_files_only=True)
assert isinstance(sam_processor, SamProcessor)
return SegmentAnythingPipeline(sam_model=sam_model, sam_processor=sam_processor)
def _segment(
self,
context: InvocationContext,
image: Image.Image,
) -> list[torch.Tensor]:
"""Use Segment Anything (SAM) to generate masks given an image + a set of bounding boxes."""
# Convert the bounding boxes to the SAM input format.
sam_bounding_boxes = [[bb.x_min, bb.y_min, bb.x_max, bb.y_max] for bb in self.bounding_boxes]
with (
grounding_dino_model.model_on_device() as (_, grounding_dino_state_dict),
segment_anything_model.model_on_device() as (_, segment_anything_state_dict),
context.models.load_remote_model(
source=SEGMENT_ANYTHING_MODEL_IDS[self.model], loader=SegmentAnythingInvocation._load_sam_model
) as sam_pipeline,
):
if not grounding_dino_state_dict or not segment_anything_state_dict:
raise RuntimeError("Unable to load segmentation models")
assert isinstance(sam_pipeline, SegmentAnythingPipeline)
masks = sam_pipeline.segment(image=image, bounding_boxes=sam_bounding_boxes)
grounding_dino = GroundingSegmentAnythingDetector.build_grounding_dino(
cast(Dict[str, torch.Tensor], grounding_dino_state_dict), TorchDevice.choose_torch_device()
)
segment_anything = GroundingSegmentAnythingDetector.build_segment_anything(
cast(Dict[str, torch.Tensor], segment_anything_state_dict), TorchDevice.choose_torch_device()
)
detector = GroundingSegmentAnythingDetector(grounding_dino, segment_anything)
masks = self._process_masks(masks)
if self.apply_polygon_refinement:
masks = self._apply_polygon_refinement(masks)
mask = detector.predict(
input_image, self.prompt, self.box_threshold, self.text_threshold, self.nms_threshold
)
image_dto = context.images.save(mask)
return masks
"""Builds an ImageOutput and its ImageField"""
processed_image_field = ImageField(image_name=image_dto.image_name)
return ImageOutput(
image=processed_image_field,
width=input_image.width,
height=input_image.height,
)
def _process_masks(self, masks: torch.Tensor) -> list[torch.Tensor]:
"""Convert the tensor output from the Segment Anything model from a tensor of shape
[num_masks, channels, height, width] to a list of tensors of shape [height, width].
"""
assert masks.dtype == torch.bool
# [num_masks, channels, height, width] -> [num_masks, height, width]
masks, _ = masks.max(dim=1)
# Split the first dimension into a list of masks.
return list(masks.cpu().unbind(dim=0))
def _apply_polygon_refinement(self, masks: list[torch.Tensor]) -> list[torch.Tensor]:
"""Apply polygon refinement to the masks.
Convert each mask to a polygon, then back to a mask. This has the following effect:
- Smooth the edges of the mask slightly.
- Ensure that each mask consists of a single closed polygon
- Removes small mask pieces.
- Removes holes from the mask.
"""
# Convert tensor masks to np masks.
np_masks = [mask.cpu().numpy().astype(np.uint8) for mask in masks]
# Apply polygon refinement.
for idx, mask in enumerate(np_masks):
shape = mask.shape
assert len(shape) == 2 # Assert length to satisfy type checker.
polygon = mask_to_polygon(mask)
mask = polygon_to_mask(polygon, shape)
np_masks[idx] = mask
# Convert np masks back to tensor masks.
masks = [torch.tensor(mask, dtype=torch.bool) for mask in np_masks]
return masks
def _filter_masks(self, masks: list[torch.Tensor], bounding_boxes: list[BoundingBoxField]) -> list[torch.Tensor]:
"""Filter the detected masks based on the specified mask filter."""
assert len(masks) == len(bounding_boxes)
if self.mask_filter == "all":
return masks
elif self.mask_filter == "largest":
# Find the largest mask.
return [max(masks, key=lambda x: float(x.sum()))]
elif self.mask_filter == "highest_box_score":
# Find the index of the bounding box with the highest score.
# Note that we fallback to -1.0 if the score is None. This is mainly to satisfy the type checker. In most
# cases the scores should all be non-None when using this filtering mode. That being said, -1.0 is a
# reasonable fallback since the expected score range is [0.0, 1.0].
max_score_idx = max(range(len(bounding_boxes)), key=lambda i: bounding_boxes[i].score or -1.0)
return [masks[max_score_idx]]
else:
raise ValueError(f"Invalid mask filter: {self.mask_filter}")

View File

@@ -1,81 +0,0 @@
import cv2
import numpy as np
from PIL import Image
from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation
from invokeai.app.invocations.fields import ImageField, InputField, WithBoard, WithMetadata
from invokeai.app.invocations.primitives import ImageOutput
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.vto_workflow.extract_channel import ImageChannel, extract_channel
from invokeai.backend.vto_workflow.overlay_pattern import multiply_images
from invokeai.backend.vto_workflow.seamless_mapping import map_seamless_tiles
@invocation("vto", title="Virtual Try-On", tags=["vto"], category="vto", version="1.1.0")
class VTOInvocation(BaseInvocation, WithMetadata, WithBoard):
"""Virtual try-on."""
original_image: ImageField = InputField(description="The input image")
clothing_mask: ImageField = InputField(description="Clothing mask.")
pattern_image: ImageField = InputField(description="Pattern image.")
pattern_vertical_repeats: float = InputField(
description="Number of vertical repeats for the pattern.", gt=0.01, default=1.0
)
shading_max: float = InputField(
description="The lightness of the light spots on the clothing. Default is 1.0. Typically in the range [0.7, 1.2]. Must be > shading_min",
default=1.0,
ge=0.0,
)
shading_min: float = InputField(
description="The lightness of the dark spots on the clothing. Default id 0.5. Typically in the range [0.2, 0.7]",
default=0.5,
ge=0.0,
)
mask_dilation: int = InputField(
description="The number of pixels to dilate the mask by. Default is 1.",
default=1,
ge=0,
)
def invoke(self, context: InvocationContext) -> ImageOutput:
# TODO(ryand): Avoid all the unnecessary flip-flopping between PIL and numpy.
original_image = context.images.get_pil(self.original_image.image_name)
clothing_mask = context.images.get_pil(self.clothing_mask.image_name)
pattern_image = context.images.get_pil(self.pattern_image.image_name)
shadows = extract_channel(np.array(original_image), ImageChannel.LAB_L)
# Clip the shadows to the 0.05 and 0.95 percentiles to eliminate outliers.
shadows = np.clip(shadows, np.percentile(shadows, 5), np.percentile(shadows, 95))
# Normalize the shadows to the range [shading_min, shading_max].
assert self.shading_min < self.shading_max
shadows = shadows.astype(np.float32)
shadows = (shadows - shadows.min()) / (shadows.max() - shadows.min())
shadows = self.shading_min + (self.shading_max - self.shading_min) * shadows
shadows = np.clip(shadows, 0.0, 1.0)
shadows = (shadows * 255).astype(np.uint8)
expanded_pattern = map_seamless_tiles(
seamless_tile=pattern_image,
target_hw=(original_image.height, original_image.width),
num_repeats_h=self.pattern_vertical_repeats,
)
pattern_with_shadows = multiply_images(expanded_pattern, Image.fromarray(shadows))
# Dilate the mask.
clothing_mask_np = np.array(clothing_mask)
if self.mask_dilation > 0:
clothing_mask_np = cv2.dilate(clothing_mask_np, np.ones((3, 3), np.uint8), iterations=self.mask_dilation)
# Merge the pattern with the model image.
pattern_with_shadows_np = np.array(pattern_with_shadows)
original_image_np = np.array(original_image)
merged_image = np.where(clothing_mask_np[:, :, None], pattern_with_shadows_np, original_image_np)
merged_image = Image.fromarray(merged_image)
image_dto = context.images.save(image=merged_image)
return ImageOutput.build(image_dto)

View File

@@ -91,6 +91,7 @@ class InvokeAIAppConfig(BaseSettings):
db_dir: Path to InvokeAI databases directory.
outputs_dir: Path to directory for outputs.
custom_nodes_dir: Path to directory for custom nodes.
style_presets_dir: Path to directory for style presets.
log_handlers: Log handler. Valid options are "console", "file=<path>", "syslog=path|address:host:port", "http=<url>".
log_format: Log format. Use "plain" for text-only, "color" for colorized output, "legacy" for 2.3-style logging and "syslog" for syslog-style.<br>Valid values: `plain`, `color`, `syslog`, `legacy`
log_level: Emit logging messages at this level or higher.<br>Valid values: `debug`, `info`, `warning`, `error`, `critical`
@@ -153,6 +154,7 @@ class InvokeAIAppConfig(BaseSettings):
db_dir: Path = Field(default=Path("databases"), description="Path to InvokeAI databases directory.")
outputs_dir: Path = Field(default=Path("outputs"), description="Path to directory for outputs.")
custom_nodes_dir: Path = Field(default=Path("nodes"), description="Path to directory for custom nodes.")
style_presets_dir: Path = Field(default=Path("style_presets"), description="Path to directory for style presets.")
# LOGGING
log_handlers: list[str] = Field(default=["console"], description='Log handler. Valid options are "console", "file=<path>", "syslog=path|address:host:port", "http=<url>".')
@@ -300,6 +302,11 @@ class InvokeAIAppConfig(BaseSettings):
"""Path to the models directory, resolved to an absolute path.."""
return self._resolve(self.models_dir)
@property
def style_presets_path(self) -> Path:
"""Path to the style presets directory, resolved to an absolute path.."""
return self._resolve(self.style_presets_dir)
@property
def convert_cache_path(self) -> Path:
"""Path to the converted cache models directory, resolved to an absolute path.."""

View File

@@ -88,6 +88,7 @@ class QueueItemEventBase(QueueEventBase):
item_id: int = Field(description="The ID of the queue item")
batch_id: str = Field(description="The ID of the queue batch")
origin: str | None = Field(default=None, description="The origin of the batch")
class InvocationEventBase(QueueItemEventBase):
@@ -95,8 +96,6 @@ class InvocationEventBase(QueueItemEventBase):
session_id: str = Field(description="The ID of the session (aka graph execution state)")
queue_id: str = Field(description="The ID of the queue")
item_id: int = Field(description="The ID of the queue item")
batch_id: str = Field(description="The ID of the queue batch")
session_id: str = Field(description="The ID of the session (aka graph execution state)")
invocation: AnyInvocation = Field(description="The ID of the invocation")
invocation_source_id: str = Field(description="The ID of the prepared invocation's source node")
@@ -114,6 +113,7 @@ class InvocationStartedEvent(InvocationEventBase):
queue_id=queue_item.queue_id,
item_id=queue_item.item_id,
batch_id=queue_item.batch_id,
origin=queue_item.origin,
session_id=queue_item.session_id,
invocation=invocation,
invocation_source_id=queue_item.session.prepared_source_mapping[invocation.id],
@@ -147,6 +147,7 @@ class InvocationDenoiseProgressEvent(InvocationEventBase):
queue_id=queue_item.queue_id,
item_id=queue_item.item_id,
batch_id=queue_item.batch_id,
origin=queue_item.origin,
session_id=queue_item.session_id,
invocation=invocation,
invocation_source_id=queue_item.session.prepared_source_mapping[invocation.id],
@@ -184,6 +185,7 @@ class InvocationCompleteEvent(InvocationEventBase):
queue_id=queue_item.queue_id,
item_id=queue_item.item_id,
batch_id=queue_item.batch_id,
origin=queue_item.origin,
session_id=queue_item.session_id,
invocation=invocation,
invocation_source_id=queue_item.session.prepared_source_mapping[invocation.id],
@@ -216,6 +218,7 @@ class InvocationErrorEvent(InvocationEventBase):
queue_id=queue_item.queue_id,
item_id=queue_item.item_id,
batch_id=queue_item.batch_id,
origin=queue_item.origin,
session_id=queue_item.session_id,
invocation=invocation,
invocation_source_id=queue_item.session.prepared_source_mapping[invocation.id],
@@ -253,6 +256,7 @@ class QueueItemStatusChangedEvent(QueueItemEventBase):
queue_id=queue_item.queue_id,
item_id=queue_item.item_id,
batch_id=queue_item.batch_id,
origin=queue_item.origin,
session_id=queue_item.session_id,
status=queue_item.status,
error_type=queue_item.error_type,
@@ -279,12 +283,14 @@ class BatchEnqueuedEvent(QueueEventBase):
description="The number of invocations initially requested to be enqueued (may be less than enqueued if queue was full)"
)
priority: int = Field(description="The priority of the batch")
origin: str | None = Field(default=None, description="The origin of the batch")
@classmethod
def build(cls, enqueue_result: EnqueueBatchResult) -> "BatchEnqueuedEvent":
return cls(
queue_id=enqueue_result.queue_id,
batch_id=enqueue_result.batch.batch_id,
origin=enqueue_result.batch.origin,
enqueued=enqueue_result.enqueued,
requested=enqueue_result.requested,
priority=enqueue_result.priority,

View File

@@ -1,46 +1,44 @@
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
import asyncio
import threading
from queue import Empty, Queue
from fastapi_events.dispatcher import dispatch
from invokeai.app.services.events.events_base import EventServiceBase
from invokeai.app.services.events.events_common import (
EventBase,
)
from invokeai.app.services.events.events_common import EventBase
class FastAPIEventService(EventServiceBase):
def __init__(self, event_handler_id: int) -> None:
def __init__(self, event_handler_id: int, loop: asyncio.AbstractEventLoop) -> None:
self.event_handler_id = event_handler_id
self._queue = Queue[EventBase | None]()
self._queue = asyncio.Queue[EventBase | None]()
self._stop_event = threading.Event()
asyncio.create_task(self._dispatch_from_queue(stop_event=self._stop_event))
self._loop = loop
# We need to store a reference to the task so it doesn't get GC'd
# See: https://docs.python.org/3/library/asyncio-task.html#creating-tasks
self._background_tasks: set[asyncio.Task[None]] = set()
task = self._loop.create_task(self._dispatch_from_queue(stop_event=self._stop_event))
self._background_tasks.add(task)
task.add_done_callback(self._background_tasks.remove)
super().__init__()
def stop(self, *args, **kwargs):
self._stop_event.set()
self._queue.put(None)
self._loop.call_soon_threadsafe(self._queue.put_nowait, None)
def dispatch(self, event: EventBase) -> None:
self._queue.put(event)
self._loop.call_soon_threadsafe(self._queue.put_nowait, event)
async def _dispatch_from_queue(self, stop_event: threading.Event):
"""Get events on from the queue and dispatch them, from the correct thread"""
while not stop_event.is_set():
try:
event = self._queue.get(block=False)
event = await self._queue.get()
if not event: # Probably stopping
continue
# Leave the payloads as live pydantic models
dispatch(event, middleware_id=self.event_handler_id, payload_schema_dump=False)
except Empty:
await asyncio.sleep(0.1)
pass
except asyncio.CancelledError as e:
raise e # Raise a proper error

View File

@@ -1,11 +1,10 @@
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654) and the InvokeAI Team
from pathlib import Path
from queue import Queue
from typing import Dict, Optional, Union
from typing import Optional, Union
from PIL import Image, PngImagePlugin
from PIL.Image import Image as PILImageType
from send2trash import send2trash
from invokeai.app.services.image_files.image_files_base import ImageFileStorageBase
from invokeai.app.services.image_files.image_files_common import (
@@ -20,18 +19,12 @@ from invokeai.app.util.thumbnails import get_thumbnail_name, make_thumbnail
class DiskImageFileStorage(ImageFileStorageBase):
"""Stores images on disk"""
__output_folder: Path
__cache_ids: Queue # TODO: this is an incredibly naive cache
__cache: Dict[Path, PILImageType]
__max_cache_size: int
__invoker: Invoker
def __init__(self, output_folder: Union[str, Path]):
self.__cache = {}
self.__cache_ids = Queue()
self.__cache: dict[Path, PILImageType] = {}
self.__cache_ids = Queue[Path]()
self.__max_cache_size = 10 # TODO: get this from config
self.__output_folder: Path = output_folder if isinstance(output_folder, Path) else Path(output_folder)
self.__output_folder = output_folder if isinstance(output_folder, Path) else Path(output_folder)
self.__thumbnails_folder = self.__output_folder / "thumbnails"
# Validate required output folders at launch
self.__validate_storage_folders()
@@ -103,7 +96,7 @@ class DiskImageFileStorage(ImageFileStorageBase):
image_path = self.get_path(image_name)
if image_path.exists():
send2trash(image_path)
image_path.unlink()
if image_path in self.__cache:
del self.__cache[image_path]
@@ -111,7 +104,7 @@ class DiskImageFileStorage(ImageFileStorageBase):
thumbnail_path = self.get_path(thumbnail_name, True)
if thumbnail_path.exists():
send2trash(thumbnail_path)
thumbnail_path.unlink()
if thumbnail_path in self.__cache:
del self.__cache[thumbnail_path]
except Exception as e:

View File

@@ -4,6 +4,8 @@ from __future__ import annotations
from typing import TYPE_CHECKING
from invokeai.app.services.object_serializer.object_serializer_base import ObjectSerializerBase
from invokeai.app.services.style_preset_images.style_preset_images_base import StylePresetImageFileStorageBase
from invokeai.app.services.style_preset_records.style_preset_records_base import StylePresetRecordsStorageBase
if TYPE_CHECKING:
from logging import Logger
@@ -61,6 +63,8 @@ class InvocationServices:
workflow_records: "WorkflowRecordsStorageBase",
tensors: "ObjectSerializerBase[torch.Tensor]",
conditioning: "ObjectSerializerBase[ConditioningFieldData]",
style_preset_records: "StylePresetRecordsStorageBase",
style_preset_image_files: "StylePresetImageFileStorageBase",
):
self.board_images = board_images
self.board_image_records = board_image_records
@@ -85,3 +89,5 @@ class InvocationServices:
self.workflow_records = workflow_records
self.tensors = tensors
self.conditioning = conditioning
self.style_preset_records = style_preset_records
self.style_preset_image_files = style_preset_image_files

View File

@@ -2,7 +2,6 @@ from pathlib import Path
from PIL import Image
from PIL.Image import Image as PILImageType
from send2trash import send2trash
from invokeai.app.services.invoker import Invoker
from invokeai.app.services.model_images.model_images_base import ModelImageFileStorageBase
@@ -70,7 +69,7 @@ class ModelImageFileStorageDisk(ModelImageFileStorageBase):
if not self._validate_path(path):
raise ModelImageFileNotFoundException
send2trash(path)
path.unlink()
except Exception as e:
raise ModelImageFileDeleteException from e

View File

@@ -6,6 +6,7 @@ from invokeai.app.services.session_queue.session_queue_common import (
Batch,
BatchStatus,
CancelByBatchIDsResult,
CancelByOriginResult,
CancelByQueueIDResult,
ClearResult,
EnqueueBatchResult,
@@ -95,6 +96,11 @@ class SessionQueueBase(ABC):
"""Cancels all queue items with matching batch IDs"""
pass
@abstractmethod
def cancel_by_origin(self, queue_id: str, origin: str) -> CancelByOriginResult:
"""Cancels all queue items with the given batch origin"""
pass
@abstractmethod
def cancel_by_queue_id(self, queue_id: str) -> CancelByQueueIDResult:
"""Cancels all queue items with matching queue ID"""

View File

@@ -77,6 +77,7 @@ BatchDataCollection: TypeAlias = list[list[BatchDatum]]
class Batch(BaseModel):
batch_id: str = Field(default_factory=uuid_string, description="The ID of the batch")
origin: str | None = Field(default=None, description="The origin of this batch.")
data: Optional[BatchDataCollection] = Field(default=None, description="The batch data collection.")
graph: Graph = Field(description="The graph to initialize the session with")
workflow: Optional[WorkflowWithoutID] = Field(
@@ -195,6 +196,7 @@ class SessionQueueItemWithoutGraph(BaseModel):
status: QUEUE_ITEM_STATUS = Field(default="pending", description="The status of this queue item")
priority: int = Field(default=0, description="The priority of this queue item")
batch_id: str = Field(description="The ID of the batch associated with this queue item")
origin: str | None = Field(default=None, description="The origin of this queue item. ")
session_id: str = Field(
description="The ID of the session associated with this queue item. The session doesn't exist in graph_executions until the queue item is executed."
)
@@ -294,6 +296,7 @@ class SessionQueueStatus(BaseModel):
class BatchStatus(BaseModel):
queue_id: str = Field(..., description="The ID of the queue")
batch_id: str = Field(..., description="The ID of the batch")
origin: str | None = Field(..., description="The origin of the batch")
pending: int = Field(..., description="Number of queue items with status 'pending'")
in_progress: int = Field(..., description="Number of queue items with status 'in_progress'")
completed: int = Field(..., description="Number of queue items with status 'complete'")
@@ -328,6 +331,12 @@ class CancelByBatchIDsResult(BaseModel):
canceled: int = Field(..., description="Number of queue items canceled")
class CancelByOriginResult(BaseModel):
"""Result of canceling by list of batch ids"""
canceled: int = Field(..., description="Number of queue items canceled")
class CancelByQueueIDResult(CancelByBatchIDsResult):
"""Result of canceling by queue id"""
@@ -433,6 +442,7 @@ class SessionQueueValueToInsert(NamedTuple):
field_values: Optional[str] # field_values json
priority: int # priority
workflow: Optional[str] # workflow json
origin: str | None
ValuesToInsert: TypeAlias = list[SessionQueueValueToInsert]
@@ -453,6 +463,7 @@ def prepare_values_to_insert(queue_id: str, batch: Batch, priority: int, max_new
json.dumps(field_values, default=to_jsonable_python) if field_values else None, # field_values (json)
priority, # priority
json.dumps(workflow, default=to_jsonable_python) if workflow else None, # workflow (json)
batch.origin, # origin
)
)
return values_to_insert

View File

@@ -10,6 +10,7 @@ from invokeai.app.services.session_queue.session_queue_common import (
Batch,
BatchStatus,
CancelByBatchIDsResult,
CancelByOriginResult,
CancelByQueueIDResult,
ClearResult,
EnqueueBatchResult,
@@ -127,8 +128,8 @@ class SqliteSessionQueue(SessionQueueBase):
self.__cursor.executemany(
"""--sql
INSERT INTO session_queue (queue_id, session, session_id, batch_id, field_values, priority, workflow)
VALUES (?, ?, ?, ?, ?, ?, ?)
INSERT INTO session_queue (queue_id, session, session_id, batch_id, field_values, priority, workflow, origin)
VALUES (?, ?, ?, ?, ?, ?, ?, ?)
""",
values_to_insert,
)
@@ -417,11 +418,7 @@ class SqliteSessionQueue(SessionQueueBase):
)
self.__conn.commit()
if current_queue_item is not None and current_queue_item.batch_id in batch_ids:
batch_status = self.get_batch_status(queue_id=queue_id, batch_id=current_queue_item.batch_id)
queue_status = self.get_queue_status(queue_id=queue_id)
self.__invoker.services.events.emit_queue_item_status_changed(
current_queue_item, batch_status, queue_status
)
self._set_queue_item_status(current_queue_item.item_id, "canceled")
except Exception:
self.__conn.rollback()
raise
@@ -429,6 +426,46 @@ class SqliteSessionQueue(SessionQueueBase):
self.__lock.release()
return CancelByBatchIDsResult(canceled=count)
def cancel_by_origin(self, queue_id: str, origin: str) -> CancelByOriginResult:
try:
current_queue_item = self.get_current(queue_id)
self.__lock.acquire()
where = """--sql
WHERE
queue_id == ?
AND origin == ?
AND status != 'canceled'
AND status != 'completed'
AND status != 'failed'
"""
params = (queue_id, origin)
self.__cursor.execute(
f"""--sql
SELECT COUNT(*)
FROM session_queue
{where};
""",
params,
)
count = self.__cursor.fetchone()[0]
self.__cursor.execute(
f"""--sql
UPDATE session_queue
SET status = 'canceled'
{where};
""",
params,
)
self.__conn.commit()
if current_queue_item is not None and current_queue_item.origin == origin:
self._set_queue_item_status(current_queue_item.item_id, "canceled")
except Exception:
self.__conn.rollback()
raise
finally:
self.__lock.release()
return CancelByOriginResult(canceled=count)
def cancel_by_queue_id(self, queue_id: str) -> CancelByQueueIDResult:
try:
current_queue_item = self.get_current(queue_id)
@@ -541,7 +578,8 @@ class SqliteSessionQueue(SessionQueueBase):
started_at,
session_id,
batch_id,
queue_id
queue_id,
origin
FROM session_queue
WHERE queue_id = ?
"""
@@ -621,7 +659,7 @@ class SqliteSessionQueue(SessionQueueBase):
self.__lock.acquire()
self.__cursor.execute(
"""--sql
SELECT status, count(*)
SELECT status, count(*), origin
FROM session_queue
WHERE
queue_id = ?
@@ -633,6 +671,7 @@ class SqliteSessionQueue(SessionQueueBase):
result = cast(list[sqlite3.Row], self.__cursor.fetchall())
total = sum(row[1] for row in result)
counts: dict[str, int] = {row[0]: row[1] for row in result}
origin = result[0]["origin"] if result else None
except Exception:
self.__conn.rollback()
raise
@@ -641,6 +680,7 @@ class SqliteSessionQueue(SessionQueueBase):
return BatchStatus(
batch_id=batch_id,
origin=origin,
queue_id=queue_id,
pending=counts.get("pending", 0),
in_progress=counts.get("in_progress", 0),

View File

@@ -16,6 +16,8 @@ from invokeai.app.services.shared.sqlite_migrator.migrations.migration_10 import
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_11 import build_migration_11
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_12 import build_migration_12
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_13 import build_migration_13
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_14 import build_migration_14
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_15 import build_migration_15
from invokeai.app.services.shared.sqlite_migrator.sqlite_migrator_impl import SqliteMigrator
@@ -49,6 +51,8 @@ def init_db(config: InvokeAIAppConfig, logger: Logger, image_files: ImageFileSto
migrator.register_migration(build_migration_11(app_config=config, logger=logger))
migrator.register_migration(build_migration_12(app_config=config))
migrator.register_migration(build_migration_13())
migrator.register_migration(build_migration_14())
migrator.register_migration(build_migration_15())
migrator.run_migrations()
return db

View File

@@ -0,0 +1,61 @@
import sqlite3
from invokeai.app.services.shared.sqlite_migrator.sqlite_migrator_common import Migration
class Migration14Callback:
def __call__(self, cursor: sqlite3.Cursor) -> None:
self._create_style_presets(cursor)
def _create_style_presets(self, cursor: sqlite3.Cursor) -> None:
"""Create the table used to store style presets."""
tables = [
"""--sql
CREATE TABLE IF NOT EXISTS style_presets (
id TEXT NOT NULL PRIMARY KEY,
name TEXT NOT NULL,
preset_data TEXT NOT NULL,
type TEXT NOT NULL DEFAULT "user",
created_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
-- Updated via trigger
updated_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW'))
);
"""
]
# Add trigger for `updated_at`.
triggers = [
"""--sql
CREATE TRIGGER IF NOT EXISTS style_presets
AFTER UPDATE
ON style_presets FOR EACH ROW
BEGIN
UPDATE style_presets SET updated_at = STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')
WHERE id = old.id;
END;
"""
]
# Add indexes for searchable fields
indices = [
"CREATE INDEX IF NOT EXISTS idx_style_presets_name ON style_presets(name);",
]
for stmt in tables + indices + triggers:
cursor.execute(stmt)
def build_migration_14() -> Migration:
"""
Build the migration from database version 13 to 14..
This migration does the following:
- Create the table used to store style presets.
"""
migration_14 = Migration(
from_version=13,
to_version=14,
callback=Migration14Callback(),
)
return migration_14

View File

@@ -0,0 +1,31 @@
import sqlite3
from invokeai.app.services.shared.sqlite_migrator.sqlite_migrator_common import Migration
class Migration15Callback:
def __call__(self, cursor: sqlite3.Cursor) -> None:
self._add_origin_col(cursor)
def _add_origin_col(self, cursor: sqlite3.Cursor) -> None:
"""
- Adds `origin` column to the session queue table.
"""
cursor.execute("ALTER TABLE session_queue ADD COLUMN origin TEXT;")
def build_migration_15() -> Migration:
"""
Build the migration from database version 14 to 15.
This migration does the following:
- Adds `origin` column to the session queue table.
"""
migration_15 = Migration(
from_version=14,
to_version=15,
callback=Migration15Callback(),
)
return migration_15

Binary file not shown.

After

Width:  |  Height:  |  Size: 98 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 138 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 122 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 123 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 160 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 146 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 119 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 117 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 110 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 46 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 79 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 156 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 141 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 96 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 91 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 88 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 107 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 132 KiB

View File

@@ -0,0 +1,33 @@
from abc import ABC, abstractmethod
from pathlib import Path
from PIL.Image import Image as PILImageType
class StylePresetImageFileStorageBase(ABC):
"""Low-level service responsible for storing and retrieving image files."""
@abstractmethod
def get(self, style_preset_id: str) -> PILImageType:
"""Retrieves a style preset image as PIL Image."""
pass
@abstractmethod
def get_path(self, style_preset_id: str) -> Path:
"""Gets the internal path to a style preset image."""
pass
@abstractmethod
def get_url(self, style_preset_id: str) -> str | None:
"""Gets the URL to fetch a style preset image."""
pass
@abstractmethod
def save(self, style_preset_id: str, image: PILImageType) -> None:
"""Saves a style preset image."""
pass
@abstractmethod
def delete(self, style_preset_id: str) -> None:
"""Deletes a style preset image."""
pass

View File

@@ -0,0 +1,19 @@
class StylePresetImageFileNotFoundException(Exception):
"""Raised when an image file is not found in storage."""
def __init__(self, message: str = "Style preset image file not found"):
super().__init__(message)
class StylePresetImageFileSaveException(Exception):
"""Raised when an image cannot be saved."""
def __init__(self, message: str = "Style preset image file not saved"):
super().__init__(message)
class StylePresetImageFileDeleteException(Exception):
"""Raised when an image cannot be deleted."""
def __init__(self, message: str = "Style preset image file not deleted"):
super().__init__(message)

View File

@@ -0,0 +1,88 @@
from pathlib import Path
from PIL import Image
from PIL.Image import Image as PILImageType
from invokeai.app.services.invoker import Invoker
from invokeai.app.services.style_preset_images.style_preset_images_base import StylePresetImageFileStorageBase
from invokeai.app.services.style_preset_images.style_preset_images_common import (
StylePresetImageFileDeleteException,
StylePresetImageFileNotFoundException,
StylePresetImageFileSaveException,
)
from invokeai.app.services.style_preset_records.style_preset_records_common import PresetType
from invokeai.app.util.misc import uuid_string
from invokeai.app.util.thumbnails import make_thumbnail
class StylePresetImageFileStorageDisk(StylePresetImageFileStorageBase):
"""Stores images on disk"""
def __init__(self, style_preset_images_folder: Path):
self._style_preset_images_folder = style_preset_images_folder
self._validate_storage_folders()
def start(self, invoker: Invoker) -> None:
self._invoker = invoker
def get(self, style_preset_id: str) -> PILImageType:
try:
path = self.get_path(style_preset_id)
return Image.open(path)
except FileNotFoundError as e:
raise StylePresetImageFileNotFoundException from e
def save(self, style_preset_id: str, image: PILImageType) -> None:
try:
self._validate_storage_folders()
image_path = self._style_preset_images_folder / (style_preset_id + ".webp")
thumbnail = make_thumbnail(image, 256)
thumbnail.save(image_path, format="webp")
except Exception as e:
raise StylePresetImageFileSaveException from e
def get_path(self, style_preset_id: str) -> Path:
style_preset = self._invoker.services.style_preset_records.get(style_preset_id)
if style_preset.type is PresetType.Default:
default_images_dir = Path(__file__).parent / Path("default_style_preset_images")
path = default_images_dir / (style_preset.name + ".png")
else:
path = self._style_preset_images_folder / (style_preset_id + ".webp")
return path
def get_url(self, style_preset_id: str) -> str | None:
path = self.get_path(style_preset_id)
if not self._validate_path(path):
return
url = self._invoker.services.urls.get_style_preset_image_url(style_preset_id)
# The image URL never changes, so we must add random query string to it to prevent caching
url += f"?{uuid_string()}"
return url
def delete(self, style_preset_id: str) -> None:
try:
path = self.get_path(style_preset_id)
if not self._validate_path(path):
raise StylePresetImageFileNotFoundException
path.unlink()
except StylePresetImageFileNotFoundException as e:
raise StylePresetImageFileNotFoundException from e
except Exception as e:
raise StylePresetImageFileDeleteException from e
def _validate_path(self, path: Path) -> bool:
"""Validates the path given for an image."""
return path.exists()
def _validate_storage_folders(self) -> None:
"""Checks if the required folders exist and create them if they don't"""
self._style_preset_images_folder.mkdir(parents=True, exist_ok=True)

View File

@@ -0,0 +1,146 @@
[
{
"name": "Photography (General)",
"type": "default",
"preset_data": {
"positive_prompt": "{prompt}. photography. f/2.8 macro photo, bokeh, photorealism",
"negative_prompt": "painting, digital art. sketch, blurry"
}
},
{
"name": "Photography (Studio Lighting)",
"type": "default",
"preset_data": {
"positive_prompt": "{prompt}, photography. f/8 photo. centered subject, studio lighting.",
"negative_prompt": "painting, digital art. sketch, blurry"
}
},
{
"name": "Photography (Landscape)",
"type": "default",
"preset_data": {
"positive_prompt": "{prompt}, landscape photograph, f/12, lifelike, highly detailed.",
"negative_prompt": "painting, digital art. sketch, blurry"
}
},
{
"name": "Photography (Portrait)",
"type": "default",
"preset_data": {
"positive_prompt": "{prompt}. photography. portraiture. catch light in eyes. one flash. rembrandt lighting. Soft box. dark shadows. High contrast. 80mm lens. F2.8.",
"negative_prompt": "painting, digital art. sketch, blurry"
}
},
{
"name": "Photography (Black and White)",
"type": "default",
"preset_data": {
"positive_prompt": "{prompt} photography. natural light. 80mm lens. F1.4. strong contrast, hard light. dark contrast. blurred background. black and white",
"negative_prompt": "painting, digital art. sketch, colour+"
}
},
{
"name": "Architectural Visualization",
"type": "default",
"preset_data": {
"positive_prompt": "{prompt}. architectural photography, f/12, luxury, aesthetically pleasing form and function.",
"negative_prompt": "painting, digital art. sketch, blurry"
}
},
{
"name": "Concept Art (Fantasy)",
"type": "default",
"preset_data": {
"positive_prompt": "concept artwork of a {prompt}. (digital painterly art style)++, mythological, (textured 2d dry media brushpack)++, glazed brushstrokes, otherworldly. painting+, illustration+",
"negative_prompt": "photo. distorted, blurry, out of focus. sketch. (cgi, 3d.)++"
}
},
{
"name": "Concept Art (Sci-Fi)",
"type": "default",
"preset_data": {
"positive_prompt": "(concept art)++, {prompt}, (sleek futurism)++, (textured 2d dry media)++, metallic highlights, digital painting style",
"negative_prompt": "photo. distorted, blurry, out of focus. sketch. (cgi, 3d.)++"
}
},
{
"name": "Concept Art (Character)",
"type": "default",
"preset_data": {
"positive_prompt": "(character concept art)++, stylized painterly digital painting of {prompt}, (painterly, impasto. Dry brush.)++",
"negative_prompt": "photo. distorted, blurry, out of focus. sketch. (cgi, 3d.)++"
}
},
{
"name": "Concept Art (Painterly)",
"type": "default",
"preset_data": {
"positive_prompt": "{prompt} oil painting. high contrast. impasto. sfumato. chiaroscuro. Palette knife.",
"negative_prompt": "photo. smooth. border. frame"
}
},
{
"name": "Environment Art",
"type": "default",
"preset_data": {
"positive_prompt": "{prompt} environment artwork, hyper-realistic digital painting style with cinematic composition, atmospheric, depth and detail, voluminous. textured dry brush 2d media",
"negative_prompt": "photo, distorted, blurry, out of focus. sketch."
}
},
{
"name": "Interior Design (Visualization)",
"type": "default",
"preset_data": {
"positive_prompt": "{prompt} interior design photo, gentle shadows, light mid-tones, dimension, mix of smooth and textured surfaces, focus on negative space and clean lines, focus",
"negative_prompt": "photo, distorted. sketch."
}
},
{
"name": "Product Rendering",
"type": "default",
"preset_data": {
"positive_prompt": "{prompt} high quality product photography, 3d rendering with key lighting, shallow depth of field, simple plain background, studio lighting.",
"negative_prompt": "blurry, sketch, messy, dirty. unfinished."
}
},
{
"name": "Sketch",
"type": "default",
"preset_data": {
"positive_prompt": "{prompt} black and white pencil drawing, off-center composition, cross-hatching for shadows, bold strokes, textured paper. sketch+++",
"negative_prompt": "blurry, photo, painting, color. messy, dirty. unfinished. frame, borders."
}
},
{
"name": "Line Art",
"type": "default",
"preset_data": {
"positive_prompt": "{prompt} Line art. bold outline. simplistic. white background. 2d",
"negative_prompt": "photo. digital art. greyscale. solid black. painting"
}
},
{
"name": "Anime",
"type": "default",
"preset_data": {
"positive_prompt": "{prompt} anime++, bold outline, cel-shaded coloring, shounen, seinen",
"negative_prompt": "(photo)+++. greyscale. solid black. painting"
}
},
{
"name": "Illustration",
"type": "default",
"preset_data": {
"positive_prompt": "{prompt} illustration, bold linework, illustrative details, vector art style, flat coloring",
"negative_prompt": "(photo)+++. greyscale. painting, black and white."
}
},
{
"name": "Vehicles",
"type": "default",
"preset_data": {
"positive_prompt": "A weird futuristic normal auto, {prompt} elegant design, nice color, nice wheels",
"negative_prompt": "sketch. digital art. greyscale. painting"
}
}
]

View File

@@ -0,0 +1,42 @@
from abc import ABC, abstractmethod
from invokeai.app.services.style_preset_records.style_preset_records_common import (
PresetType,
StylePresetChanges,
StylePresetRecordDTO,
StylePresetWithoutId,
)
class StylePresetRecordsStorageBase(ABC):
"""Base class for style preset storage services."""
@abstractmethod
def get(self, style_preset_id: str) -> StylePresetRecordDTO:
"""Get style preset by id."""
pass
@abstractmethod
def create(self, style_preset: StylePresetWithoutId) -> StylePresetRecordDTO:
"""Creates a style preset."""
pass
@abstractmethod
def create_many(self, style_presets: list[StylePresetWithoutId]) -> None:
"""Creates many style presets."""
pass
@abstractmethod
def update(self, style_preset_id: str, changes: StylePresetChanges) -> StylePresetRecordDTO:
"""Updates a style preset."""
pass
@abstractmethod
def delete(self, style_preset_id: str) -> None:
"""Deletes a style preset."""
pass
@abstractmethod
def get_many(self, type: PresetType | None = None) -> list[StylePresetRecordDTO]:
"""Gets many workflows."""
pass

View File

@@ -0,0 +1,139 @@
import codecs
import csv
import json
from enum import Enum
from typing import Any, Optional
import pydantic
from fastapi import UploadFile
from pydantic import AliasChoices, BaseModel, ConfigDict, Field, TypeAdapter
from invokeai.app.util.metaenum import MetaEnum
class StylePresetNotFoundError(Exception):
"""Raised when a style preset is not found"""
class PresetData(BaseModel, extra="forbid"):
positive_prompt: str = Field(description="Positive prompt")
negative_prompt: str = Field(description="Negative prompt")
PresetDataValidator = TypeAdapter(PresetData)
class PresetType(str, Enum, metaclass=MetaEnum):
User = "user"
Default = "default"
Project = "project"
class StylePresetChanges(BaseModel, extra="forbid"):
name: Optional[str] = Field(default=None, description="The style preset's new name.")
preset_data: Optional[PresetData] = Field(default=None, description="The updated data for style preset.")
type: Optional[PresetType] = Field(description="The updated type of the style preset")
class StylePresetWithoutId(BaseModel):
name: str = Field(description="The name of the style preset.")
preset_data: PresetData = Field(description="The preset data")
type: PresetType = Field(description="The type of style preset")
class StylePresetRecordDTO(StylePresetWithoutId):
id: str = Field(description="The style preset ID.")
@classmethod
def from_dict(cls, data: dict[str, Any]) -> "StylePresetRecordDTO":
data["preset_data"] = PresetDataValidator.validate_json(data.get("preset_data", ""))
return StylePresetRecordDTOValidator.validate_python(data)
StylePresetRecordDTOValidator = TypeAdapter(StylePresetRecordDTO)
class StylePresetRecordWithImage(StylePresetRecordDTO):
image: Optional[str] = Field(description="The path for image")
class StylePresetImportRow(BaseModel):
name: str = Field(min_length=1, description="The name of the preset.")
positive_prompt: str = Field(
default="",
description="The positive prompt for the preset.",
validation_alias=AliasChoices("positive_prompt", "prompt"),
)
negative_prompt: str = Field(default="", description="The negative prompt for the preset.")
model_config = ConfigDict(str_strip_whitespace=True, extra="forbid")
StylePresetImportList = list[StylePresetImportRow]
StylePresetImportListTypeAdapter = TypeAdapter(StylePresetImportList)
class UnsupportedFileTypeError(ValueError):
"""Raised when an unsupported file type is encountered"""
pass
class InvalidPresetImportDataError(ValueError):
"""Raised when invalid preset import data is encountered"""
pass
async def parse_presets_from_file(file: UploadFile) -> list[StylePresetWithoutId]:
"""Parses style presets from a file. The file must be a CSV or JSON file.
If CSV, the file must have the following columns:
- name
- prompt (or positive_prompt)
- negative_prompt
If JSON, the file must be a list of objects with the following keys:
- name
- prompt (or positive_prompt)
- negative_prompt
Args:
file (UploadFile): The file to parse.
Returns:
list[StylePresetWithoutId]: The parsed style presets.
Raises:
UnsupportedFileTypeError: If the file type is not supported.
InvalidPresetImportDataError: If the data in the file is invalid.
"""
if file.content_type not in ["text/csv", "application/json"]:
raise UnsupportedFileTypeError()
if file.content_type == "text/csv":
csv_reader = csv.DictReader(codecs.iterdecode(file.file, "utf-8"))
data = list(csv_reader)
else: # file.content_type == "application/json":
json_data = await file.read()
data = json.loads(json_data)
try:
imported_presets = StylePresetImportListTypeAdapter.validate_python(data)
style_presets: list[StylePresetWithoutId] = []
for imported in imported_presets:
preset_data = PresetData(positive_prompt=imported.positive_prompt, negative_prompt=imported.negative_prompt)
style_preset = StylePresetWithoutId(name=imported.name, preset_data=preset_data, type=PresetType.User)
style_presets.append(style_preset)
except pydantic.ValidationError as e:
if file.content_type == "text/csv":
msg = "Invalid CSV format: must include columns 'name', 'prompt', and 'negative_prompt' and name cannot be blank"
else: # file.content_type == "application/json":
msg = "Invalid JSON format: must be a list of objects with keys 'name', 'prompt', and 'negative_prompt' and name cannot be blank"
raise InvalidPresetImportDataError(msg) from e
finally:
file.file.close()
return style_presets

View File

@@ -0,0 +1,215 @@
import json
from pathlib import Path
from invokeai.app.services.invoker import Invoker
from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase
from invokeai.app.services.style_preset_records.style_preset_records_base import StylePresetRecordsStorageBase
from invokeai.app.services.style_preset_records.style_preset_records_common import (
PresetType,
StylePresetChanges,
StylePresetNotFoundError,
StylePresetRecordDTO,
StylePresetWithoutId,
)
from invokeai.app.util.misc import uuid_string
class SqliteStylePresetRecordsStorage(StylePresetRecordsStorageBase):
def __init__(self, db: SqliteDatabase) -> None:
super().__init__()
self._lock = db.lock
self._conn = db.conn
self._cursor = self._conn.cursor()
def start(self, invoker: Invoker) -> None:
self._invoker = invoker
self._sync_default_style_presets()
def get(self, style_preset_id: str) -> StylePresetRecordDTO:
"""Gets a style preset by ID."""
try:
self._lock.acquire()
self._cursor.execute(
"""--sql
SELECT *
FROM style_presets
WHERE id = ?;
""",
(style_preset_id,),
)
row = self._cursor.fetchone()
if row is None:
raise StylePresetNotFoundError(f"Style preset with id {style_preset_id} not found")
return StylePresetRecordDTO.from_dict(dict(row))
except Exception:
self._conn.rollback()
raise
finally:
self._lock.release()
def create(self, style_preset: StylePresetWithoutId) -> StylePresetRecordDTO:
style_preset_id = uuid_string()
try:
self._lock.acquire()
self._cursor.execute(
"""--sql
INSERT OR IGNORE INTO style_presets (
id,
name,
preset_data,
type
)
VALUES (?, ?, ?, ?);
""",
(
style_preset_id,
style_preset.name,
style_preset.preset_data.model_dump_json(),
style_preset.type,
),
)
self._conn.commit()
except Exception:
self._conn.rollback()
raise
finally:
self._lock.release()
return self.get(style_preset_id)
def create_many(self, style_presets: list[StylePresetWithoutId]) -> None:
style_preset_ids = []
try:
self._lock.acquire()
for style_preset in style_presets:
style_preset_id = uuid_string()
style_preset_ids.append(style_preset_id)
self._cursor.execute(
"""--sql
INSERT OR IGNORE INTO style_presets (
id,
name,
preset_data,
type
)
VALUES (?, ?, ?, ?);
""",
(
style_preset_id,
style_preset.name,
style_preset.preset_data.model_dump_json(),
style_preset.type,
),
)
self._conn.commit()
except Exception:
self._conn.rollback()
raise
finally:
self._lock.release()
return None
def update(self, style_preset_id: str, changes: StylePresetChanges) -> StylePresetRecordDTO:
try:
self._lock.acquire()
# Change the name of a style preset
if changes.name is not None:
self._cursor.execute(
"""--sql
UPDATE style_presets
SET name = ?
WHERE id = ?;
""",
(changes.name, style_preset_id),
)
# Change the preset data for a style preset
if changes.preset_data is not None:
self._cursor.execute(
"""--sql
UPDATE style_presets
SET preset_data = ?
WHERE id = ?;
""",
(changes.preset_data.model_dump_json(), style_preset_id),
)
self._conn.commit()
except Exception:
self._conn.rollback()
raise
finally:
self._lock.release()
return self.get(style_preset_id)
def delete(self, style_preset_id: str) -> None:
try:
self._lock.acquire()
self._cursor.execute(
"""--sql
DELETE from style_presets
WHERE id = ?;
""",
(style_preset_id,),
)
self._conn.commit()
except Exception:
self._conn.rollback()
raise
finally:
self._lock.release()
return None
def get_many(self, type: PresetType | None = None) -> list[StylePresetRecordDTO]:
try:
self._lock.acquire()
main_query = """
SELECT
*
FROM style_presets
"""
if type is not None:
main_query += "WHERE type = ? "
main_query += "ORDER BY LOWER(name) ASC"
if type is not None:
self._cursor.execute(main_query, (type,))
else:
self._cursor.execute(main_query)
rows = self._cursor.fetchall()
style_presets = [StylePresetRecordDTO.from_dict(dict(row)) for row in rows]
return style_presets
except Exception:
self._conn.rollback()
raise
finally:
self._lock.release()
def _sync_default_style_presets(self) -> None:
"""Syncs default style presets to the database. Internal use only."""
# First delete all existing default style presets
try:
self._lock.acquire()
self._cursor.execute(
"""--sql
DELETE FROM style_presets
WHERE type = "default";
"""
)
self._conn.commit()
except Exception:
self._conn.rollback()
raise
finally:
self._lock.release()
# Next, parse and create the default style presets
with self._lock, open(Path(__file__).parent / Path("default_style_presets.json"), "r") as file:
presets = json.load(file)
for preset in presets:
style_preset = StylePresetWithoutId.model_validate(preset)
self.create(style_preset)

View File

@@ -13,3 +13,8 @@ class UrlServiceBase(ABC):
def get_model_image_url(self, model_key: str) -> str:
"""Gets the URL for a model image"""
pass
@abstractmethod
def get_style_preset_image_url(self, style_preset_id: str) -> str:
"""Gets the URL for a style preset image"""
pass

View File

@@ -19,3 +19,6 @@ class LocalUrlService(UrlServiceBase):
def get_model_image_url(self, model_key: str) -> str:
return f"{self._base_url_v2}/models/i/{model_key}/image"
def get_style_preset_image_url(self, style_preset_id: str) -> str:
return f"{self._base_url}/style_presets/i/{style_preset_id}/image"

View File

@@ -81,7 +81,7 @@ def get_openapi_func(
# Add the output map to the schema
openapi_schema["components"]["schemas"]["InvocationOutputMap"] = {
"type": "object",
"properties": invocation_output_map_properties,
"properties": dict(sorted(invocation_output_map_properties.items())),
"required": invocation_output_map_required,
}

View File

@@ -1,90 +0,0 @@
from pathlib import Path
from typing import Literal
import cv2
import numpy as np
import torch
import torch.nn.functional as F
from einops import repeat
from PIL import Image
from torchvision.transforms import Compose
from invokeai.app.services.config.config_default import get_config
from invokeai.backend.image_util.depth_anything.model.dpt import DPT_DINOv2
from invokeai.backend.image_util.depth_anything.utilities.util import NormalizeImage, PrepareForNet, Resize
from invokeai.backend.util.logging import InvokeAILogger
config = get_config()
logger = InvokeAILogger.get_logger(config=config)
DEPTH_ANYTHING_MODELS = {
"large": "https://huggingface.co/spaces/LiheYoung/Depth-Anything/resolve/main/checkpoints/depth_anything_vitl14.pth?download=true",
"base": "https://huggingface.co/spaces/LiheYoung/Depth-Anything/resolve/main/checkpoints/depth_anything_vitb14.pth?download=true",
"small": "https://huggingface.co/spaces/LiheYoung/Depth-Anything/resolve/main/checkpoints/depth_anything_vits14.pth?download=true",
}
transform = Compose(
[
Resize(
width=518,
height=518,
resize_target=False,
keep_aspect_ratio=True,
ensure_multiple_of=14,
resize_method="lower_bound",
image_interpolation_method=cv2.INTER_CUBIC,
),
NormalizeImage(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
PrepareForNet(),
]
)
class DepthAnythingDetector:
def __init__(self, model: DPT_DINOv2, device: torch.device) -> None:
self.model = model
self.device = device
@staticmethod
def load_model(
model_path: Path, device: torch.device, model_size: Literal["large", "base", "small"] = "small"
) -> DPT_DINOv2:
match model_size:
case "small":
model = DPT_DINOv2(encoder="vits", features=64, out_channels=[48, 96, 192, 384])
case "base":
model = DPT_DINOv2(encoder="vitb", features=128, out_channels=[96, 192, 384, 768])
case "large":
model = DPT_DINOv2(encoder="vitl", features=256, out_channels=[256, 512, 1024, 1024])
model.load_state_dict(torch.load(model_path.as_posix(), map_location="cpu"))
model.eval()
model.to(device)
return model
def __call__(self, image: Image.Image, resolution: int = 512) -> Image.Image:
if not self.model:
logger.warn("DepthAnything model was not loaded. Returning original image")
return image
np_image = np.array(image, dtype=np.uint8)
np_image = np_image[:, :, ::-1] / 255.0
image_height, image_width = np_image.shape[:2]
np_image = transform({"image": np_image})["image"]
tensor_image = torch.from_numpy(np_image).unsqueeze(0).to(self.device)
with torch.no_grad():
depth = self.model(tensor_image)
depth = F.interpolate(depth[None], (image_height, image_width), mode="bilinear", align_corners=False)[0, 0]
depth = (depth - depth.min()) / (depth.max() - depth.min()) * 255.0
depth_map = repeat(depth, "h w -> h w 3").cpu().numpy().astype(np.uint8)
depth_map = Image.fromarray(depth_map)
new_height = int(image_height * (resolution / image_width))
depth_map = depth_map.resize((resolution, new_height))
return depth_map

View File

@@ -0,0 +1,31 @@
from typing import Optional
import torch
from PIL import Image
from transformers.pipelines import DepthEstimationPipeline
from invokeai.backend.raw_model import RawModel
class DepthAnythingPipeline(RawModel):
"""Custom wrapper for the Depth Estimation pipeline from transformers adding compatibility
for Invoke's Model Management System"""
def __init__(self, pipeline: DepthEstimationPipeline) -> None:
self._pipeline = pipeline
def generate_depth(self, image: Image.Image) -> Image.Image:
depth_map = self._pipeline(image)["depth"]
assert isinstance(depth_map, Image.Image)
return depth_map
def to(self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None):
if device is not None and device.type not in {"cpu", "cuda"}:
device = None
self._pipeline.model.to(device=device, dtype=dtype)
self._pipeline.device = self._pipeline.model.device
def calc_size(self) -> int:
from invokeai.backend.model_manager.load.model_util import calc_module_size
return calc_module_size(self._pipeline.model)

View File

@@ -1,145 +0,0 @@
import torch.nn as nn
def _make_scratch(in_shape, out_shape, groups=1, expand=False):
scratch = nn.Module()
out_shape1 = out_shape
out_shape2 = out_shape
out_shape3 = out_shape
if len(in_shape) >= 4:
out_shape4 = out_shape
if expand:
out_shape1 = out_shape
out_shape2 = out_shape * 2
out_shape3 = out_shape * 4
if len(in_shape) >= 4:
out_shape4 = out_shape * 8
scratch.layer1_rn = nn.Conv2d(
in_shape[0], out_shape1, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
)
scratch.layer2_rn = nn.Conv2d(
in_shape[1], out_shape2, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
)
scratch.layer3_rn = nn.Conv2d(
in_shape[2], out_shape3, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
)
if len(in_shape) >= 4:
scratch.layer4_rn = nn.Conv2d(
in_shape[3], out_shape4, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
)
return scratch
class ResidualConvUnit(nn.Module):
"""Residual convolution module."""
def __init__(self, features, activation, bn):
"""Init.
Args:
features (int): number of features
"""
super().__init__()
self.bn = bn
self.groups = 1
self.conv1 = nn.Conv2d(features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups)
self.conv2 = nn.Conv2d(features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups)
if self.bn:
self.bn1 = nn.BatchNorm2d(features)
self.bn2 = nn.BatchNorm2d(features)
self.activation = activation
self.skip_add = nn.quantized.FloatFunctional()
def forward(self, x):
"""Forward pass.
Args:
x (tensor): input
Returns:
tensor: output
"""
out = self.activation(x)
out = self.conv1(out)
if self.bn:
out = self.bn1(out)
out = self.activation(out)
out = self.conv2(out)
if self.bn:
out = self.bn2(out)
if self.groups > 1:
out = self.conv_merge(out)
return self.skip_add.add(out, x)
class FeatureFusionBlock(nn.Module):
"""Feature fusion block."""
def __init__(self, features, activation, deconv=False, bn=False, expand=False, align_corners=True, size=None):
"""Init.
Args:
features (int): number of features
"""
super(FeatureFusionBlock, self).__init__()
self.deconv = deconv
self.align_corners = align_corners
self.groups = 1
self.expand = expand
out_features = features
if self.expand:
out_features = features // 2
self.out_conv = nn.Conv2d(features, out_features, kernel_size=1, stride=1, padding=0, bias=True, groups=1)
self.resConfUnit1 = ResidualConvUnit(features, activation, bn)
self.resConfUnit2 = ResidualConvUnit(features, activation, bn)
self.skip_add = nn.quantized.FloatFunctional()
self.size = size
def forward(self, *xs, size=None):
"""Forward pass.
Returns:
tensor: output
"""
output = xs[0]
if len(xs) == 2:
res = self.resConfUnit1(xs[1])
output = self.skip_add.add(output, res)
output = self.resConfUnit2(output)
if (size is None) and (self.size is None):
modifier = {"scale_factor": 2}
elif size is None:
modifier = {"size": self.size}
else:
modifier = {"size": size}
output = nn.functional.interpolate(output, **modifier, mode="bilinear", align_corners=self.align_corners)
output = self.out_conv(output)
return output

View File

@@ -1,183 +0,0 @@
from pathlib import Path
import torch
import torch.nn as nn
import torch.nn.functional as F
from invokeai.backend.image_util.depth_anything.model.blocks import FeatureFusionBlock, _make_scratch
torchhub_path = Path(__file__).parent.parent / "torchhub"
def _make_fusion_block(features, use_bn, size=None):
return FeatureFusionBlock(
features,
nn.ReLU(False),
deconv=False,
bn=use_bn,
expand=False,
align_corners=True,
size=size,
)
class DPTHead(nn.Module):
def __init__(self, nclass, in_channels, features, out_channels, use_bn=False, use_clstoken=False):
super(DPTHead, self).__init__()
self.nclass = nclass
self.use_clstoken = use_clstoken
self.projects = nn.ModuleList(
[
nn.Conv2d(
in_channels=in_channels,
out_channels=out_channel,
kernel_size=1,
stride=1,
padding=0,
)
for out_channel in out_channels
]
)
self.resize_layers = nn.ModuleList(
[
nn.ConvTranspose2d(
in_channels=out_channels[0], out_channels=out_channels[0], kernel_size=4, stride=4, padding=0
),
nn.ConvTranspose2d(
in_channels=out_channels[1], out_channels=out_channels[1], kernel_size=2, stride=2, padding=0
),
nn.Identity(),
nn.Conv2d(
in_channels=out_channels[3], out_channels=out_channels[3], kernel_size=3, stride=2, padding=1
),
]
)
if use_clstoken:
self.readout_projects = nn.ModuleList()
for _ in range(len(self.projects)):
self.readout_projects.append(nn.Sequential(nn.Linear(2 * in_channels, in_channels), nn.GELU()))
self.scratch = _make_scratch(
out_channels,
features,
groups=1,
expand=False,
)
self.scratch.stem_transpose = None
self.scratch.refinenet1 = _make_fusion_block(features, use_bn)
self.scratch.refinenet2 = _make_fusion_block(features, use_bn)
self.scratch.refinenet3 = _make_fusion_block(features, use_bn)
self.scratch.refinenet4 = _make_fusion_block(features, use_bn)
head_features_1 = features
head_features_2 = 32
if nclass > 1:
self.scratch.output_conv = nn.Sequential(
nn.Conv2d(head_features_1, head_features_1, kernel_size=3, stride=1, padding=1),
nn.ReLU(True),
nn.Conv2d(head_features_1, nclass, kernel_size=1, stride=1, padding=0),
)
else:
self.scratch.output_conv1 = nn.Conv2d(
head_features_1, head_features_1 // 2, kernel_size=3, stride=1, padding=1
)
self.scratch.output_conv2 = nn.Sequential(
nn.Conv2d(head_features_1 // 2, head_features_2, kernel_size=3, stride=1, padding=1),
nn.ReLU(True),
nn.Conv2d(head_features_2, 1, kernel_size=1, stride=1, padding=0),
nn.ReLU(True),
nn.Identity(),
)
def forward(self, out_features, patch_h, patch_w):
out = []
for i, x in enumerate(out_features):
if self.use_clstoken:
x, cls_token = x[0], x[1]
readout = cls_token.unsqueeze(1).expand_as(x)
x = self.readout_projects[i](torch.cat((x, readout), -1))
else:
x = x[0]
x = x.permute(0, 2, 1).reshape((x.shape[0], x.shape[-1], patch_h, patch_w))
x = self.projects[i](x)
x = self.resize_layers[i](x)
out.append(x)
layer_1, layer_2, layer_3, layer_4 = out
layer_1_rn = self.scratch.layer1_rn(layer_1)
layer_2_rn = self.scratch.layer2_rn(layer_2)
layer_3_rn = self.scratch.layer3_rn(layer_3)
layer_4_rn = self.scratch.layer4_rn(layer_4)
path_4 = self.scratch.refinenet4(layer_4_rn, size=layer_3_rn.shape[2:])
path_3 = self.scratch.refinenet3(path_4, layer_3_rn, size=layer_2_rn.shape[2:])
path_2 = self.scratch.refinenet2(path_3, layer_2_rn, size=layer_1_rn.shape[2:])
path_1 = self.scratch.refinenet1(path_2, layer_1_rn)
out = self.scratch.output_conv1(path_1)
out = F.interpolate(out, (int(patch_h * 14), int(patch_w * 14)), mode="bilinear", align_corners=True)
out = self.scratch.output_conv2(out)
return out
class DPT_DINOv2(nn.Module):
def __init__(
self,
features,
out_channels,
encoder="vitl",
use_bn=False,
use_clstoken=False,
):
super(DPT_DINOv2, self).__init__()
assert encoder in ["vits", "vitb", "vitl"]
# # in case the Internet connection is not stable, please load the DINOv2 locally
# if use_local:
# self.pretrained = torch.hub.load(
# torchhub_path / "facebookresearch_dinov2_main",
# "dinov2_{:}14".format(encoder),
# source="local",
# pretrained=False,
# )
# else:
# self.pretrained = torch.hub.load(
# "facebookresearch/dinov2",
# "dinov2_{:}14".format(encoder),
# )
self.pretrained = torch.hub.load(
"facebookresearch/dinov2",
"dinov2_{:}14".format(encoder),
)
dim = self.pretrained.blocks[0].attn.qkv.in_features
self.depth_head = DPTHead(1, dim, features, out_channels=out_channels, use_bn=use_bn, use_clstoken=use_clstoken)
def forward(self, x):
h, w = x.shape[-2:]
features = self.pretrained.get_intermediate_layers(x, 4, return_class_token=True)
patch_h, patch_w = h // 14, w // 14
depth = self.depth_head(features, patch_h, patch_w)
depth = F.interpolate(depth, size=(h, w), mode="bilinear", align_corners=True)
depth = F.relu(depth)
return depth.squeeze(1)

View File

@@ -1,227 +0,0 @@
import math
import cv2
import numpy as np
import torch
import torch.nn.functional as F
def apply_min_size(sample, size, image_interpolation_method=cv2.INTER_AREA):
"""Rezise the sample to ensure the given size. Keeps aspect ratio.
Args:
sample (dict): sample
size (tuple): image size
Returns:
tuple: new size
"""
shape = list(sample["disparity"].shape)
if shape[0] >= size[0] and shape[1] >= size[1]:
return sample
scale = [0, 0]
scale[0] = size[0] / shape[0]
scale[1] = size[1] / shape[1]
scale = max(scale)
shape[0] = math.ceil(scale * shape[0])
shape[1] = math.ceil(scale * shape[1])
# resize
sample["image"] = cv2.resize(sample["image"], tuple(shape[::-1]), interpolation=image_interpolation_method)
sample["disparity"] = cv2.resize(sample["disparity"], tuple(shape[::-1]), interpolation=cv2.INTER_NEAREST)
sample["mask"] = cv2.resize(
sample["mask"].astype(np.float32),
tuple(shape[::-1]),
interpolation=cv2.INTER_NEAREST,
)
sample["mask"] = sample["mask"].astype(bool)
return tuple(shape)
class Resize(object):
"""Resize sample to given size (width, height)."""
def __init__(
self,
width,
height,
resize_target=True,
keep_aspect_ratio=False,
ensure_multiple_of=1,
resize_method="lower_bound",
image_interpolation_method=cv2.INTER_AREA,
):
"""Init.
Args:
width (int): desired output width
height (int): desired output height
resize_target (bool, optional):
True: Resize the full sample (image, mask, target).
False: Resize image only.
Defaults to True.
keep_aspect_ratio (bool, optional):
True: Keep the aspect ratio of the input sample.
Output sample might not have the given width and height, and
resize behaviour depends on the parameter 'resize_method'.
Defaults to False.
ensure_multiple_of (int, optional):
Output width and height is constrained to be multiple of this parameter.
Defaults to 1.
resize_method (str, optional):
"lower_bound": Output will be at least as large as the given size.
"upper_bound": Output will be at max as large as the given size. (Output size might be smaller
than given size.)
"minimal": Scale as least as possible. (Output size might be smaller than given size.)
Defaults to "lower_bound".
"""
self.__width = width
self.__height = height
self.__resize_target = resize_target
self.__keep_aspect_ratio = keep_aspect_ratio
self.__multiple_of = ensure_multiple_of
self.__resize_method = resize_method
self.__image_interpolation_method = image_interpolation_method
def constrain_to_multiple_of(self, x, min_val=0, max_val=None):
y = (np.round(x / self.__multiple_of) * self.__multiple_of).astype(int)
if max_val is not None and y > max_val:
y = (np.floor(x / self.__multiple_of) * self.__multiple_of).astype(int)
if y < min_val:
y = (np.ceil(x / self.__multiple_of) * self.__multiple_of).astype(int)
return y
def get_size(self, width, height):
# determine new height and width
scale_height = self.__height / height
scale_width = self.__width / width
if self.__keep_aspect_ratio:
if self.__resize_method == "lower_bound":
# scale such that output size is lower bound
if scale_width > scale_height:
# fit width
scale_height = scale_width
else:
# fit height
scale_width = scale_height
elif self.__resize_method == "upper_bound":
# scale such that output size is upper bound
if scale_width < scale_height:
# fit width
scale_height = scale_width
else:
# fit height
scale_width = scale_height
elif self.__resize_method == "minimal":
# scale as least as possbile
if abs(1 - scale_width) < abs(1 - scale_height):
# fit width
scale_height = scale_width
else:
# fit height
scale_width = scale_height
else:
raise ValueError(f"resize_method {self.__resize_method} not implemented")
if self.__resize_method == "lower_bound":
new_height = self.constrain_to_multiple_of(scale_height * height, min_val=self.__height)
new_width = self.constrain_to_multiple_of(scale_width * width, min_val=self.__width)
elif self.__resize_method == "upper_bound":
new_height = self.constrain_to_multiple_of(scale_height * height, max_val=self.__height)
new_width = self.constrain_to_multiple_of(scale_width * width, max_val=self.__width)
elif self.__resize_method == "minimal":
new_height = self.constrain_to_multiple_of(scale_height * height)
new_width = self.constrain_to_multiple_of(scale_width * width)
else:
raise ValueError(f"resize_method {self.__resize_method} not implemented")
return (new_width, new_height)
def __call__(self, sample):
width, height = self.get_size(sample["image"].shape[1], sample["image"].shape[0])
# resize sample
sample["image"] = cv2.resize(
sample["image"],
(width, height),
interpolation=self.__image_interpolation_method,
)
if self.__resize_target:
if "disparity" in sample:
sample["disparity"] = cv2.resize(
sample["disparity"],
(width, height),
interpolation=cv2.INTER_NEAREST,
)
if "depth" in sample:
sample["depth"] = cv2.resize(sample["depth"], (width, height), interpolation=cv2.INTER_NEAREST)
if "semseg_mask" in sample:
# sample["semseg_mask"] = cv2.resize(
# sample["semseg_mask"], (width, height), interpolation=cv2.INTER_NEAREST
# )
sample["semseg_mask"] = F.interpolate(
torch.from_numpy(sample["semseg_mask"]).float()[None, None, ...], (height, width), mode="nearest"
).numpy()[0, 0]
if "mask" in sample:
sample["mask"] = cv2.resize(
sample["mask"].astype(np.float32),
(width, height),
interpolation=cv2.INTER_NEAREST,
)
# sample["mask"] = sample["mask"].astype(bool)
# print(sample['image'].shape, sample['depth'].shape)
return sample
class NormalizeImage(object):
"""Normlize image by given mean and std."""
def __init__(self, mean, std):
self.__mean = mean
self.__std = std
def __call__(self, sample):
sample["image"] = (sample["image"] - self.__mean) / self.__std
return sample
class PrepareForNet(object):
"""Prepare sample for usage as network input."""
def __init__(self):
pass
def __call__(self, sample):
image = np.transpose(sample["image"], (2, 0, 1))
sample["image"] = np.ascontiguousarray(image).astype(np.float32)
if "mask" in sample:
sample["mask"] = sample["mask"].astype(np.float32)
sample["mask"] = np.ascontiguousarray(sample["mask"])
if "depth" in sample:
depth = sample["depth"].astype(np.float32)
sample["depth"] = np.ascontiguousarray(depth)
if "semseg_mask" in sample:
sample["semseg_mask"] = sample["semseg_mask"].astype(np.float32)
sample["semseg_mask"] = np.ascontiguousarray(sample["semseg_mask"])
return sample

View File

@@ -0,0 +1,22 @@
from pydantic import BaseModel, ConfigDict
class BoundingBox(BaseModel):
"""Bounding box helper class."""
xmin: int
ymin: int
xmax: int
ymax: int
class DetectionResult(BaseModel):
"""Detection result from Grounding DINO."""
score: float
label: str
box: BoundingBox
model_config = ConfigDict(
# Allow arbitrary types for mask, since it will be a numpy array.
arbitrary_types_allowed=True
)

View File

@@ -0,0 +1,37 @@
from typing import Optional
import torch
from PIL import Image
from transformers.pipelines import ZeroShotObjectDetectionPipeline
from invokeai.backend.image_util.grounding_dino.detection_result import DetectionResult
from invokeai.backend.raw_model import RawModel
class GroundingDinoPipeline(RawModel):
"""A wrapper class for a ZeroShotObjectDetectionPipeline that makes it compatible with the model manager's memory
management system.
"""
def __init__(self, pipeline: ZeroShotObjectDetectionPipeline):
self._pipeline = pipeline
def detect(self, image: Image.Image, candidate_labels: list[str], threshold: float = 0.1) -> list[DetectionResult]:
results = self._pipeline(image=image, candidate_labels=candidate_labels, threshold=threshold)
assert results is not None
results = [DetectionResult.model_validate(result) for result in results]
return results
def to(self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None):
# HACK(ryand): The GroundingDinoPipeline does not work on MPS devices. We only allow it to be moved to CPU or
# CUDA.
if device is not None and device.type not in {"cpu", "cuda"}:
device = None
self._pipeline.model.to(device=device, dtype=dtype)
self._pipeline.device = self._pipeline.model.device
def calc_size(self) -> int:
# HACK(ryand): Fix the circular import issue.
from invokeai.backend.model_manager.load.model_util import calc_module_size
return calc_module_size(self._pipeline.model)

View File

@@ -1,43 +0,0 @@
batch_size = 1
modelname = "groundingdino"
backbone = "swin_B_384_22k"
position_embedding = "sine"
pe_temperatureH = 20
pe_temperatureW = 20
return_interm_indices = [1, 2, 3]
backbone_freeze_keywords = None
enc_layers = 6
dec_layers = 6
pre_norm = False
dim_feedforward = 2048
hidden_dim = 256
dropout = 0.0
nheads = 8
num_queries = 900
query_dim = 4
num_patterns = 0
num_feature_levels = 4
enc_n_points = 4
dec_n_points = 4
two_stage_type = "standard"
two_stage_bbox_embed_share = False
two_stage_class_embed_share = False
transformer_activation = "relu"
dec_pred_bbox_embed_share = True
dn_box_noise_scale = 1.0
dn_label_noise_ratio = 0.5
dn_label_coef = 1.0
dn_bbox_coef = 1.0
embed_init_tgt = True
dn_labelbook_size = 2000
max_text_len = 256
text_encoder_type = "bert-base-uncased"
use_text_enhancer = True
use_fusion_layer = True
use_checkpoint = True
use_transformer_ckpt = True
use_text_cross_attention = True
text_dropout = 0.0
fusion_dropout = 0.0
fusion_droppath = 0.1
sub_sentence_present = True

View File

@@ -1,43 +0,0 @@
batch_size = 1
modelname = "groundingdino"
backbone = "swin_T_224_1k"
position_embedding = "sine"
pe_temperatureH = 20
pe_temperatureW = 20
return_interm_indices = [1, 2, 3]
backbone_freeze_keywords = None
enc_layers = 6
dec_layers = 6
pre_norm = False
dim_feedforward = 2048
hidden_dim = 256
dropout = 0.0
nheads = 8
num_queries = 900
query_dim = 4
num_patterns = 0
num_feature_levels = 4
enc_n_points = 4
dec_n_points = 4
two_stage_type = "standard"
two_stage_bbox_embed_share = False
two_stage_class_embed_share = False
transformer_activation = "relu"
dec_pred_bbox_embed_share = True
dn_box_noise_scale = 1.0
dn_label_noise_ratio = 0.5
dn_label_coef = 1.0
dn_bbox_coef = 1.0
embed_init_tgt = True
dn_labelbook_size = 2000
max_text_len = 256
text_encoder_type = "bert-base-uncased"
use_text_enhancer = True
use_fusion_layer = True
use_checkpoint = True
use_transformer_ckpt = True
use_text_cross_attention = True
text_dropout = 0.0
fusion_dropout = 0.0
fusion_droppath = 0.1
sub_sentence_present = True

View File

@@ -1,299 +0,0 @@
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
"""
Transforms and data augmentation for both image + bbox.
"""
import os
import random
import PIL
import torch
import torchvision.transforms as T
import torchvision.transforms.functional as F
from invokeai.backend.image_util.grounding_segment_anything.groundingdino.util.box_ops import box_xyxy_to_cxcywh
from invokeai.backend.image_util.grounding_segment_anything.groundingdino.util.misc import interpolate
def crop(image, target, region):
cropped_image = F.crop(image, *region)
target = target.copy()
i, j, h, w = region
# should we do something wrt the original size?
target["size"] = torch.tensor([h, w])
fields = ["labels", "area", "iscrowd", "positive_map"]
if "boxes" in target:
boxes = target["boxes"]
max_size = torch.as_tensor([w, h], dtype=torch.float32)
cropped_boxes = boxes - torch.as_tensor([j, i, j, i])
cropped_boxes = torch.min(cropped_boxes.reshape(-1, 2, 2), max_size)
cropped_boxes = cropped_boxes.clamp(min=0)
area = (cropped_boxes[:, 1, :] - cropped_boxes[:, 0, :]).prod(dim=1)
target["boxes"] = cropped_boxes.reshape(-1, 4)
target["area"] = area
fields.append("boxes")
if "masks" in target:
# FIXME should we update the area here if there are no boxes?
target["masks"] = target["masks"][:, i : i + h, j : j + w]
fields.append("masks")
# remove elements for which the boxes or masks that have zero area
if "boxes" in target or "masks" in target:
# favor boxes selection when defining which elements to keep
# this is compatible with previous implementation
if "boxes" in target:
cropped_boxes = target["boxes"].reshape(-1, 2, 2)
keep = torch.all(cropped_boxes[:, 1, :] > cropped_boxes[:, 0, :], dim=1)
else:
keep = target["masks"].flatten(1).any(1)
for field in fields:
if field in target:
target[field] = target[field][keep]
if os.environ.get("IPDB_SHILONG_DEBUG", None) == "INFO":
# for debug and visualization only.
if "strings_positive" in target:
target["strings_positive"] = [_i for _i, _j in zip(target["strings_positive"], keep, strict=False) if _j]
return cropped_image, target
def hflip(image, target):
flipped_image = F.hflip(image)
w, h = image.size
target = target.copy()
if "boxes" in target:
boxes = target["boxes"]
boxes = boxes[:, [2, 1, 0, 3]] * torch.as_tensor([-1, 1, -1, 1]) + torch.as_tensor([w, 0, w, 0])
target["boxes"] = boxes
if "masks" in target:
target["masks"] = target["masks"].flip(-1)
return flipped_image, target
def resize(image, target, size, max_size=None):
# size can be min_size (scalar) or (w, h) tuple
def get_size_with_aspect_ratio(image_size, size, max_size=None):
w, h = image_size
if max_size is not None:
min_original_size = float(min((w, h)))
max_original_size = float(max((w, h)))
if max_original_size / min_original_size * size > max_size:
size = int(round(max_size * min_original_size / max_original_size))
if (w <= h and w == size) or (h <= w and h == size):
return (h, w)
if w < h:
ow = size
oh = int(size * h / w)
else:
oh = size
ow = int(size * w / h)
return (oh, ow)
def get_size(image_size, size, max_size=None):
if isinstance(size, (list, tuple)):
return size[::-1]
else:
return get_size_with_aspect_ratio(image_size, size, max_size)
size = get_size(image.size, size, max_size)
rescaled_image = F.resize(image, size)
if target is None:
return rescaled_image, None
ratios = tuple(float(s) / float(s_orig) for s, s_orig in zip(rescaled_image.size, image.size, strict=False))
ratio_width, ratio_height = ratios
target = target.copy()
if "boxes" in target:
boxes = target["boxes"]
scaled_boxes = boxes * torch.as_tensor([ratio_width, ratio_height, ratio_width, ratio_height])
target["boxes"] = scaled_boxes
if "area" in target:
area = target["area"]
scaled_area = area * (ratio_width * ratio_height)
target["area"] = scaled_area
h, w = size
target["size"] = torch.tensor([h, w])
if "masks" in target:
target["masks"] = interpolate(target["masks"][:, None].float(), size, mode="nearest")[:, 0] > 0.5
return rescaled_image, target
def pad(image, target, padding):
# assumes that we only pad on the bottom right corners
padded_image = F.pad(image, (0, 0, padding[0], padding[1]))
if target is None:
return padded_image, None
target = target.copy()
# should we do something wrt the original size?
target["size"] = torch.tensor(padded_image.size[::-1])
if "masks" in target:
target["masks"] = torch.nn.functional.pad(target["masks"], (0, padding[0], 0, padding[1]))
return padded_image, target
class ResizeDebug(object):
def __init__(self, size):
self.size = size
def __call__(self, img, target):
return resize(img, target, self.size)
class RandomCrop(object):
def __init__(self, size):
self.size = size
def __call__(self, img, target):
region = T.RandomCrop.get_params(img, self.size)
return crop(img, target, region)
class RandomSizeCrop(object):
def __init__(self, min_size: int, max_size: int, respect_boxes: bool = False):
# respect_boxes: True to keep all boxes
# False to tolerence box filter
self.min_size = min_size
self.max_size = max_size
self.respect_boxes = respect_boxes
def __call__(self, img: PIL.Image.Image, target: dict):
init_boxes = len(target["boxes"])
max_patience = 10
for i in range(max_patience):
w = random.randint(self.min_size, min(img.width, self.max_size))
h = random.randint(self.min_size, min(img.height, self.max_size))
region = T.RandomCrop.get_params(img, [h, w])
result_img, result_target = crop(img, target, region)
if not self.respect_boxes or len(result_target["boxes"]) == init_boxes or i == max_patience - 1:
return result_img, result_target
return result_img, result_target
class CenterCrop(object):
def __init__(self, size):
self.size = size
def __call__(self, img, target):
image_width, image_height = img.size
crop_height, crop_width = self.size
crop_top = int(round((image_height - crop_height) / 2.0))
crop_left = int(round((image_width - crop_width) / 2.0))
return crop(img, target, (crop_top, crop_left, crop_height, crop_width))
class RandomHorizontalFlip(object):
def __init__(self, p=0.5):
self.p = p
def __call__(self, img, target):
if random.random() < self.p:
return hflip(img, target)
return img, target
class RandomResize(object):
def __init__(self, sizes, max_size=None):
assert isinstance(sizes, (list, tuple))
self.sizes = sizes
self.max_size = max_size
def __call__(self, img, target=None):
size = random.choice(self.sizes)
return resize(img, target, size, self.max_size)
class RandomPad(object):
def __init__(self, max_pad):
self.max_pad = max_pad
def __call__(self, img, target):
pad_x = random.randint(0, self.max_pad)
pad_y = random.randint(0, self.max_pad)
return pad(img, target, (pad_x, pad_y))
class RandomSelect(object):
"""
Randomly selects between transforms1 and transforms2,
with probability p for transforms1 and (1 - p) for transforms2
"""
def __init__(self, transforms1, transforms2, p=0.5):
self.transforms1 = transforms1
self.transforms2 = transforms2
self.p = p
def __call__(self, img, target):
if random.random() < self.p:
return self.transforms1(img, target)
return self.transforms2(img, target)
class ToTensor(object):
def __call__(self, img, target):
return F.to_tensor(img), target
class RandomErasing(object):
def __init__(self, *args, **kwargs):
self.eraser = T.RandomErasing(*args, **kwargs)
def __call__(self, img, target):
return self.eraser(img), target
class Normalize(object):
def __init__(self, mean, std):
self.mean = mean
self.std = std
def __call__(self, image, target=None):
image = F.normalize(image, mean=self.mean, std=self.std)
if target is None:
return image, None
target = target.copy()
h, w = image.shape[-2:]
if "boxes" in target:
boxes = target["boxes"]
boxes = box_xyxy_to_cxcywh(boxes)
boxes = boxes / torch.tensor([w, h, w, h], dtype=torch.float32)
target["boxes"] = boxes
return image, target
class Compose(object):
def __init__(self, transforms):
self.transforms = transforms
def __call__(self, image, target):
for t in self.transforms:
image, target = t(image, target)
return image, target
def __repr__(self):
format_string = self.__class__.__name__ + "("
for t in self.transforms:
format_string += "\n"
format_string += " {0}".format(t)
format_string += "\n)"
return format_string

View File

@@ -1,17 +0,0 @@
# ------------------------------------------------------------------------
# Grounding DINO
# url: https://github.com/IDEA-Research/GroundingDINO
# Copyright (c) 2023 IDEA. All Rights Reserved.
# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
# ------------------------------------------------------------------------
# Conditional DETR
# Copyright (c) 2021 Microsoft. All Rights Reserved.
# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
# ------------------------------------------------------------------------
# Copied from DETR (https://github.com/facebookresearch/detr)
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
# ------------------------------------------------------------------------
from invokeai.backend.image_util.grounding_segment_anything.groundingdino.models.GroundingDINO.groundingdino import (
build_groundingdino,
)

View File

@@ -1,217 +0,0 @@
# ------------------------------------------------------------------------
# Grounding DINO
# url: https://github.com/IDEA-Research/GroundingDINO
# Copyright (c) 2023 IDEA. All Rights Reserved.
# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
# ------------------------------------------------------------------------
# Conditional DETR
# Copyright (c) 2021 Microsoft. All Rights Reserved.
# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
# ------------------------------------------------------------------------
# Copied from DETR (https://github.com/facebookresearch/detr)
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
# ------------------------------------------------------------------------
"""
Backbone modules.
"""
from typing import Dict, List
import torch
import torch.nn.functional as F
import torchvision
from torch import nn
from torchvision.models._utils import IntermediateLayerGetter
from invokeai.backend.image_util.grounding_segment_anything.groundingdino.models.GroundingDINO.backbone.position_encoding import (
build_position_encoding,
)
from invokeai.backend.image_util.grounding_segment_anything.groundingdino.models.GroundingDINO.backbone.swin_transformer import (
build_swin_transformer,
)
from invokeai.backend.image_util.grounding_segment_anything.groundingdino.util.misc import NestedTensor, is_main_process
class FrozenBatchNorm2d(torch.nn.Module):
"""
BatchNorm2d where the batch statistics and the affine parameters are fixed.
Copy-paste from torchvision.misc.ops with added eps before rqsrt,
without which any other models than torchvision.models.resnet[18,34,50,101]
produce nans.
"""
def __init__(self, n):
super(FrozenBatchNorm2d, self).__init__()
self.register_buffer("weight", torch.ones(n))
self.register_buffer("bias", torch.zeros(n))
self.register_buffer("running_mean", torch.zeros(n))
self.register_buffer("running_var", torch.ones(n))
def _load_from_state_dict(
self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
):
num_batches_tracked_key = prefix + "num_batches_tracked"
if num_batches_tracked_key in state_dict:
del state_dict[num_batches_tracked_key]
super(FrozenBatchNorm2d, self)._load_from_state_dict(
state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
)
def forward(self, x):
# move reshapes to the beginning
# to make it fuser-friendly
w = self.weight.reshape(1, -1, 1, 1)
b = self.bias.reshape(1, -1, 1, 1)
rv = self.running_var.reshape(1, -1, 1, 1)
rm = self.running_mean.reshape(1, -1, 1, 1)
eps = 1e-5
scale = w * (rv + eps).rsqrt()
bias = b - rm * scale
return x * scale + bias
class BackboneBase(nn.Module):
def __init__(
self,
backbone: nn.Module,
train_backbone: bool,
num_channels: int,
return_interm_indices: list,
):
super().__init__()
for name, parameter in backbone.named_parameters():
if not train_backbone or "layer2" not in name and "layer3" not in name and "layer4" not in name:
parameter.requires_grad_(False)
return_layers = {}
for idx, layer_index in enumerate(return_interm_indices):
return_layers.update({"layer{}".format(5 - len(return_interm_indices) + idx): "{}".format(layer_index)})
# if len:
# if use_stage1_feature:
# return_layers = {"layer1": "0", "layer2": "1", "layer3": "2", "layer4": "3"}
# else:
# return_layers = {"layer2": "0", "layer3": "1", "layer4": "2"}
# else:
# return_layers = {'layer4': "0"}
self.body = IntermediateLayerGetter(backbone, return_layers=return_layers)
self.num_channels = num_channels
def forward(self, tensor_list: NestedTensor):
xs = self.body(tensor_list.tensors)
out: Dict[str, NestedTensor] = {}
for name, x in xs.items():
m = tensor_list.mask
assert m is not None
mask = F.interpolate(m[None].float(), size=x.shape[-2:]).to(torch.bool)[0]
out[name] = NestedTensor(x, mask)
# import ipdb; ipdb.set_trace()
return out
class Backbone(BackboneBase):
"""ResNet backbone with frozen BatchNorm."""
def __init__(
self,
name: str,
train_backbone: bool,
dilation: bool,
return_interm_indices: list,
batch_norm=FrozenBatchNorm2d,
):
if name in ["resnet18", "resnet34", "resnet50", "resnet101"]:
backbone = getattr(torchvision.models, name)(
replace_stride_with_dilation=[False, False, dilation],
pretrained=is_main_process(),
norm_layer=batch_norm,
)
else:
raise NotImplementedError("Why you can get here with name {}".format(name))
# num_channels = 512 if name in ('resnet18', 'resnet34') else 2048
assert name not in ("resnet18", "resnet34"), "Only resnet50 and resnet101 are available."
assert return_interm_indices in [[0, 1, 2, 3], [1, 2, 3], [3]]
num_channels_all = [256, 512, 1024, 2048]
num_channels = num_channels_all[4 - len(return_interm_indices) :]
super().__init__(backbone, train_backbone, num_channels, return_interm_indices)
class Joiner(nn.Sequential):
def __init__(self, backbone, position_embedding):
super().__init__(backbone, position_embedding)
def forward(self, tensor_list: NestedTensor):
xs = self[0](tensor_list)
out: List[NestedTensor] = []
pos = []
for name, x in xs.items():
out.append(x)
# position encoding
pos.append(self[1](x).to(x.tensors.dtype))
return out, pos
def build_backbone(args):
"""
Useful args:
- backbone: backbone name
- lr_backbone:
- dilation
- return_interm_indices: available: [0,1,2,3], [1,2,3], [3]
- backbone_freeze_keywords:
- use_checkpoint: for swin only for now
"""
position_embedding = build_position_encoding(args)
train_backbone = True
if not train_backbone:
raise ValueError("Please set lr_backbone > 0")
return_interm_indices = args.return_interm_indices
assert return_interm_indices in [[0, 1, 2, 3], [1, 2, 3], [3]]
args.backbone_freeze_keywords
use_checkpoint = getattr(args, "use_checkpoint", False)
if args.backbone in ["resnet50", "resnet101"]:
backbone = Backbone(
args.backbone,
train_backbone,
args.dilation,
return_interm_indices,
batch_norm=FrozenBatchNorm2d,
)
bb_num_channels = backbone.num_channels
elif args.backbone in [
"swin_T_224_1k",
"swin_B_224_22k",
"swin_B_384_22k",
"swin_L_224_22k",
"swin_L_384_22k",
]:
pretrain_img_size = int(args.backbone.split("_")[-2])
backbone = build_swin_transformer(
args.backbone,
pretrain_img_size=pretrain_img_size,
out_indices=tuple(return_interm_indices),
dilation=False,
use_checkpoint=use_checkpoint,
)
bb_num_channels = backbone.num_features[4 - len(return_interm_indices) :]
else:
raise NotImplementedError("Unknown backbone {}".format(args.backbone))
assert len(bb_num_channels) == len(
return_interm_indices
), f"len(bb_num_channels) {len(bb_num_channels)} != len(return_interm_indices) {len(return_interm_indices)}"
model = Joiner(backbone, position_embedding)
model.num_channels = bb_num_channels
assert isinstance(bb_num_channels, List), "bb_num_channels is expected to be a List but {}".format(
type(bb_num_channels)
)
# import ipdb; ipdb.set_trace()
return model

View File

@@ -1,176 +0,0 @@
# ------------------------------------------------------------------------
# Grounding DINO
# url: https://github.com/IDEA-Research/GroundingDINO
# Copyright (c) 2023 IDEA. All Rights Reserved.
# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
# ------------------------------------------------------------------------
# DINO
# Copyright (c) 2022 IDEA. All Rights Reserved.
# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
# ------------------------------------------------------------------------
# Conditional DETR
# Copyright (c) 2021 Microsoft. All Rights Reserved.
# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
# ------------------------------------------------------------------------
# Copied from DETR (https://github.com/facebookresearch/detr)
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
# ------------------------------------------------------------------------
"""
Various positional encodings for the transformer.
"""
import math
import torch
from torch import nn
from invokeai.backend.image_util.grounding_segment_anything.groundingdino.util.misc import NestedTensor
class PositionEmbeddingSine(nn.Module):
"""
This is a more standard version of the position embedding, very similar to the one
used by the Attention is all you need paper, generalized to work on images.
"""
def __init__(self, num_pos_feats=64, temperature=10000, normalize=False, scale=None):
super().__init__()
self.num_pos_feats = num_pos_feats
self.temperature = temperature
self.normalize = normalize
if scale is not None and normalize is False:
raise ValueError("normalize should be True if scale is passed")
if scale is None:
scale = 2 * math.pi
self.scale = scale
def forward(self, tensor_list: NestedTensor):
x = tensor_list.tensors
mask = tensor_list.mask
assert mask is not None
not_mask = ~mask
y_embed = not_mask.cumsum(1, dtype=torch.float32)
x_embed = not_mask.cumsum(2, dtype=torch.float32)
if self.normalize:
eps = 1e-6
# if os.environ.get("SHILONG_AMP", None) == '1':
# eps = 1e-4
# else:
# eps = 1e-6
y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale
x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale
dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device)
dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats)
pos_x = x_embed[:, :, :, None] / dim_t
pos_y = y_embed[:, :, :, None] / dim_t
pos_x = torch.stack((pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4).flatten(3)
pos_y = torch.stack((pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4).flatten(3)
pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
return pos
class PositionEmbeddingSineHW(nn.Module):
"""
This is a more standard version of the position embedding, very similar to the one
used by the Attention is all you need paper, generalized to work on images.
"""
def __init__(self, num_pos_feats=64, temperatureH=10000, temperatureW=10000, normalize=False, scale=None):
super().__init__()
self.num_pos_feats = num_pos_feats
self.temperatureH = temperatureH
self.temperatureW = temperatureW
self.normalize = normalize
if scale is not None and normalize is False:
raise ValueError("normalize should be True if scale is passed")
if scale is None:
scale = 2 * math.pi
self.scale = scale
def forward(self, tensor_list: NestedTensor):
x = tensor_list.tensors
mask = tensor_list.mask
assert mask is not None
not_mask = ~mask
y_embed = not_mask.cumsum(1, dtype=torch.float32)
x_embed = not_mask.cumsum(2, dtype=torch.float32)
# import ipdb; ipdb.set_trace()
if self.normalize:
eps = 1e-6
y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale
x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale
dim_tx = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device)
dim_tx = self.temperatureW ** (2 * (torch.div(dim_tx, 2, rounding_mode="floor")) / self.num_pos_feats)
pos_x = x_embed[:, :, :, None] / dim_tx
dim_ty = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device)
dim_ty = self.temperatureH ** (2 * (torch.div(dim_ty, 2, rounding_mode="floor")) / self.num_pos_feats)
pos_y = y_embed[:, :, :, None] / dim_ty
pos_x = torch.stack((pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4).flatten(3)
pos_y = torch.stack((pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4).flatten(3)
pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
# import ipdb; ipdb.set_trace()
return pos
class PositionEmbeddingLearned(nn.Module):
"""
Absolute pos embedding, learned.
"""
def __init__(self, num_pos_feats=256):
super().__init__()
self.row_embed = nn.Embedding(50, num_pos_feats)
self.col_embed = nn.Embedding(50, num_pos_feats)
self.reset_parameters()
def reset_parameters(self):
nn.init.uniform_(self.row_embed.weight)
nn.init.uniform_(self.col_embed.weight)
def forward(self, tensor_list: NestedTensor):
x = tensor_list.tensors
h, w = x.shape[-2:]
i = torch.arange(w, device=x.device)
j = torch.arange(h, device=x.device)
x_emb = self.col_embed(i)
y_emb = self.row_embed(j)
pos = (
torch.cat(
[
x_emb.unsqueeze(0).repeat(h, 1, 1),
y_emb.unsqueeze(1).repeat(1, w, 1),
],
dim=-1,
)
.permute(2, 0, 1)
.unsqueeze(0)
.repeat(x.shape[0], 1, 1, 1)
)
return pos
def build_position_encoding(args):
N_steps = args.hidden_dim // 2
if args.position_embedding in ("v2", "sine"):
# TODO find a better way of exposing other arguments
position_embedding = PositionEmbeddingSineHW(
N_steps,
temperatureH=args.pe_temperatureH,
temperatureW=args.pe_temperatureW,
normalize=True,
)
elif args.position_embedding in ("v3", "learned"):
position_embedding = PositionEmbeddingLearned(N_steps)
else:
raise ValueError(f"not supported {args.position_embedding}")
return position_embedding

View File

@@ -1,766 +0,0 @@
# ------------------------------------------------------------------------
# Grounding DINO
# url: https://github.com/IDEA-Research/GroundingDINO
# Copyright (c) 2023 IDEA. All Rights Reserved.
# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
# ------------------------------------------------------------------------
# DINO
# Copyright (c) 2022 IDEA. All Rights Reserved.
# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
# --------------------------------------------------------
# modified from https://github.com/SwinTransformer/Swin-Transformer-Object-Detection/blob/master/mmdet/models/backbones/swin_transformer.py
# --------------------------------------------------------
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.checkpoint as checkpoint
from timm.models.layers import DropPath, to_2tuple, trunc_normal_
from invokeai.backend.image_util.grounding_segment_anything.groundingdino.util.misc import NestedTensor
class Mlp(nn.Module):
"""Multilayer perceptron."""
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.0):
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
self.fc1 = nn.Linear(in_features, hidden_features)
self.act = act_layer()
self.fc2 = nn.Linear(hidden_features, out_features)
self.drop = nn.Dropout(drop)
def forward(self, x):
x = self.fc1(x)
x = self.act(x)
x = self.drop(x)
x = self.fc2(x)
x = self.drop(x)
return x
def window_partition(x, window_size):
"""
Args:
x: (B, H, W, C)
window_size (int): window size
Returns:
windows: (num_windows*B, window_size, window_size, C)
"""
B, H, W, C = x.shape
x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
return windows
def window_reverse(windows, window_size, H, W):
"""
Args:
windows: (num_windows*B, window_size, window_size, C)
window_size (int): Window size
H (int): Height of image
W (int): Width of image
Returns:
x: (B, H, W, C)
"""
B = int(windows.shape[0] / (H * W / window_size / window_size))
x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)
x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
return x
class WindowAttention(nn.Module):
"""Window based multi-head self attention (W-MSA) module with relative position bias.
It supports both of shifted and non-shifted window.
Args:
dim (int): Number of input channels.
window_size (tuple[int]): The height and width of the window.
num_heads (int): Number of attention heads.
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set
attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
proj_drop (float, optional): Dropout ratio of output. Default: 0.0
"""
def __init__(
self,
dim,
window_size,
num_heads,
qkv_bias=True,
qk_scale=None,
attn_drop=0.0,
proj_drop=0.0,
):
super().__init__()
self.dim = dim
self.window_size = window_size # Wh, Ww
self.num_heads = num_heads
head_dim = dim // num_heads
self.scale = qk_scale or head_dim**-0.5
# define a parameter table of relative position bias
self.relative_position_bias_table = nn.Parameter(
torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)
) # 2*Wh-1 * 2*Ww-1, nH
# get pair-wise relative position index for each token inside the window
coords_h = torch.arange(self.window_size[0])
coords_w = torch.arange(self.window_size[1])
coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0
relative_coords[:, :, 1] += self.window_size[1] - 1
relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
self.register_buffer("relative_position_index", relative_position_index)
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
trunc_normal_(self.relative_position_bias_table, std=0.02)
self.softmax = nn.Softmax(dim=-1)
def forward(self, x, mask=None):
"""Forward function.
Args:
x: input features with shape of (num_windows*B, N, C)
mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None
"""
B_, N, C = x.shape
qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
q = q * self.scale
attn = q @ k.transpose(-2, -1)
relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1
) # Wh*Ww,Wh*Ww,nH
relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
attn = attn + relative_position_bias.unsqueeze(0)
if mask is not None:
nW = mask.shape[0]
attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
attn = attn.view(-1, self.num_heads, N, N)
attn = self.softmax(attn)
else:
attn = self.softmax(attn)
attn = self.attn_drop(attn)
x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
x = self.proj(x)
x = self.proj_drop(x)
return x
class SwinTransformerBlock(nn.Module):
"""Swin Transformer Block.
Args:
dim (int): Number of input channels.
num_heads (int): Number of attention heads.
window_size (int): Window size.
shift_size (int): Shift size for SW-MSA.
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
drop (float, optional): Dropout rate. Default: 0.0
attn_drop (float, optional): Attention dropout rate. Default: 0.0
drop_path (float, optional): Stochastic depth rate. Default: 0.0
act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
"""
def __init__(
self,
dim,
num_heads,
window_size=7,
shift_size=0,
mlp_ratio=4.0,
qkv_bias=True,
qk_scale=None,
drop=0.0,
attn_drop=0.0,
drop_path=0.0,
act_layer=nn.GELU,
norm_layer=nn.LayerNorm,
):
super().__init__()
self.dim = dim
self.num_heads = num_heads
self.window_size = window_size
self.shift_size = shift_size
self.mlp_ratio = mlp_ratio
assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size"
self.norm1 = norm_layer(dim)
self.attn = WindowAttention(
dim,
window_size=to_2tuple(self.window_size),
num_heads=num_heads,
qkv_bias=qkv_bias,
qk_scale=qk_scale,
attn_drop=attn_drop,
proj_drop=drop,
)
self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
self.norm2 = norm_layer(dim)
mlp_hidden_dim = int(dim * mlp_ratio)
self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
self.H = None
self.W = None
def forward(self, x, mask_matrix):
"""Forward function.
Args:
x: Input feature, tensor size (B, H*W, C).
H, W: Spatial resolution of the input feature.
mask_matrix: Attention mask for cyclic shift.
"""
B, L, C = x.shape
H, W = self.H, self.W
assert L == H * W, "input feature has wrong size"
shortcut = x
x = self.norm1(x)
x = x.view(B, H, W, C)
# pad feature maps to multiples of window size
pad_l = pad_t = 0
pad_r = (self.window_size - W % self.window_size) % self.window_size
pad_b = (self.window_size - H % self.window_size) % self.window_size
x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b))
_, Hp, Wp, _ = x.shape
# cyclic shift
if self.shift_size > 0:
shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
attn_mask = mask_matrix
else:
shifted_x = x
attn_mask = None
# partition windows
x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C
x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # nW*B, window_size*window_size, C
# W-MSA/SW-MSA
attn_windows = self.attn(x_windows, mask=attn_mask) # nW*B, window_size*window_size, C
# merge windows
attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)
shifted_x = window_reverse(attn_windows, self.window_size, Hp, Wp) # B H' W' C
# reverse cyclic shift
if self.shift_size > 0:
x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
else:
x = shifted_x
if pad_r > 0 or pad_b > 0:
x = x[:, :H, :W, :].contiguous()
x = x.view(B, H * W, C)
# FFN
x = shortcut + self.drop_path(x)
x = x + self.drop_path(self.mlp(self.norm2(x)))
return x
class PatchMerging(nn.Module):
"""Patch Merging Layer
Args:
dim (int): Number of input channels.
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
"""
def __init__(self, dim, norm_layer=nn.LayerNorm):
super().__init__()
self.dim = dim
self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
self.norm = norm_layer(4 * dim)
def forward(self, x, H, W):
"""Forward function.
Args:
x: Input feature, tensor size (B, H*W, C).
H, W: Spatial resolution of the input feature.
"""
B, L, C = x.shape
assert L == H * W, "input feature has wrong size"
x = x.view(B, H, W, C)
# padding
pad_input = (H % 2 == 1) or (W % 2 == 1)
if pad_input:
x = F.pad(x, (0, 0, 0, W % 2, 0, H % 2))
x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C
x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C
x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C
x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C
x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C
x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C
x = self.norm(x)
x = self.reduction(x)
return x
class BasicLayer(nn.Module):
"""A basic Swin Transformer layer for one stage.
Args:
dim (int): Number of feature channels
depth (int): Depths of this stage.
num_heads (int): Number of attention head.
window_size (int): Local window size. Default: 7.
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4.
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
drop (float, optional): Dropout rate. Default: 0.0
attn_drop (float, optional): Attention dropout rate. Default: 0.0
drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
"""
def __init__(
self,
dim,
depth,
num_heads,
window_size=7,
mlp_ratio=4.0,
qkv_bias=True,
qk_scale=None,
drop=0.0,
attn_drop=0.0,
drop_path=0.0,
norm_layer=nn.LayerNorm,
downsample=None,
use_checkpoint=False,
):
super().__init__()
self.window_size = window_size
self.shift_size = window_size // 2
self.depth = depth
self.use_checkpoint = use_checkpoint
# build blocks
self.blocks = nn.ModuleList(
[
SwinTransformerBlock(
dim=dim,
num_heads=num_heads,
window_size=window_size,
shift_size=0 if (i % 2 == 0) else window_size // 2,
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias,
qk_scale=qk_scale,
drop=drop,
attn_drop=attn_drop,
drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
norm_layer=norm_layer,
)
for i in range(depth)
]
)
# patch merging layer
if downsample is not None:
self.downsample = downsample(dim=dim, norm_layer=norm_layer)
else:
self.downsample = None
def forward(self, x, H, W):
"""Forward function.
Args:
x: Input feature, tensor size (B, H*W, C).
H, W: Spatial resolution of the input feature.
"""
# calculate attention mask for SW-MSA
Hp = int(np.ceil(H / self.window_size)) * self.window_size
Wp = int(np.ceil(W / self.window_size)) * self.window_size
img_mask = torch.zeros((1, Hp, Wp, 1), device=x.device, dtype=x.dtype) # 1 Hp Wp 1
h_slices = (
slice(0, -self.window_size),
slice(-self.window_size, -self.shift_size),
slice(-self.shift_size, None),
)
w_slices = (
slice(0, -self.window_size),
slice(-self.window_size, -self.shift_size),
slice(-self.shift_size, None),
)
cnt = 0
for h in h_slices:
for w in w_slices:
img_mask[:, h, w, :] = cnt
cnt += 1
mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1
mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
for blk in self.blocks:
blk.H, blk.W = H, W
if self.use_checkpoint:
x = checkpoint.checkpoint(blk, x, attn_mask)
else:
x = blk(x, attn_mask)
if self.downsample is not None:
x_down = self.downsample(x, H, W)
Wh, Ww = (H + 1) // 2, (W + 1) // 2
return x, H, W, x_down, Wh, Ww
else:
return x, H, W, x, H, W
class PatchEmbed(nn.Module):
"""Image to Patch Embedding
Args:
patch_size (int): Patch token size. Default: 4.
in_chans (int): Number of input image channels. Default: 3.
embed_dim (int): Number of linear projection output channels. Default: 96.
norm_layer (nn.Module, optional): Normalization layer. Default: None
"""
def __init__(self, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None):
super().__init__()
patch_size = to_2tuple(patch_size)
self.patch_size = patch_size
self.in_chans = in_chans
self.embed_dim = embed_dim
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
if norm_layer is not None:
self.norm = norm_layer(embed_dim)
else:
self.norm = None
def forward(self, x):
"""Forward function."""
# padding
_, _, H, W = x.size()
if W % self.patch_size[1] != 0:
x = F.pad(x, (0, self.patch_size[1] - W % self.patch_size[1]))
if H % self.patch_size[0] != 0:
x = F.pad(x, (0, 0, 0, self.patch_size[0] - H % self.patch_size[0]))
x = self.proj(x) # B C Wh Ww
if self.norm is not None:
Wh, Ww = x.size(2), x.size(3)
x = x.flatten(2).transpose(1, 2)
x = self.norm(x)
x = x.transpose(1, 2).view(-1, self.embed_dim, Wh, Ww)
return x
class SwinTransformer(nn.Module):
"""Swin Transformer backbone.
A PyTorch impl of : `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows` -
https://arxiv.org/pdf/2103.14030
Args:
pretrain_img_size (int): Input image size for training the pretrained model,
used in absolute postion embedding. Default 224.
patch_size (int | tuple(int)): Patch size. Default: 4.
in_chans (int): Number of input image channels. Default: 3.
embed_dim (int): Number of linear projection output channels. Default: 96.
depths (tuple[int]): Depths of each Swin Transformer stage.
num_heads (tuple[int]): Number of attention head of each stage.
window_size (int): Window size. Default: 7.
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4.
qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True
qk_scale (float): Override default qk scale of head_dim ** -0.5 if set.
drop_rate (float): Dropout rate.
attn_drop_rate (float): Attention dropout rate. Default: 0.
drop_path_rate (float): Stochastic depth rate. Default: 0.2.
norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm.
ape (bool): If True, add absolute position embedding to the patch embedding. Default: False.
patch_norm (bool): If True, add normalization after patch embedding. Default: True.
out_indices (Sequence[int]): Output from which stages.
frozen_stages (int): Stages to be frozen (stop grad and set eval mode).
-1 means not freezing any parameters.
use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
dilation (bool): if True, the output size if 16x downsample, ow 32x downsample.
"""
def __init__(
self,
pretrain_img_size=224,
patch_size=4,
in_chans=3,
embed_dim=96,
depths=[2, 2, 6, 2],
num_heads=[3, 6, 12, 24],
window_size=7,
mlp_ratio=4.0,
qkv_bias=True,
qk_scale=None,
drop_rate=0.0,
attn_drop_rate=0.0,
drop_path_rate=0.2,
norm_layer=nn.LayerNorm,
ape=False,
patch_norm=True,
out_indices=(0, 1, 2, 3),
frozen_stages=-1,
dilation=False,
use_checkpoint=False,
):
super().__init__()
self.pretrain_img_size = pretrain_img_size
self.num_layers = len(depths)
self.embed_dim = embed_dim
self.ape = ape
self.patch_norm = patch_norm
self.out_indices = out_indices
self.frozen_stages = frozen_stages
self.dilation = dilation
# if use_checkpoint:
# print("use_checkpoint!!!!!!!!!!!!!!!!!!!!!!!!")
# split image into non-overlapping patches
self.patch_embed = PatchEmbed(
patch_size=patch_size,
in_chans=in_chans,
embed_dim=embed_dim,
norm_layer=norm_layer if self.patch_norm else None,
)
# absolute position embedding
if self.ape:
pretrain_img_size = to_2tuple(pretrain_img_size)
patch_size = to_2tuple(patch_size)
patches_resolution = [
pretrain_img_size[0] // patch_size[0],
pretrain_img_size[1] // patch_size[1],
]
self.absolute_pos_embed = nn.Parameter(
torch.zeros(1, embed_dim, patches_resolution[0], patches_resolution[1])
)
trunc_normal_(self.absolute_pos_embed, std=0.02)
self.pos_drop = nn.Dropout(p=drop_rate)
# stochastic depth
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule
# build layers
self.layers = nn.ModuleList()
# prepare downsample list
downsamplelist = [PatchMerging for i in range(self.num_layers)]
downsamplelist[-1] = None
num_features = [int(embed_dim * 2**i) for i in range(self.num_layers)]
if self.dilation:
downsamplelist[-2] = None
num_features[-1] = int(embed_dim * 2 ** (self.num_layers - 1)) // 2
for i_layer in range(self.num_layers):
layer = BasicLayer(
# dim=int(embed_dim * 2 ** i_layer),
dim=num_features[i_layer],
depth=depths[i_layer],
num_heads=num_heads[i_layer],
window_size=window_size,
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias,
qk_scale=qk_scale,
drop=drop_rate,
attn_drop=attn_drop_rate,
drop_path=dpr[sum(depths[:i_layer]) : sum(depths[: i_layer + 1])],
norm_layer=norm_layer,
# downsample=PatchMerging if (i_layer < self.num_layers - 1) else None,
downsample=downsamplelist[i_layer],
use_checkpoint=use_checkpoint,
)
self.layers.append(layer)
# num_features = [int(embed_dim * 2 ** i) for i in range(self.num_layers)]
self.num_features = num_features
# add a norm layer for each output
for i_layer in out_indices:
layer = norm_layer(num_features[i_layer])
layer_name = f"norm{i_layer}"
self.add_module(layer_name, layer)
self._freeze_stages()
def _freeze_stages(self):
if self.frozen_stages >= 0:
self.patch_embed.eval()
for param in self.patch_embed.parameters():
param.requires_grad = False
if self.frozen_stages >= 1 and self.ape:
self.absolute_pos_embed.requires_grad = False
if self.frozen_stages >= 2:
self.pos_drop.eval()
for i in range(0, self.frozen_stages - 1):
m = self.layers[i]
m.eval()
for param in m.parameters():
param.requires_grad = False
# def init_weights(self, pretrained=None):
# """Initialize the weights in backbone.
# Args:
# pretrained (str, optional): Path to pre-trained weights.
# Defaults to None.
# """
# def _init_weights(m):
# if isinstance(m, nn.Linear):
# trunc_normal_(m.weight, std=.02)
# if isinstance(m, nn.Linear) and m.bias is not None:
# nn.init.constant_(m.bias, 0)
# elif isinstance(m, nn.LayerNorm):
# nn.init.constant_(m.bias, 0)
# nn.init.constant_(m.weight, 1.0)
# if isinstance(pretrained, str):
# self.apply(_init_weights)
# logger = get_root_logger()
# load_checkpoint(self, pretrained, strict=False, logger=logger)
# elif pretrained is None:
# self.apply(_init_weights)
# else:
# raise TypeError('pretrained must be a str or None')
def forward_raw(self, x):
"""Forward function."""
x = self.patch_embed(x)
Wh, Ww = x.size(2), x.size(3)
if self.ape:
# interpolate the position embedding to the corresponding size
absolute_pos_embed = F.interpolate(self.absolute_pos_embed, size=(Wh, Ww), mode="bicubic")
x = (x + absolute_pos_embed).flatten(2).transpose(1, 2) # B Wh*Ww C
else:
x = x.flatten(2).transpose(1, 2)
x = self.pos_drop(x)
outs = []
for i in range(self.num_layers):
layer = self.layers[i]
x_out, H, W, x, Wh, Ww = layer(x, Wh, Ww)
# import ipdb; ipdb.set_trace()
if i in self.out_indices:
norm_layer = getattr(self, f"norm{i}")
x_out = norm_layer(x_out)
out = x_out.view(-1, H, W, self.num_features[i]).permute(0, 3, 1, 2).contiguous()
outs.append(out)
# in:
# torch.Size([2, 3, 1024, 1024])
# outs:
# [torch.Size([2, 192, 256, 256]), torch.Size([2, 384, 128, 128]), \
# torch.Size([2, 768, 64, 64]), torch.Size([2, 1536, 32, 32])]
return tuple(outs)
def forward(self, tensor_list: NestedTensor):
x = tensor_list.tensors
"""Forward function."""
x = self.patch_embed(x)
Wh, Ww = x.size(2), x.size(3)
if self.ape:
# interpolate the position embedding to the corresponding size
absolute_pos_embed = F.interpolate(self.absolute_pos_embed, size=(Wh, Ww), mode="bicubic")
x = (x + absolute_pos_embed).flatten(2).transpose(1, 2) # B Wh*Ww C
else:
x = x.flatten(2).transpose(1, 2)
x = self.pos_drop(x)
outs = []
for i in range(self.num_layers):
layer = self.layers[i]
x_out, H, W, x, Wh, Ww = layer(x, Wh, Ww)
if i in self.out_indices:
norm_layer = getattr(self, f"norm{i}")
x_out = norm_layer(x_out)
out = x_out.view(-1, H, W, self.num_features[i]).permute(0, 3, 1, 2).contiguous()
outs.append(out)
# in:
# torch.Size([2, 3, 1024, 1024])
# out:
# [torch.Size([2, 192, 256, 256]), torch.Size([2, 384, 128, 128]), \
# torch.Size([2, 768, 64, 64]), torch.Size([2, 1536, 32, 32])]
# collect for nesttensors
outs_dict = {}
for idx, out_i in enumerate(outs):
m = tensor_list.mask
assert m is not None
mask = F.interpolate(m[None].float(), size=out_i.shape[-2:]).to(torch.bool)[0]
outs_dict[idx] = NestedTensor(out_i, mask)
return outs_dict
def train(self, mode=True):
"""Convert the model into training mode while keep layers freezed."""
super(SwinTransformer, self).train(mode)
self._freeze_stages()
def build_swin_transformer(modelname, pretrain_img_size, **kw):
assert modelname in [
"swin_T_224_1k",
"swin_B_224_22k",
"swin_B_384_22k",
"swin_L_224_22k",
"swin_L_384_22k",
]
model_para_dict = {
"swin_T_224_1k": dict(embed_dim=96, depths=[2, 2, 6, 2], num_heads=[3, 6, 12, 24], window_size=7),
"swin_B_224_22k": dict(embed_dim=128, depths=[2, 2, 18, 2], num_heads=[4, 8, 16, 32], window_size=7),
"swin_B_384_22k": dict(embed_dim=128, depths=[2, 2, 18, 2], num_heads=[4, 8, 16, 32], window_size=12),
"swin_L_224_22k": dict(embed_dim=192, depths=[2, 2, 18, 2], num_heads=[6, 12, 24, 48], window_size=7),
"swin_L_384_22k": dict(embed_dim=192, depths=[2, 2, 18, 2], num_heads=[6, 12, 24, 48], window_size=12),
}
kw_cgf = model_para_dict[modelname]
kw_cgf.update(kw)
model = SwinTransformer(pretrain_img_size=pretrain_img_size, **kw_cgf)
return model
if __name__ == "__main__":
model = build_swin_transformer("swin_L_384_22k", 384, dilation=True)
x = torch.rand(2, 3, 1024, 1024)
y = model.forward_raw(x)
import ipdb
ipdb.set_trace()
x = torch.rand(2, 3, 384, 384)
y = model.forward_raw(x)

View File

@@ -1,250 +0,0 @@
# ------------------------------------------------------------------------
# Grounding DINO
# url: https://github.com/IDEA-Research/GroundingDINO
# Copyright (c) 2023 IDEA. All Rights Reserved.
# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
# ------------------------------------------------------------------------
import torch
from torch import nn
from transformers.modeling_outputs import BaseModelOutputWithPoolingAndCrossAttentions
class BertModelWarper(nn.Module):
def __init__(self, bert_model):
super().__init__()
# self.bert = bert_modelc
self.config = bert_model.config
self.embeddings = bert_model.embeddings
self.encoder = bert_model.encoder
self.pooler = bert_model.pooler
self.get_extended_attention_mask = bert_model.get_extended_attention_mask
self.invert_attention_mask = bert_model.invert_attention_mask
self.get_head_mask = bert_model.get_head_mask
def forward(
self,
input_ids=None,
attention_mask=None,
token_type_ids=None,
position_ids=None,
head_mask=None,
inputs_embeds=None,
encoder_hidden_states=None,
encoder_attention_mask=None,
past_key_values=None,
use_cache=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
):
r"""
encoder_hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
the model is configured as a decoder.
encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``:
- 1 for tokens that are **not masked**,
- 0 for tokens that are **masked**.
past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids`
(those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)`
instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`.
use_cache (:obj:`bool`, `optional`):
If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up
decoding (see :obj:`past_key_values`).
"""
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if self.config.is_decoder:
use_cache = use_cache if use_cache is not None else self.config.use_cache
else:
use_cache = False
if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
elif input_ids is not None:
input_shape = input_ids.size()
batch_size, seq_length = input_shape
elif inputs_embeds is not None:
input_shape = inputs_embeds.size()[:-1]
batch_size, seq_length = input_shape
else:
raise ValueError("You have to specify either input_ids or inputs_embeds")
device = input_ids.device if input_ids is not None else inputs_embeds.device
# past_key_values_length
past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0
if attention_mask is None:
attention_mask = torch.ones(((batch_size, seq_length + past_key_values_length)), device=device)
if token_type_ids is None:
token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)
# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
# ourselves in which case we just need to make it broadcastable to all heads.
extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape, device)
# If a 2D or 3D attention mask is provided for the cross-attention
# we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
if self.config.is_decoder and encoder_hidden_states is not None:
encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()
encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
if encoder_attention_mask is None:
encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)
encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
else:
encoder_extended_attention_mask = None
# if os.environ.get('IPDB_SHILONG_DEBUG', None) == 'INFO':
# import ipdb; ipdb.set_trace()
# Prepare head mask if needed
# 1.0 in head_mask indicate we keep the head
# attention_probs has shape bsz x n_heads x N x N
# input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
# and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
embedding_output = self.embeddings(
input_ids=input_ids,
position_ids=position_ids,
token_type_ids=token_type_ids,
inputs_embeds=inputs_embeds,
past_key_values_length=past_key_values_length,
)
encoder_outputs = self.encoder(
embedding_output,
attention_mask=extended_attention_mask,
head_mask=head_mask,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_extended_attention_mask,
past_key_values=past_key_values,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
sequence_output = encoder_outputs[0]
pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
if not return_dict:
return (sequence_output, pooled_output) + encoder_outputs[1:]
return BaseModelOutputWithPoolingAndCrossAttentions(
last_hidden_state=sequence_output,
pooler_output=pooled_output,
past_key_values=encoder_outputs.past_key_values,
hidden_states=encoder_outputs.hidden_states,
attentions=encoder_outputs.attentions,
cross_attentions=encoder_outputs.cross_attentions,
)
class TextEncoderShell(nn.Module):
def __init__(self, text_encoder):
super().__init__()
self.text_encoder = text_encoder
self.config = self.text_encoder.config
def forward(self, **kw):
# feed into text encoder
return self.text_encoder(**kw)
def generate_masks_with_special_tokens(tokenized, special_tokens_list, tokenizer):
"""Generate attention mask between each pair of special tokens
Args:
input_ids (torch.Tensor): input ids. Shape: [bs, num_token]
special_tokens_mask (list): special tokens mask.
Returns:
torch.Tensor: attention mask between each special tokens.
"""
input_ids = tokenized["input_ids"]
bs, num_token = input_ids.shape
# special_tokens_mask: bs, num_token. 1 for special tokens. 0 for normal tokens
special_tokens_mask = torch.zeros((bs, num_token), device=input_ids.device).bool()
for special_token in special_tokens_list:
special_tokens_mask |= input_ids == special_token
# idxs: each row is a list of indices of special tokens
idxs = torch.nonzero(special_tokens_mask)
# generate attention mask and positional ids
attention_mask = torch.eye(num_token, device=input_ids.device).bool().unsqueeze(0).repeat(bs, 1, 1)
position_ids = torch.zeros((bs, num_token), device=input_ids.device)
previous_col = 0
for i in range(idxs.shape[0]):
row, col = idxs[i]
if (col == 0) or (col == num_token - 1):
attention_mask[row, col, col] = True
position_ids[row, col] = 0
else:
attention_mask[row, previous_col + 1 : col + 1, previous_col + 1 : col + 1] = True
position_ids[row, previous_col + 1 : col + 1] = torch.arange(0, col - previous_col, device=input_ids.device)
previous_col = col
# # padding mask
# padding_mask = tokenized['attention_mask']
# attention_mask = attention_mask & padding_mask.unsqueeze(1).bool() & padding_mask.unsqueeze(2).bool()
return attention_mask, position_ids.to(torch.long)
def generate_masks_with_special_tokens_and_transfer_map(tokenized, special_tokens_list, tokenizer):
"""Generate attention mask between each pair of special tokens
Args:
input_ids (torch.Tensor): input ids. Shape: [bs, num_token]
special_tokens_mask (list): special tokens mask.
Returns:
torch.Tensor: attention mask between each special tokens.
"""
input_ids = tokenized["input_ids"]
bs, num_token = input_ids.shape
# special_tokens_mask: bs, num_token. 1 for special tokens. 0 for normal tokens
special_tokens_mask = torch.zeros((bs, num_token), device=input_ids.device).bool()
for special_token in special_tokens_list:
special_tokens_mask |= input_ids == special_token
# idxs: each row is a list of indices of special tokens
idxs = torch.nonzero(special_tokens_mask)
# generate attention mask and positional ids
attention_mask = torch.eye(num_token, device=input_ids.device).bool().unsqueeze(0).repeat(bs, 1, 1)
position_ids = torch.zeros((bs, num_token), device=input_ids.device)
cate_to_token_mask_list = [[] for _ in range(bs)]
previous_col = 0
for i in range(idxs.shape[0]):
row, col = idxs[i]
if (col == 0) or (col == num_token - 1):
attention_mask[row, col, col] = True
position_ids[row, col] = 0
else:
attention_mask[row, previous_col + 1 : col + 1, previous_col + 1 : col + 1] = True
position_ids[row, previous_col + 1 : col + 1] = torch.arange(0, col - previous_col, device=input_ids.device)
c2t_maski = torch.zeros((num_token), device=input_ids.device).bool()
c2t_maski[previous_col + 1 : col] = True
cate_to_token_mask_list[row].append(c2t_maski)
previous_col = col
cate_to_token_mask_list = [
torch.stack(cate_to_token_mask_listi, dim=0) for cate_to_token_mask_listi in cate_to_token_mask_list
]
# # padding mask
# padding_mask = tokenized['attention_mask']
# attention_mask = attention_mask & padding_mask.unsqueeze(1).bool() & padding_mask.unsqueeze(2).bool()
return attention_mask, position_ids.to(torch.long), cate_to_token_mask_list

View File

@@ -1,295 +0,0 @@
# ------------------------------------------------------------------------
# Grounding DINO
# url: https://github.com/IDEA-Research/GroundingDINO
# Copyright (c) 2023 IDEA. All Rights Reserved.
# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
# ------------------------------------------------------------------------
import torch
import torch.nn as nn
import torch.nn.functional as F
from timm.models.layers import DropPath
class FeatureResizer(nn.Module):
"""
This class takes as input a set of embeddings of dimension C1 and outputs a set of
embedding of dimension C2, after a linear transformation, dropout and normalization (LN).
"""
def __init__(self, input_feat_size, output_feat_size, dropout, do_ln=True):
super().__init__()
self.do_ln = do_ln
# Object feature encoding
self.fc = nn.Linear(input_feat_size, output_feat_size, bias=True)
self.layer_norm = nn.LayerNorm(output_feat_size, eps=1e-12)
self.dropout = nn.Dropout(dropout)
def forward(self, encoder_features):
x = self.fc(encoder_features)
if self.do_ln:
x = self.layer_norm(x)
output = self.dropout(x)
return output
def l1norm(X, dim, eps=1e-8):
"""L1-normalize columns of X"""
norm = torch.abs(X).sum(dim=dim, keepdim=True) + eps
X = torch.div(X, norm)
return X
def l2norm(X, dim, eps=1e-8):
"""L2-normalize columns of X"""
norm = torch.pow(X, 2).sum(dim=dim, keepdim=True).sqrt() + eps
X = torch.div(X, norm)
return X
def func_attention(query, context, smooth=1, raw_feature_norm="softmax", eps=1e-8):
"""
query: (n_context, queryL, d)
context: (n_context, sourceL, d)
"""
_, queryL = query.size(0), query.size(1)
batch_size, sourceL = context.size(0), context.size(1)
# Get attention
# --> (batch, d, queryL)
queryT = torch.transpose(query, 1, 2)
# (batch, sourceL, d)(batch, d, queryL)
# --> (batch, sourceL, queryL)
attn = torch.bmm(context, queryT)
if raw_feature_norm == "softmax":
# --> (batch*sourceL, queryL)
attn = attn.view(batch_size * sourceL, queryL)
attn = nn.Softmax()(attn)
# --> (batch, sourceL, queryL)
attn = attn.view(batch_size, sourceL, queryL)
elif raw_feature_norm == "l2norm":
attn = l2norm(attn, 2)
elif raw_feature_norm == "clipped_l2norm":
attn = nn.LeakyReLU(0.1)(attn)
attn = l2norm(attn, 2)
else:
raise ValueError("unknown first norm type:", raw_feature_norm)
# --> (batch, queryL, sourceL)
attn = torch.transpose(attn, 1, 2).contiguous()
# --> (batch*queryL, sourceL)
attn = attn.view(batch_size * queryL, sourceL)
attn = nn.Softmax()(attn * smooth)
# --> (batch, queryL, sourceL)
attn = attn.view(batch_size, queryL, sourceL)
# --> (batch, sourceL, queryL)
attnT = torch.transpose(attn, 1, 2).contiguous()
# --> (batch, d, sourceL)
contextT = torch.transpose(context, 1, 2)
# (batch x d x sourceL)(batch x sourceL x queryL)
# --> (batch, d, queryL)
weightedContext = torch.bmm(contextT, attnT)
# --> (batch, queryL, d)
weightedContext = torch.transpose(weightedContext, 1, 2)
return weightedContext, attnT
class BiMultiHeadAttention(nn.Module):
def __init__(self, v_dim, l_dim, embed_dim, num_heads, dropout=0.1, cfg=None):
super(BiMultiHeadAttention, self).__init__()
self.embed_dim = embed_dim
self.num_heads = num_heads
self.head_dim = embed_dim // num_heads
self.v_dim = v_dim
self.l_dim = l_dim
assert (
self.head_dim * self.num_heads == self.embed_dim
), f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and \
`num_heads`: {self.num_heads})."
self.scale = self.head_dim ** (-0.5)
self.dropout = dropout
self.v_proj = nn.Linear(self.v_dim, self.embed_dim)
self.l_proj = nn.Linear(self.l_dim, self.embed_dim)
self.values_v_proj = nn.Linear(self.v_dim, self.embed_dim)
self.values_l_proj = nn.Linear(self.l_dim, self.embed_dim)
self.out_v_proj = nn.Linear(self.embed_dim, self.v_dim)
self.out_l_proj = nn.Linear(self.embed_dim, self.l_dim)
self.stable_softmax_2d = True
self.clamp_min_for_underflow = True
self.clamp_max_for_overflow = True
self._reset_parameters()
def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
def _reset_parameters(self):
nn.init.xavier_uniform_(self.v_proj.weight)
self.v_proj.bias.data.fill_(0)
nn.init.xavier_uniform_(self.l_proj.weight)
self.l_proj.bias.data.fill_(0)
nn.init.xavier_uniform_(self.values_v_proj.weight)
self.values_v_proj.bias.data.fill_(0)
nn.init.xavier_uniform_(self.values_l_proj.weight)
self.values_l_proj.bias.data.fill_(0)
nn.init.xavier_uniform_(self.out_v_proj.weight)
self.out_v_proj.bias.data.fill_(0)
nn.init.xavier_uniform_(self.out_l_proj.weight)
self.out_l_proj.bias.data.fill_(0)
def forward(self, v, l, attention_mask_v=None, attention_mask_l=None):
"""_summary_
Args:
v (_type_): bs, n_img, dim
l (_type_): bs, n_text, dim
attention_mask_v (_type_, optional): _description_. bs, n_img
attention_mask_l (_type_, optional): _description_. bs, n_text
Returns:
_type_: _description_
"""
# if os.environ.get('IPDB_SHILONG_DEBUG', None) == 'INFO':
# import ipdb; ipdb.set_trace()
bsz, tgt_len, _ = v.size()
query_states = self.v_proj(v) * self.scale
key_states = self._shape(self.l_proj(l), -1, bsz)
value_v_states = self._shape(self.values_v_proj(v), -1, bsz)
value_l_states = self._shape(self.values_l_proj(l), -1, bsz)
proj_shape = (bsz * self.num_heads, -1, self.head_dim)
query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape)
key_states = key_states.view(*proj_shape)
value_v_states = value_v_states.view(*proj_shape)
value_l_states = value_l_states.view(*proj_shape)
src_len = key_states.size(1)
attn_weights = torch.bmm(query_states, key_states.transpose(1, 2)) # bs*nhead, nimg, ntxt
if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len):
raise ValueError(
f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, \
but is {attn_weights.size()}"
)
if self.stable_softmax_2d:
attn_weights = attn_weights - attn_weights.max()
if self.clamp_min_for_underflow:
attn_weights = torch.clamp(
attn_weights, min=-50000
) # Do not increase -50000, data type half has quite limited range
if self.clamp_max_for_overflow:
attn_weights = torch.clamp(
attn_weights, max=50000
) # Do not increase 50000, data type half has quite limited range
attn_weights_T = attn_weights.transpose(1, 2)
attn_weights_l = attn_weights_T - torch.max(attn_weights_T, dim=-1, keepdim=True)[0]
if self.clamp_min_for_underflow:
attn_weights_l = torch.clamp(
attn_weights_l, min=-50000
) # Do not increase -50000, data type half has quite limited range
if self.clamp_max_for_overflow:
attn_weights_l = torch.clamp(
attn_weights_l, max=50000
) # Do not increase 50000, data type half has quite limited range
# mask vison for language
if attention_mask_v is not None:
attention_mask_v = attention_mask_v[:, None, None, :].repeat(1, self.num_heads, 1, 1).flatten(0, 1)
attn_weights_l.masked_fill_(attention_mask_v, float("-inf"))
attn_weights_l = attn_weights_l.softmax(dim=-1)
# mask language for vision
if attention_mask_l is not None:
attention_mask_l = attention_mask_l[:, None, None, :].repeat(1, self.num_heads, 1, 1).flatten(0, 1)
attn_weights.masked_fill_(attention_mask_l, float("-inf"))
attn_weights_v = attn_weights.softmax(dim=-1)
attn_probs_v = F.dropout(attn_weights_v, p=self.dropout, training=self.training)
attn_probs_l = F.dropout(attn_weights_l, p=self.dropout, training=self.training)
attn_output_v = torch.bmm(attn_probs_v, value_l_states)
attn_output_l = torch.bmm(attn_probs_l, value_v_states)
if attn_output_v.size() != (bsz * self.num_heads, tgt_len, self.head_dim):
raise ValueError(
f"`attn_output_v` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, \
but is {attn_output_v.size()}"
)
if attn_output_l.size() != (bsz * self.num_heads, src_len, self.head_dim):
raise ValueError(
f"`attn_output_l` should be of size {(bsz, self.num_heads, src_len, self.head_dim)}, \
but is {attn_output_l.size()}"
)
attn_output_v = attn_output_v.view(bsz, self.num_heads, tgt_len, self.head_dim)
attn_output_v = attn_output_v.transpose(1, 2)
attn_output_v = attn_output_v.reshape(bsz, tgt_len, self.embed_dim)
attn_output_l = attn_output_l.view(bsz, self.num_heads, src_len, self.head_dim)
attn_output_l = attn_output_l.transpose(1, 2)
attn_output_l = attn_output_l.reshape(bsz, src_len, self.embed_dim)
attn_output_v = self.out_v_proj(attn_output_v)
attn_output_l = self.out_l_proj(attn_output_l)
return attn_output_v, attn_output_l
# Bi-Direction MHA (text->image, image->text)
class BiAttentionBlock(nn.Module):
def __init__(
self,
v_dim,
l_dim,
embed_dim,
num_heads,
dropout=0.1,
drop_path=0.0,
init_values=1e-4,
cfg=None,
):
"""
Inputs:
embed_dim - Dimensionality of input and attention feature vectors
hidden_dim - Dimensionality of hidden layer in feed-forward network
(usually 2-4x larger than embed_dim)
num_heads - Number of heads to use in the Multi-Head Attention block
dropout - Amount of dropout to apply in the feed-forward network
"""
super(BiAttentionBlock, self).__init__()
# pre layer norm
self.layer_norm_v = nn.LayerNorm(v_dim)
self.layer_norm_l = nn.LayerNorm(l_dim)
self.attn = BiMultiHeadAttention(
v_dim=v_dim, l_dim=l_dim, embed_dim=embed_dim, num_heads=num_heads, dropout=dropout
)
# add layer scale for training stability
self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
self.gamma_v = nn.Parameter(init_values * torch.ones((v_dim)), requires_grad=True)
self.gamma_l = nn.Parameter(init_values * torch.ones((l_dim)), requires_grad=True)
def forward(self, v, l, attention_mask_v=None, attention_mask_l=None):
v = self.layer_norm_v(v)
l = self.layer_norm_l(l)
delta_v, delta_l = self.attn(v, l, attention_mask_v=attention_mask_v, attention_mask_l=attention_mask_l)
# v, l = v + delta_v, l + delta_l
v = v + self.drop_path(self.gamma_v * delta_v)
l = l + self.drop_path(self.gamma_l * delta_l)
return v, l
# def forward(self, v:List[torch.Tensor], l, attention_mask_v=None, attention_mask_l=None)

View File

@@ -1,362 +0,0 @@
# ------------------------------------------------------------------------
# Grounding DINO
# url: https://github.com/IDEA-Research/GroundingDINO
# Copyright (c) 2023 IDEA. All Rights Reserved.
# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
# ------------------------------------------------------------------------
# Conditional DETR model and criterion classes.
# Copyright (c) 2021 Microsoft. All Rights Reserved.
# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
# ------------------------------------------------------------------------
# Modified from DETR (https://github.com/facebookresearch/detr)
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
# ------------------------------------------------------------------------
# Modified from Deformable DETR (https://github.com/fundamentalvision/Deformable-DETR)
# Copyright (c) 2020 SenseTime. All Rights Reserved.
# ------------------------------------------------------------------------
import copy
from typing import List
import torch
import torch.nn.functional as F
from torch import nn
from invokeai.backend.image_util.grounding_segment_anything.groundingdino.util import get_tokenlizer
from invokeai.backend.image_util.grounding_segment_anything.groundingdino.util.misc import (
NestedTensor,
inverse_sigmoid,
nested_tensor_from_tensor_list,
)
from ..registry import MODULE_BUILD_FUNCS
from .backbone import build_backbone
from .bertwarper import BertModelWarper, generate_masks_with_special_tokens_and_transfer_map
from .transformer import build_transformer
from .utils import MLP, ContrastiveEmbed
class GroundingDINO(nn.Module):
"""This is the Cross-Attention Detector module that performs object detection"""
def __init__(
self,
backbone,
transformer,
num_queries,
aux_loss=False,
iter_update=False,
query_dim=2,
num_feature_levels=1,
nheads=8,
# two stage
two_stage_type="no", # ['no', 'standard']
dec_pred_bbox_embed_share=True,
two_stage_class_embed_share=True,
two_stage_bbox_embed_share=True,
num_patterns=0,
dn_number=100,
dn_box_noise_scale=0.4,
dn_label_noise_ratio=0.5,
dn_labelbook_size=100,
text_encoder_type="bert-base-uncased",
sub_sentence_present=True,
max_text_len=256,
):
"""Initializes the model.
Parameters:
backbone: torch module of the backbone to be used. See backbone.py
transformer: torch module of the transformer architecture. See transformer.py
num_queries: number of object queries, ie detection slot. This is the maximal number of objects
Conditional DETR can detect in a single image. For COCO, we recommend 100 queries.
aux_loss: True if auxiliary decoding losses (loss at each decoder layer) are to be used.
"""
super().__init__()
self.num_queries = num_queries
self.transformer = transformer
self.hidden_dim = hidden_dim = transformer.d_model
self.num_feature_levels = num_feature_levels
self.nheads = nheads
self.max_text_len = 256
self.sub_sentence_present = sub_sentence_present
# setting query dim
self.query_dim = query_dim
assert query_dim == 4
# for dn training
self.num_patterns = num_patterns
self.dn_number = dn_number
self.dn_box_noise_scale = dn_box_noise_scale
self.dn_label_noise_ratio = dn_label_noise_ratio
self.dn_labelbook_size = dn_labelbook_size
# bert
self.tokenizer = get_tokenlizer.get_tokenlizer(text_encoder_type)
self.bert = get_tokenlizer.get_pretrained_language_model(text_encoder_type)
self.bert.pooler.dense.weight.requires_grad_(False)
self.bert.pooler.dense.bias.requires_grad_(False)
self.bert = BertModelWarper(bert_model=self.bert)
self.feat_map = nn.Linear(self.bert.config.hidden_size, self.hidden_dim, bias=True)
nn.init.constant_(self.feat_map.bias.data, 0)
nn.init.xavier_uniform_(self.feat_map.weight.data)
# freeze
# special tokens
self.specical_tokens = self.tokenizer.convert_tokens_to_ids(["[CLS]", "[SEP]", ".", "?"])
# prepare input projection layers
if num_feature_levels > 1:
num_backbone_outs = len(backbone.num_channels)
input_proj_list = []
for _ in range(num_backbone_outs):
in_channels = backbone.num_channels[_]
input_proj_list.append(
nn.Sequential(
nn.Conv2d(in_channels, hidden_dim, kernel_size=1),
nn.GroupNorm(32, hidden_dim),
)
)
for _ in range(num_feature_levels - num_backbone_outs):
input_proj_list.append(
nn.Sequential(
nn.Conv2d(in_channels, hidden_dim, kernel_size=3, stride=2, padding=1),
nn.GroupNorm(32, hidden_dim),
)
)
in_channels = hidden_dim
self.input_proj = nn.ModuleList(input_proj_list)
else:
assert two_stage_type == "no", "two_stage_type should be no if num_feature_levels=1 !!!"
self.input_proj = nn.ModuleList(
[
nn.Sequential(
nn.Conv2d(backbone.num_channels[-1], hidden_dim, kernel_size=1),
nn.GroupNorm(32, hidden_dim),
)
]
)
self.backbone = backbone
self.aux_loss = aux_loss
self.box_pred_damping = None
self.iter_update = iter_update
assert iter_update, "Why not iter_update?"
# prepare pred layers
self.dec_pred_bbox_embed_share = dec_pred_bbox_embed_share
# prepare class & box embed
_class_embed = ContrastiveEmbed()
_bbox_embed = MLP(hidden_dim, hidden_dim, 4, 3)
nn.init.constant_(_bbox_embed.layers[-1].weight.data, 0)
nn.init.constant_(_bbox_embed.layers[-1].bias.data, 0)
if dec_pred_bbox_embed_share:
box_embed_layerlist = [_bbox_embed for i in range(transformer.num_decoder_layers)]
else:
box_embed_layerlist = [copy.deepcopy(_bbox_embed) for i in range(transformer.num_decoder_layers)]
class_embed_layerlist = [_class_embed for i in range(transformer.num_decoder_layers)]
self.bbox_embed = nn.ModuleList(box_embed_layerlist)
self.class_embed = nn.ModuleList(class_embed_layerlist)
self.transformer.decoder.bbox_embed = self.bbox_embed
self.transformer.decoder.class_embed = self.class_embed
# two stage
self.two_stage_type = two_stage_type
assert two_stage_type in ["no", "standard"], "unknown param {} of two_stage_type".format(two_stage_type)
if two_stage_type != "no":
if two_stage_bbox_embed_share:
assert dec_pred_bbox_embed_share
self.transformer.enc_out_bbox_embed = _bbox_embed
else:
self.transformer.enc_out_bbox_embed = copy.deepcopy(_bbox_embed)
if two_stage_class_embed_share:
assert dec_pred_bbox_embed_share
self.transformer.enc_out_class_embed = _class_embed
else:
self.transformer.enc_out_class_embed = copy.deepcopy(_class_embed)
self.refpoint_embed = None
self._reset_parameters()
def _reset_parameters(self):
# init input_proj
for proj in self.input_proj:
nn.init.xavier_uniform_(proj[0].weight, gain=1)
nn.init.constant_(proj[0].bias, 0)
def init_ref_points(self, use_num_queries):
self.refpoint_embed = nn.Embedding(use_num_queries, self.query_dim)
def forward(self, samples: NestedTensor, targets: List = None, **kw):
"""The forward expects a NestedTensor, which consists of:
- samples.tensor: batched images, of shape [batch_size x 3 x H x W]
- samples.mask: a binary mask of shape [batch_size x H x W], containing 1 on padded pixels
It returns a dict with the following elements:
- "pred_logits": the classification logits (including no-object) for all queries.
Shape= [batch_size x num_queries x num_classes]
- "pred_boxes": The normalized boxes coordinates for all queries, represented as
(center_x, center_y, width, height). These values are normalized in [0, 1],
relative to the size of each individual image (disregarding possible padding).
See PostProcess for information on how to retrieve the unnormalized bounding box.
- "aux_outputs": Optional, only returned when auxilary losses are activated. It is a list of
dictionnaries containing the two above keys for each decoder layer.
"""
if targets is None:
captions = kw["captions"]
else:
captions = [t["caption"] for t in targets]
len(captions)
# encoder texts
tokenized = self.tokenizer(captions, padding="longest", return_tensors="pt").to(samples.device)
(
text_self_attention_masks,
position_ids,
cate_to_token_mask_list,
) = generate_masks_with_special_tokens_and_transfer_map(tokenized, self.specical_tokens, self.tokenizer)
if text_self_attention_masks.shape[1] > self.max_text_len:
text_self_attention_masks = text_self_attention_masks[:, : self.max_text_len, : self.max_text_len]
position_ids = position_ids[:, : self.max_text_len]
tokenized["input_ids"] = tokenized["input_ids"][:, : self.max_text_len]
tokenized["attention_mask"] = tokenized["attention_mask"][:, : self.max_text_len]
tokenized["token_type_ids"] = tokenized["token_type_ids"][:, : self.max_text_len]
# extract text embeddings
if self.sub_sentence_present:
tokenized_for_encoder = {k: v for k, v in tokenized.items() if k != "attention_mask"}
tokenized_for_encoder["attention_mask"] = text_self_attention_masks
tokenized_for_encoder["position_ids"] = position_ids
else:
# import ipdb; ipdb.set_trace()
tokenized_for_encoder = tokenized
bert_output = self.bert(**tokenized_for_encoder) # bs, 195, 768
encoded_text = self.feat_map(bert_output["last_hidden_state"]) # bs, 195, d_model
text_token_mask = tokenized.attention_mask.bool() # bs, 195
# text_token_mask: True for nomask, False for mask
# text_self_attention_masks: True for nomask, False for mask
if encoded_text.shape[1] > self.max_text_len:
encoded_text = encoded_text[:, : self.max_text_len, :]
text_token_mask = text_token_mask[:, : self.max_text_len]
position_ids = position_ids[:, : self.max_text_len]
text_self_attention_masks = text_self_attention_masks[:, : self.max_text_len, : self.max_text_len]
text_dict = {
"encoded_text": encoded_text, # bs, 195, d_model
"text_token_mask": text_token_mask, # bs, 195
"position_ids": position_ids, # bs, 195
"text_self_attention_masks": text_self_attention_masks, # bs, 195,195
}
# import ipdb; ipdb.set_trace()
if isinstance(samples, (list, torch.Tensor)):
samples = nested_tensor_from_tensor_list(samples)
features, poss = self.backbone(samples)
srcs = []
masks = []
for l, feat in enumerate(features):
src, mask = feat.decompose()
srcs.append(self.input_proj[l](src))
masks.append(mask)
assert mask is not None
if self.num_feature_levels > len(srcs):
_len_srcs = len(srcs)
for l in range(_len_srcs, self.num_feature_levels):
if l == _len_srcs:
src = self.input_proj[l](features[-1].tensors)
else:
src = self.input_proj[l](srcs[-1])
m = samples.mask
mask = F.interpolate(m[None].float(), size=src.shape[-2:]).to(torch.bool)[0]
pos_l = self.backbone[1](NestedTensor(src, mask)).to(src.dtype)
srcs.append(src)
masks.append(mask)
poss.append(pos_l)
input_query_bbox = input_query_label = attn_mask = None
hs, reference, hs_enc, ref_enc, init_box_proposal = self.transformer(
srcs, masks, input_query_bbox, poss, input_query_label, attn_mask, text_dict
)
# deformable-detr-like anchor update
outputs_coord_list = []
for dec_lid, (layer_ref_sig, layer_bbox_embed, layer_hs) in enumerate(zip(reference[:-1], self.bbox_embed, hs)):
layer_delta_unsig = layer_bbox_embed(layer_hs)
layer_outputs_unsig = layer_delta_unsig + inverse_sigmoid(layer_ref_sig)
layer_outputs_unsig = layer_outputs_unsig.sigmoid()
outputs_coord_list.append(layer_outputs_unsig)
outputs_coord_list = torch.stack(outputs_coord_list)
# output
outputs_class = torch.stack(
[layer_cls_embed(layer_hs, text_dict) for layer_cls_embed, layer_hs in zip(self.class_embed, hs)]
)
out = {"pred_logits": outputs_class[-1], "pred_boxes": outputs_coord_list[-1]}
# # for intermediate outputs
# if self.aux_loss:
# out['aux_outputs'] = self._set_aux_loss(outputs_class, outputs_coord_list)
# # for encoder output
# if hs_enc is not None:
# # prepare intermediate outputs
# interm_coord = ref_enc[-1]
# interm_class = self.transformer.enc_out_class_embed(hs_enc[-1], text_dict)
# out['interm_outputs'] = {'pred_logits': interm_class, 'pred_boxes': interm_coord}
# out['interm_outputs_for_matching_pre'] = {'pred_logits': interm_class, 'pred_boxes': init_box_proposal}
return out
@torch.jit.unused
def _set_aux_loss(self, outputs_class, outputs_coord):
# this is a workaround to make torchscript happy, as torchscript
# doesn't support dictionary with non-homogeneous values, such
# as a dict having both a Tensor and a list.
return [{"pred_logits": a, "pred_boxes": b} for a, b in zip(outputs_class[:-1], outputs_coord[:-1])]
@MODULE_BUILD_FUNCS.registe_with_name(module_name="groundingdino")
def build_groundingdino(args):
backbone = build_backbone(args)
transformer = build_transformer(args)
dn_labelbook_size = args.dn_labelbook_size
dec_pred_bbox_embed_share = args.dec_pred_bbox_embed_share
sub_sentence_present = args.sub_sentence_present
model = GroundingDINO(
backbone,
transformer,
num_queries=args.num_queries,
aux_loss=True,
iter_update=True,
query_dim=4,
num_feature_levels=args.num_feature_levels,
nheads=args.nheads,
dec_pred_bbox_embed_share=dec_pred_bbox_embed_share,
two_stage_type=args.two_stage_type,
two_stage_bbox_embed_share=args.two_stage_bbox_embed_share,
two_stage_class_embed_share=args.two_stage_class_embed_share,
num_patterns=args.num_patterns,
dn_number=0,
dn_box_noise_scale=args.dn_box_noise_scale,
dn_label_noise_ratio=args.dn_label_noise_ratio,
dn_labelbook_size=dn_labelbook_size,
text_encoder_type=args.text_encoder_type,
sub_sentence_present=sub_sentence_present,
max_text_len=args.max_text_len,
)
return model

View File

@@ -1,340 +0,0 @@
# ------------------------------------------------------------------------
# Grounding DINO
# url: https://github.com/IDEA-Research/GroundingDINO
# Copyright (c) 2023 IDEA. All Rights Reserved.
# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
# ------------------------------------------------------------------------
# Deformable DETR
# Copyright (c) 2020 SenseTime. All Rights Reserved.
# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
# ------------------------------------------------------------------------------------------------
# Modified from:
# https://github.com/fundamentalvision/Deformable-DETR/blob/main/models/ops/functions/ms_deform_attn_func.py
# https://github.com/fundamentalvision/Deformable-DETR/blob/main/models/ops/modules/ms_deform_attn.py
# https://github.com/open-mmlab/mmcv/blob/master/mmcv/ops/multi_scale_deform_attn.py
# ------------------------------------------------------------------------------------------------
import math
import warnings
from typing import Optional
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.init import constant_, xavier_uniform_
# helpers
def _is_power_of_2(n):
if (not isinstance(n, int)) or (n < 0):
raise ValueError("invalid input for _is_power_of_2: {} (type: {})".format(n, type(n)))
return (n & (n - 1) == 0) and n != 0
def multi_scale_deformable_attn_pytorch(
value: torch.Tensor,
value_spatial_shapes: torch.Tensor,
sampling_locations: torch.Tensor,
attention_weights: torch.Tensor,
) -> torch.Tensor:
bs, _, num_heads, embed_dims = value.shape
_, num_queries, num_heads, num_levels, num_points, _ = sampling_locations.shape
value_list = value.split([H_ * W_ for H_, W_ in value_spatial_shapes], dim=1)
sampling_grids = 2 * sampling_locations - 1
sampling_value_list = []
for level, (H_, W_) in enumerate(value_spatial_shapes):
# bs, H_*W_, num_heads, embed_dims ->
# bs, H_*W_, num_heads*embed_dims ->
# bs, num_heads*embed_dims, H_*W_ ->
# bs*num_heads, embed_dims, H_, W_
value_l_ = value_list[level].flatten(2).transpose(1, 2).reshape(bs * num_heads, embed_dims, H_, W_)
# bs, num_queries, num_heads, num_points, 2 ->
# bs, num_heads, num_queries, num_points, 2 ->
# bs*num_heads, num_queries, num_points, 2
sampling_grid_l_ = sampling_grids[:, :, :, level].transpose(1, 2).flatten(0, 1)
# bs*num_heads, embed_dims, num_queries, num_points
sampling_value_l_ = F.grid_sample(
value_l_, sampling_grid_l_, mode="bilinear", padding_mode="zeros", align_corners=False
)
sampling_value_list.append(sampling_value_l_)
# (bs, num_queries, num_heads, num_levels, num_points) ->
# (bs, num_heads, num_queries, num_levels, num_points) ->
# (bs, num_heads, 1, num_queries, num_levels*num_points)
attention_weights = attention_weights.transpose(1, 2).reshape(
bs * num_heads, 1, num_queries, num_levels * num_points
)
output = (
(torch.stack(sampling_value_list, dim=-2).flatten(-2) * attention_weights)
.sum(-1)
.view(bs, num_heads * embed_dims, num_queries)
)
return output.transpose(1, 2).contiguous()
class MultiScaleDeformableAttention(nn.Module):
"""Multi-Scale Deformable Attention Module used in Deformable-DETR
`Deformable DETR: Deformable Transformers for End-to-End Object Detection.
<https://arxiv.org/pdf/2010.04159.pdf>`_.
Args:
embed_dim (int): The embedding dimension of Attention. Default: 256.
num_heads (int): The number of attention heads. Default: 8.
num_levels (int): The number of feature map used in Attention. Default: 4.
num_points (int): The number of sampling points for each query
in each head. Default: 4.
img2col_steps (int): The step used in image_to_column. Defualt: 64.
dropout (float): Dropout layer used in output. Default: 0.1.
batch_first (bool): if ``True``, then the input and output tensor will be
provided as `(bs, n, embed_dim)`. Default: False. `(n, bs, embed_dim)`
"""
def __init__(
self,
embed_dim: int = 256,
num_heads: int = 8,
num_levels: int = 4,
num_points: int = 4,
img2col_step: int = 64,
batch_first: bool = False,
):
super().__init__()
if embed_dim % num_heads != 0:
raise ValueError("embed_dim must be divisible by num_heads, but got {} and {}".format(embed_dim, num_heads))
head_dim = embed_dim // num_heads
self.batch_first = batch_first
if not _is_power_of_2(head_dim):
warnings.warn(
"""
You'd better set d_model in MSDeformAttn to make sure that
each dim of the attention head a power of 2, which is more efficient.
"""
)
self.im2col_step = img2col_step
self.embed_dim = embed_dim
self.num_heads = num_heads
self.num_levels = num_levels
self.num_points = num_points
self.sampling_offsets = nn.Linear(embed_dim, num_heads * num_levels * num_points * 2)
self.attention_weights = nn.Linear(embed_dim, num_heads * num_levels * num_points)
self.value_proj = nn.Linear(embed_dim, embed_dim)
self.output_proj = nn.Linear(embed_dim, embed_dim)
self.init_weights()
def _reset_parameters(self):
return self.init_weights()
def init_weights(self):
"""
Default initialization for Parameters of Module.
"""
constant_(self.sampling_offsets.weight.data, 0.0)
thetas = torch.arange(self.num_heads, dtype=torch.float32) * (2.0 * math.pi / self.num_heads)
grid_init = torch.stack([thetas.cos(), thetas.sin()], -1)
grid_init = (
(grid_init / grid_init.abs().max(-1, keepdim=True)[0])
.view(self.num_heads, 1, 1, 2)
.repeat(1, self.num_levels, self.num_points, 1)
)
for i in range(self.num_points):
grid_init[:, :, i, :] *= i + 1
with torch.no_grad():
self.sampling_offsets.bias = nn.Parameter(grid_init.view(-1))
constant_(self.attention_weights.weight.data, 0.0)
constant_(self.attention_weights.bias.data, 0.0)
xavier_uniform_(self.value_proj.weight.data)
constant_(self.value_proj.bias.data, 0.0)
xavier_uniform_(self.output_proj.weight.data)
constant_(self.output_proj.bias.data, 0.0)
def freeze_sampling_offsets(self):
print("Freeze sampling offsets")
self.sampling_offsets.weight.requires_grad = False
self.sampling_offsets.bias.requires_grad = False
def freeze_attention_weights(self):
print("Freeze attention weights")
self.attention_weights.weight.requires_grad = False
self.attention_weights.bias.requires_grad = False
def forward(
self,
query: torch.Tensor,
key: Optional[torch.Tensor] = None,
value: Optional[torch.Tensor] = None,
query_pos: Optional[torch.Tensor] = None,
key_padding_mask: Optional[torch.Tensor] = None,
reference_points: Optional[torch.Tensor] = None,
spatial_shapes: Optional[torch.Tensor] = None,
level_start_index: Optional[torch.Tensor] = None,
**kwargs
) -> torch.Tensor:
"""Forward Function of MultiScaleDeformableAttention
Args:
query (torch.Tensor): Query embeddings with shape
`(num_query, bs, embed_dim)`
key (torch.Tensor): Key embeddings with shape
`(num_key, bs, embed_dim)`
value (torch.Tensor): Value embeddings with shape
`(num_key, bs, embed_dim)`
query_pos (torch.Tensor): The position embedding for `query`. Default: None.
key_padding_mask (torch.Tensor): ByteTensor for `query`, with shape `(bs, num_key)`,
indicating which elements within `key` to be ignored in attention.
reference_points (torch.Tensor): The normalized reference points
with shape `(bs, num_query, num_levels, 2)`,
all elements is range in [0, 1], top-left (0, 0),
bottom-right (1, 1), including padding are.
or `(N, Length_{query}, num_levels, 4)`, add additional
two dimensions `(h, w)` to form reference boxes.
spatial_shapes (torch.Tensor): Spatial shape of features in different levels.
With shape `(num_levels, 2)`, last dimension represents `(h, w)`.
level_start_index (torch.Tensor): The start index of each level. A tensor with
shape `(num_levels, )` which can be represented as
`[0, h_0 * w_0, h_0 * w_0 + h_1 * w_1, ...]`.
Returns:
torch.Tensor: forward results with shape `(num_query, bs, embed_dim)`
"""
if value is None:
value = query
if query_pos is not None:
query = query + query_pos
if not self.batch_first:
# change to (bs, num_query ,embed_dims)
query = query.permute(1, 0, 2)
value = value.permute(1, 0, 2)
bs, num_query, _ = query.shape
bs, num_value, _ = value.shape
assert (spatial_shapes[:, 0] * spatial_shapes[:, 1]).sum() == num_value
value = self.value_proj(value)
if key_padding_mask is not None:
value = value.masked_fill(key_padding_mask[..., None], float(0))
value = value.view(bs, num_value, self.num_heads, -1)
sampling_offsets = self.sampling_offsets(query).view(
bs, num_query, self.num_heads, self.num_levels, self.num_points, 2
)
attention_weights = self.attention_weights(query).view(
bs, num_query, self.num_heads, self.num_levels * self.num_points
)
attention_weights = attention_weights.softmax(-1)
attention_weights = attention_weights.view(
bs,
num_query,
self.num_heads,
self.num_levels,
self.num_points,
)
# bs, num_query, num_heads, num_levels, num_points, 2
if reference_points.shape[-1] == 2:
offset_normalizer = torch.stack([spatial_shapes[..., 1], spatial_shapes[..., 0]], -1)
sampling_locations = (
reference_points[:, :, None, :, None, :]
+ sampling_offsets / offset_normalizer[None, None, None, :, None, :]
)
elif reference_points.shape[-1] == 4:
sampling_locations = (
reference_points[:, :, None, :, None, :2]
+ sampling_offsets / self.num_points * reference_points[:, :, None, :, None, 2:] * 0.5
)
else:
raise ValueError(
"Last dim of reference_points must be 2 or 4, but get {} instead.".format(reference_points.shape[-1])
)
# if torch.cuda.is_available() and value.is_cuda:
# halffloat = False
# if value.dtype == torch.float16:
# halffloat = True
# value = value.float()
# sampling_locations = sampling_locations.float()
# attention_weights = attention_weights.float()
# output = MultiScaleDeformableAttnFunction.apply(
# value,
# spatial_shapes,
# level_start_index,
# sampling_locations,
# attention_weights,
# self.im2col_step,
# )
# if halffloat:
# output = output.half()
# else:
# output = multi_scale_deformable_attn_pytorch(value, spatial_shapes, sampling_locations, attention_weights)
output = multi_scale_deformable_attn_pytorch(value, spatial_shapes, sampling_locations, attention_weights)
output = self.output_proj(output)
if not self.batch_first:
output = output.permute(1, 0, 2)
return output
def create_dummy_class(klass, dependency, message=""):
"""
When a dependency of a class is not available, create a dummy class which throws ImportError
when used.
Args:
klass (str): name of the class.
dependency (str): name of the dependency.
message: extra message to print
Returns:
class: a class object
"""
err = "Cannot import '{}', therefore '{}' is not available.".format(dependency, klass)
if message:
err = err + " " + message
class _DummyMetaClass(type):
# throw error on class attribute access
def __getattr__(_, __): # noqa: B902
raise ImportError(err)
class _Dummy(object, metaclass=_DummyMetaClass):
# throw error on constructor
def __init__(self, *args, **kwargs):
raise ImportError(err)
return _Dummy
def create_dummy_func(func, dependency, message=""):
"""
When a dependency of a function is not available, create a dummy function which throws
ImportError when used.
Args:
func (str): name of the function.
dependency (str or list[str]): name(s) of the dependency.
message: extra message to print
Returns:
function: a function object
"""
err = "Cannot import '{}', therefore '{}' is not available.".format(dependency, func)
if message:
err = err + " " + message
if isinstance(dependency, (list, tuple)):
dependency = ",".join(dependency)
def _dummy(*args, **kwargs):
raise ImportError(err)
return _dummy

View File

@@ -1,927 +0,0 @@
# ------------------------------------------------------------------------
# Grounding DINO
# url: https://github.com/IDEA-Research/GroundingDINO
# Copyright (c) 2023 IDEA. All Rights Reserved.
# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
# ------------------------------------------------------------------------
# DINO
# Copyright (c) 2022 IDEA. All Rights Reserved.
# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
# ------------------------------------------------------------------------
# Conditional DETR Transformer class.
# Copyright (c) 2021 Microsoft. All Rights Reserved.
# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
# ------------------------------------------------------------------------
# Modified from DETR (https://github.com/facebookresearch/detr)
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
# ------------------------------------------------------------------------
from typing import Optional
import torch
import torch.utils.checkpoint as checkpoint
from torch import Tensor, nn
from invokeai.backend.image_util.grounding_segment_anything.groundingdino.util.misc import inverse_sigmoid
from .fuse_modules import BiAttentionBlock
from .ms_deform_attn import MultiScaleDeformableAttention as MSDeformAttn
from .transformer_vanilla import TransformerEncoderLayer
from .utils import (
MLP,
_get_activation_fn,
_get_clones,
gen_encoder_output_proposals,
gen_sineembed_for_position,
get_sine_pos_embed,
)
class Transformer(nn.Module):
def __init__(
self,
d_model=256,
nhead=8,
num_queries=300,
num_encoder_layers=6,
num_unicoder_layers=0,
num_decoder_layers=6,
dim_feedforward=2048,
dropout=0.0,
activation="relu",
normalize_before=False,
return_intermediate_dec=False,
query_dim=4,
num_patterns=0,
# for deformable encoder
num_feature_levels=1,
enc_n_points=4,
dec_n_points=4,
# init query
learnable_tgt_init=False,
# two stage
two_stage_type="no", # ['no', 'standard', 'early', 'combine', 'enceachlayer', 'enclayer1']
embed_init_tgt=False,
# for text
use_text_enhancer=False,
use_fusion_layer=False,
use_checkpoint=False,
use_transformer_ckpt=False,
use_text_cross_attention=False,
text_dropout=0.1,
fusion_dropout=0.1,
fusion_droppath=0.0,
):
super().__init__()
self.num_feature_levels = num_feature_levels
self.num_encoder_layers = num_encoder_layers
self.num_unicoder_layers = num_unicoder_layers
self.num_decoder_layers = num_decoder_layers
self.num_queries = num_queries
assert query_dim == 4
# choose encoder layer type
encoder_layer = DeformableTransformerEncoderLayer(
d_model, dim_feedforward, dropout, activation, num_feature_levels, nhead, enc_n_points
)
if use_text_enhancer:
text_enhance_layer = TransformerEncoderLayer(
d_model=d_model,
nhead=nhead // 2,
dim_feedforward=dim_feedforward // 2,
dropout=text_dropout,
)
else:
text_enhance_layer = None
if use_fusion_layer:
feature_fusion_layer = BiAttentionBlock(
v_dim=d_model,
l_dim=d_model,
embed_dim=dim_feedforward // 2,
num_heads=nhead // 2,
dropout=fusion_dropout,
drop_path=fusion_droppath,
)
else:
feature_fusion_layer = None
encoder_norm = nn.LayerNorm(d_model) if normalize_before else None
assert encoder_norm is None
self.encoder = TransformerEncoder(
encoder_layer,
num_encoder_layers,
d_model=d_model,
num_queries=num_queries,
text_enhance_layer=text_enhance_layer,
feature_fusion_layer=feature_fusion_layer,
use_checkpoint=use_checkpoint,
use_transformer_ckpt=use_transformer_ckpt,
)
# choose decoder layer type
decoder_layer = DeformableTransformerDecoderLayer(
d_model,
dim_feedforward,
dropout,
activation,
num_feature_levels,
nhead,
dec_n_points,
use_text_cross_attention=use_text_cross_attention,
)
decoder_norm = nn.LayerNorm(d_model)
self.decoder = TransformerDecoder(
decoder_layer,
num_decoder_layers,
decoder_norm,
return_intermediate=return_intermediate_dec,
d_model=d_model,
query_dim=query_dim,
num_feature_levels=num_feature_levels,
)
self.d_model = d_model
self.nhead = nhead
self.dec_layers = num_decoder_layers
self.num_queries = num_queries # useful for single stage model only
self.num_patterns = num_patterns
if not isinstance(num_patterns, int):
Warning("num_patterns should be int but {}".format(type(num_patterns)))
self.num_patterns = 0
if num_feature_levels > 1:
if self.num_encoder_layers > 0:
self.level_embed = nn.Parameter(torch.Tensor(num_feature_levels, d_model))
else:
self.level_embed = None
self.learnable_tgt_init = learnable_tgt_init
assert learnable_tgt_init, "why not learnable_tgt_init"
self.embed_init_tgt = embed_init_tgt
if (two_stage_type != "no" and embed_init_tgt) or (two_stage_type == "no"):
self.tgt_embed = nn.Embedding(self.num_queries, d_model)
nn.init.normal_(self.tgt_embed.weight.data)
else:
self.tgt_embed = None
# for two stage
self.two_stage_type = two_stage_type
assert two_stage_type in ["no", "standard"], "unknown param {} of two_stage_type".format(two_stage_type)
if two_stage_type == "standard":
# anchor selection at the output of encoder
self.enc_output = nn.Linear(d_model, d_model)
self.enc_output_norm = nn.LayerNorm(d_model)
self.two_stage_wh_embedding = None
if two_stage_type == "no":
self.init_ref_points(num_queries) # init self.refpoint_embed
self.enc_out_class_embed = None
self.enc_out_bbox_embed = None
self._reset_parameters()
def _reset_parameters(self):
for p in self.parameters():
if p.dim() > 1:
nn.init.xavier_uniform_(p)
for m in self.modules():
if isinstance(m, MSDeformAttn):
m._reset_parameters()
if self.num_feature_levels > 1 and self.level_embed is not None:
nn.init.normal_(self.level_embed)
def get_valid_ratio(self, mask):
_, H, W = mask.shape
valid_H = torch.sum(~mask[:, :, 0], 1)
valid_W = torch.sum(~mask[:, 0, :], 1)
valid_ratio_h = valid_H.float() / H
valid_ratio_w = valid_W.float() / W
valid_ratio = torch.stack([valid_ratio_w, valid_ratio_h], -1)
return valid_ratio
def init_ref_points(self, use_num_queries):
self.refpoint_embed = nn.Embedding(use_num_queries, 4)
def forward(self, srcs, masks, refpoint_embed, pos_embeds, tgt, attn_mask=None, text_dict=None):
"""
Input:
- srcs: List of multi features [bs, ci, hi, wi]
- masks: List of multi masks [bs, hi, wi]
- refpoint_embed: [bs, num_dn, 4]. None in infer
- pos_embeds: List of multi pos embeds [bs, ci, hi, wi]
- tgt: [bs, num_dn, d_model]. None in infer
"""
# prepare input for encoder
src_flatten = []
mask_flatten = []
lvl_pos_embed_flatten = []
spatial_shapes = []
for lvl, (src, mask, pos_embed) in enumerate(zip(srcs, masks, pos_embeds)):
bs, c, h, w = src.shape
spatial_shape = (h, w)
spatial_shapes.append(spatial_shape)
src = src.flatten(2).transpose(1, 2) # bs, hw, c
mask = mask.flatten(1) # bs, hw
pos_embed = pos_embed.flatten(2).transpose(1, 2) # bs, hw, c
if self.num_feature_levels > 1 and self.level_embed is not None:
lvl_pos_embed = pos_embed + self.level_embed[lvl].view(1, 1, -1)
else:
lvl_pos_embed = pos_embed
lvl_pos_embed_flatten.append(lvl_pos_embed)
src_flatten.append(src)
mask_flatten.append(mask)
src_flatten = torch.cat(src_flatten, 1) # bs, \sum{hxw}, c
mask_flatten = torch.cat(mask_flatten, 1) # bs, \sum{hxw}
lvl_pos_embed_flatten = torch.cat(lvl_pos_embed_flatten, 1) # bs, \sum{hxw}, c
spatial_shapes = torch.as_tensor(spatial_shapes, dtype=torch.long, device=src_flatten.device)
level_start_index = torch.cat((spatial_shapes.new_zeros((1,)), spatial_shapes.prod(1).cumsum(0)[:-1]))
valid_ratios = torch.stack([self.get_valid_ratio(m) for m in masks], 1).to(src.dtype)
# two stage
# enc_topk_proposals = enc_refpoint_embed = None
#########################################################
# Begin Encoder
#########################################################
memory, memory_text = self.encoder(
src_flatten,
pos=lvl_pos_embed_flatten,
level_start_index=level_start_index,
spatial_shapes=spatial_shapes,
valid_ratios=valid_ratios,
key_padding_mask=mask_flatten,
memory_text=text_dict["encoded_text"],
text_attention_mask=~text_dict["text_token_mask"],
# we ~ the mask . False means use the token; True means pad the token
position_ids=text_dict["position_ids"],
text_self_attention_masks=text_dict["text_self_attention_masks"],
)
#########################################################
# End Encoder
# - memory: bs, \sum{hw}, c
# - mask_flatten: bs, \sum{hw}
# - lvl_pos_embed_flatten: bs, \sum{hw}, c
# - enc_intermediate_output: None or (nenc+1, bs, nq, c) or (nenc, bs, nq, c)
# - enc_intermediate_refpoints: None or (nenc+1, bs, nq, c) or (nenc, bs, nq, c)
#########################################################
text_dict["encoded_text"] = memory_text
# if os.environ.get("SHILONG_AMP_INFNAN_DEBUG") == '1':
# if memory.isnan().any() | memory.isinf().any():
# import ipdb; ipdb.set_trace()
if self.two_stage_type == "standard":
output_memory, output_proposals = gen_encoder_output_proposals(memory, mask_flatten, spatial_shapes)
output_memory = self.enc_output_norm(self.enc_output(output_memory))
if text_dict is not None:
enc_outputs_class_unselected = self.enc_out_class_embed(output_memory, text_dict)
else:
enc_outputs_class_unselected = self.enc_out_class_embed(output_memory)
topk_logits = enc_outputs_class_unselected.max(-1)[0]
enc_outputs_coord_unselected = (
self.enc_out_bbox_embed(output_memory) + output_proposals
) # (bs, \sum{hw}, 4) unsigmoid
topk = self.num_queries
topk_proposals = torch.topk(topk_logits, topk, dim=1)[1] # bs, nq
# gather boxes
refpoint_embed_undetach = torch.gather(
enc_outputs_coord_unselected, 1, topk_proposals.unsqueeze(-1).repeat(1, 1, 4)
) # unsigmoid
refpoint_embed_ = refpoint_embed_undetach.detach()
init_box_proposal = torch.gather(
output_proposals, 1, topk_proposals.unsqueeze(-1).repeat(1, 1, 4)
).sigmoid() # sigmoid
# gather tgt
tgt_undetach = torch.gather(output_memory, 1, topk_proposals.unsqueeze(-1).repeat(1, 1, self.d_model))
if self.embed_init_tgt:
tgt_ = self.tgt_embed.weight[:, None, :].repeat(1, bs, 1).transpose(0, 1) # nq, bs, d_model
else:
tgt_ = tgt_undetach.detach()
if refpoint_embed is not None:
refpoint_embed = torch.cat([refpoint_embed, refpoint_embed_], dim=1)
tgt = torch.cat([tgt, tgt_], dim=1)
else:
refpoint_embed, tgt = refpoint_embed_, tgt_
elif self.two_stage_type == "no":
tgt_ = self.tgt_embed.weight[:, None, :].repeat(1, bs, 1).transpose(0, 1) # nq, bs, d_model
refpoint_embed_ = self.refpoint_embed.weight[:, None, :].repeat(1, bs, 1).transpose(0, 1) # nq, bs, 4
if refpoint_embed is not None:
refpoint_embed = torch.cat([refpoint_embed, refpoint_embed_], dim=1)
tgt = torch.cat([tgt, tgt_], dim=1)
else:
refpoint_embed, tgt = refpoint_embed_, tgt_
if self.num_patterns > 0:
tgt_embed = tgt.repeat(1, self.num_patterns, 1)
refpoint_embed = refpoint_embed.repeat(1, self.num_patterns, 1)
tgt_pat = self.patterns.weight[None, :, :].repeat_interleave(
self.num_queries, 1
) # 1, n_q*n_pat, d_model
tgt = tgt_embed + tgt_pat
init_box_proposal = refpoint_embed_.sigmoid()
else:
raise NotImplementedError("unknown two_stage_type {}".format(self.two_stage_type))
#########################################################
# End preparing tgt
# - tgt: bs, NQ, d_model
# - refpoint_embed(unsigmoid): bs, NQ, d_model
#########################################################
#########################################################
# Begin Decoder
#########################################################
hs, references = self.decoder(
tgt=tgt.transpose(0, 1),
memory=memory.transpose(0, 1),
memory_key_padding_mask=mask_flatten,
pos=lvl_pos_embed_flatten.transpose(0, 1),
refpoints_unsigmoid=refpoint_embed.transpose(0, 1),
level_start_index=level_start_index,
spatial_shapes=spatial_shapes,
valid_ratios=valid_ratios,
tgt_mask=attn_mask,
memory_text=text_dict["encoded_text"],
text_attention_mask=~text_dict["text_token_mask"],
# we ~ the mask . False means use the token; True means pad the token
)
#########################################################
# End Decoder
# hs: n_dec, bs, nq, d_model
# references: n_dec+1, bs, nq, query_dim
#########################################################
#########################################################
# Begin postprocess
#########################################################
if self.two_stage_type == "standard":
hs_enc = tgt_undetach.unsqueeze(0)
ref_enc = refpoint_embed_undetach.sigmoid().unsqueeze(0)
else:
hs_enc = ref_enc = None
#########################################################
# End postprocess
# hs_enc: (n_enc+1, bs, nq, d_model) or (1, bs, nq, d_model) or (n_enc, bs, nq, d_model) or None
# ref_enc: (n_enc+1, bs, nq, query_dim) or (1, bs, nq, query_dim) or (n_enc, bs, nq, d_model) or None
#########################################################
return hs, references, hs_enc, ref_enc, init_box_proposal
# hs: (n_dec, bs, nq, d_model)
# references: sigmoid coordinates. (n_dec+1, bs, bq, 4)
# hs_enc: (n_enc+1, bs, nq, d_model) or (1, bs, nq, d_model) or None
# ref_enc: sigmoid coordinates. \
# (n_enc+1, bs, nq, query_dim) or (1, bs, nq, query_dim) or None
class TransformerEncoder(nn.Module):
def __init__(
self,
encoder_layer,
num_layers,
d_model=256,
num_queries=300,
enc_layer_share=False,
text_enhance_layer=None,
feature_fusion_layer=None,
use_checkpoint=False,
use_transformer_ckpt=False,
):
"""_summary_
Args:
encoder_layer (_type_): _description_
num_layers (_type_): _description_
norm (_type_, optional): _description_. Defaults to None.
d_model (int, optional): _description_. Defaults to 256.
num_queries (int, optional): _description_. Defaults to 300.
enc_layer_share (bool, optional): _description_. Defaults to False.
"""
super().__init__()
# prepare layers
self.layers = []
self.text_layers = []
self.fusion_layers = []
if num_layers > 0:
self.layers = _get_clones(encoder_layer, num_layers, layer_share=enc_layer_share)
if text_enhance_layer is not None:
self.text_layers = _get_clones(text_enhance_layer, num_layers, layer_share=enc_layer_share)
if feature_fusion_layer is not None:
self.fusion_layers = _get_clones(feature_fusion_layer, num_layers, layer_share=enc_layer_share)
else:
self.layers = []
del encoder_layer
if text_enhance_layer is not None:
self.text_layers = []
del text_enhance_layer
if feature_fusion_layer is not None:
self.fusion_layers = []
del feature_fusion_layer
self.query_scale = None
self.num_queries = num_queries
self.num_layers = num_layers
self.d_model = d_model
self.use_checkpoint = use_checkpoint
self.use_transformer_ckpt = use_transformer_ckpt
@staticmethod
def get_reference_points(spatial_shapes, valid_ratios, device):
reference_points_list = []
for lvl, (H_, W_) in enumerate(spatial_shapes):
ref_y, ref_x = torch.meshgrid(
torch.linspace(0.5, H_ - 0.5, H_, dtype=torch.float32, device=device),
torch.linspace(0.5, W_ - 0.5, W_, dtype=torch.float32, device=device),
)
ref_y = ref_y.reshape(-1)[None] / (valid_ratios[:, None, lvl, 1] * H_)
ref_x = ref_x.reshape(-1)[None] / (valid_ratios[:, None, lvl, 0] * W_)
ref = torch.stack((ref_x, ref_y), -1)
reference_points_list.append(ref)
reference_points = torch.cat(reference_points_list, 1)
reference_points = reference_points[:, :, None] * valid_ratios[:, None]
return reference_points
def forward(
self,
# for images
src: Tensor,
pos: Tensor,
spatial_shapes: Tensor,
level_start_index: Tensor,
valid_ratios: Tensor,
key_padding_mask: Tensor,
# for texts
memory_text: Tensor = None,
text_attention_mask: Tensor = None,
pos_text: Tensor = None,
text_self_attention_masks: Tensor = None,
position_ids: Tensor = None,
):
"""
Input:
- src: [bs, sum(hi*wi), 256]
- pos: pos embed for src. [bs, sum(hi*wi), 256]
- spatial_shapes: h,w of each level [num_level, 2]
- level_start_index: [num_level] start point of level in sum(hi*wi).
- valid_ratios: [bs, num_level, 2]
- key_padding_mask: [bs, sum(hi*wi)]
- memory_text: bs, n_text, 256
- text_attention_mask: bs, n_text
False for no padding; True for padding
- pos_text: bs, n_text, 256
- position_ids: bs, n_text
Intermedia:
- reference_points: [bs, sum(hi*wi), num_level, 2]
Outpus:
- output: [bs, sum(hi*wi), 256]
"""
output = src
# preparation and reshape
if self.num_layers > 0:
reference_points = self.get_reference_points(spatial_shapes, valid_ratios, device=src.device)
if self.text_layers:
# generate pos_text
bs, n_text, text_dim = memory_text.shape
if pos_text is None and position_ids is None:
pos_text = (
torch.arange(n_text, device=memory_text.device).float().unsqueeze(0).unsqueeze(-1).repeat(bs, 1, 1)
)
pos_text = get_sine_pos_embed(pos_text, num_pos_feats=256, exchange_xy=False)
if position_ids is not None:
pos_text = get_sine_pos_embed(position_ids[..., None], num_pos_feats=256, exchange_xy=False)
pos_text = pos_text.to(src.dtype)
# main process
for layer_id, layer in enumerate(self.layers):
# if output.isnan().any() or memory_text.isnan().any():
# if os.environ.get('IPDB_SHILONG_DEBUG', None) == 'INFO':
# import ipdb; ipdb.set_trace()
if self.fusion_layers:
if self.use_checkpoint:
output, memory_text = checkpoint.checkpoint(
self.fusion_layers[layer_id],
output,
memory_text,
key_padding_mask,
text_attention_mask,
)
else:
output, memory_text = self.fusion_layers[layer_id](
v=output,
l=memory_text,
attention_mask_v=key_padding_mask,
attention_mask_l=text_attention_mask,
)
if self.text_layers:
memory_text = self.text_layers[layer_id](
src=memory_text.transpose(0, 1),
src_mask=~text_self_attention_masks, # note we use ~ for mask here
src_key_padding_mask=text_attention_mask,
pos=(pos_text.transpose(0, 1) if pos_text is not None else None),
).transpose(0, 1)
# main process
if self.use_transformer_ckpt:
output = checkpoint.checkpoint(
layer,
output,
pos,
reference_points,
spatial_shapes,
level_start_index,
key_padding_mask,
)
else:
output = layer(
src=output,
pos=pos,
reference_points=reference_points,
spatial_shapes=spatial_shapes,
level_start_index=level_start_index,
key_padding_mask=key_padding_mask,
)
return output, memory_text
class TransformerDecoder(nn.Module):
def __init__(
self,
decoder_layer,
num_layers,
norm=None,
return_intermediate=False,
d_model=256,
query_dim=4,
num_feature_levels=1,
):
super().__init__()
if num_layers > 0:
self.layers = _get_clones(decoder_layer, num_layers)
else:
self.layers = []
self.num_layers = num_layers
self.norm = norm
self.return_intermediate = return_intermediate
assert return_intermediate, "support return_intermediate only"
self.query_dim = query_dim
assert query_dim in [2, 4], "query_dim should be 2/4 but {}".format(query_dim)
self.num_feature_levels = num_feature_levels
self.ref_point_head = MLP(query_dim // 2 * d_model, d_model, d_model, 2)
self.query_pos_sine_scale = None
self.query_scale = None
self.bbox_embed = None
self.class_embed = None
self.d_model = d_model
self.ref_anchor_head = None
def forward(
self,
tgt,
memory,
tgt_mask: Optional[Tensor] = None,
memory_mask: Optional[Tensor] = None,
tgt_key_padding_mask: Optional[Tensor] = None,
memory_key_padding_mask: Optional[Tensor] = None,
pos: Optional[Tensor] = None,
refpoints_unsigmoid: Optional[Tensor] = None, # num_queries, bs, 2
# for memory
level_start_index: Optional[Tensor] = None, # num_levels
spatial_shapes: Optional[Tensor] = None, # bs, num_levels, 2
valid_ratios: Optional[Tensor] = None,
# for text
memory_text: Optional[Tensor] = None,
text_attention_mask: Optional[Tensor] = None,
):
"""
Input:
- tgt: nq, bs, d_model
- memory: hw, bs, d_model
- pos: hw, bs, d_model
- refpoints_unsigmoid: nq, bs, 2/4
- valid_ratios/spatial_shapes: bs, nlevel, 2
"""
output = tgt
intermediate = []
reference_points = refpoints_unsigmoid.sigmoid()
ref_points = [reference_points]
for layer_id, layer in enumerate(self.layers):
if reference_points.shape[-1] == 4:
reference_points_input = (
reference_points[:, :, None] * torch.cat([valid_ratios, valid_ratios], -1)[None, :]
) # nq, bs, nlevel, 4
else:
assert reference_points.shape[-1] == 2
reference_points_input = reference_points[:, :, None] * valid_ratios[None, :]
query_sine_embed = gen_sineembed_for_position(reference_points_input[:, :, 0, :]) # nq, bs, 256*2
# conditional query
raw_query_pos = self.ref_point_head(query_sine_embed) # nq, bs, 256
pos_scale = self.query_scale(output) if self.query_scale is not None else 1
query_pos = pos_scale * raw_query_pos
# if os.environ.get("SHILONG_AMP_INFNAN_DEBUG") == '1':
# if query_pos.isnan().any() | query_pos.isinf().any():
# import ipdb; ipdb.set_trace()
# main process
output = layer(
tgt=output,
tgt_query_pos=query_pos,
tgt_query_sine_embed=query_sine_embed,
tgt_key_padding_mask=tgt_key_padding_mask,
tgt_reference_points=reference_points_input,
memory_text=memory_text,
text_attention_mask=text_attention_mask,
memory=memory,
memory_key_padding_mask=memory_key_padding_mask,
memory_level_start_index=level_start_index,
memory_spatial_shapes=spatial_shapes,
memory_pos=pos,
self_attn_mask=tgt_mask,
cross_attn_mask=memory_mask,
)
if output.isnan().any() | output.isinf().any():
print(f"output layer_id {layer_id} is nan")
try:
num_nan = output.isnan().sum().item()
num_inf = output.isinf().sum().item()
print(f"num_nan {num_nan}, num_inf {num_inf}")
except Exception as e:
print(e)
# if os.environ.get("SHILONG_AMP_INFNAN_DEBUG") == '1':
# import ipdb; ipdb.set_trace()
# iter update
if self.bbox_embed is not None:
# box_holder = self.bbox_embed(output)
# box_holder[..., :self.query_dim] += inverse_sigmoid(reference_points)
# new_reference_points = box_holder[..., :self.query_dim].sigmoid()
reference_before_sigmoid = inverse_sigmoid(reference_points)
delta_unsig = self.bbox_embed[layer_id](output)
outputs_unsig = delta_unsig + reference_before_sigmoid
new_reference_points = outputs_unsig.sigmoid()
reference_points = new_reference_points.detach()
# if layer_id != self.num_layers - 1:
ref_points.append(new_reference_points)
intermediate.append(self.norm(output))
return [
[itm_out.transpose(0, 1) for itm_out in intermediate],
[itm_refpoint.transpose(0, 1) for itm_refpoint in ref_points],
]
class DeformableTransformerEncoderLayer(nn.Module):
def __init__(
self,
d_model=256,
d_ffn=1024,
dropout=0.1,
activation="relu",
n_levels=4,
n_heads=8,
n_points=4,
):
super().__init__()
# self attention
self.self_attn = MSDeformAttn(
embed_dim=d_model,
num_levels=n_levels,
num_heads=n_heads,
num_points=n_points,
batch_first=True,
)
self.dropout1 = nn.Dropout(dropout)
self.norm1 = nn.LayerNorm(d_model)
# ffn
self.linear1 = nn.Linear(d_model, d_ffn)
self.activation = _get_activation_fn(activation, d_model=d_ffn)
self.dropout2 = nn.Dropout(dropout)
self.linear2 = nn.Linear(d_ffn, d_model)
self.dropout3 = nn.Dropout(dropout)
self.norm2 = nn.LayerNorm(d_model)
@staticmethod
def with_pos_embed(tensor, pos):
return tensor if pos is None else tensor + pos
def forward_ffn(self, src):
src2 = self.linear2(self.dropout2(self.activation(self.linear1(src))))
src = src + self.dropout3(src2)
src = self.norm2(src)
return src
def forward(self, src, pos, reference_points, spatial_shapes, level_start_index, key_padding_mask=None):
# self attention
# import ipdb; ipdb.set_trace()
src2 = self.self_attn(
query=self.with_pos_embed(src, pos),
reference_points=reference_points,
value=src,
spatial_shapes=spatial_shapes,
level_start_index=level_start_index,
key_padding_mask=key_padding_mask,
)
src = src + self.dropout1(src2)
src = self.norm1(src)
# ffn
src = self.forward_ffn(src)
return src
class DeformableTransformerDecoderLayer(nn.Module):
def __init__(
self,
d_model=256,
d_ffn=1024,
dropout=0.1,
activation="relu",
n_levels=4,
n_heads=8,
n_points=4,
use_text_feat_guide=False,
use_text_cross_attention=False,
):
super().__init__()
# cross attention
self.cross_attn = MSDeformAttn(
embed_dim=d_model,
num_levels=n_levels,
num_heads=n_heads,
num_points=n_points,
batch_first=True,
)
self.dropout1 = nn.Dropout(dropout) if dropout > 0 else nn.Identity()
self.norm1 = nn.LayerNorm(d_model)
# cross attention text
if use_text_cross_attention:
self.ca_text = nn.MultiheadAttention(d_model, n_heads, dropout=dropout)
self.catext_dropout = nn.Dropout(dropout) if dropout > 0 else nn.Identity()
self.catext_norm = nn.LayerNorm(d_model)
# self attention
self.self_attn = nn.MultiheadAttention(d_model, n_heads, dropout=dropout)
self.dropout2 = nn.Dropout(dropout) if dropout > 0 else nn.Identity()
self.norm2 = nn.LayerNorm(d_model)
# ffn
self.linear1 = nn.Linear(d_model, d_ffn)
self.activation = _get_activation_fn(activation, d_model=d_ffn, batch_dim=1)
self.dropout3 = nn.Dropout(dropout) if dropout > 0 else nn.Identity()
self.linear2 = nn.Linear(d_ffn, d_model)
self.dropout4 = nn.Dropout(dropout) if dropout > 0 else nn.Identity()
self.norm3 = nn.LayerNorm(d_model)
self.key_aware_proj = None
self.use_text_feat_guide = use_text_feat_guide
assert not use_text_feat_guide
self.use_text_cross_attention = use_text_cross_attention
def rm_self_attn_modules(self):
self.self_attn = None
self.dropout2 = None
self.norm2 = None
@staticmethod
def with_pos_embed(tensor, pos):
return tensor if pos is None else tensor + pos
def forward_ffn(self, tgt):
with torch.cuda.amp.autocast(enabled=False):
tgt2 = self.linear2(self.dropout3(self.activation(self.linear1(tgt))))
tgt = tgt + self.dropout4(tgt2)
tgt = self.norm3(tgt)
return tgt
def forward(
self,
# for tgt
tgt: Optional[Tensor], # nq, bs, d_model
tgt_query_pos: Optional[Tensor] = None, # pos for query. MLP(Sine(pos))
tgt_query_sine_embed: Optional[Tensor] = None, # pos for query. Sine(pos)
tgt_key_padding_mask: Optional[Tensor] = None,
tgt_reference_points: Optional[Tensor] = None, # nq, bs, 4
memory_text: Optional[Tensor] = None, # bs, num_token, d_model
text_attention_mask: Optional[Tensor] = None, # bs, num_token
# for memory
memory: Optional[Tensor] = None, # hw, bs, d_model
memory_key_padding_mask: Optional[Tensor] = None,
memory_level_start_index: Optional[Tensor] = None, # num_levels
memory_spatial_shapes: Optional[Tensor] = None, # bs, num_levels, 2
memory_pos: Optional[Tensor] = None, # pos for memory
# sa
self_attn_mask: Optional[Tensor] = None, # mask used for self-attention
cross_attn_mask: Optional[Tensor] = None, # mask used for cross-attention
):
"""
Input:
- tgt/tgt_query_pos: nq, bs, d_model
-
"""
assert cross_attn_mask is None
# self attention
if self.self_attn is not None:
# import ipdb; ipdb.set_trace()
q = k = self.with_pos_embed(tgt, tgt_query_pos)
tgt2 = self.self_attn(q, k, tgt, attn_mask=self_attn_mask)[0]
tgt = tgt + self.dropout2(tgt2)
tgt = self.norm2(tgt)
if self.use_text_cross_attention:
tgt2 = self.ca_text(
self.with_pos_embed(tgt, tgt_query_pos),
memory_text.transpose(0, 1),
memory_text.transpose(0, 1),
key_padding_mask=text_attention_mask,
)[0]
tgt = tgt + self.catext_dropout(tgt2)
tgt = self.catext_norm(tgt)
tgt2 = self.cross_attn(
query=self.with_pos_embed(tgt, tgt_query_pos).transpose(0, 1),
reference_points=tgt_reference_points.transpose(0, 1).contiguous(),
value=memory.transpose(0, 1),
spatial_shapes=memory_spatial_shapes,
level_start_index=memory_level_start_index,
key_padding_mask=memory_key_padding_mask,
).transpose(0, 1)
tgt = tgt + self.dropout1(tgt2)
tgt = self.norm1(tgt)
# ffn
tgt = self.forward_ffn(tgt)
return tgt
def build_transformer(args):
return Transformer(
d_model=args.hidden_dim,
dropout=args.dropout,
nhead=args.nheads,
num_queries=args.num_queries,
dim_feedforward=args.dim_feedforward,
num_encoder_layers=args.enc_layers,
num_decoder_layers=args.dec_layers,
normalize_before=args.pre_norm,
return_intermediate_dec=True,
query_dim=args.query_dim,
activation=args.transformer_activation,
num_patterns=args.num_patterns,
num_feature_levels=args.num_feature_levels,
enc_n_points=args.enc_n_points,
dec_n_points=args.dec_n_points,
learnable_tgt_init=True,
# two stage
two_stage_type=args.two_stage_type, # ['no', 'standard', 'early']
embed_init_tgt=args.embed_init_tgt,
use_text_enhancer=args.use_text_enhancer,
use_fusion_layer=args.use_fusion_layer,
use_checkpoint=args.use_checkpoint,
use_transformer_ckpt=args.use_transformer_ckpt,
use_text_cross_attention=args.use_text_cross_attention,
text_dropout=args.text_dropout,
fusion_dropout=args.fusion_dropout,
fusion_droppath=args.fusion_droppath,
)

View File

@@ -1,115 +0,0 @@
# ------------------------------------------------------------------------
# Grounding DINO
# url: https://github.com/IDEA-Research/GroundingDINO
# Copyright (c) 2023 IDEA. All Rights Reserved.
# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
# ------------------------------------------------------------------------
# Copyright (c) Aishwarya Kamath & Nicolas Carion. Licensed under the Apache License 2.0. All Rights Reserved
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
"""
DETR Transformer class.
Copy-paste from torch.nn.Transformer with modifications:
* positional encodings are passed in MHattention
* extra LN at the end of encoder is removed
* decoder returns a stack of activations from all decoding layers
"""
from typing import Optional
import torch
from torch import Tensor, nn
from .utils import _get_activation_fn, _get_clones
class TextTransformer(nn.Module):
def __init__(self, num_layers, d_model=256, nheads=8, dim_feedforward=2048, dropout=0.1):
super().__init__()
self.num_layers = num_layers
self.d_model = d_model
self.nheads = nheads
self.dim_feedforward = dim_feedforward
self.norm = None
single_encoder_layer = TransformerEncoderLayer(
d_model=d_model, nhead=nheads, dim_feedforward=dim_feedforward, dropout=dropout
)
self.layers = _get_clones(single_encoder_layer, num_layers)
def forward(self, memory_text: torch.Tensor, text_attention_mask: torch.Tensor):
"""
Args:
text_attention_mask: bs, num_token
memory_text: bs, num_token, d_model
Raises:
RuntimeError: _description_
Returns:
output: bs, num_token, d_model
"""
output = memory_text.transpose(0, 1)
for layer in self.layers:
output = layer(output, src_key_padding_mask=text_attention_mask)
if self.norm is not None:
output = self.norm(output)
return output.transpose(0, 1)
class TransformerEncoderLayer(nn.Module):
def __init__(
self,
d_model,
nhead,
dim_feedforward=2048,
dropout=0.1,
activation="relu",
normalize_before=False,
):
super().__init__()
self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
# Implementation of Feedforward model
self.linear1 = nn.Linear(d_model, dim_feedforward)
self.dropout = nn.Dropout(dropout)
self.linear2 = nn.Linear(dim_feedforward, d_model)
self.norm1 = nn.LayerNorm(d_model)
self.norm2 = nn.LayerNorm(d_model)
self.dropout1 = nn.Dropout(dropout)
self.dropout2 = nn.Dropout(dropout)
self.activation = _get_activation_fn(activation)
self.normalize_before = normalize_before
self.nhead = nhead
def with_pos_embed(self, tensor, pos: Optional[Tensor]):
return tensor if pos is None else tensor + pos
def forward(
self,
src,
src_mask: Optional[Tensor] = None,
src_key_padding_mask: Optional[Tensor] = None,
pos: Optional[Tensor] = None,
):
# repeat attn mask
if src_mask.dim() == 3 and src_mask.shape[0] == src.shape[1]:
# bs, num_q, num_k
src_mask = src_mask.repeat(self.nhead, 1, 1)
q = k = self.with_pos_embed(src, pos)
src2 = self.self_attn(q, k, value=src, attn_mask=src_mask)[0]
# src2 = self.self_attn(q, k, value=src, attn_mask=src_mask, key_padding_mask=src_key_padding_mask)[0]
src = src + self.dropout1(src2)
src = self.norm1(src)
src2 = self.linear2(self.dropout(self.activation(self.linear1(src))))
src = src + self.dropout2(src2)
src = self.norm2(src)
return src

View File

@@ -1,258 +0,0 @@
# ------------------------------------------------------------------------
# Grounding DINO
# url: https://github.com/IDEA-Research/GroundingDINO
# Copyright (c) 2023 IDEA. All Rights Reserved.
# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
# ------------------------------------------------------------------------
import copy
import math
import torch
import torch.nn.functional as F
from torch import Tensor, nn
def _get_clones(module, N, layer_share=False):
# import ipdb; ipdb.set_trace()
if layer_share:
return nn.ModuleList([module for i in range(N)])
else:
return nn.ModuleList([copy.deepcopy(module) for i in range(N)])
def get_sine_pos_embed(
pos_tensor: torch.Tensor,
num_pos_feats: int = 128,
temperature: int = 10000,
exchange_xy: bool = True,
):
"""generate sine position embedding from a position tensor
Args:
pos_tensor (torch.Tensor): shape: [..., n].
num_pos_feats (int): projected shape for each float in the tensor.
temperature (int): temperature in the sine/cosine function.
exchange_xy (bool, optional): exchange pos x and pos y. \
For example, input tensor is [x,y], the results will be [pos(y), pos(x)]. Defaults to True.
Returns:
pos_embed (torch.Tensor): shape: [..., n*num_pos_feats].
"""
scale = 2 * math.pi
dim_t = torch.arange(num_pos_feats, dtype=torch.float32, device=pos_tensor.device)
dim_t = temperature ** (2 * torch.div(dim_t, 2, rounding_mode="floor") / num_pos_feats)
def sine_func(x: torch.Tensor):
sin_x = x * scale / dim_t
sin_x = torch.stack((sin_x[..., 0::2].sin(), sin_x[..., 1::2].cos()), dim=3).flatten(2)
return sin_x
pos_res = [sine_func(x) for x in pos_tensor.split([1] * pos_tensor.shape[-1], dim=-1)]
if exchange_xy:
pos_res[0], pos_res[1] = pos_res[1], pos_res[0]
pos_res = torch.cat(pos_res, dim=-1)
return pos_res
def gen_encoder_output_proposals(memory: Tensor, memory_padding_mask: Tensor, spatial_shapes: Tensor, learnedwh=None):
"""
Input:
- memory: bs, \sum{hw}, d_model
- memory_padding_mask: bs, \sum{hw}
- spatial_shapes: nlevel, 2
- learnedwh: 2
Output:
- output_memory: bs, \sum{hw}, d_model
- output_proposals: bs, \sum{hw}, 4
"""
N_, S_, C_ = memory.shape
proposals = []
_cur = 0
for lvl, (H_, W_) in enumerate(spatial_shapes):
mask_flatten_ = memory_padding_mask[:, _cur : (_cur + H_ * W_)].view(N_, H_, W_, 1)
valid_H = torch.sum(~mask_flatten_[:, :, 0, 0], 1)
valid_W = torch.sum(~mask_flatten_[:, 0, :, 0], 1)
# import ipdb; ipdb.set_trace()
grid_y, grid_x = torch.meshgrid(
torch.linspace(0, H_ - 1, H_, dtype=torch.float32, device=memory.device),
torch.linspace(0, W_ - 1, W_, dtype=torch.float32, device=memory.device),
)
grid = torch.cat([grid_x.unsqueeze(-1), grid_y.unsqueeze(-1)], -1) # H_, W_, 2
scale = torch.cat([valid_W.unsqueeze(-1), valid_H.unsqueeze(-1)], 1).view(N_, 1, 1, 2)
grid = (grid.unsqueeze(0).expand(N_, -1, -1, -1) + 0.5) / scale
if learnedwh is not None:
# import ipdb; ipdb.set_trace()
wh = torch.ones_like(grid) * learnedwh.sigmoid() * (2.0**lvl)
else:
wh = torch.ones_like(grid) * 0.05 * (2.0**lvl)
# scale = torch.cat([W_[None].unsqueeze(-1), H_[None].unsqueeze(-1)], 1).view(1, 1, 1, 2).repeat(N_, 1, 1, 1)
# grid = (grid.unsqueeze(0).expand(N_, -1, -1, -1) + 0.5) / scale
# wh = torch.ones_like(grid) / scale
proposal = torch.cat((grid, wh), -1).view(N_, -1, 4)
proposals.append(proposal)
_cur += H_ * W_
# import ipdb; ipdb.set_trace()
output_proposals = torch.cat(proposals, 1)
output_proposals_valid = ((output_proposals > 0.01) & (output_proposals < 0.99)).all(-1, keepdim=True)
output_proposals = torch.log(output_proposals / (1 - output_proposals)) # unsigmoid
output_proposals = output_proposals.masked_fill(memory_padding_mask.unsqueeze(-1), float("inf"))
output_proposals = output_proposals.masked_fill(~output_proposals_valid, float("inf"))
output_memory = memory
output_memory = output_memory.masked_fill(memory_padding_mask.unsqueeze(-1), float(0))
output_memory = output_memory.masked_fill(~output_proposals_valid, float(0))
# output_memory = output_memory.masked_fill(memory_padding_mask.unsqueeze(-1), float('inf'))
# output_memory = output_memory.masked_fill(~output_proposals_valid, float('inf'))
output_proposals = output_proposals.to(output_memory.dtype)
return output_memory, output_proposals
class RandomBoxPerturber:
def __init__(self, x_noise_scale=0.2, y_noise_scale=0.2, w_noise_scale=0.2, h_noise_scale=0.2) -> None:
self.noise_scale = torch.Tensor([x_noise_scale, y_noise_scale, w_noise_scale, h_noise_scale])
def __call__(self, refanchors: Tensor) -> Tensor:
nq, bs, query_dim = refanchors.shape
device = refanchors.device
noise_raw = torch.rand_like(refanchors)
noise_scale = self.noise_scale.to(device)[:query_dim]
new_refanchors = refanchors * (1 + (noise_raw - 0.5) * noise_scale)
return new_refanchors.clamp_(0, 1)
def sigmoid_focal_loss(inputs, targets, num_boxes, alpha: float = 0.25, gamma: float = 2, no_reduction=False):
"""
Loss used in RetinaNet for dense detection: https://arxiv.org/abs/1708.02002.
Args:
inputs: A float tensor of arbitrary shape.
The predictions for each example.
targets: A float tensor with the same shape as inputs. Stores the binary
classification label for each element in inputs
(0 for the negative class and 1 for the positive class).
alpha: (optional) Weighting factor in range (0,1) to balance
positive vs negative examples. Default = -1 (no weighting).
gamma: Exponent of the modulating factor (1 - p_t) to
balance easy vs hard examples.
Returns:
Loss tensor
"""
prob = inputs.sigmoid()
ce_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction="none")
p_t = prob * targets + (1 - prob) * (1 - targets)
loss = ce_loss * ((1 - p_t) ** gamma)
if alpha >= 0:
alpha_t = alpha * targets + (1 - alpha) * (1 - targets)
loss = alpha_t * loss
if no_reduction:
return loss
return loss.mean(1).sum() / num_boxes
class MLP(nn.Module):
"""Very simple multi-layer perceptron (also called FFN)"""
def __init__(self, input_dim, hidden_dim, output_dim, num_layers):
super().__init__()
self.num_layers = num_layers
h = [hidden_dim] * (num_layers - 1)
self.layers = nn.ModuleList(nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim]))
def forward(self, x):
for i, layer in enumerate(self.layers):
x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x)
return x
def _get_activation_fn(activation, d_model=256, batch_dim=0):
"""Return an activation function given a string"""
if activation == "relu":
return F.relu
if activation == "gelu":
return F.gelu
if activation == "glu":
return F.glu
if activation == "prelu":
return nn.PReLU()
if activation == "selu":
return F.selu
raise RuntimeError(f"activation should be relu/gelu, not {activation}.")
def gen_sineembed_for_position(pos_tensor):
# n_query, bs, _ = pos_tensor.size()
# sineembed_tensor = torch.zeros(n_query, bs, 256)
scale = 2 * math.pi
dim_t = torch.arange(128, dtype=torch.float32, device=pos_tensor.device)
dim_t = 10000 ** (2 * (torch.div(dim_t, 2, rounding_mode="floor")) / 128)
x_embed = pos_tensor[:, :, 0] * scale
y_embed = pos_tensor[:, :, 1] * scale
pos_x = x_embed[:, :, None] / dim_t
pos_y = y_embed[:, :, None] / dim_t
pos_x = torch.stack((pos_x[:, :, 0::2].sin(), pos_x[:, :, 1::2].cos()), dim=3).flatten(2)
pos_y = torch.stack((pos_y[:, :, 0::2].sin(), pos_y[:, :, 1::2].cos()), dim=3).flatten(2)
if pos_tensor.size(-1) == 2:
pos = torch.cat((pos_y, pos_x), dim=2)
elif pos_tensor.size(-1) == 4:
w_embed = pos_tensor[:, :, 2] * scale
pos_w = w_embed[:, :, None] / dim_t
pos_w = torch.stack((pos_w[:, :, 0::2].sin(), pos_w[:, :, 1::2].cos()), dim=3).flatten(2)
h_embed = pos_tensor[:, :, 3] * scale
pos_h = h_embed[:, :, None] / dim_t
pos_h = torch.stack((pos_h[:, :, 0::2].sin(), pos_h[:, :, 1::2].cos()), dim=3).flatten(2)
pos = torch.cat((pos_y, pos_x, pos_w, pos_h), dim=2)
else:
raise ValueError("Unknown pos_tensor shape(-1):{}".format(pos_tensor.size(-1)))
pos = pos.to(pos_tensor.dtype)
return pos
class ContrastiveEmbed(nn.Module):
def __init__(self, max_text_len=256):
"""
Args:
max_text_len: max length of text.
"""
super().__init__()
self.max_text_len = max_text_len
def forward(self, x, text_dict):
"""_summary_
Args:
x (_type_): _description_
text_dict (_type_): _description_
{
'encoded_text': encoded_text, # bs, 195, d_model
'text_token_mask': text_token_mask, # bs, 195
# True for used tokens. False for padding tokens
}
Returns:
_type_: _description_
"""
assert isinstance(text_dict, dict)
y = text_dict["encoded_text"]
text_token_mask = text_dict["text_token_mask"]
res = x @ y.transpose(-1, -2)
res.masked_fill_(~text_token_mask[:, None, :], float("-inf"))
# padding to max_text_len
new_res = torch.full((*res.shape[:-1], self.max_text_len), float("-inf"), device=res.device, dtype=res.dtype)
new_res[..., : res.shape[-1]] = res
return new_res

View File

@@ -1,18 +0,0 @@
# ------------------------------------------------------------------------
# Grounding DINO
# url: https://github.com/IDEA-Research/GroundingDINO
# Copyright (c) 2023 IDEA. All Rights Reserved.
# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
# ------------------------------------------------------------------------
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
from .GroundingDINO import build_groundingdino # noqa
def build_model(args):
# we use register to maintain models from catdet6 on.
from .registry import MODULE_BUILD_FUNCS
assert args.modelname in MODULE_BUILD_FUNCS._module_dict
build_func = MODULE_BUILD_FUNCS.get(args.modelname)
model = build_func(args)
return model

View File

@@ -1,60 +0,0 @@
# ------------------------------------------------------------------------
# Grounding DINO
# url: https://github.com/IDEA-Research/GroundingDINO
# Copyright (c) 2023 IDEA. All Rights Reserved.
# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
# ------------------------------------------------------------------------
# -*- coding: utf-8 -*-
# @Author: Yihao Chen
# @Date: 2021-08-16 16:03:17
# @Last Modified by: Shilong Liu
# @Last Modified time: 2022-01-23 15:26
# modified from mmcv
import inspect
from functools import partial
class Registry(object):
def __init__(self, name):
self._name = name
self._module_dict = dict()
def __repr__(self):
format_str = self.__class__.__name__ + "(name={}, items={})".format(self._name, list(self._module_dict.keys()))
return format_str
def __len__(self):
return len(self._module_dict)
@property
def name(self):
return self._name
@property
def module_dict(self):
return self._module_dict
def get(self, key):
return self._module_dict.get(key, None)
def registe_with_name(self, module_name=None, force=False):
return partial(self.register, module_name=module_name, force=force)
def register(self, module_build_function, module_name=None, force=False):
"""Register a module build function.
Args:
module (:obj:`nn.Module`): Module to be registered.
"""
if not inspect.isfunction(module_build_function):
raise TypeError("module_build_function must be a function, but got {}".format(type(module_build_function)))
if module_name is None:
module_name = module_build_function.__name__
if not force and module_name in self._module_dict:
raise KeyError("{} is already registered in {}".format(module_name, self.name))
self._module_dict[module_name] = module_build_function
return module_build_function
MODULE_BUILD_FUNCS = Registry("model build functions")

View File

@@ -1 +0,0 @@
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved

View File

@@ -1,140 +0,0 @@
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
"""
Utilities for bounding box manipulation and GIoU.
"""
import torch
from torchvision.ops.boxes import box_area
def box_cxcywh_to_xyxy(x):
x_c, y_c, w, h = x.unbind(-1)
b = [(x_c - 0.5 * w), (y_c - 0.5 * h), (x_c + 0.5 * w), (y_c + 0.5 * h)]
return torch.stack(b, dim=-1)
def box_xyxy_to_cxcywh(x):
x0, y0, x1, y1 = x.unbind(-1)
b = [(x0 + x1) / 2, (y0 + y1) / 2, (x1 - x0), (y1 - y0)]
return torch.stack(b, dim=-1)
# modified from torchvision to also return the union
def box_iou(boxes1, boxes2):
area1 = box_area(boxes1)
area2 = box_area(boxes2)
# import ipdb; ipdb.set_trace()
lt = torch.max(boxes1[:, None, :2], boxes2[:, :2]) # [N,M,2]
rb = torch.min(boxes1[:, None, 2:], boxes2[:, 2:]) # [N,M,2]
wh = (rb - lt).clamp(min=0) # [N,M,2]
inter = wh[:, :, 0] * wh[:, :, 1] # [N,M]
union = area1[:, None] + area2 - inter
iou = inter / (union + 1e-6)
return iou, union
def generalized_box_iou(boxes1, boxes2):
"""
Generalized IoU from https://giou.stanford.edu/
The boxes should be in [x0, y0, x1, y1] format
Returns a [N, M] pairwise matrix, where N = len(boxes1)
and M = len(boxes2)
"""
# degenerate boxes gives inf / nan results
# so do an early check
assert (boxes1[:, 2:] >= boxes1[:, :2]).all()
assert (boxes2[:, 2:] >= boxes2[:, :2]).all()
# except:
# import ipdb; ipdb.set_trace()
iou, union = box_iou(boxes1, boxes2)
lt = torch.min(boxes1[:, None, :2], boxes2[:, :2])
rb = torch.max(boxes1[:, None, 2:], boxes2[:, 2:])
wh = (rb - lt).clamp(min=0) # [N,M,2]
area = wh[:, :, 0] * wh[:, :, 1]
return iou - (area - union) / (area + 1e-6)
# modified from torchvision to also return the union
def box_iou_pairwise(boxes1, boxes2):
area1 = box_area(boxes1)
area2 = box_area(boxes2)
lt = torch.max(boxes1[:, :2], boxes2[:, :2]) # [N,2]
rb = torch.min(boxes1[:, 2:], boxes2[:, 2:]) # [N,2]
wh = (rb - lt).clamp(min=0) # [N,2]
inter = wh[:, 0] * wh[:, 1] # [N]
union = area1 + area2 - inter
iou = inter / union
return iou, union
def generalized_box_iou_pairwise(boxes1, boxes2):
"""
Generalized IoU from https://giou.stanford.edu/
Input:
- boxes1, boxes2: N,4
Output:
- giou: N, 4
"""
# degenerate boxes gives inf / nan results
# so do an early check
assert (boxes1[:, 2:] >= boxes1[:, :2]).all()
assert (boxes2[:, 2:] >= boxes2[:, :2]).all()
assert boxes1.shape == boxes2.shape
iou, union = box_iou_pairwise(boxes1, boxes2) # N, 4
lt = torch.min(boxes1[:, :2], boxes2[:, :2])
rb = torch.max(boxes1[:, 2:], boxes2[:, 2:])
wh = (rb - lt).clamp(min=0) # [N,2]
area = wh[:, 0] * wh[:, 1]
return iou - (area - union) / area
def masks_to_boxes(masks):
"""Compute the bounding boxes around the provided masks
The masks should be in format [N, H, W] where N is the number of masks, (H, W) are the spatial dimensions.
Returns a [N, 4] tensors, with the boxes in xyxy format
"""
if masks.numel() == 0:
return torch.zeros((0, 4), device=masks.device)
h, w = masks.shape[-2:]
y = torch.arange(0, h, dtype=torch.float)
x = torch.arange(0, w, dtype=torch.float)
y, x = torch.meshgrid(y, x)
x_mask = masks * x.unsqueeze(0)
x_max = x_mask.flatten(1).max(-1)[0]
x_min = x_mask.masked_fill(~(masks.bool()), 1e8).flatten(1).min(-1)[0]
y_mask = masks * y.unsqueeze(0)
y_max = y_mask.flatten(1).max(-1)[0]
y_min = y_mask.masked_fill(~(masks.bool()), 1e8).flatten(1).min(-1)[0]
return torch.stack([x_min, y_min, x_max, y_max], 1)
if __name__ == "__main__":
x = torch.rand(5, 4)
y = torch.rand(3, 4)
iou, union = box_iou(x, y)
import ipdb
ipdb.set_trace()

View File

@@ -1,24 +0,0 @@
from transformers import AutoTokenizer, BertModel, RobertaModel
def get_tokenlizer(text_encoder_type):
if not isinstance(text_encoder_type, str):
# print("text_encoder_type is not a str")
if hasattr(text_encoder_type, "text_encoder_type"):
text_encoder_type = text_encoder_type.text_encoder_type
elif text_encoder_type.get("text_encoder_type", False):
text_encoder_type = text_encoder_type.get("text_encoder_type")
else:
raise ValueError("Unknown type of text_encoder_type: {}".format(type(text_encoder_type)))
print("final text_encoder_type: {}".format(text_encoder_type))
tokenizer = AutoTokenizer.from_pretrained(text_encoder_type)
return tokenizer
def get_pretrained_language_model(text_encoder_type):
if text_encoder_type == "bert-base-uncased":
return BertModel.from_pretrained(text_encoder_type)
if text_encoder_type == "roberta-base":
return RobertaModel.from_pretrained(text_encoder_type)
raise ValueError("Unknown text_encoder_type {}".format(text_encoder_type))

View File

@@ -1,221 +0,0 @@
from typing import Dict, List, Tuple
import cv2
import numpy as np
import supervision as sv
import torch
from PIL import Image
from torchvision.ops import box_convert
import invokeai.backend.image_util.grounding_segment_anything.groundingdino.datasets.transforms as T
from invokeai.backend.image_util.grounding_segment_anything.groundingdino.models import build_model
from invokeai.backend.image_util.grounding_segment_anything.groundingdino.util.misc import clean_state_dict
from invokeai.backend.image_util.grounding_segment_anything.groundingdino.util.slconfig import SLConfig
from invokeai.backend.image_util.grounding_segment_anything.groundingdino.util.utils import get_phrases_from_posmap
# ----------------------------------------------------------------------------------------------------------------------
# OLD API
# ----------------------------------------------------------------------------------------------------------------------
def preprocess_caption(caption: str) -> str:
result = caption.lower().strip()
if result.endswith("."):
return result
return result + "."
def load_model(model_config_path: str, model_state_dict: Dict[str, torch.Tensor], device: str = "cuda"):
args = SLConfig.fromfile(model_config_path)
args.device = device
model = build_model(args)
model.load_state_dict(clean_state_dict(model_state_dict["model"]), strict=False)
model.eval()
return model
def load_image(image_path: str) -> Tuple[np.array, torch.Tensor]:
transform = T.Compose(
[
T.RandomResize([800], max_size=1333),
T.ToTensor(),
T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
]
)
image_source = Image.open(image_path).convert("RGB")
image = np.asarray(image_source)
image_transformed, _ = transform(image_source, None)
return image, image_transformed
def predict(
model, image: torch.Tensor, caption: str, box_threshold: float, text_threshold: float, device: str = "cuda"
) -> Tuple[torch.Tensor, torch.Tensor, List[str]]:
caption = preprocess_caption(caption=caption)
model = model.to(device)
image = image.to(device)
with torch.no_grad():
outputs = model(image[None], captions=[caption])
prediction_logits = outputs["pred_logits"].cpu().sigmoid()[0] # prediction_logits.shape = (nq, 256)
prediction_boxes = outputs["pred_boxes"].cpu()[0] # prediction_boxes.shape = (nq, 4)
mask = prediction_logits.max(dim=1)[0] > box_threshold
logits = prediction_logits[mask] # logits.shape = (n, 256)
boxes = prediction_boxes[mask] # boxes.shape = (n, 4)
tokenizer = model.tokenizer
tokenized = tokenizer(caption)
phrases = [
get_phrases_from_posmap(logit > text_threshold, tokenized, tokenizer).replace(".", "") for logit in logits
]
return boxes, logits.max(dim=1)[0], phrases
def annotate(image_source: np.ndarray, boxes: torch.Tensor, logits: torch.Tensor, phrases: List[str]) -> np.ndarray:
h, w, _ = image_source.shape
boxes = boxes * torch.Tensor([w, h, w, h])
xyxy = box_convert(boxes=boxes, in_fmt="cxcywh", out_fmt="xyxy").numpy()
detections = sv.Detections(xyxy=xyxy)
labels = [f"{phrase} {logit:.2f}" for phrase, logit in zip(phrases, logits)]
box_annotator = sv.BoxAnnotator()
annotated_frame = cv2.cvtColor(image_source, cv2.COLOR_RGB2BGR)
annotated_frame = box_annotator.annotate(scene=annotated_frame, detections=detections, labels=labels)
return annotated_frame
# ----------------------------------------------------------------------------------------------------------------------
# NEW API
# ----------------------------------------------------------------------------------------------------------------------
class Model:
def __init__(self, model_config_path: str, model_state_dict: Dict[str, torch.Tensor], device: str = "cuda"):
self.model = load_model(
model_config_path=model_config_path, model_state_dict=model_state_dict, device=device
).to(device)
self.device = device
def predict_with_caption(
self, image: np.ndarray, caption: str, box_threshold: float = 0.35, text_threshold: float = 0.25
) -> Tuple[sv.Detections, List[str]]:
"""
import cv2
image = cv2.imread(IMAGE_PATH)
model = Model(model_config_path=CONFIG_PATH, model_checkpoint_path=WEIGHTS_PATH)
detections, labels = model.predict_with_caption(
image=image,
caption=caption,
box_threshold=BOX_THRESHOLD,
text_threshold=TEXT_THRESHOLD
)
import supervision as sv
box_annotator = sv.BoxAnnotator()
annotated_image = box_annotator.annotate(scene=image, detections=detections, labels=labels)
"""
processed_image = Model.preprocess_image(image_bgr=image).to(self.device)
boxes, logits, phrases = predict(
model=self.model,
image=processed_image,
caption=caption,
box_threshold=box_threshold,
text_threshold=text_threshold,
device=self.device,
)
source_h, source_w, _ = image.shape
detections = Model.post_process_result(source_h=source_h, source_w=source_w, boxes=boxes, logits=logits)
return detections, phrases
def predict_with_classes(
self, image: np.ndarray, classes: List[str], box_threshold: float, text_threshold: float
) -> sv.Detections:
"""
import cv2
image = cv2.imread(IMAGE_PATH)
model = Model(model_config_path=CONFIG_PATH, model_checkpoint_path=WEIGHTS_PATH)
detections = model.predict_with_classes(
image=image,
classes=CLASSES,
box_threshold=BOX_THRESHOLD,
text_threshold=TEXT_THRESHOLD
)
import supervision as sv
box_annotator = sv.BoxAnnotator()
annotated_image = box_annotator.annotate(scene=image, detections=detections)
"""
caption = ". ".join(classes)
processed_image = Model.preprocess_image(image_bgr=image).to(self.device)
boxes, logits, phrases = predict(
model=self.model,
image=processed_image,
caption=caption,
box_threshold=box_threshold,
text_threshold=text_threshold,
device=self.device,
)
source_h, source_w, _ = image.shape
detections = Model.post_process_result(source_h=source_h, source_w=source_w, boxes=boxes, logits=logits)
class_id = Model.phrases2classes(phrases=phrases, classes=classes)
detections.class_id = class_id
return detections
@staticmethod
def preprocess_image(image_bgr: np.ndarray) -> torch.Tensor:
transform = T.Compose(
[
T.RandomResize([800], max_size=1333),
T.ToTensor(),
T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
]
)
image_pillow = Image.fromarray(cv2.cvtColor(image_bgr, cv2.COLOR_BGR2RGB))
image_transformed, _ = transform(image_pillow, None)
return image_transformed
@staticmethod
def post_process_result(source_h: int, source_w: int, boxes: torch.Tensor, logits: torch.Tensor) -> sv.Detections:
boxes = boxes * torch.Tensor([source_w, source_h, source_w, source_h])
xyxy = box_convert(boxes=boxes, in_fmt="cxcywh", out_fmt="xyxy").numpy()
confidence = logits.numpy()
return sv.Detections(xyxy=xyxy, confidence=confidence)
@staticmethod
def phrases2classes(phrases: List[str], classes: List[str]) -> np.ndarray:
class_ids = []
for phrase in phrases:
try:
# class_ids.append(classes.index(phrase))
class_ids.append(Model.find_index(phrase, classes))
except ValueError:
class_ids.append(None)
return np.array(class_ids)
@staticmethod
def find_index(string, lst):
# if meet string like "lake river" will only keep "lake"
# this is an hack implementation for visualization which will be updated in the future
string = string.lower().split()[0]
for i, s in enumerate(lst):
if string in s.lower():
return i
print(
"There's a wrong phrase happen, this is because of our post-process merged wrong tokens, which will be \
modified in the future. We will assign it with a random label at this time."
)
return 0

View File

@@ -1,701 +0,0 @@
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
"""
Misc functions, including distributed helpers.
Mostly copy-paste from torchvision references.
"""
import colorsys
import datetime
import functools
import io
import json
import os
import pickle
import subprocess
import time
from collections import OrderedDict, defaultdict, deque
from typing import List, Optional
import numpy as np
import torch
import torch.distributed as dist
# needed due to empty tensor bug in pytorch and torchvision 0.5
import torchvision
from torch import Tensor
__torchvision_need_compat_flag = float(torchvision.__version__.split(".")[1]) < 7
if __torchvision_need_compat_flag:
from torchvision.ops import _new_empty_tensor
from torchvision.ops.misc import _output_size
class SmoothedValue(object):
"""Track a series of values and provide access to smoothed values over a
window or the global series average.
"""
def __init__(self, window_size=20, fmt=None):
if fmt is None:
fmt = "{median:.4f} ({global_avg:.4f})"
self.deque = deque(maxlen=window_size)
self.total = 0.0
self.count = 0
self.fmt = fmt
def update(self, value, n=1):
self.deque.append(value)
self.count += n
self.total += value * n
def synchronize_between_processes(self):
"""
Warning: does not synchronize the deque!
"""
if not is_dist_avail_and_initialized():
return
t = torch.tensor([self.count, self.total], dtype=torch.float64, device="cuda")
dist.barrier()
dist.all_reduce(t)
t = t.tolist()
self.count = int(t[0])
self.total = t[1]
@property
def median(self):
d = torch.tensor(list(self.deque))
if d.shape[0] == 0:
return 0
return d.median().item()
@property
def avg(self):
d = torch.tensor(list(self.deque), dtype=torch.float32)
return d.mean().item()
@property
def global_avg(self):
if os.environ.get("SHILONG_AMP", None) == "1":
eps = 1e-4
else:
eps = 1e-6
return self.total / (self.count + eps)
@property
def max(self):
return max(self.deque)
@property
def value(self):
return self.deque[-1]
def __str__(self):
return self.fmt.format(
median=self.median,
avg=self.avg,
global_avg=self.global_avg,
max=self.max,
value=self.value,
)
@functools.lru_cache()
def _get_global_gloo_group():
"""
Return a process group based on gloo backend, containing all the ranks
The result is cached.
"""
if dist.get_backend() == "nccl":
return dist.new_group(backend="gloo")
return dist.group.WORLD
def all_gather_cpu(data):
"""
Run all_gather on arbitrary picklable data (not necessarily tensors)
Args:
data: any picklable object
Returns:
list[data]: list of data gathered from each rank
"""
world_size = get_world_size()
if world_size == 1:
return [data]
cpu_group = _get_global_gloo_group()
buffer = io.BytesIO()
torch.save(data, buffer)
data_view = buffer.getbuffer()
device = "cuda" if cpu_group is None else "cpu"
tensor = torch.ByteTensor(data_view).to(device)
# obtain Tensor size of each rank
local_size = torch.tensor([tensor.numel()], device=device, dtype=torch.long)
size_list = [torch.tensor([0], device=device, dtype=torch.long) for _ in range(world_size)]
if cpu_group is None:
dist.all_gather(size_list, local_size)
else:
print("gathering on cpu")
dist.all_gather(size_list, local_size, group=cpu_group)
size_list = [int(size.item()) for size in size_list]
max_size = max(size_list)
assert isinstance(local_size.item(), int)
local_size = int(local_size.item())
# receiving Tensor from all ranks
# we pad the tensor because torch all_gather does not support
# gathering tensors of different shapes
tensor_list = []
for _ in size_list:
tensor_list.append(torch.empty((max_size,), dtype=torch.uint8, device=device))
if local_size != max_size:
padding = torch.empty(size=(max_size - local_size,), dtype=torch.uint8, device=device)
tensor = torch.cat((tensor, padding), dim=0)
if cpu_group is None:
dist.all_gather(tensor_list, tensor)
else:
dist.all_gather(tensor_list, tensor, group=cpu_group)
data_list = []
for size, tensor in zip(size_list, tensor_list, strict=False):
tensor = torch.split(tensor, [size, max_size - size], dim=0)[0]
buffer = io.BytesIO(tensor.cpu().numpy())
obj = torch.load(buffer)
data_list.append(obj)
return data_list
def all_gather(data):
"""
Run all_gather on arbitrary picklable data (not necessarily tensors)
Args:
data: any picklable object
Returns:
list[data]: list of data gathered from each rank
"""
if os.getenv("CPU_REDUCE") == "1":
return all_gather_cpu(data)
world_size = get_world_size()
if world_size == 1:
return [data]
# serialized to a Tensor
buffer = pickle.dumps(data)
storage = torch.ByteStorage.from_buffer(buffer)
tensor = torch.ByteTensor(storage).to("cuda")
# obtain Tensor size of each rank
local_size = torch.tensor([tensor.numel()], device="cuda")
size_list = [torch.tensor([0], device="cuda") for _ in range(world_size)]
dist.all_gather(size_list, local_size)
size_list = [int(size.item()) for size in size_list]
max_size = max(size_list)
# receiving Tensor from all ranks
# we pad the tensor because torch all_gather does not support
# gathering tensors of different shapes
tensor_list = []
for _ in size_list:
tensor_list.append(torch.empty((max_size,), dtype=torch.uint8, device="cuda"))
if local_size != max_size:
padding = torch.empty(size=(max_size - local_size,), dtype=torch.uint8, device="cuda")
tensor = torch.cat((tensor, padding), dim=0)
dist.all_gather(tensor_list, tensor)
data_list = []
for size, tensor in zip(size_list, tensor_list, strict=False):
buffer = tensor.cpu().numpy().tobytes()[:size]
data_list.append(pickle.loads(buffer))
return data_list
def reduce_dict(input_dict, average=True):
"""
Args:
input_dict (dict): all the values will be reduced
average (bool): whether to do average or sum
Reduce the values in the dictionary from all processes so that all processes
have the averaged results. Returns a dict with the same fields as
input_dict, after reduction.
"""
world_size = get_world_size()
if world_size < 2:
return input_dict
with torch.no_grad():
names = []
values = []
# sort the keys so that they are consistent across processes
for k in sorted(input_dict.keys()):
names.append(k)
values.append(input_dict[k])
values = torch.stack(values, dim=0)
dist.all_reduce(values)
if average:
values /= world_size
reduced_dict = {k: v for k, v in zip(names, values, strict=False)}
return reduced_dict
class MetricLogger(object):
def __init__(self, delimiter="\t"):
self.meters = defaultdict(SmoothedValue)
self.delimiter = delimiter
def update(self, **kwargs):
for k, v in kwargs.items():
if isinstance(v, torch.Tensor):
v = v.item()
assert isinstance(v, (float, int))
self.meters[k].update(v)
def __getattr__(self, attr):
if attr in self.meters:
return self.meters[attr]
if attr in self.__dict__:
return self.__dict__[attr]
raise AttributeError("'{}' object has no attribute '{}'".format(type(self).__name__, attr))
def __str__(self):
loss_str = []
for name, meter in self.meters.items():
# print(name, str(meter))
# import ipdb;ipdb.set_trace()
if meter.count > 0:
loss_str.append("{}: {}".format(name, str(meter)))
return self.delimiter.join(loss_str)
def synchronize_between_processes(self):
for meter in self.meters.values():
meter.synchronize_between_processes()
def add_meter(self, name, meter):
self.meters[name] = meter
def log_every(self, iterable, print_freq, header=None, logger=None):
if logger is None:
print_func = print
else:
print_func = logger.info
i = 0
if not header:
header = ""
start_time = time.time()
end = time.time()
iter_time = SmoothedValue(fmt="{avg:.4f}")
data_time = SmoothedValue(fmt="{avg:.4f}")
space_fmt = ":" + str(len(str(len(iterable)))) + "d"
if torch.cuda.is_available():
log_msg = self.delimiter.join(
[
header,
"[{0" + space_fmt + "}/{1}]",
"eta: {eta}",
"{meters}",
"time: {time}",
"data: {data}",
"max mem: {memory:.0f}",
]
)
else:
log_msg = self.delimiter.join(
[
header,
"[{0" + space_fmt + "}/{1}]",
"eta: {eta}",
"{meters}",
"time: {time}",
"data: {data}",
]
)
MB = 1024.0 * 1024.0
for obj in iterable:
data_time.update(time.time() - end)
yield obj
# import ipdb; ipdb.set_trace()
iter_time.update(time.time() - end)
if i % print_freq == 0 or i == len(iterable) - 1:
eta_seconds = iter_time.global_avg * (len(iterable) - i)
eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
if torch.cuda.is_available():
print_func(
log_msg.format(
i,
len(iterable),
eta=eta_string,
meters=str(self),
time=str(iter_time),
data=str(data_time),
memory=torch.cuda.max_memory_allocated() / MB,
)
)
else:
print_func(
log_msg.format(
i,
len(iterable),
eta=eta_string,
meters=str(self),
time=str(iter_time),
data=str(data_time),
)
)
i += 1
end = time.time()
total_time = time.time() - start_time
total_time_str = str(datetime.timedelta(seconds=int(total_time)))
print_func("{} Total time: {} ({:.4f} s / it)".format(header, total_time_str, total_time / len(iterable)))
def get_sha():
cwd = os.path.dirname(os.path.abspath(__file__))
def _run(command):
return subprocess.check_output(command, cwd=cwd).decode("ascii").strip()
sha = "N/A"
diff = "clean"
branch = "N/A"
try:
sha = _run(["git", "rev-parse", "HEAD"])
subprocess.check_output(["git", "diff"], cwd=cwd)
diff = _run(["git", "diff-index", "HEAD"])
diff = "has uncommited changes" if diff else "clean"
branch = _run(["git", "rev-parse", "--abbrev-ref", "HEAD"])
except Exception:
pass
message = f"sha: {sha}, status: {diff}, branch: {branch}"
return message
def collate_fn(batch):
# import ipdb; ipdb.set_trace()
batch = list(zip(*batch, strict=False))
batch[0] = nested_tensor_from_tensor_list(batch[0])
return tuple(batch)
def _max_by_axis(the_list):
# type: (List[List[int]]) -> List[int]
maxes = the_list[0]
for sublist in the_list[1:]:
for index, item in enumerate(sublist):
maxes[index] = max(maxes[index], item)
return maxes
class NestedTensor(object):
def __init__(self, tensors, mask: Optional[Tensor]):
self.tensors = tensors
self.mask = mask
if mask == "auto":
self.mask = torch.zeros_like(tensors).to(tensors.device)
if self.mask.dim() == 3:
self.mask = self.mask.sum(0).to(bool)
elif self.mask.dim() == 4:
self.mask = self.mask.sum(1).to(bool)
else:
raise ValueError("tensors dim must be 3 or 4 but {}({})".format(self.tensors.dim(), self.tensors.shape))
def imgsize(self):
res = []
for i in range(self.tensors.shape[0]):
mask = self.mask[i]
maxH = (~mask).sum(0).max()
maxW = (~mask).sum(1).max()
res.append(torch.Tensor([maxH, maxW]))
return res
def to(self, device):
# type: (Device) -> NestedTensor # noqa
cast_tensor = self.tensors.to(device)
mask = self.mask
if mask is not None:
assert mask is not None
cast_mask = mask.to(device)
else:
cast_mask = None
return NestedTensor(cast_tensor, cast_mask)
def to_img_list_single(self, tensor, mask):
assert tensor.dim() == 3, "dim of tensor should be 3 but {}".format(tensor.dim())
maxH = (~mask).sum(0).max()
maxW = (~mask).sum(1).max()
img = tensor[:, :maxH, :maxW]
return img
def to_img_list(self):
"""remove the padding and convert to img list
Returns:
[type]: [description]
"""
if self.tensors.dim() == 3:
return self.to_img_list_single(self.tensors, self.mask)
else:
res = []
for i in range(self.tensors.shape[0]):
tensor_i = self.tensors[i]
mask_i = self.mask[i]
res.append(self.to_img_list_single(tensor_i, mask_i))
return res
@property
def device(self):
return self.tensors.device
def decompose(self):
return self.tensors, self.mask
def __repr__(self):
return str(self.tensors)
@property
def shape(self):
return {"tensors.shape": self.tensors.shape, "mask.shape": self.mask.shape}
def nested_tensor_from_tensor_list(tensor_list: List[Tensor]):
# TODO make this more general
if tensor_list[0].ndim == 3:
if torchvision._is_tracing():
# nested_tensor_from_tensor_list() does not export well to ONNX
# call _onnx_nested_tensor_from_tensor_list() instead
return _onnx_nested_tensor_from_tensor_list(tensor_list)
# TODO make it support different-sized images
max_size = _max_by_axis([list(img.shape) for img in tensor_list])
# min_size = tuple(min(s) for s in zip(*[img.shape for img in tensor_list]))
batch_shape = [len(tensor_list)] + max_size
b, c, h, w = batch_shape
dtype = tensor_list[0].dtype
device = tensor_list[0].device
tensor = torch.zeros(batch_shape, dtype=dtype, device=device)
mask = torch.ones((b, h, w), dtype=torch.bool, device=device)
for img, pad_img, m in zip(tensor_list, tensor, mask, strict=False):
pad_img[: img.shape[0], : img.shape[1], : img.shape[2]].copy_(img)
m[: img.shape[1], : img.shape[2]] = False
else:
raise ValueError("not supported")
return NestedTensor(tensor, mask)
# _onnx_nested_tensor_from_tensor_list() is an implementation of
# nested_tensor_from_tensor_list() that is supported by ONNX tracing.
@torch.jit.unused
def _onnx_nested_tensor_from_tensor_list(tensor_list: List[Tensor]) -> NestedTensor:
max_size = []
for i in range(tensor_list[0].dim()):
max_size_i = torch.max(torch.stack([img.shape[i] for img in tensor_list]).to(torch.float32)).to(torch.int64)
max_size.append(max_size_i)
max_size = tuple(max_size)
# work around for
# pad_img[: img.shape[0], : img.shape[1], : img.shape[2]].copy_(img)
# m[: img.shape[1], :img.shape[2]] = False
# which is not yet supported in onnx
padded_imgs = []
padded_masks = []
for img in tensor_list:
padding = [(s1 - s2) for s1, s2 in zip(max_size, tuple(img.shape), strict=False)]
padded_img = torch.nn.functional.pad(img, (0, padding[2], 0, padding[1], 0, padding[0]))
padded_imgs.append(padded_img)
m = torch.zeros_like(img[0], dtype=torch.int, device=img.device)
padded_mask = torch.nn.functional.pad(m, (0, padding[2], 0, padding[1]), "constant", 1)
padded_masks.append(padded_mask.to(torch.bool))
tensor = torch.stack(padded_imgs)
mask = torch.stack(padded_masks)
return NestedTensor(tensor, mask=mask)
def setup_for_distributed(is_master):
"""
This function disables printing when not in master process
"""
import builtins as __builtin__
builtin_print = __builtin__.print
def print(*args, **kwargs):
force = kwargs.pop("force", False)
if is_master or force:
builtin_print(*args, **kwargs)
__builtin__.print = print
def is_dist_avail_and_initialized():
if not dist.is_available():
return False
if not dist.is_initialized():
return False
return True
def get_world_size():
if not is_dist_avail_and_initialized():
return 1
return dist.get_world_size()
def get_rank():
if not is_dist_avail_and_initialized():
return 0
return dist.get_rank()
def is_main_process():
return get_rank() == 0
def save_on_master(*args, **kwargs):
if is_main_process():
torch.save(*args, **kwargs)
def init_distributed_mode(args):
if "WORLD_SIZE" in os.environ and os.environ["WORLD_SIZE"] != "": # 'RANK' in os.environ and
args.rank = int(os.environ["RANK"])
args.world_size = int(os.environ["WORLD_SIZE"])
args.gpu = args.local_rank = int(os.environ["LOCAL_RANK"])
# launch by torch.distributed.launch
# Single node
# python -m torch.distributed.launch --nproc_per_node=8 main.py --world-size 1 --rank 0 ...
# Multi nodes
# python -m torch.distributed.launch --nproc_per_node=8 main.py --world-size 2 --rank 0 --dist-url 'tcp://IP_OF_NODE0:FREEPORT' ...
# python -m torch.distributed.launch --nproc_per_node=8 main.py --world-size 2 --rank 1 --dist-url 'tcp://IP_OF_NODE0:FREEPORT' ...
# args.rank = int(os.environ.get('OMPI_COMM_WORLD_RANK'))
# local_world_size = int(os.environ['GPU_PER_NODE_COUNT'])
# args.world_size = args.world_size * local_world_size
# args.gpu = args.local_rank = int(os.environ['LOCAL_RANK'])
# args.rank = args.rank * local_world_size + args.local_rank
print("world size: {}, rank: {}, local rank: {}".format(args.world_size, args.rank, args.local_rank))
print(json.dumps(dict(os.environ), indent=2))
elif "SLURM_PROCID" in os.environ:
args.rank = int(os.environ["SLURM_PROCID"])
args.gpu = args.local_rank = int(os.environ["SLURM_LOCALID"])
args.world_size = int(os.environ["SLURM_NPROCS"])
print(
"world size: {}, world rank: {}, local rank: {}, device_count: {}".format(
args.world_size, args.rank, args.local_rank, torch.cuda.device_count()
)
)
else:
print("Not using distributed mode")
args.distributed = False
args.world_size = 1
args.rank = 0
args.local_rank = 0
return
print("world_size:{} rank:{} local_rank:{}".format(args.world_size, args.rank, args.local_rank))
args.distributed = True
torch.cuda.set_device(args.local_rank)
args.dist_backend = "nccl"
print("| distributed init (rank {}): {}".format(args.rank, args.dist_url), flush=True)
torch.distributed.init_process_group(
backend=args.dist_backend,
world_size=args.world_size,
rank=args.rank,
init_method=args.dist_url,
)
print("Before torch.distributed.barrier()")
torch.distributed.barrier()
print("End torch.distributed.barrier()")
setup_for_distributed(args.rank == 0)
@torch.no_grad()
def accuracy(output, target, topk=(1,)):
"""Computes the precision@k for the specified values of k"""
if target.numel() == 0:
return [torch.zeros([], device=output.device)]
maxk = max(topk)
batch_size = target.size(0)
_, pred = output.topk(maxk, 1, True, True)
pred = pred.t()
correct = pred.eq(target.view(1, -1).expand_as(pred))
res = []
for k in topk:
correct_k = correct[:k].view(-1).float().sum(0)
res.append(correct_k.mul_(100.0 / batch_size))
return res
@torch.no_grad()
def accuracy_onehot(pred, gt):
"""_summary_
Args:
pred (_type_): n, c
gt (_type_): n, c
"""
tp = ((pred - gt).abs().sum(-1) < 1e-4).float().sum()
acc = tp / gt.shape[0] * 100
return acc
def interpolate(input, size=None, scale_factor=None, mode="nearest", align_corners=None):
# type: (Tensor, Optional[List[int]], Optional[float], str, Optional[bool]) -> Tensor
"""
Equivalent to nn.functional.interpolate, but with support for empty batch sizes.
This will eventually be supported natively by PyTorch, and this
class can go away.
"""
if __torchvision_need_compat_flag < 0.7:
if input.numel() > 0:
return torch.nn.functional.interpolate(input, size, scale_factor, mode, align_corners)
output_shape = _output_size(2, input, size, scale_factor)
output_shape = list(input.shape[:-2]) + list(output_shape)
return _new_empty_tensor(input, output_shape)
else:
return torchvision.ops.misc.interpolate(input, size, scale_factor, mode, align_corners)
class color_sys:
def __init__(self, num_colors) -> None:
self.num_colors = num_colors
colors = []
for i in np.arange(0.0, 360.0, 360.0 / num_colors):
hue = i / 360.0
lightness = (50 + np.random.rand() * 10) / 100.0
saturation = (90 + np.random.rand() * 10) / 100.0
colors.append(tuple([int(j * 255) for j in colorsys.hls_to_rgb(hue, lightness, saturation)]))
self.colors = colors
def __call__(self, idx):
return self.colors[idx]
def inverse_sigmoid(x, eps=1e-3):
x = x.clamp(min=0, max=1)
x1 = x.clamp(min=eps)
x2 = (1 - x).clamp(min=eps)
return torch.log(x1 / x2)
def clean_state_dict(state_dict):
new_state_dict = OrderedDict()
for k, v in state_dict.items():
if k[:7] == "module.":
k = k[7:] # remove `module.`
new_state_dict[k] = v
return new_state_dict

View File

@@ -1,419 +0,0 @@
# ==========================================================
# Modified from mmcv
# ==========================================================
import ast
import os.path as osp
import platform
import shutil
import sys
import tempfile
from argparse import Action
from importlib import import_module
from addict import Dict
from yapf.yapflib.yapf_api import FormatCode
BASE_KEY = "_base_"
DELETE_KEY = "_delete_"
RESERVED_KEYS = ["filename", "text", "pretty_text", "get", "dump", "merge_from_dict"]
def check_file_exist(filename, msg_tmpl='file "{}" does not exist'):
if not osp.isfile(filename):
raise FileNotFoundError(msg_tmpl.format(filename))
class ConfigDict(Dict):
def __missing__(self, name):
raise KeyError(name)
def __getattr__(self, name):
try:
value = super(ConfigDict, self).__getattr__(name)
except KeyError:
ex = AttributeError(f"'{self.__class__.__name__}' object has no " f"attribute '{name}'")
except Exception as e:
ex = e
else:
return value
raise ex
class SLConfig(object):
"""
config files.
only support .py file as config now.
ref: mmcv.utils.config
Example:
>>> cfg = Config(dict(a=1, b=dict(b1=[0, 1])))
>>> cfg.a
1
>>> cfg.b
{'b1': [0, 1]}
>>> cfg.b.b1
[0, 1]
>>> cfg = Config.fromfile('tests/data/config/a.py')
>>> cfg.filename
"/home/kchen/projects/mmcv/tests/data/config/a.py"
>>> cfg.item4
'test'
>>> cfg
"Config [path: /home/kchen/projects/mmcv/tests/data/config/a.py]: "
"{'item1': [1, 2], 'item2': {'a': 0}, 'item3': True, 'item4': 'test'}"
"""
@staticmethod
def _validate_py_syntax(filename):
with open(filename) as f:
content = f.read()
try:
ast.parse(content)
except SyntaxError:
raise SyntaxError("There are syntax errors in config " f"file {filename}")
@staticmethod
def _file2dict(filename):
filename = osp.abspath(osp.expanduser(filename))
check_file_exist(filename)
if filename.lower().endswith(".py"):
with tempfile.TemporaryDirectory() as temp_config_dir:
temp_config_file = tempfile.NamedTemporaryFile(dir=temp_config_dir, suffix=".py")
temp_config_name = osp.basename(temp_config_file.name)
if platform.system() == "Windows":
temp_config_file.close()
shutil.copyfile(filename, osp.join(temp_config_dir, temp_config_name))
temp_module_name = osp.splitext(temp_config_name)[0]
sys.path.insert(0, temp_config_dir)
SLConfig._validate_py_syntax(filename)
mod = import_module(temp_module_name)
sys.path.pop(0)
cfg_dict = {name: value for name, value in mod.__dict__.items() if not name.startswith("__")}
# delete imported module
del sys.modules[temp_module_name]
# close temp file
temp_config_file.close()
elif filename.lower().endswith((".yml", ".yaml", ".json")):
from .slio import slload
cfg_dict = slload(filename)
else:
raise IOError("Only py/yml/yaml/json type are supported now!")
cfg_text = filename + "\n"
with open(filename, "r") as f:
cfg_text += f.read()
# parse the base file
if BASE_KEY in cfg_dict:
cfg_dir = osp.dirname(filename)
base_filename = cfg_dict.pop(BASE_KEY)
base_filename = base_filename if isinstance(base_filename, list) else [base_filename]
cfg_dict_list = list()
cfg_text_list = list()
for f in base_filename:
_cfg_dict, _cfg_text = SLConfig._file2dict(osp.join(cfg_dir, f))
cfg_dict_list.append(_cfg_dict)
cfg_text_list.append(_cfg_text)
base_cfg_dict = dict()
for c in cfg_dict_list:
if len(base_cfg_dict.keys() & c.keys()) > 0:
raise KeyError("Duplicate key is not allowed among bases")
# TODO Allow the duplicate key while warnning user
base_cfg_dict.update(c)
base_cfg_dict = SLConfig._merge_a_into_b(cfg_dict, base_cfg_dict)
cfg_dict = base_cfg_dict
# merge cfg_text
cfg_text_list.append(cfg_text)
cfg_text = "\n".join(cfg_text_list)
return cfg_dict, cfg_text
@staticmethod
def _merge_a_into_b(a, b):
"""merge dict `a` into dict `b` (non-inplace).
values in `a` will overwrite `b`.
copy first to avoid inplace modification
Args:
a ([type]): [description]
b ([type]): [description]
Returns:
[dict]: [description]
"""
# import ipdb; ipdb.set_trace()
if not isinstance(a, dict):
return a
b = b.copy()
for k, v in a.items():
if isinstance(v, dict) and k in b and not v.pop(DELETE_KEY, False):
if not isinstance(b[k], dict) and not isinstance(b[k], list):
# if :
# import ipdb; ipdb.set_trace()
raise TypeError(
f"{k}={v} in child config cannot inherit from base "
f"because {k} is a dict in the child config but is of "
f"type {type(b[k])} in base config. You may set "
f"`{DELETE_KEY}=True` to ignore the base config"
)
b[k] = SLConfig._merge_a_into_b(v, b[k])
elif isinstance(b, list):
try:
_ = int(k)
except:
raise TypeError(f"b is a list, " f"index {k} should be an int when input but {type(k)}")
b[int(k)] = SLConfig._merge_a_into_b(v, b[int(k)])
else:
b[k] = v
return b
@staticmethod
def fromfile(filename):
cfg_dict, cfg_text = SLConfig._file2dict(filename)
return SLConfig(cfg_dict, cfg_text=cfg_text, filename=filename)
def __init__(self, cfg_dict=None, cfg_text=None, filename=None):
if cfg_dict is None:
cfg_dict = dict()
elif not isinstance(cfg_dict, dict):
raise TypeError("cfg_dict must be a dict, but " f"got {type(cfg_dict)}")
for key in cfg_dict:
if key in RESERVED_KEYS:
raise KeyError(f"{key} is reserved for config file")
super(SLConfig, self).__setattr__("_cfg_dict", ConfigDict(cfg_dict))
super(SLConfig, self).__setattr__("_filename", filename)
if cfg_text:
text = cfg_text
elif filename:
with open(filename, "r") as f:
text = f.read()
else:
text = ""
super(SLConfig, self).__setattr__("_text", text)
@property
def filename(self):
return self._filename
@property
def text(self):
return self._text
@property
def pretty_text(self):
indent = 4
def _indent(s_, num_spaces):
s = s_.split("\n")
if len(s) == 1:
return s_
first = s.pop(0)
s = [(num_spaces * " ") + line for line in s]
s = "\n".join(s)
s = first + "\n" + s
return s
def _format_basic_types(k, v, use_mapping=False):
if isinstance(v, str):
v_str = f"'{v}'"
else:
v_str = str(v)
if use_mapping:
k_str = f"'{k}'" if isinstance(k, str) else str(k)
attr_str = f"{k_str}: {v_str}"
else:
attr_str = f"{str(k)}={v_str}"
attr_str = _indent(attr_str, indent)
return attr_str
def _format_list(k, v, use_mapping=False):
# check if all items in the list are dict
if all(isinstance(_, dict) for _ in v):
v_str = "[\n"
v_str += "\n".join(f"dict({_indent(_format_dict(v_), indent)})," for v_ in v).rstrip(",")
if use_mapping:
k_str = f"'{k}'" if isinstance(k, str) else str(k)
attr_str = f"{k_str}: {v_str}"
else:
attr_str = f"{str(k)}={v_str}"
attr_str = _indent(attr_str, indent) + "]"
else:
attr_str = _format_basic_types(k, v, use_mapping)
return attr_str
def _contain_invalid_identifier(dict_str):
contain_invalid_identifier = False
for key_name in dict_str:
contain_invalid_identifier |= not str(key_name).isidentifier()
return contain_invalid_identifier
def _format_dict(input_dict, outest_level=False):
r = ""
s = []
use_mapping = _contain_invalid_identifier(input_dict)
if use_mapping:
r += "{"
for idx, (k, v) in enumerate(input_dict.items()):
is_last = idx >= len(input_dict) - 1
end = "" if outest_level or is_last else ","
if isinstance(v, dict):
v_str = "\n" + _format_dict(v)
if use_mapping:
k_str = f"'{k}'" if isinstance(k, str) else str(k)
attr_str = f"{k_str}: dict({v_str}"
else:
attr_str = f"{str(k)}=dict({v_str}"
attr_str = _indent(attr_str, indent) + ")" + end
elif isinstance(v, list):
attr_str = _format_list(k, v, use_mapping) + end
else:
attr_str = _format_basic_types(k, v, use_mapping) + end
s.append(attr_str)
r += "\n".join(s)
if use_mapping:
r += "}"
return r
cfg_dict = self._cfg_dict.to_dict()
text = _format_dict(cfg_dict, outest_level=True)
# copied from setup.cfg
yapf_style = dict(
based_on_style="pep8",
blank_line_before_nested_class_or_def=True,
split_before_expression_after_opening_paren=True,
)
text, _ = FormatCode(text, style_config=yapf_style, verify=True)
return text
def __repr__(self):
return f"Config (path: {self.filename}): {self._cfg_dict.__repr__()}"
def __len__(self):
return len(self._cfg_dict)
def __getattr__(self, name):
# # debug
# print('+'*15)
# print('name=%s' % name)
# print("addr:", id(self))
# # print('type(self):', type(self))
# print(self.__dict__)
# print('+'*15)
# if self.__dict__ == {}:
# raise ValueError
return getattr(self._cfg_dict, name)
def __getitem__(self, name):
return self._cfg_dict.__getitem__(name)
def __setattr__(self, name, value):
if isinstance(value, dict):
value = ConfigDict(value)
self._cfg_dict.__setattr__(name, value)
def __setitem__(self, name, value):
if isinstance(value, dict):
value = ConfigDict(value)
self._cfg_dict.__setitem__(name, value)
def __iter__(self):
return iter(self._cfg_dict)
def dump(self, file=None):
# import ipdb; ipdb.set_trace()
if file is None:
return self.pretty_text
else:
with open(file, "w") as f:
f.write(self.pretty_text)
def merge_from_dict(self, options):
"""Merge list into cfg_dict
Merge the dict parsed by MultipleKVAction into this cfg.
Examples:
>>> options = {'model.backbone.depth': 50,
... 'model.backbone.with_cp':True}
>>> cfg = Config(dict(model=dict(backbone=dict(type='ResNet'))))
>>> cfg.merge_from_dict(options)
>>> cfg_dict = super(Config, self).__getattribute__('_cfg_dict')
>>> assert cfg_dict == dict(
... model=dict(backbone=dict(depth=50, with_cp=True)))
Args:
options (dict): dict of configs to merge from.
"""
option_cfg_dict = {}
for full_key, v in options.items():
d = option_cfg_dict
key_list = full_key.split(".")
for subkey in key_list[:-1]:
d.setdefault(subkey, ConfigDict())
d = d[subkey]
subkey = key_list[-1]
d[subkey] = v
cfg_dict = super(SLConfig, self).__getattribute__("_cfg_dict")
super(SLConfig, self).__setattr__("_cfg_dict", SLConfig._merge_a_into_b(option_cfg_dict, cfg_dict))
# for multiprocess
def __setstate__(self, state):
self.__init__(state)
def copy(self):
return SLConfig(self._cfg_dict.copy())
def deepcopy(self):
return SLConfig(self._cfg_dict.deepcopy())
class DictAction(Action):
"""
argparse action to split an argument into KEY=VALUE form
on the first = and append to a dictionary. List options should
be passed as comma separated values, i.e KEY=V1,V2,V3
"""
@staticmethod
def _parse_int_float_bool(val):
try:
return int(val)
except ValueError:
pass
try:
return float(val)
except ValueError:
pass
if val.lower() in ["true", "false"]:
return True if val.lower() == "true" else False
if val.lower() in ["none", "null"]:
return None
return val
def __call__(self, parser, namespace, values, option_string=None):
options = {}
for kv in values:
key, val = kv.split("=", maxsplit=1)
val = [self._parse_int_float_bool(v) for v in val.split(",")]
if len(val) == 1:
val = val[0]
options[key] = val
setattr(namespace, self.dest, options)

View File

@@ -1,178 +0,0 @@
# ==========================================================
# Modified from mmcv
# ==========================================================
import json
import pickle
from abc import ABCMeta, abstractmethod
from pathlib import Path
import yaml
try:
from yaml import CDumper as Dumper
from yaml import CLoader as Loader
except ImportError:
from yaml import Dumper, Loader
# ===========================
# Rigister handler
# ===========================
class BaseFileHandler(metaclass=ABCMeta):
@abstractmethod
def load_from_fileobj(self, file, **kwargs):
pass
@abstractmethod
def dump_to_fileobj(self, obj, file, **kwargs):
pass
@abstractmethod
def dump_to_str(self, obj, **kwargs):
pass
def load_from_path(self, filepath, mode="r", **kwargs):
with open(filepath, mode) as f:
return self.load_from_fileobj(f, **kwargs)
def dump_to_path(self, obj, filepath, mode="w", **kwargs):
with open(filepath, mode) as f:
self.dump_to_fileobj(obj, f, **kwargs)
class JsonHandler(BaseFileHandler):
def load_from_fileobj(self, file):
return json.load(file)
def dump_to_fileobj(self, obj, file, **kwargs):
json.dump(obj, file, **kwargs)
def dump_to_str(self, obj, **kwargs):
return json.dumps(obj, **kwargs)
class PickleHandler(BaseFileHandler):
def load_from_fileobj(self, file, **kwargs):
return pickle.load(file, **kwargs)
def load_from_path(self, filepath, **kwargs):
return super(PickleHandler, self).load_from_path(filepath, mode="rb", **kwargs)
def dump_to_str(self, obj, **kwargs):
kwargs.setdefault("protocol", 2)
return pickle.dumps(obj, **kwargs)
def dump_to_fileobj(self, obj, file, **kwargs):
kwargs.setdefault("protocol", 2)
pickle.dump(obj, file, **kwargs)
def dump_to_path(self, obj, filepath, **kwargs):
super(PickleHandler, self).dump_to_path(obj, filepath, mode="wb", **kwargs)
class YamlHandler(BaseFileHandler):
def load_from_fileobj(self, file, **kwargs):
kwargs.setdefault("Loader", Loader)
return yaml.load(file, **kwargs)
def dump_to_fileobj(self, obj, file, **kwargs):
kwargs.setdefault("Dumper", Dumper)
yaml.dump(obj, file, **kwargs)
def dump_to_str(self, obj, **kwargs):
kwargs.setdefault("Dumper", Dumper)
return yaml.dump(obj, **kwargs)
file_handlers = {
"json": JsonHandler(),
"yaml": YamlHandler(),
"yml": YamlHandler(),
"pickle": PickleHandler(),
"pkl": PickleHandler(),
}
# ===========================
# load and dump
# ===========================
def is_str(x):
"""Whether the input is an string instance.
Note: This method is deprecated since python 2 is no longer supported.
"""
return isinstance(x, str)
def slload(file, file_format=None, **kwargs):
"""Load data from json/yaml/pickle files.
This method provides a unified api for loading data from serialized files.
Args:
file (str or :obj:`Path` or file-like object): Filename or a file-like
object.
file_format (str, optional): If not specified, the file format will be
inferred from the file extension, otherwise use the specified one.
Currently supported formats include "json", "yaml/yml" and
"pickle/pkl".
Returns:
The content from the file.
"""
if isinstance(file, Path):
file = str(file)
if file_format is None and is_str(file):
file_format = file.split(".")[-1]
if file_format not in file_handlers:
raise TypeError(f"Unsupported format: {file_format}")
handler = file_handlers[file_format]
if is_str(file):
obj = handler.load_from_path(file, **kwargs)
elif hasattr(file, "read"):
obj = handler.load_from_fileobj(file, **kwargs)
else:
raise TypeError('"file" must be a filepath str or a file-object')
return obj
def sldump(obj, file=None, file_format=None, **kwargs):
"""Dump data to json/yaml/pickle strings or files.
This method provides a unified api for dumping data as strings or to files,
and also supports custom arguments for each file format.
Args:
obj (any): The python object to be dumped.
file (str or :obj:`Path` or file-like object, optional): If not
specified, then the object is dump to a str, otherwise to a file
specified by the filename or file-like object.
file_format (str, optional): Same as :func:`load`.
Returns:
bool: True for success, False otherwise.
"""
if isinstance(file, Path):
file = str(file)
if file_format is None:
if is_str(file):
file_format = file.split(".")[-1]
elif file is None:
raise ValueError("file_format must be specified since file is None")
if file_format not in file_handlers:
raise TypeError(f"Unsupported format: {file_format}")
handler = file_handlers[file_format]
if file is None:
return handler.dump_to_str(obj, **kwargs)
elif is_str(file):
handler.dump_to_path(obj, file, **kwargs)
elif hasattr(file, "write"):
handler.dump_to_fileobj(obj, file, **kwargs)
else:
raise TypeError('"file" must be a filename str or a file-object')

View File

@@ -1,62 +0,0 @@
import json
import time
class TimeCounter:
def __init__(self) -> None:
pass
def clear(self):
self.timedict = {}
self.basetime = time.perf_counter()
def timeit(self, name):
nowtime = time.perf_counter() - self.basetime
self.timedict[name] = nowtime
self.basetime = time.perf_counter()
class TimeHolder:
def __init__(self) -> None:
self.timedict = {}
def update(self, _timedict: dict):
for k, v in _timedict.items():
if k not in self.timedict:
self.timedict[k] = AverageMeter(name=k, val_only=True)
self.timedict[k].update(val=v)
def final_res(self):
return {k: v.avg for k, v in self.timedict.items()}
def __str__(self):
return json.dumps(self.final_res(), indent=2)
class AverageMeter(object):
"""Computes and stores the average and current value"""
def __init__(self, name, fmt=":f", val_only=False):
self.name = name
self.fmt = fmt
self.val_only = val_only
self.reset()
def reset(self):
self.val = 0
self.avg = 0
self.sum = 0
self.count = 0
def update(self, val, n=1):
self.val = val
self.sum += val * n
self.count += n
self.avg = self.sum / self.count
def __str__(self):
if self.val_only:
fmtstr = "{name} {val" + self.fmt + "}"
else:
fmtstr = "{name} {val" + self.fmt + "} ({avg" + self.fmt + "})"
return fmtstr.format(**self.__dict__)

View File

@@ -1,598 +0,0 @@
import argparse
import json
import warnings
from collections import OrderedDict
from copy import deepcopy
from typing import Any, Dict, List
import numpy as np
import torch
from transformers import AutoTokenizer
from invokeai.backend.image_util.grounding_segment_anything.groundingdino.util.slconfig import SLConfig
def slprint(x, name="x"):
if isinstance(x, (torch.Tensor, np.ndarray)):
print(f"{name}.shape:", x.shape)
elif isinstance(x, (tuple, list)):
print("type x:", type(x))
for i in range(min(10, len(x))):
slprint(x[i], f"{name}[{i}]")
elif isinstance(x, dict):
for k, v in x.items():
slprint(v, f"{name}[{k}]")
else:
print(f"{name}.type:", type(x))
def clean_state_dict(state_dict):
new_state_dict = OrderedDict()
for k, v in state_dict.items():
if k[:7] == "module.":
k = k[7:] # remove `module.`
new_state_dict[k] = v
return new_state_dict
def renorm(img: torch.FloatTensor, mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) -> torch.FloatTensor:
# img: tensor(3,H,W) or tensor(B,3,H,W)
# return: same as img
assert img.dim() == 3 or img.dim() == 4, "img.dim() should be 3 or 4 but %d" % img.dim()
if img.dim() == 3:
assert img.size(0) == 3, 'img.size(0) shoule be 3 but "%d". (%s)' % (
img.size(0),
str(img.size()),
)
img_perm = img.permute(1, 2, 0)
mean = torch.Tensor(mean)
std = torch.Tensor(std)
img_res = img_perm * std + mean
return img_res.permute(2, 0, 1)
else: # img.dim() == 4
assert img.size(1) == 3, 'img.size(1) shoule be 3 but "%d". (%s)' % (
img.size(1),
str(img.size()),
)
img_perm = img.permute(0, 2, 3, 1)
mean = torch.Tensor(mean)
std = torch.Tensor(std)
img_res = img_perm * std + mean
return img_res.permute(0, 3, 1, 2)
class CocoClassMapper:
def __init__(self) -> None:
self.category_map_str = {
"1": 1,
"2": 2,
"3": 3,
"4": 4,
"5": 5,
"6": 6,
"7": 7,
"8": 8,
"9": 9,
"10": 10,
"11": 11,
"13": 12,
"14": 13,
"15": 14,
"16": 15,
"17": 16,
"18": 17,
"19": 18,
"20": 19,
"21": 20,
"22": 21,
"23": 22,
"24": 23,
"25": 24,
"27": 25,
"28": 26,
"31": 27,
"32": 28,
"33": 29,
"34": 30,
"35": 31,
"36": 32,
"37": 33,
"38": 34,
"39": 35,
"40": 36,
"41": 37,
"42": 38,
"43": 39,
"44": 40,
"46": 41,
"47": 42,
"48": 43,
"49": 44,
"50": 45,
"51": 46,
"52": 47,
"53": 48,
"54": 49,
"55": 50,
"56": 51,
"57": 52,
"58": 53,
"59": 54,
"60": 55,
"61": 56,
"62": 57,
"63": 58,
"64": 59,
"65": 60,
"67": 61,
"70": 62,
"72": 63,
"73": 64,
"74": 65,
"75": 66,
"76": 67,
"77": 68,
"78": 69,
"79": 70,
"80": 71,
"81": 72,
"82": 73,
"84": 74,
"85": 75,
"86": 76,
"87": 77,
"88": 78,
"89": 79,
"90": 80,
}
self.origin2compact_mapper = {int(k): v - 1 for k, v in self.category_map_str.items()}
self.compact2origin_mapper = {int(v - 1): int(k) for k, v in self.category_map_str.items()}
def origin2compact(self, idx):
return self.origin2compact_mapper[int(idx)]
def compact2origin(self, idx):
return self.compact2origin_mapper[int(idx)]
def to_device(item, device):
if isinstance(item, torch.Tensor):
return item.to(device)
elif isinstance(item, list):
return [to_device(i, device) for i in item]
elif isinstance(item, dict):
return {k: to_device(v, device) for k, v in item.items()}
else:
raise NotImplementedError("Call Shilong if you use other containers! type: {}".format(type(item)))
#
def get_gaussian_mean(x, axis, other_axis, softmax=True):
"""
Args:
x (float): Input images(BxCxHxW)
axis (int): The index for weighted mean
other_axis (int): The other index
Returns: weighted index for axis, BxC
"""
mat2line = torch.sum(x, axis=other_axis)
# mat2line = mat2line / mat2line.mean() * 10
if softmax:
u = torch.softmax(mat2line, axis=2)
else:
u = mat2line / (mat2line.sum(2, keepdim=True) + 1e-6)
size = x.shape[axis]
ind = torch.linspace(0, 1, size).to(x.device)
batch = x.shape[0]
channel = x.shape[1]
index = ind.repeat([batch, channel, 1])
mean_position = torch.sum(index * u, dim=2)
return mean_position
def get_expected_points_from_map(hm, softmax=True):
"""get_gaussian_map_from_points
B,C,H,W -> B,N,2 float(0, 1) float(0, 1)
softargmax function
Args:
hm (float): Input images(BxCxHxW)
Returns:
weighted index for axis, BxCx2. float between 0 and 1.
"""
# hm = 10*hm
B, C, H, W = hm.shape
y_mean = get_gaussian_mean(hm, 2, 3, softmax=softmax) # B,C
x_mean = get_gaussian_mean(hm, 3, 2, softmax=softmax) # B,C
# return torch.cat((x_mean.unsqueeze(-1), y_mean.unsqueeze(-1)), 2)
return torch.stack([x_mean, y_mean], dim=2)
# Positional encoding (section 5.1)
# borrow from nerf
class Embedder:
def __init__(self, **kwargs):
self.kwargs = kwargs
self.create_embedding_fn()
def create_embedding_fn(self):
embed_fns = []
d = self.kwargs["input_dims"]
out_dim = 0
if self.kwargs["include_input"]:
embed_fns.append(lambda x: x)
out_dim += d
max_freq = self.kwargs["max_freq_log2"]
N_freqs = self.kwargs["num_freqs"]
if self.kwargs["log_sampling"]:
freq_bands = 2.0 ** torch.linspace(0.0, max_freq, steps=N_freqs)
else:
freq_bands = torch.linspace(2.0**0.0, 2.0**max_freq, steps=N_freqs)
for freq in freq_bands:
for p_fn in self.kwargs["periodic_fns"]:
embed_fns.append(lambda x, p_fn=p_fn, freq=freq: p_fn(x * freq))
out_dim += d
self.embed_fns = embed_fns
self.out_dim = out_dim
def embed(self, inputs):
return torch.cat([fn(inputs) for fn in self.embed_fns], -1)
def get_embedder(multires, i=0):
import torch.nn as nn
if i == -1:
return nn.Identity(), 3
embed_kwargs = {
"include_input": True,
"input_dims": 3,
"max_freq_log2": multires - 1,
"num_freqs": multires,
"log_sampling": True,
"periodic_fns": [torch.sin, torch.cos],
}
embedder_obj = Embedder(**embed_kwargs)
embed = lambda x, eo=embedder_obj: eo.embed(x)
return embed, embedder_obj.out_dim
class APOPMeter:
def __init__(self) -> None:
self.tp = 0
self.fp = 0
self.tn = 0
self.fn = 0
def update(self, pred, gt):
"""
Input:
pred, gt: Tensor()
"""
assert pred.shape == gt.shape
self.tp += torch.logical_and(pred == 1, gt == 1).sum().item()
self.fp += torch.logical_and(pred == 1, gt == 0).sum().item()
self.tn += torch.logical_and(pred == 0, gt == 0).sum().item()
self.tn += torch.logical_and(pred == 1, gt == 0).sum().item()
def update_cm(self, tp, fp, tn, fn):
self.tp += tp
self.fp += fp
self.tn += tn
self.tn += fn
def inverse_sigmoid(x, eps=1e-5):
x = x.clamp(min=0, max=1)
x1 = x.clamp(min=eps)
x2 = (1 - x).clamp(min=eps)
return torch.log(x1 / x2)
def get_raw_dict(args):
"""
return the dicf contained in args.
e.g:
>>> with open(path, 'w') as f:
json.dump(get_raw_dict(args), f, indent=2)
"""
if isinstance(args, argparse.Namespace):
return vars(args)
elif isinstance(args, dict):
return args
elif isinstance(args, SLConfig):
return args._cfg_dict
else:
raise NotImplementedError("Unknown type {}".format(type(args)))
def stat_tensors(tensor):
assert tensor.dim() == 1
tensor_sm = tensor.softmax(0)
entropy = (tensor_sm * torch.log(tensor_sm + 1e-9)).sum()
return {
"max": tensor.max(),
"min": tensor.min(),
"mean": tensor.mean(),
"var": tensor.var(),
"std": tensor.var() ** 0.5,
"entropy": entropy,
}
class NiceRepr:
"""Inherit from this class and define ``__nice__`` to "nicely" print your
objects.
Defines ``__str__`` and ``__repr__`` in terms of ``__nice__`` function
Classes that inherit from :class:`NiceRepr` should redefine ``__nice__``.
If the inheriting class has a ``__len__``, method then the default
``__nice__`` method will return its length.
Example:
>>> class Foo(NiceRepr):
... def __nice__(self):
... return 'info'
>>> foo = Foo()
>>> assert str(foo) == '<Foo(info)>'
>>> assert repr(foo).startswith('<Foo(info) at ')
Example:
>>> class Bar(NiceRepr):
... pass
>>> bar = Bar()
>>> import pytest
>>> with pytest.warns(None) as record:
>>> assert 'object at' in str(bar)
>>> assert 'object at' in repr(bar)
Example:
>>> class Baz(NiceRepr):
... def __len__(self):
... return 5
>>> baz = Baz()
>>> assert str(baz) == '<Baz(5)>'
"""
def __nice__(self):
"""str: a "nice" summary string describing this module"""
if hasattr(self, "__len__"):
# It is a common pattern for objects to use __len__ in __nice__
# As a convenience we define a default __nice__ for these objects
return str(len(self))
else:
# In all other cases force the subclass to overload __nice__
raise NotImplementedError(f"Define the __nice__ method for {self.__class__!r}")
def __repr__(self):
"""str: the string of the module"""
try:
nice = self.__nice__()
classname = self.__class__.__name__
return f"<{classname}({nice}) at {hex(id(self))}>"
except NotImplementedError as ex:
warnings.warn(str(ex), category=RuntimeWarning)
return object.__repr__(self)
def __str__(self):
"""str: the string of the module"""
try:
classname = self.__class__.__name__
nice = self.__nice__()
return f"<{classname}({nice})>"
except NotImplementedError as ex:
warnings.warn(str(ex), category=RuntimeWarning)
return object.__repr__(self)
def ensure_rng(rng=None):
"""Coerces input into a random number generator.
If the input is None, then a global random state is returned.
If the input is a numeric value, then that is used as a seed to construct a
random state. Otherwise the input is returned as-is.
Adapted from [1]_.
Args:
rng (int | numpy.random.RandomState | None):
if None, then defaults to the global rng. Otherwise this can be an
integer or a RandomState class
Returns:
(numpy.random.RandomState) : rng -
a numpy random number generator
References:
.. [1] https://gitlab.kitware.com/computer-vision/kwarray/blob/master/kwarray/util_random.py#L270 # noqa: E501
"""
if rng is None:
rng = np.random.mtrand._rand
elif isinstance(rng, int):
rng = np.random.RandomState(rng)
else:
rng = rng
return rng
def random_boxes(num=1, scale=1, rng=None):
"""Simple version of ``kwimage.Boxes.random``
Returns:
Tensor: shape (n, 4) in x1, y1, x2, y2 format.
References:
https://gitlab.kitware.com/computer-vision/kwimage/blob/master/kwimage/structs/boxes.py#L1390
Example:
>>> num = 3
>>> scale = 512
>>> rng = 0
>>> boxes = random_boxes(num, scale, rng)
>>> print(boxes)
tensor([[280.9925, 278.9802, 308.6148, 366.1769],
[216.9113, 330.6978, 224.0446, 456.5878],
[405.3632, 196.3221, 493.3953, 270.7942]])
"""
rng = ensure_rng(rng)
tlbr = rng.rand(num, 4).astype(np.float32)
tl_x = np.minimum(tlbr[:, 0], tlbr[:, 2])
tl_y = np.minimum(tlbr[:, 1], tlbr[:, 3])
br_x = np.maximum(tlbr[:, 0], tlbr[:, 2])
br_y = np.maximum(tlbr[:, 1], tlbr[:, 3])
tlbr[:, 0] = tl_x * scale
tlbr[:, 1] = tl_y * scale
tlbr[:, 2] = br_x * scale
tlbr[:, 3] = br_y * scale
boxes = torch.from_numpy(tlbr)
return boxes
class ModelEma(torch.nn.Module):
def __init__(self, model, decay=0.9997, device=None):
super(ModelEma, self).__init__()
# make a copy of the model for accumulating moving average of weights
self.module = deepcopy(model)
self.module.eval()
# import ipdb; ipdb.set_trace()
self.decay = decay
self.device = device # perform ema on different device from model if set
if self.device is not None:
self.module.to(device=device)
def _update(self, model, update_fn):
with torch.no_grad():
for ema_v, model_v in zip(self.module.state_dict().values(), model.state_dict().values()):
if self.device is not None:
model_v = model_v.to(device=self.device)
ema_v.copy_(update_fn(ema_v, model_v))
def update(self, model):
self._update(model, update_fn=lambda e, m: self.decay * e + (1.0 - self.decay) * m)
def set(self, model):
self._update(model, update_fn=lambda e, m: m)
class BestMetricSingle:
def __init__(self, init_res=0.0, better="large") -> None:
self.init_res = init_res
self.best_res = init_res
self.best_ep = -1
self.better = better
assert better in ["large", "small"]
def isbetter(self, new_res, old_res):
if self.better == "large":
return new_res > old_res
if self.better == "small":
return new_res < old_res
def update(self, new_res, ep):
if self.isbetter(new_res, self.best_res):
self.best_res = new_res
self.best_ep = ep
return True
return False
def __str__(self) -> str:
return "best_res: {}\t best_ep: {}".format(self.best_res, self.best_ep)
def __repr__(self) -> str:
return self.__str__()
def summary(self) -> dict:
return {
"best_res": self.best_res,
"best_ep": self.best_ep,
}
class BestMetricHolder:
def __init__(self, init_res=0.0, better="large", use_ema=False) -> None:
self.best_all = BestMetricSingle(init_res, better)
self.use_ema = use_ema
if use_ema:
self.best_ema = BestMetricSingle(init_res, better)
self.best_regular = BestMetricSingle(init_res, better)
def update(self, new_res, epoch, is_ema=False):
"""
return if the results is the best.
"""
if not self.use_ema:
return self.best_all.update(new_res, epoch)
else:
if is_ema:
self.best_ema.update(new_res, epoch)
return self.best_all.update(new_res, epoch)
else:
self.best_regular.update(new_res, epoch)
return self.best_all.update(new_res, epoch)
def summary(self):
if not self.use_ema:
return self.best_all.summary()
res = {}
res.update({f"all_{k}": v for k, v in self.best_all.summary().items()})
res.update({f"regular_{k}": v for k, v in self.best_regular.summary().items()})
res.update({f"ema_{k}": v for k, v in self.best_ema.summary().items()})
return res
def __repr__(self) -> str:
return json.dumps(self.summary(), indent=2)
def __str__(self) -> str:
return self.__repr__()
def targets_to(targets: List[Dict[str, Any]], device):
"""Moves the target dicts to the given device."""
excluded_keys = [
"questionId",
"tokens_positive",
"strings_positive",
"tokens",
"dataset_name",
"sentence_id",
"original_img_id",
"nb_eval",
"task_id",
"original_id",
"token_span",
"caption",
"dataset_type",
]
return [{k: v.to(device) if k not in excluded_keys else v for k, v in t.items()} for t in targets]
def get_phrases_from_posmap(posmap: torch.BoolTensor, tokenized: Dict, tokenizer: AutoTokenizer):
assert isinstance(posmap, torch.Tensor), "posmap must be torch.Tensor"
if posmap.dim() == 1:
non_zero_idx = posmap.nonzero(as_tuple=True)[0].tolist()
token_ids = [tokenized["input_ids"][i] for i in non_zero_idx]
return tokenizer.decode(token_ids)
else:
raise NotImplementedError("posmap must be 1-dim")

Some files were not shown because too many files have changed in this diff Show More