From 0953af3350263ae20c093331e10539bd0d7b31dc Mon Sep 17 00:00:00 2001 From: razorao Date: Tue, 17 Aug 2021 15:50:20 +0530 Subject: [PATCH] Issue #228: Fix authorization in GRPC calls --Wrote a custom auth interceptor which calls custom AuthFuncOverride with request body --Custom version of AuthFuncOverride accepting request body --Modified implementations of AuthFuncOverride in admin, publisher and subscriber server to accept request body --New file: metro/service/web/server.go having the implementation of fetching the project id from request payload --Fetch project-id from request payload using resources' name and extracted using regex capturing group --- internal/interceptors/auth.go | 69 ++++++++++++++++++++------------- service/web/adminserver.go | 2 +- service/web/publisherserver.go | 9 ++++- service/web/server.go | 66 +++++++++++++++++++++++++++++++ service/web/service.go | 7 ++-- service/web/subscriberserver.go | 9 ++++- 6 files changed, 127 insertions(+), 35 deletions(-) create mode 100644 service/web/server.go diff --git a/internal/interceptors/auth.go b/internal/interceptors/auth.go index 83e970154..ce948a890 100644 --- a/internal/interceptors/auth.go +++ b/internal/interceptors/auth.go @@ -19,12 +19,29 @@ const ( authorizationHeaderKey = "authorization" ) -// UnaryServerAuthInterceptor creates an authenticator interceptor with the given AuthFunc +// serviceAuthFuncOverride - An interface to check if the server implements the authFuncOveride method +type serviceAuthFuncOverride interface { + AuthFuncOverride(ctx context.Context, fullMethodName string, req interface{}) (context.Context, error) +} + +// UnaryServerAuthInterceptor - creates an authenticator interceptor with the given AuthFunc func UnaryServerAuthInterceptor(authFunc grpc_auth.AuthFunc) grpc.UnaryServerInterceptor { - return grpc_auth.UnaryServerInterceptor(authFunc) + return func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) { + var newCtx context.Context + var err error + if overrideSrv, ok := info.Server.(serviceAuthFuncOverride); ok { + newCtx, err = overrideSrv.AuthFuncOverride(ctx, info.FullMethod, req) + } else { + newCtx, err = authFunc(ctx) + } + if err != nil { + return nil, err + } + return handler(newCtx, req) + } } -func getUserPasswordProjectID(ctx context.Context) (user string, password []byte, projectID string, err error) { +func getUserPassword(ctx context.Context) (user string, password []byte, err error) { md, ok := metadata.FromIncomingContext(ctx) if !ok { err = status.Error(codes.Unauthenticated, "could not parse incoming context") @@ -33,9 +50,9 @@ func getUserPasswordProjectID(ctx context.Context) (user string, password []byte // `uri` gets set inside `runtime.WithMetadata()`. check `server.go` for implementation // this is ideally a single-valued slice - if val := md.Get("uri"); val != nil && len(val) > 0 { - projectID = extractProjectIDFromURI(val[0]) - } + // if val := md.Get("uri"); val != nil && len(val) > 0 { + // projectID = extractProjectIDFromURI(val[0]) + // } headers := md.Get(authorizationHeaderKey) if len(headers) != 1 { @@ -78,21 +95,21 @@ func secureCompare(expected, actual string) bool { // Example1: extractProjectIDFromURI("/v1/projects/project1/topics/t123") = project1 // Example2: extractProjectIDFromURI("/v1/projects/project7/subscriptions/s123") = project7 // Example3: extractProjectIDFromURI("/v1/admin/topic/t987") = "" -func extractProjectIDFromURI(uri string) string { - if uri == "" { - return "" - } - - parts := strings.Split(uri, "/") - if len(parts) >= 4 && parts[2] == "projects" && parts[3] != "" { - return parts[3] - } - return "" -} +// func extractProjectIDFromURI(uri string) string { +// if uri == "" { +// return "" +// } + +// parts := strings.Split(uri, "/") +// if len(parts) >= 4 && parts[2] == "projects" && parts[3] != "" { +// return parts[3] +// } +// return "" +// } // AppAuth implements app project based basic auth validations -func AppAuth(ctx context.Context, credentialCore credentials.ICore) (context.Context, error) { - user, password, uriProjectID, err := getUserPasswordProjectID(ctx) +func AppAuth(ctx context.Context, credentialCore credentials.ICore, resourceProjectID string) (context.Context, error) { + user, password, err := getUserPassword(ctx) if err != nil { return ctx, err } @@ -108,25 +125,25 @@ func AppAuth(ctx context.Context, credentialCore credentials.ICore) (context.Con return nil, status.Error(codes.Unauthenticated, "Unauthenticated") } - // match the credential projectID and the uri projectID - // this way we enforce that the credential is accessing only its own projectID's resources - if !strings.EqualFold(uriProjectID, credential.GetProjectID()) { - return nil, status.Error(codes.Unauthenticated, "Unauthenticated") - } - expectedPassword := credential.GetPassword() // check the header password matches the expected password if !secureCompare(expectedPassword, string(password)) { return nil, status.Error(codes.Unauthenticated, "Unauthenticated") } + // match the credential projectID and the resource projectID + // this way we enforce that the credential is accessing only its own projectID's resources + if !strings.EqualFold(resourceProjectID, credential.GetProjectID()) { + return nil, status.Error(codes.PermissionDenied, "Unauthorized") + } + newCtx := context.WithValue(ctx, credentials.CtxKey.String(), credentials.NewCredential(user, string(password))) return newCtx, nil } // AdminAuth implements admin credentials based basic auth validations func AdminAuth(ctx context.Context, admin *credentials.Model) (context.Context, error) { - user, password, _, err := getUserPasswordProjectID(ctx) + user, password, err := getUserPassword(ctx) if err != nil { return ctx, err } diff --git a/service/web/adminserver.go b/service/web/adminserver.go index 4755828e7..32833c6cc 100644 --- a/service/web/adminserver.go +++ b/service/web/adminserver.go @@ -172,6 +172,6 @@ func (s adminServer) DeleteProjectCredentials(ctx context.Context, req *metrov1. return &emptypb.Empty{}, nil } -func (s adminServer) AuthFuncOverride(ctx context.Context, _ string) (context.Context, error) { +func (s adminServer) AuthFuncOverride(ctx context.Context, _ string, _ interface{}) (context.Context, error) { return interceptors.AdminAuth(ctx, s.admin) } diff --git a/service/web/publisherserver.go b/service/web/publisherserver.go index 3f33f3f06..43dcd145a 100644 --- a/service/web/publisherserver.go +++ b/service/web/publisherserver.go @@ -94,6 +94,11 @@ func (s publisherServer) DeleteTopic(ctx context.Context, req *metrov1.DeleteTop return &emptypb.Empty{}, nil } -func (s publisherServer) AuthFuncOverride(ctx context.Context, _ string) (context.Context, error) { - return interceptors.AppAuth(ctx, s.credentialsCore) +//AuthFuncOverride - Override function called by the auth interceptor +func (s publisherServer) AuthFuncOverride(ctx context.Context, _ string, req interface{}) (context.Context, error) { + projectID, err := getProjectIDFromRequest(ctx, req) + if err != nil { + return ctx, err + } + return interceptors.AppAuth(ctx, s.credentialsCore, projectID) } diff --git a/service/web/server.go b/service/web/server.go new file mode 100644 index 000000000..ed6eb1d17 --- /dev/null +++ b/service/web/server.go @@ -0,0 +1,66 @@ +package web + +import ( + "context" + "regexp" + + "github.com/razorpay/metro/internal/merror" + "github.com/razorpay/metro/pkg/logger" + metrov1 "github.com/razorpay/metro/rpc/proto/v1" +) + +var projectIDRegex *regexp.Regexp + +func init() { + // Regex to capture project id from resource names + // Resources are in the form /projects/// + projectIDRegex = regexp.MustCompile(`^projects\/([^\/]+)\/.*`) +} + +func getProjectIDFromRequest(ctx context.Context, req interface{}) (string, error) { + resourceName, err := getResourceNameFromRequest(ctx, req) + if err != nil { + return "", nil + } + return getProjectIDFromResourceName(ctx, resourceName) +} + +func getResourceNameFromRequest(ctx context.Context, req interface{}) (string, error) { + switch t := req.(type) { + case *metrov1.PublishRequest: + return req.(*metrov1.PublishRequest).Topic, nil + case *metrov1.Topic: + return req.(*metrov1.Topic).Name, nil + case *metrov1.DeleteTopicRequest: + return req.(*metrov1.DeleteTopicRequest).Topic, nil + case *metrov1.Subscription: + return req.(*metrov1.Subscription).Name, nil + case *metrov1.UpdateSubscriptionRequest: + return req.(*metrov1.UpdateSubscriptionRequest).Subscription.Name, nil + case *metrov1.AcknowledgeRequest: + return req.(*metrov1.AcknowledgeRequest).Subscription, nil + case *metrov1.PullRequest: + return req.(*metrov1.PullRequest).Subscription, nil + case *metrov1.DeleteSubscriptionRequest: + return req.(*metrov1.DeleteSubscriptionRequest).Subscription, nil + case *metrov1.ModifyAckDeadlineRequest: + return req.(*metrov1.ModifyAckDeadlineRequest).Subscription, nil + default: + logger.Ctx(ctx).Infof("unknown request type: %v", t) + err := merror.New(merror.Unimplemented, "unknown resource type") + return "", err + } +} + +// getProjectIDFromResourceName - Fetches the project id from resource name using regex capturing group +// Example: projects/project001/subscriptions/subscription001 -> project001 +// projects/project001/topics/topic001 -> project001 +// topics/topic001 -> invalid +func getProjectIDFromResourceName(ctx context.Context, resourceName string) (string, error) { + matches := projectIDRegex.FindStringSubmatch(resourceName) + if len(matches) < 2 { + logger.Ctx(ctx).Warnw("could not extract project id from resource name", "resource name", resourceName) + return "", nil + } + return matches[1], nil +} diff --git a/service/web/service.go b/service/web/service.go index 709c397af..5caa9e436 100644 --- a/service/web/service.go +++ b/service/web/service.go @@ -7,7 +7,6 @@ import ( "golang.org/x/sync/errgroup" "google.golang.org/grpc" - "github.com/razorpay/metro/internal/app" "github.com/razorpay/metro/internal/brokerstore" "github.com/razorpay/metro/internal/credentials" "github.com/razorpay/metro/internal/health" @@ -143,9 +142,9 @@ func (svc *Service) Start(ctx context.Context) error { func getInterceptors() []grpc.UnaryServerInterceptor { // skip auth from test mode executions - if app.IsTestMode() { - return []grpc.UnaryServerInterceptor{} - } + // if app.IsTestMode() { + // return []grpc.UnaryServerInterceptor{} + // } return []grpc.UnaryServerInterceptor{ interceptors.UnaryServerAuthInterceptor(func(ctx context.Context) (context.Context, error) { diff --git a/service/web/subscriberserver.go b/service/web/subscriberserver.go index e5cfcf180..f955904b9 100644 --- a/service/web/subscriberserver.go +++ b/service/web/subscriberserver.go @@ -229,6 +229,11 @@ func (s subscriberserver) ModifyAckDeadline(ctx context.Context, req *metrov1.Mo return new(emptypb.Empty), nil } -func (s subscriberserver) AuthFuncOverride(ctx context.Context, _ string) (context.Context, error) { - return interceptors.AppAuth(ctx, s.credentialCore) +//AuthFuncOverride - Override function called by the auth interceptor +func (s subscriberserver) AuthFuncOverride(ctx context.Context, _ string, req interface{}) (context.Context, error) { + projectID, err := getProjectIDFromRequest(ctx, req) + if err != nil { + return ctx, err + } + return interceptors.AppAuth(ctx, s.credentialCore, projectID) }