Skip to content

Commit

Permalink
fix: azure openai support (#82)
Browse files Browse the repository at this point in the history
* fix: azure openai support

* add checks

* improve details
  • Loading branch information
j178 authored Jan 16, 2024
1 parent 114c97f commit 9882d94
Show file tree
Hide file tree
Showing 5 changed files with 70 additions and 33 deletions.
15 changes: 13 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -272,14 +272,25 @@ If you are using Azure OpenAI service, you should configure like this:
{
"api_type": "AZURE",
"api_key": "xxxx",
"api_version": "2023-05-15",
"endpoint": "https://YOUR_RESOURCE_NAME.openai.azure.com",
"engine": "YOUR_DEPLOYMENT_NAME",
"api_version": "2023-03-15-preview"
"model_mapping": {
"gpt-3.5-turbo": "your gpt-3.5-turbo deployment name",
"gpt-4": "your gpt-4 deployment name"
}
}
```

</details>

Notes:

- `api_type` should be "AZURE" or "AZURE_AD".
- `api_version` defaults to "2023-05-15" if not specified.
- Configure `model_mapping` to map model names to your deployment names. If not specified, the model name will be used as the deployment name with `.` or `:` removed (e.g. "gpt-3.5-turbo" -> "gpt-35-turbo").

Find more details about Azure OpenAI service here: https://learn.microsoft.com/en-US/azure/ai-services/openai/reference.

## Troubleshooting

1. `Error: unexpected EOF, please try again`
Expand Down
45 changes: 32 additions & 13 deletions chatgpt.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"errors"
"fmt"
"io"
"regexp"

"github.com/avast/retry-go"
"github.com/sashabaranov/go-openai"
Expand All @@ -17,18 +18,31 @@ type ChatGPT struct {
}

func NewChatGPT(conf GlobalConfig) *ChatGPT {
var config openai.ClientConfig
if conf.APIType == openai.APITypeOpenAI {
config = openai.DefaultConfig(conf.APIKey)
var cc openai.ClientConfig
switch conf.APIType {
case openai.APITypeOpenAI:
cc = openai.DefaultConfig(conf.APIKey)
if conf.Endpoint != "" {
config.BaseURL = conf.Endpoint
cc.BaseURL = conf.Endpoint
}
} else {
config = openai.DefaultAzureConfig(conf.APIKey, conf.Endpoint)
config.APIVersion = conf.APIVersion
case openai.APITypeAzure, openai.APITypeAzureAD:
cc = openai.DefaultAzureConfig(conf.APIKey, conf.Endpoint)
if conf.APIVersion != "" {
cc.APIVersion = conf.APIVersion
}
cc.AzureModelMapperFunc = func(model string) string {
m, ok := conf.ModelMapping[model]
if ok {
return m
}
// Fallback to use model name (without . or : ) as deployment name.
return regexp.MustCompile(`[.:]`).ReplaceAllString(model, "")
}
default:
panic(fmt.Sprintf("unknown API type: %s", conf.APIType))
}
config.OrgID = conf.OrgID
client := openai.NewClientWithConfig(config)
cc.OrgID = conf.OrgID
client := openai.NewClientWithConfig(cc)
return &ChatGPT{globalConf: conf, client: client}
}

Expand Down Expand Up @@ -96,16 +110,18 @@ func (c *ChatGPT) Send(conf ConversationConfig, messages []openai.ChatCompletion
if err != nil {
return err
}
content := resp.Choices[0].Delta.Content
msg = content
if len(resp.Choices) > 0 {
msg = resp.Choices[0].Delta.Content
}
hasMore = true
} else {
resp, err := c.client.CreateChatCompletion(context.Background(), req)
if err != nil {
return err
}
content := resp.Choices[0].Message.Content
msg = content
if len(resp.Choices) > 0 {
msg = resp.Choices[0].Message.Content
}
hasMore = false
}
return nil
Expand All @@ -124,6 +140,9 @@ func (c *ChatGPT) Recv() (string, error) {
if err != nil {
return "", err
}
if len(resp.Choices) == 0 {
return "", nil
}
content := resp.Choices[0].Delta.Content
return content, nil
}
Expand Down
13 changes: 10 additions & 3 deletions config.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,9 @@ import (
"encoding/json"
"errors"
"fmt"
"log"
"os"
"path/filepath"
"strings"

"github.com/mitchellh/go-homedir"
"github.com/sashabaranov/go-openai"
Expand Down Expand Up @@ -43,7 +43,8 @@ type GlobalConfig struct {
APIKey string `json:"api_key"`
Endpoint string `json:"endpoint"`
APIType openai.APIType `json:"api_type,omitempty"`
APIVersion string `json:"api_version,omitempty"` // required when APIType is APITypeAzure or APITypeAzureAD
APIVersion string `json:"api_version,omitempty"` // required when APIType is APITypeAzure or APITypeAzureAD
ModelMapping map[string]string `json:"model_mapping,omitempty"` // required when APIType is APITypeAzure or APITypeAzureAD
OrgID string `json:"org_id,omitempty"`
Prompts map[string]string `json:"prompts"`
Conversation ConversationConfig `json:"conversation"` // Default conversation config
Expand Down Expand Up @@ -146,7 +147,7 @@ func InitConfig() (GlobalConfig, error) {
}
err := readOrWriteConfig(&conf)
if err != nil {
log.Println(err)
return GlobalConfig{}, err
}
apiKey := os.Getenv("OPENAI_API_KEY")
if apiKey != "" {
Expand All @@ -159,5 +160,11 @@ func InitConfig() (GlobalConfig, error) {
if conf.APIKey == "" {
return GlobalConfig{}, errors.New("Missing API key. Set it in `~/.config/chatgpt/config.json` or by setting the `OPENAI_API_KEY` environment variable. You can find or create your API key at https://platform.openai.com/account/api-keys.")
}
conf.APIType = openai.APIType(strings.ToUpper(string(conf.APIType)))
switch conf.APIType {
default:
return GlobalConfig{}, fmt.Errorf("unknown API type: %s", conf.APIType)
case openai.APITypeOpenAI, openai.APITypeAzure, openai.APITypeAzureAD:
}
return conf, nil
}
10 changes: 5 additions & 5 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ require (
github.com/muesli/reflow v0.3.0
github.com/pkoukk/tiktoken-go v0.1.6
github.com/postfinance/single v0.0.2
github.com/sashabaranov/go-openai v1.17.9
github.com/sashabaranov/go-openai v1.18.2
)

require (
Expand All @@ -38,9 +38,9 @@ require (
github.com/rivo/uniseg v0.4.4 // indirect
github.com/yuin/goldmark v1.6.0 // indirect
github.com/yuin/goldmark-emoji v1.0.2 // indirect
golang.org/x/net v0.19.0 // indirect
golang.org/x/sync v0.5.0 // indirect
golang.org/x/sys v0.15.0 // indirect
golang.org/x/term v0.15.0 // indirect
golang.org/x/net v0.20.0 // indirect
golang.org/x/sync v0.6.0 // indirect
golang.org/x/sys v0.16.0 // indirect
golang.org/x/term v0.16.0 // indirect
golang.org/x/text v0.14.0 // indirect
)
20 changes: 10 additions & 10 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -68,8 +68,8 @@ github.com/rivo/uniseg v0.1.0/go.mod h1:J6wj4VEh+S6ZtnVlnTBMWIodfgj8LQOQFoIToxlJ
github.com/rivo/uniseg v0.2.0/go.mod h1:J6wj4VEh+S6ZtnVlnTBMWIodfgj8LQOQFoIToxlJtxc=
github.com/rivo/uniseg v0.4.4 h1:8TfxU8dW6PdqD27gjM8MVNuicgxIjxpm4K7x4jp8sis=
github.com/rivo/uniseg v0.4.4/go.mod h1:FN3SvrM+Zdj16jyLfmOkMNblXMcoc8DfTHruCPUcx88=
github.com/sashabaranov/go-openai v1.17.9 h1:QEoBiGKWW68W79YIfXWEFZ7l5cEgZBV4/Ow3uy+5hNY=
github.com/sashabaranov/go-openai v1.17.9/go.mod h1:lj5b/K+zjTSFxVLijLSTDZuP7adOgerWeFyZLUhAKRg=
github.com/sashabaranov/go-openai v1.18.2 h1:UnC307Mgc+fiIDUmEJCiCvRoMxdFrLtQlg8A594pnG8=
github.com/sashabaranov/go-openai v1.18.2/go.mod h1:lj5b/K+zjTSFxVLijLSTDZuP7adOgerWeFyZLUhAKRg=
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
Expand All @@ -84,20 +84,20 @@ github.com/yuin/goldmark-emoji v1.0.1/go.mod h1:2w1E6FEWLcDQkoTE+7HU6QF1F6SLlNGj
github.com/yuin/goldmark-emoji v1.0.2 h1:c/RgTShNgHTtc6xdz2KKI74jJr6rWi7FPgnP9GAsO5s=
github.com/yuin/goldmark-emoji v1.0.2/go.mod h1:RhP/RWpexdp+KHs7ghKnifRoIs/Bq4nDS7tRbCkOwKY=
golang.org/x/net v0.0.0-20221002022538-bcab6841153b/go.mod h1:YDH+HFinaLZZlnHAfSS6ZXJJ9M9t4Dl22yv3iI2vPwk=
golang.org/x/net v0.19.0 h1:zTwKpTd2XuCqf8huc7Fo2iSy+4RHPd10s4KzeTnVr1c=
golang.org/x/net v0.19.0/go.mod h1:CfAk/cbD4CthTvqiEl8NpboMuiuOYsAr/7NOjZJtv1U=
golang.org/x/sync v0.5.0 h1:60k92dhOjHxJkrqnwsfl8KuaHbn/5dl0lUPUklKo3qE=
golang.org/x/sync v0.5.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
golang.org/x/net v0.20.0 h1:aCL9BSgETF1k+blQaYUBx9hJ9LOGP3gAVemcZlf1Kpo=
golang.org/x/net v0.20.0/go.mod h1:z8BVo6PvndSri0LbOE3hAn0apkU+1YvI6E70E9jsnvY=
golang.org/x/sync v0.6.0 h1:5BMeUDZ7vkXGfEr1x9B4bRcTH4lpkTkpdh0T/J+qjbQ=
golang.org/x/sync v0.6.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220728004956-3c1f35247d10/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.1.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.15.0 h1:h48lPFYpsTvQJZF4EKyI4aLHaev3CxivZmv7yZig9pc=
golang.org/x/sys v0.15.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/sys v0.16.0 h1:xWw16ngr6ZMtmxDyKyIgsE93KNKz5HKmMa3b8ALHidU=
golang.org/x/sys v0.16.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8=
golang.org/x/term v0.15.0 h1:y/Oo/a/q3IXu26lQgl04j/gjuBDOBlx7X6Om1j2CPW4=
golang.org/x/term v0.15.0/go.mod h1:BDl952bC7+uMoWR75FIrCDx79TPU9oHkTZ9yRbYOrX0=
golang.org/x/term v0.16.0 h1:m+B6fahuftsE9qjo0VWp2FW0mB3MTJvR0BaMQrq0pmE=
golang.org/x/term v0.16.0/go.mod h1:yn7UURbUtPyrVJPGPq404EukNFxcm/foM+bV/bfcDsY=
golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ=
golang.org/x/text v0.14.0 h1:ScX5w1eTa3QqT8oi6+ziP7dTV1S2+ALU0bI+0zXKWiQ=
golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU=
Expand Down

0 comments on commit 9882d94

Please sign in to comment.