Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement simple rfc2136 listener middlewares #5

Merged
merged 4 commits into from
Aug 9, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 20 additions & 0 deletions internal/listener/lrfc2136/dnserr/dns_error.errs.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

38 changes: 38 additions & 0 deletions internal/listener/lrfc2136/dnserr/dns_error.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
package dnserr

import (
"fmt"
)

//go:generate go run ./errsgen --pkg dnserr --out dns_error.errs.go

type DNSError struct {
RCode int
Nested error
}

func (e DNSError) Error() string {
if e.Nested == nil {
return ""
}
return fmt.Sprintf("dns error[%d]: %v", e.RCode, e.Nested)
}

func (e *DNSError) Unwrap() error {
return e.Nested
}

func (e *DNSError) Is(target error) bool {
other, ok := target.(*DNSError)
if !ok {
return false
}
if e == nil && other == nil {
return true
}
if e == nil || other == nil {
return false
}

return e.RCode == other.RCode
}
68 changes: 68 additions & 0 deletions internal/listener/lrfc2136/dnserr/errsgen/main.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
package main

import (
"bytes"
"flag"
"fmt"
"go/format"
"os"
"strings"
)

var errNames = []string{
"RcodeFormatError",
"RcodeServerFailure",
"RcodeNameError",
"RcodeNotImplemented",
"RcodeRefused",
"RcodeYXDomain",
"RcodeYXRrset",
"RcodeNXRrset",
"RcodeNotAuth",
"RcodeNotZone",
}

func errVar(rcode string) string {
return "Err" + strings.ReplaceAll(strings.TrimPrefix(rcode, "Rcode"), "Error", "")
}

func fatalf(msg string, a ...interface{}) {
_, _ = fmt.Fprintf(os.Stderr, "errsgen: "+msg+"\n", a...)
os.Exit(1)
}

func main() {
var pkg, out string
flag.StringVar(&pkg, "pkg", "errors", "package to use")
flag.StringVar(&out, "out", "errors.go", "output filename")
flag.Parse()

var code bytes.Buffer
_, _ = fmt.Fprintln(&code, "// Code generated by errsgen; DO NOT EDIT. ")
_, _ = fmt.Fprintf(&code, "package %s\n", pkg)

_, _ = fmt.Fprintln(&code, "import (")
_, _ = fmt.Fprintln(&code, `"errors"`)
_, _ = fmt.Fprintln(&code, `"github.com/miekg/dns"`)
_, _ = fmt.Fprintln(&code, ")")

_, _ = fmt.Fprintln(&code, "var (")
for _, errName := range errNames {
errCode := fmt.Sprintf("dns.%s", errName)
_, _ = fmt.Fprintf(&code,
"%s = &DNSError{RCode: %s, Nested: errors.New(dns.RcodeToString[%s])}\n",
errVar(errName), errCode, errCode,
)
}
_, _ = fmt.Fprintln(&code, ")")

formatted, err := format.Source(code.Bytes())
if err != nil {
fatalf("parse generated code: %v", err)
}

err = os.WriteFile(out, formatted, 0o644)
if err != nil {
fatalf("save code: %v", err)
}
}
42 changes: 22 additions & 20 deletions internal/listener/lrfc2136/listener.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ import (
"github.com/rs/zerolog/log"
"golang.org/x/sync/errgroup"

"github.com/buglloc/DNSGateway/internal/listener/lrfc2136/middlewares"
"github.com/buglloc/DNSGateway/internal/upstream"
)

Expand Down Expand Up @@ -109,33 +110,41 @@ func (a *Listener) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
l.Error().Err(err).Msg("request failed")

m := new(dns.Msg)
m.SetReply(r)
m.SetRcode(r, dns.RcodeServerFailure)
_ = w.WriteMsg(m)
}
}

func (a *Listener) lockedServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) error {
tsig := r.IsTsig()
if tsig == nil {
return errors.New("missing TSIG")
}

if err := w.TsigStatus(); err != nil {
return fmt.Errorf("invalid TSIG: %w", err)
}
var responser middlewares.NextFn

switch r.Opcode {
case dns.OpcodeQuery:
if isXRFRequest(r) {
return a.logRequest(ctx, w, r, a.lockedServeXFR)
responser = middlewares.NopResponser(a.lockedServeXFR)
break
}
return a.logRequest(ctx, w, r, a.lockedServeQuery)

responser = middlewares.Responser(a.lockedServeQuery)

case dns.OpcodeUpdate:
return a.logRequest(ctx, w, r, a.lockedServeUpdate)
responser = middlewares.Responser(a.lockedServeUpdate)

default:
responser = middlewares.Responser(func(_ context.Context, _ dns.ResponseWriter, _ *dns.Msg) error {
return fmt.Errorf("unsupported opcode: %s", dns.OpcodeToString[r.Opcode])
})
}

return fmt.Errorf("unsupported opcode: %s", dns.OpcodeToString[r.Opcode])
middlewares.Logger(
middlewares.Recoverer(
middlewares.TSIGChecker(
responser,
),
),
)(ctx, w, r)

return nil
}

