Maryhipp/chatgpt UI (#7969)

* add GPTimage1 as allowed base model

* fix for non-disabled inpaint layers

* lots of boilerplate for adding gpt-image base model and disabling things along with imagen

* handle gpt-image dimensions

* build graph for gpt-image

* lint

* feat(ui): make chatgpt model naming consistent

* feat(ui): graph builder naming

* feat(ui): disable img2img for imagen3

* feat(ui): more naming

* feat(ui): support presigned url prefetch

* feat(ui): disable neg prompt for chatgpt

* docs(ui): update docstring

* feat(ui): fix graph building issues for chatgpt

* fix(ui): node ids for chatgpt/imagen

* chore(ui): typegen

---------

Co-authored-by: Mary Hipp <maryhipp@Marys-MacBook-Air.local>
Co-authored-by: psychedelicious <4822129+psychedelicious@users.noreply.github.com>
This commit is contained in:
Mary Hipp Rogers
2025-04-29 09:38:03 -04:00
committed by GitHub
parent 13d44f47ce
commit 17027c4070
31 changed files with 282 additions and 443 deletions

View File

@@ -24,6 +24,7 @@ export const CanvasAddEntityButtons = memo(() => {
const isReferenceImageEnabled = useIsEntityTypeEnabled('reference_image');
const isRegionalGuidanceEnabled = useIsEntityTypeEnabled('regional_guidance');
const isControlLayerEnabled = useIsEntityTypeEnabled('control_layer');
const isInpaintLayerEnabled = useIsEntityTypeEnabled('inpaint_mask');
return (
<Flex w="full" h="full" justifyContent="center" gap={4}>
@@ -52,6 +53,7 @@ export const CanvasAddEntityButtons = memo(() => {
justifyContent="flex-start"
leftIcon={<PiPlusBold />}
onClick={addInpaintMask}
isDisabled={!isInpaintLayerEnabled}
>
{t('controlLayers.inpaintMask')}
</Button>

View File

@@ -25,6 +25,7 @@ export const EntityListGlobalActionBarAddLayerMenu = memo(() => {
const isReferenceImageEnabled = useIsEntityTypeEnabled('reference_image');
const isRegionalGuidanceEnabled = useIsEntityTypeEnabled('regional_guidance');
const isControlLayerEnabled = useIsEntityTypeEnabled('control_layer');
const isInpaintLayerEnabled = useIsEntityTypeEnabled('inpaint_mask');
return (
<Menu>
@@ -46,7 +47,7 @@ export const EntityListGlobalActionBarAddLayerMenu = memo(() => {
</MenuItem>
</MenuGroup>
<MenuGroup title={t('controlLayers.regional')}>
<MenuItem icon={<PiPlusBold />} onClick={addInpaintMask}>
<MenuItem icon={<PiPlusBold />} onClick={addInpaintMask} isDisabled={!isInpaintLayerEnabled}>
{t('controlLayers.inpaintMask')}
</MenuItem>
<MenuItem icon={<PiPlusBold />} onClick={addRegionalGuidance} isDisabled={!isRegionalGuidanceEnabled}>

View File

@@ -1,5 +1,10 @@
import { useAppSelector } from 'app/store/storeHooks';
import { selectIsCogView4, selectIsImagen3, selectIsSD3 } from 'features/controlLayers/store/paramsSlice';
import {
selectIsChatGTP4o,
selectIsCogView4,
selectIsImagen3,
selectIsSD3,
} from 'features/controlLayers/store/paramsSlice';
import type { CanvasEntityType } from 'features/controlLayers/store/types';
import { useMemo } from 'react';
import type { Equals } from 'tsafe';
@@ -9,23 +14,24 @@ export const useIsEntityTypeEnabled = (entityType: CanvasEntityType) => {
const isSD3 = useAppSelector(selectIsSD3);
const isCogView4 = useAppSelector(selectIsCogView4);
const isImagen3 = useAppSelector(selectIsImagen3);
const isChatGPT4o = useAppSelector(selectIsChatGTP4o);
const isEntityTypeEnabled = useMemo<boolean>(() => {
switch (entityType) {
case 'reference_image':
return !isSD3 && !isCogView4 && !isImagen3;
return !isSD3 && !isCogView4 && !isImagen3 && !isChatGPT4o;
case 'regional_guidance':
return !isSD3 && !isCogView4 && !isImagen3;
return !isSD3 && !isCogView4 && !isImagen3 && !isChatGPT4o;
case 'control_layer':
return !isSD3 && !isCogView4 && !isImagen3;
return !isSD3 && !isCogView4 && !isImagen3 && !isChatGPT4o;
case 'inpaint_mask':
return !isImagen3;
return !isImagen3 && !isChatGPT4o;
case 'raster_layer':
return !isImagen3;
return !isImagen3 && !isChatGPT4o;
default:
assert<Equals<typeof entityType, never>>(false);
}
}, [entityType, isSD3, isCogView4, isImagen3]);
}, [entityType, isSD3, isCogView4, isImagen3, isChatGPT4o]);
return isEntityTypeEnabled;
};

View File

@@ -112,7 +112,7 @@ export class CanvasObjectImage extends CanvasModuleBase {
return;
}
const imageElementResult = await withResultAsync(() => loadImage(imageDTO.image_url));
const imageElementResult = await withResultAsync(() => loadImage(imageDTO.image_url, true));
if (imageElementResult.isErr()) {
// Image loading failed (e.g. the URL to the "physical" image is invalid)
this.onFailedToLoadImage(t('controlLayers.unableToLoadImage', 'Unable to load image'));

View File

@@ -235,8 +235,8 @@ export class CanvasBboxToolModule extends CanvasModuleBase {
if (tool !== 'bbox') {
return NO_ANCHORS;
}
if (model?.base === 'imagen3') {
// The bbox is not resizable in imagen3 mode
if (model?.base === 'imagen3' || model?.base === 'chatgpt-4o') {
// The bbox is not resizable in these modes
return NO_ANCHORS;
}
return ALL_ANCHORS;

View File

@@ -476,15 +476,24 @@ export function getImageDataTransparency(imageData: ImageData): Transparency {
/**
* Loads an image from a URL and returns a promise that resolves with the loaded image element.
* @param src The image source URL
* @param fetchUrlFirst Whether to fetch the image's URL first, assuming the provided `src` will redirect to a different URL. This addresses an issue where CORS headers are dropped during a redirect.
* @returns A promise that resolves with the loaded image element
*/
export function loadImage(src: string): Promise<HTMLImageElement> {
export async function loadImage(src: string, fetchUrlFirst?: boolean): Promise<HTMLImageElement> {
const authToken = $authToken.get();
let url = src;
if (authToken && fetchUrlFirst) {
const response = await fetch(`${src}?url_only=true`, { credentials: 'include' });
const data = await response.json();
url = data.url;
}
return new Promise((resolve, reject) => {
const imageElement = new Image();
imageElement.onload = () => resolve(imageElement);
imageElement.onerror = (error) => reject(error);
imageElement.crossOrigin = $authToken.get() ? 'use-credentials' : 'anonymous';
imageElement.src = src;
imageElement.src = url;
});
}

View File

@@ -67,7 +67,7 @@ import type {
IPMethodV2,
T2IAdapterConfig,
} from './types';
import { getEntityIdentifier, isImagen3AspectRatioID, isRenderableEntity } from './types';
import { getEntityIdentifier, isChatGPT4oAspectRatioID, isImagen3AspectRatioID, isRenderableEntity } from './types';
import {
converters,
getControlLayerState,
@@ -1232,6 +1232,20 @@ export const canvasSlice = createSlice({
}
state.bbox.aspectRatio.value = state.bbox.rect.width / state.bbox.rect.height;
state.bbox.aspectRatio.isLocked = true;
} else if (state.bbox.modelBase === 'chatgpt-4o' && isChatGPT4oAspectRatioID(id)) {
// gpt-image has specific output sizes that are not exactly the same as the aspect ratio. Need special handling.
if (id === '3:2') {
state.bbox.rect.width = 1536;
state.bbox.rect.height = 1024;
} else if (id === '1:1') {
state.bbox.rect.width = 1024;
state.bbox.rect.height = 1024;
} else if (id === '2:3') {
state.bbox.rect.width = 1024;
state.bbox.rect.height = 1536;
}
state.bbox.aspectRatio.value = state.bbox.rect.width / state.bbox.rect.height;
state.bbox.aspectRatio.isLocked = true;
} else {
state.bbox.aspectRatio.isLocked = true;
state.bbox.aspectRatio.value = ASPECT_RATIO_MAP[id].ratio;
@@ -1704,7 +1718,7 @@ export const canvasSlice = createSlice({
const base = model?.base;
if (isMainModelBase(base) && state.bbox.modelBase !== base) {
state.bbox.modelBase = base;
if (base === 'imagen3') {
if (base === 'imagen3' || base === 'chatgpt-4o') {
state.bbox.aspectRatio.isLocked = true;
state.bbox.aspectRatio.value = 1;
state.bbox.aspectRatio.id = '1:1';
@@ -1843,7 +1857,7 @@ export const canvasPersistConfig: PersistConfig<CanvasState> = {
};
const syncScaledSize = (state: CanvasState) => {
if (state.bbox.modelBase === 'imagen3') {
if (state.bbox.modelBase === 'imagen3' || state.bbox.modelBase === 'chatgpt-4o') {
// Imagen3 has fixed sizes. Scaled bbox is not supported.
return;
}

View File

@@ -381,6 +381,7 @@ export const selectIsFLUX = createParamsSelector((params) => params.model?.base
export const selectIsSD3 = createParamsSelector((params) => params.model?.base === 'sd-3');
export const selectIsCogView4 = createParamsSelector((params) => params.model?.base === 'cogview4');
export const selectIsImagen3 = createParamsSelector((params) => params.model?.base === 'imagen3');
export const selectIsChatGTP4o = createParamsSelector((params) => params.model?.base === 'chatgpt-4o');
export const selectModel = createParamsSelector((params) => params.model);
export const selectModelKey = createParamsSelector((params) => params.model?.key);

View File

@@ -388,9 +388,15 @@ export type StagingAreaImage = {
};
export const zAspectRatioID = z.enum(['Free', '16:9', '3:2', '4:3', '1:1', '3:4', '2:3', '9:16']);
export const zImagen3AspectRatioID = z.enum(['16:9', '4:3', '1:1', '3:4', '9:16']);
export const isImagen3AspectRatioID = (v: unknown): v is z.infer<typeof zImagen3AspectRatioID> =>
zImagen3AspectRatioID.safeParse(v).success;
export const zChatGPT4oAspectRatioID = z.enum(['3:2', '1:1', '2:3']);
export const isChatGPT4oAspectRatioID = (v: unknown): v is z.infer<typeof zChatGPT4oAspectRatioID> =>
zChatGPT4oAspectRatioID.safeParse(v).success;
export type AspectRatioID = z.infer<typeof zAspectRatioID>;
export const isAspectRatioID = (v: unknown): v is AspectRatioID => zAspectRatioID.safeParse(v).success;