refactor models for better modularity, use sweaterdog ollama/local

This commit is contained in:
MaxRobinsonTheGreat 2025-08-20 13:08:59 -05:00
parent 0cd4dcd420
commit 2a38d310fc
4 changed files with 150 additions and 179 deletions

88
src/models/_model_map.js Normal file
View file

@ -0,0 +1,88 @@
import { Gemini } from './gemini.js';
import { GPT } from './gpt.js';
import { Claude } from './claude.js';
import { Mistral } from './mistral.js';
import { ReplicateAPI } from './replicate.js';
import { Ollama } from './ollama.js';
import { Novita } from './novita.js';
import { GroqCloudAPI } from './groq.js';
import { HuggingFace } from './huggingface.js';
import { Qwen } from "./qwen.js";
import { Grok } from "./grok.js";
import { DeepSeek } from './deepseek.js';
import { Hyperbolic } from './hyperbolic.js';
import { GLHF } from './glhf.js';
import { OpenRouter } from './openrouter.js';
import { VLLM } from './vllm.js';
// Add new models here.
// It maps api prefixes to model classes, eg 'openai/gpt-4o' -> GPT
const apiMap = {
'openai': GPT,
'google': Gemini,
'anthropic': Claude,
'replicate': ReplicateAPI,
'ollama': Ollama,
'mistral': Mistral,
'groq': GroqCloudAPI,
'huggingface': HuggingFace,
'novita': Novita,
'qwen': Qwen,
'grok': Grok,
'deepseek': DeepSeek,
'hyperbolic': Hyperbolic,
'glhf': GLHF,
'openrouter': OpenRouter,
'vllm': VLLM,
}
export function selectAPI(profile) {
if (typeof profile === 'string' || profile instanceof String) {
profile = {model: profile};
}
const api = Object.keys(apiMap).find(key => profile.model.startsWith(key));
if (api) {
profile.api = api;
}
else {
// backwards compatibility with local->ollama
if (profile.model.includes('local')) {
profile.api = 'ollama';
profile.model = profile.model.replace('local/', '');
}
// check for some common models that do not require prefixes
else if (profile.model.includes('gpt') || profile.model.includes('o1')|| profile.model.includes('o3'))
profile.api = 'openai';
else if (profile.model.includes('claude'))
profile.api = 'anthropic';
else if (profile.model.includes('gemini'))
profile.api = "google";
else if (profile.model.includes('grok'))
profile.api = 'grok';
else if (profile.model.includes('mistral'))
profile.api = 'mistral';
else if (profile.model.includes('deepseek'))
profile.api = 'deepseek';
else if (profile.model.includes('qwen'))
profile.api = 'qwen';
}
if (!profile.api) {
throw new Error('Unknown model:', profile.model);
}
let model_name = profile.model.replace(profile.api + '/', ''); // remove prefix
profile.model = model_name === "" ? null : model_name; // if model is empty, set to null
return profile;
}
export function createModel(profile) {
if (!!apiMap[profile.model]) {
// if the model value is an api (instead of a specific model name)
// then set model to null so it uses the default model for that api
profile.model = null;
}
if (!apiMap[profile.api]) {
throw new Error('Unknown api:', profile.api);
}
const model = new apiMap[profile.api](profile.model, profile.url, profile.params);
return model;
}

View file

