From afc10cb5b6fe94670b4f60c7e19d43a3e98c0bce Mon Sep 17 00:00:00 2001
From: Spike Lu
Date: Fri, 26 Jan 2024 15:56:54 -0800
Subject: [PATCH 01/71] add key auth for admin endpoints
---
.gitignore | 3 ++-
CHANGELOG.md | 4 ++++
README.md | 10 +++++++++-
cmd/bricksllm/main.go | 2 +-
cmd/tool/main.go | 2 +-
internal/config/config.go | 1 +
internal/server/web/admin/admin.go | 4 ++--
internal/server/web/admin/middleware.go | 8 +++++++-
8 files changed, 27 insertions(+), 7 deletions(-)
diff --git a/.gitignore b/.gitignore
index 5b3abb0..41fb96b 100644
--- a/.gitignore
+++ b/.gitignore
@@ -1,2 +1,3 @@
release_notes.md
-target
\ No newline at end of file
+target
+.DS_STORE
\ No newline at end of file
diff --git a/CHANGELOG.md b/CHANGELOG.md
index b2d56b7..c97dd5e 100644
--- a/CHANGELOG.md
+++ b/CHANGELOG.md
@@ -1,3 +1,7 @@
+## 1.8.0 - 2024-01-17
+### Added
+- Added key authentication for admin endpoints
+
## 1.7.6 - 2024-01-17
### Fixed
- Changed code to string in OpenAI error response
diff --git a/README.md b/README.md
index 895c8ba..981df6d 100644
--- a/README.md
+++ b/README.md
@@ -145,10 +145,18 @@ docker pull luyuanxin1995/bricksllm:1.4.0
> | `REDIS_WRITE_TIME_OUT` | optional | Timeout for Redis write operations | `500ms`
> | `IN_MEMORY_DB_UPDATE_INTERVAL` | optional | The interval BricksLLM API gateway polls Postgresql DB for latest key configurations | `1s`
> | `STATS_PROVIDER` | optional | This value can only be datadog. Required for integration with Datadog. |
-> | `PROXY_TIMEOUT` | optional | This value can only be datadog. Required for integration with Datadog. |
+> | `PROXY_TIMEOUT` | optional | Timeout for proxy HTTP requests. |
+> | `ADMIN_PASS` | optional | Simple password authentication for admin endpoints. |
## Configuration Endpoints
The configuration server runs on Port `8001`.
+
+##### Headers
+> | name | type | data type | description |
+> |--------|------------|----------------|------------------------------------------------------|
+> | `X-API-KEY` | optional | `string` | Key authentication header.
+
+
Get keys: GET
/api/key-management/keys
diff --git a/cmd/bricksllm/main.go b/cmd/bricksllm/main.go
index 9f5a5b9..c8dba0d 100644
--- a/cmd/bricksllm/main.go
+++ b/cmd/bricksllm/main.go
@@ -182,7 +182,7 @@ func main() {
cpm := manager.NewCustomProvidersManager(store, cpMemStore)
rm := manager.NewRouteManager(store, store, rMemStore, psMemStore)
- as, err := admin.NewAdminServer(log, *modePtr, m, krm, psm, cpm, rm)
+ as, err := admin.NewAdminServer(log, *modePtr, m, krm, psm, cpm, rm, cfg.AdminPass)
if err != nil {
log.Sugar().Fatalf("error creating admin http server: %v", err)
}
diff --git a/cmd/tool/main.go b/cmd/tool/main.go
index 756551d..11158b8 100644
--- a/cmd/tool/main.go
+++ b/cmd/tool/main.go
@@ -181,7 +181,7 @@ func main() {
cpm := manager.NewCustomProvidersManager(store, cpMemStore)
rm := manager.NewRouteManager(store, store, rMemStore, psMemStore)
- as, err := admin.NewAdminServer(log, *modePtr, m, krm, psm, cpm, rm)
+ as, err := admin.NewAdminServer(log, *modePtr, m, krm, psm, cpm, rm, cfg.AdminPass)
if err != nil {
log.Sugar().Fatalf("error creating admin http server: %v", err)
}
diff --git a/internal/config/config.go b/internal/config/config.go
index 9b675ea..f152055 100644
--- a/internal/config/config.go
+++ b/internal/config/config.go
@@ -24,6 +24,7 @@ type Config struct {
InMemoryDbUpdateInterval time.Duration `env:"IN_MEMORY_DB_UPDATE_INTERVAL" envDefault:"5s"`
OpenAiKey string `env:"OPENAI_API_KEY"`
StatsProvider string `env:"STATS_PROVIDER"`
+ AdminPass string `env:"ADMIN_PASS"`
ProxyTimeout time.Duration `env:"PROXY_TIMEOUT" envDefault:"180s"`
}
diff --git a/internal/server/web/admin/admin.go b/internal/server/web/admin/admin.go
index dc67778..771dc95 100644
--- a/internal/server/web/admin/admin.go
+++ b/internal/server/web/admin/admin.go
@@ -55,11 +55,11 @@ type AdminServer struct {
m KeyManager
}
-func NewAdminServer(log *zap.Logger, mode string, m KeyManager, krm KeyReportingManager, psm ProviderSettingsManager, cpm CustomProvidersManager, rm RouteManager) (*AdminServer, error) {
+func NewAdminServer(log *zap.Logger, mode string, m KeyManager, krm KeyReportingManager, psm ProviderSettingsManager, cpm CustomProvidersManager, rm RouteManager, adminPass string) (*AdminServer, error) {
router := gin.New()
prod := mode == "production"
- router.Use(getAdminLoggerMiddleware(log, "admin", prod))
+ router.Use(getAdminLoggerMiddleware(log, "admin", prod, adminPass))
router.GET("/api/health", getGetHealthCheckHandler())
diff --git a/internal/server/web/admin/middleware.go b/internal/server/web/admin/middleware.go
index ad6ad85..d4089a4 100644
--- a/internal/server/web/admin/middleware.go
+++ b/internal/server/web/admin/middleware.go
@@ -8,8 +8,14 @@ import (
"go.uber.org/zap"
)
-func getAdminLoggerMiddleware(log *zap.Logger, prefix string, prod bool) gin.HandlerFunc {
+func getAdminLoggerMiddleware(log *zap.Logger, prefix string, prod bool, adminPass string) gin.HandlerFunc {
return func(c *gin.Context) {
+ if len(adminPass) != 0 && c.Request.Header.Get("X-API-KEY") != adminPass {
+ c.Status(200)
+ c.Abort()
+ return
+ }
+
c.Set(correlationId, util.NewUuid())
start := time.Now()
c.Next()
From 94978e239604c97d28b5cdb582bf5fd1e39e9d73 Mon Sep 17 00:00:00 2001
From: Spike Lu
Date: Fri, 26 Jan 2024 15:57:30 -0800
Subject: [PATCH 02/71] update CHANGELOG
---
CHANGELOG.md | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/CHANGELOG.md b/CHANGELOG.md
index c97dd5e..3459264 100644
--- a/CHANGELOG.md
+++ b/CHANGELOG.md
@@ -1,4 +1,4 @@
-## 1.8.0 - 2024-01-17
+## 1.8.0 - 2024-01-26
### Added
- Added key authentication for admin endpoints
From ad91ea17619153c748d70df842066205a8528666 Mon Sep 17 00:00:00 2001
From: Spike Lu
Date: Tue, 30 Jan 2024 20:22:51 -0800
Subject: [PATCH 03/71] add support querying events by key ids and query data
points by custom ids
---
internal/event/reporting.go | 2 ++
internal/manager/reporting.go | 14 ++++-----
internal/server/web/admin/admin.go | 38 ++++++++++++++++++-----
internal/storage/postgresql/postgresql.go | 37 ++++++++++++++++++++--
4 files changed, 74 insertions(+), 17 deletions(-)
diff --git a/internal/event/reporting.go b/internal/event/reporting.go
index 7ccb61c..5fab273 100644
--- a/internal/event/reporting.go
+++ b/internal/event/reporting.go
@@ -10,6 +10,7 @@ type DataPoint struct {
SuccessCount int `json:"successCount"`
Model string `json:"model"`
KeyId string `json:"keyId"`
+ CustomId string `json:"customId"`
}
type ReportingResponse struct {
@@ -21,6 +22,7 @@ type ReportingResponse struct {
type ReportingRequest struct {
KeyIds []string `json:"keyIds"`
Tags []string `json:"tags"`
+ CustomIds []string `json:"customIds"`
Start int64 `json:"start"`
End int64 `json:"end"`
Increment int64 `json:"increment"`
diff --git a/internal/manager/reporting.go b/internal/manager/reporting.go
index 57a7ae6..79fe066 100644
--- a/internal/manager/reporting.go
+++ b/internal/manager/reporting.go
@@ -18,8 +18,8 @@ type keyStorage interface {
}
type eventStorage interface {
- GetEvents(customId string) ([]*event.Event, error)
- GetEventDataPoints(start, end, increment int64, tags, keyIds []string, filters []string) ([]*event.DataPoint, error)
+ GetEvents(customId string, keyIds []string) ([]*event.Event, error)
+ GetEventDataPoints(start, end, increment int64, tags, keyIds, customIds []string, filters []string) ([]*event.DataPoint, error)
GetLatencyPercentiles(start, end int64, tags, keyIds []string) ([]float64, error)
}
@@ -38,7 +38,7 @@ func NewReportingManager(cs costStorage, ks keyStorage, es eventStorage) *Report
}
func (rm *ReportingManager) GetEventReporting(e *event.ReportingRequest) (*event.ReportingResponse, error) {
- dataPoints, err := rm.es.GetEventDataPoints(e.Start, e.End, e.Increment, e.Tags, e.KeyIds, e.Filters)
+ dataPoints, err := rm.es.GetEventDataPoints(e.Start, e.End, e.Increment, e.Tags, e.KeyIds, e.CustomIds, e.Filters)
if err != nil {
return nil, err
}
@@ -80,12 +80,12 @@ func (rm *ReportingManager) GetKeyReporting(keyId string) (*key.KeyReporting, er
}, err
}
-func (rm *ReportingManager) GetEvent(customId string) (*event.Event, error) {
+func (rm *ReportingManager) GetEvent(customId string, keyIds []string) (*event.Event, error) {
if len(customId) == 0 {
return nil, errors.New("customId cannot be empty")
}
- events, err := rm.es.GetEvents(customId)
+ events, err := rm.es.GetEvents(customId, keyIds)
if err != nil {
return nil, err
}
@@ -97,8 +97,8 @@ func (rm *ReportingManager) GetEvent(customId string) (*event.Event, error) {
return nil, internal_errors.NewNotFoundError(fmt.Sprintf("event is not found for customId: %s", customId))
}
-func (rm *ReportingManager) GetEvents(customId string) ([]*event.Event, error) {
- events, err := rm.es.GetEvents(customId)
+func (rm *ReportingManager) GetEvents(customId string, keyIds []string) ([]*event.Event, error) {
+ events, err := rm.es.GetEvents(customId, keyIds)
if err != nil {
return nil, err
}
diff --git a/internal/server/web/admin/admin.go b/internal/server/web/admin/admin.go
index 771dc95..0184c34 100644
--- a/internal/server/web/admin/admin.go
+++ b/internal/server/web/admin/admin.go
@@ -37,7 +37,8 @@ type KeyManager interface {
type KeyReportingManager interface {
GetKeyReporting(keyId string) (*key.KeyReporting, error)
- GetEvent(customId string) (*event.Event, error)
+ GetEvents(customId string, keyIds []string) ([]*event.Event, error)
+ GetEvent(customId string, keyIds []string) (*event.Event, error)
GetEventReporting(e *event.ReportingRequest) (*event.ReportingResponse, error)
}
@@ -819,20 +820,43 @@ func getGetEventsHandler(m KeyReportingManager, log *zap.Logger, prod bool) gin.
}
cid := c.GetString(correlationId)
- customId, ok := c.GetQuery("customId")
- if !ok {
+ customId, ciok := c.GetQuery("customId")
+ keyIds, kiok := c.GetQueryArray("keyIds")
+ if !ciok && !kiok {
c.JSON(http.StatusBadRequest, &ErrorResponse{
- Type: "/errors/custom-id-empty",
- Title: "custom id is empty",
+ Type: "/errors/no-filters-empty",
+ Title: "neither customId nor keyIds are specified",
Status: http.StatusBadRequest,
- Detail: "query param customId is empty. it is required for retrieving an event.",
+ Detail: "both query params customId and keyIds are empty. either of them is required for retrieving events.",
Instance: path,
})
return
}
- ev, err := m.GetEvent(customId)
+ if kiok {
+ evs, err := m.GetEvents(customId, keyIds)
+ if err != nil {
+ stats.Incr("bricksllm.admin.get_get_events_handler.get_events_error", nil, 1)
+
+ logError(log, "error when getting events", prod, cid, err)
+ c.JSON(http.StatusInternalServerError, &ErrorResponse{
+ Type: "/errors/event-manager",
+ Title: "getting events error",
+ Status: http.StatusInternalServerError,
+ Detail: err.Error(),
+ Instance: path,
+ })
+ return
+ }
+
+ stats.Incr("bricksllm.admin.get_get_events_handler.success", nil, 1)
+
+ c.JSON(http.StatusOK, evs)
+ return
+ }
+
+ ev, err := m.GetEvent(customId, keyIds)
if err != nil {
stats.Incr("bricksllm.admin.get_get_events_handler.get_event_error", nil, 1)
diff --git a/internal/storage/postgresql/postgresql.go b/internal/storage/postgresql/postgresql.go
index 8de0710..68cf21e 100644
--- a/internal/storage/postgresql/postgresql.go
+++ b/internal/storage/postgresql/postgresql.go
@@ -241,11 +241,27 @@ func (s *Store) InsertEvent(e *event.Event) error {
return nil
}
-func (s *Store) GetEvents(customId string) ([]*event.Event, error) {
+func (s *Store) GetEvents(customId string, keyIds []string) ([]*event.Event, error) {
+ if len(customId) == 0 && len(keyIds) == 0 {
+ return nil, errors.New("neither customId nor keyIds are specified")
+ }
+
query := `
- SELECT * FROM events WHERE $1 = custom_id
+ SELECT * FROM events WHERE
`
+ if len(customId) != 0 {
+ query += " custom_id = $1"
+ }
+
+ if len(customId) != 0 && len(keyIds) != 0 {
+ query += " AND"
+ }
+
+ if len(keyIds) != 0 {
+ query += fmt.Sprintf(" key_id = ANY('%s')", sliceToSqlStringArray(keyIds))
+ }
+
ctxTimeout, cancel := context.WithTimeout(context.Background(), s.rt)
defer cancel()
@@ -356,7 +372,7 @@ func (s *Store) GetLatencyPercentiles(start, end int64, tags, keyIds []string) (
return data, nil
}
-func (s *Store) GetEventDataPoints(start, end, increment int64, tags, keyIds []string, filters []string) ([]*event.DataPoint, error) {
+func (s *Store) GetEventDataPoints(start, end, increment int64, tags, keyIds, customIds []string, filters []string) ([]*event.DataPoint, error) {
groupByQuery := "GROUP BY time_series_table.series"
selectQuery := "SELECT series AS time_stamp, COALESCE(COUNT(events_table.event_id),0) AS num_of_requests, COALESCE(SUM(events_table.cost_in_usd),0) AS cost_in_usd, COALESCE(SUM(events_table.latency_in_ms),0) AS latency_in_ms, COALESCE(SUM(events_table.prompt_token_count),0) AS prompt_token_count, COALESCE(SUM(events_table.completion_token_count),0) AS completion_token_count, COALESCE(SUM(CASE WHEN status_code = 200 THEN 1 END),0) AS success_count"
@@ -371,6 +387,11 @@ func (s *Store) GetEventDataPoints(start, end, increment int64, tags, keyIds []s
groupByQuery += ",events_table.key_id"
selectQuery += ",events_table.key_id as keyId"
}
+
+ if filter == "customId" {
+ groupByQuery += ",events_table.custom_id"
+ selectQuery += ",events_table.custom_id as customId"
+ }
}
}
@@ -406,6 +427,10 @@ func (s *Store) GetEventDataPoints(start, end, increment int64, tags, keyIds []s
conditionBlock += fmt.Sprintf("AND key_id = ANY('%s')", sliceToSqlStringArray(keyIds))
}
+ if len(customIds) != 0 {
+ conditionBlock += fmt.Sprintf("AND custom_id = ANY('%s')", sliceToSqlStringArray(customIds))
+ }
+
eventSelectionBlock += conditionBlock
eventSelectionBlock += ")"
@@ -425,6 +450,7 @@ func (s *Store) GetEventDataPoints(start, end, increment int64, tags, keyIds []s
var e event.DataPoint
var model sql.NullString
var keyId sql.NullString
+ var customId sql.NullString
additional := []any{
&e.TimeStamp,
@@ -445,6 +471,10 @@ func (s *Store) GetEventDataPoints(start, end, increment int64, tags, keyIds []s
if filter == "keyId" {
additional = append(additional, &keyId)
}
+
+ if filter == "customId" {
+ additional = append(additional, &customId)
+ }
}
}
@@ -457,6 +487,7 @@ func (s *Store) GetEventDataPoints(start, end, increment int64, tags, keyIds []s
pe := &e
pe.Model = model.String
pe.KeyId = keyId.String
+ pe.CustomId = customId.String
data = append(data, pe)
}
From 03cab2bffafe4696871c477575711b9deb861a1e Mon Sep 17 00:00:00 2001
From: Spike Lu
Date: Tue, 30 Jan 2024 21:17:28 -0800
Subject: [PATCH 04/71] fix bugs with retrieving events
---
internal/storage/postgresql/postgresql.go | 4 ++--
1 file changed, 2 insertions(+), 2 deletions(-)
diff --git a/internal/storage/postgresql/postgresql.go b/internal/storage/postgresql/postgresql.go
index 68cf21e..94e8c6f 100644
--- a/internal/storage/postgresql/postgresql.go
+++ b/internal/storage/postgresql/postgresql.go
@@ -251,7 +251,7 @@ func (s *Store) GetEvents(customId string, keyIds []string) ([]*event.Event, err
`
if len(customId) != 0 {
- query += " custom_id = $1"
+ query += fmt.Sprintf(" custom_id = '%s'", customId)
}
if len(customId) != 0 && len(keyIds) != 0 {
@@ -266,7 +266,7 @@ func (s *Store) GetEvents(customId string, keyIds []string) ([]*event.Event, err
defer cancel()
events := []*event.Event{}
- rows, err := s.db.QueryContext(ctxTimeout, query, customId)
+ rows, err := s.db.QueryContext(ctxTimeout, query)
if err != nil {
if err == sql.ErrNoRows {
return events, nil
From 785880e8bdd1a76f026a9c01fa0cd44f73667e22 Mon Sep 17 00:00:00 2001
From: Spike Lu
Date: Tue, 30 Jan 2024 21:32:21 -0800
Subject: [PATCH 05/71] remove restrictions on number of events
---
internal/server/web/admin/admin.go | 32 +++++-------------------------
1 file changed, 5 insertions(+), 27 deletions(-)
diff --git a/internal/server/web/admin/admin.go b/internal/server/web/admin/admin.go
index 0184c34..396a063 100644
--- a/internal/server/web/admin/admin.go
+++ b/internal/server/web/admin/admin.go
@@ -834,36 +834,14 @@ func getGetEventsHandler(m KeyReportingManager, log *zap.Logger, prod bool) gin.
return
}
- if kiok {
- evs, err := m.GetEvents(customId, keyIds)
- if err != nil {
- stats.Incr("bricksllm.admin.get_get_events_handler.get_events_error", nil, 1)
-
- logError(log, "error when getting events", prod, cid, err)
- c.JSON(http.StatusInternalServerError, &ErrorResponse{
- Type: "/errors/event-manager",
- Title: "getting events error",
- Status: http.StatusInternalServerError,
- Detail: err.Error(),
- Instance: path,
- })
- return
- }
-
- stats.Incr("bricksllm.admin.get_get_events_handler.success", nil, 1)
-
- c.JSON(http.StatusOK, evs)
- return
- }
-
- ev, err := m.GetEvent(customId, keyIds)
+ evs, err := m.GetEvents(customId, keyIds)
if err != nil {
- stats.Incr("bricksllm.admin.get_get_events_handler.get_event_error", nil, 1)
+ stats.Incr("bricksllm.admin.get_get_events_handler.get_events_error", nil, 1)
- logError(log, "error when getting an event", prod, cid, err)
+ logError(log, "error when getting events", prod, cid, err)
c.JSON(http.StatusInternalServerError, &ErrorResponse{
Type: "/errors/event-manager",
- Title: "getting an event error",
+ Title: "getting events error",
Status: http.StatusInternalServerError,
Detail: err.Error(),
Instance: path,
@@ -873,7 +851,7 @@ func getGetEventsHandler(m KeyReportingManager, log *zap.Logger, prod bool) gin.
stats.Incr("bricksllm.admin.get_get_events_handler.success", nil, 1)
- c.JSON(http.StatusOK, []*event.Event{ev})
+ c.JSON(http.StatusOK, evs)
}
}
From 83c5e7aa7b0a0e1ba5772469b5e545a600697c1c Mon Sep 17 00:00:00 2001
From: Spike Lu
Date: Wed, 31 Jan 2024 14:39:42 -0800
Subject: [PATCH 06/71] Extended default proxy request timeout to 10m
---
CHANGELOG.md | 4 ++++
internal/config/config.go | 2 +-
2 files changed, 5 insertions(+), 1 deletion(-)
diff --git a/CHANGELOG.md b/CHANGELOG.md
index 3459264..31e00c7 100644
--- a/CHANGELOG.md
+++ b/CHANGELOG.md
@@ -1,3 +1,7 @@
+## 1.8.1 - 2024-01-31
+### Fixed
+- Extended default proxy request timeout to 10m
+
## 1.8.0 - 2024-01-26
### Added
- Added key authentication for admin endpoints
diff --git a/internal/config/config.go b/internal/config/config.go
index f152055..784bc2b 100644
--- a/internal/config/config.go
+++ b/internal/config/config.go
@@ -25,7 +25,7 @@ type Config struct {
OpenAiKey string `env:"OPENAI_API_KEY"`
StatsProvider string `env:"STATS_PROVIDER"`
AdminPass string `env:"ADMIN_PASS"`
- ProxyTimeout time.Duration `env:"PROXY_TIMEOUT" envDefault:"180s"`
+ ProxyTimeout time.Duration `env:"PROXY_TIMEOUT" envDefault:"600s"`
}
func ParseEnvVariables() (*Config, error) {
From 4a441422e2d93b8bd00e46490f07406f477de6f1 Mon Sep 17 00:00:00 2001
From: Spike Lu
Date: Wed, 31 Jan 2024 14:53:23 -0800
Subject: [PATCH 07/71] start handling context deadline exceeded error in
streaming mode
---
internal/server/web/proxy/anthropic.go | 8 ++++++++
internal/server/web/proxy/azure_chat_completion.go | 8 ++++++++
internal/server/web/proxy/custom_provider.go | 7 +++++++
internal/server/web/proxy/proxy.go | 7 +++++++
4 files changed, 30 insertions(+)
diff --git a/internal/server/web/proxy/anthropic.go b/internal/server/web/proxy/anthropic.go
index fa96eb2..a9aab8a 100644
--- a/internal/server/web/proxy/anthropic.go
+++ b/internal/server/web/proxy/anthropic.go
@@ -5,6 +5,7 @@ import (
"bytes"
"context"
"encoding/json"
+ "errors"
"io"
"net/http"
"strings"
@@ -191,6 +192,13 @@ func getCompletionHandler(r recorder, prod, private bool, client http.Client, km
return false
}
+ if errors.Is(err, context.DeadlineExceeded) {
+ stats.Incr("bricksllm.proxy.get_completion_handler.context_deadline_exceeded_error", nil, 1)
+ logError(log, "context deadline exceeded when reading bytes from anthropic completion response", prod, cid, err)
+
+ return false
+ }
+
stats.Incr("bricksllm.proxy.get_completion_handler.read_bytes_error", nil, 1)
logError(log, "error when reading bytes from anthropic streaming response", prod, cid, err)
diff --git a/internal/server/web/proxy/azure_chat_completion.go b/internal/server/web/proxy/azure_chat_completion.go
index d0b09aa..970237a 100644
--- a/internal/server/web/proxy/azure_chat_completion.go
+++ b/internal/server/web/proxy/azure_chat_completion.go
@@ -5,6 +5,7 @@ import (
"bytes"
"context"
"encoding/json"
+ "errors"
"fmt"
"io"
"net/http"
@@ -183,6 +184,13 @@ func getAzureChatCompletionHandler(r recorder, prod, private bool, psm ProviderS
return false
}
+ if errors.Is(err, context.DeadlineExceeded) {
+ stats.Incr("bricksllm.proxy.get_azure_chat_completion_handler.context_deadline_exceeded_error", nil, 1)
+ logError(log, "context deadline exceeded when reading bytes from azure openai chat completion response", prod, cid, err)
+
+ return false
+ }
+
stats.Incr("bricksllm.proxy.get_azure_chat_completion_handler.read_bytes_error", nil, 1)
logError(log, "error when reading bytes from azure openai chat completion response", prod, cid, err)
diff --git a/internal/server/web/proxy/custom_provider.go b/internal/server/web/proxy/custom_provider.go
index 8ebcb6e..bbc8291 100644
--- a/internal/server/web/proxy/custom_provider.go
+++ b/internal/server/web/proxy/custom_provider.go
@@ -167,6 +167,13 @@ func getCustomProviderHandler(prod, private bool, psm ProviderSettingsManager, c
return false
}
+ if errors.Is(err, context.DeadlineExceeded) {
+ stats.Incr("bricksllm.proxy.get_custom_provider_handler.context_deadline_exceeded_error", nil, 1)
+ logError(log, "context deadline exceeded when reading bytes from custom provider response", prod, cid, err)
+
+ return false
+ }
+
stats.Incr("bricksllm.proxy.get_custom_provider_handler.read_bytes_error", nil, 1)
logError(log, "error when reading bytes from custom provider response", prod, cid, err)
diff --git a/internal/server/web/proxy/proxy.go b/internal/server/web/proxy/proxy.go
index 2c1a07e..dbcb5eb 100644
--- a/internal/server/web/proxy/proxy.go
+++ b/internal/server/web/proxy/proxy.go
@@ -1222,6 +1222,13 @@ func getChatCompletionHandler(r recorder, prod, private bool, psm ProviderSettin
return false
}
+ if errors.Is(err, context.DeadlineExceeded) {
+ stats.Incr("bricksllm.proxy.get_chat_completion_handler.context_deadline_exceeded_error", nil, 1)
+ logError(log, "context deadline exceeded when reading bytes from openai chat completion response", prod, cid, err)
+
+ return false
+ }
+
stats.Incr("bricksllm.proxy.get_chat_completion_handler.read_bytes_error", nil, 1)
logError(log, "error when reading bytes from openai chat completion response", prod, cid, err)
From af2e0145984722c8e1ada294b92f3f3e95589f5c Mon Sep 17 00:00:00 2001
From: Spike Lu
Date: Wed, 31 Jan 2024 14:56:36 -0800
Subject: [PATCH 08/71] started handling context deadline exceeded error
---
CHANGELOG.md | 5 ++++-
1 file changed, 4 insertions(+), 1 deletion(-)
diff --git a/CHANGELOG.md b/CHANGELOG.md
index 31e00c7..dcbdbe3 100644
--- a/CHANGELOG.md
+++ b/CHANGELOG.md
@@ -1,7 +1,10 @@
## 1.8.1 - 2024-01-31
-### Fixed
+### Changed
- Extended default proxy request timeout to 10m
+### Fixed
+- Fixed streaming response stuck at context deadline exceeded error
+
## 1.8.0 - 2024-01-26
### Added
- Added key authentication for admin endpoints
From 5ec1edcd668b592a8cf0469be0887020fa4ac895 Mon Sep 17 00:00:00 2001
From: Spike Lu
Date: Wed, 31 Jan 2024 16:10:55 -0800
Subject: [PATCH 09/71] add start and end query param
---
internal/manager/reporting.go | 26 +--------
internal/server/web/admin/admin.go | 67 ++++++++++++++++++++++-
internal/storage/postgresql/postgresql.go | 8 ++-
3 files changed, 73 insertions(+), 28 deletions(-)
diff --git a/internal/manager/reporting.go b/internal/manager/reporting.go
index 79fe066..b78d359 100644
--- a/internal/manager/reporting.go
+++ b/internal/manager/reporting.go
@@ -1,9 +1,6 @@
package manager
import (
- "errors"
- "fmt"
-
internal_errors "github.com/bricks-cloud/bricksllm/internal/errors"
"github.com/bricks-cloud/bricksllm/internal/event"
"github.com/bricks-cloud/bricksllm/internal/key"
@@ -18,7 +15,7 @@ type keyStorage interface {
}
type eventStorage interface {
- GetEvents(customId string, keyIds []string) ([]*event.Event, error)
+ GetEvents(customId string, keyIds []string, start, end int64) ([]*event.Event, error)
GetEventDataPoints(start, end, increment int64, tags, keyIds, customIds []string, filters []string) ([]*event.DataPoint, error)
GetLatencyPercentiles(start, end int64, tags, keyIds []string) ([]float64, error)
}
@@ -80,25 +77,8 @@ func (rm *ReportingManager) GetKeyReporting(keyId string) (*key.KeyReporting, er
}, err
}
-func (rm *ReportingManager) GetEvent(customId string, keyIds []string) (*event.Event, error) {
- if len(customId) == 0 {
- return nil, errors.New("customId cannot be empty")
- }
-
- events, err := rm.es.GetEvents(customId, keyIds)
- if err != nil {
- return nil, err
- }
-
- if len(events) >= 1 {
- return events[0], nil
- }
-
- return nil, internal_errors.NewNotFoundError(fmt.Sprintf("event is not found for customId: %s", customId))
-}
-
-func (rm *ReportingManager) GetEvents(customId string, keyIds []string) ([]*event.Event, error) {
- events, err := rm.es.GetEvents(customId, keyIds)
+func (rm *ReportingManager) GetEvents(customId string, keyIds []string, start, end int64) ([]*event.Event, error) {
+ events, err := rm.es.GetEvents(customId, keyIds, start, end)
if err != nil {
return nil, err
}
diff --git a/internal/server/web/admin/admin.go b/internal/server/web/admin/admin.go
index 396a063..7766e66 100644
--- a/internal/server/web/admin/admin.go
+++ b/internal/server/web/admin/admin.go
@@ -6,6 +6,7 @@ import (
"fmt"
"io"
"net/http"
+ "strconv"
"time"
"github.com/bricks-cloud/bricksllm/internal/event"
@@ -37,8 +38,7 @@ type KeyManager interface {
type KeyReportingManager interface {
GetKeyReporting(keyId string) (*key.KeyReporting, error)
- GetEvents(customId string, keyIds []string) ([]*event.Event, error)
- GetEvent(customId string, keyIds []string) (*event.Event, error)
+ GetEvents(customId string, keyIds []string, start int64, end int64) ([]*event.Event, error)
GetEventReporting(e *event.ReportingRequest) (*event.ReportingResponse, error)
}
@@ -834,7 +834,68 @@ func getGetEventsHandler(m KeyReportingManager, log *zap.Logger, prod bool) gin.
return
}
- evs, err := m.GetEvents(customId, keyIds)
+ var qstart int64 = 0
+ var qend int64 = 0
+
+ if kiok {
+ startstr, sok := c.GetQuery("start")
+ if !sok {
+ c.JSON(http.StatusBadRequest, &ErrorResponse{
+ Type: "/errors/query-param-start-missing",
+ Title: "query param start is missing",
+ Status: http.StatusBadRequest,
+ Detail: "start query param is not provided",
+ Instance: path,
+ })
+
+ return
+ }
+
+ parsedStart, err := strconv.ParseInt(startstr, 10, 64)
+ if err != nil {
+ c.JSON(http.StatusBadRequest, &ErrorResponse{
+ Type: "/errors/bad-start-query-param",
+ Title: "start query cannot be parsed",
+ Status: http.StatusBadRequest,
+ Detail: "start query param must be int64",
+ Instance: path,
+ })
+
+ return
+ }
+
+ qstart = parsedStart
+
+ endstr, eoi := c.GetQuery("end")
+ if !eoi {
+ c.JSON(http.StatusBadRequest, &ErrorResponse{
+ Type: "/errors/query-param-end-missing",
+ Title: "query param end is missing",
+ Status: http.StatusBadRequest,
+ Detail: "end query param is not provided",
+ Instance: path,
+ })
+
+ return
+ }
+
+ parsedEnd, err := strconv.ParseInt(endstr, 10, 64)
+ if err != nil {
+ c.JSON(http.StatusBadRequest, &ErrorResponse{
+ Type: "/errors/bad-end-query-param",
+ Title: "end query cannot be parsed",
+ Status: http.StatusBadRequest,
+ Detail: "end query param must be int64",
+ Instance: path,
+ })
+
+ return
+ }
+
+ qend = parsedEnd
+ }
+
+ evs, err := m.GetEvents(customId, keyIds, qstart, qend)
if err != nil {
stats.Incr("bricksllm.admin.get_get_events_handler.get_events_error", nil, 1)
diff --git a/internal/storage/postgresql/postgresql.go b/internal/storage/postgresql/postgresql.go
index 94e8c6f..1a1467b 100644
--- a/internal/storage/postgresql/postgresql.go
+++ b/internal/storage/postgresql/postgresql.go
@@ -241,11 +241,15 @@ func (s *Store) InsertEvent(e *event.Event) error {
return nil
}
-func (s *Store) GetEvents(customId string, keyIds []string) ([]*event.Event, error) {
+func (s *Store) GetEvents(customId string, keyIds []string, start int64, end int64) ([]*event.Event, error) {
if len(customId) == 0 && len(keyIds) == 0 {
return nil, errors.New("neither customId nor keyIds are specified")
}
+ if len(keyIds) == 0 && (start == 0 || end == 0) {
+ return nil, errors.New("keyIds are provided but either start or end is not specified")
+ }
+
query := `
SELECT * FROM events WHERE
`
@@ -259,7 +263,7 @@ func (s *Store) GetEvents(customId string, keyIds []string) ([]*event.Event, err
}
if len(keyIds) != 0 {
- query += fmt.Sprintf(" key_id = ANY('%s')", sliceToSqlStringArray(keyIds))
+ query += fmt.Sprintf(" key_id = ANY('%s') AND created_at >= %d AND created_at <= %d", sliceToSqlStringArray(keyIds), start, end)
}
ctxTimeout, cancel := context.WithTimeout(context.Background(), s.rt)
From 33ac5959376baa016a71ab92b9150aecb3b691b8 Mon Sep 17 00:00:00 2001
From: Spike Lu
Date: Wed, 31 Jan 2024 23:07:29 -0800
Subject: [PATCH 10/71] update cost mapping
---
README.md | 9 +++++++--
internal/provider/openai/cost.go | 22 +++++-----------------
2 files changed, 12 insertions(+), 19 deletions(-)
diff --git a/README.md b/README.md
index 981df6d..132c15d 100644
--- a/README.md
+++ b/README.md
@@ -497,7 +497,9 @@ This endpoint is retrieving aggregated metrics given an array of key ids and tag
> | Field | required | type | example | description |
> |---------------|-----------------------------------|-|-|-|
> | keyIds | required | `[]string` | `["key-1", "key-2", "key-3" ]` | Array of ids that specicify the keys that you want to aggregate stats from. |
-> | tags | required | `[]string` | `["tag-1", "tag-2"]` | Array of tags that specicify the keys that you want to aggregate stats from. |
+> | tags | required | `[]string` | `["tag-1", "tag-2"]` | Array of tags that specicify the key tags that you want to aggregate stats from. |
+> | customIds | required | `[]string` | `["customId-1", "customId-2"]` | A list of custom IDs that you want to aggregate stats from. |
+> | filters | required | `[]string` | `["model", "keyId"]` | Group by data points through different filters(`model`,`keyId` or `customId`). |
> | start | required | `int64` | `1699933571` | Start timestamp for the requested timeseries data. |
> | end | required | `int64` | `1699933571` | End timestamp for the requested timeseries data. |
> | increment | required | `int` | `60` | This field is the increment in seconds for the requested timeseries data. |
@@ -546,7 +548,10 @@ This endpoint is for getting events.
##### Query Parameters
> | name | type | data type | description |
> |--------|------------|----------------|------------------------------------------------------|
-> | `customId` | optional | string | Custom identifier attached to an event |
+> | `customId` | optional | `string` | Custom identifier attached to an event. |
+> | `keyIds` | optional | `[]string` | A list of key IDs. |
+> | `start` | required if `keyIds` is specified | `int64` | Start timestamp. |
+> | `end` | required if `keyIds` is specified | `int64` | End timestamp. |
##### Error Response
> | http code | content-type |
diff --git a/internal/provider/openai/cost.go b/internal/provider/openai/cost.go
index 96e35c4..2bec29b 100644
--- a/internal/provider/openai/cost.go
+++ b/internal/provider/openai/cost.go
@@ -12,6 +12,7 @@ import (
var OpenAiPerThousandTokenCost = map[string]map[string]float64{
"prompt": {
"gpt-4-1106-preview": 0.01,
+ "gpt-4-0125-preview": 0.01,
"gpt-4-1106-vision-preview": 0.01,
"gpt-4": 0.03,
"gpt-4-0314": 0.03,
@@ -50,27 +51,14 @@ var OpenAiPerThousandTokenCost = map[string]map[string]float64{
"ada": 0.0004,
},
"embeddings": {
- "text-embedding-ada-002": 0.0001,
- "text-similarity-ada-001": 0.004,
- "text-search-ada-doc-001": 0.004,
- "text-search-ada-query-001": 0.004,
- "code-search-ada-code-001": 0.004,
- "code-search-ada-text-001": 0.004,
- "code-search-babbage-code-001": 0.005,
- "code-search-babbage-text-001": 0.005,
- "text-similarity-babbage-001": 0.005,
- "text-search-babbage-doc-001": 0.005,
- "text-search-babbage-query-001": 0.005,
- "text-similarity-curie-001": 0.02,
- "text-search-curie-doc-001": 0.02,
- "text-search-curie-query-001": 0.02,
- "text-search-davinci-doc-001": 0.2,
- "text-search-davinci-query-001": 0.2,
- "text-similarity-davinci-001": 0.2,
+ "text-embedding-ada-002": 0.0001,
+ "text-embedding-3-small": 0.00002,
+ "text-embedding-3-large": 0.00013,
},
"completion": {
"gpt-3.5-turbo-1106": 0.002,
"gpt-4-1106-preview": 0.03,
+ "gpt-4-0125-preview": 0.03,
"gpt-4-1106-vision-preview": 0.03,
"gpt-4": 0.06,
"gpt-4-0314": 0.06,
From b789c11a0b6e01568ed07b8e308447a048783d8b Mon Sep 17 00:00:00 2001
From: Spike Lu
Date: Wed, 31 Jan 2024 23:22:18 -0800
Subject: [PATCH 11/71] update CHANGELOG
---
CHANGELOG.md | 5 +++++
1 file changed, 5 insertions(+)
diff --git a/CHANGELOG.md b/CHANGELOG.md
index dcbdbe3..8370740 100644
--- a/CHANGELOG.md
+++ b/CHANGELOG.md
@@ -1,3 +1,8 @@
+## 1.8.2 - 2024-01-31
+### Added
+- Added support for new chat completion models
+- Added new quering options for metrics and events API
+
## 1.8.1 - 2024-01-31
### Changed
- Extended default proxy request timeout to 10m
From a84aae2cc87d8d923d37e66c224cece1a65367ee Mon Sep 17 00:00:00 2001
From: Spike Lu
Date: Wed, 31 Jan 2024 23:41:43 -0800
Subject: [PATCH 12/71] update doc
---
CHANGELOG.md | 2 +-
README.md | 1 +
2 files changed, 2 insertions(+), 1 deletion(-)
diff --git a/CHANGELOG.md b/CHANGELOG.md
index 8370740..dca46b1 100644
--- a/CHANGELOG.md
+++ b/CHANGELOG.md
@@ -1,7 +1,7 @@
## 1.8.2 - 2024-01-31
### Added
- Added support for new chat completion models
-- Added new quering options for metrics and events API
+- Added new querying options for metrics and events API
## 1.8.1 - 2024-01-31
### Changed
diff --git a/README.md b/README.md
index 132c15d..7236b82 100644
--- a/README.md
+++ b/README.md
@@ -536,6 +536,7 @@ Datapoint
> | successCount | `int` | `555` | Aggregated number of successful http requests over the given time increment. |
> | keyId | `int` | `555.7` | key Id associated with the event. |
> | model | `string` | `gpt-3.5-turbo` | model associated with the event. |
+> | customId | `string` | `customId` | customId associated with the event. |
From 7c09a50488e606c0d5b46ff2458042a671384cf5 Mon Sep 17 00:00:00 2001
From: Spike Lu
Date: Tue, 6 Feb 2024 10:39:18 -0800
Subject: [PATCH 13/71] update goopenai
---
go.mod | 2 +-
go.sum | 2 ++
2 files changed, 3 insertions(+), 1 deletion(-)
diff --git a/go.mod b/go.mod
index 1b4a791..9c500a9 100644
--- a/go.mod
+++ b/go.mod
@@ -12,7 +12,7 @@ require (
github.com/mattn/go-colorable v0.1.13
github.com/pkoukk/tiktoken-go-loader v0.0.1
github.com/redis/go-redis/v9 v9.0.5
- github.com/sashabaranov/go-openai v1.17.7
+ github.com/sashabaranov/go-openai v1.19.2
github.com/stretchr/testify v1.8.4
go.uber.org/zap v1.24.0
)
diff --git a/go.sum b/go.sum
index fc57456..b71b7c1 100644
--- a/go.sum
+++ b/go.sum
@@ -85,6 +85,8 @@ github.com/sashabaranov/go-openai v1.17.1 h1:tapFKbKE8ep0/qGkKp5Q3TtxWUD7m9VIFe9
github.com/sashabaranov/go-openai v1.17.1/go.mod h1:lj5b/K+zjTSFxVLijLSTDZuP7adOgerWeFyZLUhAKRg=
github.com/sashabaranov/go-openai v1.17.7 h1:MPcAwlwbeo7ZmhQczoOgZBHtIBY1TfZqsdx6+/ndloM=
github.com/sashabaranov/go-openai v1.17.7/go.mod h1:lj5b/K+zjTSFxVLijLSTDZuP7adOgerWeFyZLUhAKRg=
+github.com/sashabaranov/go-openai v1.19.2 h1:+dkuCADSnwXV02YVJkdphY8XD9AyHLUWwk6V7LB6EL8=
+github.com/sashabaranov/go-openai v1.19.2/go.mod h1:lj5b/K+zjTSFxVLijLSTDZuP7adOgerWeFyZLUhAKRg=
github.com/sirupsen/logrus v1.7.0/go.mod h1:yWOB1SBYBC5VeMP7gHvWumXLIWorT60ONWic61uBYv0=
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw=
From fbb4f785db4560af31a0c20c89df38e831c69bb0 Mon Sep 17 00:00:00 2001
From: Spike Lu
Date: Tue, 6 Feb 2024 10:39:56 -0800
Subject: [PATCH 14/71] add env variable for number of message consumers
---
internal/config/config.go | 39 ++++++++++++++++++++-------------------
1 file changed, 20 insertions(+), 19 deletions(-)
diff --git a/internal/config/config.go b/internal/config/config.go
index 784bc2b..ced3a3b 100644
--- a/internal/config/config.go
+++ b/internal/config/config.go
@@ -7,25 +7,26 @@ import (
)
type Config struct {
- PostgresqlHosts string `env:"POSTGRESQL_HOSTS" envSeparator:":" envDefault:"localhost"`
- PostgresqlDbName string `env:"POSTGRESQL_DB_NAME"`
- PostgresqlUsername string `env:"POSTGRESQL_USERNAME"`
- PostgresqlPassword string `env:"POSTGRESQL_PASSWORD"`
- PostgresqlSslMode string `env:"POSTGRESQL_SSL_MODE" envDefault:"disable"`
- PostgresqlPort string `env:"POSTGRESQL_PORT" envDefault:"5432"`
- RedisHosts string `env:"REDIS_HOSTS" envSeparator:":" envDefault:"localhost"`
- RedisPort string `env:"REDIS_PORT" envDefault:"6379"`
- RedisUsername string `env:"REDIS_USERNAME"`
- RedisPassword string `env:"REDIS_PASSWORD"`
- RedisReadTimeout time.Duration `env:"REDIS_READ_TIME_OUT" envDefault:"1s"`
- RedisWriteTimeout time.Duration `env:"REDIS_WRITE_TIME_OUT" envDefault:"500ms"`
- PostgresqlReadTimeout time.Duration `env:"POSTGRESQL_READ_TIME_OUT" envDefault:"2s"`
- PostgresqlWriteTimeout time.Duration `env:"POSTGRESQL_WRITE_TIME_OUT" envDefault:"1s"`
- InMemoryDbUpdateInterval time.Duration `env:"IN_MEMORY_DB_UPDATE_INTERVAL" envDefault:"5s"`
- OpenAiKey string `env:"OPENAI_API_KEY"`
- StatsProvider string `env:"STATS_PROVIDER"`
- AdminPass string `env:"ADMIN_PASS"`
- ProxyTimeout time.Duration `env:"PROXY_TIMEOUT" envDefault:"600s"`
+ PostgresqlHosts string `env:"POSTGRESQL_HOSTS" envSeparator:":" envDefault:"localhost"`
+ PostgresqlDbName string `env:"POSTGRESQL_DB_NAME"`
+ PostgresqlUsername string `env:"POSTGRESQL_USERNAME"`
+ PostgresqlPassword string `env:"POSTGRESQL_PASSWORD"`
+ PostgresqlSslMode string `env:"POSTGRESQL_SSL_MODE" envDefault:"disable"`
+ PostgresqlPort string `env:"POSTGRESQL_PORT" envDefault:"5432"`
+ RedisHosts string `env:"REDIS_HOSTS" envSeparator:":" envDefault:"localhost"`
+ RedisPort string `env:"REDIS_PORT" envDefault:"6379"`
+ RedisUsername string `env:"REDIS_USERNAME"`
+ RedisPassword string `env:"REDIS_PASSWORD"`
+ RedisReadTimeout time.Duration `env:"REDIS_READ_TIME_OUT" envDefault:"1s"`
+ RedisWriteTimeout time.Duration `env:"REDIS_WRITE_TIME_OUT" envDefault:"500ms"`
+ PostgresqlReadTimeout time.Duration `env:"POSTGRESQL_READ_TIME_OUT" envDefault:"2s"`
+ PostgresqlWriteTimeout time.Duration `env:"POSTGRESQL_WRITE_TIME_OUT" envDefault:"1s"`
+ InMemoryDbUpdateInterval time.Duration `env:"IN_MEMORY_DB_UPDATE_INTERVAL" envDefault:"5s"`
+ OpenAiKey string `env:"OPENAI_API_KEY"`
+ StatsProvider string `env:"STATS_PROVIDER"`
+ AdminPass string `env:"ADMIN_PASS"`
+ ProxyTimeout time.Duration `env:"PROXY_TIMEOUT" envDefault:"600s"`
+ NumberOfEventMessageConsumers int `env:"NUMBER_OF_EVENT_MESSAGE_CONSUMERS" envDefault:"3"`
}
func ParseEnvVariables() (*Config, error) {
From 2fe437feb97d8b8677c6ed20bf272f77f8d52afe Mon Sep 17 00:00:00 2001
From: Spike Lu
Date: Tue, 6 Feb 2024 10:40:17 -0800
Subject: [PATCH 15/71] add cost limit error
---
internal/errors/cost_limit_err.go | 17 +++++++++++++++++
1 file changed, 17 insertions(+)
create mode 100644 internal/errors/cost_limit_err.go
diff --git a/internal/errors/cost_limit_err.go b/internal/errors/cost_limit_err.go
new file mode 100644
index 0000000..15c92de
--- /dev/null
+++ b/internal/errors/cost_limit_err.go
@@ -0,0 +1,17 @@
+package errors
+
+type CostLimitError struct {
+ message string
+}
+
+func NewCostLimitError(msg string) *CostLimitError {
+ return &CostLimitError{
+ message: msg,
+ }
+}
+
+func (cle *CostLimitError) Error() string {
+ return cle.message
+}
+
+func (rle *CostLimitError) CostLimit() {}
From b9bb1a538163733e5428ff034f7863be39657376 Mon Sep 17 00:00:00 2001
From: Spike Lu
Date: Tue, 6 Feb 2024 10:45:20 -0800
Subject: [PATCH 16/71] add event
---
.../event/event_with_request_and_response.go | 16 ++++++++++++++++
1 file changed, 16 insertions(+)
create mode 100644 internal/event/event_with_request_and_response.go
diff --git a/internal/event/event_with_request_and_response.go b/internal/event/event_with_request_and_response.go
new file mode 100644
index 0000000..8e99f0b
--- /dev/null
+++ b/internal/event/event_with_request_and_response.go
@@ -0,0 +1,16 @@
+package event
+
+import (
+ "github.com/bricks-cloud/bricksllm/internal/key"
+ "github.com/bricks-cloud/bricksllm/internal/provider/custom"
+)
+
+type EventWithRequestAndContent struct {
+ Event *Event
+ IsEmbeddingsRequest bool
+ RouteConfig *custom.RouteConfig
+ Request interface{}
+ Content string
+ Response interface{}
+ Key *key.ResponseKey
+}
From e97ab555cec144e7756d4760dba8370574b1123b Mon Sep 17 00:00:00 2001
From: Spike Lu
Date: Tue, 6 Feb 2024 10:45:46 -0800
Subject: [PATCH 17/71] add async event driven architecture
---
internal/message/bus.go | 23 ++
internal/message/consumer.go | 58 +++++
internal/message/handler.go | 422 +++++++++++++++++++++++++++++++++++
internal/message/message.go | 6 +
4 files changed, 509 insertions(+)
create mode 100644 internal/message/bus.go
create mode 100644 internal/message/consumer.go
create mode 100644 internal/message/handler.go
create mode 100644 internal/message/message.go
diff --git a/internal/message/bus.go b/internal/message/bus.go
new file mode 100644
index 0000000..7fe0749
--- /dev/null
+++ b/internal/message/bus.go
@@ -0,0 +1,23 @@
+package message
+
+type MessageBus struct {
+ Subscribers map[string][]chan<- Message
+}
+
+func NewMessageBus() *MessageBus {
+ return &MessageBus{
+ Subscribers: make(map[string][]chan<- Message),
+ }
+}
+
+func (mb *MessageBus) Subscribe(messageType string, subscriber chan<- Message) {
+ mb.Subscribers[messageType] = append(mb.Subscribers[messageType], subscriber)
+}
+
+func (mb *MessageBus) Publish(ms Message) {
+ subscribers := mb.Subscribers[ms.Type]
+
+ for _, subscriber := range subscribers {
+ subscriber <- ms
+ }
+}
diff --git a/internal/message/consumer.go b/internal/message/consumer.go
new file mode 100644
index 0000000..e8e4e02
--- /dev/null
+++ b/internal/message/consumer.go
@@ -0,0 +1,58 @@
+package message
+
+import (
+ "github.com/bricks-cloud/bricksllm/internal/event"
+ "github.com/bricks-cloud/bricksllm/internal/key"
+ "go.uber.org/zap"
+)
+
+type Consumer struct {
+ messageChan <-chan Message
+ done chan bool
+ log *zap.Logger
+ numOfEventConsumers int
+ handle func(Message) error
+}
+
+type recorder interface {
+ RecordKeySpend(keyId string, micros int64, costLimitUnit key.TimeUnit) error
+ RecordEvent(e *event.Event) error
+}
+
+func NewConsumer(mc <-chan Message, log *zap.Logger, num int, handle func(Message) error) *Consumer {
+ return &Consumer{
+ messageChan: mc,
+ done: make(chan bool),
+ log: log,
+ numOfEventConsumers: num,
+ handle: handle,
+ }
+}
+
+func (c *Consumer) StartEventMessageConsumers() {
+ for i := 0; i < c.numOfEventConsumers; i++ {
+ go func() {
+ for {
+ select {
+ case <-c.done:
+ c.log.Info("event message consumer stoped...")
+ return
+
+ case m := <-c.messageChan:
+ err := c.handle(m)
+ if err != nil {
+ continue
+ }
+
+ continue
+ }
+ }
+ }()
+ }
+}
+
+func (c *Consumer) Stop() {
+ c.log.Info("shutting down consumer...")
+
+ c.done <- true
+}
diff --git a/internal/message/handler.go b/internal/message/handler.go
new file mode 100644
index 0000000..04c5697
--- /dev/null
+++ b/internal/message/handler.go
@@ -0,0 +1,422 @@
+package message
+
+import (
+ "errors"
+ "strings"
+ "time"
+
+ "github.com/bricks-cloud/bricksllm/internal/event"
+ "github.com/bricks-cloud/bricksllm/internal/key"
+ "github.com/bricks-cloud/bricksllm/internal/provider/anthropic"
+ "github.com/bricks-cloud/bricksllm/internal/provider/custom"
+ "github.com/bricks-cloud/bricksllm/internal/stats"
+ "github.com/tidwall/gjson"
+ "go.uber.org/zap"
+
+ goopenai "github.com/sashabaranov/go-openai"
+)
+
+type anthropicEstimator interface {
+ EstimateTotalCost(model string, promptTks, completionTks int) (float64, error)
+ EstimateCompletionCost(model string, tks int) (float64, error)
+ EstimatePromptCost(model string, tks int) (float64, error)
+ Count(input string) int
+}
+
+type estimator interface {
+ EstimateChatCompletionPromptCostWithTokenCounts(r *goopenai.ChatCompletionRequest) (int, float64, error)
+ EstimateEmbeddingsCost(r *goopenai.EmbeddingRequest) (float64, error)
+ EstimateChatCompletionStreamCostWithTokenCounts(model, content string) (int, float64, error)
+ EstimateCompletionCost(model string, tks int) (float64, error)
+ EstimateTotalCost(model string, promptTks, completionTks int) (float64, error)
+ EstimateEmbeddingsInputCost(model string, tks int) (float64, error)
+ EstimateChatCompletionPromptTokenCounts(model string, r *goopenai.ChatCompletionRequest) (int, error)
+}
+
+type azureEstimator interface {
+ EstimateChatCompletionStreamCostWithTokenCounts(model, content string) (int, float64, error)
+ EstimateEmbeddingsCost(r *goopenai.EmbeddingRequest) (float64, error)
+ EstimateCompletionCost(model string, tks int) (float64, error)
+ EstimatePromptCost(model string, tks int) (float64, error)
+ EstimateTotalCost(model string, promptTks, completionTks int) (float64, error)
+ EstimateEmbeddingsInputCost(model string, tks int) (float64, error)
+}
+
+type validator interface {
+ Validate(k *key.ResponseKey, promptCost float64) error
+}
+
+type keyManager interface {
+ UpdateKey(id string, uk *key.UpdateKey) (*key.ResponseKey, error)
+}
+
+type rateLimitManager interface {
+ Increment(keyId string, timeUnit key.TimeUnit) error
+}
+
+type accessCache interface {
+ Set(key string, timeUnit key.TimeUnit) error
+}
+
+type Handler struct {
+ recorder recorder
+ log *zap.Logger
+ ae anthropicEstimator
+ e estimator
+ aze azureEstimator
+ v validator
+ km keyManager
+ rlm rateLimitManager
+ ac accessCache
+}
+
+func NewHandler(r recorder, log *zap.Logger, ae anthropicEstimator, e estimator, aze azureEstimator, v validator, km keyManager, rlm rateLimitManager, ac accessCache) *Handler {
+ return &Handler{
+ recorder: r,
+ log: log,
+ ae: ae,
+ e: e,
+ aze: aze,
+ v: v,
+ km: km,
+ rlm: rlm,
+ ac: ac,
+ }
+}
+
+func (h *Handler) HandleEvent(m Message) error {
+ stats.Incr("bricksllm.message.handler.handle_event.requests", nil, 1)
+
+ e, ok := m.Data.(*event.Event)
+ if !ok {
+ stats.Incr("bricksllm.message.handler.handle_event.event_parsing_error", nil, 1)
+ h.log.Info("message contains data that cannot be converted to event format", zap.Any("data", m.Data))
+ return errors.New("message data cannot be parsed as event")
+ }
+
+ start := time.Now()
+
+ err := h.recorder.RecordEvent(e)
+ if err != nil {
+ stats.Incr("bricksllm.message.handler.handle_event.record_event_error", nil, 1)
+ h.log.Sugar().Debugf("error when publishin event: %v", err)
+ return err
+ }
+
+ stats.Timing("bricksllm.message.handler.handle_event.record_event_latency", time.Now().Sub(start), nil, 1)
+ stats.Incr("bricksllm.message.handler.handle_event.success", nil, 1)
+
+ return nil
+}
+
+const (
+ anthropicPromptMagicNum int = 1
+ anthropicCompletionMagicNum int = 4
+)
+
+func countTokensFromJson(bytes []byte, contentLoc string) (int, error) {
+ content := getContentFromJson(bytes, contentLoc)
+ return custom.Count(content)
+}
+
+func getContentFromJson(bytes []byte, contentLoc string) string {
+ result := gjson.Get(string(bytes), contentLoc)
+ content := ""
+
+ if len(result.Str) != 0 {
+ content += result.Str
+ }
+
+ if result.IsArray() {
+ for _, val := range result.Array() {
+ if len(val.Str) != 0 {
+ content += val.Str
+ }
+ }
+ }
+
+ return content
+}
+
+type costLimitError interface {
+ Error() string
+ CostLimit()
+}
+
+type rateLimitError interface {
+ Error() string
+ RateLimit()
+}
+
+type expirationError interface {
+ Error() string
+ Reason() string
+}
+
+func (h *Handler) handleValidationResult(kc *key.ResponseKey, cost float64) error {
+ err := h.v.Validate(kc, cost)
+
+ if err != nil {
+ stats.Incr("bricksllm.message.handler.handle_validation_result.handle_validation_result", nil, 1)
+
+ // tested
+ if _, ok := err.(expirationError); ok {
+ stats.Incr("bricksllm.message.handler.handle_validation_result.expiraton_error", nil, 1)
+
+ truePtr := true
+ _, err = h.km.UpdateKey(kc.KeyId, &key.UpdateKey{
+ Revoked: &truePtr,
+ RevokedReason: "Key has expired or exceeded set spend limit",
+ })
+
+ if err != nil {
+ stats.Incr("bricksllm.message.handler.handle_validation_result.update_key_error", nil, 1)
+ return err
+ }
+
+ return nil
+ }
+
+ // tested
+ if _, ok := err.(rateLimitError); ok {
+ stats.Incr("bricksllm.message.handler.handle_validation_result.rate_limit_error", nil, 1)
+
+ err = h.ac.Set(kc.KeyId, kc.RateLimitUnit)
+ if err != nil {
+ stats.Incr("bricksllm.message.handler.handle_validation_result.set_rate_limit_error", nil, 1)
+ return err
+ }
+
+ return nil
+ }
+
+ // tested
+ if _, ok := err.(costLimitError); ok {
+ stats.Incr("bricksllm.message.handler.handle_validation_result.cost_limit_error", nil, 1)
+
+ err = h.ac.Set(kc.KeyId, kc.CostLimitInUsdUnit)
+ if err != nil {
+ stats.Incr("bricksllm.message.handler.handle_validation_result.set_cost_limit_error", nil, 1)
+ return err
+ }
+
+ return nil
+ }
+
+ return err
+ }
+
+ return nil
+}
+
+func (h *Handler) HandleEventWithRequestAndResponse(m Message) error {
+ e, ok := m.Data.(*event.EventWithRequestAndContent)
+ if !ok {
+ stats.Incr("bricksllm.message.handler.handle_event_with_request_and_response.message_data_parsing_error", nil, 1)
+ h.log.Debug("message contains data that cannot be converted to event with request and response format", zap.Any("data", m.Data))
+ return errors.New("message data cannot be parsed as event with request and response")
+ }
+
+ if e.Key != nil && !e.Key.Revoked && e.Event != nil {
+ err := h.decorateEvent(m)
+ if err != nil {
+ stats.Incr("bricksllm.message.handler.handle_event_with_request_and_response.decorate_event_error", nil, 1)
+ h.log.Debug("error when decorating event", zap.Error(err))
+ }
+
+ // tested
+ if e.Event.CostInUsd != 0 {
+ micros := int64(e.Event.CostInUsd * 1000000)
+ err = h.recorder.RecordKeySpend(e.Event.KeyId, micros, e.Key.CostLimitInUsdUnit)
+ if err != nil {
+ stats.Incr("bricksllm.message.handler.handle_event_with_request_and_response.record_key_spend_error", nil, 1)
+ h.log.Debug("error when recording key spend", zap.Error(err))
+ }
+ }
+
+ // tested
+ if len(e.Key.RateLimitUnit) != 0 {
+ if err := h.rlm.Increment(e.Key.KeyId, e.Key.RateLimitUnit); err != nil {
+ stats.Incr("bricksllm.message.handler.handle_event_with_request_and_response.rate_limit_increment_error", nil, 1)
+
+ h.log.Debug("error when incrementing rate limit", zap.Error(err))
+ }
+ }
+
+ // tested
+ err = h.handleValidationResult(e.Key, e.Event.CostInUsd)
+ if err != nil {
+ stats.Incr("bricksllm.message.handler.handle_event_with_request_and_response.handle_validation_result_error", nil, 1)
+ h.log.Debug("error when handling validation result", zap.Error(err))
+ }
+
+ }
+
+ // tested
+ start := time.Now()
+ err := h.recorder.RecordEvent(e.Event)
+ if err != nil {
+ stats.Incr("bricksllm.message.handler.handle_event_with_request_and_response.record_event_error", nil, 1)
+ return err
+ }
+
+ stats.Timing("bricksllm.message.handler.handle_event_with_request_and_response.latency", time.Now().Sub(start), nil, 1)
+ stats.Incr("bricksllm.message.handler.handle_event_with_request_and_response.success", nil, 1)
+
+ return nil
+}
+
+func (h *Handler) decorateEvent(m Message) error {
+ stats.Incr("bricksllm.message.handler.decorate_event.request", nil, 1)
+
+ e, ok := m.Data.(*event.EventWithRequestAndContent)
+ if !ok {
+ stats.Incr("bricksllm.message.handler.decorate_event.message_data_parsing_error", nil, 1)
+ h.log.Debug("message contains data that cannot be converted to event with request and response format", zap.Any("data", m.Data))
+ return errors.New("message data cannot be parsed as event with request and response")
+ }
+
+ // tested
+ if e.Event.Provider == "anthropic" && e.Event.Path == "/api/providers/anthropic/v1/complete" {
+ cr, ok := e.Request.(*anthropic.CompletionRequest)
+ if !ok {
+ stats.Incr("bricksllm.message.handler.decorate_event.event_request_parsing_error", nil, 1)
+ h.log.Debug("event contains request that cannot be converted to anthropic completion request", zap.Any("data", m.Data))
+ return errors.New("event request data cannot be parsed as anthropic completon request")
+ }
+
+ tks := h.ae.Count(cr.Prompt)
+ tks += anthropicPromptMagicNum
+
+ model := cr.Model
+ cost, err := h.ae.EstimatePromptCost(model, tks)
+ if err != nil {
+ stats.Incr("bricksllm.message.handler.decorate_event.estimate_prompt_cost", nil, 1)
+ h.log.Debug("event contains request that cannot be converted to anthropic completion request", zap.Error(err))
+ return err
+ }
+
+ completiontks := h.ae.Count(e.Content)
+ completiontks += anthropicCompletionMagicNum
+
+ completionCost, err := h.ae.EstimateCompletionCost(model, completiontks)
+ if err != nil {
+ stats.Incr("bricksllm.message.handler.decorate_event.estimate_completion_cost_error", nil, 1)
+ return err
+ }
+
+ e.Event.PromptTokenCount = tks
+ e.Event.CompletionTokenCount = completiontks
+ e.Event.CostInUsd = completionCost + cost
+ }
+
+ // tested
+ if e.Event.Provider == "azure" && e.Event.Path == "/api/providers/azure/openai/deployments/:deployment_id/chat/completions" {
+ ccr, ok := e.Request.(*goopenai.ChatCompletionRequest)
+ if !ok {
+ stats.Incr("bricksllm.message.handler.decorate_event.event_request_parsing_error", nil, 1)
+ h.log.Debug("event contains data that cannot be converted to azure openai completion request", zap.Any("data", m.Data))
+ return errors.New("event request data cannot be parsed as azure openai completon request")
+ }
+
+ if ccr.Stream {
+ tks, err := h.e.EstimateChatCompletionPromptTokenCounts("gpt-3.5-turbo", ccr)
+ if err != nil {
+ stats.Incr("bricksllm.message.decorate_event.estimate_chat_completion_prompt_token_counts_error", nil, 1)
+ return err
+ }
+
+ cost, err := h.aze.EstimatePromptCost(e.Event.Model, tks)
+ if err != nil {
+ stats.Incr("bricksllm.message.decorate_event.estimate_prompt_cost_error", nil, 1)
+ return err
+ }
+
+ completiontks, completionCost, err := h.aze.EstimateChatCompletionStreamCostWithTokenCounts(e.Event.Model, e.Content)
+ if err != nil {
+ stats.Incr("bricksllm.message.decorate_event.estimate_chat_completion_stream_cost_with_token_counts_error", nil, 1)
+ return err
+ }
+
+ e.Event.PromptTokenCount = tks
+ e.Event.CompletionTokenCount = completiontks
+ e.Event.CostInUsd = cost + completionCost
+ }
+ }
+
+ // tested
+ if e.Event.Provider == "openai" && e.Event.Path == "/api/providers/openai/v1/chat/completions" {
+ ccr, ok := e.Request.(*goopenai.ChatCompletionRequest)
+ if !ok {
+ stats.Incr("bricksllm.message.handler.decorate_event.event_request_parsing_error", nil, 1)
+ h.log.Debug("event contains data that cannot be converted to openai completion request", zap.Any("data", m.Data))
+ return errors.New("event request data cannot be parsed as oepnai completon request")
+ }
+
+ if ccr.Stream {
+ tks, cost, err := h.e.EstimateChatCompletionPromptCostWithTokenCounts(ccr)
+ if err != nil {
+ stats.Incr("bricksllm.message.handler.decorate_event.estimate_chat_completion_prompt_cost_with_token_counts", nil, 1)
+ return err
+ }
+
+ completiontks, completionCost, err := h.e.EstimateChatCompletionStreamCostWithTokenCounts(e.Event.Model, e.Content)
+ if err != nil {
+ stats.Incr("bricksllm.message.handler.decorate_event.estimate_chat_completion_stream_cost_with_token_counts", nil, 1)
+ return err
+ }
+
+ e.Event.PromptTokenCount = tks
+ e.Event.CompletionTokenCount = completiontks
+ e.Event.CostInUsd = cost + completionCost
+ }
+ }
+
+ if strings.HasPrefix(e.Event.Path, "/api/custom/providers/:provider") && e.RouteConfig != nil {
+ body, ok := e.Request.([]byte)
+ if !ok {
+ stats.Incr("bricksllm.message.handler.decorate_event.event_request_custom_provider_parsing_error", nil, 1)
+ h.log.Debug("event contains request that cannot be converted to bytes", zap.Any("data", m.Data))
+ return errors.New("event request data cannot be parsed as anthropic completon request")
+ }
+
+ content, ok := e.Response.([]byte)
+ if !ok {
+ stats.Incr("bricksllm.message.handler.decorate_event.event_response_custom_provider_parsing_error", nil, 1)
+ h.log.Debug("event contains response that cannot be converted to bytes", zap.Any("data", m.Data))
+ return errors.New("event response data cannot be converted to bytes")
+ }
+
+ tks, err := countTokensFromJson(body, e.RouteConfig.RequestPromptLocation)
+ if err != nil {
+ stats.Incr("bricksllm.message.handler.decorate_event.count_tokens_from_json_error", nil, 1)
+
+ return err
+ }
+
+ e.Event.PromptTokenCount = tks
+
+ result := gjson.Get(string(body), e.RouteConfig.StreamLocation)
+ if result.IsBool() {
+ completiontks, err := custom.Count(e.Content)
+ if err != nil {
+ stats.Incr("bricksllm.message.handler.decorate_event.custom_count_error", nil, 1)
+ return err
+ }
+
+ e.Event.CompletionTokenCount = completiontks
+ }
+
+ if !result.IsBool() {
+ completiontks, err := countTokensFromJson(content, e.RouteConfig.ResponseCompletionLocation)
+ if err != nil {
+ stats.Incr("bricksllm.message.handler.decorate_event.count_tokens_from_json_error", nil, 1)
+ return err
+ }
+
+ e.Event.CompletionTokenCount = completiontks
+ }
+ }
+
+ return nil
+}
diff --git a/internal/message/message.go b/internal/message/message.go
new file mode 100644
index 0000000..a2e97a5
--- /dev/null
+++ b/internal/message/message.go
@@ -0,0 +1,6 @@
+package message
+
+type Message struct {
+ Type string
+ Data interface{}
+}
From 838bfb92d382ec09cba90cb720f8f6bbd258cf26 Mon Sep 17 00:00:00 2001
From: Spike Lu
Date: Tue, 6 Feb 2024 10:46:09 -0800
Subject: [PATCH 18/71] add redis cache for access
---
internal/storage/redis/access-cache.go | 48 ++++++++++++++++++++++++++
1 file changed, 48 insertions(+)
create mode 100644 internal/storage/redis/access-cache.go
diff --git a/internal/storage/redis/access-cache.go b/internal/storage/redis/access-cache.go
new file mode 100644
index 0000000..d8dccef
--- /dev/null
+++ b/internal/storage/redis/access-cache.go
@@ -0,0 +1,48 @@
+package redis
+
+import (
+ "context"
+ "time"
+
+ "github.com/bricks-cloud/bricksllm/internal/key"
+ "github.com/redis/go-redis/v9"
+)
+
+type AccessCache struct {
+ client *redis.Client
+ wt time.Duration
+ rt time.Duration
+}
+
+func NewAccessCache(c *redis.Client, wt time.Duration, rt time.Duration) *AccessCache {
+ return &AccessCache{
+ client: c,
+ wt: wt,
+ rt: rt,
+ }
+}
+
+func (ac *AccessCache) Set(key string, timeUnit key.TimeUnit) error {
+ ttl, err := getCounterTtl(timeUnit)
+ if err != nil {
+ return err
+ }
+
+ ctx, cancel := context.WithTimeout(context.Background(), ac.wt)
+ defer cancel()
+ err = ac.client.Set(ctx, key, true, ttl.Sub(time.Now())).Err()
+ if err != nil {
+ return err
+ }
+
+ return nil
+}
+
+func (ac *AccessCache) GetAccessStatus(key string) bool {
+ ctx, cancel := context.WithTimeout(context.Background(), ac.rt)
+ defer cancel()
+
+ result := ac.client.Get(ctx, key)
+
+ return result.Err() != redis.Nil
+}
From ed7c64de0160d4a99f046c37709123cdd5d3f109 Mon Sep 17 00:00:00 2001
From: Spike Lu
Date: Tue, 6 Feb 2024 10:46:51 -0800
Subject: [PATCH 19/71] update goopenai
---
internal/provider/openai/cost.go | 10 +++++-----
1 file changed, 5 insertions(+), 5 deletions(-)
diff --git a/internal/provider/openai/cost.go b/internal/provider/openai/cost.go
index 2bec29b..b645b13 100644
--- a/internal/provider/openai/cost.go
+++ b/internal/provider/openai/cost.go
@@ -198,7 +198,7 @@ func (ce *CostEstimator) EstimateChatCompletionStreamCostWithTokenCounts(model s
}
func (ce *CostEstimator) EstimateEmbeddingsCost(r *goopenai.EmbeddingRequest) (float64, error) {
- if len(r.Model.String()) == 0 {
+ if len(string(r.Model)) == 0 {
return 0, errors.New("model is not provided")
}
@@ -210,7 +210,7 @@ func (ce *CostEstimator) EstimateEmbeddingsCost(r *goopenai.EmbeddingRequest) (f
return 0, errors.New("input is not string")
}
- tks, err := ce.tc.Count(r.Model.String(), converted)
+ tks, err := ce.tc.Count(string(r.Model), converted)
if err != nil {
return 0, err
}
@@ -218,14 +218,14 @@ func (ce *CostEstimator) EstimateEmbeddingsCost(r *goopenai.EmbeddingRequest) (f
total += tks
}
- return ce.EstimateEmbeddingsInputCost(r.Model.String(), total)
+ return ce.EstimateEmbeddingsInputCost(string(r.Model), total)
} else if input, ok := r.Input.(string); ok {
- tks, err := ce.tc.Count(r.Model.String(), input)
+ tks, err := ce.tc.Count(string(r.Model), input)
if err != nil {
return 0, err
}
- return ce.EstimateEmbeddingsInputCost(r.Model.String(), tks)
+ return ce.EstimateEmbeddingsInputCost(string(r.Model), tks)
}
return 0, errors.New("input format is not recognized")
From f84046e2f651fa636157f961ff1eb760a6e5e127 Mon Sep 17 00:00:00 2001
From: Spike Lu
Date: Tue, 6 Feb 2024 10:47:04 -0800
Subject: [PATCH 20/71] fix validator issues
---
internal/validator/validator.go | 16 ++++++++--------
1 file changed, 8 insertions(+), 8 deletions(-)
diff --git a/internal/validator/validator.go b/internal/validator/validator.go
index 9551261..d2fd4ba 100644
--- a/internal/validator/validator.go
+++ b/internal/validator/validator.go
@@ -58,12 +58,12 @@ func (v *Validator) Validate(k *key.ResponseKey, promptCost float64) error {
return err
}
- err = v.validateCostLimitOverTime(k.KeyId, k.CostLimitInUsdOverTime, k.CostLimitInUsdUnit, promptCost)
+ err = v.validateCostLimitOverTime(k.KeyId, k.CostLimitInUsdOverTime, k.CostLimitInUsdUnit)
if err != nil {
return err
}
- err = v.validateCostLimit(k.KeyId, k.CostLimitInUsd, promptCost)
+ err = v.validateCostLimit(k.KeyId, k.CostLimitInUsd)
if err != nil {
return err
}
@@ -96,14 +96,14 @@ func (v *Validator) validateRateLimitOverTime(keyId string, rateLimitOverTime in
return errors.New("failed to get rate limit counter")
}
- if c+1 > int64(rateLimitOverTime) {
+ if c >= int64(rateLimitOverTime) {
return internal_errors.NewRateLimitError(fmt.Sprintf("key exceeded rate limit %d requests per %s", rateLimitOverTime, rateLimitUnit))
}
return nil
}
-func (v *Validator) validateCostLimitOverTime(keyId string, costLimitOverTime float64, costLimitUnit key.TimeUnit, promptCost float64) error {
+func (v *Validator) validateCostLimitOverTime(keyId string, costLimitOverTime float64, costLimitUnit key.TimeUnit) error {
if costLimitOverTime == 0 {
return nil
}
@@ -113,8 +113,8 @@ func (v *Validator) validateCostLimitOverTime(keyId string, costLimitOverTime fl
return errors.New("failed to get cached token cost")
}
- if convertDollarToMicroDollars(promptCost)+cachedCost > convertDollarToMicroDollars(costLimitOverTime) {
- return internal_errors.NewExpirationError(fmt.Sprintf("cost limit: %f has been reached for the current time period: %s", costLimitOverTime, costLimitUnit), internal_errors.CostLimitExpiration)
+ if cachedCost >= convertDollarToMicroDollars(costLimitOverTime) {
+ return internal_errors.NewCostLimitError(fmt.Sprintf("cost limit: %f has been reached for the current time period: %s", costLimitOverTime, costLimitUnit))
}
return nil
@@ -124,7 +124,7 @@ func convertDollarToMicroDollars(dollar float64) int64 {
return int64(dollar * 1000000)
}
-func (v *Validator) validateCostLimit(keyId string, costLimit float64, promptCost float64) error {
+func (v *Validator) validateCostLimit(keyId string, costLimit float64) error {
if costLimit == 0 {
return nil
}
@@ -134,7 +134,7 @@ func (v *Validator) validateCostLimit(keyId string, costLimit float64, promptCos
return errors.New("failed to get total token cost")
}
- if convertDollarToMicroDollars(promptCost)+existingTotalCost > convertDollarToMicroDollars(costLimit) {
+ if existingTotalCost >= convertDollarToMicroDollars(costLimit) {
return internal_errors.NewExpirationError(fmt.Sprintf("total cost limit: %f has been reached", costLimit), internal_errors.CostLimitExpiration)
}
From 7001d142131b1fe82a94435e7b720be56e0b16ff Mon Sep 17 00:00:00 2001
From: Spike Lu
Date: Tue, 6 Feb 2024 10:48:52 -0800
Subject: [PATCH 21/71] integrate event driven architecture
---
cmd/bricksllm/main.go | 26 +-
cmd/tool/main.go | 269 ------------------
internal/server/web/proxy/anthropic.go | 97 +++----
.../server/web/proxy/azure_chat_completion.go | 56 ++--
internal/server/web/proxy/azure_embedding.go | 27 +-
internal/server/web/proxy/custom_provider.go | 26 +-
internal/server/web/proxy/middleware.go | 200 ++++++-------
internal/server/web/proxy/proxy.go | 86 +++---
internal/server/web/proxy/route.go | 20 +-
9 files changed, 286 insertions(+), 521 deletions(-)
delete mode 100644 cmd/tool/main.go
diff --git a/cmd/bricksllm/main.go b/cmd/bricksllm/main.go
index c8dba0d..69d3354 100644
--- a/cmd/bricksllm/main.go
+++ b/cmd/bricksllm/main.go
@@ -14,6 +14,7 @@ import (
"github.com/bricks-cloud/bricksllm/internal/config"
"github.com/bricks-cloud/bricksllm/internal/logger/zap"
"github.com/bricks-cloud/bricksllm/internal/manager"
+ "github.com/bricks-cloud/bricksllm/internal/message"
"github.com/bricks-cloud/bricksllm/internal/provider/anthropic"
"github.com/bricks-cloud/bricksllm/internal/provider/azure"
"github.com/bricks-cloud/bricksllm/internal/provider/custom"
@@ -171,10 +172,23 @@ func main() {
log.Sugar().Fatalf("error connecting to api redis cache: %v", err)
}
+ accessRedisCache := redis.NewClient(&redis.Options{
+ Addr: fmt.Sprintf("%s:%s", cfg.RedisHosts, cfg.RedisPort),
+ Password: cfg.RedisPassword,
+ DB: 4,
+ })
+
+ ctx, cancel = context.WithTimeout(context.Background(), 2*time.Second)
+ defer cancel()
+ if err := apiRedisCache.Ping(ctx).Err(); err != nil {
+ log.Sugar().Fatalf("error connecting to api redis cache: %v", err)
+ }
+
rateLimitCache := redisStorage.NewCache(rateLimitRedisCache, cfg.RedisWriteTimeout, cfg.RedisReadTimeout)
costLimitCache := redisStorage.NewCache(costLimitRedisCache, cfg.RedisWriteTimeout, cfg.RedisReadTimeout)
costStorage := redisStorage.NewStore(costRedisStorage, cfg.RedisWriteTimeout, cfg.RedisReadTimeout)
apiCache := redisStorage.NewCache(apiRedisCache, cfg.RedisWriteTimeout, cfg.RedisReadTimeout)
+ accessCache := redisStorage.NewAccessCache(accessRedisCache, cfg.RedisWriteTimeout, cfg.RedisReadTimeout)
m := manager.NewManager(store)
krm := manager.NewReportingManager(costStorage, store, store)
@@ -209,7 +223,16 @@ func main() {
c := cache.NewCache(apiCache)
- ps, err := proxy.NewProxyServer(log, *modePtr, *privacyPtr, c, m, rm, a, psm, cpm, store, memStore, ce, ace, aoe, v, rec, rlm, cfg.ProxyTimeout)
+ messageBus := message.NewMessageBus()
+ eventMessageChan := make(chan message.Message)
+ messageBus.Subscribe("event", eventMessageChan)
+
+ handler := message.NewHandler(rec, log, ace, ce, aoe, v, m, rlm, accessCache)
+
+ eventConsumer := message.NewConsumer(eventMessageChan, log, 4, handler.HandleEventWithRequestAndResponse)
+ eventConsumer.StartEventMessageConsumers()
+
+ ps, err := proxy.NewProxyServer(log, *modePtr, *privacyPtr, c, m, rm, a, psm, cpm, store, memStore, ce, ace, aoe, v, rec, messageBus, rlm, cfg.ProxyTimeout, accessCache)
if err != nil {
log.Sugar().Fatalf("error creating proxy http server: %v", err)
}
@@ -220,6 +243,7 @@ func main() {
signal.Notify(quit, syscall.SIGINT, syscall.SIGTERM)
<-quit
+ eventConsumer.Stop()
memStore.Stop()
psMemStore.Stop()
cpMemStore.Stop()
diff --git a/cmd/tool/main.go b/cmd/tool/main.go
deleted file mode 100644
index 11158b8..0000000
--- a/cmd/tool/main.go
+++ /dev/null
@@ -1,269 +0,0 @@
-package main
-
-import (
- "context"
- "flag"
- "fmt"
- "os"
- "os/signal"
- "syscall"
- "time"
-
- auth "github.com/bricks-cloud/bricksllm/internal/authenticator"
- "github.com/bricks-cloud/bricksllm/internal/cache"
- "github.com/bricks-cloud/bricksllm/internal/config"
- logger "github.com/bricks-cloud/bricksllm/internal/logger/zap"
- "github.com/bricks-cloud/bricksllm/internal/manager"
- "github.com/bricks-cloud/bricksllm/internal/provider/anthropic"
- "github.com/bricks-cloud/bricksllm/internal/provider/azure"
- "github.com/bricks-cloud/bricksllm/internal/provider/custom"
- "github.com/bricks-cloud/bricksllm/internal/provider/openai"
- "github.com/bricks-cloud/bricksllm/internal/recorder"
- "github.com/bricks-cloud/bricksllm/internal/server/web/admin"
- "github.com/bricks-cloud/bricksllm/internal/server/web/proxy"
- "github.com/bricks-cloud/bricksllm/internal/stats"
- "github.com/bricks-cloud/bricksllm/internal/storage/memdb"
- "github.com/bricks-cloud/bricksllm/internal/storage/postgresql"
- redisStorage "github.com/bricks-cloud/bricksllm/internal/storage/redis"
- "github.com/bricks-cloud/bricksllm/internal/validator"
- "github.com/gin-gonic/gin"
- "github.com/redis/go-redis/v9"
-)
-
-func main() {
- modePtr := flag.String("m", "dev", "select the mode that bricksllm runs in")
- privacyPtr := flag.String("p", "strict", "select the privacy mode that bricksllm runs in")
- flag.Parse()
-
- log := logger.NewZapLogger(*modePtr)
-
- gin.SetMode(gin.ReleaseMode)
-
- cfg, err := config.ParseEnvVariables()
- if err != nil {
- log.Sugar().Fatalf("cannot parse environment variables: %v", err)
- }
-
- err = stats.InitializeClient(cfg.StatsProvider)
- if err != nil {
- log.Sugar().Fatalf("cannot connect to telemetry provider: %v", err)
- }
-
- store, err := postgresql.NewStore(
- fmt.Sprintf("postgresql:///%s?sslmode=%s&user=%s&password=%s&host=%s&port=%s", cfg.PostgresqlDbName, cfg.PostgresqlSslMode, cfg.PostgresqlUsername, cfg.PostgresqlPassword, cfg.PostgresqlHosts, cfg.PostgresqlPort),
- cfg.PostgresqlWriteTimeout,
- cfg.PostgresqlReadTimeout,
- )
-
- if err != nil {
- log.Sugar().Fatalf("cannot connect to postgresql: %v", err)
- }
-
- err = store.CreateCustomProvidersTable()
- if err != nil {
- log.Sugar().Fatalf("error creating custom providers table: %v", err)
- }
-
- err = store.CreateRoutesTable()
- if err != nil {
- log.Sugar().Fatalf("error creating routes table: %v", err)
- }
-
- err = store.CreateKeysTable()
- if err != nil {
- log.Sugar().Fatalf("error creating keys table: %v", err)
- }
-
- err = store.AlterKeysTable()
- if err != nil {
- log.Sugar().Fatalf("error altering keys table: %v", err)
- }
-
- err = store.CreateEventsTable()
- if err != nil {
- log.Sugar().Fatalf("error creating events table: %v", err)
- }
-
- err = store.AlterEventsTable()
- if err != nil {
- log.Sugar().Fatalf("error altering events table: %v", err)
- }
-
- err = store.CreateProviderSettingsTable()
- if err != nil {
- log.Sugar().Fatalf("error creating provider settings table: %v", err)
- }
-
- err = store.AlterProviderSettingsTable()
- if err != nil {
- log.Sugar().Fatalf("error altering provider settings table: %v", err)
- }
-
- memStore, err := memdb.NewMemDb(store, log, cfg.InMemoryDbUpdateInterval)
- if err != nil {
- log.Sugar().Fatalf("cannot initialize memdb: %v", err)
- }
- memStore.Listen()
-
- psMemStore, err := memdb.NewProviderSettingsMemDb(store, log, cfg.InMemoryDbUpdateInterval)
- if err != nil {
- log.Sugar().Fatalf("cannot initialize provider settings memdb: %v", err)
- }
- psMemStore.Listen()
-
- cpMemStore, err := memdb.NewCustomProvidersMemDb(store, log, cfg.InMemoryDbUpdateInterval)
- if err != nil {
- log.Sugar().Fatalf("cannot initialize custom providers memdb: %v", err)
- }
- cpMemStore.Listen()
-
- rMemStore, err := memdb.NewRoutesMemDb(store, log, cfg.InMemoryDbUpdateInterval)
- if err != nil {
- log.Sugar().Fatalf("cannot initialize routes memdb: %v", err)
- }
- rMemStore.Listen()
-
- rateLimitRedisCache := redis.NewClient(&redis.Options{
- Addr: fmt.Sprintf("%s:%s", cfg.RedisHosts, cfg.RedisPort),
- Password: cfg.RedisPassword,
- DB: 0,
- })
- ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
- defer cancel()
- if err := rateLimitRedisCache.Ping(ctx).Err(); err != nil {
- log.Sugar().Fatalf("error connecting to rate limit redis cache: %v", err)
- }
-
- costLimitRedisCache := redis.NewClient(&redis.Options{
- Addr: fmt.Sprintf("%s:%s", cfg.RedisHosts, cfg.RedisPort),
- Password: cfg.RedisPassword,
- DB: 1,
- })
-
- ctx, cancel = context.WithTimeout(context.Background(), 2*time.Second)
- defer cancel()
- if err := costLimitRedisCache.Ping(ctx).Err(); err != nil {
- log.Sugar().Fatalf("error connecting to cost limit redis cache: %v", err)
- }
-
- costRedisStorage := redis.NewClient(&redis.Options{
- Addr: fmt.Sprintf("%s:%s", cfg.RedisHosts, cfg.RedisPort),
- Password: cfg.RedisPassword,
- DB: 2,
- })
-
- ctx, cancel = context.WithTimeout(context.Background(), 2*time.Second)
- defer cancel()
- if err := costRedisStorage.Ping(ctx).Err(); err != nil {
- log.Sugar().Fatalf("error connecting to cost limit redis storage: %v", err)
- }
-
- apiRedisCache := redis.NewClient(&redis.Options{
- Addr: fmt.Sprintf("%s:%s", cfg.RedisHosts, cfg.RedisPort),
- Password: cfg.RedisPassword,
- DB: 3,
- })
-
- ctx, cancel = context.WithTimeout(context.Background(), 2*time.Second)
- defer cancel()
- if err := apiRedisCache.Ping(ctx).Err(); err != nil {
- log.Sugar().Fatalf("error connecting to api redis cache: %v", err)
- }
-
- rateLimitCache := redisStorage.NewCache(rateLimitRedisCache, cfg.RedisWriteTimeout, cfg.RedisReadTimeout)
- costLimitCache := redisStorage.NewCache(costLimitRedisCache, cfg.RedisWriteTimeout, cfg.RedisReadTimeout)
- costStorage := redisStorage.NewStore(costRedisStorage, cfg.RedisWriteTimeout, cfg.RedisReadTimeout)
- apiCache := redisStorage.NewCache(apiRedisCache, cfg.RedisWriteTimeout, cfg.RedisReadTimeout)
-
- m := manager.NewManager(store)
- krm := manager.NewReportingManager(costStorage, store, store)
- psm := manager.NewProviderSettingsManager(store, psMemStore)
- cpm := manager.NewCustomProvidersManager(store, cpMemStore)
- rm := manager.NewRouteManager(store, store, rMemStore, psMemStore)
-
- as, err := admin.NewAdminServer(log, *modePtr, m, krm, psm, cpm, rm, cfg.AdminPass)
- if err != nil {
- log.Sugar().Fatalf("error creating admin http server: %v", err)
- }
- as.Run()
-
- tc := openai.NewTokenCounter()
- custom.NewTokenCounter()
- atc, err := anthropic.NewTokenCounter()
- if err != nil {
- log.Sugar().Fatalf("error creating anthropic token counter: %v", err)
- }
-
- ae := anthropic.NewCostEstimator(atc)
-
- ce := openai.NewCostEstimator(openai.OpenAiPerThousandTokenCost, tc)
- v := validator.NewValidator(costLimitCache, rateLimitCache, costStorage)
- rec := recorder.NewRecorder(costStorage, costLimitCache, ce, store)
- rlm := manager.NewRateLimitManager(rateLimitCache)
- a := auth.NewAuthenticator(psm, memStore, rm)
-
- c := cache.NewCache(apiCache)
-
- aoe := azure.NewCostEstimator()
-
- ps, err := proxy.NewProxyServer(log, *modePtr, *privacyPtr, c, m, rm, a, psm, cpm, store, memStore, ce, ae, aoe, v, rec, rlm, cfg.ProxyTimeout)
- if err != nil {
- log.Sugar().Fatalf("error creating proxy http server: %v", err)
- }
-
- ps.Run()
-
- quit := make(chan os.Signal)
- signal.Notify(quit, syscall.SIGINT, syscall.SIGTERM)
- <-quit
-
- memStore.Stop()
- psMemStore.Stop()
- cpMemStore.Stop()
-
- log.Sugar().Info("shutting down server...")
-
- ctx, cancel = context.WithTimeout(context.Background(), 5*time.Second)
- defer cancel()
- if err := as.Shutdown(ctx); err != nil {
- log.Sugar().Debugf("admin server shutdown: %v", err)
- }
-
- ctx, cancel = context.WithTimeout(context.Background(), 5*time.Second)
- defer cancel()
- if err := ps.Shutdown(ctx); err != nil {
- log.Sugar().Debugf("proxy server shutdown: %v", err)
- }
-
- select {
- case <-ctx.Done():
- log.Sugar().Infof("timeout of 5 seconds")
- }
-
- err = store.DropKeysTable()
- if err != nil {
- log.Sugar().Fatalf("error dropping keys table: %v", err)
- }
-
- err = store.DropEventsTable()
- if err != nil {
- log.Sugar().Fatalf("error dropping events table: %v", err)
- }
-
- err = store.DropCustomProvidersTable()
- if err != nil {
- log.Sugar().Fatalf("error dropping custom providers table: %v", err)
- }
-
- err = store.DropProviderSettingsTable()
- if err != nil {
- log.Sugar().Fatalf("error dropping provider settings table: %v", err)
- }
-
- err = store.DropRoutesTable()
- if err != nil {
- log.Sugar().Fatalf("error dropping routes table: %v", err)
- }
-
- log.Sugar().Infof("server exited")
-}
diff --git a/internal/server/web/proxy/anthropic.go b/internal/server/web/proxy/anthropic.go
index a9aab8a..7f6a607 100644
--- a/internal/server/web/proxy/anthropic.go
+++ b/internal/server/web/proxy/anthropic.go
@@ -11,7 +11,6 @@ import (
"strings"
"time"
- "github.com/bricks-cloud/bricksllm/internal/key"
"github.com/bricks-cloud/bricksllm/internal/provider/anthropic"
"github.com/bricks-cloud/bricksllm/internal/stats"
"github.com/gin-gonic/gin"
@@ -61,13 +60,13 @@ func getCompletionHandler(r recorder, prod, private bool, client http.Client, km
return
}
- raw, exists := c.Get("key")
- kc, ok := raw.(*key.ResponseKey)
- if !exists || !ok {
- stats.Incr("bricksllm.proxy.get_completion_handler.api_key_not_registered", nil, 1)
- JSON(c, http.StatusUnauthorized, "[BricksLLM] api key is not registered")
- return
- }
+ // raw, exists := c.Get("key")
+ // kc, ok := raw.(*key.ResponseKey)
+ // if !exists || !ok {
+ // stats.Incr("bricksllm.proxy.get_completion_handler.api_key_not_registered", nil, 1)
+ // JSON(c, http.StatusUnauthorized, "[BricksLLM] api key is not registered")
+ // return
+ // }
copyHttpHeaders(c.Request, req)
@@ -96,7 +95,7 @@ func getCompletionHandler(r recorder, prod, private bool, client http.Client, km
}
}
- model := c.GetString("model")
+ // model := c.GetString("model")
if !isStreaming && res.StatusCode == http.StatusOK {
dur := time.Now().Sub(start)
@@ -109,8 +108,8 @@ func getCompletionHandler(r recorder, prod, private bool, client http.Client, km
return
}
- var cost float64 = 0
- completionTokens := 0
+ // var cost float64 = 0
+ // completionTokens := 0
completionRes := &anthropic.CompletionResponse{}
stats.Incr("bricksllm.proxy.get_completion_handler.success", nil, 1)
stats.Timing("bricksllm.proxy.get_completion_handler.success_latency", dur, nil, 1)
@@ -120,27 +119,29 @@ func getCompletionHandler(r recorder, prod, private bool, client http.Client, km
logError(log, "error when unmarshalling anthropic http completion response body", prod, cid, err)
}
- if err == nil {
- logCompletionResponse(log, bytes, prod, private, cid)
- completionTokens = e.Count(completionRes.Completion)
- completionTokens += anthropicCompletionMagicNum
- promptTokens := c.GetInt("promptTokenCount")
- cost, err = e.EstimateTotalCost(model, promptTokens, completionTokens)
- if err != nil {
- stats.Incr("bricksllm.proxy.get_completion_handler.estimate_total_cost_error", nil, 1)
- logError(log, "error when estimating anthropic cost", prod, cid, err)
- }
-
- micros := int64(cost * 1000000)
- err = r.RecordKeySpend(kc.KeyId, micros, kc.CostLimitInUsdUnit)
- if err != nil {
- stats.Incr("bricksllm.proxy.get_completion_handler.record_key_spend_error", nil, 1)
- logError(log, "error when recording anthropic spend", prod, cid, err)
- }
- }
-
- c.Set("costInUsd", cost)
- c.Set("completionTokenCount", completionTokens)
+ c.Set("content", completionRes.Completion)
+
+ // if err == nil {
+ // logCompletionResponse(log, bytes, prod, private, cid)
+ // completionTokens = e.Count(completionRes.Completion)
+ // completionTokens += anthropicCompletionMagicNum
+ // promptTokens := c.GetInt("promptTokenCount")
+ // cost, err = e.EstimateTotalCost(model, promptTokens, completionTokens)
+ // if err != nil {
+ // stats.Incr("bricksllm.proxy.get_completion_handler.estimate_total_cost_error", nil, 1)
+ // logError(log, "error when estimating anthropic cost", prod, cid, err)
+ // }
+
+ // micros := int64(cost * 1000000)
+ // err = r.RecordKeySpend(kc.KeyId, micros, kc.CostLimitInUsdUnit)
+ // if err != nil {
+ // stats.Incr("bricksllm.proxy.get_completion_handler.record_key_spend_error", nil, 1)
+ // logError(log, "error when recording anthropic spend", prod, cid, err)
+ // }
+ // }
+
+ // c.Set("costInUsd", cost)
+ // c.Set("completionTokenCount", completionTokens)
c.Data(res.StatusCode, "application/json", bytes)
return
@@ -163,24 +164,24 @@ func getCompletionHandler(r recorder, prod, private bool, client http.Client, km
}
buffer := bufio.NewReader(res.Body)
- var totalCost float64 = 0
+ // var totalCost float64 = 0
content := ""
- defer func() {
- tks := e.Count(content)
- model := c.GetString("model")
- cost, err := e.EstimateCompletionCost(model, tks)
- if err != nil {
- stats.Incr("bricksllm.proxy.get_completion_handler.estimate_completion_cost_error", nil, 1)
- logError(log, "error when estimating anthropic completion stream cost", prod, cid, err)
- }
-
- estimatedPromptCost := c.GetFloat64("estimatedPromptCostInUsd")
- totalCost = cost + estimatedPromptCost
-
- c.Set("costInUsd", totalCost)
- c.Set("completionTokenCount", tks+anthropicCompletionMagicNum)
- }()
+ // defer func() {
+ // tks := e.Count(content)
+ // model := c.GetString("model")
+ // cost, err := e.EstimateCompletionCost(model, tks)
+ // if err != nil {
+ // stats.Incr("bricksllm.proxy.get_completion_handler.estimate_completion_cost_error", nil, 1)
+ // logError(log, "error when estimating anthropic completion stream cost", prod, cid, err)
+ // }
+
+ // estimatedPromptCost := c.GetFloat64("estimatedPromptCostInUsd")
+ // totalCost = cost + estimatedPromptCost
+
+ // c.Set("costInUsd", totalCost)
+ // c.Set("completionTokenCount", tks+anthropicCompletionMagicNum)
+ // }()
stats.Incr("bricksllm.proxy.get_completion_handler.streaming_requests", nil, 1)
diff --git a/internal/server/web/proxy/azure_chat_completion.go b/internal/server/web/proxy/azure_chat_completion.go
index 970237a..2807071 100644
--- a/internal/server/web/proxy/azure_chat_completion.go
+++ b/internal/server/web/proxy/azure_chat_completion.go
@@ -11,7 +11,6 @@ import (
"net/http"
"time"
- "github.com/bricks-cloud/bricksllm/internal/key"
"github.com/bricks-cloud/bricksllm/internal/stats"
"github.com/gin-gonic/gin"
goopenai "github.com/sashabaranov/go-openai"
@@ -36,13 +35,6 @@ func getAzureChatCompletionHandler(r recorder, prod, private bool, psm ProviderS
}
cid := c.GetString(correlationId)
- raw, exists := c.Get("key")
- kc, ok := raw.(*key.ResponseKey)
- if !exists || !ok {
- stats.Incr("bricksllm.proxy.get_azure_chat_completion_handler.api_key_not_registered", nil, 1)
- JSON(c, http.StatusUnauthorized, "[BricksLLM] api key is not registered")
- return
- }
ctx, cancel := context.WithTimeout(context.Background(), timeOut)
defer cancel()
@@ -111,12 +103,12 @@ func getAzureChatCompletionHandler(r recorder, prod, private bool, psm ProviderS
logError(log, "error when estimating azure openai cost", prod, cid, err)
}
- micros := int64(cost * 1000000)
- err = r.RecordKeySpend(kc.KeyId, micros, kc.CostLimitInUsdUnit)
- if err != nil {
- stats.Incr("bricksllm.proxy.get_azure_chat_completion_handler.record_key_spend_error", nil, 1)
- logError(log, "error when recording azure openai spend", prod, cid, err)
- }
+ // micros := int64(cost * 1000000)
+ // err = r.RecordKeySpend(kc.KeyId, micros, kc.CostLimitInUsdUnit)
+ // if err != nil {
+ // stats.Incr("bricksllm.proxy.get_azure_chat_completion_handler.record_key_spend_error", nil, 1)
+ // logError(log, "error when recording azure openai spend", prod, cid, err)
+ // }
}
c.Set("costInUsd", cost)
@@ -145,8 +137,8 @@ func getAzureChatCompletionHandler(r recorder, prod, private bool, psm ProviderS
}
buffer := bufio.NewReader(res.Body)
- var totalCost float64 = 0
- var totalTokens int = 0
+ // var totalCost float64 = 0
+ // var totalTokens int = 0
content := ""
model := ""
@@ -155,24 +147,26 @@ func getAzureChatCompletionHandler(r recorder, prod, private bool, psm ProviderS
c.Set("model", model)
}
- tks, cost, err := aoe.EstimateChatCompletionStreamCostWithTokenCounts(model, content)
- if err != nil {
- stats.Incr("bricksllm.proxy.get_azure_chat_completion_handler.estimate_chat_completion_cost_and_tokens_error", nil, 1)
- logError(log, "error when estimating azure openai chat completion stream cost with token counts", prod, cid, err)
- }
+ c.Set("content", content)
- estimatedPromptTokenCounts := c.GetInt("promptTokenCount")
- promptCost, err := aoe.EstimatePromptCost(model, estimatedPromptTokenCounts)
- if err != nil {
- stats.Incr("bricksllm.proxy.get_azure_chat_completion_handler.estimate_chat_completion_cost_and_tokens_error", nil, 1)
- logError(log, "error when estimating azure openai chat completion stream cost with token counts", prod, cid, err)
- }
+ // tks, cost, err := aoe.EstimateChatCompletionStreamCostWithTokenCounts(model, content)
+ // if err != nil {
+ // stats.Incr("bricksllm.proxy.get_azure_chat_completion_handler.estimate_chat_completion_cost_and_tokens_error", nil, 1)
+ // logError(log, "error when estimating azure openai chat completion stream cost with token counts", prod, cid, err)
+ // }
+
+ // estimatedPromptTokenCounts := c.GetInt("promptTokenCount")
+ // promptCost, err := aoe.EstimatePromptCost(model, estimatedPromptTokenCounts)
+ // if err != nil {
+ // stats.Incr("bricksllm.proxy.get_azure_chat_completion_handler.estimate_chat_completion_cost_and_tokens_error", nil, 1)
+ // logError(log, "error when estimating azure openai chat completion stream cost with token counts", prod, cid, err)
+ // }
- totalCost = cost + promptCost
- totalTokens += tks
+ // totalCost = cost + promptCost
+ // totalTokens += tks
- c.Set("costInUsd", totalCost)
- c.Set("completionTokenCount", totalTokens)
+ // c.Set("costInUsd", totalCost)
+ // c.Set("completionTokenCount", totalTokens)
}()
stats.Incr("bricksllm.proxy.get_azure_chat_completion_handler.streaming_requests", nil, 1)
diff --git a/internal/server/web/proxy/azure_embedding.go b/internal/server/web/proxy/azure_embedding.go
index c3c2830..f9ac0de 100644
--- a/internal/server/web/proxy/azure_embedding.go
+++ b/internal/server/web/proxy/azure_embedding.go
@@ -7,7 +7,6 @@ import (
"net/http"
"time"
- "github.com/bricks-cloud/bricksllm/internal/key"
"github.com/bricks-cloud/bricksllm/internal/stats"
"github.com/gin-gonic/gin"
goopenai "github.com/sashabaranov/go-openai"
@@ -23,13 +22,13 @@ func getAzureEmbeddingsHandler(r recorder, prod, private bool, psm ProviderSetti
}
cid := c.GetString(correlationId)
- raw, exists := c.Get("key")
- kc, ok := raw.(*key.ResponseKey)
- if !exists || !ok {
- stats.Incr("bricksllm.proxy.get_azure_embeddings_handler.api_key_not_registered", nil, 1)
- JSON(c, http.StatusUnauthorized, "[BricksLLM] api key is not registered")
- return
- }
+ // raw, exists := c.Get("key")
+ // kc, ok := raw.(*key.ResponseKey)
+ // if !exists || !ok {
+ // stats.Incr("bricksllm.proxy.get_azure_embeddings_handler.api_key_not_registered", nil, 1)
+ // JSON(c, http.StatusUnauthorized, "[BricksLLM] api key is not registered")
+ // return
+ // }
ctx, cancel := context.WithTimeout(context.Background(), timeOut)
defer cancel()
@@ -111,12 +110,12 @@ func getAzureEmbeddingsHandler(r recorder, prod, private bool, psm ProviderSetti
logError(log, "error when estimating azure openai cost for embedding", prod, cid, err)
}
- micros := int64(cost * 1000000)
- err = r.RecordKeySpend(kc.KeyId, micros, kc.CostLimitInUsdUnit)
- if err != nil {
- stats.Incr("bricksllm.proxy.get_azure_embeddings_handler.record_key_spend_error", nil, 1)
- logError(log, "error when recording azure openai spend for embedding", prod, cid, err)
- }
+ // micros := int64(cost * 1000000)
+ // err = r.RecordKeySpend(kc.KeyId, micros, kc.CostLimitInUsdUnit)
+ // if err != nil {
+ // stats.Incr("bricksllm.proxy.get_azure_embeddings_handler.record_key_spend_error", nil, 1)
+ // logError(log, "error when recording azure openai spend for embedding", prod, cid, err)
+ // }
}
}
diff --git a/internal/server/web/proxy/custom_provider.go b/internal/server/web/proxy/custom_provider.go
index bbc8291..1f9007f 100644
--- a/internal/server/web/proxy/custom_provider.go
+++ b/internal/server/web/proxy/custom_provider.go
@@ -120,12 +120,14 @@ func getCustomProviderHandler(prod, private bool, psm ProviderSettingsManager, c
return
}
- tks, err := countTokensFromJson(bytes, rc.ResponseCompletionLocation)
- if err != nil {
- logError(log, "error when counting tokens for custom provider completion response", prod, cid, err)
- }
+ c.Set("response", bytes)
+
+ // tks, err := countTokensFromJson(bytes, rc.ResponseCompletionLocation)
+ // if err != nil {
+ // logError(log, "error when counting tokens for custom provider completion response", prod, cid, err)
+ // }
- c.Set("completionTokenCount", tks)
+ // c.Set("completionTokenCount", tks)
c.Data(res.StatusCode, "application/json", bytes)
return
}
@@ -149,13 +151,15 @@ func getCustomProviderHandler(prod, private bool, psm ProviderSettingsManager, c
buffer := bufio.NewReader(res.Body)
aggregated := ""
defer func() {
- tks, err := custom.Count(aggregated)
- if err != nil {
- stats.Incr("bricksllm.proxy.get_custom_provider_handler.count_error", nil, 1)
- logError(log, "error when counting tokens for custom provider streaming response", prod, cid, err)
- }
+ c.Set("content", aggregated)
+
+ // tks, err := custom.Count(aggregated)
+ // if err != nil {
+ // stats.Incr("bricksllm.proxy.get_custom_provider_handler.count_error", nil, 1)
+ // logError(log, "error when counting tokens for custom provider streaming response", prod, cid, err)
+ // }
- c.Set("completionTokenCount", tks)
+ // c.Set("completionTokenCount", tks)
}()
stats.Incr("bricksllm.proxy.get_custom_provider_handler.streaming_requests", nil, 1)
diff --git a/internal/server/web/proxy/middleware.go b/internal/server/web/proxy/middleware.go
index fa5916d..e8b268d 100644
--- a/internal/server/web/proxy/middleware.go
+++ b/internal/server/web/proxy/middleware.go
@@ -11,6 +11,7 @@ import (
"github.com/bricks-cloud/bricksllm/internal/event"
"github.com/bricks-cloud/bricksllm/internal/key"
+ "github.com/bricks-cloud/bricksllm/internal/message"
"github.com/bricks-cloud/bricksllm/internal/provider"
"github.com/bricks-cloud/bricksllm/internal/provider/anthropic"
"github.com/bricks-cloud/bricksllm/internal/route"
@@ -72,6 +73,10 @@ type rateLimitManager interface {
Increment(keyId string, timeUnit key.TimeUnit) error
}
+type accessCache interface {
+ GetAccessStatus(key string) bool
+}
+
type encrypter interface {
Encrypt(secret string) string
}
@@ -93,7 +98,40 @@ type notFoundError interface {
NotFound()
}
-func getMiddleware(kms keyMemStorage, cpm CustomProvidersManager, rm routeManager, a authenticator, prod, private bool, e estimator, ae anthropicEstimator, aoe azureEstimator, v validator, ks keyStorage, log *zap.Logger, rlm rateLimitManager, r recorder, prefix string) gin.HandlerFunc {
+type publisher interface {
+ Publish(message.Message)
+}
+
+func getProvider(c *gin.Context) string {
+ existing := c.GetString("provider")
+ if len(existing) != 0 {
+ return existing
+ }
+
+ parts := strings.Split(c.FullPath(), "/")
+
+ spaceRemoved := []string{}
+
+ for _, part := range parts {
+ if len(part) != 0 {
+ spaceRemoved = append(spaceRemoved, part)
+ }
+ }
+
+ if strings.HasPrefix(c.FullPath(), "/api/providers/") {
+ if len(spaceRemoved) >= 3 {
+ return spaceRemoved[2]
+ }
+ }
+
+ if strings.HasPrefix(c.FullPath(), "/api/custom/providers/") {
+ return c.Param("provider")
+ }
+
+ return ""
+}
+
+func getMiddleware(kms keyMemStorage, cpm CustomProvidersManager, rm routeManager, a authenticator, prod, private bool, e estimator, ae anthropicEstimator, aoe azureEstimator, v validator, ks keyStorage, log *zap.Logger, rlm rateLimitManager, pub publisher, prefix string, ac accessCache) gin.HandlerFunc {
return func(c *gin.Context) {
if c == nil || c.Request == nil {
JSON(c, http.StatusInternalServerError, "[BricksLLM] request is empty")
@@ -110,17 +148,12 @@ func getMiddleware(kms keyMemStorage, cpm CustomProvidersManager, rm routeManage
c.Set(correlationId, cid)
start := time.Now()
- selectedProvider := "openai"
+ enrichedEvent := &event.EventWithRequestAndContent{}
customId := c.Request.Header.Get("X-CUSTOM-EVENT-ID")
defer func() {
dur := time.Now().Sub(start)
latency := int(dur.Milliseconds())
- raw, exists := c.Get("key")
- var kc *key.ResponseKey
- if exists {
- kc = raw.(*key.ResponseKey)
- }
if !prod {
log.Sugar().Infof("%s | %d | %s | %s | %dms", prefix, c.Writer.Status(), c.Request.Method, c.FullPath(), latency)
@@ -129,16 +162,14 @@ func getMiddleware(kms keyMemStorage, cpm CustomProvidersManager, rm routeManage
keyId := ""
tags := []string{}
- if kc != nil {
- keyId = kc.KeyId
- tags = kc.Tags
+ if enrichedEvent.Key != nil {
+ keyId = enrichedEvent.Key.KeyId
+ tags = enrichedEvent.Key.Tags
}
stats.Timing("bricksllm.proxy.get_middleware.proxy_latency_in_ms", dur, nil, 1)
- if len(c.GetString("provider")) != 0 {
- selectedProvider = c.GetString("provider")
- }
+ selectedProvider := getProvider(c)
if prod {
log.Info("response to proxy",
@@ -173,12 +204,16 @@ func getMiddleware(kms keyMemStorage, cpm CustomProvidersManager, rm routeManage
CustomId: customId,
}
- err := r.RecordEvent(evt)
- if err != nil {
- stats.Incr("bricksllm.proxy.get_middleware.record_event_error", nil, 1)
-
- logError(log, "error when recording openai event", prod, cid, err)
+ enrichedEvent.Event = evt
+ content := c.GetString("content")
+ if len(content) != 0 {
+ enrichedEvent.Content = content
}
+
+ pub.Publish(message.Message{
+ Type: "event",
+ Data: enrichedEvent,
+ })
}()
if len(c.FullPath()) == 0 {
@@ -189,6 +224,7 @@ func getMiddleware(kms keyMemStorage, cpm CustomProvidersManager, rm routeManage
}
kc, settings, err := a.AuthenticateHttpRequest(c.Request)
+ enrichedEvent.Key = kc
_, ok := err.(notAuthorizedError)
if ok {
stats.Incr("bricksllm.proxy.get_middleware.authentication_error", nil, 1)
@@ -236,12 +272,11 @@ func getMiddleware(kms keyMemStorage, cpm CustomProvidersManager, rm routeManage
c.Request.Body = io.NopCloser(bytes.NewReader(body))
}
- var cost float64 = 0
+ // var cost float64 = 0
if c.FullPath() == "/api/providers/anthropic/v1/complete" {
logCompletionRequest(log, body, prod, private, cid)
- selectedProvider = "anthropic"
cr := &anthropic.CompletionRequest{}
err = json.Unmarshal(body, cr)
if err != nil {
@@ -249,19 +284,21 @@ func getMiddleware(kms keyMemStorage, cpm CustomProvidersManager, rm routeManage
return
}
- tks := ae.Count(cr.Prompt)
- tks += anthropicPromptMagicNum
- c.Set("promptTokenCount", tks)
+ enrichedEvent.Request = cr
- model := cr.Model
- cost, err = ae.EstimatePromptCost(model, tks)
- if err != nil {
- logError(log, "error when estimating anthropic completion prompt cost", prod, cid, err)
- }
+ // tks := ae.Count(cr.Prompt)
+ // tks += anthropicPromptMagicNum
+ // c.Set("promptTokenCount", tks)
+
+ // model := cr.Model
+ // cost, err = ae.EstimatePromptCost(model, tks)
+ // if err != nil {
+ // logError(log, "error when estimating anthropic completion prompt cost", prod, cid, err)
+ // }
if cr.Stream {
c.Set("stream", cr.Stream)
- c.Set("estimatedPromptCostInUsd", cost)
+ // c.Set("estimatedPromptCostInUsd", cost)
}
if len(cr.Model) != 0 {
@@ -288,17 +325,22 @@ func getMiddleware(kms keyMemStorage, cpm CustomProvidersManager, rm routeManage
return
}
- selectedProvider = cp.Provider
-
c.Set("provider", cp)
c.Set("route_config", rc)
- tks, err := countTokensFromJson(body, rc.RequestPromptLocation)
- if err != nil {
- logError(log, "error when counting tokens for custom provider request", prod, cid, err)
+ enrichedEvent.Request = body
+
+ customResponse, ok := c.Get("response")
+ if ok {
+ enrichedEvent.Response = customResponse
}
- c.Set("promptTokenCount", tks)
+ // tks, err := countTokensFromJson(body, rc.RequestPromptLocation)
+ // if err != nil {
+ // logError(log, "error when counting tokens for custom provider request", prod, cid, err)
+ // }
+
+ // c.Set("promptTokenCount", tks)
result := gjson.Get(string(body), rc.StreamLocation)
@@ -344,12 +386,15 @@ func getMiddleware(kms keyMemStorage, cpm CustomProvidersManager, rm routeManage
if !rc.ShouldRunEmbeddings() {
ccr := &goopenai.ChatCompletionRequest{}
+
err = json.Unmarshal(body, ccr)
if err != nil {
logError(log, "error when unmarshalling route chat completion request", prod, cid, err)
return
}
+ enrichedEvent.Request = ccr
+
logRequest(log, prod, private, cid, ccr)
if ccr.Stream {
@@ -366,8 +411,6 @@ func getMiddleware(kms keyMemStorage, cpm CustomProvidersManager, rm routeManage
}
if c.FullPath() == "/api/providers/azure/openai/deployments/:deployment_id/chat/completions" {
- selectedProvider = "azure"
-
ccr := &goopenai.ChatCompletionRequest{}
err = json.Unmarshal(body, ccr)
if err != nil {
@@ -375,23 +418,23 @@ func getMiddleware(kms keyMemStorage, cpm CustomProvidersManager, rm routeManage
return
}
+ enrichedEvent.Request = ccr
+
logRequest(log, prod, private, cid, ccr)
- tks, err := e.EstimateChatCompletionPromptTokenCounts("gpt-3.5-turbo", ccr)
- if err != nil {
- stats.Incr("bricksllm.proxy.get_middleware.estimate_chat_completion_prompt_token_counts_error", nil, 1)
- logError(log, "error when estimating prompt cost", prod, cid, err)
- }
+ // tks, err := e.EstimateChatCompletionPromptTokenCounts("gpt-3.5-turbo", ccr)
+ // if err != nil {
+ // stats.Incr("bricksllm.proxy.get_middleware.estimate_chat_completion_prompt_token_counts_error", nil, 1)
+ // logError(log, "error when estimating prompt cost", prod, cid, err)
+ // }
if ccr.Stream {
c.Set("stream", true)
- c.Set("promptTokenCount", tks)
+ // c.Set("promptTokenCount", tks)
}
}
if c.FullPath() == "/api/providers/azure/openai/deployments/:deployment_id/embeddings" {
- selectedProvider = "azure"
-
er := &goopenai.EmbeddingRequest{}
err = json.Unmarshal(body, er)
if err != nil {
@@ -404,11 +447,11 @@ func getMiddleware(kms keyMemStorage, cpm CustomProvidersManager, rm routeManage
logEmbeddingRequest(log, prod, private, cid, er)
- cost, err = aoe.EstimateEmbeddingsCost(er)
- if err != nil {
- stats.Incr("bricksllm.proxy.get_middleware.estimate_azure_openai_embeddings_cost_error", nil, 1)
- logError(log, "error when estimating azure openai embeddings cost", prod, cid, err)
- }
+ // cost, err = aoe.EstimateEmbeddingsCost(er)
+ // if err != nil {
+ // stats.Incr("bricksllm.proxy.get_middleware.estimate_azure_openai_embeddings_cost_error", nil, 1)
+ // logError(log, "error when estimating azure openai embeddings cost", prod, cid, err)
+ // }
}
if c.FullPath() == "/api/providers/openai/v1/chat/completions" {
@@ -419,6 +462,8 @@ func getMiddleware(kms keyMemStorage, cpm CustomProvidersManager, rm routeManage
return
}
+ enrichedEvent.Request = ccr
+
c.Set("model", ccr.Model)
logRequest(log, prod, private, cid, ccr)
@@ -445,16 +490,16 @@ func getMiddleware(kms keyMemStorage, cpm CustomProvidersManager, rm routeManage
return
}
- c.Set("model", er.Model.String())
+ c.Set("model", string(er.Model))
c.Set("encoding_format", string(er.EncodingFormat))
logEmbeddingRequest(log, prod, private, cid, er)
- cost, err = e.EstimateEmbeddingsCost(er)
- if err != nil {
- stats.Incr("bricksllm.proxy.get_middleware.estimate_embeddings_cost_error", nil, 1)
- logError(log, "error when estimating embeddings cost", prod, cid, err)
- }
+ // cost, err = e.EstimateEmbeddingsCost(er)
+ // if err != nil {
+ // stats.Incr("bricksllm.proxy.get_middleware.estimate_embeddings_cost_error", nil, 1)
+ // logError(log, "error when estimating embeddings cost", prod, cid, err)
+ // }
}
if c.FullPath() == "/api/providers/openai/v1/images/generations" && c.Request.Method == http.MethodPost {
@@ -735,48 +780,13 @@ func getMiddleware(kms keyMemStorage, cpm CustomProvidersManager, rm routeManage
logRetrieveFileContentRequest(log, body, prod, cid, fid)
}
- err = v.Validate(kc, cost)
- if err != nil {
- stats.Incr("bricksllm.proxy.get_middleware.validation_error", nil, 1)
-
- if _, ok := err.(expirationError); ok {
- stats.Incr("bricksllm.proxy.get_middleware.key_expired", nil, 1)
-
- truePtr := true
- _, err = ks.UpdateKey(kc.KeyId, &key.UpdateKey{
- Revoked: &truePtr,
- RevokedReason: "Key has expired or exceeded set spend limit",
- })
-
- if err != nil {
- stats.Incr("bricksllm.proxy.get_middleware.update_key_error", nil, 1)
- log.Sugar().Debugf("error when updating revoking the api key %s: %v", kc.KeyId, err)
- }
-
- JSON(c, http.StatusUnauthorized, "[BricksLLM] key has expired")
- c.Abort()
- return
- }
-
- if _, ok := err.(rateLimitError); ok {
- stats.Incr("bricksllm.proxy.get_middleware.rate_limited", nil, 1)
- JSON(c, http.StatusTooManyRequests, "[BricksLLM] too many requests")
- c.Abort()
- return
- }
-
- logError(log, "error when validating api key", prod, cid, err)
+ if ac.GetAccessStatus(kc.KeyId) {
+ stats.Incr("bricksllm.proxy.get_middleware.rate_limited", nil, 1)
+ JSON(c, http.StatusTooManyRequests, "[BricksLLM] too many requests")
+ c.Abort()
return
}
- if len(kc.RateLimitUnit) != 0 {
- if err := rlm.Increment(kc.KeyId, kc.RateLimitUnit); err != nil {
- stats.Incr("bricksllm.proxy.get_middleware.rate_limit_increment_error", nil, 1)
-
- logError(log, "error when incrementing rate limit counter", prod, cid, err)
- }
- }
-
c.Next()
}
}
diff --git a/internal/server/web/proxy/proxy.go b/internal/server/web/proxy/proxy.go
index dbcb5eb..221dac8 100644
--- a/internal/server/web/proxy/proxy.go
+++ b/internal/server/web/proxy/proxy.go
@@ -39,7 +39,7 @@ type ProxyServer struct {
}
type recorder interface {
- RecordKeySpend(keyId string, micros int64, costLimitUnit key.TimeUnit) error
+ // RecordKeySpend(keyId string, micros int64, costLimitUnit key.TimeUnit) error
RecordEvent(e *event.Event) error
}
@@ -55,12 +55,12 @@ type CustomProvidersManager interface {
GetCustomProviderFromMem(name string) *custom.Provider
}
-func NewProxyServer(log *zap.Logger, mode, privacyMode string, c cache, m KeyManager, rm routeManager, a authenticator, psm ProviderSettingsManager, cpm CustomProvidersManager, ks keyStorage, kms keyMemStorage, e estimator, ae anthropicEstimator, aoe azureEstimator, v validator, r recorder, rlm rateLimitManager, timeOut time.Duration) (*ProxyServer, error) {
+func NewProxyServer(log *zap.Logger, mode, privacyMode string, c cache, m KeyManager, rm routeManager, a authenticator, psm ProviderSettingsManager, cpm CustomProvidersManager, ks keyStorage, kms keyMemStorage, e estimator, ae anthropicEstimator, aoe azureEstimator, v validator, r recorder, pub publisher, rlm rateLimitManager, timeOut time.Duration, ac accessCache) (*ProxyServer, error) {
router := gin.New()
prod := mode == "production"
private := privacyMode == "strict"
- router.Use(getMiddleware(kms, cpm, rm, a, prod, private, e, ae, aoe, v, ks, log, rlm, r, "proxy"))
+ router.Use(getMiddleware(kms, cpm, rm, a, prod, private, e, ae, aoe, v, ks, log, rlm, pub, "proxy", ac))
client := http.Client{}
@@ -942,13 +942,13 @@ func getEmbeddingHandler(r recorder, prod, private bool, psm ProviderSettingsMan
return
}
- raw, exists := c.Get("key")
- kc, ok := raw.(*key.ResponseKey)
- if !exists || !ok {
- stats.Incr("bricksllm.proxy.get_embedding_handler.api_key_not_registered", nil, 1)
- JSON(c, http.StatusUnauthorized, "[BricksLLM] api key is not registered")
- return
- }
+ // raw, exists := c.Get("key")
+ // kc, ok := raw.(*key.ResponseKey)
+ // if !exists || !ok {
+ // stats.Incr("bricksllm.proxy.get_embedding_handler.api_key_not_registered", nil, 1)
+ // JSON(c, http.StatusUnauthorized, "[BricksLLM] api key is not registered")
+ // return
+ // }
id := c.GetString(correlationId)
@@ -1032,12 +1032,12 @@ func getEmbeddingHandler(r recorder, prod, private bool, psm ProviderSettingsMan
logError(log, "error when estimating openai cost for embedding", prod, id, err)
}
- micros := int64(cost * 1000000)
- err = r.RecordKeySpend(kc.KeyId, micros, kc.CostLimitInUsdUnit)
- if err != nil {
- stats.Incr("bricksllm.proxy.get_embedding_handler.record_key_spend_error", nil, 1)
- logError(log, "error when recording openai spend for embedding", prod, id, err)
- }
+ // micros := int64(cost * 1000000)
+ // err = r.RecordKeySpend(kc.KeyId, micros, kc.CostLimitInUsdUnit)
+ // if err != nil {
+ // stats.Incr("bricksllm.proxy.get_embedding_handler.record_key_spend_error", nil, 1)
+ // logError(log, "error when recording openai spend for embedding", prod, id, err)
+ // }
}
}
@@ -1085,13 +1085,13 @@ func getChatCompletionHandler(r recorder, prod, private bool, psm ProviderSettin
}
cid := c.GetString(correlationId)
- raw, exists := c.Get("key")
- kc, ok := raw.(*key.ResponseKey)
- if !exists || !ok {
- stats.Incr("bricksllm.proxy.get_chat_completion_handler.api_key_not_registered", nil, 1)
- JSON(c, http.StatusUnauthorized, "[BricksLLM] api key is not registered")
- return
- }
+ // raw, exists := c.Get("key")
+ // kc, ok := raw.(*key.ResponseKey)
+ // if !exists || !ok {
+ // stats.Incr("bricksllm.proxy.get_chat_completion_handler.api_key_not_registered", nil, 1)
+ // JSON(c, http.StatusUnauthorized, "[BricksLLM] api key is not registered")
+ // return
+ // }
ctx, cancel := context.WithTimeout(context.Background(), timeOut)
defer cancel()
@@ -1161,12 +1161,12 @@ func getChatCompletionHandler(r recorder, prod, private bool, psm ProviderSettin
logError(log, "error when estimating openai cost", prod, cid, err)
}
- micros := int64(cost * 1000000)
- err = r.RecordKeySpend(kc.KeyId, micros, kc.CostLimitInUsdUnit)
- if err != nil {
- stats.Incr("bricksllm.proxy.get_chat_completion_handler.record_key_spend_error", nil, 1)
- logError(log, "error when recording openai spend", prod, cid, err)
- }
+ // micros := int64(cost * 1000000)
+ // err = r.RecordKeySpend(kc.KeyId, micros, kc.CostLimitInUsdUnit)
+ // if err != nil {
+ // stats.Incr("bricksllm.proxy.get_chat_completion_handler.record_key_spend_error", nil, 1)
+ // logError(log, "error when recording openai spend", prod, cid, err)
+ // }
}
c.Set("costInUsd", cost)
@@ -1195,22 +1195,24 @@ func getChatCompletionHandler(r recorder, prod, private bool, psm ProviderSettin
}
buffer := bufio.NewReader(res.Body)
- var totalCost float64 = 0
- var totalTokens int = 0
+ // var totalCost float64 = 0
+ // var totalTokens int = 0
content := ""
defer func() {
- tks, cost, err := e.EstimateChatCompletionStreamCostWithTokenCounts(model, content)
- if err != nil {
- stats.Incr("bricksllm.proxy.get_chat_completion_handler.estimate_chat_completion_cost_and_tokens_error", nil, 1)
- logError(log, "error when estimating chat completion stream cost with token counts", prod, cid, err)
- }
+ c.Set("content", content)
+
+ // tks, cost, err := e.EstimateChatCompletionStreamCostWithTokenCounts(model, content)
+ // if err != nil {
+ // stats.Incr("bricksllm.proxy.get_chat_completion_handler.estimate_chat_completion_cost_and_tokens_error", nil, 1)
+ // logError(log, "error when estimating chat completion stream cost with token counts", prod, cid, err)
+ // }
- estimatedPromptCost := c.GetFloat64("estimatedPromptCostInUsd")
- totalCost = cost + estimatedPromptCost
- totalTokens += tks
+ // estimatedPromptCost := c.GetFloat64("estimatedPromptCostInUsd")
+ // totalCost = cost + estimatedPromptCost
+ // totalTokens += tks
- c.Set("costInUsd", totalCost)
- c.Set("completionTokenCount", totalTokens)
+ // c.Set("costInUsd", totalCost)
+ // c.Set("completionTokenCount", totalTokens)
}()
stats.Incr("bricksllm.proxy.get_chat_completion_handler.streaming_requests", nil, 1)
@@ -1522,7 +1524,7 @@ func logEmbeddingRequest(log *zap.Logger, prod, private bool, id string, r *goop
if prod {
fields := []zapcore.Field{
zap.String(correlationId, id),
- zap.String("model", r.Model.String()),
+ zap.String("model", string(r.Model)),
zap.String("encoding_format", string(r.EncodingFormat)),
zap.String("user", r.User),
}
diff --git a/internal/server/web/proxy/route.go b/internal/server/web/proxy/route.go
index 13357a9..90c46af 100644
--- a/internal/server/web/proxy/route.go
+++ b/internal/server/web/proxy/route.go
@@ -231,12 +231,12 @@ func parseResult(c *gin.Context, ca cache, kc *key.ResponseKey, runEmbeddings bo
cost = ecost
}
- micros := int64(cost * 1000000)
+ // micros := int64(cost * 1000000)
- err := r.RecordKeySpend(kc.KeyId, micros, kc.CostLimitInUsdUnit)
- if err != nil {
- return err
- }
+ // err := r.RecordKeySpend(kc.KeyId, micros, kc.CostLimitInUsdUnit)
+ // if err != nil {
+ // return err
+ // }
}
if !runEmbeddings {
@@ -262,11 +262,11 @@ func parseResult(c *gin.Context, ca cache, kc *key.ResponseKey, runEmbeddings bo
}
}
- micros := int64(cost * 1000000)
- err = r.RecordKeySpend(kc.KeyId, micros, kc.CostLimitInUsdUnit)
- if err != nil {
- return err
- }
+ // micros := int64(cost * 1000000)
+ // err = r.RecordKeySpend(kc.KeyId, micros, kc.CostLimitInUsdUnit)
+ // if err != nil {
+ // return err
+ // }
}
return nil
From 88dfa53f2d7d7109a84a4db190023fbec81f85c2 Mon Sep 17 00:00:00 2001
From: Spike Lu
Date: Tue, 6 Feb 2024 11:45:15 -0800
Subject: [PATCH 22/71] improve performance
---
internal/server/web/proxy/middleware.go | 14 +++++++-------
1 file changed, 7 insertions(+), 7 deletions(-)
diff --git a/internal/server/web/proxy/middleware.go b/internal/server/web/proxy/middleware.go
index e8b268d..78ccb00 100644
--- a/internal/server/web/proxy/middleware.go
+++ b/internal/server/web/proxy/middleware.go
@@ -468,17 +468,17 @@ func getMiddleware(kms keyMemStorage, cpm CustomProvidersManager, rm routeManage
logRequest(log, prod, private, cid, ccr)
- tks, cost, err := e.EstimateChatCompletionPromptCostWithTokenCounts(ccr)
- if err != nil {
- stats.Incr("bricksllm.proxy.get_middleware.estimate_chat_completion_prompt_cost_with_token_counts_error", nil, 1)
+ // tks, cost, err := e.EstimateChatCompletionPromptCostWithTokenCounts(ccr)
+ // if err != nil {
+ // stats.Incr("bricksllm.proxy.get_middleware.estimate_chat_completion_prompt_cost_with_token_counts_error", nil, 1)
- logError(log, "error when estimating prompt cost", prod, cid, err)
- }
+ // logError(log, "error when estimating prompt cost", prod, cid, err)
+ // }
if ccr.Stream {
c.Set("stream", true)
- c.Set("estimatedPromptCostInUsd", cost)
- c.Set("promptTokenCount", tks)
+ // c.Set("estimatedPromptCostInUsd", cost)
+ // c.Set("promptTokenCount", tks)
}
}
From 6b4ef3a7d31899a66aed65c75eeda182c7af92e0 Mon Sep 17 00:00:00 2001
From: Spike Lu
Date: Tue, 6 Feb 2024 14:40:03 -0800
Subject: [PATCH 23/71] update CHANGELOG
---
CHANGELOG.md | 11 +++++++++++
1 file changed, 11 insertions(+)
diff --git a/CHANGELOG.md b/CHANGELOG.md
index dca46b1..d8eda6e 100644
--- a/CHANGELOG.md
+++ b/CHANGELOG.md
@@ -1,3 +1,14 @@
+## 1.9.1 - 2024-02-06
+### Fixed
+- Fixed OpenAI chat completion endpoint being slow
+
+## 1.9.0 - 2024-02-06
+### Changed
+- Drastically improved performance through event driven architecture
+
+### Fixed
+- Fixed API calls that exceeds cost limit not being blocked bug
+
## 1.8.2 - 2024-01-31
### Added
- Added support for new chat completion models
From edf706fdd9e5ea03540111e2c4d0bba4d5541a0e Mon Sep 17 00:00:00 2001
From: Spike Lu
Date: Tue, 6 Feb 2024 15:35:58 -0800
Subject: [PATCH 24/71] update CHANGELOG
---
CHANGELOG.md | 4 ++++
internal/message/handler.go | 14 +++++++-------
internal/server/web/proxy/middleware.go | 7 +++++++
3 files changed, 18 insertions(+), 7 deletions(-)
diff --git a/CHANGELOG.md b/CHANGELOG.md
index d8eda6e..e61c15c 100644
--- a/CHANGELOG.md
+++ b/CHANGELOG.md
@@ -1,3 +1,7 @@
+## 1.9.2 - 2024-02-06
+### Fixed
+- Fixed custom route tokens recording issue incurred by the new architecture
+
## 1.9.1 - 2024-02-06
### Fixed
- Fixed OpenAI chat completion endpoint being slow
diff --git a/internal/message/handler.go b/internal/message/handler.go
index 04c5697..3725711 100644
--- a/internal/message/handler.go
+++ b/internal/message/handler.go
@@ -380,13 +380,6 @@ func (h *Handler) decorateEvent(m Message) error {
return errors.New("event request data cannot be parsed as anthropic completon request")
}
- content, ok := e.Response.([]byte)
- if !ok {
- stats.Incr("bricksllm.message.handler.decorate_event.event_response_custom_provider_parsing_error", nil, 1)
- h.log.Debug("event contains response that cannot be converted to bytes", zap.Any("data", m.Data))
- return errors.New("event response data cannot be converted to bytes")
- }
-
tks, err := countTokensFromJson(body, e.RouteConfig.RequestPromptLocation)
if err != nil {
stats.Incr("bricksllm.message.handler.decorate_event.count_tokens_from_json_error", nil, 1)
@@ -408,6 +401,13 @@ func (h *Handler) decorateEvent(m Message) error {
}
if !result.IsBool() {
+ content, ok := e.Response.([]byte)
+ if !ok {
+ stats.Incr("bricksllm.message.handler.decorate_event.event_response_custom_provider_parsing_error", nil, 1)
+ h.log.Debug("event contains response that cannot be converted to bytes", zap.Any("data", m.Data))
+ return errors.New("event response data cannot be converted to bytes")
+ }
+
completiontks, err := countTokensFromJson(content, e.RouteConfig.ResponseCompletionLocation)
if err != nil {
stats.Incr("bricksllm.message.handler.decorate_event.count_tokens_from_json_error", nil, 1)
diff --git a/internal/server/web/proxy/middleware.go b/internal/server/web/proxy/middleware.go
index 78ccb00..7558ff4 100644
--- a/internal/server/web/proxy/middleware.go
+++ b/internal/server/web/proxy/middleware.go
@@ -210,6 +210,11 @@ func getMiddleware(kms keyMemStorage, cpm CustomProvidersManager, rm routeManage
enrichedEvent.Content = content
}
+ resp, ok := c.Get("response")
+ if ok {
+ enrichedEvent.Response = resp
+ }
+
pub.Publish(message.Message{
Type: "event",
Data: enrichedEvent,
@@ -310,6 +315,8 @@ func getMiddleware(kms keyMemStorage, cpm CustomProvidersManager, rm routeManage
providerName := c.Param("provider")
rc := cpm.GetRouteConfigFromMem(providerName, c.Param("wildcard"))
+ enrichedEvent.RouteConfig = rc
+
cp := cpm.GetCustomProviderFromMem(providerName)
if cp == nil {
stats.Incr("bricksllm.proxy.get_middleware.provider_not_found", nil, 1)
From 0364b9df1837849fcb7a7cd07da0bb05ec9dc98e Mon Sep 17 00:00:00 2001
From: Donald Buczek
Date: Sat, 10 Feb 2024 21:41:24 +0100
Subject: [PATCH 25/71] proxy: Implement CORS handler for XMLHttpRequest
This commit introduces a middleware that adds CORS headers and handles
CORS preflight requests. This enhancement allows the proxy to be
utilized with XMLHttpRequest from a browser, enforcing the same-origin
policy.
---
internal/server/web/proxy/proxy.go | 20 ++++++++++++++++++++
1 file changed, 20 insertions(+)
diff --git a/internal/server/web/proxy/proxy.go b/internal/server/web/proxy/proxy.go
index 221dac8..efee23a 100644
--- a/internal/server/web/proxy/proxy.go
+++ b/internal/server/web/proxy/proxy.go
@@ -55,11 +55,31 @@ type CustomProvidersManager interface {
GetCustomProviderFromMem(name string) *custom.Provider
}
+func CorsMiddleware() gin.HandlerFunc {
+ return func(c *gin.Context) {
+ a_or_b := func(a, b string) string {
+ if a != "" {
+ return a
+ } else {
+ return b
+ }
+ }
+ c.Header("Access-Control-Allow-Origin", a_or_b(c.GetHeader("Origin"), "*"))
+ if c.Request.Method == "OPTIONS" {
+ c.Header("Access-Control-Allow-Methods", a_or_b(c.GetHeader("Access-Control-Request-Method"), "*"))
+ c.Header("Access-Control-Allow-Headers", a_or_b(c.GetHeader("Access-Control-Request-Headers"), "*"))
+ c.Header("Access-Control-Max-Age", "3600")
+ c.AbortWithStatus(204)
+ }
+ }
+}
+
func NewProxyServer(log *zap.Logger, mode, privacyMode string, c cache, m KeyManager, rm routeManager, a authenticator, psm ProviderSettingsManager, cpm CustomProvidersManager, ks keyStorage, kms keyMemStorage, e estimator, ae anthropicEstimator, aoe azureEstimator, v validator, r recorder, pub publisher, rlm rateLimitManager, timeOut time.Duration, ac accessCache) (*ProxyServer, error) {
router := gin.New()
prod := mode == "production"
private := privacyMode == "strict"
+ router.Use(CorsMiddleware())
router.Use(getMiddleware(kms, cpm, rm, a, prod, private, e, ae, aoe, v, ks, log, rlm, pub, "proxy", ac))
client := http.Client{}
From 9e93ec984a0596769a0ab344f734f449534db29e Mon Sep 17 00:00:00 2001
From: Spike Lu
Date: Tue, 13 Feb 2024 13:19:46 -0800
Subject: [PATCH 26/71] update CHANGELOG
---
CHANGELOG.md | 4 ++++
1 file changed, 4 insertions(+)
diff --git a/CHANGELOG.md b/CHANGELOG.md
index e61c15c..e238d00 100644
--- a/CHANGELOG.md
+++ b/CHANGELOG.md
@@ -1,3 +1,7 @@
+## 1.9.3 - 2024-02-13
+### Added
+- Added CORS support in the proxy
+
## 1.9.2 - 2024-02-06
### Fixed
- Fixed custom route tokens recording issue incurred by the new architecture
From a5bfe3758d4dd33b2599f0f57514e32dce86f4b2 Mon Sep 17 00:00:00 2001
From: Spike Lu
Date: Fri, 16 Feb 2024 10:36:57 -0800
Subject: [PATCH 27/71] add revoked reasons
---
internal/key/key.go | 2 ++
1 file changed, 2 insertions(+)
diff --git a/internal/key/key.go b/internal/key/key.go
index 7405e30..3841184 100644
--- a/internal/key/key.go
+++ b/internal/key/key.go
@@ -9,6 +9,8 @@ import (
internal_errors "github.com/bricks-cloud/bricksllm/internal/errors"
)
+const RevokedReasonExpired string = "expired"
+
type UpdateKey struct {
Name string `json:"name"`
UpdatedAt int64 `json:"updatedAt"`
From c679cbf120bd9f03715f846c067d16c612178c97 Mon Sep 17 00:00:00 2001
From: Spike Lu
Date: Fri, 16 Feb 2024 10:38:11 -0800
Subject: [PATCH 28/71] add validation to revoked field update
---
internal/manager/key.go | 10 ++++++++++
1 file changed, 10 insertions(+)
diff --git a/internal/manager/key.go b/internal/manager/key.go
index 185fec1..884cd8e 100644
--- a/internal/manager/key.go
+++ b/internal/manager/key.go
@@ -19,6 +19,7 @@ type Storage interface {
DeleteKey(id string) error
GetProviderSetting(id string) (*provider.Setting, error)
GetProviderSettings(withSecret bool, ids []string) ([]*provider.Setting, error)
+ GetKey(keyId string) (*key.ResponseKey, error)
}
type Encrypter interface {
@@ -95,6 +96,15 @@ func (m *Manager) UpdateKey(id string, uk *key.UpdateKey) (*key.ResponseKey, err
return nil, err
}
+ existingKey, err := m.s.GetKey(id)
+ if err != nil {
+ return nil, err
+ }
+
+ if uk.Revoked != nil && !*uk.Revoked && existingKey.RevokedReason == key.RevokedReasonExpired {
+ return nil, internal_errors.NewValidationError("cannot reenable an expired key")
+ }
+
if len(uk.SettingId) != 0 {
if _, err := m.s.GetProviderSetting(uk.SettingId); err != nil {
return nil, err
From b95060dfd29feca32d974b88a2b70df9027ed3e1 Mon Sep 17 00:00:00 2001
From: Spike Lu
Date: Fri, 16 Feb 2024 10:38:49 -0800
Subject: [PATCH 29/71] remove comments and update revoked reason
---
internal/message/handler.go | 12 +-----------
1 file changed, 1 insertion(+), 11 deletions(-)
diff --git a/internal/message/handler.go b/internal/message/handler.go
index 3725711..67de0c8 100644
--- a/internal/message/handler.go
+++ b/internal/message/handler.go
@@ -159,14 +159,13 @@ func (h *Handler) handleValidationResult(kc *key.ResponseKey, cost float64) erro
if err != nil {
stats.Incr("bricksllm.message.handler.handle_validation_result.handle_validation_result", nil, 1)
- // tested
if _, ok := err.(expirationError); ok {
stats.Incr("bricksllm.message.handler.handle_validation_result.expiraton_error", nil, 1)
truePtr := true
_, err = h.km.UpdateKey(kc.KeyId, &key.UpdateKey{
Revoked: &truePtr,
- RevokedReason: "Key has expired or exceeded set spend limit",
+ RevokedReason: key.RevokedReasonExpired,
})
if err != nil {
@@ -177,7 +176,6 @@ func (h *Handler) handleValidationResult(kc *key.ResponseKey, cost float64) erro
return nil
}
- // tested
if _, ok := err.(rateLimitError); ok {
stats.Incr("bricksllm.message.handler.handle_validation_result.rate_limit_error", nil, 1)
@@ -190,7 +188,6 @@ func (h *Handler) handleValidationResult(kc *key.ResponseKey, cost float64) erro
return nil
}
- // tested
if _, ok := err.(costLimitError); ok {
stats.Incr("bricksllm.message.handler.handle_validation_result.cost_limit_error", nil, 1)
@@ -224,7 +221,6 @@ func (h *Handler) HandleEventWithRequestAndResponse(m Message) error {
h.log.Debug("error when decorating event", zap.Error(err))
}
- // tested
if e.Event.CostInUsd != 0 {
micros := int64(e.Event.CostInUsd * 1000000)
err = h.recorder.RecordKeySpend(e.Event.KeyId, micros, e.Key.CostLimitInUsdUnit)
@@ -234,7 +230,6 @@ func (h *Handler) HandleEventWithRequestAndResponse(m Message) error {
}
}
- // tested
if len(e.Key.RateLimitUnit) != 0 {
if err := h.rlm.Increment(e.Key.KeyId, e.Key.RateLimitUnit); err != nil {
stats.Incr("bricksllm.message.handler.handle_event_with_request_and_response.rate_limit_increment_error", nil, 1)
@@ -243,7 +238,6 @@ func (h *Handler) HandleEventWithRequestAndResponse(m Message) error {
}
}
- // tested
err = h.handleValidationResult(e.Key, e.Event.CostInUsd)
if err != nil {
stats.Incr("bricksllm.message.handler.handle_event_with_request_and_response.handle_validation_result_error", nil, 1)
@@ -252,7 +246,6 @@ func (h *Handler) HandleEventWithRequestAndResponse(m Message) error {
}
- // tested
start := time.Now()
err := h.recorder.RecordEvent(e.Event)
if err != nil {
@@ -276,7 +269,6 @@ func (h *Handler) decorateEvent(m Message) error {
return errors.New("message data cannot be parsed as event with request and response")
}
- // tested
if e.Event.Provider == "anthropic" && e.Event.Path == "/api/providers/anthropic/v1/complete" {
cr, ok := e.Request.(*anthropic.CompletionRequest)
if !ok {
@@ -310,7 +302,6 @@ func (h *Handler) decorateEvent(m Message) error {
e.Event.CostInUsd = completionCost + cost
}
- // tested
if e.Event.Provider == "azure" && e.Event.Path == "/api/providers/azure/openai/deployments/:deployment_id/chat/completions" {
ccr, ok := e.Request.(*goopenai.ChatCompletionRequest)
if !ok {
@@ -344,7 +335,6 @@ func (h *Handler) decorateEvent(m Message) error {
}
}
- // tested
if e.Event.Provider == "openai" && e.Event.Path == "/api/providers/openai/v1/chat/completions" {
ccr, ok := e.Request.(*goopenai.ChatCompletionRequest)
if !ok {
From 51807fb8da9cf7c0fb91a1b184dbd025e50a88d9 Mon Sep 17 00:00:00 2001
From: Spike Lu
Date: Fri, 16 Feb 2024 10:39:09 -0800
Subject: [PATCH 30/71] add new model
---
internal/provider/openai/cost.go | 2 ++
1 file changed, 2 insertions(+)
diff --git a/internal/provider/openai/cost.go b/internal/provider/openai/cost.go
index b645b13..3f36d6f 100644
--- a/internal/provider/openai/cost.go
+++ b/internal/provider/openai/cost.go
@@ -22,6 +22,7 @@ var OpenAiPerThousandTokenCost = map[string]map[string]float64{
"gpt-4-32k-0314": 0.06,
"gpt-3.5-turbo": 0.0015,
"gpt-3.5-turbo-1106": 0.001,
+ "gpt-3.5-turbo-0125": 0.0005,
"gpt-3.5-turbo-0301": 0.0015,
"gpt-3.5-turbo-instruct": 0.0015,
"gpt-3.5-turbo-0613": 0.0015,
@@ -67,6 +68,7 @@ var OpenAiPerThousandTokenCost = map[string]map[string]float64{
"gpt-4-32k-0613": 0.12,
"gpt-4-32k-0314": 0.12,
"gpt-3.5-turbo": 0.002,
+ "gpt-3.5-turbo-0125": 0.0015,
"gpt-3.5-turbo-0301": 0.002,
"gpt-3.5-turbo-0613": 0.002,
"gpt-3.5-turbo-instruct": 0.002,
From 1da3a82ccb6d515f8db525f6c6cec52ece0d1447 Mon Sep 17 00:00:00 2001
From: Spike Lu
Date: Fri, 16 Feb 2024 10:52:15 -0800
Subject: [PATCH 31/71] update CHANGELOG
---
CHANGELOG.md | 5 +++++
1 file changed, 5 insertions(+)
diff --git a/CHANGELOG.md b/CHANGELOG.md
index e238d00..b710811 100644
--- a/CHANGELOG.md
+++ b/CHANGELOG.md
@@ -1,3 +1,8 @@
+## 1.9.4 - 2024-02-16
+### Added
+- Added support for calculating cost for the cheaper 3.5 turbo model
+- Added validation to updating revoked key field
+
## 1.9.3 - 2024-02-13
### Added
- Added CORS support in the proxy
From 9b988e1704986fac2e21dabe66a12161f965c055 Mon Sep 17 00:00:00 2001
From: Spike Lu
Date: Sun, 18 Feb 2024 21:51:30 -0800
Subject: [PATCH 32/71] update cost map
---
internal/provider/openai/cost.go | 4 ++++
1 file changed, 4 insertions(+)
diff --git a/internal/provider/openai/cost.go b/internal/provider/openai/cost.go
index 3f36d6f..b0e6eac 100644
--- a/internal/provider/openai/cost.go
+++ b/internal/provider/openai/cost.go
@@ -12,8 +12,10 @@ import (
var OpenAiPerThousandTokenCost = map[string]map[string]float64{
"prompt": {
"gpt-4-1106-preview": 0.01,
+ "gpt-4-turbo-preview": 0.01,
"gpt-4-0125-preview": 0.01,
"gpt-4-1106-vision-preview": 0.01,
+ "gpt-4-vision-preview": 0.01,
"gpt-4": 0.03,
"gpt-4-0314": 0.03,
"gpt-4-0613": 0.03,
@@ -58,9 +60,11 @@ var OpenAiPerThousandTokenCost = map[string]map[string]float64{
},
"completion": {
"gpt-3.5-turbo-1106": 0.002,
+ "gpt-4-turbo-preview": 0.03,
"gpt-4-1106-preview": 0.03,
"gpt-4-0125-preview": 0.03,
"gpt-4-1106-vision-preview": 0.03,
+ "gpt-4-vision-preview": 0.03,
"gpt-4": 0.06,
"gpt-4-0314": 0.06,
"gpt-4-0613": 0.06,
From b82206e8cc00549a0a332a0703055d07bceaad43 Mon Sep 17 00:00:00 2001
From: Spike Lu
Date: Sun, 18 Feb 2024 22:02:24 -0800
Subject: [PATCH 33/71] support updating key limits
---
cmd/bricksllm/main.go | 2 +-
internal/key/key.go | 53 +++++++++++++++++++----
internal/manager/key.go | 45 +++++++++++++++++--
internal/storage/postgresql/postgresql.go | 42 ++++++++++++++++--
internal/storage/redis/access-cache.go | 11 +++++
internal/storage/redis/cache.go | 11 +++++
6 files changed, 149 insertions(+), 15 deletions(-)
diff --git a/cmd/bricksllm/main.go b/cmd/bricksllm/main.go
index 69d3354..9cfd0f5 100644
--- a/cmd/bricksllm/main.go
+++ b/cmd/bricksllm/main.go
@@ -190,7 +190,7 @@ func main() {
apiCache := redisStorage.NewCache(apiRedisCache, cfg.RedisWriteTimeout, cfg.RedisReadTimeout)
accessCache := redisStorage.NewAccessCache(accessRedisCache, cfg.RedisWriteTimeout, cfg.RedisReadTimeout)
- m := manager.NewManager(store)
+ m := manager.NewManager(store, costLimitCache, rateLimitCache, accessCache)
krm := manager.NewReportingManager(costStorage, store, store)
psm := manager.NewProviderSettingsManager(store, psMemStore)
cpm := manager.NewCustomProvidersManager(store, cpMemStore)
diff --git a/internal/key/key.go b/internal/key/key.go
index 3841184..81319e3 100644
--- a/internal/key/key.go
+++ b/internal/key/key.go
@@ -12,14 +12,19 @@ import (
const RevokedReasonExpired string = "expired"
type UpdateKey struct {
- Name string `json:"name"`
- UpdatedAt int64 `json:"updatedAt"`
- Tags []string `json:"tags"`
- Revoked *bool `json:"revoked"`
- RevokedReason string `json:"revokedReason"`
- SettingId string `json:"settingId"`
- SettingIds []string `json:"settingIds"`
- AllowedPaths *[]PathConfig `json:"allowedPaths,omitempty"`
+ Name string `json:"name"`
+ UpdatedAt int64 `json:"updatedAt"`
+ Tags []string `json:"tags"`
+ Revoked *bool `json:"revoked"`
+ RevokedReason string `json:"revokedReason"`
+ SettingId string `json:"settingId"`
+ SettingIds []string `json:"settingIds"`
+ CostLimitInUsd float64 `json:"costLimitInUsd"`
+ CostLimitInUsdOverTime float64 `json:"costLimitInUsdOverTime"`
+ CostLimitInUsdUnit TimeUnit `json:"costLimitInUsdUnit"`
+ RateLimitOverTime int `json:"rateLimitOverTime"`
+ RateLimitUnit TimeUnit `json:"rateLimitUnit"`
+ AllowedPaths *[]PathConfig `json:"allowedPaths,omitempty"`
}
func (uk *UpdateKey) Validate() error {
@@ -36,6 +41,10 @@ func (uk *UpdateKey) Validate() error {
}
}
+ if uk.CostLimitInUsd < 0 {
+ invalid = append(invalid, "costLimitInUsd")
+ }
+
if uk.UpdatedAt <= 0 {
invalid = append(invalid, "updatedAt")
}
@@ -66,6 +75,34 @@ func (uk *UpdateKey) Validate() error {
return internal_errors.NewValidationError(fmt.Sprintf("fields [%s] are invalid", strings.Join(invalid, ", ")))
}
+ if len(uk.RateLimitUnit) != 0 && uk.RateLimitOverTime == 0 {
+ return internal_errors.NewValidationError("rate limit over time can not be empty if rate limit unit is specified")
+ }
+
+ if len(uk.CostLimitInUsdUnit) != 0 && uk.CostLimitInUsdOverTime == 0 {
+ return internal_errors.NewValidationError("cost limit over time can not be empty if cost limit unit is specified")
+ }
+
+ if uk.RateLimitOverTime != 0 {
+ if len(uk.RateLimitUnit) == 0 {
+ return internal_errors.NewValidationError("rate limit unit can not be empty if rate limit over time is specified")
+ }
+
+ if uk.RateLimitUnit != HourTimeUnit && uk.RateLimitUnit != MinuteTimeUnit && uk.RateLimitUnit != SecondTimeUnit && uk.RateLimitUnit != DayTimeUnit {
+ return internal_errors.NewValidationError("rate limit unit can not be identified")
+ }
+ }
+
+ if uk.CostLimitInUsdOverTime != 0 {
+ if len(uk.CostLimitInUsdUnit) == 0 {
+ return internal_errors.NewValidationError("cost limit unit can not be empty if cost limit over time is specified")
+ }
+
+ if uk.CostLimitInUsdUnit != DayTimeUnit && uk.CostLimitInUsdUnit != HourTimeUnit && uk.CostLimitInUsdUnit != MonthTimeUnit && uk.CostLimitInUsdUnit != MinuteTimeUnit {
+ return internal_errors.NewValidationError("cost limit unit can not be identified")
+ }
+ }
+
return nil
}
diff --git a/internal/manager/key.go b/internal/manager/key.go
index 884cd8e..0bf1f90 100644
--- a/internal/manager/key.go
+++ b/internal/manager/key.go
@@ -22,17 +22,35 @@ type Storage interface {
GetKey(keyId string) (*key.ResponseKey, error)
}
+type costLimitCache interface {
+ Delete(keyId string) error
+}
+
+type rateLimitCache interface {
+ Delete(keyId string) error
+}
+
+type accessCache interface {
+ Delete(keyId string) error
+}
+
type Encrypter interface {
Encrypt(secret string) string
}
type Manager struct {
- s Storage
+ s Storage
+ clc costLimitCache
+ rlc rateLimitCache
+ ac accessCache
}
-func NewManager(s Storage) *Manager {
+func NewManager(s Storage, clc costLimitCache, rlc rateLimitCache, ac accessCache) *Manager {
return &Manager{
- s: s,
+ s: s,
+ clc: clc,
+ rlc: rlc,
+ ac: ac,
}
}
@@ -126,6 +144,27 @@ func (m *Manager) UpdateKey(id string, uk *key.UpdateKey) (*key.ResponseKey, err
}
}
+ if len(uk.CostLimitInUsdUnit) != 0 {
+ err := m.clc.Delete(id)
+ if err != nil {
+ return nil, err
+ }
+ }
+
+ if len(uk.RateLimitUnit) != 0 {
+ err := m.rlc.Delete(id)
+ if err != nil {
+ return nil, err
+ }
+ }
+
+ if len(uk.CostLimitInUsdUnit) != 0 || len(uk.RateLimitUnit) != 0 {
+ err := m.ac.Delete(id)
+ if err != nil {
+ return nil, err
+ }
+ }
+
return m.s.UpdateKey(id, uk)
}
diff --git a/internal/storage/postgresql/postgresql.go b/internal/storage/postgresql/postgresql.go
index 1a1467b..87542ad 100644
--- a/internal/storage/postgresql/postgresql.go
+++ b/internal/storage/postgresql/postgresql.go
@@ -962,14 +962,50 @@ func (s *Store) UpdateKey(id string, uk *key.UpdateKey) (*key.ResponseKey, error
}
if uk.Revoked != nil {
+ if *uk.Revoked && len(uk.RevokedReason) != 0 {
+ values = append(values, uk.RevokedReason)
+ fields = append(fields, fmt.Sprintf("revoked_reason = $%d", counter))
+ counter++
+ }
+
+ if !*uk.Revoked {
+ values = append(values, "")
+ fields = append(fields, fmt.Sprintf("revoked_reason = $%d", counter))
+ counter++
+ }
+
values = append(values, uk.Revoked)
fields = append(fields, fmt.Sprintf("revoked = $%d", counter))
counter++
}
- if len(uk.RevokedReason) != 0 {
- values = append(values, uk.RevokedReason)
- fields = append(fields, fmt.Sprintf("revoked_reason = $%d", counter))
+ if uk.CostLimitInUsd != 0 {
+ values = append(values, uk.CostLimitInUsd)
+ fields = append(fields, fmt.Sprintf("cost_limit_in_usd = $%d", counter))
+ counter++
+ }
+
+ if uk.CostLimitInUsdOverTime != 0 {
+ values = append(values, uk.CostLimitInUsdOverTime)
+ fields = append(fields, fmt.Sprintf("cost_limit_in_usd_over_time = $%d", counter))
+ counter++
+ }
+
+ if len(uk.CostLimitInUsdUnit) != 0 {
+ values = append(values, uk.CostLimitInUsdUnit)
+ fields = append(fields, fmt.Sprintf("cost_limit_in_usd_unit = $%d", counter))
+ counter++
+ }
+
+ if uk.RateLimitOverTime != 0 {
+ values = append(values, uk.RateLimitOverTime)
+ fields = append(fields, fmt.Sprintf("rate_limit_over_time = $%d", counter))
+ counter++
+ }
+
+ if len(uk.RateLimitUnit) != 0 {
+ values = append(values, uk.RateLimitUnit)
+ fields = append(fields, fmt.Sprintf("rate_limit_unit = $%d", counter))
counter++
}
diff --git a/internal/storage/redis/access-cache.go b/internal/storage/redis/access-cache.go
index d8dccef..6faa660 100644
--- a/internal/storage/redis/access-cache.go
+++ b/internal/storage/redis/access-cache.go
@@ -22,6 +22,17 @@ func NewAccessCache(c *redis.Client, wt time.Duration, rt time.Duration) *Access
}
}
+func (ac *AccessCache) Delete(key string) error {
+ ctx, cancel := context.WithTimeout(context.Background(), ac.wt)
+ defer cancel()
+ err := ac.client.Del(ctx, key).Err()
+ if err != nil {
+ return err
+ }
+
+ return nil
+}
+
func (ac *AccessCache) Set(key string, timeUnit key.TimeUnit) error {
ttl, err := getCounterTtl(timeUnit)
if err != nil {
diff --git a/internal/storage/redis/cache.go b/internal/storage/redis/cache.go
index 73e887d..a60f13b 100644
--- a/internal/storage/redis/cache.go
+++ b/internal/storage/redis/cache.go
@@ -35,6 +35,17 @@ func (c *Cache) Set(key string, value interface{}, ttl time.Duration) error {
return nil
}
+func (c *Cache) Delete(key string) error {
+ ctx, cancel := context.WithTimeout(context.Background(), c.wt)
+ defer cancel()
+ err := c.client.Del(ctx, key).Err()
+ if err != nil {
+ return err
+ }
+
+ return nil
+}
+
func (c *Cache) GetBytes(key string) ([]byte, error) {
ctx, cancel := context.WithTimeout(context.Background(), c.rt)
defer cancel()
From 72c5cb5fa40f744874303677253ceecec2759e27 Mon Sep 17 00:00:00 2001
From: Spike Lu
Date: Sun, 18 Feb 2024 22:14:45 -0800
Subject: [PATCH 34/71] fixed log
---
internal/message/handler.go | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/internal/message/handler.go b/internal/message/handler.go
index 67de0c8..4bf23ad 100644
--- a/internal/message/handler.go
+++ b/internal/message/handler.go
@@ -340,7 +340,7 @@ func (h *Handler) decorateEvent(m Message) error {
if !ok {
stats.Incr("bricksllm.message.handler.decorate_event.event_request_parsing_error", nil, 1)
h.log.Debug("event contains data that cannot be converted to openai completion request", zap.Any("data", m.Data))
- return errors.New("event request data cannot be parsed as oepnai completon request")
+ return errors.New("event request data cannot be parsed as openai completon request")
}
if ccr.Stream {
From ffca578f1fc5c232d3311ff476a5dc34b9322a5f Mon Sep 17 00:00:00 2001
From: Spike Lu
Date: Sun, 18 Feb 2024 22:19:06 -0800
Subject: [PATCH 35/71] update CHANGELOG
---
CHANGELOG.md | 11 +++++++++++
1 file changed, 11 insertions(+)
diff --git a/CHANGELOG.md b/CHANGELOG.md
index b710811..ef8522c 100644
--- a/CHANGELOG.md
+++ b/CHANGELOG.md
@@ -1,3 +1,14 @@
+## 1.9.6 - 2024-02-18
+### Added
+- Added support for updating key cost limit and rate limit
+
+### Changed
+- Removed validation to updating revoked key field
+
+## 1.9.5 - 2024-02-18
+### Added
+- Added new model "gpt-4-turbo-preview" and "gpt-4-vision-preview" to the cost map
+
## 1.9.4 - 2024-02-16
### Added
- Added support for calculating cost for the cheaper 3.5 turbo model
From 14fa60c0745df659ee7d511f90ada0978ee96784 Mon Sep 17 00:00:00 2001
From: Spike Lu
Date: Sun, 18 Feb 2024 22:19:16 -0800
Subject: [PATCH 36/71] update DOC
---
README.md | 9 +++++++--
1 file changed, 7 insertions(+), 2 deletions(-)
diff --git a/README.md b/README.md
index 7236b82..68b22cc 100644
--- a/README.md
+++ b/README.md
@@ -310,8 +310,13 @@ PathConfig
> | name | optional | `string` | spike's developer key | Name of the API key. |
> | tags | optional | `[]string` | `["org-tag-12345"]` | Identifiers associated with the key. |
> | revoked | optional | `boolean` | `true` | Indicator for whether the key is revoked. |
-> | revokedReason| optional | `string` | The key has expired | Reason for why the key is revoked. |
-> | allowedPaths | optional | `[]PathConfig` | 2d | Pathes allowed for access. |
+> | revokedReason | optional | `string` | The key has expired | Reason for why the key is revoked. |
+> | costLimitInUsd | optional | `float64` | `5.5` | Total spend limit of the API key.
+> | costLimitInUsdOverTime | optional | `float64` | `2` | Total spend within period of time. This field is required if costLimitInUsdUnit is specified. |
+> | costLimitInUsdUnit | optional | `enum` | `d` | Time unit for costLimitInUsdOverTime. Possible values are [`m`, `h`, `d`, `mo`]. |
+> | rateLimitOverTime | optional | `int` | `2` | rate limit over period of time. This field is required if rateLimitUnit is specified. |
+> | rateLimitUnit | optional | `string` | `m` | Time unit for rateLimitOverTime. Possible values are [`h`, `m`, `s`, `d`] |
+> | allowedPaths | optional | `[{ "path": "/api/providers/openai/v1/chat/completions", "method": "POST"}]` | `` | Pathes allowed for access. |
##### Error Response
From b8435ed0372db54e5b610aff08d728b29dcc4a81 Mon Sep 17 00:00:00 2001
From: Spike Lu
Date: Sun, 18 Feb 2024 22:19:48 -0800
Subject: [PATCH 37/71] remove revoked field update validation
---
internal/manager/key.go | 9 ---------
1 file changed, 9 deletions(-)
diff --git a/internal/manager/key.go b/internal/manager/key.go
index 0bf1f90..2427912 100644
--- a/internal/manager/key.go
+++ b/internal/manager/key.go
@@ -114,15 +114,6 @@ func (m *Manager) UpdateKey(id string, uk *key.UpdateKey) (*key.ResponseKey, err
return nil, err
}
- existingKey, err := m.s.GetKey(id)
- if err != nil {
- return nil, err
- }
-
- if uk.Revoked != nil && !*uk.Revoked && existingKey.RevokedReason == key.RevokedReasonExpired {
- return nil, internal_errors.NewValidationError("cannot reenable an expired key")
- }
-
if len(uk.SettingId) != 0 {
if _, err := m.s.GetProviderSetting(uk.SettingId); err != nil {
return nil, err
From dcf07a3996ce38ddc74313648f688944504e36fe Mon Sep 17 00:00:00 2001
From: Donovan So
Date: Tue, 20 Feb 2024 19:14:43 -0800
Subject: [PATCH 38/71] add link to bricksai page
---
README.md | 3 +++
1 file changed, 3 insertions(+)
diff --git a/README.md b/README.md
index 68b22cc..e4d8038 100644
--- a/README.md
+++ b/README.md
@@ -11,6 +11,9 @@
+> [!TIP]
+> A [managed version of BricksLLM](https://www.trybricks.ai?utm_source=github&utm_medium=repo&utm_campaign=bricksllm) is also available! It is production ready, and comes with a dashboard to make interacting with BricksLLM easier. Try us out for free today!
+
**BricksLLM** is a cloud native AI gateway written in Go. Currently, it provide native support for OpenAI, Anthropic and Azure OpenAI. We let you create API keys that have rate limits, cost limits and TTLs. The API keys can be used in both development and production to achieve fine-grained access control that is not provided by any of the foundational model providers. The proxy is designed to be 100% compatible with existing SDKs.
## Features
From 9e719b8421baa2acb6cab75d2530f4489de5e8cb Mon Sep 17 00:00:00 2001
From: Spike Lu
Date: Wed, 21 Feb 2024 13:26:38 -0800
Subject: [PATCH 39/71] add new fields
---
internal/event/event.go | 3 +++
1 file changed, 3 insertions(+)
diff --git a/internal/event/event.go b/internal/event/event.go
index c0f0152..7f299ab 100644
--- a/internal/event/event.go
+++ b/internal/event/event.go
@@ -15,4 +15,7 @@ type Event struct {
Path string `json:"path"`
Method string `json:"method"`
CustomId string `json:"custom_id"`
+ Request []byte `json:"request"`
+ Response []byte `json:"response"`
+ UserId string `json:"userId"`
}
From c68da7ef2c82905ef602de17dcd93589b6346604 Mon Sep 17 00:00:00 2001
From: Spike Lu
Date: Wed, 21 Feb 2024 13:31:56 -0800
Subject: [PATCH 40/71] add new fields
---
internal/key/key.go | 6 ++++++
1 file changed, 6 insertions(+)
diff --git a/internal/key/key.go b/internal/key/key.go
index 81319e3..444df2e 100644
--- a/internal/key/key.go
+++ b/internal/key/key.go
@@ -25,6 +25,8 @@ type UpdateKey struct {
RateLimitOverTime int `json:"rateLimitOverTime"`
RateLimitUnit TimeUnit `json:"rateLimitUnit"`
AllowedPaths *[]PathConfig `json:"allowedPaths,omitempty"`
+ ShouldLogRequest *bool `json:"shouldLogRequest"`
+ ShouldLogResponse *bool `json:"shouldLogResponse"`
}
func (uk *UpdateKey) Validate() error {
@@ -127,6 +129,8 @@ type RequestKey struct {
SettingId string `json:"settingId"`
AllowedPaths []PathConfig `json:"allowedPaths"`
SettingIds []string `json:"settingIds"`
+ ShouldLogRequest bool `json:"shouldLogRequest"`
+ ShouldLogResponse bool `json:"shouldLogResponse"`
}
func (rk *RequestKey) Validate() error {
@@ -271,6 +275,8 @@ type ResponseKey struct {
SettingId string `json:"settingId"`
AllowedPaths []PathConfig `json:"allowedPaths"`
SettingIds []string `json:"settingIds"`
+ ShouldLogRequest bool `json:"shouldLogRequest"`
+ ShouldLogResponse bool `json:"shouldLogResponse"`
}
func (rk *ResponseKey) GetSettingIds() []string {
From be93ad83e386da96c041a07a1a3c4c09685f1a81 Mon Sep 17 00:00:00 2001
From: Spike Lu
Date: Wed, 21 Feb 2024 13:32:36 -0800
Subject: [PATCH 41/71] add userId as a new filter for retrieving events
---
internal/manager/reporting.go | 6 +++---
1 file changed, 3 insertions(+), 3 deletions(-)
diff --git a/internal/manager/reporting.go b/internal/manager/reporting.go
index b78d359..e944ecc 100644
--- a/internal/manager/reporting.go
+++ b/internal/manager/reporting.go
@@ -15,7 +15,7 @@ type keyStorage interface {
}
type eventStorage interface {
- GetEvents(customId string, keyIds []string, start, end int64) ([]*event.Event, error)
+ GetEvents(userId, customId string, keyIds []string, start, end int64) ([]*event.Event, error)
GetEventDataPoints(start, end, increment int64, tags, keyIds, customIds []string, filters []string) ([]*event.DataPoint, error)
GetLatencyPercentiles(start, end int64, tags, keyIds []string) ([]float64, error)
}
@@ -77,8 +77,8 @@ func (rm *ReportingManager) GetKeyReporting(keyId string) (*key.KeyReporting, er
}, err
}
-func (rm *ReportingManager) GetEvents(customId string, keyIds []string, start, end int64) ([]*event.Event, error) {
- events, err := rm.es.GetEvents(customId, keyIds, start, end)
+func (rm *ReportingManager) GetEvents(userId, customId string, keyIds []string, start, end int64) ([]*event.Event, error) {
+ events, err := rm.es.GetEvents(userId, customId, keyIds, start, end)
if err != nil {
return nil, err
}
From 29d599ac5f8177698476602d5516cfd0239e7e0d Mon Sep 17 00:00:00 2001
From: Spike Lu
Date: Wed, 21 Feb 2024 13:33:06 -0800
Subject: [PATCH 42/71] update logs
---
internal/message/handler.go | 4 ++--
1 file changed, 2 insertions(+), 2 deletions(-)
diff --git a/internal/message/handler.go b/internal/message/handler.go
index 4bf23ad..d39482c 100644
--- a/internal/message/handler.go
+++ b/internal/message/handler.go
@@ -99,7 +99,7 @@ func (h *Handler) HandleEvent(m Message) error {
err := h.recorder.RecordEvent(e)
if err != nil {
stats.Incr("bricksllm.message.handler.handle_event.record_event_error", nil, 1)
- h.log.Sugar().Debugf("error when publishin event: %v", err)
+ h.log.Sugar().Debugf("error when publish in event: %v", err)
return err
}
@@ -243,12 +243,12 @@ func (h *Handler) HandleEventWithRequestAndResponse(m Message) error {
stats.Incr("bricksllm.message.handler.handle_event_with_request_and_response.handle_validation_result_error", nil, 1)
h.log.Debug("error when handling validation result", zap.Error(err))
}
-
}
start := time.Now()
err := h.recorder.RecordEvent(e.Event)
if err != nil {
+ h.log.Debug("error when recording event", zap.Error(err))
stats.Incr("bricksllm.message.handler.handle_event_with_request_and_response.record_event_error", nil, 1)
return err
}
From 9343c29684c941415adbddd05e07954340e5e9e6 Mon Sep 17 00:00:00 2001
From: Spike Lu
Date: Wed, 21 Feb 2024 13:36:25 -0800
Subject: [PATCH 43/71] update logs
---
internal/message/handler.go | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/internal/message/handler.go b/internal/message/handler.go
index d39482c..0d1b033 100644
--- a/internal/message/handler.go
+++ b/internal/message/handler.go
@@ -248,7 +248,7 @@ func (h *Handler) HandleEventWithRequestAndResponse(m Message) error {
start := time.Now()
err := h.recorder.RecordEvent(e.Event)
if err != nil {
- h.log.Debug("error when recording event", zap.Error(err))
+ h.log.Debug("error when recording an event", zap.Error(err))
stats.Incr("bricksllm.message.handler.handle_event_with_request_and_response.record_event_error", nil, 1)
return err
}
From 6eed6790572d83ccce3764b0779376b86f3f6476 Mon Sep 17 00:00:00 2001
From: Spike Lu
Date: Wed, 21 Feb 2024 13:36:44 -0800
Subject: [PATCH 44/71] update interface
---
internal/server/web/admin/admin.go | 11 ++++++-----
1 file changed, 6 insertions(+), 5 deletions(-)
diff --git a/internal/server/web/admin/admin.go b/internal/server/web/admin/admin.go
index 7766e66..ea650b8 100644
--- a/internal/server/web/admin/admin.go
+++ b/internal/server/web/admin/admin.go
@@ -38,7 +38,7 @@ type KeyManager interface {
type KeyReportingManager interface {
GetKeyReporting(keyId string) (*key.KeyReporting, error)
- GetEvents(customId string, keyIds []string, start int64, end int64) ([]*event.Event, error)
+ GetEvents(userId, customId string, keyIds []string, start int64, end int64) ([]*event.Event, error)
GetEventReporting(e *event.ReportingRequest) (*event.ReportingResponse, error)
}
@@ -821,13 +821,14 @@ func getGetEventsHandler(m KeyReportingManager, log *zap.Logger, prod bool) gin.
cid := c.GetString(correlationId)
customId, ciok := c.GetQuery("customId")
+ userId, uiok := c.GetQuery("userId")
keyIds, kiok := c.GetQueryArray("keyIds")
- if !ciok && !kiok {
+ if !ciok && !kiok && !uiok {
c.JSON(http.StatusBadRequest, &ErrorResponse{
Type: "/errors/no-filters-empty",
- Title: "neither customId nor keyIds are specified",
+ Title: "none of customId, keyIds and userId is specified",
Status: http.StatusBadRequest,
- Detail: "both query params customId and keyIds are empty. either of them is required for retrieving events.",
+ Detail: "customId, userId and keyIds are empty. one of them is required for retrieving events.",
Instance: path,
})
@@ -895,7 +896,7 @@ func getGetEventsHandler(m KeyReportingManager, log *zap.Logger, prod bool) gin.
qend = parsedEnd
}
- evs, err := m.GetEvents(customId, keyIds, qstart, qend)
+ evs, err := m.GetEvents(userId, customId, keyIds, qstart, qend)
if err != nil {
stats.Incr("bricksllm.admin.get_get_events_handler.get_events_error", nil, 1)
From 778f14ebf601a5959c7017ff30e14574e267ccde Mon Sep 17 00:00:00 2001
From: Spike Lu
Date: Wed, 21 Feb 2024 13:37:56 -0800
Subject: [PATCH 45/71] add userId, request, and response to event
---
internal/server/web/proxy/middleware.go | 48 +++++++++++++++++++++++++
1 file changed, 48 insertions(+)
diff --git a/internal/server/web/proxy/middleware.go b/internal/server/web/proxy/middleware.go
index 7558ff4..aafd518 100644
--- a/internal/server/web/proxy/middleware.go
+++ b/internal/server/web/proxy/middleware.go
@@ -131,6 +131,16 @@ func getProvider(c *gin.Context) string {
return ""
}
+type responseWriter struct {
+ gin.ResponseWriter
+ body *bytes.Buffer
+}
+
+func (w responseWriter) Write(b []byte) (int, error) {
+ w.body.Write(b)
+ return w.ResponseWriter.Write(b)
+}
+
func getMiddleware(kms keyMemStorage, cpm CustomProvidersManager, rm routeManager, a authenticator, prod, private bool, e estimator, ae anthropicEstimator, aoe azureEstimator, v validator, ks keyStorage, log *zap.Logger, rlm rateLimitManager, pub publisher, prefix string, ac accessCache) gin.HandlerFunc {
return func(c *gin.Context) {
if c == nil || c.Request == nil {
@@ -144,11 +154,17 @@ func getMiddleware(kms keyMemStorage, cpm CustomProvidersManager, rm routeManage
return
}
+ blw := &responseWriter{body: bytes.NewBufferString(""), ResponseWriter: c.Writer}
+ c.Writer = blw
+
cid := util.NewUuid()
c.Set(correlationId, cid)
start := time.Now()
enrichedEvent := &event.EventWithRequestAndContent{}
+ requestBytes := []byte(`{}`)
+ responseBytes := []byte(`{}`)
+ userId := ""
customId := c.Request.Header.Get("X-CUSTOM-EVENT-ID")
defer func() {
@@ -202,6 +218,9 @@ func getMiddleware(kms keyMemStorage, cpm CustomProvidersManager, rm routeManage
Path: c.FullPath(),
Method: c.Request.Method,
CustomId: customId,
+ Request: requestBytes,
+ Response: responseBytes,
+ UserId: userId,
}
enrichedEvent.Event = evt
@@ -273,6 +292,10 @@ func getMiddleware(kms keyMemStorage, cpm CustomProvidersManager, rm routeManage
return
}
+ if kc.ShouldLogRequest {
+ requestBytes = body
+ }
+
if c.Request.Method != http.MethodGet {
c.Request.Body = io.NopCloser(bytes.NewReader(body))
}
@@ -289,6 +312,10 @@ func getMiddleware(kms keyMemStorage, cpm CustomProvidersManager, rm routeManage
return
}
+ if cr.Metadata != nil {
+ userId = cr.Metadata.UserId
+ }
+
enrichedEvent.Request = cr
// tks := ae.Count(cr.Prompt)
@@ -382,6 +409,8 @@ func getMiddleware(kms keyMemStorage, cpm CustomProvidersManager, rm routeManage
return
}
+ userId = er.User
+
if rc.CacheConfig != nil && rc.CacheConfig.Enabled {
c.Set("cache_key", route.ComputeCacheKeyForEmbeddingsRequest(r, er))
}
@@ -400,6 +429,7 @@ func getMiddleware(kms keyMemStorage, cpm CustomProvidersManager, rm routeManage
return
}
+ userId = ccr.User
enrichedEvent.Request = ccr
logRequest(log, prod, private, cid, ccr)
@@ -425,6 +455,8 @@ func getMiddleware(kms keyMemStorage, cpm CustomProvidersManager, rm routeManage
return
}
+ userId = ccr.User
+
enrichedEvent.Request = ccr
logRequest(log, prod, private, cid, ccr)
@@ -449,6 +481,8 @@ func getMiddleware(kms keyMemStorage, cpm CustomProvidersManager, rm routeManage
return
}
+ userId = er.User
+
c.Set("model", "ada")
c.Set("encoding_format", string(er.EncodingFormat))
@@ -469,6 +503,8 @@ func getMiddleware(kms keyMemStorage, cpm CustomProvidersManager, rm routeManage
return
}
+ userId = ccr.User
+
enrichedEvent.Request = ccr
c.Set("model", ccr.Model)
@@ -497,6 +533,8 @@ func getMiddleware(kms keyMemStorage, cpm CustomProvidersManager, rm routeManage
return
}
+ userId = er.User
+
c.Set("model", string(er.Model))
c.Set("encoding_format", string(er.EncodingFormat))
@@ -532,6 +570,9 @@ func getMiddleware(kms keyMemStorage, cpm CustomProvidersManager, rm routeManage
model := c.PostForm("model")
size := c.PostForm("size")
user := c.PostForm("user")
+
+ userId = user
+
responseFormat := c.PostForm("response_format")
n, _ := strconv.Atoi(c.PostForm("n"))
@@ -548,6 +589,9 @@ func getMiddleware(kms keyMemStorage, cpm CustomProvidersManager, rm routeManage
model := c.PostForm("model")
size := c.PostForm("size")
user := c.PostForm("user")
+
+ userId = user
+
responseFormat := c.PostForm("response_format")
n, _ := strconv.Atoi(c.PostForm("n"))
@@ -795,6 +839,10 @@ func getMiddleware(kms keyMemStorage, cpm CustomProvidersManager, rm routeManage
}
c.Next()
+
+ if kc.ShouldLogResponse {
+ responseBytes = blw.body.Bytes()
+ }
}
}
From 56eeb0929b254e827be5c6750fd00bf65c5b4948 Mon Sep 17 00:00:00 2001
From: Spike Lu
Date: Wed, 21 Feb 2024 13:39:00 -0800
Subject: [PATCH 46/71] integrate new fields
---
internal/storage/postgresql/postgresql.go | 80 +++++++++++++++++++----
1 file changed, 69 insertions(+), 11 deletions(-)
diff --git a/internal/storage/postgresql/postgresql.go b/internal/storage/postgresql/postgresql.go
index 87542ad..8e35df1 100644
--- a/internal/storage/postgresql/postgresql.go
+++ b/internal/storage/postgresql/postgresql.go
@@ -101,7 +101,7 @@ func (s *Store) AlterKeysTable() error {
END IF;
END
$$;
- ALTER TABLE keys ADD COLUMN IF NOT EXISTS setting_id VARCHAR(255), ADD COLUMN IF NOT EXISTS allowed_paths JSONB, ADD COLUMN IF NOT EXISTS setting_ids VARCHAR(255)[] NOT NULL DEFAULT ARRAY[]::VARCHAR(255)[];
+ ALTER TABLE keys ADD COLUMN IF NOT EXISTS setting_id VARCHAR(255), ADD COLUMN IF NOT EXISTS allowed_paths JSONB, ADD COLUMN IF NOT EXISTS setting_ids VARCHAR(255)[] NOT NULL DEFAULT ARRAY[]::VARCHAR(255)[], ADD COLUMN IF NOT EXISTS should_log_request BOOLEAN NOT NULL DEFAULT FALSE, ADD COLUMN IF NOT EXISTS should_log_response BOOLEAN NOT NULL DEFAULT FALSE;
`
ctxTimeout, cancel := context.WithTimeout(context.Background(), s.wt)
@@ -142,7 +142,7 @@ func (s *Store) CreateEventsTable() error {
func (s *Store) AlterEventsTable() error {
alterTableQuery := `
- ALTER TABLE events ADD COLUMN IF NOT EXISTS path VARCHAR(255), ADD COLUMN IF NOT EXISTS method VARCHAR(255), ADD COLUMN IF NOT EXISTS custom_id VARCHAR(255)
+ ALTER TABLE events ADD COLUMN IF NOT EXISTS path VARCHAR(255), ADD COLUMN IF NOT EXISTS method VARCHAR(255), ADD COLUMN IF NOT EXISTS custom_id VARCHAR(255), ADD COLUMN IF NOT EXISTS request JSONB, ADD COLUMN IF NOT EXISTS response JSONB, ADD COLUMN IF NOT EXISTS user_id VARCHAR(255) NOT NULL DEFAULT '';
`
ctxTimeout, cancel := context.WithTimeout(context.Background(), s.wt)
@@ -211,8 +211,8 @@ func (s *Store) DropKeysTable() error {
func (s *Store) InsertEvent(e *event.Event) error {
query := `
- INSERT INTO events (event_id, created_at, tags, key_id, cost_in_usd, provider, model, status_code, prompt_token_count, completion_token_count, latency_in_ms, path, method, custom_id)
- VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14)
+ INSERT INTO events (event_id, created_at, tags, key_id, cost_in_usd, provider, model, status_code, prompt_token_count, completion_token_count, latency_in_ms, path, method, custom_id, request, response, user_id)
+ VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16, $17)
`
values := []any{
@@ -230,6 +230,9 @@ func (s *Store) InsertEvent(e *event.Event) error {
e.Path,
e.Method,
e.CustomId,
+ e.Request,
+ e.Response,
+ e.UserId,
}
ctx, cancel := context.WithTimeout(context.Background(), s.wt)
@@ -241,12 +244,30 @@ func (s *Store) InsertEvent(e *event.Event) error {
return nil
}
-func (s *Store) GetEvents(customId string, keyIds []string, start int64, end int64) ([]*event.Event, error) {
- if len(customId) == 0 && len(keyIds) == 0 {
- return nil, errors.New("neither customId nor keyIds are specified")
+func shouldAddAnd(userId, customId string, keyIds []string) bool {
+ num := 0
+
+ if len(userId) == 0 {
+ num++
+ }
+
+ if len(customId) == 0 {
+ num++
+ }
+
+ if len(keyIds) == 0 {
+ num++
}
- if len(keyIds) == 0 && (start == 0 || end == 0) {
+ return num >= 2
+}
+
+func (s *Store) GetEvents(userId string, customId string, keyIds []string, start int64, end int64) ([]*event.Event, error) {
+ if len(customId) == 0 && len(keyIds) == 0 && len(userId) == 0 {
+ return nil, errors.New("none of customId, keyIds and userId is specified")
+ }
+
+ if len(keyIds) != 0 && (start == 0 || end == 0) {
return nil, errors.New("keyIds are provided but either start or end is not specified")
}
@@ -258,7 +279,15 @@ func (s *Store) GetEvents(customId string, keyIds []string, start int64, end int
query += fmt.Sprintf(" custom_id = '%s'", customId)
}
- if len(customId) != 0 && len(keyIds) != 0 {
+ if len(customId) > 0 && len(userId) > 0 {
+ query += " AND"
+ }
+
+ if len(userId) != 0 {
+ query += fmt.Sprintf(" user_id = '%s'", userId)
+ }
+
+ if (len(customId) > 0 || len(userId) > 0) && len(keyIds) > 0 {
query += " AND"
}
@@ -300,6 +329,9 @@ func (s *Store) GetEvents(customId string, keyIds []string, start int64, end int
&path,
&method,
&customId,
+ &e.Request,
+ &e.Response,
+ &e.UserId,
); err != nil {
return nil, err
}
@@ -583,6 +615,8 @@ func (s *Store) GetKeys(tags, keyIds []string, provider string) ([]*key.Response
&settingId,
&data,
pq.Array(&k.SettingIds),
+ &k.ShouldLogRequest,
+ &k.ShouldLogResponse,
); err != nil {
return nil, err
}
@@ -639,6 +673,8 @@ func (s *Store) GetKey(keyId string) (*key.ResponseKey, error) {
&settingId,
&data,
pq.Array(&k.SettingIds),
+ &k.ShouldLogRequest,
+ &k.ShouldLogResponse,
); err != nil {
return nil, err
}
@@ -780,6 +816,8 @@ func (s *Store) GetAllKeys() ([]*key.ResponseKey, error) {
&settingId,
&data,
pq.Array(&k.SettingIds),
+ &k.ShouldLogRequest,
+ &k.ShouldLogResponse,
); err != nil {
return nil, err
}
@@ -874,6 +912,8 @@ func (s *Store) GetUpdatedKeys(updatedAt int64) ([]*key.ResponseKey, error) {
&settingId,
&data,
pq.Array(&k.SettingIds),
+ &k.ShouldLogRequest,
+ &k.ShouldLogResponse,
); err != nil {
return nil, err
}
@@ -1021,6 +1061,18 @@ func (s *Store) UpdateKey(id string, uk *key.UpdateKey) (*key.ResponseKey, error
counter++
}
+ if uk.ShouldLogRequest != nil {
+ values = append(values, *uk.ShouldLogRequest)
+ fields = append(fields, fmt.Sprintf("should_log_request = $%d", counter))
+ counter++
+ }
+
+ if uk.ShouldLogResponse != nil {
+ values = append(values, *uk.ShouldLogResponse)
+ fields = append(fields, fmt.Sprintf("should_log_response = $%d", counter))
+ counter++
+ }
+
if uk.AllowedPaths != nil {
data, err := json.Marshal(uk.AllowedPaths)
if err != nil {
@@ -1057,6 +1109,8 @@ func (s *Store) UpdateKey(id string, uk *key.UpdateKey) (*key.ResponseKey, error
&settingId,
&data,
pq.Array(&k.SettingIds),
+ &k.ShouldLogRequest,
+ &k.ShouldLogResponse,
); err != nil {
if err == sql.ErrNoRows {
return nil, internal_errors.NewNotFoundError(fmt.Sprintf("key not found for id: %s", id))
@@ -1196,8 +1250,8 @@ func (s *Store) CreateProviderSetting(setting *provider.Setting) (*provider.Sett
func (s *Store) CreateKey(rk *key.RequestKey) (*key.ResponseKey, error) {
query := `
- INSERT INTO keys (name, created_at, updated_at, tags, revoked, key_id, key, revoked_reason, cost_limit_in_usd, cost_limit_in_usd_over_time, cost_limit_in_usd_unit, rate_limit_over_time, rate_limit_unit, ttl, setting_id, allowed_paths, setting_ids)
- VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16, $17)
+ INSERT INTO keys (name, created_at, updated_at, tags, revoked, key_id, key, revoked_reason, cost_limit_in_usd, cost_limit_in_usd_over_time, cost_limit_in_usd_unit, rate_limit_over_time, rate_limit_unit, ttl, setting_id, allowed_paths, setting_ids, should_log_request, should_log_response)
+ VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16, $17, $18, $19)
RETURNING *;
`
@@ -1224,6 +1278,8 @@ func (s *Store) CreateKey(rk *key.RequestKey) (*key.ResponseKey, error) {
rk.SettingId,
rdata,
sliceToSqlStringArray(rk.SettingIds),
+ rk.ShouldLogRequest,
+ rk.ShouldLogResponse,
}
ctxTimeout, cancel := context.WithTimeout(context.Background(), s.wt)
@@ -1251,6 +1307,8 @@ func (s *Store) CreateKey(rk *key.RequestKey) (*key.ResponseKey, error) {
&settingId,
&data,
pq.Array(&k.SettingIds),
+ &k.ShouldLogRequest,
+ &k.ShouldLogResponse,
); err != nil {
return nil, err
}
From 16f77114929dce5fbca178c104c16aff8fd3cdce Mon Sep 17 00:00:00 2001
From: Spike Lu
Date: Wed, 21 Feb 2024 13:46:08 -0800
Subject: [PATCH 47/71] update CHANGELOG
---
CHANGELOG.md | 5 +++++
1 file changed, 5 insertions(+)
diff --git a/CHANGELOG.md b/CHANGELOG.md
index ef8522c..cd95c17 100644
--- a/CHANGELOG.md
+++ b/CHANGELOG.md
@@ -1,3 +1,8 @@
+## 1.10.0 - 2024-02-21
+### Added
+- Added `userId` as a new filter option for get events API endpoint
+- Added option to store request and response using keys
+
## 1.9.6 - 2024-02-18
### Added
- Added support for updating key cost limit and rate limit
From 4a25469aae514e9e1348a3858c193c4b836476cf Mon Sep 17 00:00:00 2001
From: Spike Lu
Date: Wed, 21 Feb 2024 13:59:23 -0800
Subject: [PATCH 48/71] added userId as a querying option for data points
---
internal/event/reporting.go | 2 ++
internal/manager/reporting.go | 4 ++--
internal/storage/postgresql/postgresql.go | 17 ++++++++++++++++-
3 files changed, 20 insertions(+), 3 deletions(-)
diff --git a/internal/event/reporting.go b/internal/event/reporting.go
index 5fab273..91dbf6c 100644
--- a/internal/event/reporting.go
+++ b/internal/event/reporting.go
@@ -11,6 +11,7 @@ type DataPoint struct {
Model string `json:"model"`
KeyId string `json:"keyId"`
CustomId string `json:"customId"`
+ UserId string `json:"userId"`
}
type ReportingResponse struct {
@@ -23,6 +24,7 @@ type ReportingRequest struct {
KeyIds []string `json:"keyIds"`
Tags []string `json:"tags"`
CustomIds []string `json:"customIds"`
+ UserIds []string `json:"userIds"`
Start int64 `json:"start"`
End int64 `json:"end"`
Increment int64 `json:"increment"`
diff --git a/internal/manager/reporting.go b/internal/manager/reporting.go
index e944ecc..ef9f4d6 100644
--- a/internal/manager/reporting.go
+++ b/internal/manager/reporting.go
@@ -16,7 +16,7 @@ type keyStorage interface {
type eventStorage interface {
GetEvents(userId, customId string, keyIds []string, start, end int64) ([]*event.Event, error)
- GetEventDataPoints(start, end, increment int64, tags, keyIds, customIds []string, filters []string) ([]*event.DataPoint, error)
+ GetEventDataPoints(start, end, increment int64, tags, keyIds, customIds, userIds []string, filters []string) ([]*event.DataPoint, error)
GetLatencyPercentiles(start, end int64, tags, keyIds []string) ([]float64, error)
}
@@ -35,7 +35,7 @@ func NewReportingManager(cs costStorage, ks keyStorage, es eventStorage) *Report
}
func (rm *ReportingManager) GetEventReporting(e *event.ReportingRequest) (*event.ReportingResponse, error) {
- dataPoints, err := rm.es.GetEventDataPoints(e.Start, e.End, e.Increment, e.Tags, e.KeyIds, e.CustomIds, e.Filters)
+ dataPoints, err := rm.es.GetEventDataPoints(e.Start, e.End, e.Increment, e.Tags, e.KeyIds, e.CustomIds, e.UserIds, e.Filters)
if err != nil {
return nil, err
}
diff --git a/internal/storage/postgresql/postgresql.go b/internal/storage/postgresql/postgresql.go
index 8e35df1..817c446 100644
--- a/internal/storage/postgresql/postgresql.go
+++ b/internal/storage/postgresql/postgresql.go
@@ -408,7 +408,7 @@ func (s *Store) GetLatencyPercentiles(start, end int64, tags, keyIds []string) (
return data, nil
}
-func (s *Store) GetEventDataPoints(start, end, increment int64, tags, keyIds, customIds []string, filters []string) ([]*event.DataPoint, error) {
+func (s *Store) GetEventDataPoints(start, end, increment int64, tags, keyIds, customIds, userIds []string, filters []string) ([]*event.DataPoint, error) {
groupByQuery := "GROUP BY time_series_table.series"
selectQuery := "SELECT series AS time_stamp, COALESCE(COUNT(events_table.event_id),0) AS num_of_requests, COALESCE(SUM(events_table.cost_in_usd),0) AS cost_in_usd, COALESCE(SUM(events_table.latency_in_ms),0) AS latency_in_ms, COALESCE(SUM(events_table.prompt_token_count),0) AS prompt_token_count, COALESCE(SUM(events_table.completion_token_count),0) AS completion_token_count, COALESCE(SUM(CASE WHEN status_code = 200 THEN 1 END),0) AS success_count"
@@ -428,6 +428,11 @@ func (s *Store) GetEventDataPoints(start, end, increment int64, tags, keyIds, cu
groupByQuery += ",events_table.custom_id"
selectQuery += ",events_table.custom_id as customId"
}
+
+ if filter == "userId" {
+ groupByQuery += ",events_table.user_id"
+ selectQuery += ",events_table.user_id as userId"
+ }
}
}
@@ -467,6 +472,10 @@ func (s *Store) GetEventDataPoints(start, end, increment int64, tags, keyIds, cu
conditionBlock += fmt.Sprintf("AND custom_id = ANY('%s')", sliceToSqlStringArray(customIds))
}
+ if len(userIds) != 0 {
+ conditionBlock += fmt.Sprintf("AND user_id = ANY('%s')", sliceToSqlStringArray(userIds))
+ }
+
eventSelectionBlock += conditionBlock
eventSelectionBlock += ")"
@@ -487,6 +496,7 @@ func (s *Store) GetEventDataPoints(start, end, increment int64, tags, keyIds, cu
var model sql.NullString
var keyId sql.NullString
var customId sql.NullString
+ var userId sql.NullString
additional := []any{
&e.TimeStamp,
@@ -511,6 +521,10 @@ func (s *Store) GetEventDataPoints(start, end, increment int64, tags, keyIds, cu
if filter == "customId" {
additional = append(additional, &customId)
}
+
+ if filter == "userId" {
+ additional = append(additional, &userId)
+ }
}
}
@@ -524,6 +538,7 @@ func (s *Store) GetEventDataPoints(start, end, increment int64, tags, keyIds, cu
pe.Model = model.String
pe.KeyId = keyId.String
pe.CustomId = customId.String
+ pe.UserId = userId.String
data = append(data, pe)
}
From cd97e692f9a44078e327c06288bc0b2551098983 Mon Sep 17 00:00:00 2001
From: Spike Lu
Date: Mon, 26 Feb 2024 16:10:59 -0800
Subject: [PATCH 49/71] add new package
---
go.mod | 3 ++
go.sum | 13 +++++++
internal/provider/openai/fetcher.go | 1 +
internal/storage/redis/thread-cache.go | 53 ++++++++++++++++++++++++++
4 files changed, 70 insertions(+)
create mode 100644 internal/provider/openai/fetcher.go
create mode 100644 internal/storage/redis/thread-cache.go
diff --git a/go.mod b/go.mod
index 9c500a9..8f06052 100644
--- a/go.mod
+++ b/go.mod
@@ -19,6 +19,9 @@ require (
require (
github.com/Microsoft/go-winio v0.5.0 // indirect
+ github.com/asticode/go-astikit v0.20.0 // indirect
+ github.com/asticode/go-astisub v0.26.2 // indirect
+ github.com/asticode/go-astits v1.8.0 // indirect
github.com/bytedance/sonic v1.9.1 // indirect
github.com/cenkalti/backoff/v4 v4.2.1 // indirect
github.com/cespare/xxhash/v2 v2.2.0 // indirect
diff --git a/go.sum b/go.sum
index b71b7c1..ff47f27 100644
--- a/go.sum
+++ b/go.sum
@@ -2,6 +2,12 @@ github.com/DataDog/datadog-go/v5 v5.3.0 h1:2q2qjFOb3RwAZNU+ez27ZVDwErJv5/VpbBPpr
github.com/DataDog/datadog-go/v5 v5.3.0/go.mod h1:XRDJk1pTc00gm+ZDiBKsjh7oOOtJfYfglVCmFb8C2+Q=
github.com/Microsoft/go-winio v0.5.0 h1:Elr9Wn+sGKPlkaBvwu4mTrxtmOp3F3yV9qhaHbXGjwU=
github.com/Microsoft/go-winio v0.5.0/go.mod h1:JPGBdM1cNvN/6ISo+n8V5iA4v8pBzdOpzfwIujj1a84=
+github.com/asticode/go-astikit v0.20.0 h1:+7N+J4E4lWx2QOkRdOf6DafWJMv6O4RRfgClwQokrH8=
+github.com/asticode/go-astikit v0.20.0/go.mod h1:h4ly7idim1tNhaVkdVBeXQZEE3L0xblP7fCWbgwipF0=
+github.com/asticode/go-astisub v0.26.2 h1:cdEXcm+SUSmYCEPTQYbbfCECnmQoIFfH6pF8wDJhfVo=
+github.com/asticode/go-astisub v0.26.2/go.mod h1:WTkuSzFB+Bp7wezuSf2Oxulj5A8zu2zLRVFf6bIFQK8=
+github.com/asticode/go-astits v1.8.0 h1:rf6aiiGn/QhlFjNON1n5plqF3Fs025XLUwiQ0NB6oZg=
+github.com/asticode/go-astits v1.8.0/go.mod h1:DkOWmBNQpnr9mv24KfZjq4JawCFX1FCqjLVGvO0DygQ=
github.com/benbjohnson/clock v1.1.0 h1:Q92kusRqC1XV2MjkWETPvjJVqKetz1OzxZB7mHJLju8=
github.com/bsm/ginkgo/v2 v2.7.0 h1:ItPMPH90RbmZJt5GtkcNvIRuGEdwlBItdNVoyzaNQao=
github.com/bsm/gomega v1.26.0 h1:LhQm+AFcgV2M0WyKroMASzAzCAJVpAxQXv4SaI9a69Y=
@@ -71,6 +77,7 @@ github.com/pelletier/go-toml/v2 v2.0.8 h1:0ctb6s9mE31h0/lhu+J6OPmVeDxJn+kYnJc2jZ
github.com/pelletier/go-toml/v2 v2.0.8/go.mod h1:vuYfssBdrU2XDZ9bYydBu6t+6a6PYNcZljzZR9VXg+4=
github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4=
github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
+github.com/pkg/profile v1.4.0/go.mod h1:NWz/XGvpEW1FyYQ7fCx4dqYBLlfTcE+A9FLAkNKqjFE=
github.com/pkoukk/tiktoken-go v0.1.5 h1:hAlT4dCf6Uk50x8E7HQrddhH3EWMKUN+LArExQQsQx4=
github.com/pkoukk/tiktoken-go v0.1.5/go.mod h1:9NiV+i9mJKGj1rYOT+njbv+ZwA/zJxYdewGl6qVatpg=
github.com/pkoukk/tiktoken-go v0.1.6 h1:JF0TlJzhTbrI30wCvFuiw6FzP2+/bR+FIxUdgEAcUsw=
@@ -94,6 +101,7 @@ github.com/stretchr/objx v0.5.0 h1:1zr/of2m5FGMsad5YfcqgdqdWrIhu+EBEJRhR1U7z/c=
github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo=
github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs=
github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI=
+github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4=
github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU=
@@ -126,11 +134,13 @@ golang.org/x/arch v0.3.0 h1:02VY4/ZcO/gBOH6PUaoiptASxtXU10jazRCP865E97k=
golang.org/x/arch v0.3.0/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8=
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
+golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto=
golang.org/x/crypto v0.9.0 h1:LF6fAI+IutBocDJ2OT0Q1g8plpYljMZ4+lty+dsqw3g=
golang.org/x/crypto v0.9.0/go.mod h1:yrmDGqONDYtNj3tH8X9dzUun2m2lzPa9ngI6/RUPGR0=
golang.org/x/mod v0.4.2/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA=
golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
+golang.org/x/net v0.0.0-20200904194848-62affa334b73/go.mod h1:/O7V0waA8r7cgGh81Ro3o1hOxt32SMVPicZroKQ2sZA=
golang.org/x/net v0.0.0-20210405180319-a5a99cb37ef4/go.mod h1:p54w0d4576C0XHj96bSt6lcn1PtDYWL6XObtHCRCNQM=
golang.org/x/net v0.10.0 h1:X2//UzNDwYmtCLn7To6G58Wr6f5ahEAQgKNzv9Y951M=
golang.org/x/net v0.10.0/go.mod h1:0qNGK6F8kojg2nk9dLZ2mShWaEBan6FAoqfSigmmuDg=
@@ -139,6 +149,7 @@ golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJ
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20191026070338-33540a1f6037/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
+golang.org/x/sys v0.0.0-20200323222414-85ca7c5b95cd/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20210124154548-22da62e12c0c/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20210330210617-4fbd30eecc44/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
@@ -150,6 +161,7 @@ golang.org/x/sys v0.8.0 h1:EBmGv8NaZBZTWvrbjNoL6HVt+IVy3QDQpJs7VRIw3tU=
golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
+golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk=
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
golang.org/x/text v0.9.0 h1:2sjJmO8cDvYveuX97RDLsxlyUxLl+GHoLxBiRdHllBE=
golang.org/x/text v0.9.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8=
@@ -166,6 +178,7 @@ google.golang.org/protobuf v1.30.0 h1:kPPoIgf3TsEvrm0PFe15JQ+570QVxYzEvvHqChK+cn
google.golang.org/protobuf v1.30.0/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
+gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
diff --git a/internal/provider/openai/fetcher.go b/internal/provider/openai/fetcher.go
new file mode 100644
index 0000000..0aac709
--- /dev/null
+++ b/internal/provider/openai/fetcher.go
@@ -0,0 +1 @@
+package openai
diff --git a/internal/storage/redis/thread-cache.go b/internal/storage/redis/thread-cache.go
new file mode 100644
index 0000000..eb780cf
--- /dev/null
+++ b/internal/storage/redis/thread-cache.go
@@ -0,0 +1,53 @@
+package redis
+
+import (
+ "context"
+ "time"
+
+ "github.com/redis/go-redis/v9"
+)
+
+type ThreadCache struct {
+ client *redis.Client
+ wt time.Duration
+ rt time.Duration
+}
+
+func NewThreadCache(c *redis.Client, wt time.Duration, rt time.Duration) *ThreadCache {
+ return &ThreadCache{
+ client: c,
+ wt: wt,
+ rt: rt,
+ }
+}
+
+func (ac *ThreadCache) Delete(key string) error {
+ ctx, cancel := context.WithTimeout(context.Background(), ac.wt)
+ defer cancel()
+ err := ac.client.Del(ctx, key).Err()
+ if err != nil {
+ return err
+ }
+
+ return nil
+}
+
+func (ac *ThreadCache) Set(key string, dur time.Duration) error {
+ ctx, cancel := context.WithTimeout(context.Background(), ac.wt)
+ defer cancel()
+ err := ac.client.Set(ctx, key, true, dur).Err()
+ if err != nil {
+ return err
+ }
+
+ return nil
+}
+
+func (ac *ThreadCache) GetThreadStatus(key string) bool {
+ ctx, cancel := context.WithTimeout(context.Background(), ac.rt)
+ defer cancel()
+
+ result := ac.client.Get(ctx, key)
+
+ return result.Err() != redis.Nil
+}
From 27df0d2edd18418e0f3d62b4dd7a0ab0b807d108 Mon Sep 17 00:00:00 2001
From: Spike Lu
Date: Mon, 26 Feb 2024 16:11:35 -0800
Subject: [PATCH 50/71] use time since
---
internal/server/web/proxy/route.go | 4 ++--
1 file changed, 2 insertions(+), 2 deletions(-)
diff --git a/internal/server/web/proxy/route.go b/internal/server/web/proxy/route.go
index 90c46af..94e546d 100644
--- a/internal/server/web/proxy/route.go
+++ b/internal/server/web/proxy/route.go
@@ -64,7 +64,7 @@ func getRouteHandler(prod, private bool, rm routeManager, ca cache, aoe azureEst
bytes, err := ca.GetBytes(cacheKey)
if err == nil && len(bytes) != 0 {
stats.Incr("bricksllm.proxy.get_route_handeler.success", nil, 1)
- stats.Timing("bricksllm.proxy.get_route_handeler.success_latency", time.Now().Sub(trueStart), nil, 1)
+ stats.Timing("bricksllm.proxy.get_route_handeler.success_latency", time.Since(trueStart), nil, 1)
c.Set("provider", "cached")
c.Data(http.StatusOK, "application/json", bytes)
@@ -115,7 +115,7 @@ func getRouteHandler(prod, private bool, rm routeManager, ca cache, aoe azureEst
res := runRes.Response
defer res.Body.Close()
- dur := time.Now().Sub(start)
+ dur := time.Since(start)
stats.Timing("bricksllm.proxy.get_route_handeler.latency", dur, nil, 1)
bytes, err := io.ReadAll(res.Body)
From e728fad158257211c68b3a350d61d27819ab2800 Mon Sep 17 00:00:00 2001
From: Spike Lu
Date: Mon, 26 Feb 2024 16:37:52 -0800
Subject: [PATCH 51/71] add cost tracking for speech api
---
internal/message/handler.go | 33 ++++++++++++++++++++++++++++-----
1 file changed, 28 insertions(+), 5 deletions(-)
diff --git a/internal/message/handler.go b/internal/message/handler.go
index 0d1b033..5ceea1e 100644
--- a/internal/message/handler.go
+++ b/internal/message/handler.go
@@ -2,6 +2,7 @@ package message
import (
"errors"
+ "net/http"
"strings"
"time"
@@ -24,6 +25,7 @@ type anthropicEstimator interface {
}
type estimator interface {
+ EstimateSpeechCost(input string, model string) (float64, error)
EstimateChatCompletionPromptCostWithTokenCounts(r *goopenai.ChatCompletionRequest) (int, float64, error)
EstimateEmbeddingsCost(r *goopenai.EmbeddingRequest) (float64, error)
EstimateChatCompletionStreamCostWithTokenCounts(model, content string) (int, float64, error)
@@ -103,7 +105,7 @@ func (h *Handler) HandleEvent(m Message) error {
return err
}
- stats.Timing("bricksllm.message.handler.handle_event.record_event_latency", time.Now().Sub(start), nil, 1)
+ stats.Timing("bricksllm.message.handler.handle_event.record_event_latency", time.Since(start), nil, 1)
stats.Incr("bricksllm.message.handler.handle_event.success", nil, 1)
return nil
@@ -253,7 +255,7 @@ func (h *Handler) HandleEventWithRequestAndResponse(m Message) error {
return err
}
- stats.Timing("bricksllm.message.handler.handle_event_with_request_and_response.latency", time.Now().Sub(start), nil, 1)
+ stats.Timing("bricksllm.message.handler.handle_event_with_request_and_response.latency", time.Since(start), nil, 1)
stats.Incr("bricksllm.message.handler.handle_event_with_request_and_response.success", nil, 1)
return nil
@@ -269,7 +271,27 @@ func (h *Handler) decorateEvent(m Message) error {
return errors.New("message data cannot be parsed as event with request and response")
}
- if e.Event.Provider == "anthropic" && e.Event.Path == "/api/providers/anthropic/v1/complete" {
+ if e.Event.Path == "/api/providers/openai/v1/audio/speech" {
+ csr, ok := e.Request.(*goopenai.CreateSpeechRequest)
+ if !ok {
+ stats.Incr("bricksllm.message.handler.decorate_event.event_request_parsing_error", nil, 1)
+ h.log.Debug("event contains request that cannot be converted to anthropic completion request", zap.Any("data", m.Data))
+ return errors.New("event request data cannot be parsed as anthropic completon request")
+ }
+
+ if e.Event.Status == http.StatusOK {
+ cost, err := h.e.EstimateSpeechCost(csr.Input, string(csr.Model))
+ if err != nil {
+ stats.Incr("bricksllm.message.handler.decorate_event.estimate_prompt_cost", nil, 1)
+ h.log.Debug("event contains request that cannot be converted to anthropic completion request", zap.Error(err))
+ return err
+ }
+
+ e.Event.CostInUsd = cost
+ }
+ }
+
+ if e.Event.Path == "/api/providers/anthropic/v1/complete" {
cr, ok := e.Request.(*anthropic.CompletionRequest)
if !ok {
stats.Incr("bricksllm.message.handler.decorate_event.event_request_parsing_error", nil, 1)
@@ -298,11 +320,12 @@ func (h *Handler) decorateEvent(m Message) error {
}
e.Event.PromptTokenCount = tks
+
e.Event.CompletionTokenCount = completiontks
e.Event.CostInUsd = completionCost + cost
}
- if e.Event.Provider == "azure" && e.Event.Path == "/api/providers/azure/openai/deployments/:deployment_id/chat/completions" {
+ if e.Event.Path == "/api/providers/azure/openai/deployments/:deployment_id/chat/completions" {
ccr, ok := e.Request.(*goopenai.ChatCompletionRequest)
if !ok {
stats.Incr("bricksllm.message.handler.decorate_event.event_request_parsing_error", nil, 1)
@@ -335,7 +358,7 @@ func (h *Handler) decorateEvent(m Message) error {
}
}
- if e.Event.Provider == "openai" && e.Event.Path == "/api/providers/openai/v1/chat/completions" {
+ if e.Event.Path == "/api/providers/openai/v1/chat/completions" {
ccr, ok := e.Request.(*goopenai.ChatCompletionRequest)
if !ok {
stats.Incr("bricksllm.message.handler.decorate_event.event_request_parsing_error", nil, 1)
From 54277407716b2da3b1b756bb609158ea5d04f908 Mon Sep 17 00:00:00 2001
From: Spike Lu
Date: Mon, 26 Feb 2024 16:38:18 -0800
Subject: [PATCH 52/71] add cost calculation for audio models
---
internal/provider/openai/cost.go | 34 ++++++++++++++++++++++++++++++++
1 file changed, 34 insertions(+)
diff --git a/internal/provider/openai/cost.go b/internal/provider/openai/cost.go
index b0e6eac..c4daf11 100644
--- a/internal/provider/openai/cost.go
+++ b/internal/provider/openai/cost.go
@@ -4,6 +4,7 @@ import (
"encoding/json"
"errors"
"fmt"
+ "math"
"strings"
goopenai "github.com/sashabaranov/go-openai"
@@ -58,6 +59,11 @@ var OpenAiPerThousandTokenCost = map[string]map[string]float64{
"text-embedding-3-small": 0.00002,
"text-embedding-3-large": 0.00013,
},
+ "audio": {
+ "whisper-1": 0.006,
+ "tts-1": 0.015,
+ "tts-1-hd": 0.03,
+ },
"completion": {
"gpt-3.5-turbo-1106": 0.002,
"gpt-4-turbo-preview": 0.03,
@@ -203,6 +209,34 @@ func (ce *CostEstimator) EstimateChatCompletionStreamCostWithTokenCounts(model s
return tks, cost, nil
}
+func (ce *CostEstimator) EstimateTranscriptionCost(secs float64, model string) (float64, error) {
+ costMap, ok := ce.tokenCostMap["audio"]
+ if !ok {
+ return 0, errors.New("audio cost map is not provided")
+ }
+
+ cost, ok := costMap[model]
+ if !ok {
+ return 0, errors.New("model is not present in the audio cost map")
+ }
+
+ return math.Trunc(secs) / 60 * cost, nil
+}
+
+func (ce *CostEstimator) EstimateSpeechCost(input string, model string) (float64, error) {
+ costMap, ok := ce.tokenCostMap["audio"]
+ if !ok {
+ return 0, errors.New("audio cost map is not provided")
+ }
+
+ cost, ok := costMap[model]
+ if !ok {
+ return 0, errors.New("model is not present in the audio cost map")
+ }
+
+ return float64(len(input)) / 1000 * cost, nil
+}
+
func (ce *CostEstimator) EstimateEmbeddingsCost(r *goopenai.EmbeddingRequest) (float64, error) {
if len(string(r.Model)) == 0 {
return 0, errors.New("model is not provided")
From 1a61d3f6ab51d91f889f3fa79e5b24c35a0ab0a5 Mon Sep 17 00:00:00 2001
From: Spike Lu
Date: Mon, 26 Feb 2024 16:44:08 -0800
Subject: [PATCH 53/71] use time since
---
internal/server/web/admin/admin.go | 24 ++++++++++++------------
1 file changed, 12 insertions(+), 12 deletions(-)
diff --git a/internal/server/web/admin/admin.go b/internal/server/web/admin/admin.go
index ea650b8..d723be1 100644
--- a/internal/server/web/admin/admin.go
+++ b/internal/server/web/admin/admin.go
@@ -143,7 +143,7 @@ func getGetKeysHandler(m KeyManager, log *zap.Logger, prod bool) gin.HandlerFunc
start := time.Now()
defer func() {
- dur := time.Now().Sub(start)
+ dur := time.Since(start)
stats.Timing("bricksllm.admin.get_get_keys_handler.latency", dur, nil, 1)
}()
@@ -208,7 +208,7 @@ func getGetProviderSettingsHandler(m ProviderSettingsManager, log *zap.Logger, p
start := time.Now()
defer func() {
- dur := time.Now().Sub(start)
+ dur := time.Since(start)
stats.Timing("bricksllm.admin.get_get_provider_settings.latency", dur, nil, 1)
}()
@@ -258,7 +258,7 @@ func getCreateProviderSettingHandler(m ProviderSettingsManager, log *zap.Logger,
start := time.Now()
defer func() {
- dur := time.Now().Sub(start)
+ dur := time.Since(start)
stats.Timing("bricksllm.admin.get_create_provider_setting_handler.latency", dur, nil, 1)
}()
@@ -348,7 +348,7 @@ func getCreateKeyHandler(m KeyManager, log *zap.Logger, prod bool) gin.HandlerFu
start := time.Now()
defer func() {
- dur := time.Now().Sub(start)
+ dur := time.Since(start)
stats.Timing("bricksllm.admin.get_create_key_handler.latency", dur, nil, 1)
}()
@@ -438,7 +438,7 @@ func getUpdateProviderSettingHandler(m ProviderSettingsManager, log *zap.Logger,
start := time.Now()
defer func() {
- dur := time.Now().Sub(start)
+ dur := time.Since(start)
stats.Timing("bricksllm.admin.get_update_provider_setting_handler.latency", dur, nil, 1)
}()
@@ -540,7 +540,7 @@ func getUpdateKeyHandler(m KeyManager, log *zap.Logger, prod bool) gin.HandlerFu
start := time.Now()
defer func() {
- dur := time.Now().Sub(start)
+ dur := time.Since(start)
stats.Timing("bricksllm.admin.get_update_key_handler.latency", dur, nil, 1)
}()
@@ -715,7 +715,7 @@ func getGetEventMetricsHandler(m KeyReportingManager, log *zap.Logger, prod bool
start := time.Now()
defer func() {
- dur := time.Now().Sub(start)
+ dur := time.Since(start)
stats.Timing("bricksllm.admin.get_get_event_metrics.latency", dur, nil, 1)
}()
@@ -802,7 +802,7 @@ func getGetEventsHandler(m KeyReportingManager, log *zap.Logger, prod bool) gin.
start := time.Now()
defer func() {
- dur := time.Now().Sub(start)
+ dur := time.Since(start)
stats.Timing("bricksllm.admin.get_get_events_handler.latency", dur, nil, 1)
}()
@@ -923,7 +923,7 @@ func getGetKeyReportingHandler(m KeyReportingManager, log *zap.Logger, prod bool
start := time.Now()
defer func() {
- dur := time.Now().Sub(start)
+ dur := time.Since(start)
stats.Timing("bricksllm.admin.get_get_key_reporting_hanlder.latency", dur, nil, 1)
}()
@@ -1008,7 +1008,7 @@ func getCreateCustomProviderHandler(m CustomProvidersManager, log *zap.Logger, p
start := time.Now()
defer func() {
- dur := time.Now().Sub(start)
+ dur := time.Since(start)
stats.Timing("bricksllm.admin.get_create_custom_provider_handler.latency", dur, nil, 1)
}()
@@ -1096,7 +1096,7 @@ func getGetCustomProvidersHandler(m CustomProvidersManager, log *zap.Logger, pro
start := time.Now()
defer func() {
- dur := time.Now().Sub(start)
+ dur := time.Since(start)
stats.Timing("bricksllm.admin.get_get_custom_providers_handler.latency", dur, nil, 1)
}()
@@ -1144,7 +1144,7 @@ func getUpdateCustomProvidersHandler(m CustomProvidersManager, log *zap.Logger,
start := time.Now()
defer func() {
- dur := time.Now().Sub(start)
+ dur := time.Since(start)
stats.Timing("bricksllm.admin.get_update_custom_providers_handler.latency", dur, nil, 1)
}()
From a5818f530847d99c39dba78ded79441bafc20523 Mon Sep 17 00:00:00 2001
From: Spike Lu
Date: Mon, 26 Feb 2024 16:44:21 -0800
Subject: [PATCH 54/71] use time since
---
internal/server/web/admin/middleware.go | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/internal/server/web/admin/middleware.go b/internal/server/web/admin/middleware.go
index d4089a4..8704d44 100644
--- a/internal/server/web/admin/middleware.go
+++ b/internal/server/web/admin/middleware.go
@@ -19,7 +19,7 @@ func getAdminLoggerMiddleware(log *zap.Logger, prefix string, prod bool, adminPa
c.Set(correlationId, util.NewUuid())
start := time.Now()
c.Next()
- latency := time.Now().Sub(start).Milliseconds()
+ latency := time.Since(start).Milliseconds()
if !prod {
log.Sugar().Infof("%s | %d | %s | %s | %dms", prefix, c.Writer.Status(), c.Request.Method, c.FullPath(), latency)
}
From 04816e9f471609f6b6c8ae4a3c606369bd06bde1 Mon Sep 17 00:00:00 2001
From: Spike Lu
Date: Mon, 26 Feb 2024 16:44:41 -0800
Subject: [PATCH 55/71] use time since
---
internal/server/web/admin/route.go | 6 +++---
1 file changed, 3 insertions(+), 3 deletions(-)
diff --git a/internal/server/web/admin/route.go b/internal/server/web/admin/route.go
index 6699ea8..99d3bca 100644
--- a/internal/server/web/admin/route.go
+++ b/internal/server/web/admin/route.go
@@ -24,7 +24,7 @@ func getCreateRouteHandler(m RouteManager, log *zap.Logger, prod bool) gin.Handl
start := time.Now()
defer func() {
- dur := time.Now().Sub(start)
+ dur := time.Since(start)
stats.Timing("bricksllm.admin.get_create_route_handler.latency", dur, nil, 1)
}()
@@ -112,7 +112,7 @@ func getGetRouteHandler(m RouteManager, log *zap.Logger, prod bool) gin.HandlerF
start := time.Now()
defer func() {
- dur := time.Now().Sub(start)
+ dur := time.Since(start)
stats.Timing("bricksllm.admin.get_get_route_handler.latency", dur, nil, 1)
}()
@@ -174,7 +174,7 @@ func getGetRoutesHandler(m RouteManager, log *zap.Logger, prod bool) gin.Handler
start := time.Now()
defer func() {
- dur := time.Now().Sub(start)
+ dur := time.Since(start)
stats.Timing("bricksllm.admin.get_get_routes_handler.latency", dur, nil, 1)
}()
From 2a1792fcf977c40974c0b05af426ccf60728d278 Mon Sep 17 00:00:00 2001
From: Spike Lu
Date: Mon, 26 Feb 2024 16:45:20 -0800
Subject: [PATCH 56/71] use time since and remove unused variables
---
internal/server/web/proxy/anthropic.go | 11 +++--------
1 file changed, 3 insertions(+), 8 deletions(-)
diff --git a/internal/server/web/proxy/anthropic.go b/internal/server/web/proxy/anthropic.go
index 7f6a607..ed9a9b4 100644
--- a/internal/server/web/proxy/anthropic.go
+++ b/internal/server/web/proxy/anthropic.go
@@ -18,11 +18,6 @@ import (
"go.uber.org/zap/zapcore"
)
-const (
- anthropicPromptMagicNum int = 1
- anthropicCompletionMagicNum int = 4
-)
-
type anthropicEstimator interface {
EstimateTotalCost(model string, promptTks, completionTks int) (float64, error)
EstimateCompletionCost(model string, tks int) (float64, error)
@@ -98,7 +93,7 @@ func getCompletionHandler(r recorder, prod, private bool, client http.Client, km
// model := c.GetString("model")
if !isStreaming && res.StatusCode == http.StatusOK {
- dur := time.Now().Sub(start)
+ dur := time.Since(start)
stats.Timing("bricksllm.proxy.get_completion_handler.latency", dur, nil, 1)
bytes, err := io.ReadAll(res.Body)
@@ -148,7 +143,7 @@ func getCompletionHandler(r recorder, prod, private bool, client http.Client, km
}
if res.StatusCode != http.StatusOK {
- dur := time.Now().Sub(start)
+ dur := time.Since(start)
stats.Timing("bricksllm.proxy.get_completion_handler.error_latency", dur, nil, 1)
stats.Incr("bricksllm.proxy.get_completion_handler.error_response", nil, 1)
bytes, err := io.ReadAll(res.Body)
@@ -258,7 +253,7 @@ func getCompletionHandler(r recorder, prod, private bool, client http.Client, km
return true
})
- stats.Timing("bricksllm.proxy.get_completion_handler.streaming_latency", time.Now().Sub(start), nil, 1)
+ stats.Timing("bricksllm.proxy.get_completion_handler.streaming_latency", time.Since(start), nil, 1)
}
}
From 847ea8cea5c4ace1173b59588e3298ed011d973d Mon Sep 17 00:00:00 2001
From: Spike Lu
Date: Mon, 26 Feb 2024 16:45:44 -0800
Subject: [PATCH 57/71] use time since and remove unused variables
---
internal/server/web/proxy/azure_chat_completion.go | 6 +++---
internal/server/web/proxy/azure_embedding.go | 2 +-
internal/server/web/proxy/custom_provider.go | 6 +++---
3 files changed, 7 insertions(+), 7 deletions(-)
diff --git a/internal/server/web/proxy/azure_chat_completion.go b/internal/server/web/proxy/azure_chat_completion.go
index 2807071..bb1c478 100644
--- a/internal/server/web/proxy/azure_chat_completion.go
+++ b/internal/server/web/proxy/azure_chat_completion.go
@@ -73,7 +73,7 @@ func getAzureChatCompletionHandler(r recorder, prod, private bool, psm ProviderS
}
if res.StatusCode == http.StatusOK && !isStreaming {
- dur := time.Now().Sub(start)
+ dur := time.Since(start)
stats.Timing("bricksllm.proxy.get_azure_chat_completion_handler.latency", dur, nil, 1)
bytes, err := io.ReadAll(res.Body)
@@ -120,7 +120,7 @@ func getAzureChatCompletionHandler(r recorder, prod, private bool, psm ProviderS
}
if res.StatusCode != http.StatusOK {
- dur := time.Now().Sub(start)
+ dur := time.Since(start)
stats.Timing("bricksllm.proxy.get_azure_chat_completion_handler.error_latency", dur, nil, 1)
stats.Incr("bricksllm.proxy.get_azure_chat_completion_handler.error_response", nil, 1)
@@ -238,6 +238,6 @@ func getAzureChatCompletionHandler(r recorder, prod, private bool, psm ProviderS
return true
})
- stats.Timing("bricksllm.proxy.get_azure_chat_completion_handler.streaming_latency", time.Now().Sub(start), nil, 1)
+ stats.Timing("bricksllm.proxy.get_azure_chat_completion_handler.streaming_latency", time.Since(start), nil, 1)
}
}
diff --git a/internal/server/web/proxy/azure_embedding.go b/internal/server/web/proxy/azure_embedding.go
index f9ac0de..21150b4 100644
--- a/internal/server/web/proxy/azure_embedding.go
+++ b/internal/server/web/proxy/azure_embedding.go
@@ -54,7 +54,7 @@ func getAzureEmbeddingsHandler(r recorder, prod, private bool, psm ProviderSetti
}
defer res.Body.Close()
- dur := time.Now().Sub(start)
+ dur := time.Since(start)
stats.Timing("bricksllm.proxy.get_azure_embeddings_handler.latency", dur, nil, 1)
bytes, err := io.ReadAll(res.Body)
diff --git a/internal/server/web/proxy/custom_provider.go b/internal/server/web/proxy/custom_provider.go
index 1f9007f..f4ec182 100644
--- a/internal/server/web/proxy/custom_provider.go
+++ b/internal/server/web/proxy/custom_provider.go
@@ -110,7 +110,7 @@ func getCustomProviderHandler(prod, private bool, psm ProviderSettingsManager, c
defer res.Body.Close()
if res.StatusCode == http.StatusOK && !isStreaming {
- dur := time.Now().Sub(start)
+ dur := time.Since(start)
stats.Timing("bricksllm.proxy.get_custom_provider_handler.latency", dur, tags, 1)
bytes, err := io.ReadAll(res.Body)
@@ -133,7 +133,7 @@ func getCustomProviderHandler(prod, private bool, psm ProviderSettingsManager, c
}
if res.StatusCode != http.StatusOK {
- stats.Timing("bricksllm.proxy.get_custom_provider_handler.error_latency", time.Now().Sub(start), nil, 1)
+ stats.Timing("bricksllm.proxy.get_custom_provider_handler.error_latency", time.Since(start), nil, 1)
stats.Incr("bricksllm.proxy.get_custom_provider_handler.error_response", nil, 1)
bytes, err := io.ReadAll(res.Body)
@@ -217,6 +217,6 @@ func getCustomProviderHandler(prod, private bool, psm ProviderSettingsManager, c
return true
})
- stats.Timing("bricksllm.proxy.get_custom_provider_handler.streaming_latency", time.Now().Sub(start), nil, 1)
+ stats.Timing("bricksllm.proxy.get_custom_provider_handler.streaming_latency", time.Since(start), nil, 1)
}
}
From 43eed9ac7a5bed472e538ecc05e9d55d3669a211 Mon Sep 17 00:00:00 2001
From: Spike Lu
Date: Mon, 26 Feb 2024 16:46:12 -0800
Subject: [PATCH 58/71] add new handlers for tracking audio API costs
---
internal/server/web/proxy/audio.go | 499 ++++++++++++++++++++++++++++-
1 file changed, 491 insertions(+), 8 deletions(-)
diff --git a/internal/server/web/proxy/audio.go b/internal/server/web/proxy/audio.go
index 33e4534..ecff42b 100644
--- a/internal/server/web/proxy/audio.go
+++ b/internal/server/web/proxy/audio.go
@@ -1,10 +1,493 @@
package proxy
import (
+ "bytes"
+ "context"
+ "encoding/json"
+ "io"
+ "mime/multipart"
+ "net/http"
+ "strconv"
+ "strings"
+ "time"
+
+ "github.com/asticode/go-astisub"
+ "github.com/bricks-cloud/bricksllm/internal/stats"
+ "github.com/gin-gonic/gin"
+ goopenai "github.com/sashabaranov/go-openai"
"go.uber.org/zap"
"go.uber.org/zap/zapcore"
)
+func getSpeechHandler(r recorder, prod, private bool, client http.Client, log *zap.Logger, timeOut time.Duration) gin.HandlerFunc {
+ return func(c *gin.Context) {
+ stats.Incr("bricksllm.proxy.get_speech_handler.requests", nil, 1)
+
+ if c == nil || c.Request == nil {
+ JSON(c, http.StatusInternalServerError, "[BricksLLM] context is empty")
+ return
+ }
+
+ cid := c.GetString(correlationId)
+
+ ctx, cancel := context.WithTimeout(context.Background(), timeOut)
+ defer cancel()
+
+ req, err := http.NewRequestWithContext(ctx, c.Request.Method, "https://api.openai.com/v1/audio/speech", c.Request.Body)
+ if err != nil {
+ logError(log, "error when creating openai http request", prod, cid, err)
+ JSON(c, http.StatusInternalServerError, "[BricksLLM] failed to create openai http request")
+ return
+ }
+
+ copyHttpHeaders(c.Request, req)
+
+ start := time.Now()
+
+ res, err := client.Do(req)
+ if err != nil {
+ stats.Incr("bricksllm.proxy.get_speech_handler.http_client_error", nil, 1)
+
+ logError(log, "error when sending create speech request to openai", prod, cid, err)
+ JSON(c, http.StatusInternalServerError, "[BricksLLM] failed to send create speech request to openai")
+ return
+ }
+ defer res.Body.Close()
+
+ dur := time.Since(start)
+ stats.Timing("bricksllm.proxy.get_speech_handler.latency", dur, nil, 1)
+
+ bytes, err := io.ReadAll(res.Body)
+ if err != nil {
+ logError(log, "error when reading openai create speech response body", prod, cid, err)
+ JSON(c, http.StatusInternalServerError, "[BricksLLM] failed to read openai create speech response body")
+ return
+ }
+
+ if res.StatusCode == http.StatusOK {
+ stats.Incr("bricksllm.proxy.get_speech_handler.success", nil, 1)
+ stats.Timing("bricksllm.proxy.get_pass_through_handler.success_latency", dur, nil, 1)
+ }
+
+ if res.StatusCode != http.StatusOK {
+ stats.Timing("bricksllm.proxy.get_speech_handler.error_latency", dur, nil, 1)
+ stats.Incr("bricksllm.proxy.get_speech_handler.error_response", nil, 1)
+
+ errorRes := &goopenai.ErrorResponse{}
+ err = json.Unmarshal(bytes, errorRes)
+ if err != nil {
+ logError(log, "error when unmarshalling openai create speech error response body", prod, cid, err)
+ }
+
+ logOpenAiError(log, prod, cid, errorRes)
+ }
+
+ for name, values := range res.Header {
+ for _, value := range values {
+ c.Header(name, value)
+ }
+ }
+
+ c.Data(res.StatusCode, res.Header.Get("Content-Type"), bytes)
+ }
+}
+
+func convertVerboseJson(resp *goopenai.AudioResponse, format string) ([]byte, error) {
+ if format == "verbose_json" || format == "json" {
+ selected := resp
+ if format == "json" {
+ selected = &goopenai.AudioResponse{
+ Text: resp.Text,
+ }
+ }
+
+ data, err := json.Marshal(selected)
+ if err != nil {
+ return nil, err
+ }
+
+ return data, nil
+ }
+
+ if format == "text" {
+ return []byte(resp.Text + "\n"), nil
+ }
+
+ if format == "srt" || format == "vtt" {
+ sub := astisub.NewSubtitles()
+ items := []*astisub.Item{}
+
+ for _, seg := range resp.Segments {
+ item := &astisub.Item{
+ StartAt: time.Duration(seg.Start * float64(time.Second)),
+ EndAt: time.Duration(seg.End * float64(time.Second)),
+ Lines: []astisub.Line{
+ {
+ Items: []astisub.LineItem{
+ {Text: seg.Text},
+ },
+ },
+ },
+ }
+
+ items = append(items, item)
+ }
+
+ sub.Items = items
+
+ buf := bytes.NewBuffer([]byte{})
+
+ if format == "srt" {
+ err := sub.WriteToSRT(buf)
+ if err != nil {
+ return nil, err
+ }
+
+ return buf.Bytes(), nil
+ }
+
+ if format == "vtt" {
+ err := sub.WriteToWebVTT(buf)
+ if err != nil {
+ return nil, err
+ }
+
+ return buf.Bytes(), nil
+ }
+ }
+
+ return nil, nil
+}
+
+func getContentType(format string) string {
+ if format == "verbose_json" || format == "json" {
+ return "application/json"
+ }
+
+ return "text/plain; charset=utf-8"
+}
+
+func getTranscriptionsHandler(r recorder, prod, private bool, client http.Client, log *zap.Logger, timeOut time.Duration, e estimator) gin.HandlerFunc {
+ return func(c *gin.Context) {
+ stats.Incr("bricksllm.proxy.get_transcriptions_handler.requests", nil, 1)
+
+ if c == nil || c.Request == nil {
+ JSON(c, http.StatusInternalServerError, "[BricksLLM] context is empty")
+ return
+ }
+
+ cid := c.GetString(correlationId)
+
+ ctx, cancel := context.WithTimeout(context.Background(), timeOut)
+ defer cancel()
+
+ req, err := http.NewRequestWithContext(ctx, c.Request.Method, "https://api.openai.com/v1/audio/transcriptions", c.Request.Body)
+ if err != nil {
+ logError(log, "error when creating transcriptions openai http request", prod, cid, err)
+ JSON(c, http.StatusInternalServerError, "[BricksLLM] failed to create openai transcriptions http request")
+ return
+ }
+
+ copyHttpHeaders(c.Request, req)
+
+ var b bytes.Buffer
+ writer := multipart.NewWriter(&b)
+
+ err = writeFieldToBuffer([]string{
+ "model",
+ "language",
+ "prompt",
+ "response_format",
+ "temperature",
+ }, c, writer, map[string]string{
+ "response_format": "verbose_json",
+ })
+ if err != nil {
+ stats.Incr("bricksllm.proxy.get_transcriptions_handler.write_field_to_buffer_error", nil, 1)
+ logError(log, "error when writing field to buffer", prod, cid, err)
+ JSON(c, http.StatusInternalServerError, "[BricksLLM] cannot write field to buffer")
+ return
+ }
+
+ var form TransriptionForm
+ c.ShouldBind(&form)
+
+ if form.File != nil {
+ fieldWriter, err := writer.CreateFormFile("file", form.File.Filename)
+ if err != nil {
+ stats.Incr("bricksllm.proxy.get_transcriptions_handler.create_transcription_file_error", nil, 1)
+ logError(log, "error when creating transcription file", prod, cid, err)
+ JSON(c, http.StatusInternalServerError, "[BricksLLM] cannot create transcription file")
+ return
+ }
+
+ opened, err := form.File.Open()
+ if err != nil {
+ stats.Incr("bricksllm.proxy.get_transcriptions_handler.open_transcription_file_error", nil, 1)
+ logError(log, "error when openning transcription file", prod, cid, err)
+ JSON(c, http.StatusInternalServerError, "[BricksLLM] cannot open transcription file")
+ return
+ }
+
+ _, err = io.Copy(fieldWriter, opened)
+ if err != nil {
+ stats.Incr("bricksllm.proxy.get_transcriptions_handler.copy_transcription_file_error", nil, 1)
+ logError(log, "error when copying transcription file", prod, cid, err)
+ JSON(c, http.StatusInternalServerError, "[BricksLLM] cannot copy transcription file")
+ return
+ }
+ }
+
+ req.Header.Set("Content-Type", writer.FormDataContentType())
+ writer.Close()
+ req.Body = io.NopCloser(&b)
+
+ start := time.Now()
+
+ res, err := client.Do(req)
+ if err != nil {
+ stats.Incr("bricksllm.proxy.get_transcriptions_handler.http_client_error", nil, 1)
+
+ logError(log, "error when sending transcriptions request to openai", prod, cid, err)
+ JSON(c, http.StatusInternalServerError, "[BricksLLM] failed to send transcriptions request to openai")
+ return
+ }
+ defer res.Body.Close()
+
+ dur := time.Since(start)
+ stats.Timing("bricksllm.proxy.get_transcriptions_handler.latency", dur, nil, 1)
+
+ bytes, err := io.ReadAll(res.Body)
+ if err != nil {
+ logError(log, "error when reading openai transcriptions response body", prod, cid, err)
+ JSON(c, http.StatusInternalServerError, "[BricksLLM] failed to read openai transcriptions response body")
+ return
+ }
+
+ format := c.PostForm("response_format")
+ for name, values := range res.Header {
+ for _, value := range values {
+ if strings.ToLower(name) == "content-type" {
+ c.Header(name, getContentType(format))
+ continue
+ }
+
+ c.Header(name, value)
+ }
+ }
+
+ if res.StatusCode == http.StatusOK {
+ stats.Incr("bricksllm.proxy.get_transcriptions_handler.success", nil, 1)
+ stats.Timing("bricksllm.proxy.get_transcriptions_handler.success_latency", dur, nil, 1)
+
+ ar := &goopenai.AudioResponse{}
+ err = json.Unmarshal(bytes, ar)
+ if err != nil {
+ logError(log, "error when unmarshalling openai http audio response body", prod, cid, err)
+ }
+
+ if err == nil {
+ cost, err := e.EstimateTranscriptionCost(ar.Duration, c.GetString("model"))
+ if err != nil {
+ stats.Incr("bricksllm.proxy.get_transcriptions_handler.estimate_total_cost_error", nil, 1)
+ logError(log, "error when estimating openai cost", prod, cid, err)
+ }
+
+ c.Set("costInUsd", cost)
+ }
+
+ data, err := convertVerboseJson(ar, format)
+ if err != nil {
+ stats.Incr("bricksllm.proxy.get_transcriptions_handler.convert_verbose_json_error", nil, 1)
+ logError(log, "error when converting verbose json", prod, cid, err)
+ }
+
+ c.Header("Content-Length", strconv.Itoa(len(data)))
+
+ c.Data(res.StatusCode, getContentType(format), data)
+ return
+ }
+
+ if res.StatusCode != http.StatusOK {
+ stats.Timing("bricksllm.proxy.get_transcriptions_handler.error_latency", dur, nil, 1)
+ stats.Incr("bricksllm.proxy.get_transcriptions_handler.error_response", nil, 1)
+
+ errorRes := &goopenai.ErrorResponse{}
+ err = json.Unmarshal(bytes, errorRes)
+ if err != nil {
+ logError(log, "error when unmarshalling openai transcriptions error response body", prod, cid, err)
+ }
+
+ logOpenAiError(log, prod, cid, errorRes)
+
+ c.Data(res.StatusCode, res.Header.Get("Content-Type"), bytes)
+
+ return
+ }
+ }
+}
+
+func getTranslationsHandler(r recorder, prod, private bool, client http.Client, log *zap.Logger, timeOut time.Duration, e estimator) gin.HandlerFunc {
+ return func(c *gin.Context) {
+ stats.Incr("bricksllm.proxy.get_translations_handler.requests", nil, 1)
+
+ if c == nil || c.Request == nil {
+ JSON(c, http.StatusInternalServerError, "[BricksLLM] context is empty")
+ return
+ }
+
+ cid := c.GetString(correlationId)
+
+ ctx, cancel := context.WithTimeout(context.Background(), timeOut)
+ defer cancel()
+
+ req, err := http.NewRequestWithContext(ctx, c.Request.Method, "https://api.openai.com/v1/audio/translations", c.Request.Body)
+ if err != nil {
+ logError(log, "error when creating translations openai http request", prod, cid, err)
+ JSON(c, http.StatusInternalServerError, "[BricksLLM] failed to create openai translations http request")
+ return
+ }
+
+ copyHttpHeaders(c.Request, req)
+
+ var b bytes.Buffer
+ writer := multipart.NewWriter(&b)
+
+ err = writeFieldToBuffer([]string{
+ "model",
+ "prompt",
+ "response_format",
+ "temperature",
+ }, c, writer, map[string]string{
+ "response_format": "verbose_json",
+ })
+ if err != nil {
+ stats.Incr("bricksllm.proxy.get_pass_through_handler.write_field_to_buffer_error", nil, 1)
+ logError(log, "error when writing field to buffer", prod, cid, err)
+ JSON(c, http.StatusInternalServerError, "[BricksLLM] cannot write field to buffer")
+ return
+ }
+
+ var form TranslationForm
+ c.ShouldBind(&form)
+
+ if form.File != nil {
+ fieldWriter, err := writer.CreateFormFile("file", form.File.Filename)
+ if err != nil {
+ stats.Incr("bricksllm.proxy.get_pass_through_handler.create_translation_file_error", nil, 1)
+ logError(log, "error when creating translation file", prod, cid, err)
+ JSON(c, http.StatusInternalServerError, "[BricksLLM] cannot create translation file")
+ return
+ }
+
+ opened, err := form.File.Open()
+ if err != nil {
+ stats.Incr("bricksllm.proxy.get_pass_through_handler.open_translation_file_error", nil, 1)
+ logError(log, "error when openning translation file", prod, cid, err)
+ JSON(c, http.StatusInternalServerError, "[BricksLLM] cannot open translation file")
+ return
+ }
+
+ _, err = io.Copy(fieldWriter, opened)
+ if err != nil {
+ stats.Incr("bricksllm.proxy.get_pass_through_handler.copy_translation_file_error", nil, 1)
+ logError(log, "error when copying translation file", prod, cid, err)
+ JSON(c, http.StatusInternalServerError, "[BricksLLM] cannot copy translation file")
+ return
+ }
+ }
+
+ req.Header.Set("Content-Type", writer.FormDataContentType())
+
+ writer.Close()
+
+ req.Body = io.NopCloser(&b)
+
+ start := time.Now()
+
+ res, err := client.Do(req)
+ if err != nil {
+ stats.Incr("bricksllm.proxy.get_translations_handler.http_client_error", nil, 1)
+
+ logError(log, "error when sending translations request to openai", prod, cid, err)
+ JSON(c, http.StatusInternalServerError, "[BricksLLM] failed to send translations request to openai")
+ return
+ }
+ defer res.Body.Close()
+
+ dur := time.Since(start)
+ stats.Timing("bricksllm.proxy.get_translations_handler.latency", dur, nil, 1)
+
+ bytes, err := io.ReadAll(res.Body)
+ if err != nil {
+ logError(log, "error when reading openai translations response body", prod, cid, err)
+ JSON(c, http.StatusInternalServerError, "[BricksLLM] failed to read openai translations response body")
+ return
+ }
+
+ format := c.PostForm("response_format")
+ for name, values := range res.Header {
+ for _, value := range values {
+ if strings.ToLower(name) == "content-type" {
+ c.Header(name, getContentType(format))
+ continue
+ }
+
+ c.Header(name, value)
+ }
+ }
+
+ if res.StatusCode == http.StatusOK {
+ stats.Incr("bricksllm.proxy.get_translations_handler.success", nil, 1)
+ stats.Timing("bricksllm.proxy.get_translations_handler.success_latency", dur, nil, 1)
+
+ ar := &goopenai.AudioResponse{}
+ err = json.Unmarshal(bytes, ar)
+ if err != nil {
+ logError(log, "error when unmarshalling openai http audio response body", prod, cid, err)
+ }
+
+ if err == nil {
+ cost, err := e.EstimateTranscriptionCost(ar.Duration, c.GetString("model"))
+ if err != nil {
+ stats.Incr("bricksllm.proxy.get_translations_handler.estimate_total_cost_error", nil, 1)
+ logError(log, "error when estimating openai cost", prod, cid, err)
+ }
+
+ c.Set("costInUsd", cost)
+ }
+
+ data, err := convertVerboseJson(ar, format)
+ if err != nil {
+ stats.Incr("bricksllm.proxy.get_translations_handler.convert_verbose_json_error", nil, 1)
+ logError(log, "error when converting verbose json", prod, cid, err)
+ }
+
+ c.Header("Content-Length", strconv.Itoa(len(data)))
+
+ c.Data(res.StatusCode, getContentType(format), data)
+ return
+ }
+
+ if res.StatusCode != http.StatusOK {
+ stats.Timing("bricksllm.proxy.get_translations_handler.error_latency", dur, nil, 1)
+ stats.Incr("bricksllm.proxy.get_translations_handler.error_response", nil, 1)
+
+ errorRes := &goopenai.ErrorResponse{}
+ err = json.Unmarshal(bytes, errorRes)
+ if err != nil {
+ logError(log, "error when unmarshalling openai translations error response body", prod, cid, err)
+ }
+
+ logOpenAiError(log, prod, cid, errorRes)
+
+ c.Data(res.StatusCode, res.Header.Get("Content-Type"), bytes)
+
+ return
+ }
+ }
+}
+
type SpeechRequest struct {
Model string `json:"model"`
Input string `json:"input"`
@@ -13,24 +496,24 @@ type SpeechRequest struct {
Speed float64 `json:"speed"`
}
-func logCreateSpeechRequest(log *zap.Logger, sr *SpeechRequest, prod, private bool, cid string) {
+func logCreateSpeechRequest(log *zap.Logger, csr *goopenai.CreateSpeechRequest, prod, private bool, cid string) {
if prod {
fields := []zapcore.Field{
zap.String(correlationId, cid),
- zap.String("model", sr.Model),
- zap.String("voice", sr.Voice),
+ zap.String("model", string(csr.Model)),
+ zap.String("voice", string(csr.Voice)),
}
if !private {
- fields = append(fields, zap.String("input", sr.Input))
+ fields = append(fields, zap.String("input", csr.Input))
}
- if len(sr.ResponseFormat) != 0 {
- fields = append(fields, zap.String("response_format", sr.ResponseFormat))
+ if len(csr.ResponseFormat) != 0 {
+ fields = append(fields, zap.String("response_format", string(csr.ResponseFormat)))
}
- if sr.Speed != 0 {
- fields = append(fields, zap.Float64("speed", sr.Speed))
+ if csr.Speed != 0 {
+ fields = append(fields, zap.Float64("speed", csr.Speed))
}
log.Info("openai create speech request", fields...)
From 4c985040fcffda3718d299413022c1f4d849da26 Mon Sep 17 00:00:00 2001
From: Spike Lu
Date: Mon, 26 Feb 2024 16:46:49 -0800
Subject: [PATCH 59/71] integrate new audio handlers
---
internal/server/web/proxy/proxy.go | 143 +++++------------------------
1 file changed, 21 insertions(+), 122 deletions(-)
diff --git a/internal/server/web/proxy/proxy.go b/internal/server/web/proxy/proxy.go
index efee23a..9ee529e 100644
--- a/internal/server/web/proxy/proxy.go
+++ b/internal/server/web/proxy/proxy.go
@@ -88,9 +88,9 @@ func NewProxyServer(log *zap.Logger, mode, privacyMode string, c cache, m KeyMan
router.POST("/api/health", getGetHealthCheckHandler())
// audios
- router.POST("/api/providers/openai/v1/audio/speech", getPassThroughHandler(r, prod, private, client, log, timeOut))
- router.POST("/api/providers/openai/v1/audio/transcriptions", getPassThroughHandler(r, prod, private, client, log, timeOut))
- router.POST("/api/providers/openai/v1/audio/translations", getPassThroughHandler(r, prod, private, client, log, timeOut))
+ router.POST("/api/providers/openai/v1/audio/speech", getSpeechHandler(r, prod, private, client, log, timeOut))
+ router.POST("/api/providers/openai/v1/audio/transcriptions", getTranscriptionsHandler(r, prod, private, client, log, timeOut, e))
+ router.POST("/api/providers/openai/v1/audio/translations", getTranslationsHandler(r, prod, private, client, log, timeOut, e))
// completions
router.POST("/api/providers/openai/v1/chat/completions", getChatCompletionHandler(r, prod, private, psm, client, kms, log, e, timeOut))
@@ -209,9 +209,16 @@ type TranslationForm struct {
File *multipart.FileHeader `form:"file" binding:"required"`
}
-func writeFieldToBuffer(fields []string, c *gin.Context, writer *multipart.Writer) error {
+func writeFieldToBuffer(fields []string, c *gin.Context, writer *multipart.Writer, overWrites map[string]string) error {
for _, field := range fields {
val := c.PostForm(field)
+
+ if len(overWrites) != 0 {
+ if ow := overWrites[field]; len(ow) != 0 {
+ val = ow
+ }
+ }
+
if len(val) != 0 {
err := writer.WriteField(field, val)
if err != nil {
@@ -316,7 +323,7 @@ func getPassThroughHandler(r recorder, prod, private bool, client http.Client, l
"size",
"response_format",
"user",
- }, c, writer)
+ }, c, writer, nil)
if err != nil {
stats.Incr("bricksllm.proxy.get_pass_through_handler.write_field_to_buffer_error", tags, 1)
logError(log, "error when writing field to buffer", prod, cid, err)
@@ -396,7 +403,7 @@ func getPassThroughHandler(r recorder, prod, private bool, client http.Client, l
"size",
"response_format",
"user",
- }, c, writer)
+ }, c, writer, nil)
if err != nil {
stats.Incr("bricksllm.proxy.get_pass_through_handler.write_field_to_buffer_error", tags, 1)
logError(log, "error when writing field to buffer", prod, cid, err)
@@ -440,132 +447,25 @@ func getPassThroughHandler(r recorder, prod, private bool, client http.Client, l
req.Body = io.NopCloser(&b)
}
- if c.FullPath() == "/api/providers/openai/v1/audio/transcriptions" && c.Request.Method == http.MethodPost {
- var b bytes.Buffer
- writer := multipart.NewWriter(&b)
-
- err := writeFieldToBuffer([]string{
- "model",
- "language",
- "prompt",
- "response_format",
- "temperature",
- }, c, writer)
- if err != nil {
- stats.Incr("bricksllm.proxy.get_pass_through_handler.write_field_to_buffer_error", tags, 1)
- logError(log, "error when writing field to buffer", prod, cid, err)
- JSON(c, http.StatusInternalServerError, "[BricksLLM] cannot write field to buffer")
- return
- }
-
- var form TransriptionForm
- c.ShouldBind(&form)
-
- if form.File != nil {
- fieldWriter, err := writer.CreateFormFile("file", form.File.Filename)
- if err != nil {
- stats.Incr("bricksllm.proxy.get_pass_through_handler.create_transcription_file_error", tags, 1)
- logError(log, "error when creating transcription file", prod, cid, err)
- JSON(c, http.StatusInternalServerError, "[BricksLLM] cannot create transcription file")
- return
- }
-
- opened, err := form.File.Open()
- if err != nil {
- stats.Incr("bricksllm.proxy.get_pass_through_handler.open_transcription_file_error", tags, 1)
- logError(log, "error when openning transcription file", prod, cid, err)
- JSON(c, http.StatusInternalServerError, "[BricksLLM] cannot open transcription file")
- return
- }
-
- _, err = io.Copy(fieldWriter, opened)
- if err != nil {
- stats.Incr("bricksllm.proxy.get_pass_through_handler.copy_transcription_file_error", tags, 1)
- logError(log, "error when copying transcription file", prod, cid, err)
- JSON(c, http.StatusInternalServerError, "[BricksLLM] cannot copy transcription file")
- return
- }
- }
-
- req.Header.Set("Content-Type", writer.FormDataContentType())
-
- writer.Close()
-
- req.Body = io.NopCloser(&b)
- }
-
- if c.FullPath() == "/api/providers/openai/v1/audio/translations" && c.Request.Method == http.MethodPost {
- var b bytes.Buffer
- writer := multipart.NewWriter(&b)
-
- err := writeFieldToBuffer([]string{
- "model",
- "prompt",
- "response_format",
- "temperature",
- }, c, writer)
- if err != nil {
- stats.Incr("bricksllm.proxy.get_pass_through_handler.write_field_to_buffer_error", tags, 1)
- logError(log, "error when writing field to buffer", prod, cid, err)
- JSON(c, http.StatusInternalServerError, "[BricksLLM] cannot write field to buffer")
- return
- }
-
- var form TranslationForm
- c.ShouldBind(&form)
-
- if form.File != nil {
- fieldWriter, err := writer.CreateFormFile("file", form.File.Filename)
- if err != nil {
- stats.Incr("bricksllm.proxy.get_pass_through_handler.create_translation_file_error", tags, 1)
- logError(log, "error when creating translation file", prod, cid, err)
- JSON(c, http.StatusInternalServerError, "[BricksLLM] cannot create translation file")
- return
- }
-
- opened, err := form.File.Open()
- if err != nil {
- stats.Incr("bricksllm.proxy.get_pass_through_handler.open_translation_file_error", tags, 1)
- logError(log, "error when openning translation file", prod, cid, err)
- JSON(c, http.StatusInternalServerError, "[BricksLLM] cannot open translation file")
- return
- }
-
- _, err = io.Copy(fieldWriter, opened)
- if err != nil {
- stats.Incr("bricksllm.proxy.get_pass_through_handler.copy_translation_file_error", tags, 1)
- logError(log, "error when copying translation file", prod, cid, err)
- JSON(c, http.StatusInternalServerError, "[BricksLLM] cannot copy translation file")
- return
- }
- }
-
- req.Header.Set("Content-Type", writer.FormDataContentType())
-
- writer.Close()
-
- req.Body = io.NopCloser(&b)
- }
-
start := time.Now()
res, err := client.Do(req)
if err != nil {
stats.Incr("bricksllm.proxy.get_pass_through_handler.http_client_error", tags, 1)
- logError(log, "error when sending embedding request to openai", prod, cid, err)
- JSON(c, http.StatusInternalServerError, "[BricksLLM] failed to send embedding request to openai")
+ logError(log, "error when sending pass through request to openai", prod, cid, err)
+ JSON(c, http.StatusInternalServerError, "[BricksLLM] failed to send pass through request to openai")
return
}
defer res.Body.Close()
- dur := time.Now().Sub(start)
+ dur := time.Since(start)
stats.Timing("bricksllm.proxy.get_pass_through_handler.latency", dur, tags, 1)
bytes, err := io.ReadAll(res.Body)
if err != nil {
logError(log, "error when reading openai embedding response body", prod, cid, err)
- JSON(c, http.StatusInternalServerError, "[BricksLLM] failed to read openai embedding response body")
+ JSON(c, http.StatusInternalServerError, "[BricksLLM] failed to read openai pass through response body")
return
}
@@ -996,7 +896,7 @@ func getEmbeddingHandler(r recorder, prod, private bool, psm ProviderSettingsMan
}
defer res.Body.Close()
- dur := time.Now().Sub(start)
+ dur := time.Since(start)
stats.Timing("bricksllm.proxy.get_embedding_handler.latency", dur, nil, 1)
bytes, err := io.ReadAll(res.Body)
@@ -1092,7 +992,6 @@ var (
eventCompletionPrefix = []byte("event: completion")
eventPingPrefix = []byte("event: ping")
eventErrorPrefix = []byte("event: error")
- errorPrefix = []byte(`data: {"error":`)
)
func getChatCompletionHandler(r recorder, prod, private bool, psm ProviderSettingsManager, client http.Client, kms keyMemStorage, log *zap.Logger, e estimator, timeOut time.Duration) gin.HandlerFunc {
@@ -1153,7 +1052,7 @@ func getChatCompletionHandler(r recorder, prod, private bool, psm ProviderSettin
model := c.GetString("model")
if res.StatusCode == http.StatusOK && !isStreaming {
- dur := time.Now().Sub(start)
+ dur := time.Since(start)
stats.Timing("bricksllm.proxy.get_chat_completion_handler.latency", dur, nil, 1)
bytes, err := io.ReadAll(res.Body)
@@ -1198,7 +1097,7 @@ func getChatCompletionHandler(r recorder, prod, private bool, psm ProviderSettin
}
if res.StatusCode != http.StatusOK {
- dur := time.Now().Sub(start)
+ dur := time.Since(start)
stats.Timing("bricksllm.proxy.get_chat_completion_handler.error_latency", dur, nil, 1)
stats.Incr("bricksllm.proxy.get_chat_completion_handler.error_response", nil, 1)
@@ -1300,7 +1199,7 @@ func getChatCompletionHandler(r recorder, prod, private bool, psm ProviderSettin
return true
})
- stats.Timing("bricksllm.proxy.get_chat_completion_handler.streaming_latency", time.Now().Sub(start), nil, 1)
+ stats.Timing("bricksllm.proxy.get_chat_completion_handler.streaming_latency", time.Since(start), nil, 1)
}
}
From 291a11e2e3da7a9603d1383d2932360de4a08fa3 Mon Sep 17 00:00:00 2001
From: Spike Lu
Date: Mon, 26 Feb 2024 16:48:29 -0800
Subject: [PATCH 60/71] remove unused funcs and add cost tracking for speech
API
---
internal/server/web/proxy/middleware.go | 42 +++++--------------------
1 file changed, 7 insertions(+), 35 deletions(-)
diff --git a/internal/server/web/proxy/middleware.go b/internal/server/web/proxy/middleware.go
index aafd518..0a1e035 100644
--- a/internal/server/web/proxy/middleware.go
+++ b/internal/server/web/proxy/middleware.go
@@ -24,16 +24,6 @@ import (
goopenai "github.com/sashabaranov/go-openai"
)
-type rateLimitError interface {
- Error() string
- RateLimit()
-}
-
-type expirationError interface {
- Error() string
- Reason() string
-}
-
type keyMemStorage interface {
GetKey(hash string) *key.ResponseKey
}
@@ -43,6 +33,8 @@ type keyStorage interface {
}
type estimator interface {
+ EstimateTranscriptionCost(secs float64, model string) (float64, error)
+ EstimateSpeechCost(input string, model string) (float64, error)
EstimateChatCompletionPromptCostWithTokenCounts(r *goopenai.ChatCompletionRequest) (int, float64, error)
EstimateEmbeddingsCost(r *goopenai.EmbeddingRequest) (float64, error)
EstimateChatCompletionStreamCostWithTokenCounts(model, content string) (int, float64, error)
@@ -77,10 +69,6 @@ type accessCache interface {
GetAccessStatus(key string) bool
}
-type encrypter interface {
- Encrypt(secret string) string
-}
-
func JSON(c *gin.Context, code int, message string) {
c.JSON(code, &goopenai.ErrorResponse{
Error: &goopenai.APIError{
@@ -168,7 +156,7 @@ func getMiddleware(kms keyMemStorage, cpm CustomProvidersManager, rm routeManage
customId := c.Request.Header.Get("X-CUSTOM-EVENT-ID")
defer func() {
- dur := time.Now().Sub(start)
+ dur := time.Since(start)
latency := int(dur.Milliseconds())
if !prod {
@@ -605,14 +593,16 @@ func getMiddleware(kms keyMemStorage, cpm CustomProvidersManager, rm routeManage
}
if c.FullPath() == "/api/providers/openai/v1/audio/speech" && c.Request.Method == http.MethodPost {
- sr := &SpeechRequest{}
+ sr := &goopenai.CreateSpeechRequest{}
err := json.Unmarshal(body, sr)
if err != nil {
logError(log, "error when unmarshalling create speech request", prod, cid, err)
return
}
- c.Set("model", sr.Model)
+ enrichedEvent.Request = sr
+
+ c.Set("model", string(sr.Model))
logCreateSpeechRequest(log, sr, prod, private, cid)
}
@@ -883,21 +873,3 @@ func containsPath(arr []key.PathConfig, path, method string) bool {
return false
}
-
-func getAuthTokenFromHeader(c *gin.Context) string {
- if strings.HasPrefix(c.FullPath(), "/api/providers/anthropic") {
- return c.GetHeader("x-api-key")
- }
-
- if strings.HasPrefix(c.FullPath(), "/api/providers/azure") {
- return c.GetHeader("api-key")
- }
-
- split := strings.Split(c.Request.Header.Get("Authorization"), "Bearer ")
- if len(split) < 2 || len(split[1]) == 0 {
- return ""
- }
-
- return split[1]
-
-}
From 30df0e5b13aca17bbbcb093b901347ea6e5ad657 Mon Sep 17 00:00:00 2001
From: Spike Lu
Date: Wed, 28 Feb 2024 10:48:43 -0800
Subject: [PATCH 61/71] update logging
---
internal/server/web/proxy/proxy.go | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/internal/server/web/proxy/proxy.go b/internal/server/web/proxy/proxy.go
index 9ee529e..ff43b52 100644
--- a/internal/server/web/proxy/proxy.go
+++ b/internal/server/web/proxy/proxy.go
@@ -490,7 +490,7 @@ func getPassThroughHandler(r recorder, prod, private bool, client http.Client, l
}
if c.FullPath() == "/api/providers/openai/v1/assistants" && c.Request.Method == http.MethodGet {
- logListAssistantFilesResponse(log, bytes, prod, cid)
+ logListAssistantsResponse(log, bytes, prod, private, cid)
}
if c.FullPath() == "/api/providers/openai/v1/assistants/:assistant_id/files" && c.Request.Method == http.MethodPost {
From 0e75fece13a2b3c3b8f6fc91dc014432768bf131 Mon Sep 17 00:00:00 2001
From: Spike Lu
Date: Wed, 28 Feb 2024 10:49:17 -0800
Subject: [PATCH 62/71] update logging
---
internal/server/web/proxy/middleware.go | 6 +++---
1 file changed, 3 insertions(+), 3 deletions(-)
diff --git a/internal/server/web/proxy/middleware.go b/internal/server/web/proxy/middleware.go
index 0a1e035..2a40314 100644
--- a/internal/server/web/proxy/middleware.go
+++ b/internal/server/web/proxy/middleware.go
@@ -701,7 +701,7 @@ func getMiddleware(kms keyMemStorage, cpm CustomProvidersManager, rm routeManage
}
if c.FullPath() == "/api/providers/openai/v1/assistants/:assistant_id/files/:file_id" && c.Request.Method == http.MethodDelete {
- logRetrieveAssistantFileRequest(log, prod, cid, fid, aid)
+ logDeleteAssistantFileRequest(log, body, prod, cid, fid, aid)
}
if c.FullPath() == "/api/providers/openai/v1/assistants/:assistant_id/files" && c.Request.Method == http.MethodGet {
@@ -713,7 +713,7 @@ func getMiddleware(kms keyMemStorage, cpm CustomProvidersManager, rm routeManage
}
if c.FullPath() == "/api/providers/openai/v1/threads/:thread_id" && c.Request.Method == http.MethodGet {
- logCreateThreadRequest(log, body, prod, private, cid)
+ logRetrieveThreadRequest(log, prod, cid, tid)
}
if c.FullPath() == "/api/providers/openai/v1/threads/:thread_id" && c.Request.Method == http.MethodPost {
@@ -745,7 +745,7 @@ func getMiddleware(kms keyMemStorage, cpm CustomProvidersManager, rm routeManage
}
if c.FullPath() == "/api/providers/openai/v1/threads/:thread_id/messages/:message_id/files" && c.Request.Method == http.MethodGet {
- logListAssistantFilesRequest(log, prod, cid, aid, qm)
+ logListMessageFilesRequest(log, body, prod, cid, tid, mid, qm)
}
if c.FullPath() == "/api/providers/openai/v1/threads/:thread_id/runs" && c.Request.Method == http.MethodPost {
From 88d350556d134da21e383e37f5deccd65dacf1fb Mon Sep 17 00:00:00 2001
From: Spike Lu
Date: Wed, 28 Feb 2024 10:49:49 -0800
Subject: [PATCH 63/71] remove unused functions
---
internal/server/web/proxy/custom_provider.go | 5 -----
1 file changed, 5 deletions(-)
diff --git a/internal/server/web/proxy/custom_provider.go b/internal/server/web/proxy/custom_provider.go
index f4ec182..25b5a7b 100644
--- a/internal/server/web/proxy/custom_provider.go
+++ b/internal/server/web/proxy/custom_provider.go
@@ -18,11 +18,6 @@ import (
"go.uber.org/zap"
)
-func countTokensFromJson(bytes []byte, contentLoc string) (int, error) {
- content := getContentFromJson(bytes, contentLoc)
- return custom.Count(content)
-}
-
func getContentFromJson(bytes []byte, contentLoc string) string {
result := gjson.Get(string(bytes), contentLoc)
content := ""
From a3a6c6ce3658770d328cf6792144c43e59c17b23 Mon Sep 17 00:00:00 2001
From: Spike Lu
Date: Wed, 28 Feb 2024 10:50:09 -0800
Subject: [PATCH 64/71] add logging
---
internal/server/web/proxy/anthropic.go | 2 ++
1 file changed, 2 insertions(+)
diff --git a/internal/server/web/proxy/anthropic.go b/internal/server/web/proxy/anthropic.go
index ed9a9b4..11375e7 100644
--- a/internal/server/web/proxy/anthropic.go
+++ b/internal/server/web/proxy/anthropic.go
@@ -114,6 +114,8 @@ func getCompletionHandler(r recorder, prod, private bool, client http.Client, km
logError(log, "error when unmarshalling anthropic http completion response body", prod, cid, err)
}
+ logCompletionResponse(log, bytes, prod, private, cid)
+
c.Set("content", completionRes.Completion)
// if err == nil {
From b23b30ba868e33b4fdfd4c358d7f65dead51b7a7 Mon Sep 17 00:00:00 2001
From: Spike Lu
Date: Wed, 28 Feb 2024 10:53:54 -0800
Subject: [PATCH 65/71] add cost trackig for finetune models
---
internal/provider/openai/cost.go | 165 +++++++++++++++++++------------
1 file changed, 104 insertions(+), 61 deletions(-)
diff --git a/internal/provider/openai/cost.go b/internal/provider/openai/cost.go
index c4daf11..299c525 100644
--- a/internal/provider/openai/cost.go
+++ b/internal/provider/openai/cost.go
@@ -10,49 +10,72 @@ import (
goopenai "github.com/sashabaranov/go-openai"
)
+func useFinetuneModel(model string) string {
+ if isFinetuneModel(model) {
+ return parseFinetuneModel(model)
+ }
+
+ return model
+}
+
+func isFinetuneModel(model string) bool {
+ return strings.HasPrefix(model, "ft:")
+}
+
+func parseFinetuneModel(model string) string {
+ parts := strings.Split(model, ":")
+ if len(parts) > 2 {
+ return "finetune-" + parts[1]
+ }
+
+ return model
+}
+
var OpenAiPerThousandTokenCost = map[string]map[string]float64{
"prompt": {
- "gpt-4-1106-preview": 0.01,
- "gpt-4-turbo-preview": 0.01,
- "gpt-4-0125-preview": 0.01,
- "gpt-4-1106-vision-preview": 0.01,
- "gpt-4-vision-preview": 0.01,
- "gpt-4": 0.03,
- "gpt-4-0314": 0.03,
- "gpt-4-0613": 0.03,
- "gpt-4-32k": 0.06,
- "gpt-4-32k-0613": 0.06,
- "gpt-4-32k-0314": 0.06,
- "gpt-3.5-turbo": 0.0015,
- "gpt-3.5-turbo-1106": 0.001,
- "gpt-3.5-turbo-0125": 0.0005,
- "gpt-3.5-turbo-0301": 0.0015,
- "gpt-3.5-turbo-instruct": 0.0015,
- "gpt-3.5-turbo-0613": 0.0015,
- "gpt-3.5-turbo-16k": 0.0015,
- "gpt-3.5-turbo-16k-0613": 0.0015,
- "text-davinci-003": 0.12,
- "text-davinci-002": 0.12,
- "code-davinci-002": 0.12,
- "text-curie-001": 0.012,
- "text-babbage-001": 0.0024,
- "text-ada-001": 0.0016,
- "davinci": 0.12,
- "curie": 0.012,
- "babbage": 0.0024,
- "ada": 0.0016,
+ "gpt-4-1106-preview": 0.01,
+ "gpt-4-turbo-preview": 0.01,
+ "gpt-4-0125-preview": 0.01,
+ "gpt-4-1106-vision-preview": 0.01,
+ "gpt-4-vision-preview": 0.01,
+ "gpt-4": 0.03,
+ "gpt-4-0314": 0.03,
+ "gpt-4-0613": 0.03,
+ "gpt-4-32k": 0.06,
+ "gpt-4-32k-0613": 0.06,
+ "gpt-4-32k-0314": 0.06,
+ "gpt-3.5-turbo": 0.0015,
+ "gpt-3.5-turbo-1106": 0.001,
+ "gpt-3.5-turbo-0125": 0.0005,
+ "gpt-3.5-turbo-0301": 0.0015,
+ "gpt-3.5-turbo-instruct": 0.0015,
+ "gpt-3.5-turbo-0613": 0.0015,
+ "gpt-3.5-turbo-16k": 0.0015,
+ "gpt-3.5-turbo-16k-0613": 0.0015,
+ "text-davinci-003": 0.12,
+ "text-davinci-002": 0.12,
+ "code-davinci-002": 0.12,
+ "text-curie-001": 0.012,
+ "text-babbage-001": 0.0024,
+ "text-ada-001": 0.0016,
+ "davinci": 0.12,
+ "curie": 0.012,
+ "babbage": 0.0024,
+ "ada": 0.0016,
+ "finetune-gpt-4-0613": 0.045,
+ "finetune-gpt-3.5-turbo-0125": 0.003,
+ "finetune-gpt-3.5-turbo-1106": 0.003,
+ "finetune-gpt-3.5-turbo-0613": 0.003,
+ "finetune-babbage-002": 0.0016,
+ "finetune-davinci-002": 0.012,
},
- "fine_tune": {
- "text-davinci-003": 0.03,
- "text-davinci-002": 0.03,
- "code-davinci-002": 0.03,
- "text-curie-001": 0.03,
- "text-babbage-001": 0.0006,
- "text-ada-001": 0.0004,
- "davinci": 0.03,
- "curie": 0.03,
- "babbage": 0.0006,
- "ada": 0.0004,
+ "finetune": {
+ "gpt-4-0613": 0.09,
+ "gpt-3.5-turbo-0125": 0.008,
+ "gpt-3.5-turbo-1106": 0.008,
+ "gpt-3.5-turbo-0613": 0.008,
+ "babbage-002": 0.0004,
+ "davinci-002": 0.006,
},
"embeddings": {
"text-embedding-ada-002": 0.0001,
@@ -65,25 +88,31 @@ var OpenAiPerThousandTokenCost = map[string]map[string]float64{
"tts-1-hd": 0.03,
},
"completion": {
- "gpt-3.5-turbo-1106": 0.002,
- "gpt-4-turbo-preview": 0.03,
- "gpt-4-1106-preview": 0.03,
- "gpt-4-0125-preview": 0.03,
- "gpt-4-1106-vision-preview": 0.03,
- "gpt-4-vision-preview": 0.03,
- "gpt-4": 0.06,
- "gpt-4-0314": 0.06,
- "gpt-4-0613": 0.06,
- "gpt-4-32k": 0.12,
- "gpt-4-32k-0613": 0.12,
- "gpt-4-32k-0314": 0.12,
- "gpt-3.5-turbo": 0.002,
- "gpt-3.5-turbo-0125": 0.0015,
- "gpt-3.5-turbo-0301": 0.002,
- "gpt-3.5-turbo-0613": 0.002,
- "gpt-3.5-turbo-instruct": 0.002,
- "gpt-3.5-turbo-16k": 0.004,
- "gpt-3.5-turbo-16k-0613": 0.004,
+ "gpt-3.5-turbo-1106": 0.002,
+ "gpt-4-turbo-preview": 0.03,
+ "gpt-4-1106-preview": 0.03,
+ "gpt-4-0125-preview": 0.03,
+ "gpt-4-1106-vision-preview": 0.03,
+ "gpt-4-vision-preview": 0.03,
+ "gpt-4": 0.06,
+ "gpt-4-0314": 0.06,
+ "gpt-4-0613": 0.06,
+ "gpt-4-32k": 0.12,
+ "gpt-4-32k-0613": 0.12,
+ "gpt-4-32k-0314": 0.12,
+ "gpt-3.5-turbo": 0.002,
+ "gpt-3.5-turbo-0125": 0.0015,
+ "gpt-3.5-turbo-0301": 0.002,
+ "gpt-3.5-turbo-0613": 0.002,
+ "gpt-3.5-turbo-instruct": 0.002,
+ "gpt-3.5-turbo-16k": 0.004,
+ "gpt-3.5-turbo-16k-0613": 0.004,
+ "finetune-gpt-4-0613": 0.09,
+ "finetune-gpt-3.5-turbo-0125": 0.006,
+ "finetune-gpt-3.5-turbo-1106": 0.006,
+ "finetune-gpt-3.5-turbo-0613": 0.006,
+ "finetune-babbage-002": 0.0016,
+ "finetune-davinci-002": 0.012,
},
}
@@ -124,7 +153,7 @@ func (ce *CostEstimator) EstimatePromptCost(model string, tks int) (float64, err
}
- cost, ok := costMap[model]
+ cost, ok := costMap[useFinetuneModel(model)]
if !ok {
return 0, fmt.Errorf("%s is not present in the cost map provided", model)
}
@@ -155,7 +184,7 @@ func (ce *CostEstimator) EstimateCompletionCost(model string, tks int) (float64,
return 0, errors.New("prompt token cost is not provided")
}
- cost, ok := costMap[model]
+ cost, ok := costMap[useFinetuneModel(model)]
if !ok {
return 0, errors.New("model is not present in the cost map provided")
}
@@ -237,6 +266,20 @@ func (ce *CostEstimator) EstimateSpeechCost(input string, model string) (float64
return float64(len(input)) / 1000 * cost, nil
}
+func (ce *CostEstimator) EstimateFinetuningCost(num int, model string) (float64, error) {
+ costMap, ok := ce.tokenCostMap["finetune"]
+ if !ok {
+ return 0, errors.New("audio cost map is not provided")
+ }
+
+ cost, ok := costMap[model]
+ if !ok {
+ return 0, errors.New("model is not present in the audio cost map")
+ }
+
+ return cost * float64(num), nil
+}
+
func (ce *CostEstimator) EstimateEmbeddingsCost(r *goopenai.EmbeddingRequest) (float64, error) {
if len(string(r.Model)) == 0 {
return 0, errors.New("model is not provided")
From 358b6f701d9e6608a12d7bc45d089db00c7ee4f3 Mon Sep 17 00:00:00 2001
From: Spike Lu
Date: Wed, 28 Feb 2024 10:55:36 -0800
Subject: [PATCH 66/71] update goopenai package
---
go.mod | 2 +-
go.sum | 2 ++
2 files changed, 3 insertions(+), 1 deletion(-)
diff --git a/go.mod b/go.mod
index 8f06052..ac6d969 100644
--- a/go.mod
+++ b/go.mod
@@ -12,7 +12,7 @@ require (
github.com/mattn/go-colorable v0.1.13
github.com/pkoukk/tiktoken-go-loader v0.0.1
github.com/redis/go-redis/v9 v9.0.5
- github.com/sashabaranov/go-openai v1.19.2
+ github.com/sashabaranov/go-openai v1.20.1
github.com/stretchr/testify v1.8.4
go.uber.org/zap v1.24.0
)
diff --git a/go.sum b/go.sum
index ff47f27..49b7bf3 100644
--- a/go.sum
+++ b/go.sum
@@ -94,6 +94,8 @@ github.com/sashabaranov/go-openai v1.17.7 h1:MPcAwlwbeo7ZmhQczoOgZBHtIBY1TfZqsdx
github.com/sashabaranov/go-openai v1.17.7/go.mod h1:lj5b/K+zjTSFxVLijLSTDZuP7adOgerWeFyZLUhAKRg=
github.com/sashabaranov/go-openai v1.19.2 h1:+dkuCADSnwXV02YVJkdphY8XD9AyHLUWwk6V7LB6EL8=
github.com/sashabaranov/go-openai v1.19.2/go.mod h1:lj5b/K+zjTSFxVLijLSTDZuP7adOgerWeFyZLUhAKRg=
+github.com/sashabaranov/go-openai v1.20.1 h1:cFnTixAtc0I0cCBFr8gkvEbGCm6Rjf2JyoVWCjXwy9g=
+github.com/sashabaranov/go-openai v1.20.1/go.mod h1:lj5b/K+zjTSFxVLijLSTDZuP7adOgerWeFyZLUhAKRg=
github.com/sirupsen/logrus v1.7.0/go.mod h1:yWOB1SBYBC5VeMP7gHvWumXLIWorT60ONWic61uBYv0=
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw=
From fbf4e180884ea9680cc2c9bb2d3a5446cb6282c1 Mon Sep 17 00:00:00 2001
From: Spike Lu
Date: Wed, 28 Feb 2024 11:14:51 -0800
Subject: [PATCH 67/71] remove unused variables
---
internal/server/web/proxy/anthropic.go | 2 +-
internal/server/web/proxy/assistant.go | 4 +-
internal/server/web/proxy/assistant_file.go | 2 +-
internal/server/web/proxy/audio.go | 7 +-
.../server/web/proxy/azure_chat_completion.go | 2 +-
internal/server/web/proxy/azure_embedding.go | 2 +-
internal/server/web/proxy/custom_provider.go | 2 +-
internal/server/web/proxy/file.go | 4 +-
internal/server/web/proxy/image.go | 2 +-
internal/server/web/proxy/message.go | 2 +-
internal/server/web/proxy/message_file.go | 2 +-
internal/server/web/proxy/middleware.go | 34 +++---
internal/server/web/proxy/models.go | 4 +-
internal/server/web/proxy/proxy.go | 110 +++++++++---------
internal/server/web/proxy/route.go | 6 +-
internal/server/web/proxy/run.go | 25 ++--
16 files changed, 107 insertions(+), 103 deletions(-)
diff --git a/internal/server/web/proxy/anthropic.go b/internal/server/web/proxy/anthropic.go
index 11375e7..0123275 100644
--- a/internal/server/web/proxy/anthropic.go
+++ b/internal/server/web/proxy/anthropic.go
@@ -35,7 +35,7 @@ func copyHttpHeaders(source *http.Request, dest *http.Request) {
dest.Header.Set("Accept-Encoding", "*")
}
-func getCompletionHandler(r recorder, prod, private bool, client http.Client, kms keyMemStorage, log *zap.Logger, e anthropicEstimator, timeOut time.Duration) gin.HandlerFunc {
+func getCompletionHandler(prod, private bool, client http.Client, log *zap.Logger, timeOut time.Duration) gin.HandlerFunc {
return func(c *gin.Context) {
stats.Incr("bricksllm.proxy.get_completion_handler.requests", nil, 1)
diff --git a/internal/server/web/proxy/assistant.go b/internal/server/web/proxy/assistant.go
index 0b20368..1a7b61f 100644
--- a/internal/server/web/proxy/assistant.go
+++ b/internal/server/web/proxy/assistant.go
@@ -63,7 +63,7 @@ func logAssistantResponse(log *zap.Logger, data []byte, prod, private bool, cid
}
}
-func logRetrieveAssistantRequest(log *zap.Logger, data []byte, prod bool, cid, assistantId string) {
+func logRetrieveAssistantRequest(log *zap.Logger, prod bool, cid, assistantId string) {
if prod {
fields := []zapcore.Field{
zap.String(correlationId, cid),
@@ -102,7 +102,7 @@ func logModifyAssistantRequest(log *zap.Logger, data []byte, prod, private bool,
}
}
-func logDeleteAssistantRequest(log *zap.Logger, data []byte, prod bool, cid, assistantId string) {
+func logDeleteAssistantRequest(log *zap.Logger, prod bool, cid, assistantId string) {
if prod {
fields := []zapcore.Field{
zap.String(correlationId, cid),
diff --git a/internal/server/web/proxy/assistant_file.go b/internal/server/web/proxy/assistant_file.go
index 82a45cf..32dca5e 100644
--- a/internal/server/web/proxy/assistant_file.go
+++ b/internal/server/web/proxy/assistant_file.go
@@ -60,7 +60,7 @@ func logRetrieveAssistantFileRequest(log *zap.Logger, prod bool, cid, fid, aid s
}
}
-func logDeleteAssistantFileRequest(log *zap.Logger, data []byte, prod bool, cid, fid, aid string) {
+func logDeleteAssistantFileRequest(log *zap.Logger, prod bool, cid, fid, aid string) {
if prod {
fields := []zapcore.Field{
zap.String(correlationId, cid),
diff --git a/internal/server/web/proxy/audio.go b/internal/server/web/proxy/audio.go
index ecff42b..d513829 100644
--- a/internal/server/web/proxy/audio.go
+++ b/internal/server/web/proxy/audio.go
@@ -19,7 +19,7 @@ import (
"go.uber.org/zap/zapcore"
)
-func getSpeechHandler(r recorder, prod, private bool, client http.Client, log *zap.Logger, timeOut time.Duration) gin.HandlerFunc {
+func getSpeechHandler(prod bool, client http.Client, log *zap.Logger, timeOut time.Duration) gin.HandlerFunc {
return func(c *gin.Context) {
stats.Incr("bricksllm.proxy.get_speech_handler.requests", nil, 1)
@@ -167,7 +167,7 @@ func getContentType(format string) string {
return "text/plain; charset=utf-8"
}
-func getTranscriptionsHandler(r recorder, prod, private bool, client http.Client, log *zap.Logger, timeOut time.Duration, e estimator) gin.HandlerFunc {
+func getTranscriptionsHandler(prod bool, client http.Client, log *zap.Logger, timeOut time.Duration, e estimator) gin.HandlerFunc {
return func(c *gin.Context) {
stats.Incr("bricksllm.proxy.get_transcriptions_handler.requests", nil, 1)
@@ -327,7 +327,7 @@ func getTranscriptionsHandler(r recorder, prod, private bool, client http.Client
}
}
-func getTranslationsHandler(r recorder, prod, private bool, client http.Client, log *zap.Logger, timeOut time.Duration, e estimator) gin.HandlerFunc {
+func getTranslationsHandler(prod bool, client http.Client, log *zap.Logger, timeOut time.Duration, e estimator) gin.HandlerFunc {
return func(c *gin.Context) {
stats.Incr("bricksllm.proxy.get_translations_handler.requests", nil, 1)
@@ -552,6 +552,7 @@ func logCreateTranslationRequest(log *zap.Logger, model, prompt, responseFormat
fields := []zapcore.Field{
zap.String(correlationId, cid),
zap.String("model", model),
+ zap.Float64("temperature", temperature),
}
if !private && len(prompt) == 0 {
diff --git a/internal/server/web/proxy/azure_chat_completion.go b/internal/server/web/proxy/azure_chat_completion.go
index bb1c478..b5de639 100644
--- a/internal/server/web/proxy/azure_chat_completion.go
+++ b/internal/server/web/proxy/azure_chat_completion.go
@@ -25,7 +25,7 @@ func buildAzureUrl(path, deploymentId, apiVersion, resourceName string) string {
return fmt.Sprintf("https://%s.openai.azure.com/openai/deployments/%s/embeddings?api-version=%s", resourceName, deploymentId, apiVersion)
}
-func getAzureChatCompletionHandler(r recorder, prod, private bool, psm ProviderSettingsManager, client http.Client, kms keyMemStorage, log *zap.Logger, aoe azureEstimator, timeOut time.Duration) gin.HandlerFunc {
+func getAzureChatCompletionHandler(prod, private bool, client http.Client, log *zap.Logger, aoe azureEstimator, timeOut time.Duration) gin.HandlerFunc {
return func(c *gin.Context) {
stats.Incr("bricksllm.proxy.get_azure_chat_completion_handler.requests", nil, 1)
diff --git a/internal/server/web/proxy/azure_embedding.go b/internal/server/web/proxy/azure_embedding.go
index 21150b4..ba963e5 100644
--- a/internal/server/web/proxy/azure_embedding.go
+++ b/internal/server/web/proxy/azure_embedding.go
@@ -13,7 +13,7 @@ import (
"go.uber.org/zap"
)
-func getAzureEmbeddingsHandler(r recorder, prod, private bool, psm ProviderSettingsManager, client http.Client, kms keyMemStorage, log *zap.Logger, aoe azureEstimator, timeOut time.Duration) gin.HandlerFunc {
+func getAzureEmbeddingsHandler(prod, private bool, client http.Client, log *zap.Logger, aoe azureEstimator, timeOut time.Duration) gin.HandlerFunc {
return func(c *gin.Context) {
stats.Incr("bricksllm.proxy.get_azure_embeddings_handler.requests", nil, 1)
if c == nil || c.Request == nil {
diff --git a/internal/server/web/proxy/custom_provider.go b/internal/server/web/proxy/custom_provider.go
index 25b5a7b..ee09c25 100644
--- a/internal/server/web/proxy/custom_provider.go
+++ b/internal/server/web/proxy/custom_provider.go
@@ -46,7 +46,7 @@ type ErrorResponse struct {
Error *Error `json:"error"`
}
-func getCustomProviderHandler(prod, private bool, psm ProviderSettingsManager, cpm CustomProvidersManager, client http.Client, log *zap.Logger, timeOut time.Duration) gin.HandlerFunc {
+func getCustomProviderHandler(prod bool, client http.Client, log *zap.Logger, timeOut time.Duration) gin.HandlerFunc {
return func(c *gin.Context) {
tags := []string{
fmt.Sprintf("path:%s", c.FullPath()),
diff --git a/internal/server/web/proxy/file.go b/internal/server/web/proxy/file.go
index cd2667d..55a1377 100644
--- a/internal/server/web/proxy/file.go
+++ b/internal/server/web/proxy/file.go
@@ -108,7 +108,7 @@ func logDeleteFileResponse(log *zap.Logger, data []byte, prod bool, cid string)
}
}
-func logRetrieveFileContentRequest(log *zap.Logger, data []byte, prod bool, cid, fid string) {
+func logRetrieveFileContentRequest(log *zap.Logger, prod bool, cid, fid string) {
if prod {
fields := []zapcore.Field{
zap.String(correlationId, cid),
@@ -119,7 +119,7 @@ func logRetrieveFileContentRequest(log *zap.Logger, data []byte, prod bool, cid,
}
}
-func logRetrieveFileContentResponse(log *zap.Logger, data []byte, prod bool, cid string) {
+func logRetrieveFileContentResponse(log *zap.Logger, prod bool, cid string) {
if prod {
fields := []zapcore.Field{
zap.String(correlationId, cid),
diff --git a/internal/server/web/proxy/image.go b/internal/server/web/proxy/image.go
index edba584..7b8461d 100644
--- a/internal/server/web/proxy/image.go
+++ b/internal/server/web/proxy/image.go
@@ -63,7 +63,7 @@ func logEditImageRequest(log *zap.Logger, prompt, model string, n int, size, res
}
}
-func logImageVariationsRequest(log *zap.Logger, model string, n int, size, responseFormat, user string, prod, private bool, cid string) {
+func logImageVariationsRequest(log *zap.Logger, model string, n int, size, responseFormat, user string, prod bool, cid string) {
if prod {
fields := []zapcore.Field{
zap.String(correlationId, cid),
diff --git a/internal/server/web/proxy/message.go b/internal/server/web/proxy/message.go
index 737eed3..3f44113 100644
--- a/internal/server/web/proxy/message.go
+++ b/internal/server/web/proxy/message.go
@@ -100,7 +100,7 @@ func logModifyMessageRequest(log *zap.Logger, data []byte, prod, private bool, c
}
}
-func logListMessagesRequest(log *zap.Logger, data []byte, prod bool, cid, tid string) {
+func logListMessagesRequest(log *zap.Logger, prod bool, cid, tid string) {
if prod {
fields := []zapcore.Field{
zap.String(correlationId, cid),
diff --git a/internal/server/web/proxy/message_file.go b/internal/server/web/proxy/message_file.go
index 23518e6..8e85b7c 100644
--- a/internal/server/web/proxy/message_file.go
+++ b/internal/server/web/proxy/message_file.go
@@ -42,7 +42,7 @@ func logRetrieveMessageFileResponse(log *zap.Logger, data []byte, prod bool, cid
}
}
-func logListMessageFilesRequest(log *zap.Logger, data []byte, prod bool, cid, tid, mid string, params map[string]string) {
+func logListMessageFilesRequest(log *zap.Logger, prod bool, cid, tid, mid string, params map[string]string) {
if prod {
fields := []zapcore.Field{
zap.String(correlationId, cid),
diff --git a/internal/server/web/proxy/middleware.go b/internal/server/web/proxy/middleware.go
index 2a40314..37585c3 100644
--- a/internal/server/web/proxy/middleware.go
+++ b/internal/server/web/proxy/middleware.go
@@ -129,7 +129,7 @@ func (w responseWriter) Write(b []byte) (int, error) {
return w.ResponseWriter.Write(b)
}
-func getMiddleware(kms keyMemStorage, cpm CustomProvidersManager, rm routeManager, a authenticator, prod, private bool, e estimator, ae anthropicEstimator, aoe azureEstimator, v validator, ks keyStorage, log *zap.Logger, rlm rateLimitManager, pub publisher, prefix string, ac accessCache) gin.HandlerFunc {
+func getMiddleware(cpm CustomProvidersManager, rm routeManager, a authenticator, prod, private bool, log *zap.Logger, pub publisher, prefix string, ac accessCache) gin.HandlerFunc {
return func(c *gin.Context) {
if c == nil || c.Request == nil {
JSON(c, http.StatusInternalServerError, "[BricksLLM] request is empty")
@@ -589,7 +589,7 @@ func getMiddleware(kms keyMemStorage, cpm CustomProvidersManager, rm routeManage
c.Set("model", "dall-e-2")
}
- logImageVariationsRequest(log, model, n, size, responseFormat, user, prod, private, cid)
+ logImageVariationsRequest(log, model, n, size, responseFormat, user, prod, cid)
}
if c.FullPath() == "/api/providers/openai/v1/audio/speech" && c.Request.Method == http.MethodPost {
@@ -677,7 +677,7 @@ func getMiddleware(kms keyMemStorage, cpm CustomProvidersManager, rm routeManage
}
if c.FullPath() == "/api/providers/openai/v1/assistants/:assistant_id" && c.Request.Method == http.MethodGet {
- logRetrieveAssistantRequest(log, body, prod, cid, aid)
+ logRetrieveAssistantRequest(log, prod, cid, aid)
}
if c.FullPath() == "/api/providers/openai/v1/assistants/:assistant_id" && c.Request.Method == http.MethodPost {
@@ -685,7 +685,7 @@ func getMiddleware(kms keyMemStorage, cpm CustomProvidersManager, rm routeManage
}
if c.FullPath() == "/api/providers/openai/v1/assistants/:assistant_id" && c.Request.Method == http.MethodDelete {
- logDeleteAssistantRequest(log, body, prod, cid, aid)
+ logDeleteAssistantRequest(log, prod, cid, aid)
}
if c.FullPath() == "/api/providers/openai/v1/assistants" && c.Request.Method == http.MethodGet {
@@ -701,7 +701,7 @@ func getMiddleware(kms keyMemStorage, cpm CustomProvidersManager, rm routeManage
}
if c.FullPath() == "/api/providers/openai/v1/assistants/:assistant_id/files/:file_id" && c.Request.Method == http.MethodDelete {
- logDeleteAssistantFileRequest(log, body, prod, cid, fid, aid)
+ logDeleteAssistantFileRequest(log, prod, cid, fid, aid)
}
if c.FullPath() == "/api/providers/openai/v1/assistants/:assistant_id/files" && c.Request.Method == http.MethodGet {
@@ -737,7 +737,7 @@ func getMiddleware(kms keyMemStorage, cpm CustomProvidersManager, rm routeManage
}
if c.FullPath() == "/api/providers/openai/v1/threads/:thread_id/messages" && c.Request.Method == http.MethodGet {
- logListMessagesRequest(log, body, prod, cid, aid)
+ logListMessagesRequest(log, prod, cid, aid)
}
if c.FullPath() == "/api/providers/openai/v1/threads/:thread_id/messages/:message_id/files/:file_id" && c.Request.Method == http.MethodGet {
@@ -745,15 +745,15 @@ func getMiddleware(kms keyMemStorage, cpm CustomProvidersManager, rm routeManage
}
if c.FullPath() == "/api/providers/openai/v1/threads/:thread_id/messages/:message_id/files" && c.Request.Method == http.MethodGet {
- logListMessageFilesRequest(log, body, prod, cid, tid, mid, qm)
+ logListMessageFilesRequest(log, prod, cid, tid, mid, qm)
}
if c.FullPath() == "/api/providers/openai/v1/threads/:thread_id/runs" && c.Request.Method == http.MethodPost {
- logCreateRunRequest(log, body, prod, cid)
+ logCreateRunRequest(log, body, prod, private, cid)
}
if c.FullPath() == "/api/providers/openai/v1/threads/:thread_id/runs/:run_id" && c.Request.Method == http.MethodGet {
- logRetrieveRunRequest(log, body, prod, cid, tid, rid)
+ logRetrieveRunRequest(log, prod, cid, tid, rid)
}
if c.FullPath() == "/api/providers/openai/v1/threads/:thread_id/runs/:run_id" && c.Request.Method == http.MethodPost {
@@ -761,7 +761,7 @@ func getMiddleware(kms keyMemStorage, cpm CustomProvidersManager, rm routeManage
}
if c.FullPath() == "/api/providers/openai/v1/threads/:thread_id/runs" && c.Request.Method == http.MethodGet {
- logListRunsRequest(log, body, prod, cid, tid, qm)
+ logListRunsRequest(log, prod, cid, tid, qm)
}
if c.FullPath() == "/api/providers/openai/v1/threads/:thread_id/runs/:run_id/submit_tool_outputs" && c.Request.Method == http.MethodPost {
@@ -769,19 +769,19 @@ func getMiddleware(kms keyMemStorage, cpm CustomProvidersManager, rm routeManage
}
if c.FullPath() == "/api/providers/openai/v1/threads/:thread_id/runs/:run_id/cancel" && c.Request.Method == http.MethodPost {
- logCancelARunRequest(log, body, prod, cid, tid, rid)
+ logCancelARunRequest(log, prod, cid, tid, rid)
}
if c.FullPath() == "/api/providers/openai/v1/threads/runs" && c.Request.Method == http.MethodPost {
- logCreateThreadAndRunRequest(log, body, prod, private, cid, tid, rid)
+ logCreateThreadAndRunRequest(log, body, prod, private, cid)
}
if c.FullPath() == "/api/providers/openai/v1/threads/:thread_id/runs/:run_id/steps/:step_id" && c.Request.Method == http.MethodGet {
- logRetrieveRunStepRequest(log, body, prod, cid, tid, rid, sid)
+ logRetrieveRunStepRequest(log, prod, cid, tid, rid, sid)
}
if c.FullPath() == "/api/providers/openai/v1/threads/:thread_id/runs/:run_id/steps" && c.Request.Method == http.MethodGet {
- logListRunStepsRequest(log, body, prod, cid, tid, rid, qm)
+ logListRunStepsRequest(log, prod, cid, tid, rid, qm)
}
if c.FullPath() == "/api/providers/openai/v1/moderations" && c.Request.Method == http.MethodPost {
@@ -793,11 +793,11 @@ func getMiddleware(kms keyMemStorage, cpm CustomProvidersManager, rm routeManage
}
if c.FullPath() == "/api/providers/openai/v1/models/:model" && c.Request.Method == http.MethodGet {
- logRetrieveModelRequest(log, body, prod, cid, md)
+ logRetrieveModelRequest(log, prod, cid, md)
}
if c.FullPath() == "/api/providers/openai/v1/models/:model" && c.Request.Method == http.MethodDelete {
- logDeleteModelRequest(log, body, prod, cid, md)
+ logDeleteModelRequest(log, prod, cid, md)
}
if c.FullPath() == "/api/providers/openai/v1/files" && c.Request.Method == http.MethodGet {
@@ -818,7 +818,7 @@ func getMiddleware(kms keyMemStorage, cpm CustomProvidersManager, rm routeManage
}
if c.FullPath() == "/api/providers/openai/v1/files/:file_id/content" && c.Request.Method == http.MethodGet {
- logRetrieveFileContentRequest(log, body, prod, cid, fid)
+ logRetrieveFileContentRequest(log, prod, cid, fid)
}
if ac.GetAccessStatus(kc.KeyId) {
diff --git a/internal/server/web/proxy/models.go b/internal/server/web/proxy/models.go
index 640a068..ba6db48 100644
--- a/internal/server/web/proxy/models.go
+++ b/internal/server/web/proxy/models.go
@@ -26,7 +26,7 @@ func logListModelsResponse(log *zap.Logger, data []byte, prod bool, cid string)
}
}
-func logRetrieveModelRequest(log *zap.Logger, data []byte, prod bool, cid, model string) {
+func logRetrieveModelRequest(log *zap.Logger, prod bool, cid, model string) {
if prod {
fields := []zapcore.Field{
zap.String(correlationId, cid),
@@ -58,7 +58,7 @@ func logRetrieveModelResponse(log *zap.Logger, data []byte, prod bool, cid strin
}
}
-func logDeleteModelRequest(log *zap.Logger, data []byte, prod bool, cid, model string) {
+func logDeleteModelRequest(log *zap.Logger, prod bool, cid, model string) {
if prod {
fields := []zapcore.Field{
zap.String(correlationId, cid),
diff --git a/internal/server/web/proxy/proxy.go b/internal/server/web/proxy/proxy.go
index ff43b52..450795b 100644
--- a/internal/server/web/proxy/proxy.go
+++ b/internal/server/web/proxy/proxy.go
@@ -80,7 +80,7 @@ func NewProxyServer(log *zap.Logger, mode, privacyMode string, c cache, m KeyMan
private := privacyMode == "strict"
router.Use(CorsMiddleware())
- router.Use(getMiddleware(kms, cpm, rm, a, prod, private, e, ae, aoe, v, ks, log, rlm, pub, "proxy", ac))
+ router.Use(getMiddleware(cpm, rm, a, prod, private, log, pub, "proxy", ac))
client := http.Client{}
@@ -88,88 +88,88 @@ func NewProxyServer(log *zap.Logger, mode, privacyMode string, c cache, m KeyMan
router.POST("/api/health", getGetHealthCheckHandler())
// audios
- router.POST("/api/providers/openai/v1/audio/speech", getSpeechHandler(r, prod, private, client, log, timeOut))
- router.POST("/api/providers/openai/v1/audio/transcriptions", getTranscriptionsHandler(r, prod, private, client, log, timeOut, e))
- router.POST("/api/providers/openai/v1/audio/translations", getTranslationsHandler(r, prod, private, client, log, timeOut, e))
+ router.POST("/api/providers/openai/v1/audio/speech", getSpeechHandler(prod, client, log, timeOut))
+ router.POST("/api/providers/openai/v1/audio/transcriptions", getTranscriptionsHandler(prod, client, log, timeOut, e))
+ router.POST("/api/providers/openai/v1/audio/translations", getTranslationsHandler(prod, client, log, timeOut, e))
// completions
- router.POST("/api/providers/openai/v1/chat/completions", getChatCompletionHandler(r, prod, private, psm, client, kms, log, e, timeOut))
+ router.POST("/api/providers/openai/v1/chat/completions", getChatCompletionHandler(prod, private, client, log, e, timeOut))
// embeddings
- router.POST("/api/providers/openai/v1/embeddings", getEmbeddingHandler(r, prod, private, psm, client, kms, log, e, timeOut))
+ router.POST("/api/providers/openai/v1/embeddings", getEmbeddingHandler(prod, private, client, log, e, timeOut))
// moderations
- router.POST("/api/providers/openai/v1/moderations", getPassThroughHandler(r, prod, private, client, log, timeOut))
+ router.POST("/api/providers/openai/v1/moderations", getPassThroughHandler(prod, private, client, log, timeOut))
// models
- router.GET("/api/providers/openai/v1/models", getPassThroughHandler(r, prod, private, client, log, timeOut))
- router.GET("/api/providers/openai/v1/models/:model", getPassThroughHandler(r, prod, private, client, log, timeOut))
- router.DELETE("/api/providers/openai/v1/models/:model", getPassThroughHandler(r, prod, private, client, log, timeOut))
+ router.GET("/api/providers/openai/v1/models", getPassThroughHandler(prod, private, client, log, timeOut))
+ router.GET("/api/providers/openai/v1/models/:model", getPassThroughHandler(prod, private, client, log, timeOut))
+ router.DELETE("/api/providers/openai/v1/models/:model", getPassThroughHandler(prod, private, client, log, timeOut))
// assistants
- router.POST("/api/providers/openai/v1/assistants", getPassThroughHandler(r, prod, private, client, log, timeOut))
- router.GET("/api/providers/openai/v1/assistants/:assistant_id", getPassThroughHandler(r, prod, private, client, log, timeOut))
- router.POST("/api/providers/openai/v1/assistants/:assistant_id", getPassThroughHandler(r, prod, private, client, log, timeOut))
- router.DELETE("/api/providers/openai/v1/assistants/:assistant_id", getPassThroughHandler(r, prod, private, client, log, timeOut))
- router.GET("/api/providers/openai/v1/assistants", getPassThroughHandler(r, prod, private, client, log, timeOut))
+ router.POST("/api/providers/openai/v1/assistants", getPassThroughHandler(prod, private, client, log, timeOut))
+ router.GET("/api/providers/openai/v1/assistants/:assistant_id", getPassThroughHandler(prod, private, client, log, timeOut))
+ router.POST("/api/providers/openai/v1/assistants/:assistant_id", getPassThroughHandler(prod, private, client, log, timeOut))
+ router.DELETE("/api/providers/openai/v1/assistants/:assistant_id", getPassThroughHandler(prod, private, client, log, timeOut))
+ router.GET("/api/providers/openai/v1/assistants", getPassThroughHandler(prod, private, client, log, timeOut))
// assistant files
- router.POST("/api/providers/openai/v1/assistants/:assistant_id/files", getPassThroughHandler(r, prod, private, client, log, timeOut))
- router.GET("/api/providers/openai/v1/assistants/:assistant_id/files/:file_id", getPassThroughHandler(r, prod, private, client, log, timeOut))
- router.DELETE("/api/providers/openai/v1/assistants/:assistant_id/files/:file_id", getPassThroughHandler(r, prod, private, client, log, timeOut))
- router.GET("/api/providers/openai/v1/assistants/:assistant_id/files", getPassThroughHandler(r, prod, private, client, log, timeOut))
+ router.POST("/api/providers/openai/v1/assistants/:assistant_id/files", getPassThroughHandler(prod, private, client, log, timeOut))
+ router.GET("/api/providers/openai/v1/assistants/:assistant_id/files/:file_id", getPassThroughHandler(prod, private, client, log, timeOut))
+ router.DELETE("/api/providers/openai/v1/assistants/:assistant_id/files/:file_id", getPassThroughHandler(prod, private, client, log, timeOut))
+ router.GET("/api/providers/openai/v1/assistants/:assistant_id/files", getPassThroughHandler(prod, private, client, log, timeOut))
// threads
- router.POST("/api/providers/openai/v1/threads", getPassThroughHandler(r, prod, private, client, log, timeOut))
- router.GET("/api/providers/openai/v1/threads/:thread_id", getPassThroughHandler(r, prod, private, client, log, timeOut))
- router.POST("/api/providers/openai/v1/threads/:thread_id", getPassThroughHandler(r, prod, private, client, log, timeOut))
- router.DELETE("/api/providers/openai/v1/threads/:thread_id", getPassThroughHandler(r, prod, private, client, log, timeOut))
+ router.POST("/api/providers/openai/v1/threads", getPassThroughHandler(prod, private, client, log, timeOut))
+ router.GET("/api/providers/openai/v1/threads/:thread_id", getPassThroughHandler(prod, private, client, log, timeOut))
+ router.POST("/api/providers/openai/v1/threads/:thread_id", getPassThroughHandler(prod, private, client, log, timeOut))
+ router.DELETE("/api/providers/openai/v1/threads/:thread_id", getPassThroughHandler(prod, private, client, log, timeOut))
// messages
- router.POST("/api/providers/openai/v1/threads/:thread_id/messages", getPassThroughHandler(r, prod, private, client, log, timeOut))
- router.GET("/api/providers/openai/v1/threads/:thread_id/messages/:message_id", getPassThroughHandler(r, prod, private, client, log, timeOut))
- router.POST("/api/providers/openai/v1/threads/:thread_id/messages/:message_id", getPassThroughHandler(r, prod, private, client, log, timeOut))
- router.GET("/api/providers/openai/v1/threads/:thread_id/messages", getPassThroughHandler(r, prod, private, client, log, timeOut))
+ router.POST("/api/providers/openai/v1/threads/:thread_id/messages", getPassThroughHandler(prod, private, client, log, timeOut))
+ router.GET("/api/providers/openai/v1/threads/:thread_id/messages/:message_id", getPassThroughHandler(prod, private, client, log, timeOut))
+ router.POST("/api/providers/openai/v1/threads/:thread_id/messages/:message_id", getPassThroughHandler(prod, private, client, log, timeOut))
+ router.GET("/api/providers/openai/v1/threads/:thread_id/messages", getPassThroughHandler(prod, private, client, log, timeOut))
// message files
- router.GET("/api/providers/openai/v1/threads/:thread_id/messages/:message_id/files/:file_id", getPassThroughHandler(r, prod, private, client, log, timeOut))
- router.GET("/api/providers/openai/v1/threads/:thread_id/messages/:message_id/files", getPassThroughHandler(r, prod, private, client, log, timeOut))
+ router.GET("/api/providers/openai/v1/threads/:thread_id/messages/:message_id/files/:file_id", getPassThroughHandler(prod, private, client, log, timeOut))
+ router.GET("/api/providers/openai/v1/threads/:thread_id/messages/:message_id/files", getPassThroughHandler(prod, private, client, log, timeOut))
// runs
- router.POST("/api/providers/openai/v1/threads/:thread_id/runs", getPassThroughHandler(r, prod, private, client, log, timeOut))
- router.GET("/api/providers/openai/v1/threads/:thread_id/runs/:run_id", getPassThroughHandler(r, prod, private, client, log, timeOut))
- router.POST("/api/providers/openai/v1/threads/:thread_id/runs/:run_id", getPassThroughHandler(r, prod, private, client, log, timeOut))
- router.GET("/api/providers/openai/v1/threads/:thread_id/runs", getPassThroughHandler(r, prod, private, client, log, timeOut))
- router.POST("/api/providers/openai/v1/threads/:thread_id/runs/:run_id/submit_tool_outputs", getPassThroughHandler(r, prod, private, client, log, timeOut))
- router.POST("/api/providers/openai/v1/threads/:thread_id/runs/:run_id/cancel", getPassThroughHandler(r, prod, private, client, log, timeOut))
- router.POST("/api/providers/openai/v1/threads/runs", getPassThroughHandler(r, prod, private, client, log, timeOut))
- router.GET("/api/providers/openai/v1/threads/:thread_id/runs/:run_id/steps/:step_id", getPassThroughHandler(r, prod, private, client, log, timeOut))
- router.GET("/api/providers/openai/v1/threads/:thread_id/runs/:run_id/steps", getPassThroughHandler(r, prod, private, client, log, timeOut))
+ router.POST("/api/providers/openai/v1/threads/:thread_id/runs", getPassThroughHandler(prod, private, client, log, timeOut))
+ router.GET("/api/providers/openai/v1/threads/:thread_id/runs/:run_id", getPassThroughHandler(prod, private, client, log, timeOut))
+ router.POST("/api/providers/openai/v1/threads/:thread_id/runs/:run_id", getPassThroughHandler(prod, private, client, log, timeOut))
+ router.GET("/api/providers/openai/v1/threads/:thread_id/runs", getPassThroughHandler(prod, private, client, log, timeOut))
+ router.POST("/api/providers/openai/v1/threads/:thread_id/runs/:run_id/submit_tool_outputs", getPassThroughHandler(prod, private, client, log, timeOut))
+ router.POST("/api/providers/openai/v1/threads/:thread_id/runs/:run_id/cancel", getPassThroughHandler(prod, private, client, log, timeOut))
+ router.POST("/api/providers/openai/v1/threads/runs", getPassThroughHandler(prod, private, client, log, timeOut))
+ router.GET("/api/providers/openai/v1/threads/:thread_id/runs/:run_id/steps/:step_id", getPassThroughHandler(prod, private, client, log, timeOut))
+ router.GET("/api/providers/openai/v1/threads/:thread_id/runs/:run_id/steps", getPassThroughHandler(prod, private, client, log, timeOut))
// files
- router.GET("/api/providers/openai/v1/files", getPassThroughHandler(r, prod, private, client, log, timeOut))
- router.POST("/api/providers/openai/v1/files", getPassThroughHandler(r, prod, private, client, log, timeOut))
- router.DELETE("/api/providers/openai/v1/files/:file_id", getPassThroughHandler(r, prod, private, client, log, timeOut))
- router.GET("/api/providers/openai/v1/files/:file_id", getPassThroughHandler(r, prod, private, client, log, timeOut))
- router.GET("/api/providers/openai/v1/files/:file_id/content", getPassThroughHandler(r, prod, private, client, log, timeOut))
+ router.GET("/api/providers/openai/v1/files", getPassThroughHandler(prod, private, client, log, timeOut))
+ router.POST("/api/providers/openai/v1/files", getPassThroughHandler(prod, private, client, log, timeOut))
+ router.DELETE("/api/providers/openai/v1/files/:file_id", getPassThroughHandler(prod, private, client, log, timeOut))
+ router.GET("/api/providers/openai/v1/files/:file_id", getPassThroughHandler(prod, private, client, log, timeOut))
+ router.GET("/api/providers/openai/v1/files/:file_id/content", getPassThroughHandler(prod, private, client, log, timeOut))
// images
- router.POST("/api/providers/openai/v1/images/generations", getPassThroughHandler(r, prod, private, client, log, timeOut))
- router.POST("/api/providers/openai/v1/images/edits", getPassThroughHandler(r, prod, private, client, log, timeOut))
- router.POST("/api/providers/openai/v1/images/variations", getPassThroughHandler(r, prod, private, client, log, timeOut))
+ router.POST("/api/providers/openai/v1/images/generations", getPassThroughHandler(prod, private, client, log, timeOut))
+ router.POST("/api/providers/openai/v1/images/edits", getPassThroughHandler(prod, private, client, log, timeOut))
+ router.POST("/api/providers/openai/v1/images/variations", getPassThroughHandler(prod, private, client, log, timeOut))
// azure
- router.POST("/api/providers/azure/openai/deployments/:deployment_id/chat/completions", getAzureChatCompletionHandler(r, prod, private, psm, client, kms, log, aoe, timeOut))
- router.POST("/api/providers/azure/openai/deployments/:deployment_id/embeddings", getAzureEmbeddingsHandler(r, prod, private, psm, client, kms, log, aoe, timeOut))
+ router.POST("/api/providers/azure/openai/deployments/:deployment_id/chat/completions", getAzureChatCompletionHandler(prod, private, client, log, aoe, timeOut))
+ router.POST("/api/providers/azure/openai/deployments/:deployment_id/embeddings", getAzureEmbeddingsHandler(prod, private, client, log, aoe, timeOut))
// anthropic
- router.POST("/api/providers/anthropic/v1/complete", getCompletionHandler(r, prod, private, client, kms, log, ae, timeOut))
+ router.POST("/api/providers/anthropic/v1/complete", getCompletionHandler(prod, private, client, log, timeOut))
// custom provider
- router.POST("/api/custom/providers/:provider/*wildcard", getCustomProviderHandler(prod, private, psm, cpm, client, log, timeOut))
+ router.POST("/api/custom/providers/:provider/*wildcard", getCustomProviderHandler(prod, client, log, timeOut))
// custom route
- router.POST("/api/routes/*route", getRouteHandler(prod, private, rm, c, aoe, e, r, client, log, timeOut))
+ router.POST("/api/routes/*route", getRouteHandler(prod, c, aoe, e, client, log))
srv := &http.Server{
Addr: ":8002",
@@ -230,7 +230,7 @@ func writeFieldToBuffer(fields []string, c *gin.Context, writer *multipart.Write
return nil
}
-func getPassThroughHandler(r recorder, prod, private bool, client http.Client, log *zap.Logger, timeOut time.Duration) gin.HandlerFunc {
+func getPassThroughHandler(prod, private bool, client http.Client, log *zap.Logger, timeOut time.Duration) gin.HandlerFunc {
return func(c *gin.Context) {
tags := []string{
fmt.Sprintf("path:%s", c.FullPath()),
@@ -618,7 +618,7 @@ func getPassThroughHandler(r recorder, prod, private bool, client http.Client, l
}
if c.FullPath() == "/api/providers/openai/v1/files/:file_id/content" && c.Request.Method == http.MethodGet {
- logRetrieveFileContentResponse(log, bytes, prod, cid)
+ logRetrieveFileContentResponse(log, prod, cid)
}
if c.FullPath() == "/api/providers/openai/v1/images/generations" && c.Request.Method == http.MethodPost {
@@ -854,7 +854,7 @@ type EmbeddingResponseBase64 struct {
Usage goopenai.Usage `json:"usage"`
}
-func getEmbeddingHandler(r recorder, prod, private bool, psm ProviderSettingsManager, client http.Client, kms keyMemStorage, log *zap.Logger, e estimator, timeOut time.Duration) gin.HandlerFunc {
+func getEmbeddingHandler(prod, private bool, client http.Client, log *zap.Logger, e estimator, timeOut time.Duration) gin.HandlerFunc {
return func(c *gin.Context) {
stats.Incr("bricksllm.proxy.get_embedding_handler.requests", nil, 1)
if c == nil || c.Request == nil {
@@ -994,7 +994,7 @@ var (
eventErrorPrefix = []byte("event: error")
)
-func getChatCompletionHandler(r recorder, prod, private bool, psm ProviderSettingsManager, client http.Client, kms keyMemStorage, log *zap.Logger, e estimator, timeOut time.Duration) gin.HandlerFunc {
+func getChatCompletionHandler(prod, private bool, client http.Client, log *zap.Logger, e estimator, timeOut time.Duration) gin.HandlerFunc {
return func(c *gin.Context) {
stats.Incr("bricksllm.proxy.get_chat_completion_handler.requests", nil, 1)
diff --git a/internal/server/web/proxy/route.go b/internal/server/web/proxy/route.go
index 94e546d..2b56741 100644
--- a/internal/server/web/proxy/route.go
+++ b/internal/server/web/proxy/route.go
@@ -27,7 +27,7 @@ type cache interface {
GetBytes(key string) ([]byte, error)
}
-func getRouteHandler(prod, private bool, rm routeManager, ca cache, aoe azureEstimator, e estimator, r recorder, client http.Client, log *zap.Logger, timeOut time.Duration) gin.HandlerFunc {
+func getRouteHandler(prod bool, ca cache, aoe azureEstimator, e estimator, client http.Client, log *zap.Logger) gin.HandlerFunc {
return func(c *gin.Context) {
trueStart := time.Now()
@@ -144,7 +144,7 @@ func getRouteHandler(prod, private bool, rm routeManager, ca cache, aoe azureEst
}
- err := parseResult(c, ca, kc, rc.ShouldRunEmbeddings(), bytes, e, aoe, r, runRes.Model, runRes.Provider)
+ err := parseResult(c, rc.ShouldRunEmbeddings(), bytes, e, aoe, runRes.Model, runRes.Provider)
if err != nil {
logError(log, "error when parsing run steps result", prod, cid, err)
}
@@ -173,7 +173,7 @@ func getRouteHandler(prod, private bool, rm routeManager, ca cache, aoe azureEst
}
}
-func parseResult(c *gin.Context, ca cache, kc *key.ResponseKey, runEmbeddings bool, bytes []byte, e estimator, aoe azureEstimator, r recorder, model, provider string) error {
+func parseResult(c *gin.Context, runEmbeddings bool, bytes []byte, e estimator, aoe azureEstimator, model, provider string) error {
base64ChatRes := &EmbeddingResponseBase64{}
chatRes := &EmbeddingResponse{}
diff --git a/internal/server/web/proxy/run.go b/internal/server/web/proxy/run.go
index de4cfa5..a8558ff 100644
--- a/internal/server/web/proxy/run.go
+++ b/internal/server/web/proxy/run.go
@@ -8,7 +8,7 @@ import (
"go.uber.org/zap/zapcore"
)
-func logCreateRunRequest(log *zap.Logger, data []byte, prod bool, cid string) {
+func logCreateRunRequest(log *zap.Logger, data []byte, prod, private bool, cid string) {
rr := &goopenai.RunRequest{}
err := json.Unmarshal(data, rr)
if err != nil {
@@ -20,12 +20,15 @@ func logCreateRunRequest(log *zap.Logger, data []byte, prod bool, cid string) {
fields := []zapcore.Field{
zap.String(correlationId, cid),
zap.String("assistant_id", rr.AssistantID),
- zap.Stringp("instruction", rr.Instructions),
- zap.Stringp("model", rr.Model),
+ zap.String("model", rr.Model),
zap.Any("tools", rr.Tools),
zap.Any("metadata", rr.Metadata),
}
+ if !private {
+ fields = append(fields, zap.String("instruction", rr.Instructions))
+ }
+
log.Info("openai create run request", fields...)
}
}
@@ -68,7 +71,7 @@ func logRunResponse(log *zap.Logger, data []byte, prod, private bool, cid string
}
}
-func logRetrieveRunRequest(log *zap.Logger, data []byte, prod bool, cid, tid, rid string) {
+func logRetrieveRunRequest(log *zap.Logger, prod bool, cid, tid, rid string) {
if prod {
fields := []zapcore.Field{
zap.String(correlationId, cid),
@@ -100,7 +103,7 @@ func logModifyRunRequest(log *zap.Logger, data []byte, prod bool, cid, tid, rid
}
}
-func logListRunsRequest(log *zap.Logger, data []byte, prod bool, cid, tid string, params map[string]string) {
+func logListRunsRequest(log *zap.Logger, prod bool, cid, tid string, params map[string]string) {
if prod {
fields := []zapcore.Field{
zap.String(correlationId, cid),
@@ -171,7 +174,7 @@ func logSubmitToolOutputsRequest(log *zap.Logger, data []byte, prod bool, cid, t
}
}
-func logCancelARunRequest(log *zap.Logger, data []byte, prod bool, cid, tid, rid string) {
+func logCancelARunRequest(log *zap.Logger, prod bool, cid, tid, rid string) {
if prod {
fields := []zapcore.Field{
zap.String(correlationId, cid),
@@ -183,7 +186,7 @@ func logCancelARunRequest(log *zap.Logger, data []byte, prod bool, cid, tid, rid
}
}
-func logCreateThreadAndRunRequest(log *zap.Logger, data []byte, prod, private bool, cid, tid, rid string) {
+func logCreateThreadAndRunRequest(log *zap.Logger, data []byte, prod, private bool, cid string) {
r := &goopenai.CreateThreadAndRunRequest{}
err := json.Unmarshal(data, r)
if err != nil {
@@ -196,20 +199,20 @@ func logCreateThreadAndRunRequest(log *zap.Logger, data []byte, prod, private bo
zap.String(correlationId, cid),
zap.String("assistant_id", r.AssistantID),
zap.Any("thread", r.Thread),
- zap.Stringp("model", r.Model),
+ zap.String("model", r.Model),
zap.Any("tools", r.Tools),
zap.Any("metadata", r.Metadata),
}
if !private {
- fields = append(fields, zap.Stringp("instructions", r.Instructions))
+ fields = append(fields, zap.String("instructions", r.Instructions))
}
log.Info("openai create thread and run request", fields...)
}
}
-func logRetrieveRunStepRequest(log *zap.Logger, data []byte, prod bool, cid, tid, rid, sid string) {
+func logRetrieveRunStepRequest(log *zap.Logger, prod bool, cid, tid, rid, sid string) {
if prod {
fields := []zapcore.Field{
zap.String(correlationId, cid),
@@ -254,7 +257,7 @@ func logRetrieveRunStepResponse(log *zap.Logger, data []byte, prod bool, cid str
}
}
-func logListRunStepsRequest(log *zap.Logger, data []byte, prod bool, cid, tid, rid string, params map[string]string) {
+func logListRunStepsRequest(log *zap.Logger, prod bool, cid, tid, rid string, params map[string]string) {
if prod {
fields := []zapcore.Field{
zap.String(correlationId, cid),
From 993dfb20d35fb9852b4a8069cd9b916109bdff5c Mon Sep 17 00:00:00 2001
From: Spike Lu
Date: Wed, 28 Feb 2024 11:27:24 -0800
Subject: [PATCH 68/71] update CHANGELOG
---
CHANGELOG.md | 5 +++++
1 file changed, 5 insertions(+)
diff --git a/CHANGELOG.md b/CHANGELOG.md
index cd95c17..c913442 100644
--- a/CHANGELOG.md
+++ b/CHANGELOG.md
@@ -1,3 +1,8 @@
+## 1.11.0 - 2024-02-28
+### Added
+- Added cost tracking for OpenAI audio endpoints
+- Added inference cost tracking for OpenAI finetune models
+
## 1.10.0 - 2024-02-21
### Added
- Added `userId` as a new filter option for get events API endpoint
From 5f75d863299e766f853bd9d5792f0001ba5fd864 Mon Sep 17 00:00:00 2001
From: Spike Lu
Date: Wed, 28 Feb 2024 14:47:13 -0800
Subject: [PATCH 69/71] add key rotaton feature
---
CHANGELOG.md | 4 ++++
internal/authenticator/authenticator.go | 7 ++++++-
internal/key/key.go | 3 +++
internal/manager/key.go | 25 -----------------------
internal/manager/provider_setting.go | 8 ++++----
internal/manager/route.go | 14 ++++++-------
internal/storage/postgresql/postgresql.go | 19 ++++++++++++++---
7 files changed, 40 insertions(+), 40 deletions(-)
diff --git a/CHANGELOG.md b/CHANGELOG.md
index c913442..b443d09 100644
--- a/CHANGELOG.md
+++ b/CHANGELOG.md
@@ -1,3 +1,7 @@
+## 1.12.0 - 2024-02-28
+### Added
+- Added setting rotation feature to key
+
## 1.11.0 - 2024-02-28
### Added
- Added cost tracking for OpenAI audio endpoints
diff --git a/internal/authenticator/authenticator.go b/internal/authenticator/authenticator.go
index f95fca9..94aa3c9 100644
--- a/internal/authenticator/authenticator.go
+++ b/internal/authenticator/authenticator.go
@@ -3,6 +3,7 @@ package auth
import (
"errors"
"fmt"
+ "math/rand"
"net/http"
"strings"
@@ -196,8 +197,12 @@ func (a *Authenticator) AuthenticateHttpRequest(req *http.Request) (*key.Respons
}
if len(selected) != 0 {
- err := rewriteHttpAuthHeader(req, selected[0])
+ used := selected[0]
+ if key.RotationEnabled {
+ used = selected[rand.Intn(len(selected))]
+ }
+ err := rewriteHttpAuthHeader(req, used)
if err != nil {
return nil, nil, err
}
diff --git a/internal/key/key.go b/internal/key/key.go
index 444df2e..a8b0be5 100644
--- a/internal/key/key.go
+++ b/internal/key/key.go
@@ -27,6 +27,7 @@ type UpdateKey struct {
AllowedPaths *[]PathConfig `json:"allowedPaths,omitempty"`
ShouldLogRequest *bool `json:"shouldLogRequest"`
ShouldLogResponse *bool `json:"shouldLogResponse"`
+ RotationEnabled *bool `json:"rotationEnabled"`
}
func (uk *UpdateKey) Validate() error {
@@ -131,6 +132,7 @@ type RequestKey struct {
SettingIds []string `json:"settingIds"`
ShouldLogRequest bool `json:"shouldLogRequest"`
ShouldLogResponse bool `json:"shouldLogResponse"`
+ RotationEnabled bool `json:"rotationEnabled"`
}
func (rk *RequestKey) Validate() error {
@@ -277,6 +279,7 @@ type ResponseKey struct {
SettingIds []string `json:"settingIds"`
ShouldLogRequest bool `json:"shouldLogRequest"`
ShouldLogResponse bool `json:"shouldLogResponse"`
+ RotationEnabled bool `json:"rotationEnabled"`
}
func (rk *ResponseKey) GetSettingIds() []string {
diff --git a/internal/manager/key.go b/internal/manager/key.go
index 2427912..59292aa 100644
--- a/internal/manager/key.go
+++ b/internal/manager/key.go
@@ -8,8 +8,6 @@ import (
"github.com/bricks-cloud/bricksllm/internal/key"
"github.com/bricks-cloud/bricksllm/internal/provider"
"github.com/bricks-cloud/bricksllm/internal/util"
-
- internal_errors "github.com/bricks-cloud/bricksllm/internal/errors"
)
type Storage interface {
@@ -58,20 +56,6 @@ func (m *Manager) GetKeys(tags, keyIds []string, provider string) ([]*key.Respon
return m.s.GetKeys(tags, keyIds, provider)
}
-func (m *Manager) areProviderSettingsUniqueness(settings []*provider.Setting) bool {
- providerMap := map[string]bool{}
-
- for _, setting := range settings {
- if providerMap[setting.Provider] {
- return false
- }
-
- providerMap[setting.Provider] = true
- }
-
- return true
-}
-
func (m *Manager) CreateKey(rk *key.RequestKey) (*key.ResponseKey, error) {
rk.CreatedAt = time.Now().Unix()
rk.UpdatedAt = time.Now().Unix()
@@ -97,11 +81,6 @@ func (m *Manager) CreateKey(rk *key.RequestKey) (*key.ResponseKey, error) {
if len(existing) == 0 {
return nil, errors.New("provider settings not found")
}
-
- if !m.areProviderSettingsUniqueness(existing) {
- return nil, internal_errors.NewValidationError("key can only be assoicated with one setting per provider")
- }
-
}
return m.s.CreateKey(rk)
@@ -129,10 +108,6 @@ func (m *Manager) UpdateKey(id string, uk *key.UpdateKey) (*key.ResponseKey, err
if len(existing) == 0 {
return nil, errors.New("provider settings not found")
}
-
- if !m.areProviderSettingsUniqueness(existing) {
- return nil, internal_errors.NewValidationError("key can only be assoicated with one setting per provider")
- }
}
if len(uk.CostLimitInUsdUnit) != 0 {
diff --git a/internal/manager/provider_setting.go b/internal/manager/provider_setting.go
index 5ba2a7c..76299b6 100644
--- a/internal/manager/provider_setting.go
+++ b/internal/manager/provider_setting.go
@@ -44,7 +44,7 @@ func findMissingAuthParams(providerName string, params map[string]string) string
missingFields := []string{}
if providerName == "openai" || providerName == "anthropic" {
- val, _ := params["apikey"]
+ val := params["apikey"]
if len(val) == 0 {
missingFields = append(missingFields, "apikey")
}
@@ -53,12 +53,12 @@ func findMissingAuthParams(providerName string, params map[string]string) string
}
if providerName == "azure" {
- val, _ := params["resourceName"]
+ val := params["resourceName"]
if len(val) == 0 {
missingFields = append(missingFields, "resourceName")
}
- val, _ = params["apikey"]
+ val = params["apikey"]
if len(val) == 0 {
missingFields = append(missingFields, "apikey")
}
@@ -76,7 +76,7 @@ func (m *ProviderSettingsManager) validateSettings(providerName string, setting
}
if len(provider.AuthenticationParam) != 0 {
- val, _ := setting[provider.AuthenticationParam]
+ val := setting[provider.AuthenticationParam]
if len(val) == 0 {
return internal_errors.NewValidationError(fmt.Sprintf("provider %s is missing value for field %s", providerName, provider.AuthenticationParam))
}
diff --git a/internal/manager/route.go b/internal/manager/route.go
index b3c8d8b..eb3c4ed 100644
--- a/internal/manager/route.go
+++ b/internal/manager/route.go
@@ -228,16 +228,16 @@ func (m *RouteManager) validateRoute(r *route.Route) error {
}
if !contains(step.Provider, supportedProviders) {
- return errors.New(fmt.Sprintf("steps.[%d].provider is not supported. Only azure and openai are supported", index))
+ return fmt.Errorf("steps.[%d].provider is not supported. Only azure and openai are supported", index)
}
if step.Provider == "azure" {
- apiVersion, _ := step.Params["apiVersion"]
+ apiVersion := step.Params["apiVersion"]
if len(apiVersion) == 0 {
fields = append(fields, fmt.Sprintf("steps.[%d].params.apiVersion", index))
}
- deploymentId, _ := step.Params["deploymentId"]
+ deploymentId := step.Params["deploymentId"]
if len(deploymentId) == 0 {
fields = append(fields, fmt.Sprintf("steps.[%d].params.deploymentId", index))
}
@@ -248,11 +248,11 @@ func (m *RouteManager) validateRoute(r *route.Route) error {
}
if !contains(step.Model, supportedModels) {
- return errors.New(fmt.Sprintf("steps.[%d].model is not supported. Only chat completion and embeddings model are supported.", index))
+ return fmt.Errorf("steps.[%d].model is not supported. Only chat completion and embeddings model are supported", index)
}
if !checkModelValidity(step.Provider, step.Model) {
- return errors.New(fmt.Sprintf("model: %s is not supported for provider: %s.", step.Model, step.Provider))
+ return fmt.Errorf("model: %s is not supported for provider: %s", step.Model, step.Provider)
}
if !containAda && contains(step.Model, adaModels) {
@@ -262,11 +262,11 @@ func (m *RouteManager) validateRoute(r *route.Route) error {
for _, step := range r.Steps {
if containAda && !contains(step.Model, adaModels) {
- return errors.New("steps must have congruent models. Chat completion and embedding models cannot be in the same route config.")
+ return errors.New("steps must have congruent models. Chat completion and embedding models cannot be in the same route config")
}
if !containAda && !contains(step.Model, chatCompletionModels) {
- return errors.New("steps must have congruent models. Chat completion and embedding models cannot be in the same route config.")
+ return errors.New("steps must have congruent models. Chat completion and embedding models cannot be in the same route config")
}
}
diff --git a/internal/storage/postgresql/postgresql.go b/internal/storage/postgresql/postgresql.go
index 817c446..f54a4a0 100644
--- a/internal/storage/postgresql/postgresql.go
+++ b/internal/storage/postgresql/postgresql.go
@@ -101,7 +101,7 @@ func (s *Store) AlterKeysTable() error {
END IF;
END
$$;
- ALTER TABLE keys ADD COLUMN IF NOT EXISTS setting_id VARCHAR(255), ADD COLUMN IF NOT EXISTS allowed_paths JSONB, ADD COLUMN IF NOT EXISTS setting_ids VARCHAR(255)[] NOT NULL DEFAULT ARRAY[]::VARCHAR(255)[], ADD COLUMN IF NOT EXISTS should_log_request BOOLEAN NOT NULL DEFAULT FALSE, ADD COLUMN IF NOT EXISTS should_log_response BOOLEAN NOT NULL DEFAULT FALSE;
+ ALTER TABLE keys ADD COLUMN IF NOT EXISTS setting_id VARCHAR(255), ADD COLUMN IF NOT EXISTS allowed_paths JSONB, ADD COLUMN IF NOT EXISTS setting_ids VARCHAR(255)[] NOT NULL DEFAULT ARRAY[]::VARCHAR(255)[], ADD COLUMN IF NOT EXISTS should_log_request BOOLEAN NOT NULL DEFAULT FALSE, ADD COLUMN IF NOT EXISTS should_log_response BOOLEAN NOT NULL DEFAULT FALSE, ADD COLUMN IF NOT EXISTS rotation_enabled BOOLEAN NOT NULL DEFAULT FALSE;
`
ctxTimeout, cancel := context.WithTimeout(context.Background(), s.wt)
@@ -632,6 +632,7 @@ func (s *Store) GetKeys(tags, keyIds []string, provider string) ([]*key.Response
pq.Array(&k.SettingIds),
&k.ShouldLogRequest,
&k.ShouldLogResponse,
+ &k.RotationEnabled,
); err != nil {
return nil, err
}
@@ -690,6 +691,7 @@ func (s *Store) GetKey(keyId string) (*key.ResponseKey, error) {
pq.Array(&k.SettingIds),
&k.ShouldLogRequest,
&k.ShouldLogResponse,
+ &k.RotationEnabled,
); err != nil {
return nil, err
}
@@ -833,6 +835,7 @@ func (s *Store) GetAllKeys() ([]*key.ResponseKey, error) {
pq.Array(&k.SettingIds),
&k.ShouldLogRequest,
&k.ShouldLogResponse,
+ &k.RotationEnabled,
); err != nil {
return nil, err
}
@@ -929,6 +932,7 @@ func (s *Store) GetUpdatedKeys(updatedAt int64) ([]*key.ResponseKey, error) {
pq.Array(&k.SettingIds),
&k.ShouldLogRequest,
&k.ShouldLogResponse,
+ &k.RotationEnabled,
); err != nil {
return nil, err
}
@@ -1088,6 +1092,12 @@ func (s *Store) UpdateKey(id string, uk *key.UpdateKey) (*key.ResponseKey, error
counter++
}
+ if uk.RotationEnabled != nil {
+ values = append(values, *uk.RotationEnabled)
+ fields = append(fields, fmt.Sprintf("rotation_enabled = $%d", counter))
+ counter++
+ }
+
if uk.AllowedPaths != nil {
data, err := json.Marshal(uk.AllowedPaths)
if err != nil {
@@ -1126,6 +1136,7 @@ func (s *Store) UpdateKey(id string, uk *key.UpdateKey) (*key.ResponseKey, error
pq.Array(&k.SettingIds),
&k.ShouldLogRequest,
&k.ShouldLogResponse,
+ &k.RotationEnabled,
); err != nil {
if err == sql.ErrNoRows {
return nil, internal_errors.NewNotFoundError(fmt.Sprintf("key not found for id: %s", id))
@@ -1265,8 +1276,8 @@ func (s *Store) CreateProviderSetting(setting *provider.Setting) (*provider.Sett
func (s *Store) CreateKey(rk *key.RequestKey) (*key.ResponseKey, error) {
query := `
- INSERT INTO keys (name, created_at, updated_at, tags, revoked, key_id, key, revoked_reason, cost_limit_in_usd, cost_limit_in_usd_over_time, cost_limit_in_usd_unit, rate_limit_over_time, rate_limit_unit, ttl, setting_id, allowed_paths, setting_ids, should_log_request, should_log_response)
- VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16, $17, $18, $19)
+ INSERT INTO keys (name, created_at, updated_at, tags, revoked, key_id, key, revoked_reason, cost_limit_in_usd, cost_limit_in_usd_over_time, cost_limit_in_usd_unit, rate_limit_over_time, rate_limit_unit, ttl, setting_id, allowed_paths, setting_ids, should_log_request, should_log_response, rotation_enabled)
+ VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16, $17, $18, $19, $20)
RETURNING *;
`
@@ -1295,6 +1306,7 @@ func (s *Store) CreateKey(rk *key.RequestKey) (*key.ResponseKey, error) {
sliceToSqlStringArray(rk.SettingIds),
rk.ShouldLogRequest,
rk.ShouldLogResponse,
+ rk.RotationEnabled,
}
ctxTimeout, cancel := context.WithTimeout(context.Background(), s.wt)
@@ -1324,6 +1336,7 @@ func (s *Store) CreateKey(rk *key.RequestKey) (*key.ResponseKey, error) {
pq.Array(&k.SettingIds),
&k.ShouldLogRequest,
&k.ShouldLogResponse,
+ &k.RotationEnabled,
); err != nil {
return nil, err
}
From 5f86ed695351e74cf900c859d87e9c3620f2fb16 Mon Sep 17 00:00:00 2001
From: Spike Lu
Date: Wed, 28 Feb 2024 18:29:06 -0800
Subject: [PATCH 70/71] add querying keys through key ids
---
README.md | 20 +++++++++++++++++++-
internal/server/web/admin/admin.go | 6 +++---
2 files changed, 22 insertions(+), 4 deletions(-)
diff --git a/README.md b/README.md
index e4d8038..de3073f 100644
--- a/README.md
+++ b/README.md
@@ -172,7 +172,8 @@ This endpoint is set up for retrieving key configurations using a query param ca
> |--------|------------|----------------|------------------------------------------------------|
> | `tag` | optional | `string` | Identifier attached to a key configuration |
> | `tags` | optional | `[]string` | Identifiers attached to a key configuration |
-> | `provider` | optional | `string` | Provider attached to a key provider configuration. Its value can only be `openai`. |
+> | `provider` | optional | `string` | Provider attached to a key provider configuration. Its value can only be `openai`.
+> | `keyIds` | optional | `[]string` | Unique identifiers for keys.
##### Error Response
@@ -213,6 +214,9 @@ Fields of KeyConfiguration
> | allowedPaths | `[]PathConfig` | `[{ "path": "/api/providers/openai/v1/chat/completion", "method": "POST"}]` | Allowed paths that can be accessed using the key. |
> | settingId | `string` | `98daa3ae-961d-4253-bf6a-322a32fdca3d` | This field is DEPERCATED. Use `settingIds` field instead. |
> | settingIds | `string` | `[98daa3ae-961d-4253-bf6a-322a32fdca3d]` | Setting ids associated with the key. |
+> | shouldLogRequest | `bool` | `false` | Should request be stored. |
+> | shouldLogResponse | `bool` | `true` | Should response be stored. |
+> | rotationEnabled | `bool` | `false` | Should key rotate setting used to access third party endpoints in order to circumvent rate limits. |
@@ -247,6 +251,9 @@ PathConfig
> | rateLimitUnit | optional | `enum` | m | Time unit for rateLimitOverTime. Possible values are [`h`, `m`, `s`, `d`] |
> | ttl | optional | `string` | 2d | time to live. Available units are [`s`, `m`, `h`]. |
> | allowedPaths | optional | `[]PathConfig` | 2d | Pathes allowed for access. |
+> | shouldLogRequest | optional | `bool` | `false` | Should request be stored. |
+> | shouldLogResponse | optional | `bool` | `true` | Should response be stored. |
+> | rotationEnabled | optional | `bool` | `false` | Should key rotate setting used to access third party endpoints in order to circumvent rate limits. |
##### Error Response
@@ -283,6 +290,9 @@ PathConfig
> | allowedPaths | `[]PathConfig` | `[{ "path": "/api/providers/openai/v1/chat/completion", method: "POST"}]` | Allowed paths that can be accessed using the key. |
> | settingId | `string` | `98daa3ae-961d-4253-bf6a-322a32fdca3d` | This field is DEPERCATED. Use `settingIds` field instead. |
> | settingIds | `string` | `[98daa3ae-961d-4253-bf6a-322a32fdca3d]` | Setting ids associated with the key. |
+> | shouldLogRequest | `bool` | `false` | Should request be stored. |
+> | shouldLogResponse | `bool` | `true` | Should response be stored. |
+> | rotationEnabled | `bool` | `false` | Should key rotate setting used to access third party endpoints in order to circumvent rate limits. |
@@ -320,6 +330,9 @@ PathConfig
> | rateLimitOverTime | optional | `int` | `2` | rate limit over period of time. This field is required if rateLimitUnit is specified. |
> | rateLimitUnit | optional | `string` | `m` | Time unit for rateLimitOverTime. Possible values are [`h`, `m`, `s`, `d`] |
> | allowedPaths | optional | `[{ "path": "/api/providers/openai/v1/chat/completions", "method": "POST"}]` | `` | Pathes allowed for access. |
+> | shouldLogRequest | optional | `bool` | `false` | Should request be stored. |
+> | shouldLogResponse | optional | `bool` | `true` | Should response be stored. |
+> | rotationEnabled | optional | `bool` | `false` | Should key rotate setting used to access third party endpoints in order to circumvent rate limits. |
##### Error Response
@@ -354,6 +367,9 @@ PathConfig
> | allowedPaths | `[]PathConfig` | `[{ "path": "/api/providers/openai/v1/chat/completion", method: "POST"}]` | Allowed paths that can be accessed using the key. |
> | settingId | `string` | `98daa3ae-961d-4253-bf6a-322a32fdca3d` | This field is DEPERCATED. Use `settingIds` field instead. |
> | settingIds | `string` | `[98daa3ae-961d-4253-bf6a-322a32fdca3d]` | Setting ids associated with the key. |
+> | shouldLogRequest | `bool` | `false` | Should request be stored. |
+> | shouldLogResponse | `bool` | `true` | Should response be stored. |
+> | rotationEnabled | `bool` | `false` | Should key rotate setting used to access third party endpoints in order to circumvent rate limits. |
@@ -597,6 +613,8 @@ Event
> | path | `string` | `/api/v1/chat/completion` | Provider setting name. |
> | method | `string` | `POST` | Http method for the assoicated proxu request. |
> | custom_id | `string` | `YOUR_CUSTOM_ID` | Custom Id passed by the user in the headers of proxy requests. |
+> | request | `[]byte` | `{}` | Custom Id passed by the user in the headers of proxy requests. |
+> | custom_id | `string` | `YOUR_CUSTOM_ID` | Custom Id passed by the user in the headers of proxy requests. |
diff --git a/internal/server/web/admin/admin.go b/internal/server/web/admin/admin.go
index d723be1..eee9e82 100644
--- a/internal/server/web/admin/admin.go
+++ b/internal/server/web/admin/admin.go
@@ -149,11 +149,11 @@ func getGetKeysHandler(m KeyManager, log *zap.Logger, prod bool) gin.HandlerFunc
tag := c.Query("tag")
tags := c.QueryArray("tags")
+ keyIds := c.QueryArray("keyIds")
provider := c.Query("provider")
path := "/api/key-management/keys"
-
- if len(tags) == 0 && len(tag) == 0 && len(provider) == 0 {
+ if len(tags) == 0 && len(tag) == 0 && len(provider) == 0 && len(keyIds) == 0 {
c.JSON(http.StatusBadRequest, &ErrorResponse{
Type: "/errors/missing-filteres",
Title: "filters are not found",
@@ -177,7 +177,7 @@ func getGetKeysHandler(m KeyManager, log *zap.Logger, prod bool) gin.HandlerFunc
}
cid := c.GetString(correlationId)
- keys, err := m.GetKeys(selected, nil, provider)
+ keys, err := m.GetKeys(selected, keyIds, provider)
if err != nil {
stats.Incr("bricksllm.admin.get_get_keys_handler.get_keys_by_tag_err", nil, 1)
From b41d5fe73089e40095c108f6789d5f65bfdb4817 Mon Sep 17 00:00:00 2001
From: Spike Lu
Date: Wed, 28 Feb 2024 19:11:43 -0800
Subject: [PATCH 71/71] increase timeout
---
CHANGELOG.md | 5 +++++
internal/config/config.go | 4 ++--
2 files changed, 7 insertions(+), 2 deletions(-)
diff --git a/CHANGELOG.md b/CHANGELOG.md
index b443d09..e34b8e4 100644
--- a/CHANGELOG.md
+++ b/CHANGELOG.md
@@ -1,3 +1,8 @@
+## 1.12.1 - 2024-02-28
+### Added
+- Added querying keys by `keyIds`
+- Increased default postgres DB read timeout to `15s` and write timeout to `5s`
+
## 1.12.0 - 2024-02-28
### Added
- Added setting rotation feature to key
diff --git a/internal/config/config.go b/internal/config/config.go
index ced3a3b..0ebb895 100644
--- a/internal/config/config.go
+++ b/internal/config/config.go
@@ -19,8 +19,8 @@ type Config struct {
RedisPassword string `env:"REDIS_PASSWORD"`
RedisReadTimeout time.Duration `env:"REDIS_READ_TIME_OUT" envDefault:"1s"`
RedisWriteTimeout time.Duration `env:"REDIS_WRITE_TIME_OUT" envDefault:"500ms"`
- PostgresqlReadTimeout time.Duration `env:"POSTGRESQL_READ_TIME_OUT" envDefault:"2s"`
- PostgresqlWriteTimeout time.Duration `env:"POSTGRESQL_WRITE_TIME_OUT" envDefault:"1s"`
+ PostgresqlReadTimeout time.Duration `env:"POSTGRESQL_READ_TIME_OUT" envDefault:"15s"`
+ PostgresqlWriteTimeout time.Duration `env:"POSTGRESQL_WRITE_TIME_OUT" envDefault:"5s"`
InMemoryDbUpdateInterval time.Duration `env:"IN_MEMORY_DB_UPDATE_INTERVAL" envDefault:"5s"`
OpenAiKey string `env:"OPENAI_API_KEY"`
StatsProvider string `env:"STATS_PROVIDER"`