Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions cmd/provider-services/cmd/run.go
Original file line number Diff line number Diff line change
Expand Up @@ -691,6 +691,10 @@ func doRunCmd(ctx context.Context, cmd *cobra.Command, _ []string) error {
config.BidDeposit = bidDeposit
config.RPCQueryTimeout = rpcQueryTimeout
config.CachedResultMaxAge = cachedResultMaxAge
config.ProviderSigner, err = provider.NewProviderSigner(cctx, cl.Tx())
if err != nil {
return err
}

// This value can be nil, the operator is not mandatory
var ipOperatorClient cip.Client
Expand Down
2 changes: 2 additions & 0 deletions config.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import (

"github.com/akash-network/provider/bidengine"
"github.com/akash-network/provider/cluster"
ptypes "github.com/akash-network/provider/types"
)

type Config struct {
Expand All @@ -26,6 +27,7 @@ type Config struct {
MaxGroupVolumes int
RPCQueryTimeout time.Duration
CachedResultMaxAge time.Duration
ProviderSigner ptypes.ProviderSigner
cluster.Config
}

Expand Down
6 changes: 3 additions & 3 deletions gateway/grpc/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -117,9 +117,9 @@ func authInterceptor() grpc.UnaryServerInterceptor {
}

if md, ok := metadata.FromIncomingContext(ctx); ok {
tokens := md["authorization"]
if len(tokens) == 1 {
tokString = tokens[1]
tokString, err = gwutils.AuthHeaderToken(md.Get("authorization"))
if err != nil {
return nil, err
}
}

Expand Down
46 changes: 46 additions & 0 deletions gateway/grpc/server_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
package grpc

import (
"context"
"testing"

"github.com/stretchr/testify/require"
grpcpkg "google.golang.org/grpc"
"google.golang.org/grpc/metadata"

ajwt "pkg.akt.dev/go/util/jwt"
)

func TestAuthInterceptorAllowsNoAuthorizationHeader(t *testing.T) {
interceptor := authInterceptor()
handlerCalled := false

resp, err := interceptor(context.Background(), nil, &grpcpkg.UnaryServerInfo{}, func(ctx context.Context, _ interface{}) (interface{}, error) {
handlerCalled = true
require.Equal(t, ajwt.AccessTypeNone, ClaimsFromCtx(ctx).Leases.Access)

return "ok", nil
})

require.NoError(t, err)
require.Equal(t, "ok", resp)
require.True(t, handlerCalled)
}

func TestAuthInterceptorSingleAuthorizationHeaderDoesNotPanic(t *testing.T) {
interceptor := authInterceptor()
ctx := metadata.NewIncomingContext(context.Background(), metadata.Pairs("authorization", "Bearer not-a-jwt"))
handlerCalled := false
var err error

require.NotPanics(t, func() {
_, err = interceptor(ctx, nil, &grpcpkg.UnaryServerInfo{}, func(context.Context, interface{}) (interface{}, error) {
handlerCalled = true

return nil, nil
})
})

require.Error(t, err)
require.False(t, handlerCalled)
}
13 changes: 1 addition & 12 deletions gateway/rest/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ import (
"errors"
"log"
"net/http"
"strings"

gcontext "github.com/gorilla/context"

Expand All @@ -16,17 +15,7 @@ import (
// AuthHeaderTokenExtractor is a TokenExtractor that takes a request
// and extracts the token from the Authorization header.
func AuthHeaderTokenExtractor(r *http.Request) (string, error) {
authHeader := r.Header.Get("Authorization")
if authHeader == "" {
return "", nil // No error, just no JWT.
}

authHeaderParts := strings.Fields(authHeader)
if len(authHeaderParts) != 2 || strings.ToLower(authHeaderParts[0]) != "bearer" {
return "", httperror.ErrInvalidAuthHeader
}

return authHeaderParts[1], nil
return gwutils.AuthHeaderToken(r.Header.Values("Authorization"))
}

func DefaultErrorHandler(w http.ResponseWriter, _ *http.Request, err error) {
Expand Down
30 changes: 30 additions & 0 deletions gateway/utils/auth.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
package utils

import (
"strings"

"github.com/akash-network/provider/utils/httperror"
)

// AuthHeaderToken extracts a bearer token from Authorization header values.
func AuthHeaderToken(authHeaders []string) (string, error) {
if len(authHeaders) == 0 {
return "", nil
}

if len(authHeaders) != 1 {
return "", httperror.ErrInvalidAuthHeader
}

authHeader := authHeaders[0]
if authHeader == "" {
return "", nil
}

authHeaderParts := strings.Fields(authHeader)
if len(authHeaderParts) != 2 || strings.ToLower(authHeaderParts[0]) != "bearer" {
return "", httperror.ErrInvalidAuthHeader
}

return authHeaderParts[1], nil
}
46 changes: 46 additions & 0 deletions gateway/utils/utils_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,52 @@ func TestAuthProcess_NoTokenReturnsClaims(t *testing.T) {
require.Equal(t, ajwt.AccessTypeNone, claims.Leases.Access)
}

func TestAuthHeaderToken(t *testing.T) {
tests := []struct {
name string
headers []string
want string
wantErr error
}{
{
name: "empty",
},
{
name: "bearer",
headers: []string{"Bearer token"},
want: "token",
},
{
name: "lowercase bearer",
headers: []string{"bearer token"},
want: "token",
},
{
name: "malformed",
headers: []string{"token"},
wantErr: httperror.ErrInvalidAuthHeader,
},
{
name: "multiple headers",
headers: []string{"Bearer token-a", "Bearer token-b"},
wantErr: httperror.ErrInvalidAuthHeader,
},
}

for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
token, err := AuthHeaderToken(test.headers)
if test.wantErr != nil {
require.ErrorIs(t, err, test.wantErr)
return
}

require.NoError(t, err)
require.Equal(t, test.want, token)
})
}
}

func mustCreateCertWithCN(t *testing.T, cn string) *x509.Certificate {
t.Helper()
key, err := rsa.GenerateKey(rand.Reader, 2048)
Expand Down
82 changes: 82 additions & 0 deletions signer.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
package provider

import (
"context"
"errors"
"fmt"

"github.com/cosmos/cosmos-sdk/client"
"github.com/cosmos/cosmos-sdk/crypto/keyring"
sdk "github.com/cosmos/cosmos-sdk/types"
"github.com/cosmos/cosmos-sdk/types/tx/signing"

aclient "pkg.akt.dev/go/node/client/v1beta3"

ptypes "github.com/akash-network/provider/types"
)

var (
errProviderSignerMissingAddress = errors.New("provider signer missing address")
errProviderSignerMissingKeyring = errors.New("provider signer missing keyring")
errProviderSignerMissingTxClient = errors.New("provider signer missing tx client")
errProviderSignerUnexpectedBroadcastResponse = errors.New("provider signer unexpected broadcast response")
)

type providerSigner struct {
address sdk.AccAddress
keyring keyring.Signer
tx aclient.TxClient
}

var _ ptypes.ProviderSigner = (*providerSigner)(nil)

func NewProviderSigner(cctx client.Context, tx aclient.TxClient) (ptypes.ProviderSigner, error) {
if len(cctx.FromAddress) == 0 {
return nil, errProviderSignerMissingAddress
}

if cctx.Keyring == nil {
return nil, errProviderSignerMissingKeyring
}

if tx == nil {
return nil, errProviderSignerMissingTxClient
}

return &providerSigner{
address: cctx.FromAddress,
keyring: cctx.Keyring,
tx: tx,
}, nil
}

func (s *providerSigner) Address() sdk.AccAddress {
return s.address
}

func (s *providerSigner) Sign(ctx context.Context, payload []byte) ([]byte, error) {
if err := ctx.Err(); err != nil {
return nil, err
}

signature, _, err := s.keyring.SignByAddress(s.address, payload, signing.SignMode_SIGN_MODE_DIRECT)
if err != nil {
return nil, err
}

return signature, nil
}

func (s *providerSigner) Broadcast(ctx context.Context, msgs ...sdk.Msg) (*sdk.TxResponse, error) {
resp, err := s.tx.BroadcastMsgs(ctx, msgs, aclient.WithResultCodeAsError())
if err != nil {
return nil, err
}

txResp, ok := resp.(*sdk.TxResponse)
if !ok || txResp == nil {
return nil, fmt.Errorf("%w: %T", errProviderSignerUnexpectedBroadcastResponse, resp)
}

return txResp, nil
}
Loading
Loading