From 18300a8f13905527886aeca15561630f56626dca Mon Sep 17 00:00:00 2001 From: Zgrill2 Date: Thu, 6 Feb 2025 00:20:52 -0500 Subject: [PATCH] Add support for GPT models in Azure AI Studio --- profiles/azure.json | 15 +++++++++++++++ src/models/azure.js | 23 +++++++++++++++++++++++ src/models/prompter.js | 5 +++++ 3 files changed, 43 insertions(+) create mode 100644 profiles/azure.json create mode 100644 src/models/azure.js diff --git a/profiles/azure.json b/profiles/azure.json new file mode 100644 index 0000000..fbd382d --- /dev/null +++ b/profiles/azure.json @@ -0,0 +1,15 @@ +{ + "name": "azure", + "model": { + "api": "azure", + "url": "", + "model": "gpt-4o", + "api_version": "2024-08-01-preview" + }, + "embedding": { + "api": "azure", + "url": "", + "model": "text-embedding-ada-002", + "api_version": "2024-08-01-preview" + } +} \ No newline at end of file diff --git a/src/models/azure.js b/src/models/azure.js new file mode 100644 index 0000000..d8e2f4a --- /dev/null +++ b/src/models/azure.js @@ -0,0 +1,23 @@ +import { AzureOpenAI } from "openai"; +import { getKey } from '../utils/keys.js'; +import { GPT } from './gpt.js' + +export class AzureGPT extends GPT { + constructor(model_name, url, api_version, params) { + super(model_name, url) + + this.model_name = model_name; + this.params = params; + + let config = {} + + if (url) + config.endpoint = url; + + config.apiKey = getKey('OPENAI_API_KEY'); + config.deployment = model_name; // This must be what you named the deployment in Azure, not the model version itself + config.apiVersion = api_version; // This is required for Azure + + this.openai = new AzureOpenAI(config) + } +} \ No newline at end of file diff --git a/src/models/prompter.js b/src/models/prompter.js index 5295653..2260ead 100644 --- a/src/models/prompter.js +++ b/src/models/prompter.js @@ -19,6 +19,7 @@ import { HuggingFace } from './huggingface.js'; import { Qwen } from "./qwen.js"; import { Grok } from "./grok.js"; import { DeepSeek } from './deepseek.js'; +import { AzureGPT } from './azure.js'; export class Prompter { constructor(agent, fp) { @@ -72,6 +73,8 @@ export class Prompter { this.embedding_model = new Gemini(embedding.model, embedding.url); else if (embedding.api === 'openai') this.embedding_model = new GPT(embedding.model, embedding.url); + else if (embedding.api === 'azure') + this.embedding_model = new AzureGPT(embedding.model, embedding.url, embedding.api_version); else if (embedding.api === 'replicate') this.embedding_model = new ReplicateAPI(embedding.model, embedding.url); else if (embedding.api === 'ollama') @@ -139,6 +142,8 @@ export class Prompter { model = new Gemini(profile.model, profile.url, profile.params); else if (profile.api === 'openai') model = new GPT(profile.model, profile.url, profile.params); + else if (profile.api === 'azure') + model = new AzureGPT(profile.model, profile.url, profile.api_version, profile.params); else if (profile.api === 'anthropic') model = new Claude(profile.model, profile.url, profile.params); else if (profile.api === 'replicate')