package main import ( "bufio" "bytes" "compress/gzip" "context" "crypto/tls" "fmt" "io" "log" "net/http" "net/url" "os" "os/signal" "strings" "syscall" "time" "github.com/google/uuid" "github.com/pelletier/go-toml/v2" ) type Config struct { UpstreamURL string ListenAddr string APIKey string Insecure bool } func main() { if len(os.Args) > 1 && (os.Args[1] == "-h" || os.Args[1] == "--help") { printHelp() os.Exit(0) } cfg := loadConfig() printConfig(cfg) if cfg.APIKey == "" { fmt.Fprintln(os.Stderr, "error: API_KEY is required") os.Exit(1) } if cfg.Insecure { fmt.Fprintln(os.Stderr, "WARNING: TLS verification disabled") } mux := http.NewServeMux() mux.HandleFunc("/", handleProxy(cfg)) srv := &http.Server{ Addr: cfg.ListenAddr, Handler: mux, } go func() { fmt.Printf("LLM Proxy listening on %s proxy to upstream: %s\n", cfg.ListenAddr, cfg.UpstreamURL) srv.ListenAndServe() }() sigCh := make(chan os.Signal, 1) signal.Notify(sigCh, syscall.SIGINT, syscall.SIGTERM) <-sigCh ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() srv.Shutdown(ctx) } func handleProxy(cfg Config) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { requestId := uuid.New().String() log.Println(requestId, "Handle proxy:", r.Method, r.URL) if r.URL.Path == "/health" { w.WriteHeader(http.StatusOK) return } // Read and print request body if available if r.Body != nil { reqBody, err := io.ReadAll(r.Body) if err != nil { log.Println(requestId, "Error reading request body:", err) } else { log.Println(requestId, "Request body:", string(reqBody)) // Restore the body for further processing r.Body = io.NopCloser(bytes.NewBuffer(reqBody)) } } proxyReq := cloneRequest(requestId, r, cfg.UpstreamURL) if cfg.APIKey != "" { proxyReq.Header.Set("Authorization", "Bearer "+cfg.APIKey) } tr := &http.Transport{ TLSClientConfig: &tls.Config{InsecureSkipVerify: cfg.Insecure}, } client := &http.Client{Transport: tr} resp, err := client.Do(proxyReq) if err != nil { log.Println(requestId, "Upstream error:", err) http.Error(w, fmt.Sprintf("upstream error: %v", err), http.StatusBadGateway) return } defer resp.Body.Close() for k, v := range resp.Header { log.Println(requestId, "Header:", k, v) w.Header()[k] = v } w.WriteHeader(resp.StatusCode) isStreamingRequest := isStreaming(requestId, resp) log.Println(requestId, "Request streaming:", isStreamingRequest) if !isStreamingRequest { log.Println(requestId, "Request streaming:", false) body, err := io.ReadAll(resp.Body) if err != nil { log.Println(requestId, "Read response error:", err) io.Copy(w, resp.Body) return } printRawBody := true contentEncoding := resp.Header.Get("Content-Encoding") if contentEncoding == "gzip" { gr, err := gzip.NewReader(bytes.NewReader(body)) if err != nil { log.Println(requestId, "Decompress error:", err) } else { decodedBody, err := io.ReadAll(gr) if err != nil { log.Println(requestId, "Decompress error:", err) } gr.Close() printRawBody = false log.Println(requestId, "Response[decoded]:", contentEncoding, string(decodedBody)) } } if printRawBody { log.Println(requestId, "Response[raw]:", string(body)) } w.Write(body) log.Println(requestId, "Reponse end.") return } handleStream(requestId, w, resp.Body) } } func cloneRequest(requestId string, r *http.Request, upstreamURL string) *http.Request { upstream, _ := url.Parse(upstreamURL) proxyReq := r.Clone(context.Background()) proxyReq.URL.Scheme = upstream.Scheme proxyReq.URL.Host = upstream.Host proxyReq.URL.Path = strings.ReplaceAll(r.URL.Path, upstream.Path, "") if upstream.Path != "" && !strings.HasSuffix(proxyReq.URL.Path, "/") { proxyReq.URL.Path = upstream.Path + proxyReq.URL.Path } proxyReq.Host = upstream.Host log.Println(requestId, "Upstream proxy:", proxyReq.URL) if val := r.Header.Get("Content-Type"); val != "" { proxyReq.Header.Set("Content-Type", val) } proxyReq.Header.Del("Host") proxyReq.Header.Del("Authorization") proxyReq.RequestURI = "" return proxyReq } func isStreaming(requestId string, resp *http.Response) bool { ct := strings.ToLower(resp.Header.Get("Content-Type")) log.Println(requestId, "Content-Type:", ct) return strings.Contains(ct, "text/event-stream") || strings.Contains(ct, "application/x-ndjson") || strings.Contains(ct, "stream") } func handleStream(requestId string, w io.Writer, body io.Reader) { // Cast to http.ResponseWriter to access Header and Flush methods rw, ok := w.(http.ResponseWriter) if !ok { log.Println(requestId, "Error: ResponseWriter is not an http.ResponseWriter") return } // // Set headers for streaming // rw.Header().Set("Content-Type", "text/event-stream") // rw.Header().Set("Cache-Control", "no-cache") // rw.Header().Set("Connection", "keep-alive") // rw.Header().Set("Transfer-Encoding", "chunked") reader := bufio.NewReader(body) for { // log.Println("Sleep:", 4) // time.Sleep(time.Duration(4) * time.Second) line, err := reader.ReadString('\n') if err != nil { if err != io.EOF { log.Println(requestId, "Stream error:", err) fmt.Fprintf(os.Stderr, "stream error: %v\n", err) } else { log.Println(requestId, "Stream end.") } break } _, writeErr := rw.Write([]byte(line)) if writeErr != nil { log.Println(requestId, "Write error:", writeErr) break } if len(line) == 1 && line[len(line)-1] == 10 { // SKIP empty line continue } else { log.Println(requestId, "Process chunk:", fmt.Sprintf("%d bytes", len(line)), strings.TrimSpace(line)) } // Flush the response to ensure the chunk is sent immediately if flusher, ok := rw.(http.Flusher); ok { flusher.Flush() } else { log.Println(requestId, "Warning: ResponseWriter does not support flushing") } } } func printHelp() { fmt.Println(`LLM Proxy - HTTP proxy for LLM APIs Usage: llm-proxy Start the proxy llm-proxy -h Show this help Config: Config file (optional): llm-proxy.toml Environment variables take priority over config file. Environment Variables: UPSTREAM_URL Upstream LLM API URL (default: https://api.openai.com/v1/chat/completions) LISTEN_ADDR Listen address (default: :8080) API_KEY Upstream API key (required) INSECURE Skip TLS verification (default: false)`) } func printConfig(cfg Config) { masked := cfg.APIKey if len(masked) > 4 { masked = "****" + masked[len(masked)-4:] } else { masked = "****" } fmt.Printf("Upstream URL: %s\n", cfg.UpstreamURL) fmt.Printf("Listen Addr: %s\n", cfg.ListenAddr) fmt.Printf("API Key: %s\n", masked) fmt.Printf("Insecure: %v\n", cfg.Insecure) } func loadConfig() Config { cfg := Config{ UpstreamURL: "https://api.openai.com/v1/chat/completions", ListenAddr: "127.0.0.1:8080", } if data, err := os.ReadFile("llm-proxy.toml"); err == nil { var tomlCfg struct { UpstreamURL string `toml:"upstream_url"` ListenAddr string `toml:"listen_addr"` APIKey string `toml:"api_key"` Insecure bool `toml:"insecure"` } if err := toml.Unmarshal(data, &tomlCfg); err == nil { cfg.UpstreamURL = tomlCfg.UpstreamURL cfg.ListenAddr = tomlCfg.ListenAddr cfg.APIKey = tomlCfg.APIKey cfg.Insecure = tomlCfg.Insecure fmt.Println("Loaded config from llm-proxy.toml") } } if val := getEnv("UPSTREAM_URL", "OPENAI_API_BASE"); val != "" { cfg.UpstreamURL = val } if val := getEnv("API_KEY", "OPENAI_API_KEY"); val != "" { cfg.APIKey = val } if val := getEnv("LISTEN_ADDR"); val != "" { cfg.ListenAddr = val } if val := getEnv("INSECURE"); val != "" { cfg.Insecure = val == "true" } return cfg } func getEnv(keys ...string) string { for _, key := range keys { if val := os.Getenv(key); val != "" { return val } } return "" }