mirror of
https://github.com/safedep/vet.git
synced 2025-12-10 13:43:01 -06:00
Add HTTP HEAD request support to SSE MCP server (#533)
* Initial plan * Add HTTP HEAD request support to SSE MCP server - Created sseHandlerWithHeadSupport wrapper to handle HEAD requests to /sse endpoint - HEAD requests return same headers as GET (text/event-stream, no-cache, etc.) without body - Modified NewMcpServerWithSseTransport to use the wrapper - Added comprehensive unit and integration tests - Updated documentation to mention HEAD support for SSE endpoint - Enables tools like Langchain to probe endpoint for health/capability checks Co-authored-by: abhisek <31844+abhisek@users.noreply.github.com> * Add HTTP HEAD request support to SSE MCP server Co-authored-by: abhisek <31844+abhisek@users.noreply.github.com> * Fix linter issues: remove trailing whitespace and handle w.Write error Co-authored-by: abhisek <31844+abhisek@users.noreply.github.com> --------- Co-authored-by: copilot-swe-agent[bot] <198982749+Copilot@users.noreply.github.com> Co-authored-by: abhisek <31844+abhisek@users.noreply.github.com>
This commit is contained in:
parent
548ede77b8
commit
cd7caffb4a
10
docs/mcp.md
10
docs/mcp.md
@ -32,6 +32,16 @@ vet -s -l /tmp/vet-mcp.log server mcp --server-type stdio
|
||||
|
||||
> Avoid using `stdout` logging as it will interfere with the MCP server output.
|
||||
|
||||
### SSE Transport Features
|
||||
|
||||
The SSE (Server-Sent Events) transport supports:
|
||||
|
||||
- **GET requests**: For establishing SSE connections to receive real-time events
|
||||
- **HEAD requests**: For endpoint health checks and capability probing (useful for tools like Langchain)
|
||||
- **POST requests**: For sending messages to the MCP server via the message endpoint
|
||||
|
||||
The SSE endpoint returns appropriate headers for HEAD requests without a body, allowing tools to verify endpoint availability and capabilities.
|
||||
|
||||
## Configure MCP Client
|
||||
|
||||
> **Note:** The example below uses pre-build docker image. You can build your own by running
|
||||
|
||||
@ -1,10 +1,32 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
"github.com/mark3labs/mcp-go/server"
|
||||
"github.com/safedep/vet/pkg/common/logger"
|
||||
)
|
||||
|
||||
// sseHandlerWithHeadSupport wraps the SSE handler to add support for HTTP HEAD requests.
|
||||
// HEAD requests will return the same headers as GET requests but without a body,
|
||||
// allowing tools like Langchain to probe the endpoint for health or capability checks.
|
||||
func sseHandlerWithHeadSupport(handler http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
// Only handle HEAD requests to the SSE endpoint specifically
|
||||
if r.Method == http.MethodHead && r.URL.Path == "/sse" {
|
||||
// For HEAD requests to SSE endpoint, set the same headers as SSE connections but don't send a body
|
||||
w.Header().Set("Content-Type", "text/event-stream")
|
||||
w.Header().Set("Cache-Control", "no-cache")
|
||||
w.Header().Set("Connection", "keep-alive")
|
||||
w.Header().Set("Access-Control-Allow-Origin", "*")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
return
|
||||
}
|
||||
// For all other requests (including GET, and HEAD to other endpoints), delegate to the original handler
|
||||
handler.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
|
||||
func NewMcpServerWithSseTransport(config McpServerConfig) (*mcpServer, error) {
|
||||
srv := newServer(config)
|
||||
return &mcpServer{
|
||||
@ -13,7 +35,15 @@ func NewMcpServerWithSseTransport(config McpServerConfig) (*mcpServer, error) {
|
||||
servingFunc: func(srv *mcpServer) error {
|
||||
logger.Infof("Starting MCP server with SSE transport: %s", config.SseServerAddr)
|
||||
s := server.NewSSEServer(srv.server, server.WithStaticBasePath(config.SseServerBasePath))
|
||||
return s.Start(config.SseServerAddr)
|
||||
|
||||
// Wrap the SSE server with HEAD request support
|
||||
wrappedHandler := sseHandlerWithHeadSupport(s)
|
||||
httpServer := &http.Server{
|
||||
Addr: config.SseServerAddr,
|
||||
Handler: wrappedHandler,
|
||||
}
|
||||
|
||||
return httpServer.ListenAndServe()
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
100
mcp/server/sse_integration_test.go
Normal file
100
mcp/server/sse_integration_test.go
Normal file
@ -0,0 +1,100 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/mark3labs/mcp-go/server"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestSSEServerIntegration(t *testing.T) {
|
||||
// Create a test MCP server
|
||||
mcpServer := server.NewMCPServer("test-vet-mcp", "0.0.1",
|
||||
server.WithInstructions("Test MCP server for integration testing"))
|
||||
|
||||
// Create SSE server with our custom handler
|
||||
sseServer := server.NewSSEServer(mcpServer, server.WithStaticBasePath(""))
|
||||
wrappedHandler := sseHandlerWithHeadSupport(sseServer)
|
||||
|
||||
// Create test server
|
||||
testServer := httptest.NewServer(wrappedHandler)
|
||||
defer testServer.Close()
|
||||
|
||||
t.Run("HEAD request to SSE endpoint", func(t *testing.T) {
|
||||
req, err := http.NewRequest(http.MethodHead, testServer.URL+"/sse", nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
client := &http.Client{Timeout: 5 * time.Second}
|
||||
resp, err := client.Do(req)
|
||||
require.NoError(t, err)
|
||||
defer resp.Body.Close()
|
||||
|
||||
// Check status code
|
||||
assert.Equal(t, http.StatusOK, resp.StatusCode)
|
||||
|
||||
// Check SSE headers are present
|
||||
assert.Equal(t, "text/event-stream", resp.Header.Get("Content-Type"))
|
||||
assert.Equal(t, "no-cache", resp.Header.Get("Cache-Control"))
|
||||
assert.Equal(t, "keep-alive", resp.Header.Get("Connection"))
|
||||
assert.Equal(t, "*", resp.Header.Get("Access-Control-Allow-Origin"))
|
||||
|
||||
// Verify no body was returned for HEAD request (ContentLength -1 is expected for HEAD)
|
||||
assert.True(t, resp.ContentLength <= 0, "HEAD request should not have content length > 0")
|
||||
})
|
||||
|
||||
t.Run("GET request to SSE endpoint", func(t *testing.T) {
|
||||
req, err := http.NewRequest(http.MethodGet, testServer.URL+"/sse", nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Use a context with timeout to avoid hanging
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
|
||||
defer cancel()
|
||||
req = req.WithContext(ctx)
|
||||
|
||||
client := &http.Client{Timeout: 3 * time.Second}
|
||||
resp, err := client.Do(req)
|
||||
require.NoError(t, err)
|
||||
defer resp.Body.Close()
|
||||
|
||||
// Check status code
|
||||
assert.Equal(t, http.StatusOK, resp.StatusCode)
|
||||
|
||||
// Check SSE headers are present
|
||||
assert.Equal(t, "text/event-stream", resp.Header.Get("Content-Type"))
|
||||
assert.Equal(t, "no-cache", resp.Header.Get("Cache-Control"))
|
||||
assert.Equal(t, "keep-alive", resp.Header.Get("Connection"))
|
||||
assert.Equal(t, "*", resp.Header.Get("Access-Control-Allow-Origin"))
|
||||
})
|
||||
|
||||
t.Run("POST request to SSE endpoint should be handled by original handler", func(t *testing.T) {
|
||||
req, err := http.NewRequest(http.MethodPost, testServer.URL+"/sse", nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
client := &http.Client{Timeout: 5 * time.Second}
|
||||
resp, err := client.Do(req)
|
||||
require.NoError(t, err)
|
||||
defer resp.Body.Close()
|
||||
|
||||
// POST to SSE endpoint should return 405 Method Not Allowed since SSE only accepts GET/HEAD
|
||||
assert.Equal(t, http.StatusMethodNotAllowed, resp.StatusCode)
|
||||
})
|
||||
|
||||
t.Run("HEAD request to message endpoint should not be handled specially", func(t *testing.T) {
|
||||
req, err := http.NewRequest(http.MethodHead, testServer.URL+"/message", nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
client := &http.Client{Timeout: 5 * time.Second}
|
||||
resp, err := client.Do(req)
|
||||
require.NoError(t, err)
|
||||
defer resp.Body.Close()
|
||||
|
||||
// HEAD requests to message endpoint should be handled by original SSE server handler
|
||||
// which returns 400 Bad Request because message handler expects POST with sessionId parameter
|
||||
assert.Equal(t, http.StatusBadRequest, resp.StatusCode)
|
||||
})
|
||||
}
|
||||
118
mcp/server/sse_test.go
Normal file
118
mcp/server/sse_test.go
Normal file
@ -0,0 +1,118 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestSSEHandlerWithHeadSupport(t *testing.T) {
|
||||
// Create a mock SSE handler that would normally handle GET requests
|
||||
mockSSEHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method == http.MethodGet {
|
||||
w.Header().Set("Content-Type", "text/event-stream")
|
||||
w.Header().Set("Cache-Control", "no-cache")
|
||||
w.Header().Set("Connection", "keep-alive")
|
||||
w.Header().Set("Access-Control-Allow-Origin", "*")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
_, _ = w.Write([]byte("event: endpoint\ndata: /message?sessionId=test\n\n"))
|
||||
} else {
|
||||
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
|
||||
}
|
||||
})
|
||||
|
||||
// Wrap the mock handler with HEAD support
|
||||
wrappedHandler := sseHandlerWithHeadSupport(mockSSEHandler)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
method string
|
||||
path string
|
||||
expectedStatus int
|
||||
expectedHeaders map[string]string
|
||||
expectBody bool
|
||||
}{
|
||||
{
|
||||
name: "HEAD request to SSE endpoint should return SSE headers without body",
|
||||
method: http.MethodHead,
|
||||
path: "/sse",
|
||||
expectedStatus: http.StatusOK,
|
||||
expectedHeaders: map[string]string{
|
||||
"Content-Type": "text/event-stream",
|
||||
"Cache-Control": "no-cache",
|
||||
"Connection": "keep-alive",
|
||||
"Access-Control-Allow-Origin": "*",
|
||||
},
|
||||
expectBody: false,
|
||||
},
|
||||
{
|
||||
name: "GET request to SSE endpoint should work normally",
|
||||
method: http.MethodGet,
|
||||
path: "/sse",
|
||||
expectedStatus: http.StatusOK,
|
||||
expectedHeaders: map[string]string{
|
||||
"Content-Type": "text/event-stream",
|
||||
"Cache-Control": "no-cache",
|
||||
"Connection": "keep-alive",
|
||||
"Access-Control-Allow-Origin": "*",
|
||||
},
|
||||
expectBody: true,
|
||||
},
|
||||
{
|
||||
name: "POST request to SSE endpoint should be rejected",
|
||||
method: http.MethodPost,
|
||||
path: "/sse",
|
||||
expectedStatus: http.StatusMethodNotAllowed,
|
||||
expectedHeaders: map[string]string{},
|
||||
expectBody: true, // Error message body
|
||||
},
|
||||
{
|
||||
name: "HEAD request to non-SSE endpoint should be passed through",
|
||||
method: http.MethodHead,
|
||||
path: "/message",
|
||||
expectedStatus: http.StatusMethodNotAllowed,
|
||||
expectedHeaders: map[string]string{},
|
||||
expectBody: true, // Error message body
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
req := httptest.NewRequest(tt.method, tt.path, nil)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
wrappedHandler.ServeHTTP(w, req)
|
||||
|
||||
assert.Equal(t, tt.expectedStatus, w.Code)
|
||||
|
||||
// Check expected headers
|
||||
for key, expectedValue := range tt.expectedHeaders {
|
||||
assert.Equal(t, expectedValue, w.Header().Get(key), "Header %s mismatch", key)
|
||||
}
|
||||
|
||||
// Check body presence
|
||||
body := w.Body.String()
|
||||
if tt.expectBody {
|
||||
assert.NotEmpty(t, body, "Expected body to be present")
|
||||
} else {
|
||||
assert.Empty(t, body, "Expected body to be empty for HEAD request")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestMcpServerWithSseTransport(t *testing.T) {
|
||||
config := DefaultMcpServerConfig()
|
||||
config.SseServerAddr = "localhost:0" // Use random available port for testing
|
||||
|
||||
srv, err := NewMcpServerWithSseTransport(config)
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, srv)
|
||||
|
||||
// Verify that the server instance is properly configured
|
||||
assert.Equal(t, config, srv.config)
|
||||
assert.NotNil(t, srv.server)
|
||||
assert.NotNil(t, srv.servingFunc)
|
||||
}
|
||||
Loading…
x
Reference in New Issue
Block a user