diff --git a/pkg/kthena-router/backend/sglang/metrics.go b/pkg/kthena-router/backend/sglang/metrics.go index a4a3891fb..e71c3043a 100644 --- a/pkg/kthena-router/backend/sglang/metrics.go +++ b/pkg/kthena-router/backend/sglang/metrics.go @@ -17,7 +17,10 @@ limitations under the License. package sglang import ( + "encoding/json" "fmt" + "io" + "net/http" dto "github.com/prometheus/client_model/go" corev1 "k8s.io/api/core/v1" @@ -29,6 +32,7 @@ import ( var ( GPUCacheUsage = "sglang:token_usage" RequestWaitingNum = "sglang:num_queue_reqs" + RequestRunningNum = "sglang:num_running_reqs" TPOT = "sglang:time_per_output_token_seconds" TTFT = "sglang:time_to_first_token_seconds" ) @@ -37,6 +41,7 @@ var ( CounterAndGaugeMetrics = []string{ GPUCacheUsage, RequestWaitingNum, + RequestRunningNum, } HistogramMetrics = []string{ @@ -47,11 +52,20 @@ var ( mapOfMetricsName = map[string]string{ GPUCacheUsage: utils.GPUCacheUsage, RequestWaitingNum: utils.RequestWaitingNum, + RequestRunningNum: utils.RequestRunningNum, TPOT: utils.TPOT, TTFT: utils.TTFT, } ) +type Model struct { + ID string `json:"id"` +} + +type ModelList struct { + Data []Model `json:"data"` +} + type sglangEngine struct { // The address of sglang's query metrics is http://{model server}:MetricPort/metrics // Default is 30000 @@ -83,7 +97,15 @@ func (engine *sglangEngine) GetCountMetricsInfo(allMetrics map[string]*dto.Metri continue } for _, metric := range metricInfo.Metric { - metricValue := metric.GetCounter().GetValue() + var metricValue float64 + switch metricInfo.GetType() { + case dto.MetricType_GAUGE: + metricValue = metric.GetGauge().GetValue() + case dto.MetricType_COUNTER: + metricValue = metric.GetCounter().GetValue() + default: + continue + } wantMetrics[mapOfMetricsName[metricName]] = metricValue } } @@ -115,7 +137,33 @@ func (engine *sglangEngine) GetHistogramPodMetrics(allMetrics map[string]*dto.Me return wantMetrics, histogramMetrics } -// TODO: Methods to get Models from sglang +// GetPodModels retrieves the list of models from a pod running the sglang engine. func (engine *sglangEngine) GetPodModels(pod *corev1.Pod) ([]string, error) { - return nil, nil + url := fmt.Sprintf("http://%s:%d/v1/models", pod.Status.PodIP, engine.MetricPort) + resp, err := http.Get(url) + if err != nil { + return nil, err + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("failed to get models from pod %s/%s: HTTP %d", pod.GetNamespace(), pod.GetName(), resp.StatusCode) + } + + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, err + } + + var modelList ModelList + err = json.Unmarshal(body, &modelList) + if err != nil { + return nil, err + } + + models := make([]string, 0, len(modelList.Data)) + for _, model := range modelList.Data { + models = append(models, model.ID) + } + return models, nil }