Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 15 additions & 1 deletion internal/swmcp/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -93,11 +93,25 @@ func WithSkyWalkingAuth(ctx context.Context, username, password string) context.
// The value is sourced from the CLI/config binding for `--sw-url`,
// falling back to the built-in default when unset.
func configuredSkyWalkingURL() string {
resolvedURL, err := resolvedConfiguredSkyWalkingURL()
if err != nil {
logrus.WithError(err).Warn("invalid SkyWalking OAP URL configuration; falling back to default URL")
return config.DefaultSWURL
}
return resolvedURL
}

func resolvedConfiguredSkyWalkingURL() (string, error) {
urlStr := viper.GetString("url")
if urlStr == "" {
urlStr = config.DefaultSWURL
}
return tools.FinalizeURL(urlStr)
return tools.NormalizeOAPURL(urlStr)
}

func validateConfiguredSkyWalkingURL() error {
_, err := resolvedConfiguredSkyWalkingURL()
return err
}

// resolveEnvVar resolves a value that may contain an environment variable reference
Expand Down
38 changes: 38 additions & 0 deletions internal/swmcp/server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import (
"testing"

"github.com/apache/skywalking-cli/pkg/contextkey"
"github.com/spf13/cobra"
"github.com/spf13/viper"

"github.com/apache/skywalking-mcp/internal/config"
Expand Down Expand Up @@ -52,6 +53,43 @@ func TestConfiguredSkyWalkingURLFinalizesConfiguredValue(t *testing.T) {
}
}

func TestConfiguredSkyWalkingURLFallsBackToDefaultOnInvalidValue(t *testing.T) {
t.Cleanup(viper.Reset)
viper.Set("url", "ftp://configured-oap.example.com:12800")

got := configuredSkyWalkingURL()
if got != config.DefaultSWURL {
t.Fatalf("configuredSkyWalkingURL() = %q, want %q", got, config.DefaultSWURL)
}
}

func TestValidateConfiguredSkyWalkingURLRejectsUnsupportedScheme(t *testing.T) {
t.Cleanup(viper.Reset)
viper.Set("url", "ftp://configured-oap.example.com:12800")

err := validateConfiguredSkyWalkingURL()
if err == nil {
t.Fatal("validateConfiguredSkyWalkingURL() error = nil, want error")
}
}

func TestTransportCommandsRejectInvalidSWURL(t *testing.T) {
t.Cleanup(viper.Reset)
viper.Set("url", "ftp://configured-oap.example.com:12800")

for name, cmd := range map[string]*cobra.Command{
"stdio": NewStdioServer(),
"sse": NewSSEServer(),
"streamable": NewStreamable(),
} {
t.Run(name, func(t *testing.T) {
if err := cmd.RunE(cmd, nil); err == nil {
t.Fatal("RunE() error = nil, want invalid sw-url error")
}
})
}
}

