diff --git a/internal/swmcp/server.go b/internal/swmcp/server.go index e5307d4..c24e888 100644 --- a/internal/swmcp/server.go +++ b/internal/swmcp/server.go @@ -93,11 +93,25 @@ func WithSkyWalkingAuth(ctx context.Context, username, password string) context. // The value is sourced from the CLI/config binding for `--sw-url`, // falling back to the built-in default when unset. func configuredSkyWalkingURL() string { + resolvedURL, err := resolvedConfiguredSkyWalkingURL() + if err != nil { + logrus.WithError(err).Warn("invalid SkyWalking OAP URL configuration; falling back to default URL") + return config.DefaultSWURL + } + return resolvedURL +} + +func resolvedConfiguredSkyWalkingURL() (string, error) { urlStr := viper.GetString("url") if urlStr == "" { urlStr = config.DefaultSWURL } - return tools.FinalizeURL(urlStr) + return tools.NormalizeOAPURL(urlStr) +} + +func validateConfiguredSkyWalkingURL() error { + _, err := resolvedConfiguredSkyWalkingURL() + return err } // resolveEnvVar resolves a value that may contain an environment variable reference diff --git a/internal/swmcp/server_test.go b/internal/swmcp/server_test.go index 1007e56..8ec433e 100644 --- a/internal/swmcp/server_test.go +++ b/internal/swmcp/server_test.go @@ -22,6 +22,7 @@ import ( "testing" "github.com/apache/skywalking-cli/pkg/contextkey" + "github.com/spf13/cobra" "github.com/spf13/viper" "github.com/apache/skywalking-mcp/internal/config" @@ -52,6 +53,43 @@ func TestConfiguredSkyWalkingURLFinalizesConfiguredValue(t *testing.T) { } } +func TestConfiguredSkyWalkingURLFallsBackToDefaultOnInvalidValue(t *testing.T) { + t.Cleanup(viper.Reset) + viper.Set("url", "ftp://configured-oap.example.com:12800") + + got := configuredSkyWalkingURL() + if got != config.DefaultSWURL { + t.Fatalf("configuredSkyWalkingURL() = %q, want %q", got, config.DefaultSWURL) + } +} + +func TestValidateConfiguredSkyWalkingURLRejectsUnsupportedScheme(t *testing.T) { + t.Cleanup(viper.Reset) + viper.Set("url", "ftp://configured-oap.example.com:12800") + + err := validateConfiguredSkyWalkingURL() + if err == nil { + t.Fatal("validateConfiguredSkyWalkingURL() error = nil, want error") + } +} + +func TestTransportCommandsRejectInvalidSWURL(t *testing.T) { + t.Cleanup(viper.Reset) + viper.Set("url", "ftp://configured-oap.example.com:12800") + + for name, cmd := range map[string]*cobra.Command{ + "stdio": NewStdioServer(), + "sse": NewSSEServer(), + "streamable": NewStreamable(), + } { + t.Run(name, func(t *testing.T) { + if err := cmd.RunE(cmd, nil); err == nil { + t.Fatal("RunE() error = nil, want invalid sw-url error") + } + }) + } +} + func TestResolveEnvVar(t *testing.T) { t.Setenv("SW_TEST_SECRET", "resolved-secret") diff --git a/internal/swmcp/sse.go b/internal/swmcp/sse.go index ae06a07..0dab635 100644 --- a/internal/swmcp/sse.go +++ b/internal/swmcp/sse.go @@ -41,6 +41,10 @@ func NewSSEServer() *cobra.Command { Short: "Start SSE server", Long: `Start a server that listens for Server-Sent Events (SSE) on the specified address.`, RunE: func(_ *cobra.Command, _ []string) error { + if err := validateConfiguredSkyWalkingURL(); err != nil { + return err + } + sseServerConfig := config.SSEServerConfig{ Address: viper.GetString("sse-address"), BasePath: viper.GetString("base-path"), diff --git a/internal/swmcp/stdio.go b/internal/swmcp/stdio.go index 02abb4a..12fe05b 100644 --- a/internal/swmcp/stdio.go +++ b/internal/swmcp/stdio.go @@ -41,6 +41,10 @@ func NewStdioServer() *cobra.Command { Short: "Start stdio server", Long: `Start a server that communicates via standard input/output streams using JSON-RPC messages.`, RunE: func(_ *cobra.Command, _ []string) error { + if err := validateConfiguredSkyWalkingURL(); err != nil { + return err + } + stdioServerConfig := config.StdioServerConfig{ URL: viper.GetString("url"), ReadOnly: viper.GetBool("read-only"), diff --git a/internal/swmcp/streamable.go b/internal/swmcp/streamable.go index 0f85fb1..5d692ad 100644 --- a/internal/swmcp/streamable.go +++ b/internal/swmcp/streamable.go @@ -36,6 +36,10 @@ func NewStreamable() *cobra.Command { Short: "Start Streamable server", Long: `Starting SkyWalking MCP server with Streamable HTTP transport.`, RunE: func(_ *cobra.Command, _ []string) error { + if err := validateConfiguredSkyWalkingURL(); err != nil { + return err + } + streamableConfig := config.StreamableServerConfig{ Address: viper.GetString("address"), EndpointPath: viper.GetString("endpoint-path"), diff --git a/internal/tools/common.go b/internal/tools/common.go index 86b8fcb..83e3db2 100644 --- a/internal/tools/common.go +++ b/internal/tools/common.go @@ -45,18 +45,37 @@ const ( // FinalizeURL ensures the URL ends with "/graphql". func FinalizeURL(urlStr string) string { - if !strings.HasSuffix(urlStr, "/graphql") { - urlStr = strings.TrimRight(urlStr, "/") + "/graphql" + normalizedURL, err := NormalizeOAPURL(urlStr) + if err == nil { + return normalizedURL } return urlStr } -// validateURLScheme ensures the URL uses http or https. -func validateURLScheme(rawURL string) error { +// NormalizeOAPURL parses and validates the OAP URL, then ensures the path ends with /graphql. +func NormalizeOAPURL(rawURL string) (string, error) { u, err := url.Parse(rawURL) if err != nil { - return fmt.Errorf("invalid OAP URL: %w", err) + return "", fmt.Errorf("invalid OAP URL: %w", err) + } + if err := validateURLScheme(u); err != nil { + return "", err + } + if u.Host == "" { + return "", fmt.Errorf("invalid OAP URL %q: host is required", rawURL) + } + + if u.Path == "" || u.Path == "/" { + u.Path = "/graphql" + } else if !strings.HasSuffix(u.Path, "/graphql") { + u.Path = strings.TrimRight(u.Path, "/") + "/graphql" } + + return u.String(), nil +} + +// validateURLScheme ensures the URL uses http or https. +func validateURLScheme(u *url.URL) error { if u.Scheme != "http" && u.Scheme != "https" { return fmt.Errorf("unsupported OAP URL scheme %q: only http and https are allowed", u.Scheme) } diff --git a/internal/tools/common_test.go b/internal/tools/common_test.go index eb60ce1..29e06b7 100644 --- a/internal/tools/common_test.go +++ b/internal/tools/common_test.go @@ -17,6 +17,7 @@ package tools import ( + "strings" "testing" "time" @@ -39,6 +40,7 @@ func TestFinalizeURL(t *testing.T) { {name: "adds graphql suffix", in: "http://localhost:12800", want: "http://localhost:12800/graphql"}, {name: "trims trailing slash", in: "http://localhost:12800/", want: "http://localhost:12800/graphql"}, {name: "keeps existing graphql", in: "http://localhost:12800/graphql", want: "http://localhost:12800/graphql"}, + {name: "preserves query string", in: "http://localhost:12800?x=1", want: "http://localhost:12800/graphql?x=1"}, } for _, tc := range tests { @@ -50,6 +52,44 @@ func TestFinalizeURL(t *testing.T) { } } +func TestNormalizeOAPURL(t *testing.T) { + tests := []struct { + name string + in string + want string + wantErr string + }{ + {name: "http", in: "http://localhost:12800", want: "http://localhost:12800/graphql"}, + {name: "https", in: "https://localhost:12800/graphql", want: "https://localhost:12800/graphql"}, + {name: "preserves query and fragment", in: "https://localhost:12800/oap?debug=1#frag", want: "https://localhost:12800/oap/graphql?debug=1#frag"}, + {name: "rejects unsupported scheme", in: "ftp://localhost:12800", wantErr: "unsupported OAP URL scheme \"ftp\""}, + {name: "rejects missing host", in: "http://", wantErr: "host is required"}, + {name: "rejects malformed hostless path", in: "http:/foo", wantErr: "host is required"}, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + got, err := NormalizeOAPURL(tc.in) + if tc.wantErr != "" { + if err == nil { + t.Fatalf("NormalizeOAPURL(%q) error = nil, want %q", tc.in, tc.wantErr) + } + if !strings.Contains(err.Error(), tc.wantErr) { + t.Fatalf("NormalizeOAPURL(%q) error = %q, want substring %q", tc.in, err.Error(), tc.wantErr) + } + return + } + + if err != nil { + t.Fatalf("NormalizeOAPURL(%q) unexpected error: %v", tc.in, err) + } + if got != tc.want { + t.Fatalf("NormalizeOAPURL(%q) = %q, want %q", tc.in, got, tc.want) + } + }) + } +} + func TestParseTimezoneOffset(t *testing.T) { loc, ok := parseTimezoneOffset("+0830") if !ok { diff --git a/internal/tools/mqe.go b/internal/tools/mqe.go index e9f4e5c..032d456 100644 --- a/internal/tools/mqe.go +++ b/internal/tools/mqe.go @@ -91,9 +91,8 @@ func getContextBool(ctx context.Context, key any) bool { // executeGraphQLWithContext executes a GraphQL query using URL and auth from context. func executeGraphQLWithContext(ctx context.Context, query string, variables map[string]interface{}) (*GraphQLResponse, error) { rawURL := getContextString(ctx, contextkey.BaseURL{}) - rawURL = FinalizeURL(rawURL) - - if err := validateURLScheme(rawURL); err != nil { + normalizedURL, err := NormalizeOAPURL(rawURL) + if err != nil { return nil, err } @@ -107,7 +106,7 @@ func executeGraphQLWithContext(ctx context.Context, query string, variables map[ return nil, fmt.Errorf("failed to marshal GraphQL request: %w", err) } - req, err := http.NewRequestWithContext(ctx, "POST", rawURL, bytes.NewBuffer(jsonData)) + req, err := http.NewRequestWithContext(ctx, "POST", normalizedURL, bytes.NewBuffer(jsonData)) if err != nil { return nil, fmt.Errorf("failed to create HTTP request: %w", err) }