2024-02-18 22:56:38 -06:00
|
|
|
import { cosineSimilarity } from './math.js';
|
2024-01-25 13:25:36 -08:00
|
|
|
import { stringifyTurns } from './text.js';
|
2024-01-15 11:04:50 -06:00
|
|
|
|
|
|
|
export class Examples {
|
2024-02-25 14:13:32 -06:00
|
|
|
constructor(model, select_num=2) {
|
2024-01-15 11:04:50 -06:00
|
|
|
this.examples = [];
|
2024-02-25 14:13:32 -06:00
|
|
|
this.model = model;
|
2024-01-15 11:04:50 -06:00
|
|
|
this.select_num = select_num;
|
2024-04-24 11:28:04 -07:00
|
|
|
this.embeddings = {};
|
|
|
|
}
|
|
|
|
|
|
|
|
turnsToText(turns) {
|
|
|
|
let messages = '';
|
|
|
|
for (let turn of turns) {
|
|
|
|
if (turn.role === 'user')
|
|
|
|
messages += turn.content.substring(turn.content.indexOf(':')+1).trim() + '\n';
|
|
|
|
}
|
|
|
|
return messages.trim();
|
|
|
|
}
|
|
|
|
|
|
|
|
getWords(text) {
|
|
|
|
return text.replace(/[^a-zA-Z ]/g, '').toLowerCase().split(' ');
|
|
|
|
}
|
|
|
|
|
|
|
|
async getSimilarity(text1, text2) {
|
|
|
|
if (this.model !== null) {
|
|
|
|
let embeddings1 = null;
|
|
|
|
let embeddings2 = null;
|
|
|
|
|
|
|
|
if (this.embeddings[text1])
|
|
|
|
embeddings1 = this.embeddings[text1];
|
|
|
|
else
|
|
|
|
embeddings1 = await this.model.embed(text1);
|
|
|
|
|
|
|
|
if (this.embeddings[text2])
|
|
|
|
embeddings2 = this.embeddings[text2];
|
|
|
|
else
|
|
|
|
embeddings2 = await this.model.embed(text2);
|
|
|
|
|
|
|
|
return cosineSimilarity(embeddings1, embeddings2);
|
|
|
|
|
|
|
|
} else {
|
|
|
|
const words1 = this.getWords(text1);
|
|
|
|
const words2 = this.getWords(text2);
|
|
|
|
const intersection = words1.filter(word => words2.includes(word));
|
|
|
|
return intersection.length / (words1.length + words2.length - intersection.length);
|
|
|
|
}
|
2024-01-15 11:04:50 -06:00
|
|
|
}
|
|
|
|
|
2024-02-25 14:13:32 -06:00
|
|
|
async load(examples) {
|
2024-04-24 11:28:04 -07:00
|
|
|
this.examples = examples;
|
|
|
|
if (this.model !== null) {
|
|
|
|
for (let example of this.examples) {
|
|
|
|
let turn_text = this.turnsToText(example);
|
|
|
|
this.embeddings[turn_text] = await this.model.embed(turn_text);
|
2024-01-15 11:04:50 -06:00
|
|
|
}
|
2024-04-24 11:28:04 -07:00
|
|
|
}
|
2024-01-15 11:04:50 -06:00
|
|
|
}
|
|
|
|
|
|
|
|
async getRelevant(turns) {
|
2024-04-24 11:28:04 -07:00
|
|
|
let turn_text = this.turnsToText(turns);
|
|
|
|
this.examples.sort((a, b) =>
|
|
|
|
this.getSimilarity(turn_text, this.turnsToText(a)) -
|
|
|
|
this.getSimilarity(turn_text, this.turnsToText(b))
|
|
|
|
);
|
|
|
|
let selected = this.examples.slice(-this.select_num);
|
2024-01-15 11:04:50 -06:00
|
|
|
return JSON.parse(JSON.stringify(selected)); // deep copy
|
|
|
|
}
|
|
|
|
|
|
|
|
async createExampleMessage(turns) {
|
|
|
|
let selected_examples = await this.getRelevant(turns);
|
|
|
|
|
|
|
|
console.log('selected examples:');
|
|
|
|
for (let example of selected_examples) {
|
2024-04-24 11:28:04 -07:00
|
|
|
console.log(example[0].content)
|
2024-01-15 11:04:50 -06:00
|
|
|
}
|
|
|
|
|
2024-02-25 14:13:32 -06:00
|
|
|
let msg = 'Examples of how to respond:\n';
|
2024-01-15 11:04:50 -06:00
|
|
|
for (let i=0; i<selected_examples.length; i++) {
|
|
|
|
let example = selected_examples[i];
|
2024-04-24 11:28:04 -07:00
|
|
|
msg += `Example ${i+1}:\n${stringifyTurns(example)}\n\n`;
|
2024-01-15 11:04:50 -06:00
|
|
|
}
|
2024-02-25 14:13:32 -06:00
|
|
|
return msg;
|
2024-01-15 11:04:50 -06:00
|
|
|
}
|
|
|
|
}
|