func TestResolveEnvVar(t *testing.T) {
t.Setenv("SW_TEST_SECRET", "resolved-secret")

Expand Down
4 changes: 4 additions & 0 deletions internal/swmcp/sse.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,10 @@ func NewSSEServer() *cobra.Command {
Short: "Start SSE server",
Long: `Start a server that listens for Server-Sent Events (SSE) on the specified address.`,
RunE: func(_ *cobra.Command, _ []string) error {
if err := validateConfiguredSkyWalkingURL(); err != nil {
return err
}

sseServerConfig := config.SSEServerConfig{
Address: viper.GetString("sse-address"),
BasePath: viper.GetString("base-path"),
Expand Down
4 changes: 4 additions & 0 deletions internal/swmcp/stdio.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,10 @@ func NewStdioServer() *cobra.Command {
Short: "Start stdio server",
Long: `Start a server that communicates via standard input/output streams using JSON-RPC messages.`,
RunE: func(_ *cobra.Command, _ []string) error {
if err := validateConfiguredSkyWalkingURL(); err != nil {
return err
}

stdioServerConfig := config.StdioServerConfig{
URL: viper.GetString("url"),
ReadOnly: viper.GetBool("read-only"),
Expand Down
4 changes: 4 additions & 0 deletions internal/swmcp/streamable.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,10 @@ func NewStreamable() *cobra.Command {
Short: "Start Streamable server",
Long: `Starting SkyWalking MCP server with Streamable HTTP transport.`,
RunE: func(_ *cobra.Command, _ []string) error {
if err := validateConfiguredSkyWalkingURL(); err != nil {
return err
}

streamableConfig := config.StreamableServerConfig{
Address: viper.GetString("address"),
EndpointPath: viper.GetString("endpoint-path"),
Expand Down
29 changes: 24 additions & 5 deletions internal/tools/common.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,18 +45,37 @@ const (

// FinalizeURL ensures the URL ends with "/graphql".
func FinalizeURL(urlStr string) string {
if !strings.HasSuffix(urlStr, "/graphql") {
urlStr = strings.TrimRight(urlStr, "/") + "/graphql"
normalizedURL, err := NormalizeOAPURL(urlStr)
if err == nil {
return normalizedURL
}
return urlStr
}

// validateURLScheme ensures the URL uses http or https.
func validateURLScheme(rawURL string) error {
// NormalizeOAPURL parses and validates the OAP URL, then ensures the path ends with /graphql.
func NormalizeOAPURL(rawURL string) (string, error) {
u, err := url.Parse(rawURL)
if err != nil {
return fmt.Errorf("invalid OAP URL: %w", err)
return "", fmt.Errorf("invalid OAP URL: %w", err)
}
if err := validateURLScheme(u); err != nil {
return "", err
}
if u.Host == "" {
return "", fmt.Errorf("invalid OAP URL %q: host is required", rawURL)
}

if u.Path == "" || u.Path == "/" {
u.Path = "/graphql"
} else if !strings.HasSuffix(u.Path, "/graphql") {
u.Path = strings.TrimRight(u.Path, "/") + "/graphql"
}

return u.String(), nil
}

// validateURLScheme ensures the URL uses http or https.
func validateURLScheme(u *url.URL) error {
if u.Scheme != "http" && u.Scheme != "https" {
return fmt.Errorf("unsupported OAP URL scheme %q: only http and https are allowed", u.Scheme)
}
Expand Down
40 changes: 40 additions & 0 deletions internal/tools/common_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
package tools

import (
"strings"
"testing"
"time"

Expand All @@ -39,6 +40,7 @@ func TestFinalizeURL(t *testing.T) {
{name: "adds graphql suffix", in: "http://localhost:12800", want: "http://localhost:12800/graphql"},
{name: "trims trailing slash", in: "http://localhost:12800/", want: "http://localhost:12800/graphql"},
{name: "keeps existing graphql", in: "http://localhost:12800/graphql", want: "http://localhost:12800/graphql"},
{name: "preserves query string", in: "http://localhost:12800?x=1", want: "http://localhost:12800/graphql?x=1"},
}

for _, tc := range tests {
Expand All @@ -50,6 +52,44 @@ func TestFinalizeURL(t *testing.T) {
}
}

func TestNormalizeOAPURL(t *testing.T) {
tests := []struct {
name string
in string
want string
wantErr string
}{
{name: "http", in: "http://localhost:12800", want: "http://localhost:12800/graphql"},
{name: "https", in: "https://localhost:12800/graphql", want: "https://localhost:12800/graphql"},
{name: "preserves query and fragment", in: "https://localhost:12800/oap?debug=1#frag", want: "https://localhost:12800/oap/graphql?debug=1#frag"},
{name: "rejects unsupported scheme", in: "ftp://localhost:12800", wantErr: "unsupported OAP URL scheme \"ftp\""},
{name: "rejects missing host", in: "http://", wantErr: "host is required"},
{name: "rejects malformed hostless path", in: "http:/foo", wantErr: "host is required"},
}

for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
got, err := NormalizeOAPURL(tc.in)
if tc.wantErr != "" {
if err == nil {
t.Fatalf("NormalizeOAPURL(%q) error = nil, want %q", tc.in, tc.wantErr)
}
if !strings.Contains(err.Error(), tc.wantErr) {
t.Fatalf("NormalizeOAPURL(%q) error = %q, want substring %q", tc.in, err.Error(), tc.wantErr)
}
return
}

if err != nil {
t.Fatalf("NormalizeOAPURL(%q) unexpected error: %v", tc.in, err)
}
if got != tc.want {
t.Fatalf("NormalizeOAPURL(%q) = %q, want %q", tc.in, got, tc.want)
}
})
}
}

func TestParseTimezoneOffset(t *testing.T) {
loc, ok := parseTimezoneOffset("+0830")
if !ok {
Expand Down
7 changes: 3 additions & 4 deletions internal/tools/mqe.go
Original file line number Diff line number Diff line change
Expand Up @@ -91,9 +91,8 @@ func getContextBool(ctx context.Context, key any) bool {
// executeGraphQLWithContext executes a GraphQL query using URL and auth from context.
func executeGraphQLWithContext(ctx context.Context, query string, variables map[string]interface{}) (*GraphQLResponse, error) {
rawURL := getContextString(ctx, contextkey.BaseURL{})
rawURL = FinalizeURL(rawURL)

if err := validateURLScheme(rawURL); err != nil {
normalizedURL, err := NormalizeOAPURL(rawURL)
if err != nil {
return nil, err
}

Expand All @@ -107,7 +106,7 @@ func executeGraphQLWithContext(ctx context.Context, query string, variables map[
return nil, fmt.Errorf("failed to marshal GraphQL request: %w", err)
}

req, err := http.NewRequestWithContext(ctx, "POST", rawURL, bytes.NewBuffer(jsonData))
req, err := http.NewRequestWithContext(ctx, "POST", normalizedURL, bytes.NewBuffer(jsonData))
if err != nil {
return nil, fmt.Errorf("failed to create HTTP request: %w", err)
}
Expand Down
Loading