Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

net/http middleware functions #47

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
104 changes: 104 additions & 0 deletions example/zalando_nethttp/main.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
// Zalando specific example.
package main

import (
"flag"
"fmt"
"net/http"
"time"

"github.com/golang/glog"
"github.com/zalando/gin-oauth2"
"github.com/zalando/gin-oauth2/zalando"
"goji.io"
"goji.io/pat"
)

var USERS []zalando.AccessTuple = []zalando.AccessTuple{
{"/employees", "sszuecs", "Sandor Szücs"},
{"/employees", "njuettner", "Nick Jüttner"},
}

var TEAMS []zalando.AccessTuple = []zalando.AccessTuple{
{"teams", "opensourceguild", "OpenSource"},
{"teams", "tm", "Platform Engineering / System"},
{"teams", "teapot", "Platform / Cloud API"},
}
var SERVICES []zalando.AccessTuple = []zalando.AccessTuple{
{"services", "foo", "Fooservice"},
}

func loggerMiddleware(next http.Handler) http.Handler {
return http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
glog.Infof("loggerMiddleware: Got request: %s", req.URL)
next.ServeHTTP(rw, req)
})
}

func hello(w http.ResponseWriter, r *http.Request) {
name := pat.Param(r, "name")
fmt.Fprintf(w, "Hello, %s!\n", name)
}

func main() {
flag.Parse()
// start glog flusher
go func() {
for range time.Tick(1 * time.Second) {
glog.Flush()
}
}()

mux := goji.NewMux()
mux.Use(loggerMiddleware)
mux.Use(ginoauth2.RequestLoggerNetHTTP([]string{"uid"}, "data"))
ginoauth2.VarianceTimer = 3000 * time.Millisecond // defaults to 30s

public := goji.SubMux()
mux.Handle(pat.New("/api/*"), public)
public.HandleFunc(pat.Get("/:name"), hello)

private := goji.SubMux()
mux.Handle(pat.New("/private/*"), private)
privateGroup := goji.SubMux()
mux.Handle(pat.New("/privateGroup/*"), privateGroup)
privateUser := goji.SubMux()
mux.Handle(pat.New("/privateUser/*"), privateUser)
privateService := goji.SubMux()
mux.Handle(pat.New("/privateService/*"), privateService)
glog.Infof("Register allowed users: %+v and groups: %+v and services: %+v", USERS, TEAMS, SERVICES)

private.Use(ginoauth2.AuthChainNetHTTP(zalando.OAuth2Endpoint, zalando.UidCheckNetHTTP(USERS), zalando.GroupCheckNetHTTP(TEAMS), zalando.UidCheckNetHTTP(SERVICES)))
privateGroup.Use(ginoauth2.AuthNetHTTP(zalando.GroupCheckNetHTTP(TEAMS), zalando.OAuth2Endpoint))
privateUser.Use(ginoauth2.AuthNetHTTP(zalando.UidCheckNetHTTP(USERS), zalando.OAuth2Endpoint))
privateService.Use(ginoauth2.AuthNetHTTP(zalando.ScopeAndCheckNetHTTP("uidcheck", "uid", "bar"), zalando.OAuth2Endpoint))

private.HandleFunc(pat.Get("/"), http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
h := w.Header()
uid := h.Get("uid")
fmt.Fprintf(w, "Hello from private for groups and users: %s\n", uid)
}))

privateGroup.HandleFunc(pat.Get("/"), http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
h := w.Header()
uid := h.Get("uid")
team := h.Get("team")
fmt.Fprintf(w, "Hello from private group: uid: %s, team: %s\n", uid, team)
}))

privateUser.HandleFunc(pat.Get("/"), http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
h := w.Header()
uid := h.Get("uid")
fmt.Fprintf(w, "Hello from private user: uid: %s\n", uid)
}))

privateService.HandleFunc(pat.Get("/"), http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
h := w.Header()
cn := h.Get("cn")
fmt.Fprintf(w, "Hello from private service cn: %s\n", cn)
}))

