Skip to content

Commit

Permalink
Issue #228: Fix authorization in GRPC calls
Browse files Browse the repository at this point in the history
	--Wrote a custom auth interceptor which calls custom AuthFuncOverride with request body
	--Custom version of AuthFuncOverride accepting request body
	--Modified implementations of AuthFuncOverride in admin, publisher and subscriber server to accept request body
	--New file: metro/service/web/server.go having the implementation of fetching the project id from request payload
	--Fetch project-id from request payload using resources' name and extracted using regex capturing group
  • Loading branch information
razorao committed Aug 17, 2021
1 parent aaf256a commit 0953af3
Show file tree
Hide file tree
Showing 6 changed files with 127 additions and 35 deletions.
69 changes: 43 additions & 26 deletions internal/interceptors/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,29 @@ const (
authorizationHeaderKey = "authorization"
)

// UnaryServerAuthInterceptor creates an authenticator interceptor with the given AuthFunc
// serviceAuthFuncOverride - An interface to check if the server implements the authFuncOveride method
type serviceAuthFuncOverride interface {
AuthFuncOverride(ctx context.Context, fullMethodName string, req interface{}) (context.Context, error)
}

// UnaryServerAuthInterceptor - creates an authenticator interceptor with the given AuthFunc
func UnaryServerAuthInterceptor(authFunc grpc_auth.AuthFunc) grpc.UnaryServerInterceptor {
return grpc_auth.UnaryServerInterceptor(authFunc)
return func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) {
var newCtx context.Context
var err error
if overrideSrv, ok := info.Server.(serviceAuthFuncOverride); ok {
newCtx, err = overrideSrv.AuthFuncOverride(ctx, info.FullMethod, req)
} else {
newCtx, err = authFunc(ctx)
}
if err != nil {
return nil, err
}
return handler(newCtx, req)
}
}

