mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-01-15 00:58:02 -05:00
feat(ui): seeded random generators
- Add JS Mersenne Twister implementation dependency to use as seeded PRNG. This is not a cryptographically secure algorithm. - Add nullish seed field to float and integer random generators. - Add UI to control the seed. - When seed is not set, behaviour is unchanged - the values are randomized when you Invoke. When seed is set, the random distribution is deterministic depending on the seed. In this case, we can display the values to the user.
This commit is contained in:
@@ -76,6 +76,7 @@
|
||||
"konva": "^9.3.15",
|
||||
"lodash-es": "^4.17.21",
|
||||
"lru-cache": "^11.0.1",
|
||||
"mtwist": "^1.0.2",
|
||||
"nanoid": "^5.0.7",
|
||||
"nanostores": "^0.11.3",
|
||||
"new-github-issue-url": "^1.0.0",
|
||||
|
||||
7
invokeai/frontend/web/pnpm-lock.yaml
generated
7
invokeai/frontend/web/pnpm-lock.yaml
generated
@@ -77,6 +77,9 @@ dependencies:
|
||||
lru-cache:
|
||||
specifier: ^11.0.1
|
||||
version: 11.0.1
|
||||
mtwist:
|
||||
specifier: ^1.0.2
|
||||
version: 1.0.2
|
||||
nanoid:
|
||||
specifier: ^5.0.7
|
||||
version: 5.0.7
|
||||
@@ -7016,6 +7019,10 @@ packages:
|
||||
/ms@2.1.3:
|
||||
resolution: {integrity: sha512-6FlzubTLZG3J2a/NVCAleEhjzq5oxgHyaCU9yYXvcLsvoVaHJq/s5xXI6/XXP6tz7R9xAOtHnSO/tXtF3WRTlA==}
|
||||
|
||||
/mtwist@1.0.2:
|
||||
resolution: {integrity: sha512-eRsSga5jkLg7nNERPOV8vDNxgSwuEcj5upQfJcT0gXfJwXo3pMc7xOga0fu8rXHyrxzl7GFVWWDuaPQgpKDvgw==}
|
||||
dev: false
|
||||
|
||||
/muggle-string@0.3.1:
|
||||
resolution: {integrity: sha512-ckmWDJjphvd/FvZawgygcUeQCxzvohjFO5RxTjj4eq8kw359gFF3E1brjfI+viLMxss5JrHTDRHZvu2/tuy0Qg==}
|
||||
dev: true
|
||||
|
||||
@@ -185,7 +185,8 @@
|
||||
"min": "Min",
|
||||
"max": "Max",
|
||||
"values": "Values",
|
||||
"resetToDefaults": "Reset to Defaults"
|
||||
"resetToDefaults": "Reset to Defaults",
|
||||
"seed": "Seed"
|
||||
},
|
||||
"hrf": {
|
||||
"hrf": "High Resolution Fix",
|
||||
|
||||
@@ -16,7 +16,7 @@ import {
|
||||
getFloatGeneratorDefaults,
|
||||
resolveFloatGeneratorField,
|
||||
} from 'features/nodes/types/field';
|
||||
import { round } from 'lodash-es';
|
||||
import { isNil, round } from 'lodash-es';
|
||||
import { OverlayScrollbarsComponent } from 'overlayscrollbars-react';
|
||||
import type { ChangeEvent } from 'react';
|
||||
import { memo, useCallback, useMemo } from 'react';
|
||||
@@ -63,7 +63,10 @@ export const FloatGeneratorFieldInputComponent = memo(
|
||||
|
||||
const [debouncedField] = useDebounce(field, 300);
|
||||
const resolvedValuesAsString = useMemo(() => {
|
||||
if (debouncedField.value.type === FloatGeneratorUniformRandomDistributionType) {
|
||||
if (
|
||||
debouncedField.value.type === FloatGeneratorUniformRandomDistributionType &&
|
||||
isNil(debouncedField.value.seed)
|
||||
) {
|
||||
const { count } = debouncedField.value;
|
||||
return `<${t('nodes.generatorNRandomValues', { count })}>`;
|
||||
}
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import { CompositeNumberInput, Flex, FormControl, FormLabel } from '@invoke-ai/ui-library';
|
||||
import { Checkbox, CompositeNumberInput, Flex, FormControl, FormLabel } from '@invoke-ai/ui-library';
|
||||
import type { FloatGeneratorUniformRandomDistribution } from 'features/nodes/types/field';
|
||||
import { isNil } from 'lodash-es';
|
||||
import { memo, useCallback } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
|
||||
@@ -29,21 +30,47 @@ export const FloatGeneratorUniformRandomDistributionSettings = memo(
|
||||
},
|
||||
[onChange, state]
|
||||
);
|
||||
const onToggleSeed = useCallback(() => {
|
||||
onChange({ ...state, seed: isNil(state.seed) ? 0 : null });
|
||||
}, [onChange, state]);
|
||||
const onChangeSeed = useCallback(
|
||||
(seed?: number | null) => {
|
||||
onChange({ ...state, seed });
|
||||
},
|
||||
[onChange, state]
|
||||
);
|
||||
|
||||
return (
|
||||
<Flex gap={2} alignItems="flex-end">
|
||||
<FormControl orientation="vertical">
|
||||
<FormLabel>{t('common.min')}</FormLabel>
|
||||
<CompositeNumberInput value={state.min} onChange={onChangeMin} min={-Infinity} max={Infinity} step={0.01} />
|
||||
</FormControl>
|
||||
<FormControl orientation="vertical">
|
||||
<FormLabel>{t('common.max')}</FormLabel>
|
||||
<CompositeNumberInput value={state.max} onChange={onChangeMax} min={-Infinity} max={Infinity} step={0.01} />
|
||||
</FormControl>
|
||||
<FormControl orientation="vertical">
|
||||
<FormLabel>{t('common.count')}</FormLabel>
|
||||
<CompositeNumberInput value={state.count} onChange={onChangeCount} min={1} max={Infinity} />
|
||||
</FormControl>
|
||||
<Flex gap={2} flexDir="column">
|
||||
<Flex gap={2} alignItems="flex-end">
|
||||
<FormControl orientation="vertical">
|
||||
<FormLabel>{t('common.min')}</FormLabel>
|
||||
<CompositeNumberInput value={state.min} onChange={onChangeMin} min={-Infinity} max={Infinity} step={0.01} />
|
||||
</FormControl>
|
||||
<FormControl orientation="vertical">
|
||||
<FormLabel>{t('common.max')}</FormLabel>
|
||||
<CompositeNumberInput value={state.max} onChange={onChangeMax} min={-Infinity} max={Infinity} step={0.01} />
|
||||
</FormControl>
|
||||
<FormControl orientation="vertical">
|
||||
<FormLabel>{t('common.count')}</FormLabel>
|
||||
<CompositeNumberInput value={state.count} onChange={onChangeCount} min={1} max={Infinity} />
|
||||
</FormControl>
|
||||
<FormControl orientation="vertical">
|
||||
<FormLabel alignItems="center" justifyContent="space-between" m={0} display="flex" w="full">
|
||||
{t('common.seed')}
|
||||
<Checkbox onChange={onToggleSeed} isChecked={!isNil(state.seed)} />
|
||||
</FormLabel>
|
||||
<CompositeNumberInput
|
||||
isDisabled={isNil(state.seed)}
|
||||
// This cast is save only because we disable the element when seed is not a number - the `...` is
|
||||
// rendered in the input field in this case
|
||||
value={state.seed ?? ('...' as unknown as number)}
|
||||
onChange={onChangeSeed}
|
||||
min={-Infinity}
|
||||
max={Infinity}
|
||||
/>
|
||||
</FormControl>
|
||||
</Flex>
|
||||
</Flex>
|
||||
);
|
||||
}
|
||||
|
||||
@@ -19,7 +19,7 @@ import {
|
||||
IntegerGeneratorUniformRandomDistributionType,
|
||||
resolveIntegerGeneratorField,
|
||||
} from 'features/nodes/types/field';
|
||||
import { round } from 'lodash-es';
|
||||
import { isNil, round } from 'lodash-es';
|
||||
import { OverlayScrollbarsComponent } from 'overlayscrollbars-react';
|
||||
import type { ChangeEvent } from 'react';
|
||||
import { memo, useCallback, useMemo } from 'react';
|
||||
@@ -65,7 +65,10 @@ export const IntegerGeneratorFieldInputComponent = memo(
|
||||
|
||||
const [debouncedField] = useDebounce(field, 300);
|
||||
const resolvedValuesAsString = useMemo(() => {
|
||||
if (debouncedField.value.type === IntegerGeneratorUniformRandomDistributionType) {
|
||||
if (
|
||||
debouncedField.value.type === IntegerGeneratorUniformRandomDistributionType &&
|
||||
isNil(debouncedField.value.seed)
|
||||
) {
|
||||
const { count } = debouncedField.value;
|
||||
return `<${t('nodes.generatorNRandomValues', { count })}>`;
|
||||
}
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import { CompositeNumberInput, Flex, FormControl, FormLabel } from '@invoke-ai/ui-library';
|
||||
import { Checkbox, CompositeNumberInput, Flex, FormControl, FormLabel } from '@invoke-ai/ui-library';
|
||||
import type { IntegerGeneratorUniformRandomDistribution } from 'features/nodes/types/field';
|
||||
import { isNil } from 'lodash-es';
|
||||
import { memo, useCallback } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
|
||||
@@ -29,21 +30,47 @@ export const IntegerGeneratorUniformRandomDistributionSettings = memo(
|
||||
},
|
||||
[onChange, state]
|
||||
);
|
||||
const onToggleSeed = useCallback(() => {
|
||||
onChange({ ...state, seed: isNil(state.seed) ? 0 : null });
|
||||
}, [onChange, state]);
|
||||
const onChangeSeed = useCallback(
|
||||
(seed?: number | null) => {
|
||||
onChange({ ...state, seed });
|
||||
},
|
||||
[onChange, state]
|
||||
);
|
||||
|
||||
return (
|
||||
<Flex gap={2} alignItems="flex-end">
|
||||
<FormControl orientation="vertical">
|
||||
<FormLabel>{t('common.min')}</FormLabel>
|
||||
<CompositeNumberInput value={state.min} onChange={onChangeMin} min={-Infinity} max={Infinity} />
|
||||
</FormControl>
|
||||
<FormControl orientation="vertical">
|
||||
<FormLabel>{t('common.max')}</FormLabel>
|
||||
<CompositeNumberInput value={state.max} onChange={onChangeMax} min={-Infinity} max={Infinity} />
|
||||
</FormControl>
|
||||
<FormControl orientation="vertical">
|
||||
<FormLabel>{t('common.count')}</FormLabel>
|
||||
<CompositeNumberInput value={state.count} onChange={onChangeCount} min={1} max={Infinity} />
|
||||
</FormControl>
|
||||
<Flex gap={2} flexDir="column">
|
||||
<Flex gap={2} alignItems="flex-end">
|
||||
<FormControl orientation="vertical">
|
||||
<FormLabel>{t('common.min')}</FormLabel>
|
||||
<CompositeNumberInput value={state.min} onChange={onChangeMin} min={-Infinity} max={Infinity} />
|
||||
</FormControl>
|
||||
<FormControl orientation="vertical">
|
||||
<FormLabel>{t('common.max')}</FormLabel>
|
||||
<CompositeNumberInput value={state.max} onChange={onChangeMax} min={-Infinity} max={Infinity} />
|
||||
</FormControl>
|
||||
<FormControl orientation="vertical">
|
||||
<FormLabel>{t('common.count')}</FormLabel>
|
||||
<CompositeNumberInput value={state.count} onChange={onChangeCount} min={1} max={Infinity} />
|
||||
</FormControl>
|
||||
<FormControl orientation="vertical">
|
||||
<FormLabel alignItems="center" justifyContent="space-between" m={0} display="flex" w="full">
|
||||
{t('common.seed')}
|
||||
<Checkbox onChange={onToggleSeed} isChecked={!isNil(state.seed)} />
|
||||
</FormLabel>
|
||||
<CompositeNumberInput
|
||||
isDisabled={isNil(state.seed)}
|
||||
// This cast is save only because we disable the element when seed is not a number - the `...` is
|
||||
// rendered in the input field in this case
|
||||
value={state.seed ?? ('...' as unknown as number)}
|
||||
onChange={onChangeSeed}
|
||||
min={-Infinity}
|
||||
max={Infinity}
|
||||
/>
|
||||
</FormControl>
|
||||
</Flex>
|
||||
</Flex>
|
||||
);
|
||||
}
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import { buildTypeGuard } from 'features/parameters/types/parameterSchemas';
|
||||
import { trim } from 'lodash-es';
|
||||
import { isNil, trim } from 'lodash-es';
|
||||
import MersenneTwister from 'mtwist';
|
||||
import { assert } from 'tsafe';
|
||||
import { z } from 'zod';
|
||||
|
||||
@@ -1057,13 +1058,22 @@ const zFloatGeneratorUniformRandomDistribution = z.object({
|
||||
min: z.number().default(0),
|
||||
max: z.number().default(1),
|
||||
count: z.number().int().default(10),
|
||||
seed: z.number().int().nullish(),
|
||||
values: z.array(z.number()).nullish(),
|
||||
});
|
||||
export type FloatGeneratorUniformRandomDistribution = z.infer<typeof zFloatGeneratorUniformRandomDistribution>;
|
||||
const getFloatGeneratorUniformRandomDistributionDefaults = () => zFloatGeneratorUniformRandomDistribution.parse({});
|
||||
const getRng = (seed?: number | null) => {
|
||||
if (isNil(seed)) {
|
||||
return () => Math.random();
|
||||
}
|
||||
const m = new MersenneTwister(seed);
|
||||
return () => m.random();
|
||||
};
|
||||
const getFloatGeneratorUniformRandomDistributionValues = (generator: FloatGeneratorUniformRandomDistribution) => {
|
||||
const { min, max, count } = generator;
|
||||
const values = Array.from({ length: count }, () => Math.random() * (max - min) + min);
|
||||
const { min, max, count, seed } = generator;
|
||||
const rng = getRng(seed);
|
||||
const values = Array.from({ length: count }, (_) => rng() * (max - min) + min);
|
||||
return values;
|
||||
};
|
||||
|
||||
@@ -1191,13 +1201,15 @@ const zIntegerGeneratorUniformRandomDistribution = z.object({
|
||||
min: z.number().int().default(0),
|
||||
max: z.number().int().default(10),
|
||||
count: z.number().int().default(10),
|
||||
seed: z.number().int().nullish(),
|
||||
values: z.array(z.number().int()).nullish(),
|
||||
});
|
||||
export type IntegerGeneratorUniformRandomDistribution = z.infer<typeof zIntegerGeneratorUniformRandomDistribution>;
|
||||
const getIntegerGeneratorUniformRandomDistributionDefaults = () => zIntegerGeneratorUniformRandomDistribution.parse({});
|
||||
const getIntegerGeneratorUniformRandomDistributionValues = (generator: IntegerGeneratorUniformRandomDistribution) => {
|
||||
const { min, max, count } = generator;
|
||||
const values = Array.from({ length: count }, () => Math.floor(Math.random() * (max - min + 1)) + min);
|
||||
const { min, max, count, seed } = generator;
|
||||
const rng = getRng(seed);
|
||||
const values = Array.from({ length: count }, () => Math.floor(rng() * (max - min + 1)) + min);
|
||||
return values;
|
||||
};
|
||||
|
||||
|
||||
Reference in New Issue
Block a user