2024-02-18 22:56:38 -06:00
|
|
|
import { cosineSimilarity } from './math.js';
|
2024-01-25 13:25:36 -08:00
|
|
|
import { stringifyTurns } from './text.js';
|
2024-01-15 11:04:50 -06:00
|
|
|
|
|
|
|
export class Examples {
|
2024-02-25 14:13:32 -06:00
|
|
|
constructor(model, select_num=2) {
|
2024-01-15 11:04:50 -06:00
|
|
|
this.examples = [];
|
2024-02-25 14:13:32 -06:00
|
|
|
this.model = model;
|
2024-01-15 11:04:50 -06:00
|
|
|
this.select_num = select_num;
|
|
|
|
}
|
|
|
|
|
2024-02-25 14:13:32 -06:00
|
|
|
async load(examples) {
|
2024-01-15 11:04:50 -06:00
|
|
|
this.examples = [];
|
2024-02-25 14:13:32 -06:00
|
|
|
let promises = examples.map(async (example) => {
|
2024-01-15 11:04:50 -06:00
|
|
|
let messages = '';
|
|
|
|
for (let turn of example) {
|
2024-02-25 14:13:32 -06:00
|
|
|
if (turn.role === 'user')
|
2024-01-15 11:04:50 -06:00
|
|
|
messages += turn.content.substring(turn.content.indexOf(':')+1).trim() + '\n';
|
|
|
|
}
|
|
|
|
messages = messages.trim();
|
2024-02-25 14:13:32 -06:00
|
|
|
const embedding = await this.model.embed(messages);
|
|
|
|
return {'embedding': embedding, 'turns': example};
|
|
|
|
});
|
|
|
|
this.examples = await Promise.all(promises);
|
2024-01-15 11:04:50 -06:00
|
|
|
}
|
|
|
|
|
|
|
|
async getRelevant(turns) {
|
|
|
|
let messages = '';
|
|
|
|
for (let turn of turns) {
|
|
|
|
if (turn.role != 'assistant')
|
|
|
|
messages += turn.content.substring(turn.content.indexOf(':')+1).trim() + '\n';
|
|
|
|
}
|
|
|
|
messages = messages.trim();
|
2024-02-25 14:13:32 -06:00
|
|
|
const embedding = await this.model.embed(messages);
|
2024-01-15 11:04:50 -06:00
|
|
|
this.examples.sort((a, b) => {
|
|
|
|
return cosineSimilarity(b.embedding, embedding) - cosineSimilarity(a.embedding, embedding);
|
|
|
|
});
|
|
|
|
let selected = this.examples.slice(0, this.select_num);
|
|
|
|
return JSON.parse(JSON.stringify(selected)); // deep copy
|
|
|
|
}
|
|
|
|
|
|
|
|
async createExampleMessage(turns) {
|
|
|
|
let selected_examples = await this.getRelevant(turns);
|
|
|
|
|
|
|
|
console.log('selected examples:');
|
|
|
|
for (let example of selected_examples) {
|
2024-01-15 12:04:18 -06:00
|
|
|
console.log(example.turns[0].content)
|
2024-01-15 11:04:50 -06:00
|
|
|
}
|
|
|
|
|
2024-02-25 14:13:32 -06:00
|
|
|
let msg = 'Examples of how to respond:\n';
|
2024-01-15 11:04:50 -06:00
|
|
|
for (let i=0; i<selected_examples.length; i++) {
|
|
|
|
let example = selected_examples[i];
|
|
|
|
msg += `Example ${i+1}:\n${stringifyTurns(example.turns)}\n\n`;
|
|
|
|
}
|
2024-02-25 14:13:32 -06:00
|
|
|
return msg;
|
2024-01-15 11:04:50 -06:00
|
|
|
}
|
|
|
|
}
|