diff --git a/cmd/provider-services/cmd/run.go b/cmd/provider-services/cmd/run.go index 519e2752..89d67f54 100644 --- a/cmd/provider-services/cmd/run.go +++ b/cmd/provider-services/cmd/run.go @@ -691,6 +691,10 @@ func doRunCmd(ctx context.Context, cmd *cobra.Command, _ []string) error { config.BidDeposit = bidDeposit config.RPCQueryTimeout = rpcQueryTimeout config.CachedResultMaxAge = cachedResultMaxAge + config.ProviderSigner, err = provider.NewProviderSigner(cctx, cl.Tx()) + if err != nil { + return err + } // This value can be nil, the operator is not mandatory var ipOperatorClient cip.Client diff --git a/config.go b/config.go index e2248d27..cae87050 100644 --- a/config.go +++ b/config.go @@ -11,6 +11,7 @@ import ( "github.com/akash-network/provider/bidengine" "github.com/akash-network/provider/cluster" + ptypes "github.com/akash-network/provider/types" ) type Config struct { @@ -26,6 +27,7 @@ type Config struct { MaxGroupVolumes int RPCQueryTimeout time.Duration CachedResultMaxAge time.Duration + ProviderSigner ptypes.ProviderSigner cluster.Config } diff --git a/gateway/grpc/server.go b/gateway/grpc/server.go index 02af7964..fed1ee87 100644 --- a/gateway/grpc/server.go +++ b/gateway/grpc/server.go @@ -117,9 +117,9 @@ func authInterceptor() grpc.UnaryServerInterceptor { } if md, ok := metadata.FromIncomingContext(ctx); ok { - tokens := md["authorization"] - if len(tokens) == 1 { - tokString = tokens[1] + tokString, err = gwutils.AuthHeaderToken(md.Get("authorization")) + if err != nil { + return nil, err } } diff --git a/gateway/grpc/server_test.go b/gateway/grpc/server_test.go new file mode 100644 index 00000000..24df665e --- /dev/null +++ b/gateway/grpc/server_test.go @@ -0,0 +1,46 @@ +package grpc + +import ( + "context" + "testing" + + "github.com/stretchr/testify/require" + grpcpkg "google.golang.org/grpc" + "google.golang.org/grpc/metadata" + + ajwt "pkg.akt.dev/go/util/jwt" +) + +func TestAuthInterceptorAllowsNoAuthorizationHeader(t *testing.T) { + interceptor := authInterceptor() + handlerCalled := false + + resp, err := interceptor(context.Background(), nil, &grpcpkg.UnaryServerInfo{}, func(ctx context.Context, _ interface{}) (interface{}, error) { + handlerCalled = true + require.Equal(t, ajwt.AccessTypeNone, ClaimsFromCtx(ctx).Leases.Access) + + return "ok", nil + }) + + require.NoError(t, err) + require.Equal(t, "ok", resp) + require.True(t, handlerCalled) +} + +func TestAuthInterceptorSingleAuthorizationHeaderDoesNotPanic(t *testing.T) { + interceptor := authInterceptor() + ctx := metadata.NewIncomingContext(context.Background(), metadata.Pairs("authorization", "Bearer not-a-jwt")) + handlerCalled := false + var err error + + require.NotPanics(t, func() { + _, err = interceptor(ctx, nil, &grpcpkg.UnaryServerInfo{}, func(context.Context, interface{}) (interface{}, error) { + handlerCalled = true + + return nil, nil + }) + }) + + require.Error(t, err) + require.False(t, handlerCalled) +} diff --git a/gateway/rest/auth.go b/gateway/rest/auth.go index 19c6aec6..4fbbc6f7 100644 --- a/gateway/rest/auth.go +++ b/gateway/rest/auth.go @@ -5,7 +5,6 @@ import ( "errors" "log" "net/http" - "strings" gcontext "github.com/gorilla/context" @@ -16,17 +15,7 @@ import ( // AuthHeaderTokenExtractor is a TokenExtractor that takes a request // and extracts the token from the Authorization header. func AuthHeaderTokenExtractor(r *http.Request) (string, error) { - authHeader := r.Header.Get("Authorization") - if authHeader == "" { - return "", nil // No error, just no JWT. - } - - authHeaderParts := strings.Fields(authHeader) - if len(authHeaderParts) != 2 || strings.ToLower(authHeaderParts[0]) != "bearer" { - return "", httperror.ErrInvalidAuthHeader - } - - return authHeaderParts[1], nil + return gwutils.AuthHeaderToken(r.Header.Values("Authorization")) } func DefaultErrorHandler(w http.ResponseWriter, _ *http.Request, err error) { diff --git a/gateway/utils/auth.go b/gateway/utils/auth.go new file mode 100644 index 00000000..53ebb8e6 --- /dev/null +++ b/gateway/utils/auth.go @@ -0,0 +1,30 @@ +package utils + +import ( + "strings" + + "github.com/akash-network/provider/utils/httperror" +) + +// AuthHeaderToken extracts a bearer token from Authorization header values. +func AuthHeaderToken(authHeaders []string) (string, error) { + if len(authHeaders) == 0 { + return "", nil + } + + if len(authHeaders) != 1 { + return "", httperror.ErrInvalidAuthHeader + } + + authHeader := authHeaders[0] + if authHeader == "" { + return "", nil + } + + authHeaderParts := strings.Fields(authHeader) + if len(authHeaderParts) != 2 || strings.ToLower(authHeaderParts[0]) != "bearer" { + return "", httperror.ErrInvalidAuthHeader + } + + return authHeaderParts[1], nil +} diff --git a/gateway/utils/utils_test.go b/gateway/utils/utils_test.go index 4e913d1d..59bae198 100644 --- a/gateway/utils/utils_test.go +++ b/gateway/utils/utils_test.go @@ -116,6 +116,52 @@ func TestAuthProcess_NoTokenReturnsClaims(t *testing.T) { require.Equal(t, ajwt.AccessTypeNone, claims.Leases.Access) } +func TestAuthHeaderToken(t *testing.T) { + tests := []struct { + name string + headers []string + want string + wantErr error + }{ + { + name: "empty", + }, + { + name: "bearer", + headers: []string{"Bearer token"}, + want: "token", + }, + { + name: "lowercase bearer", + headers: []string{"bearer token"}, + want: "token", + }, + { + name: "malformed", + headers: []string{"token"}, + wantErr: httperror.ErrInvalidAuthHeader, + }, + { + name: "multiple headers", + headers: []string{"Bearer token-a", "Bearer token-b"}, + wantErr: httperror.ErrInvalidAuthHeader, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + token, err := AuthHeaderToken(test.headers) + if test.wantErr != nil { + require.ErrorIs(t, err, test.wantErr) + return + } + + require.NoError(t, err) + require.Equal(t, test.want, token) + }) + } +} + func mustCreateCertWithCN(t *testing.T, cn string) *x509.Certificate { t.Helper() key, err := rsa.GenerateKey(rand.Reader, 2048) diff --git a/signer.go b/signer.go new file mode 100644 index 00000000..e96a42ac --- /dev/null +++ b/signer.go @@ -0,0 +1,82 @@ +package provider + +import ( + "context" + "errors" + "fmt" + + "github.com/cosmos/cosmos-sdk/client" + "github.com/cosmos/cosmos-sdk/crypto/keyring" + sdk "github.com/cosmos/cosmos-sdk/types" + "github.com/cosmos/cosmos-sdk/types/tx/signing" + + aclient "pkg.akt.dev/go/node/client/v1beta3" + + ptypes "github.com/akash-network/provider/types" +) + +var ( + errProviderSignerMissingAddress = errors.New("provider signer missing address") + errProviderSignerMissingKeyring = errors.New("provider signer missing keyring") + errProviderSignerMissingTxClient = errors.New("provider signer missing tx client") + errProviderSignerUnexpectedBroadcastResponse = errors.New("provider signer unexpected broadcast response") +) + +type providerSigner struct { + address sdk.AccAddress + keyring keyring.Signer + tx aclient.TxClient +} + +var _ ptypes.ProviderSigner = (*providerSigner)(nil) + +func NewProviderSigner(cctx client.Context, tx aclient.TxClient) (ptypes.ProviderSigner, error) { + if len(cctx.FromAddress) == 0 { + return nil, errProviderSignerMissingAddress + } + + if cctx.Keyring == nil { + return nil, errProviderSignerMissingKeyring + } + + if tx == nil { + return nil, errProviderSignerMissingTxClient + } + + return &providerSigner{ + address: cctx.FromAddress, + keyring: cctx.Keyring, + tx: tx, + }, nil +} + +func (s *providerSigner) Address() sdk.AccAddress { + return s.address +} + +func (s *providerSigner) Sign(ctx context.Context, payload []byte) ([]byte, error) { + if err := ctx.Err(); err != nil { + return nil, err + } + + signature, _, err := s.keyring.SignByAddress(s.address, payload, signing.SignMode_SIGN_MODE_DIRECT) + if err != nil { + return nil, err + } + + return signature, nil +} + +func (s *providerSigner) Broadcast(ctx context.Context, msgs ...sdk.Msg) (*sdk.TxResponse, error) { + resp, err := s.tx.BroadcastMsgs(ctx, msgs, aclient.WithResultCodeAsError()) + if err != nil { + return nil, err + } + + txResp, ok := resp.(*sdk.TxResponse) + if !ok || txResp == nil { + return nil, fmt.Errorf("%w: %T", errProviderSignerUnexpectedBroadcastResponse, resp) + } + + return txResp, nil +} diff --git a/signer_test.go b/signer_test.go new file mode 100644 index 00000000..8264f67f --- /dev/null +++ b/signer_test.go @@ -0,0 +1,213 @@ +package provider + +import ( + "context" + "errors" + "testing" + + "github.com/stretchr/testify/require" + + sdkclient "github.com/cosmos/cosmos-sdk/client" + "github.com/cosmos/cosmos-sdk/crypto/hd" + "github.com/cosmos/cosmos-sdk/crypto/keyring" + sdk "github.com/cosmos/cosmos-sdk/types" + + aclient "pkg.akt.dev/go/node/client/v1beta3" + "pkg.akt.dev/go/sdkutil" + "pkg.akt.dev/go/testutil" +) + +type recordingTxClient struct { + response interface{} + err error + + called bool + msgs []sdk.Msg + opts []aclient.BroadcastOption +} + +func (c *recordingTxClient) BroadcastMsgs(_ context.Context, msgs []sdk.Msg, opts ...aclient.BroadcastOption) (interface{}, error) { + c.called = true + c.msgs = msgs + c.opts = opts + + return c.response, c.err +} + +func (c *recordingTxClient) BroadcastTx(context.Context, sdk.Tx, ...aclient.BroadcastOption) (interface{}, error) { + return nil, errors.New("unused") +} + +func newProviderSignerTestKeyring(t *testing.T) (testutil.Keyring, sdk.AccAddress) { + t.Helper() + + encCfg := sdkutil.MakeEncodingConfig() + kr := testutil.NewTestKeyring(encCfg.Codec) + record, _, err := kr.NewMnemonic( + "provider", + keyring.English, + sdk.FullFundraiserPath, + keyring.DefaultBIP39Passphrase, + hd.Secp256k1, + ) + require.NoError(t, err) + + addr, err := record.GetAddress() + require.NoError(t, err) + + return kr, addr +} + +func TestNewProviderSignerValidatesInputs(t *testing.T) { + kr, addr := newProviderSignerTestKeyring(t) + tx := &recordingTxClient{response: &sdk.TxResponse{}} + + tests := []struct { + name string + cctx sdkclient.Context + tx aclient.TxClient + wantErr error + }{ + { + name: "success", + cctx: sdkclient.Context{ + FromAddress: addr, + Keyring: kr, + }, + tx: tx, + }, + { + name: "missing address", + cctx: sdkclient.Context{ + Keyring: kr, + }, + tx: tx, + wantErr: errProviderSignerMissingAddress, + }, + { + name: "missing keyring", + cctx: sdkclient.Context{ + FromAddress: addr, + }, + tx: tx, + wantErr: errProviderSignerMissingKeyring, + }, + { + name: "missing tx client", + cctx: sdkclient.Context{ + FromAddress: addr, + Keyring: kr, + }, + wantErr: errProviderSignerMissingTxClient, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + signer, err := NewProviderSigner(test.cctx, test.tx) + if test.wantErr != nil { + require.ErrorIs(t, err, test.wantErr) + require.Nil(t, signer) + return + } + + require.NoError(t, err) + require.Equal(t, addr.String(), signer.Address().String()) + }) + } +} + +func TestProviderSignerSign(t *testing.T) { + kr, addr := newProviderSignerTestKeyring(t) + payload := []byte("inventory snapshot payload") + + signer, err := NewProviderSigner(sdkclient.Context{ + FromAddress: addr, + Keyring: kr, + }, &recordingTxClient{response: &sdk.TxResponse{}}) + require.NoError(t, err) + + signature, err := signer.Sign(context.Background(), payload) + require.NoError(t, err) + require.NotEmpty(t, signature) + + key, err := kr.KeyByAddress(addr) + require.NoError(t, err) + pubKey, err := key.GetPubKey() + require.NoError(t, err) + require.True(t, pubKey.VerifySignature(payload, signature)) +} + +func TestProviderSignerSignCanceledContext(t *testing.T) { + kr, addr := newProviderSignerTestKeyring(t) + ctx, cancel := context.WithCancel(context.Background()) + cancel() + + signer, err := NewProviderSigner(sdkclient.Context{ + FromAddress: addr, + Keyring: kr, + }, &recordingTxClient{response: &sdk.TxResponse{}}) + require.NoError(t, err) + + signature, err := signer.Sign(ctx, []byte("payload")) + require.ErrorIs(t, err, context.Canceled) + require.Nil(t, signature) +} + +func TestProviderSignerBroadcast(t *testing.T) { + kr, addr := newProviderSignerTestKeyring(t) + tx := &recordingTxClient{response: &sdk.TxResponse{TxHash: "txhash"}} + + signer, err := NewProviderSigner(sdkclient.Context{ + FromAddress: addr, + Keyring: kr, + }, tx) + require.NoError(t, err) + + response, err := signer.Broadcast(context.Background()) + require.NoError(t, err) + require.Equal(t, "txhash", response.TxHash) + require.True(t, tx.called) + require.Empty(t, tx.msgs) + require.Len(t, tx.opts, 1) +} + +func TestProviderSignerBroadcastErrors(t *testing.T) { + kr, addr := newProviderSignerTestKeyring(t) + broadcastErr := errors.New("broadcast failed") + + tests := []struct { + name string + tx *recordingTxClient + wantErr error + }{ + { + name: "broadcast error", + tx: &recordingTxClient{ + err: broadcastErr, + }, + wantErr: broadcastErr, + }, + { + name: "unexpected response", + tx: &recordingTxClient{ + response: struct{}{}, + }, + wantErr: errProviderSignerUnexpectedBroadcastResponse, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + signer, err := NewProviderSigner(sdkclient.Context{ + FromAddress: addr, + Keyring: kr, + }, test.tx) + require.NoError(t, err) + + response, err := signer.Broadcast(context.Background()) + require.ErrorIs(t, err, test.wantErr) + require.Nil(t, response) + }) + } +} diff --git a/types/types.go b/types/types.go index 4993b2df..0d8c2f61 100644 --- a/types/types.go +++ b/types/types.go @@ -19,3 +19,10 @@ const ( type AccountQuerier interface { GetAccountPublicKey(context.Context, sdk.Address) (cryptotypes.PubKey, error) } + +// ProviderSigner is the narrow signing and broadcast surface needed by AEP-86 provider components. +type ProviderSigner interface { + Address() sdk.AccAddress + Sign(context.Context, []byte) ([]byte, error) + Broadcast(context.Context, ...sdk.Msg) (*sdk.TxResponse, error) +}