Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(trace): add drop-only option #632

Open
wants to merge 7 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion cmd/trace.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ var (
L4Proto string
Port int
OutputFile string
DropOnly bool
)

func init() {
Expand Down Expand Up @@ -56,7 +57,7 @@ func init() {

ctx, cancel := signal.NotifyContext(context.Background(), syscall.SIGINT, syscall.SIGTERM)
defer cancel()
if err := trace.StartTrace(ctx, IPVersion, L4ProtoNo, Port, OutputFile); err != nil {
if err := trace.StartTrace(ctx, IPVersion, L4ProtoNo, Port, DropOnly, OutputFile); err != nil {
logrus.Fatalln(err)
}
},
Expand All @@ -66,6 +67,7 @@ func init() {
traceCmd.PersistentFlags().BoolVarP(&IPv6, "ipv6", "6", false, "Capture IPv6 traffic")
traceCmd.PersistentFlags().StringVarP(&L4Proto, "l4-proto", "p", "tcp", "Layer 4 protocol")
traceCmd.PersistentFlags().IntVarP(&Port, "port", "P", 80, "Port")
traceCmd.PersistentFlags().BoolVarP(&DropOnly, "drop-only", "", true, "only trace the dropped package")
traceCmd.PersistentFlags().StringVarP(&OutputFile, "output", "o", "/dev/stdout", "Output file")

rootCmd.AddCommand(traceCmd)
Expand Down
62 changes: 45 additions & 17 deletions trace/trace.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import (
"encoding/binary"
"errors"
"fmt"
"slices"
"net"
"os"
"syscall"
Expand Down Expand Up @@ -43,7 +44,7 @@ func init() {
}
}

func StartTrace(ctx context.Context, ipVersion int, l4ProtoNo uint16, port int, outputFile string) (err error) {
func StartTrace(ctx context.Context, ipVersion int, l4ProtoNo uint16, port int, dropOnly bool, outputFile string) (err error) {
kernelVersion, err := internal.KernelVersion()
if err != nil {
return fmt.Errorf("failed to get kernel version: %w", err)
Expand Down Expand Up @@ -80,7 +81,7 @@ func StartTrace(ctx context.Context, ipVersion int, l4ProtoNo uint16, port int,
}()

fmt.Printf("\nstart tracing\n")
if err = handleEvents(ctx, objs, outputFile, kfreeSkbReasons); err != nil {
if err = handleEvents(ctx, objs, outputFile, kfreeSkbReasons, dropOnly); err != nil {
return
}
return
Expand Down Expand Up @@ -221,7 +222,7 @@ func attachBpfToTargets(objs *bpfObjects, targets map[string]int) (links []link.
return links, nil
}

func handleEvents(ctx context.Context, objs *bpfObjects, outputFile string, kfreeSkbReasons map[uint64]string) (err error) {
func handleEvents(ctx context.Context, objs *bpfObjects, outputFile string, kfreeSkbReasons map[uint64]string, dropOnly bool) (err error) {
writer, err := os.Create(outputFile)
if err != nil {
return
Expand Down Expand Up @@ -258,6 +259,12 @@ func handleEvents(ctx context.Context, objs *bpfObjects, outputFile string, kfre
PayloadLen uint16
}

var skb2events map[uint64][]bpfEvent
skb2events = make(map[uint64][]bpfEvent)
// a map to save slices of bpfEvent of the Skb
var skb2symNames map[uint64][]string
skb2symNames = make(map[uint64][]string)
// a map to save slices of function name called with the Skb
for {
rec, err := eventsReader.Read()
if err != nil {
Expand All @@ -273,22 +280,43 @@ func handleEvents(ctx context.Context, objs *bpfObjects, outputFile string, kfre
logrus.Debugf("failed to parse ringbuf event: %+v", err)
continue
}

fmt.Fprintf(writer, "%x mark=%x netns=%010d if=%d(%s) proc=%d(%s) ", event.Skb, event.Mark, event.Netns, event.Ifindex, TrimNull(string(event.Ifname[:])), event.Pid, TrimNull(string(event.Pname[:])))
if event.L3Proto == syscall.ETH_P_IP {
fmt.Fprintf(writer, "%s:%d > %s:%d ", net.IP(event.Saddr[:4]).String(), Ntohs(event.Sport), net.IP(event.Daddr[:4]).String(), Ntohs(event.Dport))
} else {
fmt.Fprintf(writer, "[%s]:%d > [%s]:%d ", net.IP(event.Saddr[:]).String(), Ntohs(event.Sport), net.IP(event.Daddr[:]).String(), Ntohs(event.Dport))
if skb2events[event.Skb]==nil {
skb2events[event.Skb] = []bpfEvent{}
}
if event.L4Proto == syscall.IPPROTO_TCP {
fmt.Fprintf(writer, "tcp_flags=%s ", TcpFlags(event.TcpFlags))
skb2events[event.Skb] = append(skb2events[event.Skb],event)


sym := NearestSymbol(event.Pc);
if skb2symNames[event.Skb]==nil {
skb2symNames[event.Skb] = []string{}
}
fmt.Fprintf(writer, "payload_len=%d ", event.PayloadLen)
sym := NearestSymbol(event.Pc)
fmt.Fprintf(writer, "%s", sym.Name)
if sym.Name == "kfree_skb_reason" {
fmt.Fprintf(writer, "(%s)", kfreeSkbReasons[event.SecondParam])
skb2symNames[event.Skb] = append(skb2symNames[event.Skb],sym.Name)
switch sym.Name {
case "__kfree_skb","kfree_skbmem":
// most skb end in the call of kfree_skbmem
if !dropOnly || slices.Contains(skb2symNames[event.Skb],"kfree_skb_reason") {
// trace dropOnly with drop reason or all skb
for _,skb_ev := range skb2events[event.Skb] {
fmt.Fprintf(writer, "%x mark=%x netns=%010d if=%d(%s) proc=%d(%s) ", skb_ev.Skb, skb_ev.Mark, skb_ev.Netns, skb_ev.Ifindex, TrimNull(string(skb_ev.Ifname[:])), skb_ev.Pid, TrimNull(string(skb_ev.Pname[:])))
if event.L3Proto == syscall.ETH_P_IP {
fmt.Fprintf(writer, "%s:%d > %s:%d ", net.IP(skb_ev.Saddr[:4]).String(), Ntohs(skb_ev.Sport), net.IP(skb_ev.Daddr[:4]).String(), Ntohs(skb_ev.Dport))
} else {
fmt.Fprintf(writer, "[%s]:%d > [%s]:%d ", net.IP(skb_ev.Saddr[:]).String(), Ntohs(skb_ev.Sport), net.IP(skb_ev.Daddr[:]).String(), Ntohs(skb_ev.Dport))
}
if event.L4Proto == syscall.IPPROTO_TCP {
fmt.Fprintf(writer, "tcp_flags=%s ", TcpFlags(skb_ev.TcpFlags))
}
fmt.Fprintf(writer, "payload_len=%d ", event.PayloadLen)
sym := NearestSymbol(skb_ev.Pc)
fmt.Fprintf(writer, "%s", sym.Name)
if sym.Name == "kfree_skb_reason" {
fmt.Fprintf(writer, "(%s)", kfreeSkbReasons[skb_ev.SecondParam])
}
fmt.Fprintf(writer, "\n")
}
delete(skb2events, event.Skb)
delete(skb2symNames, event.Skb)
}
}
fmt.Fprintf(writer, "\n")
}
}
Loading