Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 48 additions & 11 deletions helper/github/client.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package github

import (
"crypto/tls"
"net/http"

"github.com/bradleyfalzon/ghinstallation"
Expand All @@ -10,26 +11,62 @@ import (
// Client is a wrapper around GitHub client that supports GitHub App authentication for multiple installations.
type Client struct {
*github.Client

transport *ghinstallation.AppsTransport
transport *ghinstallation.AppsTransport
skipTLSVerify bool
}

// NewClient creates a new Client.
func NewClient(appID int64, appPrivateKey string) (*Client, error) {
transport, err := ghinstallation.NewAppsTransport(http.DefaultTransport, appID, []byte(appPrivateKey))
// NewClient creates a new GitHub App client that supports both GitHub.com and GitHub Enterprise API URLs.
// githubEnterpriseApiUrl should be a full API base URL (e.g. https://api.githubenterprise.example.com/).
func NewClient(appID int64, appPrivateKey string, githubEnterpriseApiUrl string, skipTLSVerify bool) (*Client, error) {
baseTransport := http.DefaultTransport.(*http.Transport).Clone()
if skipTLSVerify {
if baseTransport.TLSClientConfig == nil {
baseTransport.TLSClientConfig = &tls.Config{}
}
baseTransport.TLSClientConfig.InsecureSkipVerify = true
// Prefer TLS 1.2+ even when skipping verification.
baseTransport.TLSClientConfig.MinVersion = tls.VersionTLS12
}

transport, err := ghinstallation.NewAppsTransport(baseTransport, appID, []byte(appPrivateKey))
if err != nil {
return nil, err
}

client := &Client{
Client: github.NewClient(&http.Client{Transport: transport}),
transport: transport,
httpClient := &http.Client{Transport: transport}
ghClient := github.NewClient(httpClient)

if githubEnterpriseApiUrl != "" {
ghClient, err = ghClient.WithEnterpriseURLs(githubEnterpriseApiUrl, "")
if err != nil {
return nil, err
}

transport.BaseURL = githubEnterpriseApiUrl
}
Comment on lines +39 to 46
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Normalize enterprise URLs and set upload base.

WithEnterpriseURLs expects the API base (.../api/v3/) and an upload base (.../api/uploads/). Empty upload URL may break asset uploads; lack of trailing slash can break path joins.

Apply:

- if githubEnterpriseApiUrl != "" {
-     ghClient, err = ghClient.WithEnterpriseURLs(githubEnterpriseApiUrl, "")
+ if githubEnterpriseApiUrl != "" {
+     base := ensureTrailingSlash(githubEnterpriseApiUrl)
+     // If base is the host root, append /api/v3/
+     if !strings.Contains(base, "/api/") {
+         base = strings.TrimRight(base, "/") + "/api/v3/"
+     }
+     upload := strings.Replace(base, "/api/v3/", "/api/uploads/", 1)
+     ghClient, err = ghClient.WithEnterpriseURLs(base, upload)
      if err != nil {
          return nil, err
      }
-     transport.BaseURL = githubEnterpriseApiUrl
+     transport.BaseURL = base
 }

And add helpers/imports:

+import "strings"
+
+func ensureTrailingSlash(u string) string {
+    if strings.HasSuffix(u, "/") { return u }
+    return u + "/"
+}


return client, nil
return &Client{
Client: ghClient,
transport: transport,
skipTLSVerify: skipTLSVerify,
}, nil
}

// Installation returns a new GitHub client for the given installation ID.
func (c *Client) Installation(installationID int64) *github.Client {
return github.NewClient(&http.Client{Transport: ghinstallation.NewFromAppsTransport(c.transport, installationID)})
installationTransport := ghinstallation.NewFromAppsTransport(c.transport, installationID)
httpClient := &http.Client{Transport: installationTransport}

installationClient := github.NewClient(httpClient)

// Detect if we're using GitHub Enterprise (since DefaultBaseURL is no longer exported)
if c.Client.BaseURL != nil && c.Client.BaseURL.String() != "https://api.github.com/" {
if enterpriseClient, err := installationClient.WithEnterpriseURLs(
c.Client.BaseURL.String(),
c.Client.UploadURL.String(),
); err == nil {
return enterpriseClient
}
}

return installationClient
Comment thread
coderabbitai[bot] marked this conversation as resolved.
}
79 changes: 75 additions & 4 deletions helper/github/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,19 +7,67 @@ import (
"github.com/stretchr/testify/assert"
)

func TestNewClient_Success(t *testing.T) {
func TestNewClient_Success_GitHubCom(t *testing.T) {
key, err := os.ReadFile("testdata/test.key")
if err != nil {
t.Fatal(err)
}

client, err := NewClient(12345, string(key))
client, err := NewClient(12345, string(key), "", false)
assert.NoError(t, err)
assert.NotNil(t, client)
assert.NotNil(t, client.Client)
assert.NotNil(t, client.transport)

// Should default to GitHub.com
assert.Contains(t, client.Client.BaseURL.String(), "https://api.github.com/")
}

func TestNewClient_Success_GitHubCom_TLS(t *testing.T) {
key, err := os.ReadFile("testdata/test.key")
if err != nil {
t.Fatal(err)
}

client, err := NewClient(12345, string(key), "", true)
assert.NoError(t, err)
assert.NotNil(t, client)
}

func TestNewClient_Success_Enterprise(t *testing.T) {
key, err := os.ReadFile("testdata/test.key")
if err != nil {
t.Fatal(err)
}

fakeEnterpriseURL := "https://api.githubenterprise.example.com/"

client, err := NewClient(12345, string(key), fakeEnterpriseURL, false)
assert.NoError(t, err)
assert.NotNil(t, client)
assert.NotNil(t, client.transport)

// Ensure the Enterprise base URL was applied
assert.Contains(t, client.Client.BaseURL.String(), fakeEnterpriseURL)
assert.Equal(t, client.transport.BaseURL, fakeEnterpriseURL)
}
Comment on lines +43 to +53
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Use the API base, not host root, and assert normalized URL.

Tests pass a host root (https://api.githubenterprise.example.com/), but config and most clients expect the API base (.../api/v3/). Align to avoid false positives that would fail at runtime when hitting endpoints.

Apply:

- fakeEnterpriseURL := "https://api.githubenterprise.example.com/"
+ fakeEnterpriseURL := "https://api.githubenterprise.example.com/api/v3/"

Then assert exact equality (with trailing slash) on client.Client.BaseURL.String() and client.transport.BaseURL.

📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
fakeEnterpriseURL := "https://api.githubenterprise.example.com/"
client, err := NewClient(12345, string(key), fakeEnterpriseURL, false)
assert.NoError(t, err)
assert.NotNil(t, client)
assert.NotNil(t, client.transport)
// Ensure the Enterprise base URL was applied
assert.Contains(t, client.Client.BaseURL.String(), fakeEnterpriseURL)
assert.Equal(t, client.transport.BaseURL, fakeEnterpriseURL)
}
fakeEnterpriseURL := "https://api.githubenterprise.example.com/api/v3/"
client, err := NewClient(12345, string(key), fakeEnterpriseURL, false)
assert.NoError(t, err)
assert.NotNil(t, client)
assert.NotNil(t, client.transport)
// Ensure the Enterprise base URL was applied
assert.Contains(t, client.Client.BaseURL.String(), fakeEnterpriseURL)
assert.Equal(t, client.transport.BaseURL, fakeEnterpriseURL)
}
🤖 Prompt for AI Agents
In helper/github/client_test.go around lines 43 to 53, the test currently passes
the host root URL but the code and real clients expect the API base (including
/api/v3/), which can hide runtime failures; update fakeEnterpriseURL to use the
API base (e.g. "https://api.githubenterprise.example.com/api/v3/"), then change
the assertions to check exact equality (including the trailing slash) for both
client.Client.BaseURL.String() and client.transport.BaseURL so the test verifies
normalized API base URLs precisely.


func TestNewClient_Success_Enterprise_TLS(t *testing.T) {
key, err := os.ReadFile("testdata/test.key")
if err != nil {
t.Fatal(err)
}

fakeEnterpriseURL := "https://api.githubenterprise.example.com/"

client, err := NewClient(12345, string(key), fakeEnterpriseURL, true)
assert.NoError(t, err)
assert.NotNil(t, client)
assert.Contains(t, client.Client.BaseURL.String(), fakeEnterpriseURL)
}

func TestNewClient_Failure(t *testing.T) {
client, err := NewClient(12345, "")
client, err := NewClient(12345, "", "", false)
assert.Error(t, err)
assert.Nil(t, client)
}
Expand All @@ -30,9 +78,32 @@ func TestClientInstallation(t *testing.T) {
t.Fatal(err)
}

client, err := NewClient(12345, string(key))
client, err := NewClient(12345, string(key), "", false)
assert.NoError(t, err)
assert.NotNil(t, client)

installation := client.Installation(12345)
assert.NotNil(t, installation)

// The installation client should still have a valid BaseURL
assert.Contains(t, installation.BaseURL.String(), "https://api.github.com/")
}

func TestClientInstallation_Enterprise(t *testing.T) {
key, err := os.ReadFile("testdata/test.key")
if err != nil {
t.Fatal(err)
}

fakeEnterpriseURL := "https://api.githubenterprise.example.com/"

client, err := NewClient(12345, string(key), fakeEnterpriseURL, true)
assert.NoError(t, err)
assert.NotNil(t, client)

installation := client.Installation(12345)
assert.NotNil(t, installation)

// Should retain enterprise base URL
assert.Contains(t, installation.BaseURL.String(), fakeEnterpriseURL)
}
48 changes: 31 additions & 17 deletions server/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,21 +28,23 @@ type MetricsConfig struct {
}

type GitHubConfig struct {
AppPrivateKey string `yaml:"app_private_key" validate:"required"`
AppID int64 `yaml:"app_id" validate:"required"`
AppPrivateKey string `yaml:"app_private_key" validate:"required"`
AppID int64 `yaml:"app_id" validate:"required"`
EnterpriseApiUrl string `yaml:"enterprise_api_url" validate:"omitempty,url"`
SkipTLSVerify bool `yaml:"skip_tls_verify" validate:"skiptls_if_enterprise"`
}
Comment thread
coderabbitai[bot] marked this conversation as resolved.

type RunnerConfig struct {
Name string `yaml:"name" validate:"required"`
ImagePullPolicy string `yaml:"image_pull_policy" validate:"required,oneof=always never ifnotpresent"`
ImagePullPolicy string `yaml:"image_pull_policy" validate:"required,oneof=Always Never IfNotPresent"`
Image string `yaml:"image" validate:"required"`
Organization string `yaml:"organization" validate:"required"`
GroupID int64 `yaml:"group_id" validate:"required"`
Labels []string `yaml:"labels" validate:"required"`
}

type FirecrackerConfig struct {
BinaryPath string `yaml:"binary_path" `
BinaryPath string `yaml:"binary_path"`
KernelImagePath string `yaml:"kernel_image_path"`
KernelArgs string `yaml:"kernel_args"`
MachineConfig FirecrackerMachineConfig `yaml:"machine_config"`
Expand All @@ -56,32 +58,33 @@ type FirecrackerMachineConfig struct {

// DefaultConfig creates a new Config with default values.
func DefaultConfig() *Config {
c := &Config{
return &Config{
BindAddress: ":8080",
Metrics: &MetricsConfig{Enabled: true, Address: ":8081"},
BasicAuthEnabled: false,
BasicAuthUsers: map[string]string{},
GitHub: &GitHubConfig{AppPrivateKey: "", AppID: 0},
Pools: []*PoolConfig{},
LogLevel: "debug",
Debug: false,
GitHub: &GitHubConfig{
AppPrivateKey: "",
AppID: 0,
EnterpriseApiUrl: "", // empty = GitHub.com
SkipTLSVerify: false, // default: do not skip TLS
},
Pools: []*PoolConfig{},
LogLevel: "debug",
Debug: false,
}

return c
}

// NewConfigFromFile creates a new Config from a file.
func NewConfig(path string) (*Config, error) {
c := DefaultConfig()
c.path = path

err := c.Load()
if err != nil {
if err := c.Load(); err != nil {
return nil, err
}

err = c.Validate()
if err != nil {
if err := c.Validate(); err != nil {
return nil, fmt.Errorf("validate: %w", err)
}

Expand All @@ -94,13 +97,24 @@ func (c *Config) Load() error {
if err != nil {
return fmt.Errorf("open file: %w", err)
}

defer file.Close()

return yaml.NewDecoder(file).Decode(c)
}

// Validate validates the configuration.
func (c *Config) Validate() error {
return validator.New().Struct(c)
v := validator.New()

// Custom validation: SkipTLSVerify can only be true if EnterpriseApiUrl is set
_ = v.RegisterValidation("skiptls_if_enterprise", func(fl validator.FieldLevel) bool {
cfg := fl.Parent().Interface().(GitHubConfig)
if cfg.SkipTLSVerify && cfg.EnterpriseApiUrl == "" {
return false
}
return true
})

// Apply struct validation
return v.Struct(c)
}
9 changes: 8 additions & 1 deletion server/config_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,15 @@ import (
func TestNewConfig(t *testing.T) {
config, err := NewConfig("testdata/config1.yaml")
if err != nil {
t.Errorf("unexpected error: %v", err)
t.Fatalf("unexpected error: %v", err)
}

assert.Equal(t, "testdata/config1.yaml", config.path)

// Check GitHub config values
assert.NotNil(t, config.GitHub)
assert.NotEmpty(t, config.GitHub.AppPrivateKey)
assert.NotZero(t, config.GitHub.AppID)
assert.Equal(t, "https://api.githubenterprise.example.com/api/v3", config.GitHub.EnterpriseApiUrl)
assert.True(t, config.GitHub.SkipTLSVerify)
}
Loading