Added provider support for Cerebras

This commit is contained in:
Kevin Taylor 2025-05-11 23:22:14 -07:00
parent 4a19f72e75
commit bba08979a8
4 changed files with 51 additions and 1 deletions

View file

@ -13,5 +13,6 @@
"GHLF_API_KEY": "", "GHLF_API_KEY": "",
"HYPERBOLIC_API_KEY": "", "HYPERBOLIC_API_KEY": "",
"NOVITA_API_KEY": "", "NOVITA_API_KEY": "",
"OPENROUTER_API_KEY": "" "OPENROUTER_API_KEY": "",
"CEREBRAS_API_KEY": ""
} }

View file

@ -2,6 +2,7 @@
"type": "module", "type": "module",
"dependencies": { "dependencies": {
"@anthropic-ai/sdk": "^0.17.1", "@anthropic-ai/sdk": "^0.17.1",
"@cerebras/cerebras_cloud_sdk": "^1.0.0",
"@google/generative-ai": "^0.2.1", "@google/generative-ai": "^0.2.1",
"@huggingface/inference": "^2.8.1", "@huggingface/inference": "^2.8.1",
"@mistralai/mistralai": "^1.1.0", "@mistralai/mistralai": "^1.1.0",

43
src/models/cerebras.js Normal file
View file

@ -0,0 +1,43 @@
import CerebrasSDK from '@cerebras/cerebras_cloud_sdk';
import { strictFormat } from '../utils/text.js';
import { getKey } from '../utils/keys.js';
export class Cerebras {
constructor(model_name, url, params) {
// Strip the prefix
this.model_name = model_name.replace('cerebras/', '');
this.url = url;
this.params = params;
// Initialize client with API key
this.client = new CerebrasSDK({ apiKey: getKey('CEREBRAS_API_KEY') });
}
async sendRequest(turns, systemMessage, stop_seq = '***') {
// Format messages array
const messages = strictFormat(turns);
messages.unshift({ role: 'system', content: systemMessage });
const pack = {
model: this.model_name || 'llama-4-scout-17b-16e-instruct',
messages,
stream: false,
...(this.params || {}),
};
let res;
try {
const completion = await this.client.chat.completions.create(pack);
// OpenAI-compatible shape
res = completion.choices?.[0]?.message?.content || '';
} catch (err) {
console.error('Cerebras API error:', err);
res = 'My brain disconnected, try again.';
}
return res;
}
async embed(text) {
throw new Error('Embeddings are not supported by Cerebras.');
}
}

View file

@ -22,6 +22,7 @@ import { Hyperbolic } from './hyperbolic.js';
import { GLHF } from './glhf.js'; import { GLHF } from './glhf.js';
import { OpenRouter } from './openrouter.js'; import { OpenRouter } from './openrouter.js';
import { VLLM } from './vllm.js'; import { VLLM } from './vllm.js';
import { Cerebras } from './cerebras.js';
import { promises as fs } from 'fs'; import { promises as fs } from 'fs';
import path from 'path'; import path from 'path';
import { fileURLToPath } from 'url'; import { fileURLToPath } from 'url';
@ -170,6 +171,8 @@ export class Prompter {
profile.api = 'deepseek'; profile.api = 'deepseek';
else if (profile.model.includes('mistral')) else if (profile.model.includes('mistral'))
profile.api = 'mistral'; profile.api = 'mistral';
else if (profile.model.startsWith('cerebras/'))
profile.api = 'cerebras';
else else
throw new Error('Unknown model:', profile.model); throw new Error('Unknown model:', profile.model);
} }
@ -209,6 +212,8 @@ export class Prompter {
model = new OpenRouter(profile.model.replace('openrouter/', ''), profile.url, profile.params); model = new OpenRouter(profile.model.replace('openrouter/', ''), profile.url, profile.params);
else if (profile.api === 'vllm') else if (profile.api === 'vllm')
model = new VLLM(profile.model.replace('vllm/', ''), profile.url, profile.params); model = new VLLM(profile.model.replace('vllm/', ''), profile.url, profile.params);
else if (profile.api === 'cerebras')
model = new Cerebras(profile.model.replace('cerebras/', ''), profile.url, profile.params);
else else
throw new Error('Unknown API:', profile.api); throw new Error('Unknown API:', profile.api);
return model; return model;