diff --git a/src/models/gemini.js b/src/models/gemini.js index ee4dfa4..1536d66 100644 --- a/src/models/gemini.js +++ b/src/models/gemini.js @@ -6,6 +6,28 @@ export class Gemini { constructor(model_name, url) { this.model_name = model_name; this.url = url; + this.safetySettings = [ + { + "category": "HARM_CATEGORY_DANGEROUS", + "threshold": "BLOCK_NONE", + }, + { + "category": "HARM_CATEGORY_HARASSMENT", + "threshold": "BLOCK_NONE", + }, + { + "category": "HARM_CATEGORY_HATE_SPEECH", + "threshold": "BLOCK_NONE", + }, + { + "category": "HARM_CATEGORY_SEXUALLY_EXPLICIT", + "threshold": "BLOCK_NONE", + }, + { + "category": "HARM_CATEGORY_DANGEROUS_CONTENT", + "threshold": "BLOCK_NONE", + }, + ]; this.genAI = new GoogleGenerativeAI(getKey('GEMINI_API_KEY')); } @@ -14,12 +36,14 @@ export class Gemini { let model; if (this.url) { model = this.genAI.getGenerativeModel( - {model: this.model_name || "gemini-pro"}, - {baseUrl: this.url} + { model: this.model_name || "gemini-1.5-flash" }, + { baseUrl: this.url }, + { safetySettings: this.safetySettings } ); } else { model = this.genAI.getGenerativeModel( - {model: this.model_name || "gemini-pro"} + { model: this.model_name || "gemini-1.5-flash" }, + { safetySettings: this.safetySettings } ); } @@ -39,16 +63,16 @@ export class Gemini { let model; if (this.url) { model = this.genAI.getGenerativeModel( - {model: this.model_name || "embedding-001"}, - {baseUrl: this.url} + { model: "text-embedding-004" }, + { baseUrl: this.url } ); } else { model = this.genAI.getGenerativeModel( - {model: this.model_name || "embedding-001"} + { model: "text-embedding-004" } ); } const result = await model.embedContent(text); - return result.embedding; + return result.embedding.values; } } \ No newline at end of file