diff --git a/keys.example.json b/keys.example.json index 99286c5..52204ae 100644 --- a/keys.example.json +++ b/keys.example.json @@ -13,5 +13,6 @@ "GHLF_API_KEY": "", "HYPERBOLIC_API_KEY": "", "NOVITA_API_KEY": "", - "OPENROUTER_API_KEY": "" + "OPENROUTER_API_KEY": "", + "CEREBRAS_API_KEY": "" } diff --git a/package.json b/package.json index bb3fd90..a37cf43 100644 --- a/package.json +++ b/package.json @@ -2,6 +2,7 @@ "type": "module", "dependencies": { "@anthropic-ai/sdk": "^0.17.1", + "@cerebras/cerebras_cloud_sdk": "^1.0.0", "@google/generative-ai": "^0.2.1", "@huggingface/inference": "^2.8.1", "@mistralai/mistralai": "^1.1.0", diff --git a/src/models/cerebras.js b/src/models/cerebras.js new file mode 100644 index 0000000..21f1eee --- /dev/null +++ b/src/models/cerebras.js @@ -0,0 +1,43 @@ +import CerebrasSDK from '@cerebras/cerebras_cloud_sdk'; +import { strictFormat } from '../utils/text.js'; +import { getKey } from '../utils/keys.js'; + +export class Cerebras { + constructor(model_name, url, params) { + // Strip the prefix + this.model_name = model_name.replace('cerebras/', ''); + this.url = url; + this.params = params; + + // Initialize client with API key + this.client = new CerebrasSDK({ apiKey: getKey('CEREBRAS_API_KEY') }); + } + + async sendRequest(turns, systemMessage, stop_seq = '***') { + // Format messages array + const messages = strictFormat(turns); + messages.unshift({ role: 'system', content: systemMessage }); + + const pack = { + model: this.model_name || 'llama-4-scout-17b-16e-instruct', + messages, + stream: false, + ...(this.params || {}), + }; + + let res; + try { + const completion = await this.client.chat.completions.create(pack); + // OpenAI-compatible shape + res = completion.choices?.[0]?.message?.content || ''; + } catch (err) { + console.error('Cerebras API error:', err); + res = 'My brain disconnected, try again.'; + } + return res; + } + + async embed(text) { + throw new Error('Embeddings are not supported by Cerebras.'); + } +} diff --git a/src/models/prompter.js b/src/models/prompter.js index e05f5a8..22f23f7 100644 --- a/src/models/prompter.js +++ b/src/models/prompter.js @@ -22,6 +22,7 @@ import { Hyperbolic } from './hyperbolic.js'; import { GLHF } from './glhf.js'; import { OpenRouter } from './openrouter.js'; import { VLLM } from './vllm.js'; +import { Cerebras } from './cerebras.js'; import { promises as fs } from 'fs'; import path from 'path'; import { fileURLToPath } from 'url'; @@ -170,6 +171,8 @@ export class Prompter { profile.api = 'deepseek'; else if (profile.model.includes('mistral')) profile.api = 'mistral'; + else if (profile.model.startsWith('cerebras/')) + profile.api = 'cerebras'; else throw new Error('Unknown model:', profile.model); } @@ -209,6 +212,8 @@ export class Prompter { model = new OpenRouter(profile.model.replace('openrouter/', ''), profile.url, profile.params); else if (profile.api === 'vllm') model = new VLLM(profile.model.replace('vllm/', ''), profile.url, profile.params); + else if (profile.api === 'cerebras') + model = new Cerebras(profile.model.replace('cerebras/', ''), profile.url, profile.params); else throw new Error('Unknown API:', profile.api); return model;