mindcraft/src/utils/examples.js

83 lines
2.8 KiB
JavaScript
Raw Normal View History

2024-02-18 22:56:38 -06:00
import { cosineSimilarity } from './math.js';
2024-01-25 13:25:36 -08:00
import { stringifyTurns } from './text.js';
2024-01-15 11:04:50 -06:00
export class Examples {
constructor(model, select_num=2) {
2024-01-15 11:04:50 -06:00
this.examples = [];
this.model = model;
2024-01-15 11:04:50 -06:00
this.select_num = select_num;
2024-04-24 11:28:04 -07:00
this.embeddings = {};
}
turnsToText(turns) {
let messages = '';
for (let turn of turns) {
if (turn.role !== 'assistant')
2024-04-24 11:28:04 -07:00
messages += turn.content.substring(turn.content.indexOf(':')+1).trim() + '\n';
}
return messages.trim();
}
getWords(text) {
return text.replace(/[^a-zA-Z ]/g, '').toLowerCase().split(' ');
}
wordOverlapScore(text1, text2) {
const words1 = this.getWords(text1);
const words2 = this.getWords(text2);
const intersection = words1.filter(word => words2.includes(word));
return intersection.length / (words1.length + words2.length - intersection.length);
2024-01-15 11:04:50 -06:00
}
async load(examples) {
2024-04-24 11:28:04 -07:00
this.examples = examples;
try {
if (this.model !== null) {
const embeddingPromises = this.examples.map(async (example) => {
let turn_text = this.turnsToText(example);
this.embeddings[turn_text] = await this.model.embed(turn_text);
});
await Promise.all(embeddingPromises);
}
} catch (err) {
console.warn('Error with embedding model, using word overlap instead.');
this.model = null;
2024-04-24 11:28:04 -07:00
}
2024-01-15 11:04:50 -06:00
}
async getRelevant(turns) {
2024-04-24 11:28:04 -07:00
let turn_text = this.turnsToText(turns);
if (this.model !== null) {
let embedding = await this.model.embed(turn_text);
this.examples.sort((a, b) =>
cosineSimilarity(embedding, this.embeddings[this.turnsToText(b)]) -
cosineSimilarity(embedding, this.embeddings[this.turnsToText(a)])
);
}
else {
this.examples.sort((a, b) =>
this.wordOverlapScore(turn_text, this.turnsToText(b)) -
this.wordOverlapScore(turn_text, this.turnsToText(a))
);
}
let selected = this.examples.slice(0, this.select_num);
2024-01-15 11:04:50 -06:00
return JSON.parse(JSON.stringify(selected)); // deep copy
}
async createExampleMessage(turns) {
let selected_examples = await this.getRelevant(turns);
console.log('selected examples:');
for (let example of selected_examples) {
2024-04-24 11:28:04 -07:00
console.log(example[0].content)
2024-01-15 11:04:50 -06:00
}
let msg = 'Examples of how to respond:\n';
2024-01-15 11:04:50 -06:00
for (let i=0; i<selected_examples.length; i++) {
let example = selected_examples[i];
2024-04-24 11:28:04 -07:00
msg += `Example ${i+1}:\n${stringifyTurns(example)}\n\n`;
2024-01-15 11:04:50 -06:00
}
return msg;
2024-01-15 11:04:50 -06:00
}
}