Compare commits

..

4 Commits

Author SHA1 Message Date
psychedelicious
6af5a22d3a chore: release v4.2.9.dev3
Instead of using dates, just going to increment.
2024-08-24 15:03:43 +10:00
psychedelicious
8cf4321010 feat(ui): use new Result utils for enqueueing 2024-08-24 14:49:17 +10:00
psychedelicious
aa7f2b096a fix(ui): graph building issue w/ controlnet 2024-08-24 14:48:18 +10:00
psychedelicious
bf0824b56d feat(ui): add Result type & helpers
Wrappers to capture errors and turn into results:
- `withResult` wraps a sync function
- `withResultAsync` wraps an async function

Comments, tests.
2024-08-24 14:46:58 +10:00
7 changed files with 224 additions and 38 deletions

View File

@@ -1,6 +1,9 @@
import { logger } from 'app/logging/logger';
import { enqueueRequested } from 'app/store/actions';
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
import type { SerializableObject } from 'common/types';
import type { Result } from 'common/util/result';
import { isErr, withResult, withResultAsync } from 'common/util/result';
import { $canvasManager } from 'features/controlLayers/konva/CanvasManager';
import { sessionStagingAreaReset, sessionStartedStaging } from 'features/controlLayers/store/canvasV2Slice';
import { prepareLinearUIBatch } from 'features/nodes/util/graph/buildLinearBatchConfig';
@@ -27,48 +30,70 @@ export const addEnqueueRequestedLinear = (startAppListening: AppStartListening)
assert(manager, 'No model found in state');
let didStartStaging = false;
if (!state.canvasV2.session.isStaging && state.canvasV2.session.mode === 'compose') {
dispatch(sessionStartedStaging());
didStartStaging = true;
}
try {
let g: Graph;
let noise: Invocation<'noise'>;
let posCond: Invocation<'compel' | 'sdxl_compel_prompt'>;
assert(model, 'No model found in state');
const base = model.base;
if (base === 'sdxl') {
const result = await buildSDXLGraph(state, manager);
g = result.g;
noise = result.noise;
posCond = result.posCond;
} else if (base === 'sd-1' || base === 'sd-2') {
const result = await buildSD1Graph(state, manager);
g = result.g;
noise = result.noise;
posCond = result.posCond;
} else {
assert(false, `No graph builders for base ${base}`);
}
const batchConfig = prepareLinearUIBatch(state, g, prepend, noise, posCond);
const req = dispatch(
queueApi.endpoints.enqueueBatch.initiate(batchConfig, {
fixedCacheKey: 'enqueueBatch',
})
);
req.reset();
await req.unwrap();
} catch (error) {
log.error({ error: serializeError(error) }, 'Failed to enqueue batch');
const abortStaging = () => {
if (didStartStaging && getState().canvasV2.session.isStaging) {
dispatch(sessionStagingAreaReset());
}
};
let buildGraphResult: Result<
{ g: Graph; noise: Invocation<'noise'>; posCond: Invocation<'compel' | 'sdxl_compel_prompt'> },
Error
>;
assert(model, 'No model found in state');
const base = model.base;
switch (base) {
case 'sdxl':
buildGraphResult = await withResultAsync(() => buildSDXLGraph(state, manager));
break;
case 'sd-1':
case `sd-2`:
buildGraphResult = await withResultAsync(() => buildSD1Graph(state, manager));
break;
default:
assert(false, `No graph builders for base ${base}`);
}
if (isErr(buildGraphResult)) {
log.error({ error: serializeError(buildGraphResult.error) }, 'Failed to build graph');
abortStaging();
return;
}
const { g, noise, posCond } = buildGraphResult.value;
const prepareBatchResult = withResult(() => prepareLinearUIBatch(state, g, prepend, noise, posCond));
if (isErr(prepareBatchResult)) {
log.error({ error: serializeError(prepareBatchResult.error) }, 'Failed to prepare batch');
abortStaging();
return;
}
const req = dispatch(
queueApi.endpoints.enqueueBatch.initiate(prepareBatchResult.value, {
fixedCacheKey: 'enqueueBatch',
})
);
req.reset();
const enqueueResult = await withResultAsync(() => req.unwrap());
if (isErr(enqueueResult)) {
log.error({ error: serializeError(enqueueResult.error) }, 'Failed to enqueue batch');
abortStaging();
return;
}
log.debug({ batchConfig: prepareBatchResult.value } as SerializableObject, 'Enqueued batch');
},
});
};

View File

@@ -0,0 +1,72 @@
import type { Equals } from 'tsafe';
import { assert } from 'tsafe';
import { describe, expect, it } from 'vitest';
import type { ErrResult, OkResult } from './result';
import { Err, isErr, isOk, Ok, withResult, withResultAsync } from './result'; // Adjust import as needed
const promiseify = <T>(fn: () => T): (() => Promise<T>) => {
return () =>
new Promise((resolve) => {
resolve(fn());
});
};
describe('Result Utility Functions', () => {
it('Ok() should create an OkResult', () => {
const result = Ok(42);
expect(result).toEqual({ type: 'Ok', value: 42 });
expect(isOk(result)).toBe(true);
expect(isErr(result)).toBe(false);
assert<Equals<OkResult<number>, typeof result>>(result);
});
it('Err() should create an ErrResult', () => {
const error = new Error('Something went wrong');
const result = Err(error);
expect(result).toEqual({ type: 'Err', error });
expect(isOk(result)).toBe(false);
expect(isErr(result)).toBe(true);
assert<Equals<ErrResult<Error>, typeof result>>(result);
});
it('withResult() should return Ok on success', () => {
const fn = () => 42;
const result = withResult(fn);
expect(isOk(result)).toBe(true);
if (isOk(result)) {
expect(result.value).toBe(42);
}
});
it('withResult() should return Err on exception', () => {
const fn = () => {
throw new Error('Failure');
};
const result = withResult(fn);
expect(isErr(result)).toBe(true);
if (isErr(result)) {
expect(result.error.message).toBe('Failure');
}
});
it('withResultAsync() should return Ok on success', async () => {
const fn = promiseify(() => 42);
const result = await withResultAsync(fn);
expect(isOk(result)).toBe(true);
if (isOk(result)) {
expect(result.value).toBe(42);
}
});
it('withResultAsync() should return Err on exception', async () => {
const fn = promiseify(() => {
throw new Error('Async failure');
});
const result = await withResultAsync(fn);
expect(isErr(result)).toBe(true);
if (isErr(result)) {
expect(result.error.message).toBe('Async failure');
}
});
});

