From 0054afb826c5ba09381cb2ffeb3285b6f5b50585 Mon Sep 17 00:00:00 2001 From: Bruce Leidl Date: Tue, 26 Jan 2016 01:57:47 +0000 Subject: [PATCH] refactor of proc reading code --- policy.go | 16 +++- proc/proc.go | 200 ++++------------------------------------------- proc/proc_pid.go | 65 +++++++++------ proc/socket.go | 95 ++++++++++++++++++++++ 4 files changed, 164 insertions(+), 212 deletions(-) create mode 100644 proc/socket.go diff --git a/policy.go b/policy.go index c523fd5..0650439 100644 --- a/policy.go +++ b/policy.go @@ -163,7 +163,7 @@ func (fw *Firewall) filterPacket(pkt *nfqueue.Packet) { fw.dns.processDNS(pkt) return } - pinfo := proc.FindProcessForPacket(pkt) + pinfo := findProcessForPacket(pkt) if pinfo == nil { log.Warning("No proc found for %s", printPacket(pkt, fw.dns.Lookup(pkt.Dst))) pkt.Accept() @@ -180,6 +180,20 @@ func (fw *Firewall) filterPacket(pkt *nfqueue.Packet) { policy.processPacket(pkt, pinfo) } +func findProcessForPacket(pkt *nfqueue.Packet) *proc.ProcInfo { + proto := "" + switch pkt.Protocol { + case nfqueue.TCP: + proto = "tcp" + case nfqueue.UDP: + proto = "udp" + default: + log.Warning("Packet has unknown protocol: %d", pkt.Protocol) + return nil + } + return proc.LookupSocketProcess(proto, pkt.SrcPort, pkt.Dst, pkt.DstPort) +} + func basicAllowPacket(pkt *nfqueue.Packet) bool { return pkt.Dst.IsLoopback() || pkt.Dst.IsLinkLocalMulticast() || diff --git a/proc/proc.go b/proc/proc.go index 503d01f..3164b76 100644 --- a/proc/proc.go +++ b/proc/proc.go @@ -19,101 +19,24 @@ func SetLogger(logger *logging.Logger) { log = logger } -type socketAddr struct { - ip net.IP - port uint16 -} - -func (sa socketAddr) String() string { - return fmt.Sprintf("%v:%d", sa.ip, sa.port) -} +var pcache = &pidCache{} -type socketStatus struct { - local socketAddr - remote socketAddr - uid int - inode uint64 - pid int - // XXX debugging - line string -} - -func (ss *socketStatus) String() string { - return fmt.Sprintf("%s -> %s uid=%d inode=%d pid=%d", ss.local, ss.remote, ss.uid, ss.inode, ss.pid) +func LookupSocketProcess(proto string, srcPort uint16, dstAddr net.IP, dstPort uint16) *ProcInfo { + ss := findSocket(proto, srcPort, dstAddr, dstPort) + if ss == nil { + return nil + } + return pcache.lookup(ss.inode) } type ConnectionInfo struct { - proc *ProcInfo + pinfo *ProcInfo local *socketAddr remote *socketAddr } func (ci *ConnectionInfo) String() string { - return fmt.Sprintf("%v %s %s", ci.proc, ci.local, ci.remote) -} - -func FindProcessForPacket(pkt *nfqueue.Packet) *ProcInfo { - ss := getSocketForPacket(pkt) - if ss == nil { - return nil - } - return findProcessForSocket(ss) -} -func findProcessForSocket(ss *socketStatus) *ProcInfo { - - exePath, err := os.Readlink(fmt.Sprintf("/proc/%d/exe", ss.pid)) - if err != nil { - log.Warning("Error reading exe link for pid %d: %v", ss.pid, err) - return nil - } - bs, err := ioutil.ReadFile(fmt.Sprintf("/proc/%d/cmdline", ss.pid)) - if err != nil { - log.Warning("Error reading cmdline for pid %d: %v", ss.pid, err) - return nil - } - for i, b := range bs { - if b == 0 { - bs[i] = byte(' ') - } - } - - finfo, err := os.Stat(fmt.Sprintf("/proc/%d", ss.pid)) - if err != nil { - log.Warning("Could not stat /proc/%d: %v", ss.pid, err) - return nil - } - finfo.Sys() - return &ProcInfo{ - Pid: ss.pid, - Uid: ss.uid, - ExePath: exePath, - CmdLine: string(bs), - } -} - -func getSocketLinesForPacket(pkt *nfqueue.Packet) []string { - if pkt.Protocol == nfqueue.TCP { - return getSocketLines("tcp") - } else if pkt.Protocol == nfqueue.UDP { - return getSocketLines("udp") - } else { - log.Warning("Cannot lookup socket for protocol %s", pkt.Protocol) - return nil - } -} - -func getSocketLines(proto string) []string { - path := fmt.Sprintf("/proc/net/%s", proto) - data, err := ioutil.ReadFile(path) - if err != nil { - log.Warning("Error reading %s: %v", path, err) - return nil - } - lines := strings.Split(string(data), "\n") - if len(lines) > 0 { - lines = lines[1:] - } - return lines + return fmt.Sprintf("%v %s %s", ci.pinfo, ci.local, ci.remote) } func (sa *socketAddr) parse(s string) error { @@ -134,29 +57,6 @@ func (sa *socketAddr) parse(s string) error { return nil } -func (ss *socketStatus) parseLine(line string) error { - fs := strings.Fields(line) - if len(fs) < 10 { - return errors.New("insufficient fields") - } - if err := ss.local.parse(fs[1]); err != nil { - return err - } - if err := ss.remote.parse(fs[2]); err != nil { - return err - } - uid, err := strconv.ParseUint(fs[7], 10, 32) - if err != nil { - return err - } - ss.uid = int(uid) - inode, err := strconv.ParseUint(fs[9], 10, 64) - if err != nil { - return err - } - ss.inode = inode - return nil -} func printPacket(pkt *nfqueue.Packet) string { proto := func() string { @@ -172,38 +72,6 @@ func printPacket(pkt *nfqueue.Packet) string { return fmt.Sprintf("(%s %s:%d --> %s:%d)", proto, pkt.Src, pkt.SrcPort, pkt.Dst.String(), pkt.DstPort) } -func getSocketForPacket(pkt *nfqueue.Packet) *socketStatus { - ss := findSocket(pkt) - if ss == nil { - return nil - } - pid := findPidForInode(ss.inode) - if pid > 0 { - ss.pid = pid - return ss - } - log.Info("Unable to find socket link socket:[%d] %s", ss.inode, printPacket(pkt)) - log.Info("Line was %s", ss.line) - return nil -} - -func findSocket(pkt *nfqueue.Packet) *socketStatus { - var status socketStatus - for _, line := range getSocketLinesForPacket(pkt) { - if len(line) == 0 { - continue - } - if err := status.parseLine(line); err != nil { - log.Warning("Unable to parse line [%s]: %v", line, err) - } else if status.remote.ip.Equal(pkt.Dst) && status.remote.port == pkt.DstPort && status.local.ip.Equal(pkt.Src) && status.local.port == pkt.SrcPort { - status.line = line - return &status - } - } - log.Info("Failed to find socket for packet: %s", printPacket(pkt)) - return nil -} - func ParseIp(ip string) (net.IP, error) { var result net.IP dst, err := hex.DecodeString(ip) @@ -227,48 +95,6 @@ func ParsePort(port string) (uint16, error) { return uint16(p64), nil } -func findPidForInode(inode uint64) int { - search := fmt.Sprintf("socket:[%d]", inode) - for _, pid := range getAllPids() { - if matchesSocketLink(pid, search) { - return pid - } - } - return -1 -} - -func matchesSocketLink(pid int, search string) bool { - paths, _ := filepath.Glob(fmt.Sprintf("/proc/%d/fd/*", pid)) - for _, p := range paths { - link, err := os.Readlink(p) - if err == nil && link == search { - return true - } - } - return false -} - -func getAllPids() []int { - var pids []int - d, err := os.Open("/proc") - if err != nil { - log.Warning("Error opening /proc: %v", err) - return nil - } - defer d.Close() - names, err := d.Readdirnames(0) - if err != nil { - log.Warning("Error reading directory names from /proc: %v", err) - return nil - } - for _, n := range names { - if pid, err := strconv.ParseUint(n, 10, 32); err == nil { - pids = append(pids, int(pid)) - } - } - return pids -} - func getConnections() ([]*ConnectionInfo, error) { conns,err := readConntrack() if err != nil { @@ -288,12 +114,14 @@ func resolveProcinfo(conns []*ConnectionInfo) { if err := ss.parseLine(line); err != nil { log.Warning("Unable to parse line [%s]: %v", line, err) } else { + /* pid := findPidForInode(ss.inode) if pid > 0 { ss.pid = pid fmt.Println("Socket", ss) sockets = append(sockets, ss) } + */ } } for _,ci := range conns { @@ -301,9 +129,9 @@ func resolveProcinfo(conns []*ConnectionInfo) { if ss == nil { continue } - proc := findProcessForSocket(ss) - if proc != nil { - ci.proc = proc + pinfo := pcache.lookup(ss.inode) + if pinfo != nil { + ci.pinfo = pinfo } } } diff --git a/proc/proc_pid.go b/proc/proc_pid.go index e247153..1536b8c 100644 --- a/proc/proc_pid.go +++ b/proc/proc_pid.go @@ -7,35 +7,51 @@ import ( "strings" "path" "io/ioutil" + "sync" + "syscall" ) type ProcInfo struct { - Uid int + Uid int Pid int loaded bool ExePath string CmdLine string } -var cacheMap = make(map[uint64]*ProcInfo) +type pidCache struct { + cacheMap map[uint64]*ProcInfo + lock sync.Mutex +} -func pidCacheLookup(inode uint64) *ProcInfo { - pi,ok := cacheMap[inode] - if ok { +func (pc *pidCache) lookup(inode uint64) *ProcInfo { + pc.lock.Lock() + defer pc.lock.Unlock() + pi,ok := pc.cacheMap[inode] + if ok && pi.loadProcessInfo() { + return pi + } + pc.cacheMap = loadCache() + pi,ok = pc.cacheMap[inode] + if ok && pi.loadProcessInfo() { return pi } - pidCacheReload() - return cacheMap[inode] + return nil } -func pidCacheReload() { +func loadCache() map[uint64]*ProcInfo { + cmap := make(map[uint64]*ProcInfo) for _, n := range readdir("/proc") { pid := toPid(n) if pid != 0 { - scrapePid(pid) + pinfo := &ProcInfo{Pid: pid} + for _,inode := range inodesFromPid(pid) { + cmap[inode] = pinfo + } } } + return cmap } func toPid(name string) int { @@ -54,36 +70,34 @@ func toPid(name string) int { return (int)(pid) } -func scrapePid(pid int) { +func inodesFromPid(pid int) []uint64 { + var inodes []uint64 fdpath := fmt.Sprintf("/proc/%d/fd", pid) for _, n := range readdir(fdpath) { if link, err := os.Readlink(path.Join(fdpath, n)); err != nil { - log.Warning("Error reading link %s: %v", n, err) + if !os.IsNotExist(err) { + log.Warning("Error reading link %s: %v", n, err) + } } else { - extractSocket(link, pid) + if inode := extractSocket(link); inode > 0 { + inodes = append(inodes, inode) + } } } + return inodes } -func extractSocket(name string, pid int) { +func extractSocket(name string) uint64 { if !strings.HasPrefix(name, "socket:[") || !strings.HasSuffix(name, "]") { - return + return 0 } val := name[8:len(name)-1] inode,err := strconv.ParseUint(val, 10, 64) if err != nil { log.Warning("Error parsing inode value from %s: %v", name, err) - return - } - cacheAddPid(inode, pid) -} - -func cacheAddPid(inode uint64, pid int) { - pi,ok := cacheMap[inode] - if ok && pi.Pid == pid { - return + return 0 } - cacheMap[inode] = &ProcInfo{ Pid: pid } + return inode } func readdir(dir string) []string { @@ -127,7 +141,8 @@ func (pi *ProcInfo) loadProcessInfo() bool { log.Warning("Could not stat /proc/%d: %v", pi.Pid, err) return false } - finfo.Sys() + sys := finfo.Sys().(*syscall.Stat_t) + pi.Uid = int(sys.Uid) pi.ExePath = exePath pi.CmdLine = string(bs) pi.loaded = true diff --git a/proc/socket.go b/proc/socket.go new file mode 100644 index 0000000..7a39eb8 --- /dev/null +++ b/proc/socket.go @@ -0,0 +1,95 @@ +package proc +import ( + "net" + "fmt" + "io/ioutil" + "strings" + "errors" + "strconv" +) + +type socketAddr struct { + ip net.IP + port uint16 +} + +func (sa socketAddr) String() string { + return fmt.Sprintf("%v:%d", sa.ip, sa.port) +} + +type socketStatus struct { + local socketAddr + remote socketAddr + uid int + inode uint64 + line string +} + +func (ss *socketStatus) String() string { + return fmt.Sprintf("%s -> %s uid=%d inode=%d", ss.local, ss.remote, ss.uid, ss.inode) +} + +func findUDPSocket(srcPort uint16, dstAddr net.IP, dstPort uint16) *socketStatus { + return findSocket("udp", srcPort, dstAddr, dstPort) +} + +func findTCPSocket(srcPort uint16, dstAddr net.IP, dstPort uint16) *socketStatus { + return findSocket("tcp", srcPort, dstAddr, dstPort) +} + +func findSocket(proto string, srcPort uint16, dstAddr net.IP, dstPort uint16) *socketStatus { + var ss socketStatus + for _,line := range getSocketLines(proto) { + if len(line) == 0 { + continue + } + if err := ss.parseLine(line); err != nil { + log.Warning("Unable to parse line from /proc/net/%s [%s]: %v", proto, line, err) + continue + } + if ss.remote.port == dstPort && ss.remote.ip.Equal(dstAddr) && ss.local.port == srcPort { + ss.line = line + return &ss + } + } + return nil +} + +func (ss *socketStatus) parseLine(line string) error { + fs := strings.Fields(line) + if len(fs) < 10 { + return errors.New("insufficient fields") + } + if err := ss.local.parse(fs[1]); err != nil { + return err + } + if err := ss.remote.parse(fs[2]); err != nil { + return err + } + uid, err := strconv.ParseUint(fs[7], 10, 32) + if err != nil { + return err + } + ss.uid = int(uid) + inode, err := strconv.ParseUint(fs[9], 10, 64) + if err != nil { + return err + } + ss.inode = inode + return nil +} + + +func getSocketLines(proto string) []string { + path := fmt.Sprintf("/proc/net/%s", proto) + data, err := ioutil.ReadFile(path) + if err != nil { + log.Warning("Error reading %s: %v", path, err) + return nil + } + lines := strings.Split(string(data), "\n") + if len(lines) > 0 { + lines = lines[1:] + } + return lines +}