Skip to content

Commit

Permalink
Merge pull request #15 from bricks-cloud/v1.1.0
Browse files Browse the repository at this point in the history
[V1.1.0] Add support for OpenAI embeddings API
  • Loading branch information
spikelu2016 authored Nov 21, 2023
2 parents defd3c5 + 5e768c2 commit 9e3a13a
Show file tree
Hide file tree
Showing 5 changed files with 148 additions and 42 deletions.
5 changes: 5 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,8 @@
## 1.1.0 - 2023-11-20
### Fixed
- ### Added
- Added support for OpenAI's embeddings API

## 1.0.4 - 2023-11-20
### Fixed
- Fixed configuration not found inconsistency with key and provider settings
Expand Down
8 changes: 8 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -417,3 +417,11 @@ The OpenAI proxy runs on Port `8002`.
This endpoint is set up for proxying OpenAI API requests. Documentation for this endpoint can be found [here](https://platform.openai.com/docs/api-reference/chat).

</details>

<details>
<summary>Call OpenAI embeddings: <code>POST</code> <code><b>/api/providers/openai/v1/embeddings</b></code></summary>

##### Description
This endpoint is set up for proxying OpenAI API requests. Documentation for this endpoint can be found [here](https://platform.openai.com/docs/api-reference/embeddings/create).

</details>
9 changes: 7 additions & 2 deletions internal/provider/openai/cost.go
Original file line number Diff line number Diff line change
Expand Up @@ -181,10 +181,15 @@ func (ce *CostEstimator) EstimateEmbeddingsCost(r *goopenai.EmbeddingRequest) (f
return 0, errors.New("model is not provided")
}

if inputs, ok := r.Input.([]string); ok {
if inputs, ok := r.Input.([]interface{}); ok {
total := 0
for _, input := range inputs {
tks, err := ce.tc.Count(r.Model.String(), input)
converted, ok := input.(string)
if !ok {
return 0, errors.New("input is not string")
}

tks, err := ce.tc.Count(r.Model.String(), converted)
if err != nil {
return 0, err
}
Expand Down
38 changes: 20 additions & 18 deletions internal/server/web/middleware.go
Original file line number Diff line number Diff line change
Expand Up @@ -207,7 +207,8 @@ func getMiddleware(kms keyMemStorage, prod, private bool, e estimator, v validat

var cost float64 = 0
model := ""
if strings.HasSuffix(c.FullPath(), "/chat/completions") {

if c.FullPath() == "/api/providers/openai/v1/chat/completions" {
ccr := &goopenai.ChatCompletionRequest{}
err = json.Unmarshal(body, ccr)
if err != nil {
Expand All @@ -216,7 +217,7 @@ func getMiddleware(kms keyMemStorage, prod, private bool, e estimator, v validat
}

model = ccr.Model
c.Set("model", ccr.Model)
c.Set("model", model)

logRequest(log, prod, private, cid, ccr)

Expand All @@ -229,25 +230,26 @@ func getMiddleware(kms keyMemStorage, prod, private bool, e estimator, v validat

}

// if strings.HasSuffix(c.FullPath(), "/embeddings") {
// er := &goopenai.EmbeddingRequest{}
// err = json.Unmarshal(body, er)
// if err != nil {
// logError(log, "error when unmarshalling embedding request", prod, cid, err)
// return
// }
if c.FullPath() == "/api/providers/openai/v1/embeddings" {
er := &goopenai.EmbeddingRequest{}
err = json.Unmarshal(body, er)
if err != nil {
logError(log, "error when unmarshalling embedding request", prod, cid, err)
return
}

// model = er.Model.String()
// c.Set("model", er.Model.String())
model = er.Model.String()
c.Set("model", model)
c.Set("encoding_format", string(er.EncodingFormat))

// logEmbeddingRequest(log, prod, private, cid, er)
logEmbeddingRequest(log, prod, private, cid, er)

// cost, err = e.EstimateEmbeddingsCost(er)
// if err != nil {
// stats.Incr("bricksllm.web.get_middleware.estimate_embeddings_cost_error", nil, 1)
// logError(log, "error when estimating embeddings cost", prod, cid, err)
// }
// }
cost, err = e.EstimateEmbeddingsCost(er)
if err != nil {
stats.Incr("bricksllm.web.get_middleware.estimate_embeddings_cost_error", nil, 1)
logError(log, "error when estimating embeddings cost", prod, cid, err)
}
}

// if c.FullPath() == "/assistants" && c.Request.Method == http.MethodPost {
// logCreateAssistantRequest(log, body, prod, private, cid)
Expand Down
130 changes: 108 additions & 22 deletions internal/server/web/proxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,15 +35,15 @@ type recorder interface {
func NewProxyServer(log *zap.Logger, mode, privacyMode string, m KeyManager, psm ProviderSettingsManager, ks keyStorage, kms keyMemStorage, e estimator, v validator, r recorder, credential string, enc encrypter, rlm rateLimitManager, timeOut time.Duration) (*ProxyServer, error) {
router := gin.New()
prod := mode == "production"
private := mode == "strict"
private := privacyMode == "strict"

router.Use(getMiddleware(kms, prod, private, e, v, ks, log, enc, rlm, r, "proxy"))

client := http.Client{}

router.POST("/api/health", getGetHealthCheckHandler())
router.POST("/api/providers/openai/v1/chat/completions", getChatCompletionHandler(r, prod, private, psm, client, kms, log, enc, e, timeOut))
// router.POST("/api/providers/openai/v1/embeddings", getEmbeddingHandler(r, prod, private, psm, client, kms, log, enc, e, timeOut))
router.POST("/api/providers/openai/v1/embeddings", getEmbeddingHandler(r, prod, private, psm, client, kms, log, enc, e, timeOut))

// router.GET("/api/providers/openai/v1/assistants", getChatCompletionHandler(r, prod, private, psm, client, kms, log, enc, e, timeOut))
// router.POST("/api/providers/openai/v1/assistants", getChatCompletionHandler(r, prod, private, psm, client, kms, log, enc, e, timeOut))
Expand Down Expand Up @@ -189,6 +189,22 @@ func getPassThroughHandler(r recorder, prod, private bool, psm ProviderSettingsM
}
}

// EmbeddingResponse is the response from a Create embeddings request.
type EmbeddingResponse struct {
Object string `json:"object"`
Data []goopenai.Embedding `json:"data"`
Model string `json:"model"`
Usage goopenai.Usage `json:"usage"`
}

// EmbeddingResponse is the response from a Create embeddings request.
type EmbeddingResponseBase64 struct {
Object string `json:"object"`
Data []goopenai.Base64Embedding `json:"data"`
Model string `json:"model"`
Usage goopenai.Usage `json:"usage"`
}

func getEmbeddingHandler(r recorder, prod, private bool, psm ProviderSettingsManager, client http.Client, kms keyMemStorage, log *zap.Logger, enc encrypter, e estimator, timeOut time.Duration) gin.HandlerFunc {
return func(c *gin.Context) {
stats.Incr("bricksllm.web.get_embedding_handler.requests", nil, 1)
Expand Down Expand Up @@ -238,26 +254,50 @@ func getEmbeddingHandler(r recorder, prod, private bool, psm ProviderSettingsMan
}

var cost float64 = 0
chatRes := &goopenai.EmbeddingResponse{}
chatRes := &EmbeddingResponse{}
base64ChatRes := &EmbeddingResponseBase64{}
if res.StatusCode == http.StatusOK {
stats.Incr("bricksllm.web.get_embedding_handler.success", nil, 1)
stats.Timing("bricksllm.web.get_embedding_handler.success_latency", dur, nil, 1)

err = json.Unmarshal(bytes, chatRes)
if err != nil {
logError(log, "error when unmarshalling openai embedding response body", prod, id, err)
format := c.GetString("encoding_format")

if format == "base64" {
err = json.Unmarshal(bytes, base64ChatRes)
if err != nil {
logError(log, "error when unmarshalling openai base64 embedding response body", prod, id, err)
}
}

if format != "base64" {
err = json.Unmarshal(bytes, chatRes)
if err != nil {
logError(log, "error when unmarshalling openai embedding response body", prod, id, err)
}
}

model := c.GetString("model")

totalTokens := 0
if err == nil {
logEmbeddingResponse(log, prod, private, id, chatRes)
cost, err = e.EstimateEmbeddingsInputCost(chatRes.Model.String(), chatRes.Usage.TotalTokens)
if format == "base64" {
logBase64EmbeddingResponse(log, prod, private, id, base64ChatRes)
totalTokens = base64ChatRes.Usage.TotalTokens
}

if format != "base64" {
logEmbeddingResponse(log, prod, private, id, chatRes)
totalTokens = chatRes.Usage.TotalTokens
}

cost, err = e.EstimateEmbeddingsInputCost(model, totalTokens)
if err != nil {
stats.Incr("bricksllm.web.get_embedding_handler.estimate_total_cost_error", nil, 1)
logError(log, "error when estimating openai cost for embedding", prod, id, err)
}

micros := int64(cost * 1000000)
err = r.RecordKeySpend(kc.KeyId, chatRes.Model.String(), micros, kc.CostLimitInUsdUnit)
err = r.RecordKeySpend(kc.KeyId, model, micros, kc.CostLimitInUsdUnit)
if err != nil {
stats.Incr("bricksllm.web.get_embedding_handler.record_key_spend_error", nil, 1)
logError(log, "error when recording openai spend for embedding", prod, id, err)
Expand Down Expand Up @@ -374,16 +414,17 @@ func getChatCompletionHandler(r recorder, prod, private bool, psm ProviderSettin
logError(log, "error when unmarshalling openai http chat completion response body", prod, id, err)
}

model := c.GetString("model")
if err == nil {
logChatCompletionResponse(log, prod, private, id, chatRes)
cost, err = e.EstimateTotalCost(chatRes.Model, chatRes.Usage.PromptTokens, chatRes.Usage.CompletionTokens)
cost, err = e.EstimateTotalCost(model, chatRes.Usage.PromptTokens, chatRes.Usage.CompletionTokens)
if err != nil {
stats.Incr("bricksllm.web.get_chat_completion_handler.estimate_total_cost_error", nil, 1)
logError(log, "error when estimating openai cost", prod, id, err)
}

micros := int64(cost * 1000000)
err = r.RecordKeySpend(kc.KeyId, chatRes.Model, micros, kc.CostLimitInUsdUnit)
err = r.RecordKeySpend(kc.KeyId, model, micros, kc.CostLimitInUsdUnit)
if err != nil {
stats.Incr("bricksllm.web.get_chat_completion_handler.record_key_spend_error", nil, 1)
logError(log, "error when recording openai spend", prod, id, err)
Expand Down Expand Up @@ -422,7 +463,7 @@ func (ps *ProxyServer) Run() {
go func() {
ps.log.Info("proxy server listening at 8002")
ps.log.Info("PORT 8002 | POST | /api/providers/openai/v1/chat/completions is ready for forwarding chat completion requests to openai")
// ps.log.Info("PORT 8002 | POST | /api/providers/openai/v1/embeddings is ready for forwarding embeddings requests to openai")
ps.log.Info("PORT 8002 | POST | /api/providers/openai/v1/embeddings is ready for forwarding embeddings requests to openai")

if err := ps.server.ListenAndServe(); err != nil && err != http.ErrServerClosed {
ps.log.Sugar().Fatalf("error proxy server listening: %v", err)
Expand All @@ -431,29 +472,74 @@ func (ps *ProxyServer) Run() {
}()
}

func logEmbeddingResponse(log *zap.Logger, prod, private bool, cid string, r *goopenai.EmbeddingResponse) {
func logEmbeddingResponse(log *zap.Logger, prod, private bool, cid string, r *EmbeddingResponse) {
if prod {
log.Info("openai chat completion response",
log.Info("openai embeddings response",
zap.Time("createdAt", time.Now()),
zap.String(correlationId, cid),
zap.Object("response", zapcore.ObjectMarshalerFunc(
func(enc zapcore.ObjectEncoder) error {
enc.AddString("object", r.Object)
enc.AddString("model", r.Model.String())
enc.AddString("model", r.Model)
enc.AddArray("data", zapcore.ArrayMarshalerFunc(
func(enc zapcore.ArrayEncoder) error {
for _, d := range r.Data {
enc.AppendObject(zapcore.ObjectMarshalerFunc(
func(enc zapcore.ObjectEncoder) error {
enc.AddInt("index", d.Index)
enc.AddString("object", d.Object)
enc.AddArray("embedding", zapcore.ArrayMarshalerFunc(
func(enc zapcore.ArrayEncoder) error {
for _, e := range d.Embedding {
enc.AppendFloat32(e)
}
return nil
}))
if !private {
enc.AddArray("embedding", zapcore.ArrayMarshalerFunc(
func(enc zapcore.ArrayEncoder) error {
for _, e := range d.Embedding {
enc.AppendFloat32(e)
}
return nil
}))
}

return nil
},
))
}
return nil
},
))

enc.AddObject("usage", zapcore.ObjectMarshalerFunc(
func(enc zapcore.ObjectEncoder) error {
enc.AddInt("prompt_tokens", r.Usage.PromptTokens)
enc.AddInt("completion_tokens", r.Usage.CompletionTokens)
enc.AddInt("total_tokens", r.Usage.TotalTokens)
return nil
},
))
return nil
},
)),
)
}
}

func logBase64EmbeddingResponse(log *zap.Logger, prod, private bool, cid string, r *EmbeddingResponseBase64) {
if prod {
log.Info("openai embeddings response",
zap.Time("createdAt", time.Now()),
zap.String(correlationId, cid),
zap.Object("response", zapcore.ObjectMarshalerFunc(
func(enc zapcore.ObjectEncoder) error {
enc.AddString("object", r.Object)
enc.AddString("model", r.Model)
enc.AddArray("data", zapcore.ArrayMarshalerFunc(
func(enc zapcore.ArrayEncoder) error {
for _, d := range r.Data {
enc.AppendObject(zapcore.ObjectMarshalerFunc(
func(enc zapcore.ObjectEncoder) error {
enc.AddInt("index", d.Index)
enc.AddString("object", d.Object)
if !private {
enc.AddString("embedding", string(d.Embedding))
}

return nil
},
Expand Down

0 comments on commit 9e3a13a

Please sign in to comment.