View File

@@ -0,0 +1,89 @@
/**
* Represents a successful result.
* @template T The type of the value.
*/
export type OkResult<T> = { type: 'Ok'; value: T };
/**
* Represents a failed result.
* @template E The type of the error.
*/
export type ErrResult<E> = { type: 'Err'; error: E };
/**
* A union type that represents either a successful result (`Ok`) or a failed result (`Err`).
* @template T The type of the value in the `Ok` case.
* @template E The type of the error in the `Err` case.
*/
export type Result<T, E = Error> = OkResult<T> | ErrResult<E>;
/**
* Creates a successful result.
* @template T The type of the value.
* @param {T} value The value to wrap in an `Ok` result.
* @returns {OkResult<T>} The `Ok` result containing the value.
*/
export function Ok<T>(value: T): OkResult<T> {
return { type: 'Ok', value };
}
/**
* Creates a failed result.
* @template E The type of the error.
* @param {E} error The error to wrap in an `Err` result.
* @returns {ErrResult<E>} The `Err` result containing the error.
*/
export function Err<E>(error: E): ErrResult<E> {
return { type: 'Err', error };
}
/**
* Wraps a synchronous function in a try-catch block, returning a `Result`.
* @template T The type of the value returned by the function.
* @param {() => T} fn The function to execute.
* @returns {Result<T>} An `Ok` result if the function succeeds, or an `Err` result if it throws an error.
*/
export function withResult<T>(fn: () => T): Result<T> {
try {
return Ok(fn());
} catch (error) {
return Err(error instanceof Error ? error : new Error(String(error)));
}
}
/**
* Wraps an asynchronous function in a try-catch block, returning a `Promise` of a `Result`.
* @template T The type of the value returned by the function.
* @param {() => Promise<T>} fn The asynchronous function to execute.
* @returns {Promise<Result<T>>} A `Promise` resolving to an `Ok` result if the function succeeds, or an `Err` result if it throws an error.
*/
export async function withResultAsync<T>(fn: () => Promise<T>): Promise<Result<T>> {
try {
const result = await fn();
return Ok(result);
} catch (error) {
return Err(error instanceof Error ? error : new Error(String(error)));
}
}
/**
* Type guard to check if a `Result` is an `Ok` result.
* @template T The type of the value in the `Ok` result.
* @template E The type of the error in the `Err` result.
* @param {Result<T, E>} result The result to check.
* @returns {result is OkResult<T>} `true` if the result is an `Ok` result, otherwise `false`.
*/
export function isOk<T, E>(result: Result<T, E>): result is OkResult<T> {
return result.type === 'Ok';
}
/**
* Type guard to check if a `Result` is an `Err` result.
* @template T The type of the value in the `Ok` result.
* @template E The type of the error in the `Err` result.
* @param {Result<T, E>} result The result to check.
* @returns {result is ErrResult<E>} `true` if the result is an `Err` result, otherwise `false`.
*/
export function isErr<T, E>(result: Result<T, E>): result is ErrResult<E> {
return result.type === 'Err';
}

View File

@@ -36,7 +36,7 @@ export const addControlNets = async (
const adapter = manager.adapters.controlLayers.get(layer.id);
assert(adapter, 'Adapter not found');
const imageDTO = await adapter.renderer.rasterize({ rect: bbox, attrs: { opacity: 1, filters: [] } });
await addControlNetToGraph(g, layer, imageDTO, collector);
addControlNetToGraph(g, layer, imageDTO, collector);
}
return result;
@@ -69,7 +69,7 @@ export const addT2IAdapters = async (
const adapter = manager.adapters.controlLayers.get(layer.id);
assert(adapter, 'Adapter not found');
const imageDTO = await adapter.renderer.rasterize({ rect: bbox, attrs: { opacity: 1, filters: [] } });
await addT2IAdapterToGraph(g, layer, imageDTO, collector);
addT2IAdapterToGraph(g, layer, imageDTO, collector);
}
return result;

View File

@@ -233,7 +233,7 @@ export const buildSD1Graph = async (
state.canvasV2.controlLayers.entities,
g,
state.canvasV2.bbox.rect,
controlNetCollector,
t2iAdapterCollector,
modelConfig.base
);
if (t2iAdapterResult.addedT2IAdapters > 0) {

View File

@@ -236,7 +236,7 @@ export const buildSDXLGraph = async (
state.canvasV2.controlLayers.entities,
g,
state.canvasV2.bbox.rect,
controlNetCollector,
t2iAdapterCollector,
modelConfig.base
);
if (t2iAdapterResult.addedT2IAdapters > 0) {

View File

@@ -1 +1 @@
__version__ = "4.2.9.dev20240824"
__version__ = "4.2.9.dev3"