refactor of proc reading code

pull/16/head
Bruce Leidl 9 years ago
parent 30a21eaa57
commit 0054afb826

@ -163,7 +163,7 @@ func (fw *Firewall) filterPacket(pkt *nfqueue.Packet) {
fw.dns.processDNS(pkt) fw.dns.processDNS(pkt)
return return
} }
pinfo := proc.FindProcessForPacket(pkt) pinfo := findProcessForPacket(pkt)
if pinfo == nil { if pinfo == nil {
log.Warning("No proc found for %s", printPacket(pkt, fw.dns.Lookup(pkt.Dst))) log.Warning("No proc found for %s", printPacket(pkt, fw.dns.Lookup(pkt.Dst)))
pkt.Accept() pkt.Accept()
@ -180,6 +180,20 @@ func (fw *Firewall) filterPacket(pkt *nfqueue.Packet) {
policy.processPacket(pkt, pinfo) policy.processPacket(pkt, pinfo)
} }
func findProcessForPacket(pkt *nfqueue.Packet) *proc.ProcInfo {
proto := ""
switch pkt.Protocol {
case nfqueue.TCP:
proto = "tcp"
case nfqueue.UDP:
proto = "udp"
default:
log.Warning("Packet has unknown protocol: %d", pkt.Protocol)
return nil
}
return proc.LookupSocketProcess(proto, pkt.SrcPort, pkt.Dst, pkt.DstPort)
}
func basicAllowPacket(pkt *nfqueue.Packet) bool { func basicAllowPacket(pkt *nfqueue.Packet) bool {
return pkt.Dst.IsLoopback() || return pkt.Dst.IsLoopback() ||
pkt.Dst.IsLinkLocalMulticast() || pkt.Dst.IsLinkLocalMulticast() ||

@ -19,101 +19,24 @@ func SetLogger(logger *logging.Logger) {
log = logger log = logger
} }
type socketAddr struct { var pcache = &pidCache{}
ip net.IP
port uint16
}
func (sa socketAddr) String() string {
return fmt.Sprintf("%v:%d", sa.ip, sa.port)
}
type socketStatus struct { func LookupSocketProcess(proto string, srcPort uint16, dstAddr net.IP, dstPort uint16) *ProcInfo {
local socketAddr ss := findSocket(proto, srcPort, dstAddr, dstPort)
remote socketAddr if ss == nil {
uid int return nil
inode uint64 }
pid int return pcache.lookup(ss.inode)
// XXX debugging
line string
}
func (ss *socketStatus) String() string {
return fmt.Sprintf("%s -> %s uid=%d inode=%d pid=%d", ss.local, ss.remote, ss.uid, ss.inode, ss.pid)
} }
type ConnectionInfo struct { type ConnectionInfo struct {
proc *ProcInfo pinfo *ProcInfo
local *socketAddr local *socketAddr
remote *socketAddr remote *socketAddr
} }
func (ci *ConnectionInfo) String() string { func (ci *ConnectionInfo) String() string {
return fmt.Sprintf("%v %s %s", ci.proc, ci.local, ci.remote) return fmt.Sprintf("%v %s %s", ci.pinfo, ci.local, ci.remote)
}
func FindProcessForPacket(pkt *nfqueue.Packet) *ProcInfo {
ss := getSocketForPacket(pkt)
if ss == nil {
return nil
}
return findProcessForSocket(ss)
}
func findProcessForSocket(ss *socketStatus) *ProcInfo {
exePath, err := os.Readlink(fmt.Sprintf("/proc/%d/exe", ss.pid))
if err != nil {
log.Warning("Error reading exe link for pid %d: %v", ss.pid, err)
return nil
}
bs, err := ioutil.ReadFile(fmt.Sprintf("/proc/%d/cmdline", ss.pid))
if err != nil {
log.Warning("Error reading cmdline for pid %d: %v", ss.pid, err)
return nil
}
for i, b := range bs {
if b == 0 {
bs[i] = byte(' ')
}
}
finfo, err := os.Stat(fmt.Sprintf("/proc/%d", ss.pid))
if err != nil {
log.Warning("Could not stat /proc/%d: %v", ss.pid, err)
return nil
}
finfo.Sys()
return &ProcInfo{
Pid: ss.pid,
Uid: ss.uid,
ExePath: exePath,
CmdLine: string(bs),
}
}
func getSocketLinesForPacket(pkt *nfqueue.Packet) []string {
if pkt.Protocol == nfqueue.TCP {
return getSocketLines("tcp")
} else if pkt.Protocol == nfqueue.UDP {
return getSocketLines("udp")
} else {
log.Warning("Cannot lookup socket for protocol %s", pkt.Protocol)
return nil
}
}
func getSocketLines(proto string) []string {
path := fmt.Sprintf("/proc/net/%s", proto)
data, err := ioutil.ReadFile(path)
if err != nil {
log.Warning("Error reading %s: %v", path, err)
return nil
}
lines := strings.Split(string(data), "\n")
if len(lines) > 0 {
lines = lines[1:]
}
return lines
} }
func (sa *socketAddr) parse(s string) error { func (sa *socketAddr) parse(s string) error {
@ -134,29 +57,6 @@ func (sa *socketAddr) parse(s string) error {
return nil return nil
} }
func (ss *socketStatus) parseLine(line string) error {
fs := strings.Fields(line)
if len(fs) < 10 {
return errors.New("insufficient fields")
}
if err := ss.local.parse(fs[1]); err != nil {
return err
}
if err := ss.remote.parse(fs[2]); err != nil {
return err
}
uid, err := strconv.ParseUint(fs[7], 10, 32)
if err != nil {
return err
}
ss.uid = int(uid)
inode, err := strconv.ParseUint(fs[9], 10, 64)
if err != nil {
return err
}
ss.inode = inode
return nil
}
func printPacket(pkt *nfqueue.Packet) string { func printPacket(pkt *nfqueue.Packet) string {
proto := func() string { proto := func() string {
@ -172,38 +72,6 @@ func printPacket(pkt *nfqueue.Packet) string {
return fmt.Sprintf("(%s %s:%d --> %s:%d)", proto, pkt.Src, pkt.SrcPort, pkt.Dst.String(), pkt.DstPort) return fmt.Sprintf("(%s %s:%d --> %s:%d)", proto, pkt.Src, pkt.SrcPort, pkt.Dst.String(), pkt.DstPort)
} }
func getSocketForPacket(pkt *nfqueue.Packet) *socketStatus {
ss := findSocket(pkt)
if ss == nil {
return nil
}
pid := findPidForInode(ss.inode)
if pid > 0 {
ss.pid = pid
return ss
}
log.Info("Unable to find socket link socket:[%d] %s", ss.inode, printPacket(pkt))
log.Info("Line was %s", ss.line)
return nil
}
func findSocket(pkt *nfqueue.Packet) *socketStatus {
var status socketStatus
for _, line := range getSocketLinesForPacket(pkt) {
if len(line) == 0 {
continue
}
if err := status.parseLine(line); err != nil {
log.Warning("Unable to parse line [%s]: %v", line, err)
} else if status.remote.ip.Equal(pkt.Dst) && status.remote.port == pkt.DstPort && status.local.ip.Equal(pkt.Src) && status.local.port == pkt.SrcPort {
status.line = line
return &status
}
}
log.Info("Failed to find socket for packet: %s", printPacket(pkt))
return nil
}
func ParseIp(ip string) (net.IP, error) { func ParseIp(ip string) (net.IP, error) {
var result net.IP var result net.IP
dst, err := hex.DecodeString(ip) dst, err := hex.DecodeString(ip)
@ -227,48 +95,6 @@ func ParsePort(port string) (uint16, error) {
return uint16(p64), nil return uint16(p64), nil
} }
func findPidForInode(inode uint64) int {
search := fmt.Sprintf("socket:[%d]", inode)
for _, pid := range getAllPids() {
if matchesSocketLink(pid, search) {
return pid
}
}
return -1
}
func matchesSocketLink(pid int, search string) bool {
paths, _ := filepath.Glob(fmt.Sprintf("/proc/%d/fd/*", pid))
for _, p := range paths {
link, err := os.Readlink(p)
if err == nil && link == search {
return true
}
}
return false
}
func getAllPids() []int {
var pids []int
d, err := os.Open("/proc")
if err != nil {
log.Warning("Error opening /proc: %v", err)
return nil
}
defer d.Close()
names, err := d.Readdirnames(0)
if err != nil {
log.Warning("Error reading directory names from /proc: %v", err)
return nil
}
for _, n := range names {
if pid, err := strconv.ParseUint(n, 10, 32); err == nil {
pids = append(pids, int(pid))
}
}
return pids
}
func getConnections() ([]*ConnectionInfo, error) { func getConnections() ([]*ConnectionInfo, error) {
conns,err := readConntrack() conns,err := readConntrack()
if err != nil { if err != nil {
@ -288,12 +114,14 @@ func resolveProcinfo(conns []*ConnectionInfo) {
if err := ss.parseLine(line); err != nil { if err := ss.parseLine(line); err != nil {
log.Warning("Unable to parse line [%s]: %v", line, err) log.Warning("Unable to parse line [%s]: %v", line, err)
} else { } else {
/*
pid := findPidForInode(ss.inode) pid := findPidForInode(ss.inode)
if pid > 0 { if pid > 0 {
ss.pid = pid ss.pid = pid
fmt.Println("Socket", ss) fmt.Println("Socket", ss)
sockets = append(sockets, ss) sockets = append(sockets, ss)
} }
*/
} }
} }
for _,ci := range conns { for _,ci := range conns {
@ -301,9 +129,9 @@ func resolveProcinfo(conns []*ConnectionInfo) {
if ss == nil { if ss == nil {
continue continue
} }
proc := findProcessForSocket(ss) pinfo := pcache.lookup(ss.inode)
if proc != nil { if pinfo != nil {
ci.proc = proc ci.pinfo = pinfo
} }
} }
} }

@ -7,6 +7,8 @@ import (
"strings" "strings"
"path" "path"
"io/ioutil" "io/ioutil"
"sync"
"syscall"
) )
@ -18,24 +20,38 @@ type ProcInfo struct {
CmdLine string CmdLine string
} }
var cacheMap = make(map[uint64]*ProcInfo) type pidCache struct {
cacheMap map[uint64]*ProcInfo
lock sync.Mutex
}
func pidCacheLookup(inode uint64) *ProcInfo { func (pc *pidCache) lookup(inode uint64) *ProcInfo {
pi,ok := cacheMap[inode] pc.lock.Lock()
if ok { defer pc.lock.Unlock()
pi,ok := pc.cacheMap[inode]
if ok && pi.loadProcessInfo() {
return pi
}
pc.cacheMap = loadCache()
pi,ok = pc.cacheMap[inode]
if ok && pi.loadProcessInfo() {
return pi return pi
} }
pidCacheReload() return nil
return cacheMap[inode]
} }
func pidCacheReload() { func loadCache() map[uint64]*ProcInfo {
cmap := make(map[uint64]*ProcInfo)
for _, n := range readdir("/proc") { for _, n := range readdir("/proc") {
pid := toPid(n) pid := toPid(n)
if pid != 0 { if pid != 0 {
scrapePid(pid) pinfo := &ProcInfo{Pid: pid}
for _,inode := range inodesFromPid(pid) {
cmap[inode] = pinfo
}
} }
} }
return cmap
} }
func toPid(name string) int { func toPid(name string) int {
@ -54,36 +70,34 @@ func toPid(name string) int {
return (int)(pid) return (int)(pid)
} }
func scrapePid(pid int) { func inodesFromPid(pid int) []uint64 {
var inodes []uint64
fdpath := fmt.Sprintf("/proc/%d/fd", pid) fdpath := fmt.Sprintf("/proc/%d/fd", pid)
for _, n := range readdir(fdpath) { for _, n := range readdir(fdpath) {
if link, err := os.Readlink(path.Join(fdpath, n)); err != nil { if link, err := os.Readlink(path.Join(fdpath, n)); err != nil {
if !os.IsNotExist(err) {
log.Warning("Error reading link %s: %v", n, err) log.Warning("Error reading link %s: %v", n, err)
}
} else { } else {
extractSocket(link, pid) if inode := extractSocket(link); inode > 0 {
inodes = append(inodes, inode)
} }
} }
}
return inodes
} }
func extractSocket(name string, pid int) { func extractSocket(name string) uint64 {
if !strings.HasPrefix(name, "socket:[") || !strings.HasSuffix(name, "]") { if !strings.HasPrefix(name, "socket:[") || !strings.HasSuffix(name, "]") {
return return 0
} }
val := name[8:len(name)-1] val := name[8:len(name)-1]
inode,err := strconv.ParseUint(val, 10, 64) inode,err := strconv.ParseUint(val, 10, 64)
if err != nil { if err != nil {
log.Warning("Error parsing inode value from %s: %v", name, err) log.Warning("Error parsing inode value from %s: %v", name, err)
return return 0
}
cacheAddPid(inode, pid)
}
func cacheAddPid(inode uint64, pid int) {
pi,ok := cacheMap[inode]
if ok && pi.Pid == pid {
return
} }
cacheMap[inode] = &ProcInfo{ Pid: pid } return inode
} }
func readdir(dir string) []string { func readdir(dir string) []string {
@ -127,7 +141,8 @@ func (pi *ProcInfo) loadProcessInfo() bool {
log.Warning("Could not stat /proc/%d: %v", pi.Pid, err) log.Warning("Could not stat /proc/%d: %v", pi.Pid, err)
return false return false
} }
finfo.Sys() sys := finfo.Sys().(*syscall.Stat_t)
pi.Uid = int(sys.Uid)
pi.ExePath = exePath pi.ExePath = exePath
pi.CmdLine = string(bs) pi.CmdLine = string(bs)
pi.loaded = true pi.loaded = true

@ -0,0 +1,95 @@
package proc
import (
"net"
"fmt"
"io/ioutil"
"strings"
"errors"
"strconv"
)
type socketAddr struct {
ip net.IP
port uint16
}
func (sa socketAddr) String() string {
return fmt.Sprintf("%v:%d", sa.ip, sa.port)
}
type socketStatus struct {
local socketAddr
remote socketAddr
uid int
inode uint64
line string
}
func (ss *socketStatus) String() string {
return fmt.Sprintf("%s -> %s uid=%d inode=%d", ss.local, ss.remote, ss.uid, ss.inode)
}
func findUDPSocket(srcPort uint16, dstAddr net.IP, dstPort uint16) *socketStatus {
return findSocket("udp", srcPort, dstAddr, dstPort)
}
func findTCPSocket(srcPort uint16, dstAddr net.IP, dstPort uint16) *socketStatus {
return findSocket("tcp", srcPort, dstAddr, dstPort)
}
func findSocket(proto string, srcPort uint16, dstAddr net.IP, dstPort uint16) *socketStatus {
var ss socketStatus
for _,line := range getSocketLines(proto) {
if len(line) == 0 {
continue
}
if err := ss.parseLine(line); err != nil {
log.Warning("Unable to parse line from /proc/net/%s [%s]: %v", proto, line, err)
continue
}
if ss.remote.port == dstPort && ss.remote.ip.Equal(dstAddr) && ss.local.port == srcPort {
ss.line = line
return &ss
}
}
return nil
}
func (ss *socketStatus) parseLine(line string) error {
fs := strings.Fields(line)
if len(fs) < 10 {
return errors.New("insufficient fields")
}
if err := ss.local.parse(fs[1]); err != nil {
return err
}
if err := ss.remote.parse(fs[2]); err != nil {
return err
}
uid, err := strconv.ParseUint(fs[7], 10, 32)
if err != nil {
return err
}
ss.uid = int(uid)
inode, err := strconv.ParseUint(fs[9], 10, 64)
if err != nil {
return err
}
ss.inode = inode
return nil
}
func getSocketLines(proto string) []string {
path := fmt.Sprintf("/proc/net/%s", proto)
data, err := ioutil.ReadFile(path)
if err != nil {
log.Warning("Error reading %s: %v", path, err)
return nil
}
lines := strings.Split(string(data), "\n")
if len(lines) > 0 {
lines = lines[1:]
}
return lines
}
Loading…
Cancel
Save