feat(engine): add support for MLX AI provider (#437)

* docs(CONTRIBUTING.md): update `TODO.md` reference (#435)

Signed-off-by: Emmanuel Ferdman <emmanuelferdman@gmail.com>

* feat(engine): add support for MLX AI provider
docs/engine: update documentation to include new engine providers

* fix(mlx.ts): add repetition_penalty option to generateCommitMessage method for improved model behavior

---------

Signed-off-by: Emmanuel Ferdman <emmanuelferdman@gmail.com>
Co-authored-by: Emmanuel Ferdman <emmanuelferdman@gmail.com>
This commit is contained in:
albi ⚡️
2024-12-09 11:02:38 +01:00
committed by GitHub
parent dd65b9c3e3
commit 26ebfb416d
7 changed files with 145 additions and 16 deletions

View File

@@ -18,7 +18,7 @@ To get started, follow these steps:
1. Clone the project repository locally.
2. Install dependencies with `npm install`.
3. Run the project with `npm run dev`.
4. See [issues](https://github.com/di-sukharev/opencommit/issues) or [TODO.md](../TODO.md) to help the project.
4. See [issues](https://github.com/di-sukharev/opencommit/issues) or [TODO.md](TODO.md) to help the project.
## Commit message guidelines

View File

@@ -431,8 +431,8 @@ var require_escape = __commonJS({
}
function escapeArgument(arg, doubleEscapeMetaChars) {
arg = `${arg}`;
arg = arg.replace(/(\\*)"/g, '$1$1\\"');
arg = arg.replace(/(\\*)$/, "$1$1");
arg = arg.replace(/(?=(\\+?)?)\1"/g, '$1$1\\"');
arg = arg.replace(/(?=(\\+?)?)\1$/, "$1$1");
arg = `"${arg}"`;
arg = arg.replace(metaCharsRegExp, "^$1");
if (doubleEscapeMetaChars) {
@@ -578,7 +578,7 @@ var require_enoent = __commonJS({
const originalEmit = cp.emit;
cp.emit = function(name, arg1) {
if (name === "exit") {
const err = verifyENOENT(arg1, parsed, "spawn");
const err = verifyENOENT(arg1, parsed);
if (err) {
return originalEmit.call(cp, "error", err);
}
@@ -27389,7 +27389,8 @@ var package_default = {
"test:unit:docker": "npm run test:docker-build && DOCKER_CONTENT_TRUST=0 docker run --rm oco-test npm run test:unit",
"test:e2e": "npm run test:e2e:setup && jest test/e2e",
"test:e2e:setup": "sh test/e2e/setup.sh",
"test:e2e:docker": "npm run test:docker-build && DOCKER_CONTENT_TRUST=0 docker run --rm oco-test npm run test:e2e"
"test:e2e:docker": "npm run test:docker-build && DOCKER_CONTENT_TRUST=0 docker run --rm oco-test npm run test:e2e",
"mlx:start": "OCO_AI_PROVIDER='mlx' node ./out/cli.cjs"
},
devDependencies: {
"@commitlint/types": "^17.4.4",
@@ -29933,6 +29934,8 @@ var getDefaultModel = (provider) => {
switch (provider) {
case "ollama":
return "";
case "mlx":
return "";
case "anthropic":
return MODEL_LIST.anthropic[0];
case "gemini":
@@ -29964,7 +29967,7 @@ var configValidators = {
validateConfig(
"OCO_API_KEY",
value,
'You need to provide the OCO_API_KEY when OCO_AI_PROVIDER set to "openai" (default) or "ollama" or "azure" or "gemini" or "flowise" or "anthropic". Run `oco config set OCO_API_KEY=your_key OCO_AI_PROVIDER=openai`'
'You need to provide the OCO_API_KEY when OCO_AI_PROVIDER set to "openai" (default) or "ollama" or "mlx" or "azure" or "gemini" or "flowise" or "anthropic". Run `oco config set OCO_API_KEY=your_key OCO_AI_PROVIDER=openai`'
);
return value;
},
@@ -30070,8 +30073,8 @@ var configValidators = {
"test",
"flowise",
"groq"
].includes(value) || value.startsWith("ollama"),
`${value} is not supported yet, use 'ollama', 'anthropic', 'azure', 'gemini', 'flowise' or 'openai' (default)`
].includes(value) || value.startsWith("ollama") || value.startsWith("mlx"),
`${value} is not supported yet, use 'ollama', 'mlx', anthropic', 'azure', 'gemini', 'flowise' or 'openai' (default)`
);
return value;
},
@@ -30111,6 +30114,7 @@ var OCO_AI_PROVIDER_ENUM = /* @__PURE__ */ ((OCO_AI_PROVIDER_ENUM2) => {
OCO_AI_PROVIDER_ENUM2["TEST"] = "test";
OCO_AI_PROVIDER_ENUM2["FLOWISE"] = "flowise";
OCO_AI_PROVIDER_ENUM2["GROQ"] = "groq";
OCO_AI_PROVIDER_ENUM2["MLX"] = "mlx";
return OCO_AI_PROVIDER_ENUM2;
})(OCO_AI_PROVIDER_ENUM || {});
var defaultConfigPath = (0, import_path.join)((0, import_os.homedir)(), ".opencommit");
@@ -44524,6 +44528,38 @@ var GroqEngine = class extends OpenAiEngine {
}
};
// src/engine/mlx.ts
var MLXEngine = class {
constructor(config7) {
this.config = config7;
this.client = axios_default.create({
url: config7.baseURL ? `${config7.baseURL}/${config7.apiKey}` : "http://localhost:8080/v1/chat/completions",
headers: { "Content-Type": "application/json" }
});
}
async generateCommitMessage(messages) {
const params = {
messages,
temperature: 0,
top_p: 0.1,
repetition_penalty: 1.5,
stream: false
};
try {
const response = await this.client.post(
this.client.getUri(this.config),
params
);
const choices = response.data.choices;
const message = choices[0].message;
return message?.content;
} catch (err) {
const message = err.response?.data?.error ?? err.message;
throw new Error(`MLX provider error: ${message}`);
}
}
};
// src/utils/engine.ts
function getEngine() {
const config7 = getConfig();
@@ -44550,6 +44586,8 @@ function getEngine() {
return new FlowiseEngine(DEFAULT_CONFIG2);
case "groq" /* GROQ */:
return new GroqEngine(DEFAULT_CONFIG2);
case "mlx" /* MLX */:
return new MLXEngine(DEFAULT_CONFIG2);
default:
return new OpenAiEngine(DEFAULT_CONFIG2);
}

View File

@@ -48745,6 +48745,8 @@ var getDefaultModel = (provider) => {
switch (provider) {
case "ollama":
return "";
case "mlx":
return "";
case "anthropic":
return MODEL_LIST.anthropic[0];
case "gemini":
@@ -48776,7 +48778,7 @@ var configValidators = {
validateConfig(
"OCO_API_KEY",
value,
'You need to provide the OCO_API_KEY when OCO_AI_PROVIDER set to "openai" (default) or "ollama" or "azure" or "gemini" or "flowise" or "anthropic". Run `oco config set OCO_API_KEY=your_key OCO_AI_PROVIDER=openai`'
'You need to provide the OCO_API_KEY when OCO_AI_PROVIDER set to "openai" (default) or "ollama" or "mlx" or "azure" or "gemini" or "flowise" or "anthropic". Run `oco config set OCO_API_KEY=your_key OCO_AI_PROVIDER=openai`'
);
return value;
},
@@ -48882,8 +48884,8 @@ var configValidators = {
"test",
"flowise",
"groq"
].includes(value) || value.startsWith("ollama"),
`${value} is not supported yet, use 'ollama', 'anthropic', 'azure', 'gemini', 'flowise' or 'openai' (default)`
].includes(value) || value.startsWith("ollama") || value.startsWith("mlx"),
`${value} is not supported yet, use 'ollama', 'mlx', anthropic', 'azure', 'gemini', 'flowise' or 'openai' (default)`
);
return value;
},
@@ -63325,6 +63327,38 @@ var GroqEngine = class extends OpenAiEngine {
}
};
// src/engine/mlx.ts
var MLXEngine = class {
constructor(config6) {
this.config = config6;
this.client = axios_default.create({
url: config6.baseURL ? `${config6.baseURL}/${config6.apiKey}` : "http://localhost:8080/v1/chat/completions",
headers: { "Content-Type": "application/json" }
});
}
async generateCommitMessage(messages) {
const params = {
messages,
temperature: 0,
top_p: 0.1,
repetition_penalty: 1.5,
stream: false
};
try {
const response = await this.client.post(
this.client.getUri(this.config),
params
);
const choices = response.data.choices;
const message = choices[0].message;
return message?.content;
} catch (err) {
const message = err.response?.data?.error ?? err.message;
throw new Error(`MLX provider error: ${message}`);
}
}
};
// src/utils/engine.ts
function getEngine() {
const config6 = getConfig();
@@ -63351,6 +63385,8 @@ function getEngine() {
return new FlowiseEngine(DEFAULT_CONFIG2);
case "groq" /* GROQ */:
return new GroqEngine(DEFAULT_CONFIG2);
case "mlx" /* MLX */:
return new MLXEngine(DEFAULT_CONFIG2);
default:
return new OpenAiEngine(DEFAULT_CONFIG2);
}

View File

@@ -58,7 +58,8 @@
"test:unit:docker": "npm run test:docker-build && DOCKER_CONTENT_TRUST=0 docker run --rm oco-test npm run test:unit",
"test:e2e": "npm run test:e2e:setup && jest test/e2e",
"test:e2e:setup": "sh test/e2e/setup.sh",
"test:e2e:docker": "npm run test:docker-build && DOCKER_CONTENT_TRUST=0 docker run --rm oco-test npm run test:e2e"
"test:e2e:docker": "npm run test:docker-build && DOCKER_CONTENT_TRUST=0 docker run --rm oco-test npm run test:e2e",
"mlx:start": "OCO_AI_PROVIDER='mlx' node ./out/cli.cjs"
},
"devDependencies": {
"@commitlint/types": "^17.4.4",

View File

@@ -93,6 +93,8 @@ const getDefaultModel = (provider: string | undefined): string => {
switch (provider) {
case 'ollama':
return '';
case 'mlx':
return '';
case 'anthropic':
return MODEL_LIST.anthropic[0];
case 'gemini':
@@ -138,7 +140,7 @@ export const configValidators = {
validateConfig(
'OCO_API_KEY',
value,
'You need to provide the OCO_API_KEY when OCO_AI_PROVIDER set to "openai" (default) or "ollama" or "azure" or "gemini" or "flowise" or "anthropic". Run `oco config set OCO_API_KEY=your_key OCO_AI_PROVIDER=openai`'
'You need to provide the OCO_API_KEY when OCO_AI_PROVIDER set to "openai" (default) or "ollama" or "mlx" or "azure" or "gemini" or "flowise" or "anthropic". Run `oco config set OCO_API_KEY=your_key OCO_AI_PROVIDER=openai`'
);
return value;
@@ -261,8 +263,8 @@ export const configValidators = {
'test',
'flowise',
'groq'
].includes(value) || value.startsWith('ollama'),
`${value} is not supported yet, use 'ollama', 'anthropic', 'azure', 'gemini', 'flowise' or 'openai' (default)`
].includes(value) || value.startsWith('ollama') || value.startsWith('mlx'),
`${value} is not supported yet, use 'ollama', 'mlx', anthropic', 'azure', 'gemini', 'flowise' or 'openai' (default)`
);
return value;
@@ -307,7 +309,8 @@ export enum OCO_AI_PROVIDER_ENUM {
AZURE = 'azure',
TEST = 'test',
FLOWISE = 'flowise',
GROQ = 'groq'
GROQ = 'groq',
MLX = 'mlx'
}
export type ConfigType = {

47
src/engine/mlx.ts Normal file
View File

@@ -0,0 +1,47 @@
import axios, { AxiosInstance } from 'axios';
import { OpenAI } from 'openai';
import { AiEngine, AiEngineConfig } from './Engine';
import { chown } from 'fs';
interface MLXConfig extends AiEngineConfig {}
export class MLXEngine implements AiEngine {
config: MLXConfig;
client: AxiosInstance;
constructor(config) {
this.config = config;
this.client = axios.create({
url: config.baseURL
? `${config.baseURL}/${config.apiKey}`
: 'http://localhost:8080/v1/chat/completions',
headers: { 'Content-Type': 'application/json' }
});
}
async generateCommitMessage(
messages: Array<OpenAI.Chat.Completions.ChatCompletionMessageParam>):
Promise<string | undefined> {
const params = {
messages,
temperature: 0,
top_p: 0.1,
repetition_penalty: 1.5,
stream: false
};
try {
const response = await this.client.post(
this.client.getUri(this.config),
params
);
const choices = response.data.choices;
const message = choices[0].message;
return message?.content;
} catch (err: any) {
const message = err.response?.data?.error ?? err.message;
throw new Error(`MLX provider error: ${message}`);
}
}
}

View File

@@ -8,6 +8,7 @@ import { OllamaEngine } from '../engine/ollama';
import { OpenAiEngine } from '../engine/openAi';
import { TestAi, TestMockType } from '../engine/testAi';
import { GroqEngine } from '../engine/groq';
import { MLXEngine } from '../engine/mlx';
export function getEngine(): AiEngine {
const config = getConfig();
@@ -43,6 +44,9 @@ export function getEngine(): AiEngine {
case OCO_AI_PROVIDER_ENUM.GROQ:
return new GroqEngine(DEFAULT_CONFIG);
case OCO_AI_PROVIDER_ENUM.MLX:
return new MLXEngine(DEFAULT_CONFIG);
default:
return new OpenAiEngine(DEFAULT_CONFIG);
}