From ab251474d759c775f8bb8c2ca7e0f03d9496189a Mon Sep 17 00:00:00 2001 From: "sweep-ai[bot]" <128439645+sweep-ai[bot]@users.noreply.github.com> Date: Fri, 22 Mar 2024 15:43:10 +0000 Subject: [PATCH 1/9] feat: Add HealthChecker interface and default impl --- gateway/internal/router/health_checker.go | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) create mode 100644 gateway/internal/router/health_checker.go diff --git a/gateway/internal/router/health_checker.go b/gateway/internal/router/health_checker.go new file mode 100644 index 0000000..8cb1f28 --- /dev/null +++ b/gateway/internal/router/health_checker.go @@ -0,0 +1,17 @@ +package router + +import ( + "log" +) + +type HealthChecker interface { + IsHealthy(providerName string) bool +} + +type DefaultHealthChecker struct{} + +func (d *DefaultHealthChecker) IsHealthy(providerName string) bool { + // Placeholder for actual health check logic + // Currently returns true, assuming all providers are healthy + return true +} From d37b54bf0bc3b9e4d336d3c99bc0fc4fbf13ac02 Mon Sep 17 00:00:00 2001 From: "sweep-ai[bot]" <128439645+sweep-ai[bot]@users.noreply.github.com> Date: Fri, 22 Mar 2024 15:44:40 +0000 Subject: [PATCH 2/9] feat: Updated gateway/internal/router/round_robin. --- gateway/internal/router/round_robin.go | 23 +++++++++++++++++++---- 1 file changed, 19 insertions(+), 4 deletions(-) diff --git a/gateway/internal/router/round_robin.go b/gateway/internal/router/round_robin.go index 60ac24c..5d88108 100644 --- a/gateway/internal/router/round_robin.go +++ b/gateway/internal/router/round_robin.go @@ -1,7 +1,9 @@ package router import ( + "log" "sync/atomic" + "gateway/internal/router" // Importing to use HealthChecker ) const ( @@ -26,9 +28,22 @@ func (r *RoundRobinRouter) Iterator() RouterIterator { func (r *RoundRobinRouter) Next() *RouterConfig { providerLen := len(r.providers) - // Todo: make a check for healthy provider - idx := r.idx.Add(1) - 1 - model := &r.providers[idx%uint64(providerLen)] + // Iterate through providers to find a healthy one + var healthyProvider *RouterConfig + originalIdx := r.idx.Load() + for i := 0; i < providerLen; i++ { + idx := (originalIdx + uint64(i)) % uint64(providerLen) + if router.DefaultHealthChecker{}.IsHealthy(r.providers[idx].Name) { + healthyProvider = &r.providers[idx] + r.idx.Add(1) + break + } + } + + if healthyProvider == nil { + log.Println("Error: No healthy providers available.") + return nil + } - return model + return healthyProvider } From 4ddece299009bf57563dbc2e19f6eafa66465ffc Mon Sep 17 00:00:00 2001 From: "sweep-ai[bot]" <128439645+sweep-ai[bot]@users.noreply.github.com> Date: Fri, 22 Mar 2024 15:46:32 +0000 Subject: [PATCH 3/9] feat: Updated gateway/internal/router/priority.go --- gateway/internal/router/priority.go | 25 ++++++++++++++++++------- 1 file changed, 18 insertions(+), 7 deletions(-) diff --git a/gateway/internal/router/priority.go b/gateway/internal/router/priority.go index 7b0f80e..71a5136 100644 --- a/gateway/internal/router/priority.go +++ b/gateway/internal/router/priority.go @@ -2,6 +2,8 @@ package router import ( "sync/atomic" + "log" + "gateway/internal/router" // Importing to use HealthChecker ) const ( @@ -21,11 +23,20 @@ func NewPriorityRouter(providers []RouterConfig) *PriorityRouter { } func (r *PriorityRouter) Next() (*RouterConfig, error) { - idx := int(r.idx.Load()) - - // Todo: make a check for healthy provider - model := &r.providers[idx] - r.idx.Add(1) - - return model, nil + providerLen := len(r.providers) + originalIdx := r.idx.Load() + var healthyProvider *RouterConfig + for i := 0; i < providerLen; i++ { + idx := (originalIdx + uint64(i)) % uint64(providerLen) + if router.DefaultHealthChecker{}.IsHealthy(r.providers[idx].Name) { + healthyProvider = &r.providers[idx] + r.idx.Store(idx + 1) + break + } + } + if healthyProvider == nil { + log.Println("Error: No healthy providers available.") + return nil, fmt.Errorf("no healthy providers available") + } + return healthyProvider, nil } From 81cf9967ab6c4eab4920a3b9e234ac7224248c04 Mon Sep 17 00:00:00 2001 From: "sweep-ai[bot]" <128439645+sweep-ai[bot]@users.noreply.github.com> Date: Fri, 22 Mar 2024 15:47:17 +0000 Subject: [PATCH 4/9] feat: Updated gateway/internal/api/v1/models.go --- gateway/internal/api/v1/models.go | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/gateway/internal/api/v1/models.go b/gateway/internal/api/v1/models.go index 160f991..df24e33 100644 --- a/gateway/internal/api/v1/models.go +++ b/gateway/internal/api/v1/models.go @@ -13,6 +13,11 @@ func (s *V1Handler) ListModels(ctx context.Context, req *connect.Request[llmv1.M allProviderModels := map[string]*llmv1.ProviderModels{} for name := range base.ProviderRegistry { + // Check if the provider is healthy before fetching models + if !router.DefaultHealthChecker{}.IsHealthy(name) { + continue + } + provider, err := s.iProviderService.GetProvider(provider.Provider{Name: name}) if err != nil { continue From 1601da062828c257bd4cbd26ecaaa0f1a37b0bda Mon Sep 17 00:00:00 2001 From: "sweep-ai[bot]" <128439645+sweep-ai[bot]@users.noreply.github.com> Date: Fri, 22 Mar 2024 15:55:53 +0000 Subject: [PATCH 5/9] feat: Updated gateway/internal/api/v1/providers.go --- gateway/internal/api/v1/providers.go | 24 ++++++++++++++++++++++++ 1 file changed, 24 insertions(+) diff --git a/gateway/internal/api/v1/providers.go b/gateway/internal/api/v1/providers.go index 54d5ddd..b65e4c5 100644 --- a/gateway/internal/api/v1/providers.go +++ b/gateway/internal/api/v1/providers.go @@ -20,6 +20,10 @@ func (s *V1Handler) ListProviders(ctx context.Context, req *connect.Request[empt data := []*llmv1.Provider{} for _, provider := range providers { + // Check if the provider is healthy before adding to the list + if !router.DefaultHealthChecker{}.IsHealthy(provider.Info().Name) { + continue + } providerInfo := provider.Info() data = append(data, &llmv1.Provider{ Title: providerInfo.Title, @@ -34,6 +38,11 @@ func (s *V1Handler) ListProviders(ctx context.Context, req *connect.Request[empt } func (s *V1Handler) GetProvider(ctx context.Context, req *connect.Request[llmv1.GetProviderRequest]) (*connect.Response[llmv1.GetProviderResponse], error) { + // First, check if the provider is healthy + if !router.DefaultHealthChecker{}.IsHealthy(req.Msg.Name) { + return nil, errors.NewNotFound("Provider is unhealthy") + } + provider, err := s.iProviderService.GetProvider(provider.Provider{Name: req.Msg.Name}) if err != nil { return nil, errors.NewNotFound(err.Error()) @@ -63,6 +72,11 @@ func (s *V1Handler) GetProvider(ctx context.Context, req *connect.Request[llmv1. } func (s *V1Handler) CreateProvider(ctx context.Context, req *connect.Request[llmv1.CreateProviderRequest]) (*connect.Response[llmv1.CreateProviderResponse], error) { + // First, check if the provider is healthy + if !router.DefaultHealthChecker{}.IsHealthy(req.Msg.Name) { + return nil, errors.NewNotFound("Provider is unhealthy") + } + provider := provider.Provider{Name: req.Msg.Name, Config: req.Msg.Config.AsMap()} p, err := s.iProviderService.GetProvider(provider) @@ -111,6 +125,11 @@ func (s *V1Handler) CreateProvider(ctx context.Context, req *connect.Request[llm } func (s *V1Handler) UpsertProvider(ctx context.Context, req *connect.Request[llmv1.UpdateProviderRequest]) (*connect.Response[llmv1.UpdateProviderResponse], error) { + // First, check if the provider is healthy + if !router.DefaultHealthChecker{}.IsHealthy(req.Msg.Name) { + return nil, errors.NewNotFound("Provider is unhealthy") + } + provider := provider.Provider{Name: req.Msg.Name, Config: req.Msg.Config.AsMap()} p, err := s.iProviderService.GetProvider(provider) @@ -172,6 +191,11 @@ func (s *V1Handler) UpsertProvider(ctx context.Context, req *connect.Request[llm } func (s *V1Handler) GetProviderConfig(ctx context.Context, req *connect.Request[llmv1.GetProviderConfigRequest]) (*connect.Response[llmv1.GetProviderConfigResponse], error) { + // First, check if the provider is healthy + if !router.DefaultHealthChecker{}.IsHealthy(req.Msg.Name) { + return nil, errors.NewNotFound("Provider is unhealthy") + } + p, err := s.iProviderService.GetProvider(provider.Provider{Name: req.Msg.Name}) if err != nil { return nil, errors.NewNotFound(err.Error()) From 60911bdcddb9e8d725b85d2cf0f8421262fce932 Mon Sep 17 00:00:00 2001 From: "sweep-ai[bot]" <128439645+sweep-ai[bot]@users.noreply.github.com> Date: Fri, 22 Mar 2024 16:26:44 +0000 Subject: [PATCH 6/9] feat: Add HuggingFace provider for transformer mod --- .../provider/huggingface/huggingface.go | 116 ++++++++++++++++++ 1 file changed, 116 insertions(+) create mode 100644 gateway/internal/provider/huggingface/huggingface.go diff --git a/gateway/internal/provider/huggingface/huggingface.go b/gateway/internal/provider/huggingface/huggingface.go new file mode 100644 index 0000000..dabcf50 --- /dev/null +++ b/gateway/internal/provider/huggingface/huggingface.go @@ -0,0 +1,116 @@ +package huggingface + +import ( + "context" + "encoding/json" + "fmt" + "io/ioutil" + "net/http" + "github.com/missingstudio/ai/gateway/internal/provider/base" + "github.com/missingstudio/common/errors" +) + +type HuggingFaceProvider struct { + APIKey string + BaseURL string +} + +func (hfp *HuggingFaceProvider) Info() base.ProviderInfo { + return base.ProviderInfo{ + Name: "HuggingFace", + Description: "Provider for interacting with HuggingFace's transformer models", + } +} + +func (hfp *HuggingFaceProvider) Models(ctx context.Context) ([]string, error) { + url := fmt.Sprintf("%s/models", hfp.BaseURL) + req, err := http.NewRequestWithContext(ctx, "GET", url, nil) + if err != nil { + return nil, err + } + + req.Header.Add("Authorization", "Bearer "+hfp.APIKey) + client := &http.Client{} + resp, err := client.Do(req) + if err != nil { + return nil, err + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + return nil, errors.NewBadRequest("failed to fetch models from HuggingFace") + } + + var models []string + err = json.NewDecoder(resp.Body).Decode(&models) + if err != nil { + return nil, err + } + + return models, nil +} + +func (hfp *HuggingFaceProvider) InitiateFineTuning(ctx context.Context, model string, parameters map[string]interface{}) (string, error) { + url := fmt.Sprintf("%s/fine-tune", hfp.BaseURL) + payload, err := json.Marshal(map[string]interface{}{ + "model": model, + "parameters": parameters, + }) + if err != nil { + return "", err + } + + req, err := http.NewRequestWithContext(ctx, "POST", url, ioutil.NopCloser(bytes.NewReader(payload))) + if err != nil { + return "", err + } + + req.Header.Add("Authorization", "Bearer "+hfp.APIKey) + req.Header.Add("Content-Type", "application/json") + client := &http.Client{} + resp, err := client.Do(req) + if err != nil { + return "", err + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + return "", errors.NewBadRequest("failed to initiate fine-tuning on HuggingFace") + } + + var result map[string]string + err = json.NewDecoder(resp.Body).Decode(&result) + if err != nil { + return "", err + } + + return result["job_id"], nil +} + +func (hfp *HuggingFaceProvider) RetrieveFineTuningResults(ctx context.Context, jobID string) (map[string]interface{}, error) { + url := fmt.Sprintf("%s/fine-tune/%s", hfp.BaseURL, jobID) + req, err := http.NewRequestWithContext(ctx, "GET", url, nil) + if err != nil { + return nil, err + } + + req.Header.Add("Authorization", "Bearer "+hfp.APIKey) + client := &http.Client{} + resp, err := client.Do(req) + if err != nil { + return nil, err + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + return nil, errors.NewBadRequest("failed to retrieve fine-tuning results from HuggingFace") + } + + var result map[string]interface{} + err = json.NewDecoder(resp.Body).Decode(&result) + if err != nil { + return nil, err + } + + return result, nil +} From 82f55933d34ab4ae641421256a30790d745a3cfd Mon Sep 17 00:00:00 2001 From: "sweep-ai[bot]" <128439645+sweep-ai[bot]@users.noreply.github.com> Date: Fri, 22 Mar 2024 16:27:34 +0000 Subject: [PATCH 7/9] feat: Add endpoints for initiating and checking fi --- gateway/internal/api/v1/finetune.go | 50 +++++++++++++++++++++++++++++ 1 file changed, 50 insertions(+) create mode 100644 gateway/internal/api/v1/finetune.go diff --git a/gateway/internal/api/v1/finetune.go b/gateway/internal/api/v1/finetune.go new file mode 100644 index 0000000..e4456bf --- /dev/null +++ b/gateway/internal/api/v1/finetune.go @@ -0,0 +1,50 @@ +package v1 + +import ( + "context" + "encoding/json" + "net/http" + + "connectrpc.com/connect" + "github.com/missingstudio/ai/gateway/core/provider" + "github.com/missingstudio/ai/gateway/internal/provider/huggingface" + "github.com/missingstudio/common/errors" + llmv1 "github.com/missingstudio/protos/pkg/llm/v1" +) + +func (s *V1Handler) InitiateFineTuning(ctx context.Context, req *connect.Request[llmv1.FineTuneRequest]) (*connect.Response[llmv1.FineTuneResponse], error) { + hfProvider, err := s.iProviderService.GetProvider(provider.Provider{Name: "HuggingFace"}) + if err != nil { + return nil, errors.NewInternal("failed to get HuggingFace provider") + } + + jobID, err := hfProvider.(*huggingface.HuggingFaceProvider).InitiateFineTuning(ctx, req.Payload.Model, req.Payload.Parameters) + if err != nil { + return nil, errors.NewInternal("failed to initiate fine-tuning: " + err.Error()) + } + + return connect.NewResponse(&llmv1.FineTuneResponse{ + JobId: jobID, + }), nil +} + +func (s *V1Handler) CheckFineTuningStatus(ctx context.Context, req *connect.Request[llmv1.FineTuneStatusRequest]) (*connect.Response[llmv1.FineTuneStatusResponse], error) { + hfProvider, err := s.iProviderService.GetProvider(provider.Provider{Name: "HuggingFace"}) + if err != nil { + return nil, errors.NewInternal("failed to get HuggingFace provider") + } + + result, err := hfProvider.(*huggingface.HuggingFaceProvider).RetrieveFineTuningResults(ctx, req.Payload.JobId) + if err != nil { + return nil, errors.NewInternal("failed to retrieve fine-tuning results: " + err.Error()) + } + + status, ok := result["status"].(string) + if !ok { + return nil, errors.NewInternal("unexpected response format from HuggingFace") + } + + return connect.NewResponse(&llmv1.FineTuneStatusResponse{ + Status: status, + }), nil +} From 3f04b91788cb909b639228a1376a27f2ebcfa2cf Mon Sep 17 00:00:00 2001 From: "sweep-ai[bot]" <128439645+sweep-ai[bot]@users.noreply.github.com> Date: Fri, 22 Mar 2024 16:31:54 +0000 Subject: [PATCH 8/9] feat: Updated playgrounds/apps/studio/app/(llm)/pl --- .../apps/studio/app/(llm)/playground/hooks/useModelFetch.tsx | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/playgrounds/apps/studio/app/(llm)/playground/hooks/useModelFetch.tsx b/playgrounds/apps/studio/app/(llm)/playground/hooks/useModelFetch.tsx index 48bcac8..b3a6955 100644 --- a/playgrounds/apps/studio/app/(llm)/playground/hooks/useModelFetch.tsx +++ b/playgrounds/apps/studio/app/(llm)/playground/hooks/useModelFetch.tsx @@ -14,11 +14,13 @@ interface Model { const BASE_URL = process.env.NEXT_PUBLIC_GATEWAY_URL ?? "http://localhost:3000"; export function useModelFetch() { const [providers, setProviders] = useState([]); +const [isFineTuning, setIsFineTuning] = useState(false); useEffect(() => { + const fetchEndpoint = isFineTuning ? `${BASE_URL}/api/v1/finetune/models` : `${BASE_URL}/api/v1/models`; async function fetchModels() { try { - const response = await fetch(`${BASE_URL}/api/v1/models`); + const response = await fetch(fetchEndpoint); const { models } = await response.json(); const fetchedProviders: ModelType[] = Object.keys(models).map( (key) => ({ From bd0d2375e35ba4f7425000471c68d136c3864e24 Mon Sep 17 00:00:00 2001 From: "sweep-ai[bot]" <128439645+sweep-ai[bot]@users.noreply.github.com> Date: Fri, 22 Mar 2024 16:33:34 +0000 Subject: [PATCH 9/9] feat: Updated playgrounds/apps/studio/app/(llm)/pl --- .../app/(llm)/playground/components/modelselector.tsx | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/playgrounds/apps/studio/app/(llm)/playground/components/modelselector.tsx b/playgrounds/apps/studio/app/(llm)/playground/components/modelselector.tsx index ad00039..f05a59f 100644 --- a/playgrounds/apps/studio/app/(llm)/playground/components/modelselector.tsx +++ b/playgrounds/apps/studio/app/(llm)/playground/components/modelselector.tsx @@ -26,12 +26,18 @@ interface ModelSelectorProps extends PopoverProps {} export default function ModelSelector(props: ModelSelectorProps) { const [open, setOpen] = React.useState(false); - const { providers } = useModelFetch(); + const [isFineTuning, setIsFineTuning] = React.useState(false); + const { providers } = useModelFetch(isFineTuning); const { model, setModel, setProvider } = useStore(); + const toggleFineTuning = () => setIsFineTuning(!isFineTuning); + return (
+