Files
llm-proxy-go/main.go

311 lines
7.5 KiB
Go

package main
import (
"bufio"
"bytes"
"compress/gzip"
"context"
"crypto/tls"
"fmt"
"io"
"log"
"net/http"
"net/url"
"os"
"os/signal"
"strings"
"syscall"
"time"
"github.com/google/uuid"
"github.com/joho/godotenv"
"github.com/pelletier/go-toml/v2"
)
type Config struct {
UpstreamURL string
ListenAddr string
APIKey string
Insecure bool
}
func main() {
if len(os.Args) > 1 && (os.Args[1] == "-h" || os.Args[1] == "--help") {
printHelp()
os.Exit(0)
}
godotenv.Load()
cfg := loadConfig()
printConfig(cfg)
if cfg.APIKey == "" {
fmt.Fprintln(os.Stderr, "error: API_KEY is required")
os.Exit(1)
}
if cfg.Insecure {
fmt.Fprintln(os.Stderr, "WARNING: TLS verification disabled")
}
mux := http.NewServeMux()
mux.HandleFunc("/", handleProxy(cfg))
srv := &http.Server{
Addr: cfg.ListenAddr,
Handler: mux,
}
go func() {
fmt.Printf("LLM Proxy listening on %s proxy to upstream: %s\n", cfg.ListenAddr, cfg.UpstreamURL)
srv.ListenAndServe()
}()
sigCh := make(chan os.Signal, 1)
signal.Notify(sigCh, syscall.SIGINT, syscall.SIGTERM)
<-sigCh
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
srv.Shutdown(ctx)
}
func handleProxy(cfg Config) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
requestId := uuid.New().String()
log.Println(requestId, "Handle proxy:", r.Method, r.URL)
if r.URL.Path == "/health" {
w.WriteHeader(http.StatusOK)
return
}
proxyReq := cloneRequest(requestId, r, cfg.UpstreamURL)
if cfg.APIKey != "" {
proxyReq.Header.Set("Authorization", "Bearer "+cfg.APIKey)
}
tr := &http.Transport{
TLSClientConfig: &tls.Config{InsecureSkipVerify: cfg.Insecure},
}
client := &http.Client{Transport: tr}
resp, err := client.Do(proxyReq)
if err != nil {
log.Println(requestId, "Upstream error:", err)
http.Error(w, fmt.Sprintf("upstream error: %v", err), http.StatusBadGateway)
return
}
defer resp.Body.Close()
for k, v := range resp.Header {
log.Println(requestId, "Header:", k, v)
w.Header()[k] = v
}
w.WriteHeader(resp.StatusCode)
isStreamingRequest := isStreaming(requestId, resp)
log.Println(requestId, "Request streaming:", isStreamingRequest)
if !isStreamingRequest {
log.Println(requestId, "Request streaming:", false)
body, err := io.ReadAll(resp.Body)
if err != nil {
log.Println(requestId, "Read response error:", err)
io.Copy(w, resp.Body)
return
}
printRawBody := true
contentEncoding := resp.Header.Get("Content-Encoding")
if contentEncoding == "gzip" {
gr, err := gzip.NewReader(bytes.NewReader(body))
if err != nil {
log.Println(requestId, "Decompress error:", err)
} else {
decodedBody, err := io.ReadAll(gr)
if err != nil {
log.Println(requestId, "Decompress error:", err)
}
gr.Close()
printRawBody = false
log.Println(requestId, "Response[decoded]:", contentEncoding, string(decodedBody))
}
}
if printRawBody {
log.Println(requestId, "Response[raw]:", string(body))
}
w.Write(body)
log.Println(requestId, "Reponse end.")
return
}
handleStream(requestId, w, resp.Body)
}
}
func cloneRequest(requestId string, r *http.Request, upstreamURL string) *http.Request {
upstream, _ := url.Parse(upstreamURL)
proxyReq := r.Clone(context.Background())
proxyReq.URL.Scheme = upstream.Scheme
proxyReq.URL.Host = upstream.Host
proxyReq.URL.Path = strings.ReplaceAll(r.URL.Path, upstream.Path, "")
if upstream.Path != "" && !strings.HasSuffix(proxyReq.URL.Path, "/") {
proxyReq.URL.Path = upstream.Path + proxyReq.URL.Path
}
proxyReq.Host = upstream.Host
log.Println(requestId, "Upstream proxy:", proxyReq.URL)
if val := r.Header.Get("Content-Type"); val != "" {
proxyReq.Header.Set("Content-Type", val)
}
proxyReq.Header.Del("Host")
proxyReq.Header.Del("Authorization")
proxyReq.RequestURI = ""
return proxyReq
}
func isStreaming(requestId string, resp *http.Response) bool {
ct := strings.ToLower(resp.Header.Get("Content-Type"))
log.Println(requestId, "Content-Type:", ct)
return strings.Contains(ct, "text/event-stream") ||
strings.Contains(ct, "application/x-ndjson") ||
strings.Contains(ct, "stream")
}
func handleStream(requestId string, w io.Writer, body io.Reader) {
// Cast to http.ResponseWriter to access Header and Flush methods
rw, ok := w.(http.ResponseWriter)
if !ok {
log.Println(requestId, "Error: ResponseWriter is not an http.ResponseWriter")
return
}
// // Set headers for streaming
// rw.Header().Set("Content-Type", "text/event-stream")
// rw.Header().Set("Cache-Control", "no-cache")
// rw.Header().Set("Connection", "keep-alive")
// rw.Header().Set("Transfer-Encoding", "chunked")
reader := bufio.NewReader(body)
for {
// log.Println("Sleep:", 4)
// time.Sleep(time.Duration(4) * time.Second)
line, err := reader.ReadString('\n')
if err != nil {
if err != io.EOF {
log.Println(requestId, "Stream error:", err)
fmt.Fprintf(os.Stderr, "stream error: %v\n", err)
} else {
log.Println(requestId, "Stream end.")
}
break
}
_, writeErr := rw.Write([]byte(line))
if writeErr != nil {
log.Println(requestId, "Write error:", writeErr)
break
}
if len(line) == 1 && line[len(line)-1] == 10 {
// SKIP empty line
continue
} else {
log.Println(requestId, "Process chunk:", fmt.Sprintf("%d bytes", len(line)), strings.TrimSpace(line))
}
// Flush the response to ensure the chunk is sent immediately
if flusher, ok := rw.(http.Flusher); ok {
flusher.Flush()
} else {
log.Println(requestId, "Warning: ResponseWriter does not support flushing")
}
}
}
func printHelp() {
fmt.Println(`LLM Proxy - HTTP proxy for LLM APIs
Usage:
llm-proxy Start the proxy
llm-proxy -h Show this help
Config:
Config file (optional): llm-proxy.toml
Environment variables take priority over config file.
Environment Variables:
UPSTREAM_URL Upstream LLM API URL (default: https://api.openai.com/v1/chat/completions)
LISTEN_ADDR Listen address (default: :8080)
API_KEY Upstream API key (required)
INSECURE Skip TLS verification (default: false)`)
}
func printConfig(cfg Config) {
masked := cfg.APIKey
if len(masked) > 4 {
masked = "****" + masked[len(masked)-4:]
} else {
masked = "****"
}
fmt.Printf("Upstream URL: %s\n", cfg.UpstreamURL)
fmt.Printf("Listen Addr: %s\n", cfg.ListenAddr)
fmt.Printf("API Key: %s\n", masked)
fmt.Printf("Insecure: %v\n", cfg.Insecure)
}
func loadConfig() Config {
cfg := Config{
UpstreamURL: "https://api.openai.com/v1/chat/completions",
ListenAddr: "127.0.0.1:8080",
}
if data, err := os.ReadFile("llm-proxy.toml"); err == nil {
var tomlCfg struct {
UpstreamURL string `toml:"upstream_url"`
ListenAddr string `toml:"listen_addr"`
APIKey string `toml:"api_key"`
Insecure bool `toml:"insecure"`
}
if err := toml.Unmarshal(data, &tomlCfg); err == nil {
cfg.UpstreamURL = tomlCfg.UpstreamURL
cfg.ListenAddr = tomlCfg.ListenAddr
cfg.APIKey = tomlCfg.APIKey
cfg.Insecure = tomlCfg.Insecure
fmt.Println("Loaded config from llm-proxy.toml")
}
}
if val := getEnv("UPSTREAM_URL", "OPENAI_API_BASE"); val != "" {
cfg.UpstreamURL = val
}
if val := getEnv("API_KEY", "OPENAI_API_KEY"); val != "" {
cfg.APIKey = val
}
if val := getEnv("LISTEN_ADDR"); val != "" {
cfg.ListenAddr = val
}
if val := getEnv("INSECURE"); val != "" {
cfg.Insecure = val == "true"
}
return cfg
}
func getEnv(keys ...string) string {
for _, key := range keys {
if val := os.Getenv(key); val != "" {
return val
}
}
return ""
}