Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 59 additions & 0 deletions packages/api/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,9 @@ const (
operationCallGetMFASessionStatus = "CallGetMFASessionStatus"
operationCallOrgRelayHeartBeat = "CallOrgRelayHeartBeat"
operationCallInstanceRelayHeartBeat = "CallInstanceRelayHeartBeat"
operationCallRelayLogin = "CallRelayLogin"
operationCallRelayConnect = "CallRelayConnect"
operationCallRelayHeartbeatV2 = "CallRelayHeartbeatV2"
operationCallIssueCertificate = "CallIssueCertificate"
operationCallRetrieveCertificate = "CallRetrieveCertificate"
operationCallGetCertificateBundle = "CallGetCertificateBundle"
Expand Down Expand Up @@ -901,6 +904,62 @@ func CallGetRelays(httpClient *resty.Client) (GetRelaysResponse, error) {
return resBody, nil
}

func CallRelayLogin(httpClient *resty.Client, request RelayLoginRequest) (RelayLoginResponse, error) {
var resBody RelayLoginResponse
response, err := httpClient.
R().
SetResult(&resBody).
SetHeader("User-Agent", USER_AGENT).
SetBody(request).
Post(fmt.Sprintf("%v/v2/relays/login", config.INFISICAL_URL))

if err != nil {
return RelayLoginResponse{}, NewGenericRequestError(operationCallRelayLogin, err)
}

if response.IsError() {
return RelayLoginResponse{}, NewAPIErrorWithResponse(operationCallRelayLogin, response, nil)
}

return resBody, nil
}

func CallRelayConnect(httpClient *resty.Client) (RelayConnectResponse, error) {
var resBody RelayConnectResponse
response, err := httpClient.
R().
SetResult(&resBody).
SetHeader("User-Agent", USER_AGENT).
Post(fmt.Sprintf("%v/v2/relays/connect", config.INFISICAL_URL))

if err != nil {
return RelayConnectResponse{}, NewGenericRequestError(operationCallRelayConnect, err)
}

if response.IsError() {
return RelayConnectResponse{}, NewAPIErrorWithResponse(operationCallRelayConnect, response, nil)
}

return resBody, nil
}

func CallRelayHeartbeatV2(httpClient *resty.Client) error {
response, err := httpClient.
R().
SetHeader("User-Agent", USER_AGENT).
Post(fmt.Sprintf("%v/v2/relays/heartbeat", config.INFISICAL_URL))

if err != nil {
return NewGenericRequestError(operationCallRelayHeartbeatV2, err)
}

if response.IsError() {
return NewAPIErrorWithResponse(operationCallRelayHeartbeatV2, response, nil)
}

return nil
}

func CallConnectGateway(httpClient *resty.Client, request ConnectGatewayRequest) (RegisterGatewayResponse, error) {
var resBody RegisterGatewayResponse
response, err := httpClient.
Expand Down
29 changes: 29 additions & 0 deletions packages/api/model.go
Original file line number Diff line number Diff line change
Expand Up @@ -986,6 +986,35 @@ type RelayHeartbeatRequest struct {
Name string `json:"name"`
}

type RelayLoginRequest struct {
Method string `json:"method"`
Token string `json:"token,omitempty"`
RelayID string `json:"relayId,omitempty"`
HTTPRequestMethod string `json:"iamHttpRequestMethod,omitempty"`
IamRequestBody string `json:"iamRequestBody,omitempty"`
IamRequestHeaders string `json:"iamRequestHeaders,omitempty"`
}

type RelayLoginResponse struct {
AccessToken string `json:"accessToken"`
RelayID string `json:"relayId"`
TokenType string `json:"tokenType"`
}

type RelayConnectResponse struct {
RelayID string `json:"relayId"`
PKI struct {
ServerCertificate string `json:"serverCertificate"`
ServerPrivateKey string `json:"serverPrivateKey"`
ClientCertificateChain string `json:"clientCertificateChain"`
} `json:"pki"`
SSH struct {
ServerCertificate string `json:"serverCertificate"`
ServerPrivateKey string `json:"serverPrivateKey"`
ClientCAPublicKey string `json:"clientCAPublicKey"`
} `json:"ssh"`
}

type AltName struct {
Type string `json:"type"`
Value string `json:"value"`
Expand Down
164 changes: 152 additions & 12 deletions packages/cmd/relay.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
"syscall"
"time"

"github.com/Infisical/infisical-merge/packages/api"
"github.com/Infisical/infisical-merge/packages/config"
gatewayv2 "github.com/Infisical/infisical-merge/packages/gateway-v2"
"github.com/Infisical/infisical-merge/packages/relay"
"github.com/Infisical/infisical-merge/packages/util"
Expand Down Expand Up @@ -39,29 +41,166 @@
util.HandleError(err, fmt.Sprintf("unable to get name flag or %s env", gatewayv2.RELAY_NAME_ENV_NAME))
}

host, err := util.GetCmdFlagOrEnv(cmd, "host", []string{gatewayv2.RELAY_HOST_ENV_NAME})
if err != nil || host == "" {
util.HandleError(err, fmt.Sprintf("unable to get host flag or %s env", gatewayv2.RELAY_HOST_ENV_NAME))
enrollMethod, _ := cmd.Flags().GetString("enroll-method")
if enrollMethod == "" {
enrollMethod = os.Getenv("INFISICAL_RELAY_ENROLL_METHOD")
}

host, _ := util.GetCmdFlagOrEnv(cmd, "host", []string{gatewayv2.RELAY_HOST_ENV_NAME})
if host == "" && enrollMethod == "" {
util.HandleError(fmt.Errorf("please provide host flag"), fmt.Sprintf("unable to get host flag or %s env", gatewayv2.RELAY_HOST_ENV_NAME))

Check warning on line 51 in packages/cmd/relay.go

View check run for this annotation

Claude / Claude Code Review

Invalid --enroll-method value silently falls through

Invalid `--enroll-method` values (e.g. a typo like `tokn`) silently bypass the host requirement at line 50, skip both enrollment branches, and fall through to the legacy machine-identity path, where the user sees a cryptic `no access token found` error far from the actual mistake. Since the flag help advertises `[token, aws]`, consider validating that a non-empty value is one of those two and erroring out early with a clear message.
Comment thread
saifsmailbox98 marked this conversation as resolved.
}

instanceType, err := util.GetCmdFlagOrEnvWithDefaultValue(cmd, "type", []string{gatewayv2.RELAY_TYPE_ENV_NAME}, "org")
if err != nil {
util.HandleError(err, fmt.Sprintf("unable to get type flag or %s env", gatewayv2.RELAY_TYPE_ENV_NAME))
}

var enrolledAccessToken string

// --- AWS Auth path ---
if enrollMethod == relay.EnrollMethodAws {
relayID, _ := cmd.Flags().GetString("relay-id")
if relayID == "" {
relayID = os.Getenv(relay.INFISICAL_RELAY_ID_KEY)
}
if relayID == "" {
stored, _ := relay.LoadStoredRelayID(relayName)
relayID = stored
}
if relayID == "" {
util.HandleError(errors.New("--relay-id is required when --enroll-method=aws"))
}

domain, _ := cmd.Flags().GetString("domain")
if domain != "" {
config.INFISICAL_URL = util.AppendAPIEndpoint(domain)
} else if storedDomain, _ := relay.LoadStoredDomain(relayName); storedDomain != "" {
config.INFISICAL_URL = util.AppendAPIEndpoint(storedDomain)
}

httpClient, err := util.GetRestyClientWithCustomHeaders()
if err != nil {
util.HandleError(err, "unable to create HTTP client")
}

log.Info().Msg("Authenticating relay via AWS Auth (STS GetCallerIdentity)...")
accessTokenStr, err := relay.LoginRelayWithAws(cmd.Context(), httpClient, relayID)
if err != nil {
util.HandleError(err, "AWS Auth login failed")
}

enrolledAccessToken = accessTokenStr

if err := relay.SaveRelayID(relayName, relayID); err != nil {
util.HandleError(err, "failed to save relay id to config")
}

effectiveDomain := domain
if effectiveDomain == "" {
effectiveDomain = config.INFISICAL_URL
}
if effectiveDomain != "" {
if err := relay.SaveDomain(relayName, effectiveDomain); err != nil {
util.HandleError(err, "failed to save domain to config")
}
}

log.Info().Msgf("Relay authenticated via AWS Auth. State saved to %s", relay.GetConfPathDisplay(relayName))
log.Info().Msg("Starting relay...")
}

// --- Enrollment token path ---
if enrollMethod == relay.EnrollMethodToken {
enrollToken, _ := cmd.Flags().GetString("token")
if enrollToken == "" {
util.HandleError(errors.New("--token is required when --enroll-method=token"))
}

storedEnrollToken, _ := relay.LoadStoredEnrollmentToken(relayName)
alreadyEnrolled := storedEnrollToken != "" && storedEnrollToken == enrollToken

if alreadyEnrolled {
log.Info().Msg("Enrollment token matches stored token. Skipping enrollment.")
} else {
domain, _ := cmd.Flags().GetString("domain")
if domain != "" {
config.INFISICAL_URL = util.AppendAPIEndpoint(domain)
}

Check failure on line 129 in packages/cmd/relay.go

View check run for this annotation

Claude / Claude Code Review

Token re-enrollment ignores stored domain

The token-enrollment path at lines 125-129 only honors `--domain` from the flag and does not fall back to `LoadStoredDomain` the way the AWS path at lines 75-80 does. As a result, a self-hosted user who initially enrolled with `--domain=https://corp.infisical.io` and later re-runs with a new enrollment token but no `--domain` will hit the default `https://app.infisical.com` at `CallRelayLogin` (line 137) and get a confusing `enrollment failed` error against the wrong backend. Fix: mirror the AWS
Comment thread
saifsmailbox98 marked this conversation as resolved.

httpClient, err := util.GetRestyClientWithCustomHeaders()
if err != nil {
util.HandleError(err, "unable to create HTTP client")
}

log.Info().Msg("Enrolling relay with enrollment token...")
enrollResp, err := api.CallRelayLogin(httpClient, api.RelayLoginRequest{
Method: "token",
Token: enrollToken,
})
if err != nil {
util.HandleError(err, "enrollment failed")
}

enrolledAccessToken = enrollResp.AccessToken
if err := relay.SaveAccessToken(relayName, enrollResp.AccessToken); err != nil {
util.HandleError(err, "failed to save relay access token")
}
if err := relay.SaveEnrollmentToken(relayName, enrollToken); err != nil {
util.HandleError(err, "failed to save enrollment token to config")
}

effectiveDomain := domain
if effectiveDomain == "" {
effectiveDomain = config.INFISICAL_URL
}
if effectiveDomain != "" {
if err := relay.SaveDomain(relayName, effectiveDomain); err != nil {
util.HandleError(err, "failed to save domain to config")
}
}

log.Info().Msgf("Relay enrolled successfully. Access token saved to %s", relay.GetConfPathDisplay(relayName))
}

log.Info().Msg("Starting relay...")
}

// --- Domain resolution for resource auth / stored token ---
isResourceAuth := enrollMethod == relay.EnrollMethodToken || enrollMethod == relay.EnrollMethodAws
if isResourceAuth {
if flagDomain, _ := cmd.Flags().GetString("domain"); flagDomain != "" {
config.INFISICAL_URL = util.AppendAPIEndpoint(flagDomain)
} else if storedDomain, _ := relay.LoadStoredDomain(relayName); storedDomain != "" {
config.INFISICAL_URL = util.AppendAPIEndpoint(storedDomain)
}
}

relayInstance, err := relay.NewRelay(&relay.RelayConfig{
RelayName: relayName,
SSHPort: "2222",
TLSPort: "8443",
Host: host,
Type: instanceType,
RelayName: relayName,
SSHPort: "2222",
TLSPort: "8443",
Host: host,
Type: instanceType,
EnrollMethod: enrollMethod,
})

if err != nil {
util.HandleError(err, "unable to create relay instance")
}

if instanceType == "instance" {
if isResourceAuth {
// Use the freshly enrolled token, or load the stored one.
if enrolledAccessToken != "" {
relayInstance.SetToken(enrolledAccessToken)
} else {
storedToken, err := relay.LoadStoredAccessToken(relayName)
if err != nil || storedToken == "" {
util.HandleError(errors.New("no stored access token found — re-run with enrollment token"))
}
relayInstance.SetToken(storedToken)
}
} else if instanceType == "instance" {
relayAuthSecret := os.Getenv(gatewayv2.RELAY_AUTH_SECRET_ENV_NAME)
if relayAuthSecret == "" {
util.HandleError(fmt.Errorf("%s is not set", gatewayv2.RELAY_AUTH_SECRET_ENV_NAME), "unable to get relay auth secret")
Expand Down Expand Up @@ -96,7 +235,6 @@
cancelCmd()
cancelSdk()

// Give graceful shutdown 10 seconds, then force exit on second signal
select {
case <-sigCh:
log.Warn().Msg("Second signal received, force exit triggered")
Expand All @@ -107,7 +245,6 @@
}
}()

// Token refresh goroutine - runs every 10 seconds
go func() {
tokenRefreshTicker := time.NewTicker(10 * time.Second)
defer tokenRefreshTicker.Stop()
Expand Down Expand Up @@ -259,8 +396,11 @@
relayStartCmd.Flags().String("type", "", "The type of relay to run. Defaults to 'org'")
relayStartCmd.Flags().String("host", "", "The IP or hostname for the relay")
relayStartCmd.Flags().String("name", "", "The name of the relay")
relayStartCmd.Flags().String("token", "", "connect with Infisical using machine identity access token. if not provided, you must set the auth-method flag")
relayStartCmd.Flags().String("token", "", "connect with Infisical using machine identity access token, or a one-time enrollment token when --enroll-method=token")
relayStartCmd.Flags().String("auth-method", "", "login method [universal-auth, kubernetes, azure, gcp-id-token, gcp-iam, aws-iam, oidc-auth]. if not provided, you must set the token flag")
relayStartCmd.Flags().String("enroll-method", "", "relay auth method [token, aws]. when set to 'token', uses --token as a one-time enrollment token. when set to 'aws', authenticates via signed STS GetCallerIdentity using --relay-id")
relayStartCmd.Flags().String("relay-id", "", "relay id (required when --enroll-method=aws)")
relayStartCmd.Flags().String("domain", "", "domain of your self-hosted Infisical instance (used with --enroll-method)")
relayStartCmd.Flags().String("client-id", "", "client id for universal auth")
relayStartCmd.Flags().String("client-secret", "", "client secret for universal auth")
relayStartCmd.Flags().String("machine-identity-id", "", "machine identity id for kubernetes, azure, gcp-id-token, gcp-iam, and aws-iam auth methods")
Expand Down
75 changes: 75 additions & 0 deletions packages/relay/aws_auth.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
package relay

import (
"context"
"crypto/sha256"
"encoding/base64"
"encoding/json"
"errors"
"fmt"
"net/http"
"strings"
"time"

"github.com/Infisical/infisical-merge/packages/api"
v4 "github.com/aws/aws-sdk-go-v2/aws/signer/v4"
"github.com/go-resty/resty/v2"
infisicalSdkUtil "github.com/infisical/go-sdk/packages/util"
)

func LoginRelayWithAws(ctx context.Context, httpClient *resty.Client, relayID string) (string, error) {
if relayID == "" {
return "", errors.New("--relay-id is required when --enroll-method=aws")
}

awsCredentials, awsRegion, err := infisicalSdkUtil.RetrieveAwsCredentials()
if err != nil {
return "", fmt.Errorf("unable to retrieve AWS credentials: %w", err)
}

iamRequestURL := fmt.Sprintf("https://sts.%s.amazonaws.com/", awsRegion)
iamRequestBody := "Action=GetCallerIdentity&Version=2011-06-15"

req, err := http.NewRequest(http.MethodPost, iamRequestURL, strings.NewReader(iamRequestBody))
if err != nil {
return "", fmt.Errorf("error building STS request: %w", err)
}

req.Header.Set("Content-Type", "application/x-www-form-urlencoded; charset=utf-8")

hash := sha256.New()
hash.Write([]byte(iamRequestBody))
payloadHash := fmt.Sprintf("%x", hash.Sum(nil))

signer := v4.NewSigner()
if err := signer.SignHTTP(ctx, awsCredentials, req, payloadHash, "sts", awsRegion, time.Now()); err != nil {
return "", fmt.Errorf("error signing STS request: %w", err)
}

headers := make(map[string]string)
for name, values := range req.Header {
if strings.ToLower(name) == "content-length" {
continue
}
headers[name] = values[0]
}
headers["Host"] = fmt.Sprintf("sts.%s.amazonaws.com", awsRegion)

headersJSON, err := json.Marshal(headers)
if err != nil {
return "", fmt.Errorf("error marshalling headers: %w", err)
}

resp, err := api.CallRelayLogin(httpClient, api.RelayLoginRequest{
Method: EnrollMethodAws,
RelayID: relayID,
HTTPRequestMethod: req.Method,
IamRequestBody: base64.StdEncoding.EncodeToString([]byte(iamRequestBody)),
IamRequestHeaders: base64.StdEncoding.EncodeToString(headersJSON),
})
if err != nil {
return "", err
}

return resp.AccessToken, nil
}
Loading
Loading