diff --git a/internal/cmd/root.go b/internal/cmd/root.go index 35ec6c90..b821ba0c 100644 --- a/internal/cmd/root.go +++ b/internal/cmd/root.go @@ -1,7 +1,6 @@ package cmd import ( - "bufio" "context" "crypto/tls" "encoding/json" @@ -21,7 +20,6 @@ import ( "github.com/github/gh-aw-mcpg/internal/difc" "github.com/github/gh-aw-mcpg/internal/envutil" "github.com/github/gh-aw-mcpg/internal/logger" - "github.com/github/gh-aw-mcpg/internal/logger/sanitize" "github.com/github/gh-aw-mcpg/internal/server" "github.com/github/gh-aw-mcpg/internal/tracing" "github.com/github/gh-aw-mcpg/internal/version" @@ -155,7 +153,7 @@ func run(cmd *cobra.Command, args []string) error { // Load .env file if specified if envFile != "" { debugLog.Printf("Loading environment from file: %s", envFile) - if err := loadEnvFile(envFile); err != nil { + if err := envutil.LoadEnvFile(envFile); err != nil { return fmt.Errorf("failed to load .env file: %w", err) } } @@ -627,51 +625,6 @@ func writeGatewayConfig(cfg *config.Config, listenAddr, mode string, tlsEnabled return nil } -// loadEnvFile reads a .env file and sets environment variables -func loadEnvFile(path string) error { - file, err := os.Open(path) - if err != nil { - return err - } - defer file.Close() - - log.Printf("Loading environment from %s...", path) - scanner := bufio.NewScanner(file) - loadedVars := 0 - for scanner.Scan() { - line := strings.TrimSpace(scanner.Text()) - - // Skip empty lines and comments - if line == "" || strings.HasPrefix(line, "#") { - continue - } - - // Parse KEY=VALUE - parts := strings.SplitN(line, "=", 2) - if len(parts) != 2 { - continue - } - - key := strings.TrimSpace(parts[0]) - value := strings.TrimSpace(parts[1]) - - // Expand $VAR references in value - value = os.ExpandEnv(value) - - if err := os.Setenv(key, value); err != nil { - return fmt.Errorf("failed to set %s: %w", key, err) - } - - // Log loaded variable (hide sensitive values) - log.Printf(" Loaded: %s=%s", key, sanitize.TruncateSecret(value)) - loadedVars++ - } - - log.Printf("Loaded %d environment variables from %s", loadedVars, path) - - return scanner.Err() -} - // Execute runs the root command func Execute() { if err := rootCmd.Execute(); err != nil { diff --git a/internal/cmd/root_test.go b/internal/cmd/root_test.go index c63a768b..604485e7 100644 --- a/internal/cmd/root_test.go +++ b/internal/cmd/root_test.go @@ -5,7 +5,6 @@ import ( "context" "encoding/json" "os" - "path/filepath" "testing" "time" @@ -176,108 +175,6 @@ func TestPreRunValidation(t *testing.T) { }) } -func TestLoadEnvFile(t *testing.T) { - t.Run("load valid env file", func(t *testing.T) { - // Create temporary env file - tmpDir := t.TempDir() - envFile := filepath.Join(tmpDir, ".env") - content := `# Comment line -TEST_VAR1=value1 -TEST_VAR2=value2 -EMPTY_LINE= - -# Another comment -TEST_VAR3=value with spaces -` - err := os.WriteFile(envFile, []byte(content), 0644) - require.NoError(t, err) - - // Save and restore environment variables - origTestVar1, testVar1WasSet := os.LookupEnv("TEST_VAR1") - origTestVar2, testVar2WasSet := os.LookupEnv("TEST_VAR2") - origTestVar3, testVar3WasSet := os.LookupEnv("TEST_VAR3") - origEmptyLine, emptyLineWasSet := os.LookupEnv("EMPTY_LINE") - t.Cleanup(func() { - if testVar1WasSet { - require.NoError(t, os.Setenv("TEST_VAR1", origTestVar1)) - } else { - require.NoError(t, os.Unsetenv("TEST_VAR1")) - } - if testVar2WasSet { - require.NoError(t, os.Setenv("TEST_VAR2", origTestVar2)) - } else { - require.NoError(t, os.Unsetenv("TEST_VAR2")) - } - if testVar3WasSet { - require.NoError(t, os.Setenv("TEST_VAR3", origTestVar3)) - } else { - require.NoError(t, os.Unsetenv("TEST_VAR3")) - } - if emptyLineWasSet { - require.NoError(t, os.Setenv("EMPTY_LINE", origEmptyLine)) - } else { - require.NoError(t, os.Unsetenv("EMPTY_LINE")) - } - }) - - // Load env file - err = loadEnvFile(envFile) - require.NoError(t, err) - - // Verify variables are set - assert.Equal(t, "value1", os.Getenv("TEST_VAR1")) - assert.Equal(t, "value2", os.Getenv("TEST_VAR2")) - assert.Equal(t, "value with spaces", os.Getenv("TEST_VAR3")) - assert.Equal(t, "", os.Getenv("EMPTY_LINE")) - }) - - t.Run("nonexistent file", func(t *testing.T) { - err := loadEnvFile("/nonexistent/path/.env") - require.Error(t, err, "Should error on nonexistent file") - }) - - t.Run("env file with variable expansion", func(t *testing.T) { - // Save original values and set up cleanup before modifying environment - origBasePath, basePathWasSet := os.LookupEnv("BASE_PATH") - origExpandedVar, expandedVarWasSet := os.LookupEnv("EXPANDED_VAR") - t.Cleanup(func() { - if basePathWasSet { - _ = os.Setenv("BASE_PATH", origBasePath) - } else { - _ = os.Unsetenv("BASE_PATH") - } - if expandedVarWasSet { - _ = os.Setenv("EXPANDED_VAR", origExpandedVar) - } else { - _ = os.Unsetenv("EXPANDED_VAR") - } - }) - - // Set up a base variable for expansion - os.Setenv("BASE_PATH", "/home/user") - tmpDir := t.TempDir() - envFile := filepath.Join(tmpDir, ".env") - content := `EXPANDED_VAR=$BASE_PATH/subdir` - err := os.WriteFile(envFile, []byte(content), 0644) - require.NoError(t, err) - - err = loadEnvFile(envFile) - require.NoError(t, err) - - assert.Equal(t, "/home/user/subdir", os.Getenv("EXPANDED_VAR")) - }) - - t.Run("empty file", func(t *testing.T) { - tmpDir := t.TempDir() - envFile := filepath.Join(tmpDir, ".env") - err := os.WriteFile(envFile, []byte(""), 0644) - require.NoError(t, err) - - err = loadEnvFile(envFile) - require.NoError(t, err, "Empty file should not cause error") - }) -} - func TestWriteGatewayConfig(t *testing.T) { t.Run("unified mode with API key", func(t *testing.T) { cfg := &config.Config{ @@ -638,59 +535,3 @@ func TestWriteGatewayConfig_FileSync(t *testing.T) { require.True(t, ok) assert.Contains(t, mcpServers, "svc", "svc server should appear in output") } - -// TestLoadEnvFile_SkipMalformedLines verifies that lines without an '=' sign -// are silently skipped rather than causing an error. -func TestLoadEnvFile_SkipMalformedLines(t *testing.T) { - const envKey = "LOAD_ENV_VALID_KEY_SKIP_TEST" - t.Setenv(envKey, "") - - tmpDir := t.TempDir() - envFilePath := filepath.Join(tmpDir, ".env") - content := `# comment line -MALFORMED_NO_EQUALS -` + envKey + `=expected_value -ANOTHER_MALFORMED_LINE_WITHOUT_EQUALS -` - require.NoError(t, os.WriteFile(envFilePath, []byte(content), 0644)) - - err := loadEnvFile(envFilePath) - require.NoError(t, err, "Malformed lines should be silently skipped, not cause errors") - - // Only the valid KEY=VALUE line should have been applied - assert.Equal(t, "expected_value", os.Getenv(envKey)) -} - -// TestLoadEnvFile_OnlyComments verifies that a file with only comment and blank -// lines is processed without error and no env vars are modified. -func TestLoadEnvFile_OnlyComments(t *testing.T) { - tmpDir := t.TempDir() - envFilePath := filepath.Join(tmpDir, ".env") - content := `# This is a comment -# Another comment - -# Yet another -` - require.NoError(t, os.WriteFile(envFilePath, []byte(content), 0644)) - - err := loadEnvFile(envFilePath) - require.NoError(t, err, "File with only comments should be processed without error") -} - -// TestLoadEnvFile_EqualsInValue verifies that values containing '=' are -// preserved correctly (SplitN(..., 2) must not split on the second '='). -func TestLoadEnvFile_EqualsInValue(t *testing.T) { - const envKey = "LOAD_ENV_EQUALS_IN_VALUE" - t.Setenv(envKey, "") - - tmpDir := t.TempDir() - envFilePath := filepath.Join(tmpDir, ".env") - // Value intentionally contains '=' characters (e.g. base64-encoded secret) - content := envKey + `=dGVzdA==` - require.NoError(t, os.WriteFile(envFilePath, []byte(content), 0644)) - - err := loadEnvFile(envFilePath) - require.NoError(t, err) - assert.Equal(t, "dGVzdA==", os.Getenv(envKey), - "Value containing '=' signs should not be split on the second '='") -} diff --git a/internal/envutil/envfile.go b/internal/envutil/envfile.go new file mode 100644 index 00000000..27e86e57 --- /dev/null +++ b/internal/envutil/envfile.go @@ -0,0 +1,62 @@ +package envutil + +import ( + "bufio" + "fmt" + "os" + "strings" + + "github.com/github/gh-aw-mcpg/internal/logger" + "github.com/github/gh-aw-mcpg/internal/logger/sanitize" +) + +var logEnvFile = logger.New("envutil:envfile") + +// LoadEnvFile reads a .env file and sets environment variables. +// Lines beginning with '#' and blank lines are ignored. +// Each remaining line is expected in KEY=VALUE format; lines without '=' +// are silently skipped. Values may reference existing environment variables +// using $VAR or ${VAR} syntax (expanded via os.ExpandEnv). +func LoadEnvFile(path string) error { + file, err := os.Open(path) + if err != nil { + return err + } + defer file.Close() + + logEnvFile.Printf("Loading environment from %s...", path) + scanner := bufio.NewScanner(file) + loadedVars := 0 + for scanner.Scan() { + line := strings.TrimSpace(scanner.Text()) + + // Skip empty lines and comments + if line == "" || strings.HasPrefix(line, "#") { + continue + } + + // Parse KEY=VALUE + parts := strings.SplitN(line, "=", 2) + if len(parts) != 2 { + continue + } + + key := strings.TrimSpace(parts[0]) + value := strings.TrimSpace(parts[1]) + + // Expand $VAR references in value + value = os.ExpandEnv(value) + + if err := os.Setenv(key, value); err != nil { + return fmt.Errorf("failed to set %s: %w", key, err) + } + + // Log loaded variable (hide sensitive values) + logEnvFile.Printf(" Loaded: %s=%s", key, sanitize.TruncateSecret(value)) + loadedVars++ + } + + logEnvFile.Printf("Loaded %d environment variables from %s", loadedVars, path) + + return scanner.Err() +} diff --git a/internal/envutil/envfile_test.go b/internal/envutil/envfile_test.go new file mode 100644 index 00000000..073813f5 --- /dev/null +++ b/internal/envutil/envfile_test.go @@ -0,0 +1,168 @@ +package envutil + +import ( + "os" + "path/filepath" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestLoadEnvFile(t *testing.T) { + t.Run("load valid env file", func(t *testing.T) { + // Create temporary env file + tmpDir := t.TempDir() + envFile := filepath.Join(tmpDir, ".env") + content := `# Comment line +TEST_VAR1=value1 +TEST_VAR2=value2 +EMPTY_LINE= + +# Another comment +TEST_VAR3=value with spaces +` + err := os.WriteFile(envFile, []byte(content), 0644) + require.NoError(t, err) + + // Save and restore environment variables + origTestVar1, testVar1WasSet := os.LookupEnv("TEST_VAR1") + origTestVar2, testVar2WasSet := os.LookupEnv("TEST_VAR2") + origTestVar3, testVar3WasSet := os.LookupEnv("TEST_VAR3") + origEmptyLine, emptyLineWasSet := os.LookupEnv("EMPTY_LINE") + t.Cleanup(func() { + if testVar1WasSet { + require.NoError(t, os.Setenv("TEST_VAR1", origTestVar1)) + } else { + require.NoError(t, os.Unsetenv("TEST_VAR1")) + } + if testVar2WasSet { + require.NoError(t, os.Setenv("TEST_VAR2", origTestVar2)) + } else { + require.NoError(t, os.Unsetenv("TEST_VAR2")) + } + if testVar3WasSet { + require.NoError(t, os.Setenv("TEST_VAR3", origTestVar3)) + } else { + require.NoError(t, os.Unsetenv("TEST_VAR3")) + } + if emptyLineWasSet { + require.NoError(t, os.Setenv("EMPTY_LINE", origEmptyLine)) + } else { + require.NoError(t, os.Unsetenv("EMPTY_LINE")) + } + }) + + // Load env file + err = LoadEnvFile(envFile) + require.NoError(t, err) + + // Verify variables are set + assert.Equal(t, "value1", os.Getenv("TEST_VAR1")) + assert.Equal(t, "value2", os.Getenv("TEST_VAR2")) + assert.Equal(t, "value with spaces", os.Getenv("TEST_VAR3")) + assert.Equal(t, "", os.Getenv("EMPTY_LINE")) + }) + + t.Run("nonexistent file", func(t *testing.T) { + err := LoadEnvFile("/nonexistent/path/.env") + require.Error(t, err, "Should error on nonexistent file") + }) + + t.Run("env file with variable expansion", func(t *testing.T) { + // Save original values and set up cleanup before modifying environment + origBasePath, basePathWasSet := os.LookupEnv("BASE_PATH") + origExpandedVar, expandedVarWasSet := os.LookupEnv("EXPANDED_VAR") + t.Cleanup(func() { + if basePathWasSet { + _ = os.Setenv("BASE_PATH", origBasePath) + } else { + _ = os.Unsetenv("BASE_PATH") + } + if expandedVarWasSet { + _ = os.Setenv("EXPANDED_VAR", origExpandedVar) + } else { + _ = os.Unsetenv("EXPANDED_VAR") + } + }) + + // Set up a base variable for expansion + os.Setenv("BASE_PATH", "/home/user") + tmpDir := t.TempDir() + envFile := filepath.Join(tmpDir, ".env") + content := `EXPANDED_VAR=$BASE_PATH/subdir` + err := os.WriteFile(envFile, []byte(content), 0644) + require.NoError(t, err) + + err = LoadEnvFile(envFile) + require.NoError(t, err) + + assert.Equal(t, "/home/user/subdir", os.Getenv("EXPANDED_VAR")) + }) + + t.Run("empty file", func(t *testing.T) { + tmpDir := t.TempDir() + envFile := filepath.Join(tmpDir, ".env") + err := os.WriteFile(envFile, []byte(""), 0644) + require.NoError(t, err) + + err = LoadEnvFile(envFile) + require.NoError(t, err, "Empty file should not cause error") + }) +} + +// TestLoadEnvFile_SkipMalformedLines verifies that lines without an '=' sign +// are silently skipped rather than causing an error. +func TestLoadEnvFile_SkipMalformedLines(t *testing.T) { + const envKey = "LOAD_ENV_VALID_KEY_SKIP_TEST" + t.Setenv(envKey, "") + + tmpDir := t.TempDir() + envFilePath := filepath.Join(tmpDir, ".env") + content := `# comment line +MALFORMED_NO_EQUALS +` + envKey + `=expected_value +ANOTHER_MALFORMED_LINE_WITHOUT_EQUALS +` + require.NoError(t, os.WriteFile(envFilePath, []byte(content), 0644)) + + err := LoadEnvFile(envFilePath) + require.NoError(t, err, "Malformed lines should be silently skipped, not cause errors") + + // Only the valid KEY=VALUE line should have been applied + assert.Equal(t, "expected_value", os.Getenv(envKey)) +} + +// TestLoadEnvFile_OnlyComments verifies that a file with only comment and blank +// lines is processed without error and no env vars are modified. +func TestLoadEnvFile_OnlyComments(t *testing.T) { + tmpDir := t.TempDir() + envFilePath := filepath.Join(tmpDir, ".env") + content := `# This is a comment +# Another comment + +# Yet another +` + require.NoError(t, os.WriteFile(envFilePath, []byte(content), 0644)) + + err := LoadEnvFile(envFilePath) + require.NoError(t, err, "File with only comments should be processed without error") +} + +// TestLoadEnvFile_EqualsInValue verifies that values containing '=' are +// preserved correctly (SplitN(..., 2) must not split on the second '='). +func TestLoadEnvFile_EqualsInValue(t *testing.T) { + const envKey = "LOAD_ENV_EQUALS_IN_VALUE" + t.Setenv(envKey, "") + + tmpDir := t.TempDir() + envFilePath := filepath.Join(tmpDir, ".env") + // Value intentionally contains '=' characters (e.g. base64-encoded secret) + content := envKey + `=dGVzdA==` + require.NoError(t, os.WriteFile(envFilePath, []byte(content), 0644)) + + err := LoadEnvFile(envFilePath) + require.NoError(t, err) + assert.Equal(t, "dGVzdA==", os.Getenv(envKey), + "Value containing '=' signs should not be split on the second '='") +} diff --git a/internal/logger/sanitize/sanitize.go b/internal/logger/sanitize/sanitize.go index 74783bca..802dd272 100644 --- a/internal/logger/sanitize/sanitize.go +++ b/internal/logger/sanitize/sanitize.go @@ -33,6 +33,8 @@ import ( "net/url" "regexp" "strings" + + "github.com/github/gh-aw-mcpg/internal/strutil" ) // SecretPatterns contains regex patterns for detecting potential secrets @@ -82,12 +84,13 @@ func SanitizeString(message string) string { // For strings with 4 or fewer characters, it returns only "...". // For empty strings, it returns an empty string. func TruncateSecret(input string) string { - if len(input) > 4 { - return input[:4] + "..." - } else if len(input) > 0 { + if len(input) == 0 { + return "" + } + if len(input) <= 4 { return "..." } - return "" + return strutil.TruncateWithSuffix(input, 4, "...") } // TruncateSecretMap returns a sanitized version of environment variables diff --git a/internal/middleware/jqschema.go b/internal/middleware/jqschema.go index c29ca7ee..8a0c7c71 100644 --- a/internal/middleware/jqschema.go +++ b/internal/middleware/jqschema.go @@ -145,10 +145,9 @@ func init() { logger.LogInfo("startup", "jq schema filter compiled successfully - native Go walk_schema, array limit: 2^29 elements, timeout: %v", DefaultJqTimeout) } -// generateRandomID generates a random ID for payload storage -func generateRandomID() string { - return strutil.RandomHexWithFallback(16) -} +// queryIDBytes is the number of random bytes used to generate a query ID. +// The resulting hex string has length 2*queryIDBytes (32 characters). +const queryIDBytes = 16 // applyJqSchema applies the jq schema transformation to JSON data // Uses pre-compiled query code for better performance (3-10x faster than parsing on each request) @@ -295,7 +294,7 @@ func WrapToolHandler( toolName, sizeThreshold, baseDir, pathPrefix != "") return func(ctx context.Context, req *sdk.CallToolRequest, args interface{}) (*sdk.CallToolResult, interface{}, error) { // Generate random query ID - queryID := generateRandomID() + queryID := strutil.RandomHexWithFallback(queryIDBytes) // Get session ID from context sessionID := getSessionID(ctx) diff --git a/internal/middleware/jqschema_test.go b/internal/middleware/jqschema_test.go index 98b07a95..477370f0 100644 --- a/internal/middleware/jqschema_test.go +++ b/internal/middleware/jqschema_test.go @@ -23,17 +23,6 @@ func testGetSessionID(ctx context.Context) string { return "test-session" } -func TestGenerateRandomID(t *testing.T) { - // Generate multiple IDs and ensure they're unique - ids := make(map[string]bool) - for i := 0; i < 100; i++ { - id := generateRandomID() - assert.NotEmpty(t, id, "ID should not be empty") - assert.False(t, ids[id], "ID should be unique") - ids[id] = true - } -} - // payloadMetadataToMap converts PayloadMetadata to map[string]interface{} for test assertions // This allows tests to remain unchanged while working with the new struct type func payloadMetadataToMap(t *testing.T, data interface{}) map[string]interface{} {