-
Notifications
You must be signed in to change notification settings - Fork 160
Add request body size limit middleware function #3114
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
24a4436
3d69d56
453d830
9cab326
c4bc14c
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,110 @@ | ||
| package api | ||
|
|
||
| import ( | ||
| "bytes" | ||
| "net/http" | ||
| "net/http/httptest" | ||
| "testing" | ||
|
|
||
| "github.com/stretchr/testify/assert" | ||
| ) | ||
|
|
||
| func TestRequestBodySizeLimitMiddleware(t *testing.T) { | ||
| t.Parallel() | ||
| // Define the limit (1MB) | ||
| const maxBodySize = 1 << 20 // 1MB | ||
|
|
||
| // Helper to create the middleware handler | ||
| createHandler := func(next http.Handler) http.Handler { | ||
| return requestBodySizeLimitMiddleware(maxBodySize)(next) | ||
| } | ||
|
|
||
| t.Run("Request body within limit", func(t *testing.T) { | ||
| t.Parallel() | ||
| // Create a request with a body smaller than the limit | ||
| body := bytes.NewBuffer(make([]byte, maxBodySize-1)) | ||
| req := httptest.NewRequest(http.MethodPost, "/test", body) | ||
| rec := httptest.NewRecorder() | ||
|
|
||
| // Dummy handler that reads the body to trigger MaxBytesReader | ||
| nextHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { | ||
| buf := new(bytes.Buffer) | ||
| _, err := buf.ReadFrom(r.Body) | ||
| assert.NoError(t, err) | ||
| w.WriteHeader(http.StatusOK) | ||
| }) | ||
|
|
||
| handler := createHandler(nextHandler) | ||
| handler.ServeHTTP(rec, req) | ||
|
|
||
| assert.Equal(t, http.StatusOK, rec.Code) | ||
| }) | ||
|
|
||
| t.Run("Request body exactly at limit", func(t *testing.T) { | ||
| t.Parallel() | ||
| // Create a request with a body exactly at the limit | ||
| body := bytes.NewBuffer(make([]byte, maxBodySize)) | ||
| req := httptest.NewRequest(http.MethodPost, "/test", body) | ||
| rec := httptest.NewRecorder() | ||
|
|
||
| nextHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { | ||
| buf := new(bytes.Buffer) | ||
| _, err := buf.ReadFrom(r.Body) | ||
| assert.NoError(t, err) | ||
| w.WriteHeader(http.StatusOK) | ||
| }) | ||
|
|
||
| handler := createHandler(nextHandler) | ||
| handler.ServeHTTP(rec, req) | ||
|
|
||
| assert.Equal(t, http.StatusOK, rec.Code) | ||
| }) | ||
|
|
||
| t.Run("Request body exceeds limit via Content-Length", func(t *testing.T) { | ||
| t.Parallel() | ||
| // Create a request with a body larger than the limit | ||
| body := bytes.NewBuffer(make([]byte, maxBodySize+1)) | ||
| req := httptest.NewRequest(http.MethodPost, "/test", body) | ||
| rec := httptest.NewRecorder() | ||
|
|
||
| nextHandler := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { | ||
| w.WriteHeader(http.StatusOK) | ||
| }) | ||
|
|
||
| handler := createHandler(nextHandler) | ||
| handler.ServeHTTP(rec, req) | ||
|
|
||
| assert.Equal(t, http.StatusRequestEntityTooLarge, rec.Code) | ||
| assert.Contains(t, rec.Body.String(), "Request Entity Too Large") | ||
| }) | ||
|
|
||
| t.Run("MaxBytesReader enforces limit when Content-Length is misleading", func(t *testing.T) { | ||
| t.Parallel() | ||
| // Create a request where Content-Length is set incorrectly to test the MaxBytesReader safety net | ||
| oversizedBody := make([]byte, maxBodySize+100) | ||
| body := bytes.NewBuffer(oversizedBody) | ||
| req := httptest.NewRequest(http.MethodPost, "/api/v1beta/test", body) | ||
|
|
||
| // Manually set Content-Length to be within limit to bypass early check | ||
| req.ContentLength = maxBodySize - 1 | ||
|
|
||
| rec := httptest.NewRecorder() | ||
|
|
||
| nextHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { | ||
| // Try to read the entire body - MaxBytesReader should return an error | ||
| buf := new(bytes.Buffer) | ||
| _, err := buf.ReadFrom(r.Body) | ||
| if err != nil { | ||
| w.Header().Set("Content-Type", "application/json") | ||
| http.Error(w, "Request Entity Too Large", http.StatusRequestEntityTooLarge) | ||
| return | ||
| } | ||
| w.WriteHeader(http.StatusOK) | ||
| }) | ||
|
|
||
| handler := createHandler(nextHandler) | ||
| handler.ServeHTTP(rec, req) | ||
|
|
||
| assert.Equal(t, http.StatusRequestEntityTooLarge, rec.Code) | ||
| }) | ||
| } | ||
| Original file line number | Diff line number | Diff line change | ||||||
|---|---|---|---|---|---|---|---|---|
|
|
@@ -39,9 +39,10 @@ import ( | |||||||
|
|
||||||||
| // Not sure if these values need to be configurable. | ||||||||
| const ( | ||||||||
| middlewareTimeout = 60 * time.Second | ||||||||
| readHeaderTimeout = 10 * time.Second | ||||||||
| socketPermissions = 0660 // Socket file permissions (owner/group read-write) | ||||||||
| middlewareTimeout = 60 * time.Second | ||||||||
| readHeaderTimeout = 10 * time.Second | ||||||||
| socketPermissions = 0660 // Socket file permissions (owner/group read-write) | ||||||||
| maxRequestBodySize = 1 << 20 // 1MB - Maximum request body size | ||||||||
| ) | ||||||||
|
|
||||||||
| // ServerBuilder provides a fluent interface for building and configuring the API server | ||||||||
|
|
@@ -142,6 +143,7 @@ func (b *ServerBuilder) Build(ctx context.Context) (*chi.Mux, error) { | |||||||
| middleware.RequestID, | ||||||||
| // TODO: Figure out logging middleware. We may want to use a different logger. | ||||||||
| middleware.Timeout(middlewareTimeout), | ||||||||
| requestBodySizeLimitMiddleware(maxRequestBodySize), | ||||||||
| headersMiddleware, | ||||||||
| ) | ||||||||
|
|
||||||||
|
|
@@ -312,6 +314,27 @@ func updateCheckMiddleware() func(next http.Handler) http.Handler { | |||||||
| } | ||||||||
| } | ||||||||
|
|
||||||||
| // requestBodySizeLimitMiddleware limits request body size, returns a 413 for request bodies larger than maxSize. | ||||||||
| func requestBodySizeLimitMiddleware(maxSize int64) func(http.Handler) http.Handler { | ||||||||
| return func(next http.Handler) http.Handler { | ||||||||
| return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { | ||||||||
| // Check Content-Length header first for early rejection | ||||||||
| if r.ContentLength > maxSize { | ||||||||
| // Set Content-Type for API endpoints to match headersMiddleware behavior | ||||||||
| if strings.HasPrefix(r.URL.Path, "/api/") { | ||||||||
| w.Header().Set("Content-Type", "application/json") | ||||||||
| } | ||||||||
|
Comment on lines
+324
to
+326
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Looking at this again, let's remove this section. There is a larger issue in the codebase where errors are returned as plaintext, but we can deal with it separately.
Suggested change
|
||||||||
| http.Error(w, "Request Entity Too Large", http.StatusRequestEntityTooLarge) | ||||||||
muzman123 marked this conversation as resolved.
Show resolved
Hide resolved
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It would be nice to have a log message before the error is returned to help with diagnostics. |
||||||||
| return | ||||||||
| } | ||||||||
|
|
||||||||
| // Also set MaxBytesReader as a safety net for requests without Content-Length | ||||||||
| r.Body = http.MaxBytesReader(w, r.Body, maxSize) | ||||||||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm concerned that if this fails, it won't return a 413 status code, and cause a different error. Do you have any thoughts around this? |
||||||||
| next.ServeHTTP(w, r) | ||||||||
| }) | ||||||||
| } | ||||||||
| } | ||||||||
|
|
||||||||
| // getComponentAndVersionFromRequest determines the component name, version, and ui release build from the request | ||||||||
| func getComponentAndVersionFromRequest(r *http.Request) (string, string, bool) { | ||||||||
| clientType := r.Header.Get("X-Client-Type") | ||||||||
|
|
||||||||
Uh oh!
There was an error while loading. Please reload this page.