Merge pull request #274 from Shandelier/gemini-embedding-fix

gemini gains 10 IQ points (embedding fix)
This commit is contained in:
Max Robinson 2024-11-02 22:59:27 -05:00 committed by GitHub
commit b6eb9da29d
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -6,6 +6,28 @@ export class Gemini {
constructor(model_name, url) { constructor(model_name, url) {
this.model_name = model_name; this.model_name = model_name;
this.url = url; this.url = url;
this.safetySettings = [
{
"category": "HARM_CATEGORY_DANGEROUS",
"threshold": "BLOCK_NONE",
},
{
"category": "HARM_CATEGORY_HARASSMENT",
"threshold": "BLOCK_NONE",
},
{
"category": "HARM_CATEGORY_HATE_SPEECH",
"threshold": "BLOCK_NONE",
},
{
"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT",
"threshold": "BLOCK_NONE",
},
{
"category": "HARM_CATEGORY_DANGEROUS_CONTENT",
"threshold": "BLOCK_NONE",
},
];
this.genAI = new GoogleGenerativeAI(getKey('GEMINI_API_KEY')); this.genAI = new GoogleGenerativeAI(getKey('GEMINI_API_KEY'));
} }
@ -14,12 +36,14 @@ export class Gemini {
let model; let model;
if (this.url) { if (this.url) {
model = this.genAI.getGenerativeModel( model = this.genAI.getGenerativeModel(
{model: this.model_name || "gemini-pro"}, { model: this.model_name || "gemini-1.5-flash" },
{baseUrl: this.url} { baseUrl: this.url },
{ safetySettings: this.safetySettings }
); );
} else { } else {
model = this.genAI.getGenerativeModel( model = this.genAI.getGenerativeModel(
{model: this.model_name || "gemini-pro"} { model: this.model_name || "gemini-1.5-flash" },
{ safetySettings: this.safetySettings }
); );
} }
@ -39,16 +63,16 @@ export class Gemini {
let model; let model;
if (this.url) { if (this.url) {
model = this.genAI.getGenerativeModel( model = this.genAI.getGenerativeModel(
{model: this.model_name || "embedding-001"}, { model: "text-embedding-004" },
{baseUrl: this.url} { baseUrl: this.url }
); );
} else { } else {
model = this.genAI.getGenerativeModel( model = this.genAI.getGenerativeModel(
{model: this.model_name || "embedding-001"} { model: "text-embedding-004" }
); );
} }
const result = await model.embedContent(text); const result = await model.embedContent(text);
return result.embedding; return result.embedding.values;
} }
} }