diff --git a/proc.go b/proc.go index 28c2db3..4ae0497 100644 --- a/proc.go +++ b/proc.go @@ -14,18 +14,15 @@ import ( "github.com/subgraph/fw-daemon/nfqueue" ) -type ProcInfo struct { - pid int - uid int - exePath string - cmdLine string -} - type socketAddr struct { ip net.IP port uint16 } +func (sa socketAddr) String() string { + return fmt.Sprintf("%v:%d", sa.ip, sa.port) +} + type socketStatus struct { local socketAddr remote socketAddr @@ -36,11 +33,29 @@ type socketStatus struct { line string } +func (ss *socketStatus) String() string { + return fmt.Sprintf("%s -> %s uid=%d inode=%d pid=%d", ss.local, ss.remote, ss.uid, ss.inode, ss.pid) +} + +type ConnectionInfo struct { + proc *ProcInfo + local *socketAddr + remote *socketAddr +} + +func (ci *ConnectionInfo) String() string { + return fmt.Sprintf("%v %s %s", ci.proc, ci.local, ci.remote) +} + func findProcessForPacket(pkt *nfqueue.Packet) *ProcInfo { ss := getSocketForPacket(pkt) if ss == nil { return nil } + return findProcessForSocket(ss) +} +func findProcessForSocket(ss *socketStatus) *ProcInfo { + exePath, err := os.Readlink(fmt.Sprintf("/proc/%d/exe", ss.pid)) if err != nil { log.Warning("Error reading exe link for pid %d: %v", ss.pid, err) @@ -234,3 +249,114 @@ func getAllPids() []int { } return pids } + +func getConnections() ([]*ConnectionInfo, error) { + conns,err := readConntrack() + if err != nil { + return nil, err + } + resolveProcinfo(conns) + return conns, nil +} + +func resolveProcinfo(conns []*ConnectionInfo) { + var sockets []*socketStatus + for _,line := range getSocketLines("tcp") { + if len(strings.TrimSpace(line)) == 0 { + continue + } + ss := new(socketStatus) + if err := ss.parseLine(line); err != nil { + log.Warning("Unable to parse line [%s]: %v", line, err) + } else { + pid := findPidForInode(ss.inode) + if pid > 0 { + ss.pid = pid + fmt.Println("Socket", ss) + sockets = append(sockets, ss) + } + } + } + for _,ci := range conns { + ss := findContrackSocket(ci, sockets) + if ss == nil { + continue + } + proc := findProcessForSocket(ss) + if proc != nil { + ci.proc = proc + } + } +} + +func findContrackSocket(ci *ConnectionInfo, sockets []*socketStatus) *socketStatus { + for _,ss := range sockets { + if ss.local.port == ci.local.port && ss.remote.ip.Equal(ci.remote.ip) && ss.remote.port == ci.remote.port { + return ss + } + } + return nil +} + +func readConntrack() ([]*ConnectionInfo, error) { + path := fmt.Sprintf("/proc/net/ip_conntrack") + data, err := ioutil.ReadFile(path) + if err != nil { + return nil, err + } + var result []*ConnectionInfo + lines := strings.Split(string(data), "\n") + for _,line := range(lines) { + ci,err := parseConntrackLine(line) + if err != nil { + return nil, err + } + if ci != nil { + result = append(result, ci) + } + } + return result, nil +} + +func parseConntrackLine(line string) (*ConnectionInfo, error) { + parts := strings.Fields(line) + if len(parts) < 8 || parts[0] != "tcp" || parts[3] != "ESTABLISHED" { + return nil, nil + } + + local,err := conntrackAddr(parts[4], parts[6]) + if err != nil { + return nil, err + } + remote,err := conntrackAddr(parts[5], parts[7]) + if err != nil { + return nil, err + } + return &ConnectionInfo{ + local: local, + remote: remote, + },nil +} + +func conntrackAddr(ip_str, port_str string) (*socketAddr, error) { + ip := net.ParseIP(stripLabel(ip_str)) + if ip == nil { + return nil, errors.New("Could not parse IP: "+ip_str) + } + i64, err := strconv.Atoi(stripLabel(port_str)) + if err != nil { + return nil, err + } + return &socketAddr{ + ip: ip, + port: uint16(i64), + },nil +} + +func stripLabel(s string) string { + idx := strings.Index(s, "=") + if idx == -1 { + return s + } + return s[idx+1:] +} diff --git a/proc_pid.go b/proc_pid.go new file mode 100644 index 0000000..3783359 --- /dev/null +++ b/proc_pid.go @@ -0,0 +1,129 @@ +package main +import ( + "os" + "strconv" + "fmt" + "strings" + "path" + "io/ioutil" +) + + +type ProcInfo struct { + pid int + uid int + exePath string + cmdLine string +} + + +var cacheMap = make(map[uint64]*ProcInfo) + +func pidCacheLookup(inode uint64) *ProcInfo { + pi,ok := cacheMap[inode] + if ok { + return pi + } + pidCacheReload() + return cacheMap[inode] +} + +func pidCacheReload() { + for _, n := range readdir("/proc") { + pid := toPid(n) + if pid != 0 { + scrapePid(pid) + } + } +} + +func toPid(name string) int { + pid, err := strconv.ParseUint(name, 10, 32) + if err != nil { + return 0 + } + fdpath := fmt.Sprintf("/proc/%d/fd", pid) + fi,err := os.Stat(fdpath) + if err != nil { + return 0 + } + if !fi.IsDir() { + return 0 + } + return (int)(pid) +} + +func scrapePid(pid int) { + fdpath := fmt.Sprintf("/proc/%d/fd", pid) + for _, n := range readdir(fdpath) { + if link, err := os.Readlink(path.Join(fdpath, n)); err != nil { + log.Warning("Error reading link %s: %v", n, err) + } else { + extractSocket(link, pid) + } + } +} + +func extractSocket(name string, pid int) { + if !strings.HasPrefix(name, "socket:[") || !strings.HasSuffix(name, "]") { + return + } + val := name[8:len(name)-1] + inode,err := strconv.ParseUint(val, 10, 64) + if err != nil { + log.Warning("Error parsing inode value from %s: %v", name, err) + return + } + cacheAddPid(inode, pid) +} + +func cacheAddPid(inode uint64, pid int) { + pi,ok := cacheMap[inode] + if ok && pi.pid == pid { + return + } + cacheMap[inode] = &ProcInfo{ pid: pid } +} + +func readdir(dir string) []string { + d,err := os.Open(dir) + if err != nil { + log.Warning("Error opening directory %s: %v", dir, err) + return nil + } + defer d.Close() + names, err := d.Readdirnames(0) + if err != nil { + log.Warning("Error reading directory names from %s: %v", dir, err) + return nil + } + return names +} + +func (pi *ProcInfo) loadProcessInfo() bool { + + exePath, err := os.Readlink(fmt.Sprintf("/proc/%d/exe", pi.pid)) + if err != nil { + log.Warning("Error reading exe link for pid %d: %v", pi.pid, err) + return false + } + bs, err := ioutil.ReadFile(fmt.Sprintf("/proc/%d/cmdline", pi.pid)) + if err != nil { + log.Warning("Error reading cmdline for pid %d: %v", pi.pid, err) + return false + } + for i, b := range bs { + if b == 0 { + bs[i] = byte(' ') + } + } + + finfo, err := os.Stat(fmt.Sprintf("/proc/%d", pi.pid)) + if err != nil { + log.Warning("Could not stat /proc/%d: %v", pi.pid, err) + return false + } + finfo.Sys() + pi.exePath = exePath + pi.cmdLine = string(bs) +}