Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
110 changes: 110 additions & 0 deletions pkg/api/request_size_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
package api

import (
"bytes"
"net/http"
"net/http/httptest"
"testing"

"github.com/stretchr/testify/assert"
)

func TestRequestBodySizeLimitMiddleware(t *testing.T) {
t.Parallel()
// Define the limit (1MB)
const maxBodySize = 1 << 20 // 1MB

// Helper to create the middleware handler
createHandler := func(next http.Handler) http.Handler {
return requestBodySizeLimitMiddleware(maxBodySize)(next)
}

t.Run("Request body within limit", func(t *testing.T) {
t.Parallel()
// Create a request with a body smaller than the limit
body := bytes.NewBuffer(make([]byte, maxBodySize-1))
req := httptest.NewRequest(http.MethodPost, "/test", body)
rec := httptest.NewRecorder()

// Dummy handler that reads the body to trigger MaxBytesReader
nextHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
buf := new(bytes.Buffer)
_, err := buf.ReadFrom(r.Body)
assert.NoError(t, err)
w.WriteHeader(http.StatusOK)
})

handler := createHandler(nextHandler)
handler.ServeHTTP(rec, req)

assert.Equal(t, http.StatusOK, rec.Code)
})

t.Run("Request body exactly at limit", func(t *testing.T) {
t.Parallel()
// Create a request with a body exactly at the limit
body := bytes.NewBuffer(make([]byte, maxBodySize))
req := httptest.NewRequest(http.MethodPost, "/test", body)
rec := httptest.NewRecorder()

nextHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
buf := new(bytes.Buffer)
_, err := buf.ReadFrom(r.Body)
assert.NoError(t, err)
w.WriteHeader(http.StatusOK)
})

handler := createHandler(nextHandler)
handler.ServeHTTP(rec, req)

assert.Equal(t, http.StatusOK, rec.Code)
})

t.Run("Request body exceeds limit via Content-Length", func(t *testing.T) {
t.Parallel()
// Create a request with a body larger than the limit
body := bytes.NewBuffer(make([]byte, maxBodySize+1))
req := httptest.NewRequest(http.MethodPost, "/test", body)
rec := httptest.NewRecorder()

nextHandler := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
w.WriteHeader(http.StatusOK)
})

handler := createHandler(nextHandler)
handler.ServeHTTP(rec, req)

assert.Equal(t, http.StatusRequestEntityTooLarge, rec.Code)
assert.Contains(t, rec.Body.String(), "Request Entity Too Large")
})

t.Run("MaxBytesReader enforces limit when Content-Length is misleading", func(t *testing.T) {
t.Parallel()
// Create a request where Content-Length is set incorrectly to test the MaxBytesReader safety net
oversizedBody := make([]byte, maxBodySize+100)
body := bytes.NewBuffer(oversizedBody)
req := httptest.NewRequest(http.MethodPost, "/api/v1beta/test", body)

// Manually set Content-Length to be within limit to bypass early check
req.ContentLength = maxBodySize - 1

rec := httptest.NewRecorder()

nextHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Try to read the entire body - MaxBytesReader should return an error
buf := new(bytes.Buffer)
_, err := buf.ReadFrom(r.Body)
if err != nil {
w.Header().Set("Content-Type", "application/json")
http.Error(w, "Request Entity Too Large", http.StatusRequestEntityTooLarge)
return
}
w.WriteHeader(http.StatusOK)
})

handler := createHandler(nextHandler)
handler.ServeHTTP(rec, req)

assert.Equal(t, http.StatusRequestEntityTooLarge, rec.Code)
})
}
29 changes: 26 additions & 3 deletions pkg/api/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,9 +39,10 @@ import (

// Not sure if these values need to be configurable.
const (
middlewareTimeout = 60 * time.Second
readHeaderTimeout = 10 * time.Second
socketPermissions = 0660 // Socket file permissions (owner/group read-write)
middlewareTimeout = 60 * time.Second
readHeaderTimeout = 10 * time.Second
socketPermissions = 0660 // Socket file permissions (owner/group read-write)
maxRequestBodySize = 1 << 20 // 1MB - Maximum request body size
)

// ServerBuilder provides a fluent interface for building and configuring the API server
Expand Down Expand Up @@ -142,6 +143,7 @@ func (b *ServerBuilder) Build(ctx context.Context) (*chi.Mux, error) {
middleware.RequestID,
// TODO: Figure out logging middleware. We may want to use a different logger.
middleware.Timeout(middlewareTimeout),
requestBodySizeLimitMiddleware(maxRequestBodySize),
headersMiddleware,
)

Expand Down Expand Up @@ -312,6 +314,27 @@ func updateCheckMiddleware() func(next http.Handler) http.Handler {
}
}

// requestBodySizeLimitMiddleware limits request body size, returns a 413 for request bodies larger than maxSize.
func requestBodySizeLimitMiddleware(maxSize int64) func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Check Content-Length header first for early rejection
if r.ContentLength > maxSize {
// Set Content-Type for API endpoints to match headersMiddleware behavior
if strings.HasPrefix(r.URL.Path, "/api/") {
w.Header().Set("Content-Type", "application/json")
}
Comment on lines +324 to +326
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looking at this again, let's remove this section. There is a larger issue in the codebase where errors are returned as plaintext, but we can deal with it separately.

Suggested change
if strings.HasPrefix(r.URL.Path, "/api/") {
w.Header().Set("Content-Type", "application/json")
}

http.Error(w, "Request Entity Too Large", http.StatusRequestEntityTooLarge)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would be nice to have a log message before the error is returned to help with diagnostics.

return
}

// Also set MaxBytesReader as a safety net for requests without Content-Length
r.Body = http.MaxBytesReader(w, r.Body, maxSize)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm concerned that if this fails, it won't return a 413 status code, and cause a different error. Do you have any thoughts around this?

next.ServeHTTP(w, r)
})
}
}

// getComponentAndVersionFromRequest determines the component name, version, and ui release build from the request
func getComponentAndVersionFromRequest(r *http.Request) (string, string, bool) {
clientType := r.Header.Get("X-Client-Type")
Expand Down
Loading