func getUserPasswordProjectID(ctx context.Context) (user string, password []byte, projectID string, err error) {
func getUserPassword(ctx context.Context) (user string, password []byte, err error) {
md, ok := metadata.FromIncomingContext(ctx)
if !ok {
err = status.Error(codes.Unauthenticated, "could not parse incoming context")
Expand All @@ -33,9 +50,9 @@ func getUserPasswordProjectID(ctx context.Context) (user string, password []byte

// `uri` gets set inside `runtime.WithMetadata()`. check `server.go` for implementation
// this is ideally a single-valued slice
if val := md.Get("uri"); val != nil && len(val) > 0 {
projectID = extractProjectIDFromURI(val[0])
}
// if val := md.Get("uri"); val != nil && len(val) > 0 {
// projectID = extractProjectIDFromURI(val[0])
// }

headers := md.Get(authorizationHeaderKey)
if len(headers) != 1 {
Expand Down Expand Up @@ -78,21 +95,21 @@ func secureCompare(expected, actual string) bool {
// Example1: extractProjectIDFromURI("/v1/projects/project1/topics/t123") = project1
// Example2: extractProjectIDFromURI("/v1/projects/project7/subscriptions/s123") = project7
// Example3: extractProjectIDFromURI("/v1/admin/topic/t987") = ""
func extractProjectIDFromURI(uri string) string {
if uri == "" {
return ""
}

parts := strings.Split(uri, "/")
if len(parts) >= 4 && parts[2] == "projects" && parts[3] != "" {
return parts[3]
}
return ""
}
// func extractProjectIDFromURI(uri string) string {
// if uri == "" {
// return ""
// }

// parts := strings.Split(uri, "/")
// if len(parts) >= 4 && parts[2] == "projects" && parts[3] != "" {
// return parts[3]
// }
// return ""
// }

// AppAuth implements app project based basic auth validations
func AppAuth(ctx context.Context, credentialCore credentials.ICore) (context.Context, error) {
user, password, uriProjectID, err := getUserPasswordProjectID(ctx)
func AppAuth(ctx context.Context, credentialCore credentials.ICore, resourceProjectID string) (context.Context, error) {
user, password, err := getUserPassword(ctx)
if err != nil {
return ctx, err
}
Expand All @@ -108,25 +125,25 @@ func AppAuth(ctx context.Context, credentialCore credentials.ICore) (context.Con
return nil, status.Error(codes.Unauthenticated, "Unauthenticated")
}

// match the credential projectID and the uri projectID
// this way we enforce that the credential is accessing only its own projectID's resources
if !strings.EqualFold(uriProjectID, credential.GetProjectID()) {
return nil, status.Error(codes.Unauthenticated, "Unauthenticated")
}

expectedPassword := credential.GetPassword()
// check the header password matches the expected password
if !secureCompare(expectedPassword, string(password)) {
return nil, status.Error(codes.Unauthenticated, "Unauthenticated")
}

// match the credential projectID and the resource projectID
// this way we enforce that the credential is accessing only its own projectID's resources
if !strings.EqualFold(resourceProjectID, credential.GetProjectID()) {
return nil, status.Error(codes.PermissionDenied, "Unauthorized")
}

newCtx := context.WithValue(ctx, credentials.CtxKey.String(), credentials.NewCredential(user, string(password)))
return newCtx, nil
}

// AdminAuth implements admin credentials based basic auth validations
func AdminAuth(ctx context.Context, admin *credentials.Model) (context.Context, error) {
user, password, _, err := getUserPasswordProjectID(ctx)
user, password, err := getUserPassword(ctx)
if err != nil {
return ctx, err
}
Expand Down
2 changes: 1 addition & 1 deletion service/web/adminserver.go
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,6 @@ func (s adminServer) DeleteProjectCredentials(ctx context.Context, req *metrov1.
return &emptypb.Empty{}, nil
}

func (s adminServer) AuthFuncOverride(ctx context.Context, _ string) (context.Context, error) {
func (s adminServer) AuthFuncOverride(ctx context.Context, _ string, _ interface{}) (context.Context, error) {
return interceptors.AdminAuth(ctx, s.admin)
}
9 changes: 7 additions & 2 deletions service/web/publisherserver.go
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,11 @@ func (s publisherServer) DeleteTopic(ctx context.Context, req *metrov1.DeleteTop
return &emptypb.Empty{}, nil
}

func (s publisherServer) AuthFuncOverride(ctx context.Context, _ string) (context.Context, error) {
return interceptors.AppAuth(ctx, s.credentialsCore)
//AuthFuncOverride - Override function called by the auth interceptor
func (s publisherServer) AuthFuncOverride(ctx context.Context, _ string, req interface{}) (context.Context, error) {
projectID, err := getProjectIDFromRequest(ctx, req)
if err != nil {
return ctx, err
}
return interceptors.AppAuth(ctx, s.credentialsCore, projectID)
}
66 changes: 66 additions & 0 deletions service/web/server.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
package web

import (
"context"
"regexp"

"github.com/razorpay/metro/internal/merror"
"github.com/razorpay/metro/pkg/logger"
metrov1 "github.com/razorpay/metro/rpc/proto/v1"
)

var projectIDRegex *regexp.Regexp

func init() {
// Regex to capture project id from resource names
// Resources are in the form /projects/<project-id>/<resource-type>/<resource-id>
projectIDRegex = regexp.MustCompile(`^projects\/([^\/]+)\/.*`)
}

func getProjectIDFromRequest(ctx context.Context, req interface{}) (string, error) {
resourceName, err := getResourceNameFromRequest(ctx, req)
if err != nil {
return "", nil
}
return getProjectIDFromResourceName(ctx, resourceName)
}

func getResourceNameFromRequest(ctx context.Context, req interface{}) (string, error) {
switch t := req.(type) {
case *metrov1.PublishRequest:
return req.(*metrov1.PublishRequest).Topic, nil
case *metrov1.Topic:
return req.(*metrov1.Topic).Name, nil
case *metrov1.DeleteTopicRequest:
return req.(*metrov1.DeleteTopicRequest).Topic, nil
case *metrov1.Subscription:
return req.(*metrov1.Subscription).Name, nil
case *metrov1.UpdateSubscriptionRequest:
return req.(*metrov1.UpdateSubscriptionRequest).Subscription.Name, nil
case *metrov1.AcknowledgeRequest:
return req.(*metrov1.AcknowledgeRequest).Subscription, nil
case *metrov1.PullRequest:
return req.(*metrov1.PullRequest).Subscription, nil
case *metrov1.DeleteSubscriptionRequest:
return req.(*metrov1.DeleteSubscriptionRequest).Subscription, nil
case *metrov1.ModifyAckDeadlineRequest:
return req.(*metrov1.ModifyAckDeadlineRequest).Subscription, nil
default:
logger.Ctx(ctx).Infof("unknown request type: %v", t)
err := merror.New(merror.Unimplemented, "unknown resource type")
return "", err
}
}

// getProjectIDFromResourceName - Fetches the project id from resource name using regex capturing group
// Example: projects/project001/subscriptions/subscription001 -> project001
// projects/project001/topics/topic001 -> project001
// topics/topic001 -> invalid
func getProjectIDFromResourceName(ctx context.Context, resourceName string) (string, error) {
matches := projectIDRegex.FindStringSubmatch(resourceName)
if len(matches) < 2 {
logger.Ctx(ctx).Warnw("could not extract project id from resource name", "resource name", resourceName)
return "", nil
}
return matches[1], nil
}
7 changes: 3 additions & 4 deletions service/web/service.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@ import (
"golang.org/x/sync/errgroup"
"google.golang.org/grpc"

"github.com/razorpay/metro/internal/app"
"github.com/razorpay/metro/internal/brokerstore"
"github.com/razorpay/metro/internal/credentials"
"github.com/razorpay/metro/internal/health"
Expand Down Expand Up @@ -143,9 +142,9 @@ func (svc *Service) Start(ctx context.Context) error {

func getInterceptors() []grpc.UnaryServerInterceptor {
// skip auth from test mode executions
if app.IsTestMode() {
return []grpc.UnaryServerInterceptor{}
}
// if app.IsTestMode() {
// return []grpc.UnaryServerInterceptor{}
// }

return []grpc.UnaryServerInterceptor{
interceptors.UnaryServerAuthInterceptor(func(ctx context.Context) (context.Context, error) {
Expand Down
9 changes: 7 additions & 2 deletions service/web/subscriberserver.go
Original file line number Diff line number Diff line change
Expand Up @@ -229,6 +229,11 @@ func (s subscriberserver) ModifyAckDeadline(ctx context.Context, req *metrov1.Mo
return new(emptypb.Empty), nil
}

func (s subscriberserver) AuthFuncOverride(ctx context.Context, _ string) (context.Context, error) {
return interceptors.AppAuth(ctx, s.credentialCore)
//AuthFuncOverride - Override function called by the auth interceptor
func (s subscriberserver) AuthFuncOverride(ctx context.Context, _ string, req interface{}) (context.Context, error) {
projectID, err := getProjectIDFromRequest(ctx, req)
if err != nil {
return ctx, err
}
return interceptors.AppAuth(ctx, s.credentialCore, projectID)
}

0 comments on commit 0953af3

Please sign in to comment.