Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

net/http middleware functions #47

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
104 changes: 104 additions & 0 deletions example/zalando_nethttp/main.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
// Zalando specific example.
package main

import (
"flag"
"fmt"
"net/http"
"time"

"github.com/golang/glog"
"github.com/zalando/gin-oauth2"
"github.com/zalando/gin-oauth2/zalando"
"goji.io"
"goji.io/pat"
)

var USERS []zalando.AccessTuple = []zalando.AccessTuple{
{"/employees", "sszuecs", "Sandor Szücs"},
{"/employees", "njuettner", "Nick Jüttner"},
}

var TEAMS []zalando.AccessTuple = []zalando.AccessTuple{
{"teams", "opensourceguild", "OpenSource"},
{"teams", "tm", "Platform Engineering / System"},
{"teams", "teapot", "Platform / Cloud API"},
}
var SERVICES []zalando.AccessTuple = []zalando.AccessTuple{
{"services", "foo", "Fooservice"},
}

func loggerMiddleware(next http.Handler) http.Handler {
return http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
glog.Infof("loggerMiddleware: Got request: %s", req.URL)
next.ServeHTTP(rw, req)
})
}

func hello(w http.ResponseWriter, r *http.Request) {
name := pat.Param(r, "name")
fmt.Fprintf(w, "Hello, %s!\n", name)
}

func main() {
flag.Parse()
// start glog flusher
go func() {
for range time.Tick(1 * time.Second) {
glog.Flush()
}
}()

mux := goji.NewMux()
mux.Use(loggerMiddleware)
mux.Use(ginoauth2.RequestLoggerNetHTTP([]string{"uid"}, "data"))
ginoauth2.VarianceTimer = 3000 * time.Millisecond // defaults to 30s

public := goji.SubMux()
mux.Handle(pat.New("/api/*"), public)
public.HandleFunc(pat.Get("/:name"), hello)

private := goji.SubMux()
mux.Handle(pat.New("/private/*"), private)
privateGroup := goji.SubMux()
mux.Handle(pat.New("/privateGroup/*"), privateGroup)
privateUser := goji.SubMux()
mux.Handle(pat.New("/privateUser/*"), privateUser)
privateService := goji.SubMux()
mux.Handle(pat.New("/privateService/*"), privateService)
glog.Infof("Register allowed users: %+v and groups: %+v and services: %+v", USERS, TEAMS, SERVICES)

private.Use(ginoauth2.AuthChainNetHTTP(zalando.OAuth2Endpoint, zalando.UidCheckNetHTTP(USERS), zalando.GroupCheckNetHTTP(TEAMS), zalando.UidCheckNetHTTP(SERVICES)))
privateGroup.Use(ginoauth2.AuthNetHTTP(zalando.GroupCheckNetHTTP(TEAMS), zalando.OAuth2Endpoint))
privateUser.Use(ginoauth2.AuthNetHTTP(zalando.UidCheckNetHTTP(USERS), zalando.OAuth2Endpoint))
privateService.Use(ginoauth2.AuthNetHTTP(zalando.ScopeAndCheckNetHTTP("uidcheck", "uid", "bar"), zalando.OAuth2Endpoint))

private.HandleFunc(pat.Get("/"), http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
h := w.Header()
uid := h.Get("uid")
fmt.Fprintf(w, "Hello from private for groups and users: %s\n", uid)
}))

privateGroup.HandleFunc(pat.Get("/"), http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
h := w.Header()
uid := h.Get("uid")
team := h.Get("team")
fmt.Fprintf(w, "Hello from private group: uid: %s, team: %s\n", uid, team)
}))

privateUser.HandleFunc(pat.Get("/"), http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
h := w.Header()
uid := h.Get("uid")
fmt.Fprintf(w, "Hello from private user: uid: %s\n", uid)
}))

privateService.HandleFunc(pat.Get("/"), http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
h := w.Header()
cn := h.Get("cn")
fmt.Fprintf(w, "Hello from private service cn: %s\n", cn)
}))