glog.Info("bootstrapped application")
http.ListenAndServe("localhost:8081", mux)

}
108 changes: 101 additions & 7 deletions ginoauth2.go
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,10 @@ type TokenContainer struct {
// access.
type AccessCheckFunction func(tc *TokenContainer, ctx *gin.Context) bool

// AccessCheckFunctionNetHTTP is a function that checks if a given token grants
// access.
type AccessCheckFunctionNetHTTP func(tc *TokenContainer, w http.ResponseWriter, r *http.Request) bool

func extractToken(r *http.Request) (*oauth2.Token, error) {
hdr := r.Header.Get("Authorization")
if hdr == "" {
Expand Down Expand Up @@ -180,12 +184,12 @@ func GetTokenContainer(token *oauth2.Token) (*TokenContainer, error) {
return ParseTokenContainer(token, data)
}

func getTokenContainer(ctx *gin.Context) (*TokenContainer, bool) {
func getTokenContainer(r *http.Request) (*TokenContainer, bool) {
var oauthToken *oauth2.Token
var tc *TokenContainer
var err error

if oauthToken, err = extractToken(ctx.Request); err != nil {
if oauthToken, err = extractToken(r); err != nil {
glog.Errorf("[Gin-OAuth] Can not extract oauth2.Token, caused by: %s", err)
return nil, false
}
Expand Down Expand Up @@ -232,6 +236,11 @@ func Auth(accessCheckFunction AccessCheckFunction, endpoints oauth2.Endpoint) gi
return AuthChain(endpoints, accessCheckFunction)
}

// AuthNetHTTP is the net/http version of Auth
func AuthNetHTTP(accessCheckFunction AccessCheckFunctionNetHTTP, endpoints oauth2.Endpoint) func(http.Handler) http.Handler {
return AuthChainNetHTTP(endpoints, accessCheckFunction)
}

// AuthChain is a router middleware that can be used to get an authenticated
// and authorized service for the whole router group. Similar to Auth, but
// takes a chain of AccessCheckFunctions and only fails if all of them fails.
Expand Down Expand Up @@ -262,7 +271,7 @@ func AuthChain(endpoints oauth2.Endpoint, accessCheckFunctions ...AccessCheckFun
varianceControl := make(chan bool, 1)

go func() {
tokenContainer, ok := getTokenContainer(ctx)
tokenContainer, ok := getTokenContainer(ctx.Request)
if !ok {
// set LOCATION header to auth endpoint such that the user can easily get a new access-token
ctx.Writer.Header().Set("Location", endpoints.AuthURL)
Expand Down Expand Up @@ -309,6 +318,70 @@ func AuthChain(endpoints oauth2.Endpoint, accessCheckFunctions ...AccessCheckFun
}
}

// AuthChainNetHTTP is the net/http version of AuthChain
func AuthChainNetHTTP(endpoints oauth2.Endpoint, accessCheckFunctions ...AccessCheckFunctionNetHTTP) func(http.Handler) http.Handler {
// init
AuthInfoURL = endpoints.TokenURL
// middleware
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(writer http.ResponseWriter, request *http.Request) {
t := time.Now()
varianceControl := make(chan bool, 1)

go func() {
tokenContainer, ok := getTokenContainer(request)
if !ok {
// set LOCATION header to auth endpoint such that the user can easily get a new access-token
writer.Header().Set("Location", endpoints.AuthURL)
writer.WriteHeader(http.StatusUnauthorized)
writer.Write([]byte("No token in context"))
varianceControl <- false
return
}

if !tokenContainer.Valid() {
// set LOCATION header to auth endpoint such that the user can easily get a new access-token
writer.Header().Set("Location", endpoints.AuthURL)
writer.WriteHeader(http.StatusUnauthorized)
writer.Write([]byte("Invalid Token"))
varianceControl <- false
return
}

for i, fn := range accessCheckFunctions {
if fn(tokenContainer, writer, request) {
varianceControl <- true
break
}

if len(accessCheckFunctions)-1 == i {
writer.WriteHeader(http.StatusForbidden)
writer.Write([]byte("Access to the Resource is fobidden"))
varianceControl <- false
return
}
}
}()

select {
case ok := <-varianceControl:
if !ok {
glog.V(2).Infof("[Gin-OAuth] %12v %s access not allowed", time.Since(t), request.URL.Path)
return
}
case <-time.After(VarianceTimer):
writer.WriteHeader(http.StatusGatewayTimeout)
writer.Write([]byte("Authorization check overtime"))
glog.V(2).Infof("[Gin-OAuth] %12v %s overtime", time.Since(t), request.URL.Path)
return
}

glog.V(2).Infof("[Gin-OAuth] %12v %s access allowed", time.Since(t), request.URL.Path)
next.ServeHTTP(writer, request)
})
}
}

// RequestLogger is a middleware that logs all the request and prints
// relevant information. This can be used for logging all the
// requests that contain important information and are authorized.
Expand All @@ -333,12 +406,10 @@ func RequestLogger(keys []string, contentKey string) gin.HandlerFunc {
c.Next()
err := c.Errors
if request.Method != "GET" && err == nil {
data, e := c.Get(contentKey)
if e != false { //key is non existent
if data, ok := c.Get(contentKey); ok {
values := make([]string, 0)
for _, key := range keys {
val, keyPresent := c.Get(key)
if keyPresent {
if val, ok := c.Get(key); ok {
values = append(values, val.(string))
}
}
Expand All @@ -348,4 +419,27 @@ func RequestLogger(keys []string, contentKey string) gin.HandlerFunc {
}
}

// RequestLoggerNetHTTP is the net/http version of RequestLogger.
func RequestLoggerNetHTTP(keys []string, contentKey string) func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
request := r
next.ServeHTTP(w, r.WithContext(ctx))
if request.Method != "GET" {
if data, ok := ctx.Value(contentKey).(string); ok {
values := make([]string, 0)
for _, key := range keys {
s, ok := ctx.Value(key).(string)
if ok {
values = append(values, s)
}
}
glog.Infof("[Gin-OAuth] Request: %+v for %s", data, strings.Join(values, "-"))
}
}
})
}
}

// vim: ts=4 sw=4 noexpandtab nolist syn=go
91 changes: 90 additions & 1 deletion zalando/zalando.go
Original file line number Diff line number Diff line change
Expand Up @@ -102,10 +102,45 @@ func GroupCheck(at []AccessTuple) func(tc *ginoauth2.TokenContainer, ctx *gin.Co
}
}

// GroupCheckNetHTTP is the net/http version of GroupCheck
func GroupCheckNetHTTP(at []AccessTuple) func(tc *ginoauth2.TokenContainer, w http.ResponseWriter, r *http.Request) bool {
ats := at
return func(tc *ginoauth2.TokenContainer, w http.ResponseWriter, r *http.Request) bool {
blob, err := RequestTeamInfo(tc, TeamAPI)
if err != nil {
glog.Errorf("[Gin-OAuth] failed to get team info, caused by: %s", err)
return false
}
var data []TeamInfo
err = json.Unmarshal(blob, &data)
if err != nil {
glog.Errorf("[Gin-OAuth] JSON.Unmarshal failed, caused by: %s", err)
return false
}
granted := false
for _, teamInfo := range data {
for idx := range ats {
at := ats[idx]
if teamInfo.Id == at.Uid {
granted = true
glog.Infof("[Gin-OAuth] Grant access to %s as team member of \"%s\"\n", tc.Scopes["uid"].(string), teamInfo.Id)
}
if teamInfo.Type == "official" {
if uid, ok := tc.Scopes["uid"].(string); ok {
w.Header().Set("uid", uid)
w.Header().Set("team", teamInfo.Id)
}
}
}
}
return granted
}
}

// UidCheck is an authorization function that checks UID scope
// TokenContainer must be Valid. As side effect it sets "uid" and
// "cn" in the gin.Context to the authorized uid and cn (Realname).
func UidCheck(at []AccessTuple) func(tc *ginoauth2.TokenContainer, ctx *gin.Context) bool {
func UidCheck(at []AccessTuple) ginoauth2.AccessCheckFunction {
ats := at
return func(tc *ginoauth2.TokenContainer, ctx *gin.Context) bool {
uid := tc.Scopes["uid"].(string)
Expand All @@ -122,6 +157,24 @@ func UidCheck(at []AccessTuple) func(tc *ginoauth2.TokenContainer, ctx *gin.Cont
}
}

// UidCheckNetHTTP is the net/http version of UidCheck
func UidCheckNetHTTP(at []AccessTuple) ginoauth2.AccessCheckFunctionNetHTTP {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would it make sense to put this under a different package to not have the NetHTTP part in the function name?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this would be a bigger refactoring, because of the dependencies, but yes you are right would be good to do

ats := at
return func(tc *ginoauth2.TokenContainer, w http.ResponseWriter, r *http.Request) bool {
uid := tc.Scopes["uid"].(string)
for idx := range ats {
at := ats[idx]
if tc.Realm == at.Realm && uid == at.Uid {
w.Header().Set("uid", uid)
w.Header().Set("cn", at.Cn)
glog.Infof("[Gin-OAuth] Grant access to %s\n", uid)
return true
}
}
return false
}
}

// ScopeCheck does an OR check of scopes given from token of the
// request to all provided scopes. If one of provided scopes is in the
// Scopes of the token it grants access to the resource.
Expand All @@ -141,6 +194,23 @@ func ScopeCheck(name string, scopes ...string) func(tc *ginoauth2.TokenContainer
}
}

// ScopeCheckNetHTTP is the net/http version of ScopeCheck
func ScopeCheckNetHTTP(name string, scopes ...string) func(tc *ginoauth2.TokenContainer, w http.ResponseWriter, r *http.Request) bool {
glog.Infof("ScopeCheck %s configured to grant access for scopes: %v", name, scopes)
configuredScopes := scopes
return func(tc *ginoauth2.TokenContainer, w http.ResponseWriter, r *http.Request) bool {
scopesFromToken := make([]string, 0)
for _, s := range configuredScopes {
if cur, ok := tc.Scopes[s].(string); ok {
glog.V(2).Infof("Found configured scope %s", cur)
scopesFromToken = append(scopesFromToken, cur)
w.Header().Add(s, cur)
}
}
return len(scopesFromToken) > 0
}
}

// ScopeAndCheck does an AND check of scopes given from token of the
// request to all provided scopes. Only if all of provided scopes are found in the
// Scopes of the token it grants access to the resource.
Expand All @@ -162,6 +232,25 @@ func ScopeAndCheck(name string, scopes ...string) func(tc *ginoauth2.TokenContai
}
}

// ScopeAndCheckNetHTTP is the net/http version of ScopeAndCheck
func ScopeAndCheckNetHTTP(name string, scopes ...string) func(tc *ginoauth2.TokenContainer, w http.ResponseWriter, r *http.Request) bool {
glog.Infof("ScopeCheck %s configured to grant access only if scopes: %v are present", name, scopes)
configuredScopes := scopes
return func(tc *ginoauth2.TokenContainer, w http.ResponseWriter, r *http.Request) bool {
scopesFromToken := make([]string, 0)
for _, s := range configuredScopes {
if cur, ok := tc.Scopes[s].(string); ok {
glog.V(2).Infof("Found configured scope %s", cur)
scopesFromToken = append(scopesFromToken, cur)
w.Header().Add(s, cur)
} else {
return false
}
}
return true
}
}

// NoAuthorization sets "team" and "uid" in the context without
// checking if the user/team is authorized.
func NoAuthorization() func(tc *ginoauth2.TokenContainer, ctx *gin.Context) bool {
Expand Down