Different matching strategies for TCP and UDP sockets

pull/16/head
Bruce Leidl 9 years ago
parent b4fb258d0d
commit 65f65d2a42

@ -181,17 +181,15 @@ func (fw *Firewall) filterPacket(pkt *nfqueue.Packet) {
} }
func findProcessForPacket(pkt *nfqueue.Packet) *proc.ProcInfo { func findProcessForPacket(pkt *nfqueue.Packet) *proc.ProcInfo {
proto := ""
switch pkt.Protocol { switch pkt.Protocol {
case nfqueue.TCP: case nfqueue.TCP:
proto = "tcp" return proc.LookupTCPSocketProcess(pkt.SrcPort, pkt.Dst, pkt.DstPort)
case nfqueue.UDP: case nfqueue.UDP:
proto = "udp" return proc.LookupUDPSocketProcess(pkt.SrcPort)
default: default:
log.Warning("Packet has unknown protocol: %d", pkt.Protocol) log.Warning("Packet has unknown protocol: %d", pkt.Protocol)
return nil return nil
} }
return proc.LookupSocketProcess(proto, pkt.SrcPort, pkt.Dst, pkt.DstPort)
} }
func basicAllowPacket(pkt *nfqueue.Packet) bool { func basicAllowPacket(pkt *nfqueue.Packet) bool {

@ -9,7 +9,6 @@ import (
"strconv" "strconv"
"strings" "strings"
"github.com/subgraph/fw-daemon/Godeps/_workspace/src/github.com/op/go-logging" "github.com/subgraph/fw-daemon/Godeps/_workspace/src/github.com/op/go-logging"
"github.com/subgraph/fw-daemon/nfqueue"
) )
var log = logging.MustGetLogger("proc") var log = logging.MustGetLogger("proc")
@ -19,8 +18,17 @@ func SetLogger(logger *logging.Logger) {
var pcache = &pidCache{} var pcache = &pidCache{}
func LookupSocketProcess(proto string, srcPort uint16, dstAddr net.IP, dstPort uint16) *ProcInfo {
ss := findSocket(proto, srcPort, dstAddr, dstPort) func LookupUDPSocketProcess(srcPort uint16) *ProcInfo {
ss := findUDPSocket(srcPort)
if ss == nil {
return nil
}
return pcache.lookup(ss.inode)
}
func LookupTCPSocketProcess(srcPort uint16, dstAddr net.IP, dstPort uint16) *ProcInfo {
ss := findTCPSocket(srcPort, dstAddr, dstPort)
if ss == nil { if ss == nil {
return nil return nil
} }
@ -56,20 +64,6 @@ func (sa *socketAddr) parse(s string) error {
} }
func printPacket(pkt *nfqueue.Packet) string {
proto := func() string {
switch pkt.Protocol {
case nfqueue.TCP:
return "TCP"
case nfqueue.UDP:
return "UDP"
default:
return "???"
}
}()
return fmt.Sprintf("(%s %s:%d --> %s:%d)", proto, pkt.Src, pkt.SrcPort, pkt.Dst.String(), pkt.DstPort)
}
func ParseIp(ip string) (net.IP, error) { func ParseIp(ip string) (net.IP, error) {
var result net.IP var result net.IP
dst, err := hex.DecodeString(ip) dst, err := hex.DecodeString(ip)

@ -29,15 +29,19 @@ func (ss *socketStatus) String() string {
return fmt.Sprintf("%s -> %s uid=%d inode=%d", ss.local, ss.remote, ss.uid, ss.inode) return fmt.Sprintf("%s -> %s uid=%d inode=%d", ss.local, ss.remote, ss.uid, ss.inode)
} }
func findUDPSocket(srcPort uint16, dstAddr net.IP, dstPort uint16) *socketStatus { func findUDPSocket(srcPort uint16) *socketStatus {
return findSocket("udp", srcPort, dstAddr, dstPort) return findSocket("udp", func(ss socketStatus) bool {
return ss.local.port == srcPort
})
} }
func findTCPSocket(srcPort uint16, dstAddr net.IP, dstPort uint16) *socketStatus { func findTCPSocket(srcPort uint16, dstAddr net.IP, dstPort uint16) *socketStatus {
return findSocket("tcp", srcPort, dstAddr, dstPort) return findSocket("tcp", func(ss socketStatus) bool {
return ss.remote.port == dstPort && ss.remote.ip.Equal(dstAddr) && ss.local.port == srcPort
})
} }
func findSocket(proto string, srcPort uint16, dstAddr net.IP, dstPort uint16) *socketStatus { func findSocket(proto string, matcher func(socketStatus) bool) *socketStatus {
var ss socketStatus var ss socketStatus
for _,line := range getSocketLines(proto) { for _,line := range getSocketLines(proto) {
if len(line) == 0 { if len(line) == 0 {
@ -47,7 +51,7 @@ func findSocket(proto string, srcPort uint16, dstAddr net.IP, dstPort uint16) *s
log.Warning("Unable to parse line from /proc/net/%s [%s]: %v", proto, line, err) log.Warning("Unable to parse line from /proc/net/%s [%s]: %v", proto, line, err)
continue continue
} }
if ss.remote.port == dstPort && ss.remote.ip.Equal(dstAddr) && ss.local.port == srcPort { if matcher(ss) {
ss.line = line ss.line = line
return &ss return &ss
} }

Loading…
Cancel
Save