diff --git a/docs/mcp.md b/docs/mcp.md index f74f720..0eef518 100644 --- a/docs/mcp.md +++ b/docs/mcp.md @@ -32,6 +32,16 @@ vet -s -l /tmp/vet-mcp.log server mcp --server-type stdio > Avoid using `stdout` logging as it will interfere with the MCP server output. +### SSE Transport Features + +The SSE (Server-Sent Events) transport supports: + +- **GET requests**: For establishing SSE connections to receive real-time events +- **HEAD requests**: For endpoint health checks and capability probing (useful for tools like Langchain) +- **POST requests**: For sending messages to the MCP server via the message endpoint + +The SSE endpoint returns appropriate headers for HEAD requests without a body, allowing tools to verify endpoint availability and capabilities. + ## Configure MCP Client > **Note:** The example below uses pre-build docker image. You can build your own by running diff --git a/mcp/server/sse.go b/mcp/server/sse.go index f8fdc98..f95167f 100644 --- a/mcp/server/sse.go +++ b/mcp/server/sse.go @@ -1,10 +1,32 @@ package server import ( + "net/http" + "github.com/mark3labs/mcp-go/server" "github.com/safedep/vet/pkg/common/logger" ) +// sseHandlerWithHeadSupport wraps the SSE handler to add support for HTTP HEAD requests. +// HEAD requests will return the same headers as GET requests but without a body, +// allowing tools like Langchain to probe the endpoint for health or capability checks. +func sseHandlerWithHeadSupport(handler http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Only handle HEAD requests to the SSE endpoint specifically + if r.Method == http.MethodHead && r.URL.Path == "/sse" { + // For HEAD requests to SSE endpoint, set the same headers as SSE connections but don't send a body + w.Header().Set("Content-Type", "text/event-stream") + w.Header().Set("Cache-Control", "no-cache") + w.Header().Set("Connection", "keep-alive") + w.Header().Set("Access-Control-Allow-Origin", "*") + w.WriteHeader(http.StatusOK) + return + } + // For all other requests (including GET, and HEAD to other endpoints), delegate to the original handler + handler.ServeHTTP(w, r) + }) +} + func NewMcpServerWithSseTransport(config McpServerConfig) (*mcpServer, error) { srv := newServer(config) return &mcpServer{ @@ -13,7 +35,15 @@ func NewMcpServerWithSseTransport(config McpServerConfig) (*mcpServer, error) { servingFunc: func(srv *mcpServer) error { logger.Infof("Starting MCP server with SSE transport: %s", config.SseServerAddr) s := server.NewSSEServer(srv.server, server.WithStaticBasePath(config.SseServerBasePath)) - return s.Start(config.SseServerAddr) + + // Wrap the SSE server with HEAD request support + wrappedHandler := sseHandlerWithHeadSupport(s) + httpServer := &http.Server{ + Addr: config.SseServerAddr, + Handler: wrappedHandler, + } + + return httpServer.ListenAndServe() }, }, nil } diff --git a/mcp/server/sse_integration_test.go b/mcp/server/sse_integration_test.go new file mode 100644 index 0000000..eef8e7e --- /dev/null +++ b/mcp/server/sse_integration_test.go @@ -0,0 +1,100 @@ +package server + +import ( + "context" + "net/http" + "net/http/httptest" + "testing" + "time" + + "github.com/mark3labs/mcp-go/server" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestSSEServerIntegration(t *testing.T) { + // Create a test MCP server + mcpServer := server.NewMCPServer("test-vet-mcp", "0.0.1", + server.WithInstructions("Test MCP server for integration testing")) + + // Create SSE server with our custom handler + sseServer := server.NewSSEServer(mcpServer, server.WithStaticBasePath("")) + wrappedHandler := sseHandlerWithHeadSupport(sseServer) + + // Create test server + testServer := httptest.NewServer(wrappedHandler) + defer testServer.Close() + + t.Run("HEAD request to SSE endpoint", func(t *testing.T) { + req, err := http.NewRequest(http.MethodHead, testServer.URL+"/sse", nil) + require.NoError(t, err) + + client := &http.Client{Timeout: 5 * time.Second} + resp, err := client.Do(req) + require.NoError(t, err) + defer resp.Body.Close() + + // Check status code + assert.Equal(t, http.StatusOK, resp.StatusCode) + + // Check SSE headers are present + assert.Equal(t, "text/event-stream", resp.Header.Get("Content-Type")) + assert.Equal(t, "no-cache", resp.Header.Get("Cache-Control")) + assert.Equal(t, "keep-alive", resp.Header.Get("Connection")) + assert.Equal(t, "*", resp.Header.Get("Access-Control-Allow-Origin")) + + // Verify no body was returned for HEAD request (ContentLength -1 is expected for HEAD) + assert.True(t, resp.ContentLength <= 0, "HEAD request should not have content length > 0") + }) + + t.Run("GET request to SSE endpoint", func(t *testing.T) { + req, err := http.NewRequest(http.MethodGet, testServer.URL+"/sse", nil) + require.NoError(t, err) + + // Use a context with timeout to avoid hanging + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() + req = req.WithContext(ctx) + + client := &http.Client{Timeout: 3 * time.Second} + resp, err := client.Do(req) + require.NoError(t, err) + defer resp.Body.Close() + + // Check status code + assert.Equal(t, http.StatusOK, resp.StatusCode) + + // Check SSE headers are present + assert.Equal(t, "text/event-stream", resp.Header.Get("Content-Type")) + assert.Equal(t, "no-cache", resp.Header.Get("Cache-Control")) + assert.Equal(t, "keep-alive", resp.Header.Get("Connection")) + assert.Equal(t, "*", resp.Header.Get("Access-Control-Allow-Origin")) + }) + + t.Run("POST request to SSE endpoint should be handled by original handler", func(t *testing.T) { + req, err := http.NewRequest(http.MethodPost, testServer.URL+"/sse", nil) + require.NoError(t, err) + + client := &http.Client{Timeout: 5 * time.Second} + resp, err := client.Do(req) + require.NoError(t, err) + defer resp.Body.Close() + + // POST to SSE endpoint should return 405 Method Not Allowed since SSE only accepts GET/HEAD + assert.Equal(t, http.StatusMethodNotAllowed, resp.StatusCode) + }) + + t.Run("HEAD request to message endpoint should not be handled specially", func(t *testing.T) { + req, err := http.NewRequest(http.MethodHead, testServer.URL+"/message", nil) + require.NoError(t, err) + + client := &http.Client{Timeout: 5 * time.Second} + resp, err := client.Do(req) + require.NoError(t, err) + defer resp.Body.Close() + + // HEAD requests to message endpoint should be handled by original SSE server handler + // which returns 400 Bad Request because message handler expects POST with sessionId parameter + assert.Equal(t, http.StatusBadRequest, resp.StatusCode) + }) +} diff --git a/mcp/server/sse_test.go b/mcp/server/sse_test.go new file mode 100644 index 0000000..ba50b7b --- /dev/null +++ b/mcp/server/sse_test.go @@ -0,0 +1,118 @@ +package server + +import ( + "net/http" + "net/http/httptest" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestSSEHandlerWithHeadSupport(t *testing.T) { + // Create a mock SSE handler that would normally handle GET requests + mockSSEHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Method == http.MethodGet { + w.Header().Set("Content-Type", "text/event-stream") + w.Header().Set("Cache-Control", "no-cache") + w.Header().Set("Connection", "keep-alive") + w.Header().Set("Access-Control-Allow-Origin", "*") + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte("event: endpoint\ndata: /message?sessionId=test\n\n")) + } else { + http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) + } + }) + + // Wrap the mock handler with HEAD support + wrappedHandler := sseHandlerWithHeadSupport(mockSSEHandler) + + tests := []struct { + name string + method string + path string + expectedStatus int + expectedHeaders map[string]string + expectBody bool + }{ + { + name: "HEAD request to SSE endpoint should return SSE headers without body", + method: http.MethodHead, + path: "/sse", + expectedStatus: http.StatusOK, + expectedHeaders: map[string]string{ + "Content-Type": "text/event-stream", + "Cache-Control": "no-cache", + "Connection": "keep-alive", + "Access-Control-Allow-Origin": "*", + }, + expectBody: false, + }, + { + name: "GET request to SSE endpoint should work normally", + method: http.MethodGet, + path: "/sse", + expectedStatus: http.StatusOK, + expectedHeaders: map[string]string{ + "Content-Type": "text/event-stream", + "Cache-Control": "no-cache", + "Connection": "keep-alive", + "Access-Control-Allow-Origin": "*", + }, + expectBody: true, + }, + { + name: "POST request to SSE endpoint should be rejected", + method: http.MethodPost, + path: "/sse", + expectedStatus: http.StatusMethodNotAllowed, + expectedHeaders: map[string]string{}, + expectBody: true, // Error message body + }, + { + name: "HEAD request to non-SSE endpoint should be passed through", + method: http.MethodHead, + path: "/message", + expectedStatus: http.StatusMethodNotAllowed, + expectedHeaders: map[string]string{}, + expectBody: true, // Error message body + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + req := httptest.NewRequest(tt.method, tt.path, nil) + w := httptest.NewRecorder() + + wrappedHandler.ServeHTTP(w, req) + + assert.Equal(t, tt.expectedStatus, w.Code) + + // Check expected headers + for key, expectedValue := range tt.expectedHeaders { + assert.Equal(t, expectedValue, w.Header().Get(key), "Header %s mismatch", key) + } + + // Check body presence + body := w.Body.String() + if tt.expectBody { + assert.NotEmpty(t, body, "Expected body to be present") + } else { + assert.Empty(t, body, "Expected body to be empty for HEAD request") + } + }) + } +} + +func TestMcpServerWithSseTransport(t *testing.T) { + config := DefaultMcpServerConfig() + config.SseServerAddr = "localhost:0" // Use random available port for testing + + srv, err := NewMcpServerWithSseTransport(config) + assert.NoError(t, err) + assert.NotNil(t, srv) + + // Verify that the server instance is properly configured + assert.Equal(t, config, srv.config) + assert.NotNil(t, srv.server) + assert.NotNil(t, srv.servingFunc) +} \ No newline at end of file