func (a *Listener) lockedServeXFR(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) error {
Expand Down Expand Up @@ -190,13 +199,6 @@ func (a *Listener) lockedServeUpdate(ctx context.Context, w dns.ResponseWriter,
return nil
}

func (a *Listener) logRequest(ctx context.Context, w dns.ResponseWriter, r *dns.Msg, fn func(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) error) error {
now := time.Now()
err := fn(ctx, w, r)
log.Ctx(ctx).Info().Dur("elapsed", time.Since(now)).Msg("finished")
return err
}

func (a *Listener) handleXFR(ctx context.Context, q dns.Question, out chan *dns.Envelope) {
log.Ctx(ctx).Info().Str("name", q.Name).Msg("handle XFR request")

Expand Down
17 changes: 17 additions & 0 deletions internal/listener/lrfc2136/middlewares/log.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
package middlewares

import (
"context"
"time"

"github.com/miekg/dns"
"github.com/rs/zerolog/log"
)

func Logger(next NextFn) NextFn {
return func(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) {
now := time.Now()
next(ctx, w, r)
log.Ctx(ctx).Info().Dur("elapsed", time.Since(now)).Msg("finished")
}
}
9 changes: 9 additions & 0 deletions internal/listener/lrfc2136/middlewares/next.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
package middlewares

import (
"context"

"github.com/miekg/dns"
)

type NextFn func(ctx context.Context, w dns.ResponseWriter, r *dns.Msg)
22 changes: 22 additions & 0 deletions internal/listener/lrfc2136/middlewares/recovery.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
package middlewares

import (
"context"

"github.com/miekg/dns"
"github.com/rs/zerolog/log"
)

func Recoverer(next NextFn) NextFn {
return func(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) {
defer func() {
if rvr := recover(); rvr != nil {
log.Ctx(ctx).Panic().Any("error", rvr).Msg("panic occurred")

WriteResponse(ctx, w, r, dns.RcodeServerFailure)
}
}()

next(ctx, w, r)
}
}
55 changes: 55 additions & 0 deletions internal/listener/lrfc2136/middlewares/response.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
package middlewares

import (
"context"
"errors"
"time"

"github.com/miekg/dns"
"github.com/rs/zerolog/log"

"github.com/buglloc/DNSGateway/internal/listener/lrfc2136/dnserr"
)

type HandleFn func(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) error

func NopResponser(fn HandleFn) NextFn {
return func(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) {
err := fn(ctx, w, r)
if err != nil {
log.Ctx(ctx).Error().Err(err).Msg("request failed")
}
}
}

func Responser(fn HandleFn) NextFn {
return func(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) {
err := fn(ctx, w, r)
if err == nil {
WriteResponse(ctx, w, r, dns.RcodeSuccess)
return
}

log.Ctx(ctx).Error().Err(err).Msg("request failed")
var dnsErr *dnserr.DNSError
if errors.As(err, &dnsErr) {
WriteResponse(ctx, w, r, dnsErr.RCode)
return
}

WriteResponse(ctx, w, r, dns.RcodeServerFailure)
}
}

func WriteResponse(ctx context.Context, w dns.ResponseWriter, r *dns.Msg, rcode int) {
m := new(dns.Msg)
m.SetRcode(r, rcode)

if tsig := r.IsTsig(); tsig != nil {
m.SetTsig(tsig.Hdr.Name, dns.HmacSHA256, 300, time.Now().Unix())
}

if err := w.WriteMsg(m); err != nil {
log.Ctx(ctx).Error().Err(err).Msg("write failed")
}
}
34 changes: 34 additions & 0 deletions internal/listener/lrfc2136/middlewares/tsig.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
package middlewares

import (
"context"

"github.com/miekg/dns"
"github.com/rs/zerolog/log"
)

func TSIGChecker(next NextFn) NextFn {
return func(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) {
writeRefuse := func() {
m := new(dns.Msg)
m.SetRcode(r, dns.RcodeRefused)

if err := w.WriteMsg(m); err != nil {
log.Ctx(ctx).Error().Err(err).Msg("write failed")
}
}

tsig := r.IsTsig()
if tsig == nil {
writeRefuse()
return
}

if err := w.TsigStatus(); err != nil {
writeRefuse()
return
}

next(ctx, w, r)
}
}
Loading