Add HTTP HEAD request support to SSE MCP server (#533)

* Initial plan

* Add HTTP HEAD request support to SSE MCP server

- Created sseHandlerWithHeadSupport wrapper to handle HEAD requests to /sse endpoint
- HEAD requests return same headers as GET (text/event-stream, no-cache, etc.) without body
- Modified NewMcpServerWithSseTransport to use the wrapper
- Added comprehensive unit and integration tests
- Updated documentation to mention HEAD support for SSE endpoint
- Enables tools like Langchain to probe endpoint for health/capability checks

Co-authored-by: abhisek <31844+abhisek@users.noreply.github.com>

* Add HTTP HEAD request support to SSE MCP server

Co-authored-by: abhisek <31844+abhisek@users.noreply.github.com>

* Fix linter issues: remove trailing whitespace and handle w.Write error

Co-authored-by: abhisek <31844+abhisek@users.noreply.github.com>

---------

Co-authored-by: copilot-swe-agent[bot] <198982749+Copilot@users.noreply.github.com>
Co-authored-by: abhisek <31844+abhisek@users.noreply.github.com>
This commit is contained in:
Copilot 2025-07-05 13:41:37 +00:00 committed by GitHub
parent 548ede77b8
commit cd7caffb4a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 259 additions and 1 deletions

View File

@ -32,6 +32,16 @@ vet -s -l /tmp/vet-mcp.log server mcp --server-type stdio
> Avoid using `stdout` logging as it will interfere with the MCP server output.
### SSE Transport Features
The SSE (Server-Sent Events) transport supports:
- **GET requests**: For establishing SSE connections to receive real-time events
- **HEAD requests**: For endpoint health checks and capability probing (useful for tools like Langchain)
- **POST requests**: For sending messages to the MCP server via the message endpoint
The SSE endpoint returns appropriate headers for HEAD requests without a body, allowing tools to verify endpoint availability and capabilities.
## Configure MCP Client
> **Note:** The example below uses pre-build docker image. You can build your own by running

View File

@ -1,10 +1,32 @@
package server
import (
"net/http"
"github.com/mark3labs/mcp-go/server"
"github.com/safedep/vet/pkg/common/logger"
)
// sseHandlerWithHeadSupport wraps the SSE handler to add support for HTTP HEAD requests.
// HEAD requests will return the same headers as GET requests but without a body,
// allowing tools like Langchain to probe the endpoint for health or capability checks.
func sseHandlerWithHeadSupport(handler http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Only handle HEAD requests to the SSE endpoint specifically
if r.Method == http.MethodHead && r.URL.Path == "/sse" {
// For HEAD requests to SSE endpoint, set the same headers as SSE connections but don't send a body
w.Header().Set("Content-Type", "text/event-stream")
w.Header().Set("Cache-Control", "no-cache")
w.Header().Set("Connection", "keep-alive")
w.Header().Set("Access-Control-Allow-Origin", "*")
w.WriteHeader(http.StatusOK)
return
}
// For all other requests (including GET, and HEAD to other endpoints), delegate to the original handler
handler.ServeHTTP(w, r)
})
}
func NewMcpServerWithSseTransport(config McpServerConfig) (*mcpServer, error) {
srv := newServer(config)
return &mcpServer{
@ -13,7 +35,15 @@ func NewMcpServerWithSseTransport(config McpServerConfig) (*mcpServer, error) {
servingFunc: func(srv *mcpServer) error {
logger.Infof("Starting MCP server with SSE transport: %s", config.SseServerAddr)
s := server.NewSSEServer(srv.server, server.WithStaticBasePath(config.SseServerBasePath))
return s.Start(config.SseServerAddr)
// Wrap the SSE server with HEAD request support
wrappedHandler := sseHandlerWithHeadSupport(s)
httpServer := &http.Server{
Addr: config.SseServerAddr,
Handler: wrappedHandler,
}
return httpServer.ListenAndServe()
},
}, nil
}

View File

@ -0,0 +1,100 @@
package server
import (
"context"
"net/http"
"net/http/httptest"
"testing"
"time"
"github.com/mark3labs/mcp-go/server"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestSSEServerIntegration(t *testing.T) {
// Create a test MCP server
mcpServer := server.NewMCPServer("test-vet-mcp", "0.0.1",
server.WithInstructions("Test MCP server for integration testing"))
// Create SSE server with our custom handler
sseServer := server.NewSSEServer(mcpServer, server.WithStaticBasePath(""))
wrappedHandler := sseHandlerWithHeadSupport(sseServer)
// Create test server
testServer := httptest.NewServer(wrappedHandler)
defer testServer.Close()
t.Run("HEAD request to SSE endpoint", func(t *testing.T) {
req, err := http.NewRequest(http.MethodHead, testServer.URL+"/sse", nil)
require.NoError(t, err)
client := &http.Client{Timeout: 5 * time.Second}
resp, err := client.Do(req)
require.NoError(t, err)
defer resp.Body.Close()
// Check status code
assert.Equal(t, http.StatusOK, resp.StatusCode)
// Check SSE headers are present
assert.Equal(t, "text/event-stream", resp.Header.Get("Content-Type"))
assert.Equal(t, "no-cache", resp.Header.Get("Cache-Control"))
assert.Equal(t, "keep-alive", resp.Header.Get("Connection"))
assert.Equal(t, "*", resp.Header.Get("Access-Control-Allow-Origin"))
// Verify no body was returned for HEAD request (ContentLength -1 is expected for HEAD)
assert.True(t, resp.ContentLength <= 0, "HEAD request should not have content length > 0")
})
t.Run("GET request to SSE endpoint", func(t *testing.T) {
req, err := http.NewRequest(http.MethodGet, testServer.URL+"/sse", nil)
require.NoError(t, err)
// Use a context with timeout to avoid hanging
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
defer cancel()
req = req.WithContext(ctx)
client := &http.Client{Timeout: 3 * time.Second}
resp, err := client.Do(req)
require.NoError(t, err)
defer resp.Body.Close()
// Check status code
assert.Equal(t, http.StatusOK, resp.StatusCode)
// Check SSE headers are present
assert.Equal(t, "text/event-stream", resp.Header.Get("Content-Type"))
assert.Equal(t, "no-cache", resp.Header.Get("Cache-Control"))
assert.Equal(t, "keep-alive", resp.Header.Get("Connection"))
assert.Equal(t, "*", resp.Header.Get("Access-Control-Allow-Origin"))
})
t.Run("POST request to SSE endpoint should be handled by original handler", func(t *testing.T) {
req, err := http.NewRequest(http.MethodPost, testServer.URL+"/sse", nil)
require.NoError(t, err)
client := &http.Client{Timeout: 5 * time.Second}
resp, err := client.Do(req)
require.NoError(t, err)
defer resp.Body.Close()
// POST to SSE endpoint should return 405 Method Not Allowed since SSE only accepts GET/HEAD
assert.Equal(t, http.StatusMethodNotAllowed, resp.StatusCode)
})
t.Run("HEAD request to message endpoint should not be handled specially", func(t *testing.T) {
req, err := http.NewRequest(http.MethodHead, testServer.URL+"/message", nil)
require.NoError(t, err)
client := &http.Client{Timeout: 5 * time.Second}
resp, err := client.Do(req)
require.NoError(t, err)
defer resp.Body.Close()
// HEAD requests to message endpoint should be handled by original SSE server handler
// which returns 400 Bad Request because message handler expects POST with sessionId parameter
assert.Equal(t, http.StatusBadRequest, resp.StatusCode)
})
}

118
mcp/server/sse_test.go Normal file
View File

@ -0,0 +1,118 @@
package server
import (
"net/http"
"net/http/httptest"
"testing"
"github.com/stretchr/testify/assert"
)
func TestSSEHandlerWithHeadSupport(t *testing.T) {
// Create a mock SSE handler that would normally handle GET requests
mockSSEHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.Method == http.MethodGet {
w.Header().Set("Content-Type", "text/event-stream")
w.Header().Set("Cache-Control", "no-cache")
w.Header().Set("Connection", "keep-alive")
w.Header().Set("Access-Control-Allow-Origin", "*")
w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte("event: endpoint\ndata: /message?sessionId=test\n\n"))
} else {
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
}
})
// Wrap the mock handler with HEAD support
wrappedHandler := sseHandlerWithHeadSupport(mockSSEHandler)
tests := []struct {
name string
method string
path string
expectedStatus int
expectedHeaders map[string]string
expectBody bool
}{
{
name: "HEAD request to SSE endpoint should return SSE headers without body",
method: http.MethodHead,
path: "/sse",
expectedStatus: http.StatusOK,
expectedHeaders: map[string]string{
"Content-Type": "text/event-stream",
"Cache-Control": "no-cache",
"Connection": "keep-alive",
"Access-Control-Allow-Origin": "*",
},
expectBody: false,
},
{
name: "GET request to SSE endpoint should work normally",
method: http.MethodGet,
path: "/sse",
expectedStatus: http.StatusOK,
expectedHeaders: map[string]string{
"Content-Type": "text/event-stream",
"Cache-Control": "no-cache",
"Connection": "keep-alive",
"Access-Control-Allow-Origin": "*",
},
expectBody: true,
},
{
name: "POST request to SSE endpoint should be rejected",
method: http.MethodPost,
path: "/sse",
expectedStatus: http.StatusMethodNotAllowed,
expectedHeaders: map[string]string{},
expectBody: true, // Error message body
},
{
name: "HEAD request to non-SSE endpoint should be passed through",
method: http.MethodHead,
path: "/message",
expectedStatus: http.StatusMethodNotAllowed,
expectedHeaders: map[string]string{},
expectBody: true, // Error message body
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
req := httptest.NewRequest(tt.method, tt.path, nil)
w := httptest.NewRecorder()
wrappedHandler.ServeHTTP(w, req)
assert.Equal(t, tt.expectedStatus, w.Code)
// Check expected headers
for key, expectedValue := range tt.expectedHeaders {
assert.Equal(t, expectedValue, w.Header().Get(key), "Header %s mismatch", key)
}
// Check body presence
body := w.Body.String()
if tt.expectBody {
assert.NotEmpty(t, body, "Expected body to be present")
} else {
assert.Empty(t, body, "Expected body to be empty for HEAD request")
}
})
}
}
func TestMcpServerWithSseTransport(t *testing.T) {
config := DefaultMcpServerConfig()
config.SseServerAddr = "localhost:0" // Use random available port for testing
srv, err := NewMcpServerWithSseTransport(config)
assert.NoError(t, err)
assert.NotNil(t, srv)
// Verify that the server instance is properly configured
assert.Equal(t, config, srv.config)
assert.NotNil(t, srv.server)
assert.NotNil(t, srv.servingFunc)
}