diff --git a/llmproxymetrics.go b/llmproxymetrics.go index 3d498b9..6d86188 100644 --- a/llmproxymetrics.go +++ b/llmproxymetrics.go @@ -4,13 +4,12 @@ import ( "bufio" "encoding/json" "fmt" + "io" "log" "net/http" - "net/http/httptest" "net/http/httputil" "net/url" "strconv" - "strings" "github.com/caarlos0/env/v11" ) @@ -28,116 +27,60 @@ func createProxy(target *url.URL) func(http.ResponseWriter, *http.Request) { r.URL.Scheme = target.Scheme r.URL.Host = target.Host - // startTime := time.Now() - proxy := httputil.NewSingleHostReverseProxy(target) - proxy.Director = func(req *http.Request) { + director := func(req *http.Request) { req.Header.Set("X-Forwarded-For", r.RemoteAddr) req.Host = target.Host + req.URL.Scheme = target.Scheme + req.URL.Host = target.Host } - recorder := httptest.NewRecorder() + modifyResponse := func(response *http.Response) error { + pr, pw := io.Pipe() + body := response.Body + response.Body = pr - proxy.ServeHTTP(recorder, r) + go func() { + defer pw.Close() - responseBody := recorder.Body.Bytes() - contentType := recorder.Result().Header.Get("Content-Type") - - log.Printf("Response Content-Type: %s", contentType) - log.Printf("Response Body: %s", string(responseBody)) - - if strings.Contains(contentType, "application/json") { - var jsonResponse map[string]interface{} - err := json.Unmarshal(responseBody, &jsonResponse) - if err != nil { - log.Printf("Error unmarshalling JSON response: %v", err) - w.WriteHeader(recorder.Code) - w.Write(responseBody) - return - } - - // jsonResponse["metrics"] = map[string]interface{}{ - // "requestPath": r.URL.Path, - // "statusCode": recorder.Code, - // "responseTime": time.Since(startTime).Milliseconds(), - // } - - modifiedResponseBody, err := json.Marshal(jsonResponse) - if err != nil { - log.Printf("Error marshalling modified JSON response: %v", err) - w.WriteHeader(recorder.Code) - w.Write(responseBody) - return - } - - for name, values := range recorder.Header() { - for _, value := range values { - w.Header().Add(name, value) - } - } - - w.WriteHeader(recorder.Code) - w.Write(modifiedResponseBody) - - // log.Printf("Modified Response Body: %s", string(modifiedResponseBody)) - } else if strings.Contains(contentType, "application/x-ndjson") { - var modifiedResponseBody []string - scanner := bufio.NewScanner(strings.NewReader(string(responseBody))) - for scanner.Scan() { - line := scanner.Text() - if line != "" { - var jsonResponse map[string]interface{} - err := json.Unmarshal([]byte(line), &jsonResponse) + reader := bufio.NewReader(body) + for { + line, err := reader.ReadBytes('\n') if err != nil { - log.Printf("Error unmarshalling NDJSON line: %v", err) - modifiedResponseBody = append(modifiedResponseBody, line) - continue + if err == io.EOF { + handleJsonLine([]byte(string(line))) + pw.Write(line) + break + } + return } - - // jsonResponse["metrics"] = map[string]interface{}{ - // "requestPath": r.URL.Path, - // "statusCode": recorder.Code, - // "responseTime": time.Since(startTime).Milliseconds(), - // } - - modifiedLine, err := json.Marshal(jsonResponse) - if err != nil { - log.Printf("Error marshalling modified NDJSON line: %v", err) - modifiedResponseBody = append(modifiedResponseBody, string([]byte(line))) - continue - } - modifiedResponseBody = append(modifiedResponseBody, string(modifiedLine)) + handleJsonLine(line) + pw.Write(line) } - } - if err := scanner.Err(); err != nil { - log.Printf("Error scanning NDJSON: %v", err) - w.WriteHeader(recorder.Code) - w.Write(responseBody) - return - } + }() - for name, values := range recorder.Header() { - for _, value := range values { - w.Header().Add(name, value) - } - } - - w.WriteHeader(recorder.Code) - w.Write([]byte(strings.Join(modifiedResponseBody, "\n"))) - - // log.Printf("Modified Response Body: %s", strings.Join(modifiedResponseBody, "\n")) - } else { - for name, values := range recorder.Header() { - for _, value := range values { - w.Header().Add(name, value) - } - } - - w.WriteHeader(recorder.Code) - w.Write(responseBody) - - // log.Printf("Response without metrics: %s", string(responseBody)) + return nil } + + proxy := httputil.NewSingleHostReverseProxy(target) + proxy.Director = director + proxy.ModifyResponse = modifyResponse + + proxy.ServeHTTP(w, r) + } +} + +func handleJsonLine(line []byte) { + var jsonData map[string]interface{} + err := json.Unmarshal([]byte(line), &jsonData) + if err != nil { + fmt.Println("Error parsing JSON:", err) + return + } + + if jsonData["done"].(bool) { + duration := jsonData["eval_duration"].(float64) + fmt.Printf("Duration: %.2f seconds\n", duration/1000000000.0) } }