mindcraft/src/agent/prompter.js

301 lines
13 KiB
JavaScript
Raw Normal View History

import {mkdirSync, readFileSync, writeFileSync} from 'fs';
import {Examples} from '../utils/examples.js';
import {getCommand, getCommandDocs} from './commands/index.js';
import {getSkillDocs} from './library/index.js';
import {stringifyTurns} from '../utils/text.js';
import {cosineSimilarity} from '../utils/math.js';
2024-03-23 11:15:53 -05:00
import { Gemini } from '../models/gemini.js';
import { GPT } from '../models/gpt.js';
import { Claude } from '../models/claude.js';
2024-05-06 00:18:04 -05:00
import { ReplicateAPI } from '../models/replicate.js';
import { Local } from '../models/local.js';
2024-08-31 15:29:34 -07:00
import { GroqCloudAPI } from '../models/groq.js';
import { HuggingFace } from '../models/huggingface.js';
2024-10-28 13:29:16 +08:00
import { Qwen } from "../models/qwen.js";
export class Prompter {
constructor(agent, fp) {
2024-04-24 11:28:04 -07:00
this.agent = agent;
2024-06-03 18:23:29 -05:00
this.profile = JSON.parse(readFileSync(fp, 'utf8'));
2024-04-24 11:28:04 -07:00
this.convo_examples = null;
this.coding_examples = null;
this.skill_docs_embeddings = {};
2024-06-03 18:23:29 -05:00
let name = this.profile.name;
let chat = this.profile.model;
this.cooldown = this.profile.cooldown ? this.profile.cooldown : 0;
this.last_prompt_time = 0;
2024-08-31 15:29:34 -07:00
// try to get "max_tokens" parameter, else null
let max_tokens = null;
if (this.profile.max_tokens)
max_tokens = this.profile.max_tokens;
2024-04-24 11:28:04 -07:00
if (typeof chat === 'string' || chat instanceof String) {
chat = {model: chat};
if (chat.model.includes('gemini'))
chat.api = 'google';
else if (chat.model.includes('gpt') || chat.model.includes('o1'))
2024-04-24 11:28:04 -07:00
chat.api = 'openai';
else if (chat.model.includes('claude'))
chat.api = 'anthropic';
else if (chat.model.includes('huggingface/'))
chat.api = "huggingface";
2024-05-06 00:18:04 -05:00
else if (chat.model.includes('meta/') || chat.model.includes('mistralai/') || chat.model.includes('replicate/'))
chat.api = 'replicate';
2024-08-31 15:29:34 -07:00
else if (chat.model.includes("groq/") || chat.model.includes("groqcloud/"))
chat.api = 'groq';
2024-10-28 13:29:16 +08:00
else if (chat.model.includes('qwen'))
chat.api = 'qwen';
2024-04-24 11:28:04 -07:00
else
chat.api = 'ollama';
}
console.log('Using chat settings:', chat);
if (chat.api === 'google')
2024-04-24 11:28:04 -07:00
this.chat_model = new Gemini(chat.model, chat.url);
else if (chat.api === 'openai')
2024-04-24 11:28:04 -07:00
this.chat_model = new GPT(chat.model, chat.url);
else if (chat.api === 'anthropic')
2024-04-24 11:28:04 -07:00
this.chat_model = new Claude(chat.model, chat.url);
else if (chat.api === 'replicate')
2024-05-06 00:18:04 -05:00
this.chat_model = new ReplicateAPI(chat.model, chat.url);
else if (chat.api === 'ollama')
2024-04-24 11:28:04 -07:00
this.chat_model = new Local(chat.model, chat.url);
else if (chat.api === 'groq') {
2024-08-31 15:29:34 -07:00
this.chat_model = new GroqCloudAPI(chat.model.replace('groq/', '').replace('groqcloud/', ''), chat.url, max_tokens ? max_tokens : 8192);
}
else if (chat.api === 'huggingface')
this.chat_model = new HuggingFace(chat.model, chat.url);
2024-10-28 13:29:16 +08:00
else if (chat.api === 'qwen')
this.chat_model = new Qwen(chat.model, chat.url);
2024-04-24 11:28:04 -07:00
else
throw new Error('Unknown API:', api);
2024-06-03 18:23:29 -05:00
let embedding = this.profile.embedding;
2024-06-01 16:23:14 -05:00
if (embedding === undefined) {
if (chat.api !== 'ollama')
embedding = {api: chat.api};
else
embedding = {api: 'none'};
}
2024-04-24 11:28:04 -07:00
else if (typeof embedding === 'string' || embedding instanceof String)
embedding = {api: embedding};
console.log('Using embedding settings:', embedding);
2024-11-07 11:35:09 -06:00
try {
if (embedding.api === 'google')
this.embedding_model = new Gemini(embedding.model, embedding.url);
else if (embedding.api === 'openai')
this.embedding_model = new GPT(embedding.model, embedding.url);
else if (embedding.api === 'replicate')
this.embedding_model = new ReplicateAPI(embedding.model, embedding.url);
else if (embedding.api === 'ollama')
this.embedding_model = new Local(embedding.model, embedding.url);
else if (embedding.api === 'qwen')
this.embedding_model = new Qwen(embedding.model, embedding.url);
else {
this.embedding_model = null;
console.log('Unknown embedding: ', embedding ? embedding.api : '[NOT SPECIFIED]', '. Using word overlap.');
}
}
catch (err) {
console.log('Warning: Failed to initialize embedding model:', err.message);
console.log('Continuing anyway, using word overlap instead.');
2024-04-24 11:28:04 -07:00
this.embedding_model = null;
}
mkdirSync(`./bots/${name}`, { recursive: true });
2024-06-03 18:23:29 -05:00
writeFileSync(`./bots/${name}/last_profile.json`, JSON.stringify(this.profile, null, 4), (err) => {
if (err) {
2024-11-07 10:25:49 -06:00
throw new Error('Failed to save profile:', err);
}
console.log("Copy profile saved.");
});
}
getName() {
2024-06-03 18:23:29 -05:00
return this.profile.name;
}
getInitModes() {
return this.profile.modes;
}
async initExamples() {
2024-11-07 10:25:49 -06:00
try {
this.convo_examples = new Examples(this.embedding_model);
this.coding_examples = new Examples(this.embedding_model);
const results = await Promise.allSettled([
2024-11-07 10:25:49 -06:00
this.convo_examples.load(this.profile.conversation_examples),
this.coding_examples.load(this.profile.coding_examples),
...getSkillDocs().map(async (doc) => {
let func_name_desc = doc.split('\n').slice(0, 2).join('');
this.skill_docs_embeddings[doc] = await this.embedding_model.embed([func_name_desc]);
})
2024-11-07 10:25:49 -06:00
]);
// Handle potential failures for conversation and coding examples
const [convoResult, codingResult, ...skillDocResults] = results;
2024-11-07 10:25:49 -06:00
if (convoResult.status === 'rejected') {
console.error('Failed to load conversation examples:', convoResult.reason);
throw convoResult.reason;
}
if (codingResult.status === 'rejected') {
console.error('Failed to load coding examples:', codingResult.reason);
throw codingResult.reason;
}
skillDocResults.forEach((result, index) => {
if (result.status === 'rejected') {
console.error(`Failed to load skill doc ${index + 1}:`, result.reason);
}
});
2024-11-07 10:25:49 -06:00
} catch (error) {
console.error('Failed to initialize examples:', error);
throw error;
}
}
async getRelevantSkillDocs(message, select_num) {
let latest_message_embedding = '';
if(message) //message is not empty, get the relevant skill docs, else return all skill docs
latest_message_embedding = await this.embedding_model.embed(message);
let skill_doc_similarities = Object.keys(this.skill_docs_embeddings)
.map(doc_key => ({
doc_key,
similarity_score: cosineSimilarity(latest_message_embedding, this.skill_docs_embeddings[doc_key])
}))
.sort((a, b) => b.similarity_score - a.similarity_score);
let length = skill_doc_similarities.length;
if (typeof select_num !== 'number' || isNaN(select_num) || select_num < 0) {
select_num = length;
} else {
select_num = Math.min(Math.floor(select_num), length);
}
let selected_docs = skill_doc_similarities.slice(0, select_num);
let relevant_skill_docs = '#### RELEVENT DOCS INFO ###\nThe following functions are listed in descending order of relevance.\n';
relevant_skill_docs += 'SkillDocs:\n'
relevant_skill_docs += '###'+ selected_docs.map(doc => `${doc.doc_key}`).join('\n');
return relevant_skill_docs;
}
async replaceStrings(prompt, messages, examples=null, to_summarize=[], last_goals=null) {
prompt = prompt.replaceAll('$NAME', this.agent.name);
if (prompt.includes('$STATS')) {
let stats = await getCommand('!stats').perform(this.agent);
prompt = prompt.replaceAll('$STATS', stats);
}
if (prompt.includes('$INVENTORY')) {
let inventory = await getCommand('!inventory').perform(this.agent);
prompt = prompt.replaceAll('$INVENTORY', inventory);
}
if (prompt.includes('$COMMAND_DOCS'))
prompt = prompt.replaceAll('$COMMAND_DOCS', getCommandDocs());
if (prompt.includes('$CODE_DOCS')){
let latest_message_content = messages.slice().reverse().find(msg => msg.role !== 'system')?.content || '';
prompt = prompt.replaceAll('$CODE_DOCS', await this.getRelevantSkillDocs(latest_message_content, 5));
}
if (prompt.includes('$EXAMPLES') && examples !== null)
prompt = prompt.replaceAll('$EXAMPLES', await examples.createExampleMessage(messages));
if (prompt.includes('$MEMORY'))
prompt = prompt.replaceAll('$MEMORY', this.agent.history.memory);
if (prompt.includes('$TO_SUMMARIZE'))
prompt = prompt.replaceAll('$TO_SUMMARIZE', stringifyTurns(to_summarize));
2024-04-24 13:34:09 -07:00
if (prompt.includes('$CONVO'))
prompt = prompt.replaceAll('$CONVO', 'Recent conversation:\n' + stringifyTurns(messages));
if (prompt.includes('$SELF_PROMPT')) {
let self_prompt = this.agent.self_prompter.on ? `YOUR CURRENT ASSIGNED GOAL: "${this.agent.self_prompter.prompt}"\n` : '';
prompt = prompt.replaceAll('$SELF_PROMPT', self_prompt);
}
2024-04-24 13:34:09 -07:00
if (prompt.includes('$LAST_GOALS')) {
let goal_text = '';
for (let goal in last_goals) {
if (last_goals[goal])
goal_text += `You recently successfully completed the goal ${goal}.\n`
else
goal_text += `You recently failed to complete the goal ${goal}.\n`
}
prompt = prompt.replaceAll('$LAST_GOALS', goal_text.trim());
}
if (prompt.includes('$BLUEPRINTS')) {
if (this.agent.npc.constructions) {
let blueprints = '';
for (let blueprint in this.agent.npc.constructions) {
blueprints += blueprint + ', ';
}
prompt = prompt.replaceAll('$BLUEPRINTS', blueprints.slice(0, -2));
}
}
// check if there are any remaining placeholders with syntax $<word>
let remaining = prompt.match(/\$[A-Z_]+/g);
if (remaining !== null) {
console.warn('Unknown prompt placeholders:', remaining.join(', '));
}
return prompt;
}
async checkCooldown() {
let elapsed = Date.now() - this.last_prompt_time;
if (elapsed < this.cooldown && this.cooldown > 0) {
await new Promise(r => setTimeout(r, this.cooldown - elapsed));
}
this.last_prompt_time = Date.now();
}
async promptConvo(messages) {
await this.checkCooldown();
2024-06-03 18:23:29 -05:00
let prompt = this.profile.conversing;
prompt = await this.replaceStrings(prompt, messages, this.convo_examples);
2024-04-24 11:28:04 -07:00
return await this.chat_model.sendRequest(messages, prompt);
}
async promptCoding(messages) {
await this.checkCooldown();
2024-06-03 18:23:29 -05:00
let prompt = this.profile.coding;
prompt = await this.replaceStrings(prompt, messages, this.coding_examples);
2024-04-24 11:28:04 -07:00
return await this.chat_model.sendRequest(messages, prompt);
}
async promptMemSaving(to_summarize) {
await this.checkCooldown();
2024-06-03 18:23:29 -05:00
let prompt = this.profile.saving_memory;
prompt = await this.replaceStrings(prompt, null, null, to_summarize);
2024-04-24 11:28:04 -07:00
return await this.chat_model.sendRequest([], prompt);
}
2024-04-23 20:47:01 -07:00
async promptGoalSetting(messages, last_goals) {
2024-06-03 18:23:29 -05:00
let system_message = this.profile.goal_setting;
2024-04-24 13:34:09 -07:00
system_message = await this.replaceStrings(system_message, messages);
let user_message = 'Use the below info to determine what goal to target next\n\n';
user_message += '$LAST_GOALS\n$STATS\n$INVENTORY\n$CONVO'
user_message = await this.replaceStrings(user_message, messages, null, null, last_goals);
2024-04-24 13:34:09 -07:00
let user_messages = [{role: 'user', content: user_message}];
2024-04-30 13:31:51 -07:00
let res = await this.chat_model.sendRequest(user_messages, system_message);
2024-04-24 13:34:09 -07:00
let goal = null;
try {
let data = res.split('```')[1].replace('json', '').trim();
goal = JSON.parse(data);
} catch (err) {
console.log('Failed to parse goal:', res, err);
}
if (!goal || !goal.name || !goal.quantity || isNaN(parseInt(goal.quantity))) {
console.log('Failed to set goal:', res);
return null;
}
goal.quantity = parseInt(goal.quantity);
return goal;
2024-04-23 20:47:01 -07:00
}
2024-04-30 13:31:51 -07:00
}