@ -142,15 +142,15 @@ export class Gemini {
}
async embed(text) {
let model;
let model = this.model_name || "text-embedding-004";
if (this.url) {
model = this.genAI.getGenerativeModel(
{ model: "text-embedding-004" },
{ model },
{ baseUrl: this.url }
);
} else {
model = this.genAI.getGenerativeModel(
{ model: "text-embedding-004" }
{ model }
);
}

View file

@ -1,6 +1,6 @@
import { strictFormat } from '../utils/text.js';
export class Local {
export class Ollama {
constructor(model_name, url, params) {
this.model_name = model_name;
this.params = params;
@ -10,11 +10,9 @@ export class Local {
}
async sendRequest(turns, systemMessage) {
let model = this.model_name || 'llama3.1'; // Updated to llama3.1, as it is more performant than llama3
let model = this.model_name || 'sweaterdog/andy-4:micro-q5_k_m';
let messages = strictFormat(turns);
messages.unshift({ role: 'system', content: systemMessage });
// We'll attempt up to 5 times for models with deepseek-r1-esk reasoning if the <think> tags are mismatched.
const maxAttempts = 5;
let attempt = 0;
let finalRes = null;
@ -24,14 +22,14 @@ export class Local {
console.log(`Awaiting local response... (model: ${model}, attempt: ${attempt})`);
let res = null;
try {
res = await this.send(this.chat_endpoint, {
let apiResponse = await this.send(this.chat_endpoint, {
model: model,
messages: messages,
stream: false,
...(this.params || {})
});
if (res) {
res = res['message']['content'];
if (apiResponse) {
res = apiResponse['message']['content'];
} else {
res = 'No response data.';
}
@ -43,36 +41,27 @@ export class Local {
console.log(err);
res = 'My brain disconnected, try again.';
}
}
// If the model name includes "deepseek-r1" or "Andy-3.5-reasoning", then handle the <think> block.
const hasOpenTag = res.includes("<think>");
const hasCloseTag = res.includes("</think>");
// If there's a partial mismatch, retry to get a complete response.
if ((hasOpenTag && !hasCloseTag)) {
console.warn("Partial <think> block detected. Re-generating...");
continue;
}
// If </think> is present but <think> is not, prepend <think>
if (hasCloseTag && !hasOpenTag) {
res = '<think>' + res;
}
// Changed this so if the model reasons, using <think> and </think> but doesn't start the message with <think>, <think> ges prepended to the message so no error occur.
// If both tags appear, remove them (and everything inside).
if (hasOpenTag && hasCloseTag) {
res = res.replace(/<think>[\s\S]*?<\/think>/g, '');
}
const hasOpenTag = res.includes("<think>");
const hasCloseTag = res.includes("</think>");
if ((hasOpenTag && !hasCloseTag)) {
console.warn("Partial <think> block detected. Re-generating...");
if (attempt < maxAttempts) continue;
}
if (hasCloseTag && !hasOpenTag) {
res = '<think>' + res;
}
if (hasOpenTag && hasCloseTag) {
res = res.replace(/<think>[\s\S]*?<\/think>/g, '').trim();
}
finalRes = res;
break; // Exit the loop if we got a valid response.
break;
}
if (finalRes == null) {
console.warn("Could not get a valid <think> block or normal response after max attempts.");
console.warn("Could not get a valid response after max attempts.");
finalRes = 'I thought too hard, sorry, try again.';
}
return finalRes;
@ -104,4 +93,22 @@ export class Local {
}
return data;
}
async sendVisionRequest(messages, systemMessage, imageBuffer) {
const imageMessages = [...messages];
imageMessages.push({
role: "user",
content: [
{ type: "text", text: systemMessage },
{
type: "image_url",
image_url: {
url: `data:image/jpeg;base64,${imageBuffer.toString('base64')}`
}
}
]
});
return this.sendRequest(imageMessages, systemMessage);
}
}

View file

@ -5,26 +5,10 @@ import { SkillLibrary } from "../agent/library/skill_library.js";
import { stringifyTurns } from '../utils/text.js';
import { getCommand } from '../agent/commands/index.js';
import settings from '../agent/settings.js';
import { Gemini } from './gemini.js';
import { GPT } from './gpt.js';
import { Claude } from './claude.js';
import { Mistral } from './mistral.js';
import { ReplicateAPI } from './replicate.js';
import { Local } from './local.js';
import { Novita } from './novita.js';
import { GroqCloudAPI } from './groq.js';
import { HuggingFace } from './huggingface.js';
import { Qwen } from "./qwen.js";
import { Grok } from "./grok.js";
import { DeepSeek } from './deepseek.js';
import { Hyperbolic } from './hyperbolic.js';
import { GLHF } from './glhf.js';
import { OpenRouter } from './openrouter.js';
import { VLLM } from './vllm.js';
import { promises as fs } from 'fs';
import path from 'path';
import { fileURLToPath } from 'url';
import { selectAPI, createModel } from './_model_map.js';
const __filename = fileURLToPath(import.meta.url);
const __dirname = path.dirname(__filename);
@ -66,70 +50,46 @@ export class Prompter {
this.last_prompt_time = 0;
this.awaiting_coding = false;
// try to get "max_tokens" parameter, else null
// for backwards compatibility, move max_tokens to params
let max_tokens = null;
if (this.profile.max_tokens)
max_tokens = this.profile.max_tokens;
let chat_model_profile = this._selectAPI(this.profile.model);
this.chat_model = this._createModel(chat_model_profile);
let chat_model_profile = selectAPI(this.profile.model);
this.chat_model = createModel(chat_model_profile);
if (this.profile.code_model) {
let code_model_profile = this._selectAPI(this.profile.code_model);
this.code_model = this._createModel(code_model_profile);
let code_model_profile = selectAPI(this.profile.code_model);
this.code_model = createModel(code_model_profile);
}
else {
this.code_model = this.chat_model;
}
if (this.profile.vision_model) {
let vision_model_profile = this._selectAPI(this.profile.vision_model);
this.vision_model = this._createModel(vision_model_profile);
let vision_model_profile = selectAPI(this.profile.vision_model);
this.vision_model = createModel(vision_model_profile);
}
else {
this.vision_model = this.chat_model;
}
let embedding = this.profile.embedding;
if (embedding === undefined) {
if (chat_model_profile.api !== 'ollama')
embedding = {api: chat_model_profile.api};
else
embedding = {api: 'none'};
}
else if (typeof embedding === 'string' || embedding instanceof String)
embedding = {api: embedding};
console.log('Using embedding settings:', embedding);
try {
if (embedding.api === 'google')
this.embedding_model = new Gemini(embedding.model, embedding.url);
else if (embedding.api === 'openai')
this.embedding_model = new GPT(embedding.model, embedding.url);
else if (embedding.api === 'replicate')
this.embedding_model = new ReplicateAPI(embedding.model, embedding.url);
else if (embedding.api === 'ollama')
this.embedding_model = new Local(embedding.model, embedding.url);
else if (embedding.api === 'qwen')
this.embedding_model = new Qwen(embedding.model, embedding.url);
else if (embedding.api === 'mistral')
this.embedding_model = new Mistral(embedding.model, embedding.url);
else if (embedding.api === 'huggingface')
this.embedding_model = new HuggingFace(embedding.model, embedding.url);
else if (embedding.api === 'novita')
this.embedding_model = new Novita(embedding.model, embedding.url);
else {
this.embedding_model = null;
let embedding_name = embedding ? embedding.api : '[NOT SPECIFIED]'
console.warn('Unsupported embedding: ' + embedding_name + '. Using word-overlap instead, expect reduced performance. Recommend using a supported embedding model. See Readme.');
let embedding_model_profile = null;
if (this.profile.embedding) {
try {
embedding_model_profile = selectAPI(this.profile.embedding);
} catch (e) {
embedding_model_profile = null;
}
}
catch (err) {
console.warn('Warning: Failed to initialize embedding model:', err.message);
console.log('Continuing anyway, using word-overlap instead.');
this.embedding_model = null;
if (embedding_model_profile) {
this.embedding_model = createModel(embedding_model_profile);
}
else {
this.embedding_model = createModel({api: chat_model_profile.api});
}
this.skill_libary = new SkillLibrary(agent, this.embedding_model);
mkdirSync(`./bots/${name}`, { recursive: true });
writeFileSync(`./bots/${name}/last_profile.json`, JSON.stringify(this.profile, null, 4), (err) => {
@ -140,88 +100,6 @@ export class Prompter {
});
}
_selectAPI(profile) {
if (typeof profile === 'string' || profile instanceof String) {
profile = {model: profile};
}
if (!profile.api) {
if (profile.model.includes('openrouter/'))
profile.api = 'openrouter'; // must do first because shares names with other models
else if (profile.model.includes('ollama/'))
profile.api = 'ollama'; // also must do early because shares names with other models
else if (profile.model.includes('gemini'))
profile.api = 'google';
else if (profile.model.includes('vllm/'))
profile.api = 'vllm';
else if (profile.model.includes('gpt') || profile.model.includes('o1')|| profile.model.includes('o3'))
profile.api = 'openai';
else if (profile.model.includes('claude'))
profile.api = 'anthropic';
else if (profile.model.includes('huggingface/'))
profile.api = "huggingface";
else if (profile.model.includes('replicate/'))
profile.api = 'replicate';
else if (profile.model.includes('mistralai/') || profile.model.includes("mistral/"))
model_profile.api = 'mistral';
else if (profile.model.includes("groq/") || profile.model.includes("groqcloud/"))
profile.api = 'groq';
else if (profile.model.includes("glhf/"))
profile.api = 'glhf';
else if (profile.model.includes("hyperbolic/"))
profile.api = 'hyperbolic';
else if (profile.model.includes('novita/'))
profile.api = 'novita';
else if (profile.model.includes('qwen'))
profile.api = 'qwen';
else if (profile.model.includes('grok'))
profile.api = 'xai';
else if (profile.model.includes('deepseek'))
profile.api = 'deepseek';
else if (profile.model.includes('mistral'))
profile.api = 'mistral';
else
throw new Error('Unknown model:', profile.model);
}
return profile;
}
_createModel(profile) {
let model = null;
if (profile.api === 'google')
model = new Gemini(profile.model, profile.url, profile.params);
else if (profile.api === 'openai')
model = new GPT(profile.model, profile.url, profile.params);
else if (profile.api === 'anthropic')
model = new Claude(profile.model, profile.url, profile.params);
else if (profile.api === 'replicate')
model = new ReplicateAPI(profile.model.replace('replicate/', ''), profile.url, profile.params);
else if (profile.api === 'ollama')
model = new Local(profile.model.replace('ollama/', ''), profile.url, profile.params);
else if (profile.api === 'mistral')
model = new Mistral(profile.model, profile.url, profile.params);
else if (profile.api === 'groq')
model = new GroqCloudAPI(profile.model.replace('groq/', '').replace('groqcloud/', ''), profile.url, profile.params);
else if (profile.api === 'huggingface')
model = new HuggingFace(profile.model, profile.url, profile.params);
else if (profile.api === 'glhf')
model = new GLHF(profile.model.replace('glhf/', ''), profile.url, profile.params);
else if (profile.api === 'hyperbolic')
model = new Hyperbolic(profile.model.replace('hyperbolic/', ''), profile.url, profile.params);
else if (profile.api === 'novita')
model = new Novita(profile.model.replace('novita/', ''), profile.url, profile.params);
else if (profile.api === 'qwen')
model = new Qwen(profile.model, profile.url, profile.params);
else if (profile.api === 'xai')
model = new Grok(profile.model, profile.url, profile.params);
else if (profile.api === 'deepseek')
model = new DeepSeek(profile.model, profile.url, profile.params);
else if (profile.api === 'openrouter')
model = new OpenRouter(profile.model.replace('openrouter/', ''), profile.url, profile.params);
else if (profile.api === 'vllm')
model = new VLLM(profile.model.replace('vllm/', ''), profile.url, profile.params);
else
throw new Error('Unknown API:', profile.api);
return model;
}
getName() {
return this.profile.name;
}
@ -482,6 +360,4 @@ export class Prompter {
logFile = path.join(logDir, logFile);
await fs.appendFile(logFile, String(logEntry), 'utf-8');
}
}