glog.Info("bootstrapped application")
http.ListenAndServe("localhost:8081", mux)

}
120 changes: 105 additions & 15 deletions ginoauth2.go
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,6 @@ import (
"errors"
"io/ioutil"
"net/http"
"net/url"
"strings"
"time"

Expand Down Expand Up @@ -87,6 +86,10 @@ type TokenContainer struct {
// access.
type AccessCheckFunction func(tc *TokenContainer, ctx *gin.Context) bool

// AccessCheckFunctionNetHTTP is a function that checks if a given token grants
// access.
type AccessCheckFunctionNetHTTP func(tc *TokenContainer, w http.ResponseWriter, r *http.Request) bool

func extractToken(r *http.Request) (*oauth2.Token, error) {
hdr := r.Header.Get("Authorization")
if hdr == "" {
Expand All @@ -102,15 +105,12 @@ func extractToken(r *http.Request) (*oauth2.Token, error) {
}

func RequestAuthInfo(t *oauth2.Token) ([]byte, error) {
var uv = make(url.Values)
// uv.Set("realm", o.Realm)
uv.Set("access_token", t.AccessToken)
infoURL := AuthInfoURL + "?" + uv.Encode()
client := &http.Client{Transport: &Transport}
req, err := http.NewRequest("GET", infoURL, nil)
req, err := http.NewRequest("GET", AuthInfoURL, nil)
if err != nil {
return nil, err
}
req.Header.Add("Authorization", "Bearer "+t.AccessToken)

resp, err := client.Do(req)
if err != nil {
Expand Down Expand Up @@ -171,21 +171,21 @@ func GetTokenContainer(token *oauth2.Token) (*TokenContainer, error) {
glog.Errorf("[Gin-OAuth] JSON.Unmarshal failed caused by: %s", err)
return nil, err
}
if _, ok := data["error_description"]; ok {
if ed, ok := data["error_description"]; ok {
var s string
s = data["error_description"].(string)
s = ed.(string)
glog.Errorf("[Gin-OAuth] RequestAuthInfo returned an error: %s", s)
return nil, errors.New(s)
}
return ParseTokenContainer(token, data)
}

func getTokenContainer(ctx *gin.Context) (*TokenContainer, bool) {
func getTokenContainer(r *http.Request) (*TokenContainer, bool) {
var oauthToken *oauth2.Token
var tc *TokenContainer
var err error

if oauthToken, err = extractToken(ctx.Request); err != nil {
if oauthToken, err = extractToken(r); err != nil {
glog.Errorf("[Gin-OAuth] Can not extract oauth2.Token, caused by: %s", err)
return nil, false
}
Expand Down Expand Up @@ -232,6 +232,11 @@ func Auth(accessCheckFunction AccessCheckFunction, endpoints oauth2.Endpoint) gi
return AuthChain(endpoints, accessCheckFunction)
}

// AuthNetHTTP is the net/http version of Auth
func AuthNetHTTP(accessCheckFunction AccessCheckFunctionNetHTTP, endpoints oauth2.Endpoint) func(http.Handler) http.Handler {
return AuthChainNetHTTP(endpoints, accessCheckFunction)
}

// AuthChain is a router middleware that can be used to get an authenticated
// and authorized service for the whole router group. Similar to Auth, but
// takes a chain of AccessCheckFunctions and only fails if all of them fails.
Expand Down Expand Up @@ -262,7 +267,7 @@ func AuthChain(endpoints oauth2.Endpoint, accessCheckFunctions ...AccessCheckFun
varianceControl := make(chan bool, 1)

go func() {
tokenContainer, ok := getTokenContainer(ctx)
tokenContainer, ok := getTokenContainer(ctx.Request)
if !ok {
// set LOCATION header to auth endpoint such that the user can easily get a new access-token
ctx.Writer.Header().Set("Location", endpoints.AuthURL)
Expand Down Expand Up @@ -309,6 +314,70 @@ func AuthChain(endpoints oauth2.Endpoint, accessCheckFunctions ...AccessCheckFun
}
}

// AuthChainNetHTTP is the net/http version of AuthChain
func AuthChainNetHTTP(endpoints oauth2.Endpoint, accessCheckFunctions ...AccessCheckFunctionNetHTTP) func(http.Handler) http.Handler {
// init
AuthInfoURL = endpoints.TokenURL
// middleware
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(writer http.ResponseWriter, request *http.Request) {
t := time.Now()
varianceControl := make(chan bool, 1)

go func() {
tokenContainer, ok := getTokenContainer(request)
if !ok {
// set LOCATION header to auth endpoint such that the user can easily get a new access-token
writer.Header().Set("Location", endpoints.AuthURL)
writer.WriteHeader(http.StatusUnauthorized)
writer.Write([]byte("No token in context"))
varianceControl <- false
return
}

if !tokenContainer.Valid() {
// set LOCATION header to auth endpoint such that the user can easily get a new access-token
writer.Header().Set("Location", endpoints.AuthURL)
writer.WriteHeader(http.StatusUnauthorized)
writer.Write([]byte("Invalid Token"))
varianceControl <- false
return
}

for i, fn := range accessCheckFunctions {
if fn(tokenContainer, writer, request) {
varianceControl <- true
break
}

if len(accessCheckFunctions)-1 == i {
writer.WriteHeader(http.StatusForbidden)
writer.Write([]byte("Access to the Resource is fobidden"))
varianceControl <- false
return
}
}
}()

select {
case ok := <-varianceControl:
if !ok {
glog.V(2).Infof("[Gin-OAuth] %12v %s access not allowed", time.Since(t), request.URL.Path)
return
}
case <-time.After(VarianceTimer):
writer.WriteHeader(http.StatusGatewayTimeout)
writer.Write([]byte("Authorization check overtime"))
glog.V(2).Infof("[Gin-OAuth] %12v %s overtime", time.Since(t), request.URL.Path)
return
}

glog.V(2).Infof("[Gin-OAuth] %12v %s access allowed", time.Since(t), request.URL.Path)
next.ServeHTTP(writer, request)
})
}
}

// RequestLogger is a middleware that logs all the request and prints
// relevant information. This can be used for logging all the
// requests that contain important information and are authorized.
Expand All @@ -333,12 +402,10 @@ func RequestLogger(keys []string, contentKey string) gin.HandlerFunc {
c.Next()
err := c.Errors
if request.Method != "GET" && err == nil {
data, e := c.Get(contentKey)
if e != false { //key is non existent
if data, ok := c.Get(contentKey); ok {
values := make([]string, 0)
for _, key := range keys {
val, keyPresent := c.Get(key)
if keyPresent {
if val, ok := c.Get(key); ok {
values = append(values, val.(string))
}
}
Expand All @@ -348,4 +415,27 @@ func RequestLogger(keys []string, contentKey string) gin.HandlerFunc {
}
}

// RequestLoggerNetHTTP is the net/http version of RequestLogger.
func RequestLoggerNetHTTP(keys []string, contentKey string) func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
request := r
next.ServeHTTP(w, r.WithContext(ctx))
if request.Method != "GET" {
if data, ok := ctx.Value(contentKey).(string); ok {
values := make([]string, 0)
for _, key := range keys {
s, ok := ctx.Value(key).(string)
if ok {
values = append(values, s)
}
}
glog.Infof("[Gin-OAuth] Request: %+v for %s", data, strings.Join(values, "-"))
}
}
})
}
}

// vim: ts=4 sw=4 noexpandtab nolist syn=go
Loading