diff --git a/pkg/api/request_size_test.go b/pkg/api/request_size_test.go new file mode 100644 index 000000000..895c88716 --- /dev/null +++ b/pkg/api/request_size_test.go @@ -0,0 +1,110 @@ +package api + +import ( + "bytes" + "net/http" + "net/http/httptest" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestRequestBodySizeLimitMiddleware(t *testing.T) { + t.Parallel() + // Define the limit (1MB) + const maxBodySize = 1 << 20 // 1MB + + // Helper to create the middleware handler + createHandler := func(next http.Handler) http.Handler { + return requestBodySizeLimitMiddleware(maxBodySize)(next) + } + + t.Run("Request body within limit", func(t *testing.T) { + t.Parallel() + // Create a request with a body smaller than the limit + body := bytes.NewBuffer(make([]byte, maxBodySize-1)) + req := httptest.NewRequest(http.MethodPost, "/test", body) + rec := httptest.NewRecorder() + + // Dummy handler that reads the body to trigger MaxBytesReader + nextHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + buf := new(bytes.Buffer) + _, err := buf.ReadFrom(r.Body) + assert.NoError(t, err) + w.WriteHeader(http.StatusOK) + }) + + handler := createHandler(nextHandler) + handler.ServeHTTP(rec, req) + + assert.Equal(t, http.StatusOK, rec.Code) + }) + + t.Run("Request body exactly at limit", func(t *testing.T) { + t.Parallel() + // Create a request with a body exactly at the limit + body := bytes.NewBuffer(make([]byte, maxBodySize)) + req := httptest.NewRequest(http.MethodPost, "/test", body) + rec := httptest.NewRecorder() + + nextHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + buf := new(bytes.Buffer) + _, err := buf.ReadFrom(r.Body) + assert.NoError(t, err) + w.WriteHeader(http.StatusOK) + }) + + handler := createHandler(nextHandler) + handler.ServeHTTP(rec, req) + + assert.Equal(t, http.StatusOK, rec.Code) + }) + + t.Run("Request body exceeds limit via Content-Length", func(t *testing.T) { + t.Parallel() + // Create a request with a body larger than the limit + body := bytes.NewBuffer(make([]byte, maxBodySize+1)) + req := httptest.NewRequest(http.MethodPost, "/test", body) + rec := httptest.NewRecorder() + + nextHandler := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusOK) + }) + + handler := createHandler(nextHandler) + handler.ServeHTTP(rec, req) + + assert.Equal(t, http.StatusRequestEntityTooLarge, rec.Code) + assert.Contains(t, rec.Body.String(), "Request Entity Too Large") + }) + + t.Run("MaxBytesReader enforces limit when Content-Length is misleading", func(t *testing.T) { + t.Parallel() + // Create a request where Content-Length is set incorrectly to test the MaxBytesReader safety net + oversizedBody := make([]byte, maxBodySize+100) + body := bytes.NewBuffer(oversizedBody) + req := httptest.NewRequest(http.MethodPost, "/api/v1beta/test", body) + + // Manually set Content-Length to be within limit to bypass early check + req.ContentLength = maxBodySize - 1 + + rec := httptest.NewRecorder() + + nextHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Try to read the entire body - MaxBytesReader should return an error + buf := new(bytes.Buffer) + _, err := buf.ReadFrom(r.Body) + if err != nil { + w.Header().Set("Content-Type", "application/json") + http.Error(w, "Request Entity Too Large", http.StatusRequestEntityTooLarge) + return + } + w.WriteHeader(http.StatusOK) + }) + + handler := createHandler(nextHandler) + handler.ServeHTTP(rec, req) + + assert.Equal(t, http.StatusRequestEntityTooLarge, rec.Code) + }) +} diff --git a/pkg/api/server.go b/pkg/api/server.go index 72819b490..3e566d571 100644 --- a/pkg/api/server.go +++ b/pkg/api/server.go @@ -39,9 +39,10 @@ import ( // Not sure if these values need to be configurable. const ( - middlewareTimeout = 60 * time.Second - readHeaderTimeout = 10 * time.Second - socketPermissions = 0660 // Socket file permissions (owner/group read-write) + middlewareTimeout = 60 * time.Second + readHeaderTimeout = 10 * time.Second + socketPermissions = 0660 // Socket file permissions (owner/group read-write) + maxRequestBodySize = 1 << 20 // 1MB - Maximum request body size ) // ServerBuilder provides a fluent interface for building and configuring the API server @@ -142,6 +143,7 @@ func (b *ServerBuilder) Build(ctx context.Context) (*chi.Mux, error) { middleware.RequestID, // TODO: Figure out logging middleware. We may want to use a different logger. middleware.Timeout(middlewareTimeout), + requestBodySizeLimitMiddleware(maxRequestBodySize), headersMiddleware, ) @@ -312,6 +314,27 @@ func updateCheckMiddleware() func(next http.Handler) http.Handler { } } +// requestBodySizeLimitMiddleware limits request body size, returns a 413 for request bodies larger than maxSize. +func requestBodySizeLimitMiddleware(maxSize int64) func(http.Handler) http.Handler { + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Check Content-Length header first for early rejection + if r.ContentLength > maxSize { + // Set Content-Type for API endpoints to match headersMiddleware behavior + if strings.HasPrefix(r.URL.Path, "/api/") { + w.Header().Set("Content-Type", "application/json") + } + http.Error(w, "Request Entity Too Large", http.StatusRequestEntityTooLarge) + return + } + + // Also set MaxBytesReader as a safety net for requests without Content-Length + r.Body = http.MaxBytesReader(w, r.Body, maxSize) + next.ServeHTTP(w, r) + }) + } +} + // getComponentAndVersionFromRequest determines the component name, version, and ui release build from the request func getComponentAndVersionFromRequest(r *http.Request) (string, string, bool) { clientType := r.Header.Get("X-Client-Type")