diff --git a/dbus.go b/dbus.go index f70b628..9d024de 100644 --- a/dbus.go +++ b/dbus.go @@ -3,12 +3,13 @@ package main import ( "errors" "fmt" + "path" "strings" + "github.com/op/go-logging" "github.com/subgraph/fw-daemon/Godeps/_workspace/src/github.com/godbus/dbus" "github.com/subgraph/fw-daemon/Godeps/_workspace/src/github.com/godbus/dbus/introspect" - "github.com/subgraph/fw-daemon/Godeps/_workspace/src/github.com/op/go-logging" - "path" + // "github.com/subgraph/fw-daemon/Godeps/_workspace/src/github.com/op/go-logging" ) const introspectXml = ` diff --git a/main.go b/main.go index 257decb..77a2ea3 100644 --- a/main.go +++ b/main.go @@ -4,14 +4,15 @@ import ( // _ "net/http/pprof" "os" "os/signal" - "time" - - "github.com/subgraph/fw-daemon/Godeps/_workspace/src/github.com/op/go-logging" - "github.com/subgraph/fw-daemon/nfqueue" - "github.com/subgraph/fw-daemon/proc" "sync" "syscall" + "time" "unsafe" + + "github.com/op/go-logging" + "github.com/subgraph/fw-daemon/nfqueue" + "github.com/subgraph/go-procsnitch" + // "github.com/subgraph/fw-daemon/Godeps/_workspace/src/github.com/op/go-logging" ) var log = logging.MustGetLogger("sgfw") @@ -136,7 +137,7 @@ func (fw *Firewall) runFilter() { func main() { logBackend := setupLoggerBackend() log.SetBackend(logBackend) - proc.SetLogger(log) + procsnitch.SetLogger(log) if os.Geteuid() != 0 { log.Error("Must be run as root") diff --git a/policy.go b/policy.go index bec5f0f..f76f3b5 100644 --- a/policy.go +++ b/policy.go @@ -5,14 +5,14 @@ import ( "sync" "github.com/subgraph/fw-daemon/nfqueue" - "github.com/subgraph/fw-daemon/proc" + "github.com/subgraph/go-procsnitch" ) type pendingPkt struct { policy *Policy hostname string pkt *nfqueue.Packet - pinfo *proc.ProcInfo + pinfo *procsnitch.Info } type Policy struct { @@ -43,7 +43,7 @@ func (fw *Firewall) policyForPath(path string) *Policy { return fw.policyMap[path] } -func (p *Policy) processPacket(pkt *nfqueue.Packet, pinfo *proc.ProcInfo) { +func (p *Policy) processPacket(pkt *nfqueue.Packet, pinfo *procsnitch.Info) { p.lock.Lock() defer p.lock.Unlock() name := p.fw.dns.Lookup(pkt.Dst) @@ -214,12 +214,12 @@ func (fw *Firewall) filterPacket(pkt *nfqueue.Packet) { policy.processPacket(pkt, pinfo) } -func findProcessForPacket(pkt *nfqueue.Packet) *proc.ProcInfo { +func findProcessForPacket(pkt *nfqueue.Packet) *procsnitch.Info { switch pkt.Protocol { case nfqueue.TCP: - return proc.LookupTCPSocketProcess(pkt.SrcPort, pkt.Dst, pkt.DstPort) + return procsnitch.LookupTCPSocketProcess(pkt.SrcPort, pkt.Dst, pkt.DstPort) case nfqueue.UDP: - return proc.LookupUDPSocketProcess(pkt.SrcPort) + return procsnitch.LookupUDPSocketProcess(pkt.SrcPort) default: log.Warning("Packet has unknown protocol: %d", pkt.Protocol) return nil diff --git a/proc/proc.go b/proc/proc.go deleted file mode 100644 index 939ebc4..0000000 --- a/proc/proc.go +++ /dev/null @@ -1,200 +0,0 @@ -package proc - -import ( - "encoding/hex" - "errors" - "fmt" - "github.com/subgraph/fw-daemon/Godeps/_workspace/src/github.com/op/go-logging" - "io/ioutil" - "net" - "strconv" - "strings" -) - -var log = logging.MustGetLogger("proc") - -func SetLogger(logger *logging.Logger) { - log = logger -} - -var pcache = &pidCache{} - -func LookupUDPSocketProcess(srcPort uint16) *ProcInfo { - ss := findUDPSocket(srcPort) - if ss == nil { - return nil - } - return pcache.lookup(ss.inode) -} - -func LookupTCPSocketProcess(srcPort uint16, dstAddr net.IP, dstPort uint16) *ProcInfo { - ss := findTCPSocket(srcPort, dstAddr, dstPort) - if ss == nil { - return nil - } - return pcache.lookup(ss.inode) -} - -type ConnectionInfo struct { - pinfo *ProcInfo - local *socketAddr - remote *socketAddr -} - -func (ci *ConnectionInfo) String() string { - return fmt.Sprintf("%v %s %s", ci.pinfo, ci.local, ci.remote) -} - -func (sa *socketAddr) parse(s string) error { - ipPort := strings.Split(s, ":") - if len(ipPort) != 2 { - return fmt.Errorf("badly formatted socket address field: %s", s) - } - ip, err := ParseIp(ipPort[0]) - if err != nil { - return fmt.Errorf("error parsing ip field [%s]: %v", ipPort[0], err) - } - port, err := ParsePort(ipPort[1]) - if err != nil { - return fmt.Errorf("error parsing port field [%s]: %v", ipPort[1], err) - } - sa.ip = ip - sa.port = port - return nil -} - -func ParseIp(ip string) (net.IP, error) { - var result net.IP - dst, err := hex.DecodeString(ip) - if err != nil { - return result, fmt.Errorf("Error parsing IP: %s", err) - } - // Reverse byte order -- /proc/net/tcp etc. is little-endian - // TODO: Does this vary by architecture? - for i, j := 0, len(dst)-1; i < j; i, j = i+1, j-1 { - dst[i], dst[j] = dst[j], dst[i] - } - result = net.IP(dst) - return result, nil -} - -func ParsePort(port string) (uint16, error) { - p64, err := strconv.ParseInt(port, 16, 32) - if err != nil { - return 0, fmt.Errorf("Error parsing port: %s", err) - } - return uint16(p64), nil -} - -func getConnections() ([]*ConnectionInfo, error) { - conns, err := readConntrack() - if err != nil { - return nil, err - } - resolveProcinfo(conns) - return conns, nil -} - -func resolveProcinfo(conns []*ConnectionInfo) { - var sockets []*socketStatus - for _, line := range getSocketLines("tcp") { - if len(strings.TrimSpace(line)) == 0 { - continue - } - ss := new(socketStatus) - if err := ss.parseLine(line); err != nil { - log.Warning("Unable to parse line [%s]: %v", line, err) - } else { - /* - pid := findPidForInode(ss.inode) - if pid > 0 { - ss.pid = pid - fmt.Println("Socket", ss) - sockets = append(sockets, ss) - } - */ - } - } - for _, ci := range conns { - ss := findContrackSocket(ci, sockets) - if ss == nil { - continue - } - pinfo := pcache.lookup(ss.inode) - if pinfo != nil { - ci.pinfo = pinfo - } - } -} - -func findContrackSocket(ci *ConnectionInfo, sockets []*socketStatus) *socketStatus { - for _, ss := range sockets { - if ss.local.port == ci.local.port && ss.remote.ip.Equal(ci.remote.ip) && ss.remote.port == ci.remote.port { - return ss - } - } - return nil -} - -func readConntrack() ([]*ConnectionInfo, error) { - path := fmt.Sprintf("/proc/net/ip_conntrack") - data, err := ioutil.ReadFile(path) - if err != nil { - return nil, err - } - var result []*ConnectionInfo - lines := strings.Split(string(data), "\n") - for _, line := range lines { - ci, err := parseConntrackLine(line) - if err != nil { - return nil, err - } - if ci != nil { - result = append(result, ci) - } - } - return result, nil -} - -func parseConntrackLine(line string) (*ConnectionInfo, error) { - parts := strings.Fields(line) - if len(parts) < 8 || parts[0] != "tcp" || parts[3] != "ESTABLISHED" { - return nil, nil - } - - local, err := conntrackAddr(parts[4], parts[6]) - if err != nil { - return nil, err - } - remote, err := conntrackAddr(parts[5], parts[7]) - if err != nil { - return nil, err - } - return &ConnectionInfo{ - local: local, - remote: remote, - }, nil -} - -func conntrackAddr(ip_str, port_str string) (*socketAddr, error) { - ip := net.ParseIP(stripLabel(ip_str)) - if ip == nil { - return nil, errors.New("Could not parse IP: " + ip_str) - } - i64, err := strconv.Atoi(stripLabel(port_str)) - if err != nil { - return nil, err - } - return &socketAddr{ - ip: ip, - port: uint16(i64), - }, nil -} - -func stripLabel(s string) string { - idx := strings.Index(s, "=") - if idx == -1 { - return s - } - return s[idx+1:] -} diff --git a/proc/proc_pid.go b/proc/proc_pid.go deleted file mode 100644 index 2000643..0000000 --- a/proc/proc_pid.go +++ /dev/null @@ -1,149 +0,0 @@ -package proc - -import ( - "fmt" - "io/ioutil" - "os" - "path" - "strconv" - "strings" - "sync" - "syscall" -) - -type ProcInfo struct { - Uid int - Pid int - loaded bool - ExePath string - CmdLine string -} - -type pidCache struct { - cacheMap map[uint64]*ProcInfo - lock sync.Mutex -} - -func (pc *pidCache) lookup(inode uint64) *ProcInfo { - pc.lock.Lock() - defer pc.lock.Unlock() - pi, ok := pc.cacheMap[inode] - if ok && pi.loadProcessInfo() { - return pi - } - pc.cacheMap = loadCache() - pi, ok = pc.cacheMap[inode] - if ok && pi.loadProcessInfo() { - return pi - } - return nil -} - -func loadCache() map[uint64]*ProcInfo { - cmap := make(map[uint64]*ProcInfo) - for _, n := range readdir("/proc") { - pid := toPid(n) - if pid != 0 { - pinfo := &ProcInfo{Pid: pid} - for _, inode := range inodesFromPid(pid) { - cmap[inode] = pinfo - } - } - } - return cmap -} - -func toPid(name string) int { - pid, err := strconv.ParseUint(name, 10, 32) - if err != nil { - return 0 - } - fdpath := fmt.Sprintf("/proc/%d/fd", pid) - fi, err := os.Stat(fdpath) - if err != nil { - return 0 - } - if !fi.IsDir() { - return 0 - } - return (int)(pid) -} - -func inodesFromPid(pid int) []uint64 { - var inodes []uint64 - fdpath := fmt.Sprintf("/proc/%d/fd", pid) - for _, n := range readdir(fdpath) { - if link, err := os.Readlink(path.Join(fdpath, n)); err != nil { - if !os.IsNotExist(err) { - log.Warning("Error reading link %s: %v", n, err) - } - } else { - if inode := extractSocket(link); inode > 0 { - inodes = append(inodes, inode) - } - } - } - return inodes -} - -func extractSocket(name string) uint64 { - if !strings.HasPrefix(name, "socket:[") || !strings.HasSuffix(name, "]") { - return 0 - } - val := name[8 : len(name)-1] - inode, err := strconv.ParseUint(val, 10, 64) - if err != nil { - log.Warning("Error parsing inode value from %s: %v", name, err) - return 0 - } - return inode -} - -func readdir(dir string) []string { - d, err := os.Open(dir) - if err != nil { - log.Warning("Error opening directory %s: %v", dir, err) - return nil - } - defer d.Close() - names, err := d.Readdirnames(0) - if err != nil { - log.Warning("Error reading directory names from %s: %v", dir, err) - return nil - } - return names -} - -func (pi *ProcInfo) loadProcessInfo() bool { - if pi.loaded { - return true - } - - exePath, err := os.Readlink(fmt.Sprintf("/proc/%d/exe", pi.Pid)) - if err != nil { - log.Warning("Error reading exe link for pid %d: %v", pi.Pid, err) - return false - } - bs, err := ioutil.ReadFile(fmt.Sprintf("/proc/%d/cmdline", pi.Pid)) - if err != nil { - log.Warning("Error reading cmdline for pid %d: %v", pi.Pid, err) - return false - } - for i, b := range bs { - if b == 0 { - bs[i] = byte(' ') - } - } - - finfo, err := os.Stat(fmt.Sprintf("/proc/%d", pi.Pid)) - if err != nil { - log.Warning("Could not stat /proc/%d: %v", pi.Pid, err) - return false - } - sys := finfo.Sys().(*syscall.Stat_t) - pi.Uid = int(sys.Uid) - pi.ExePath = exePath - pi.CmdLine = string(bs) - pi.loaded = true - return true -} diff --git a/proc/socket.go b/proc/socket.go deleted file mode 100644 index 7792ab9..0000000 --- a/proc/socket.go +++ /dev/null @@ -1,99 +0,0 @@ -package proc - -import ( - "errors" - "fmt" - "io/ioutil" - "net" - "strconv" - "strings" -) - -type socketAddr struct { - ip net.IP - port uint16 -} - -func (sa socketAddr) String() string { - return fmt.Sprintf("%v:%d", sa.ip, sa.port) -} - -type socketStatus struct { - local socketAddr - remote socketAddr - uid int - inode uint64 - line string -} - -func (ss *socketStatus) String() string { - return fmt.Sprintf("%s -> %s uid=%d inode=%d", ss.local, ss.remote, ss.uid, ss.inode) -} - -func findUDPSocket(srcPort uint16) *socketStatus { - return findSocket("udp", func(ss socketStatus) bool { - return ss.local.port == srcPort - }) -} - -func findTCPSocket(srcPort uint16, dstAddr net.IP, dstPort uint16) *socketStatus { - return findSocket("tcp", func(ss socketStatus) bool { - return ss.remote.port == dstPort && ss.remote.ip.Equal(dstAddr) && ss.local.port == srcPort - }) -} - -func findSocket(proto string, matcher func(socketStatus) bool) *socketStatus { - var ss socketStatus - for _, line := range getSocketLines(proto) { - if len(line) == 0 { - continue - } - if err := ss.parseLine(line); err != nil { - log.Warning("Unable to parse line from /proc/net/%s [%s]: %v", proto, line, err) - continue - } - if matcher(ss) { - ss.line = line - return &ss - } - } - return nil -} - -func (ss *socketStatus) parseLine(line string) error { - fs := strings.Fields(line) - if len(fs) < 10 { - return errors.New("insufficient fields") - } - if err := ss.local.parse(fs[1]); err != nil { - return err - } - if err := ss.remote.parse(fs[2]); err != nil { - return err - } - uid, err := strconv.ParseUint(fs[7], 10, 32) - if err != nil { - return err - } - ss.uid = int(uid) - inode, err := strconv.ParseUint(fs[9], 10, 64) - if err != nil { - return err - } - ss.inode = inode - return nil -} - -func getSocketLines(proto string) []string { - path := fmt.Sprintf("/proc/net/%s", proto) - data, err := ioutil.ReadFile(path) - if err != nil { - log.Warning("Error reading %s: %v", path, err) - return nil - } - lines := strings.Split(string(data), "\n") - if len(lines) > 0 { - lines = lines[1:] - } - return lines -} diff --git a/prompt.go b/prompt.go index f82ae08..6d52d9c 100644 --- a/prompt.go +++ b/prompt.go @@ -92,7 +92,7 @@ func (p *prompter) processPacket(pp *pendingPkt) { addr, int32(pp.pkt.DstPort), pp.pkt.Dst.String(), - uidToUser(pp.pinfo.Uid), + uidToUser(pp.pinfo.UID), int32(pp.pinfo.Pid)) err := call.Store(&scope, &rule) if err != nil { diff --git a/rules.go b/rules.go index 418264a..6767478 100644 --- a/rules.go +++ b/rules.go @@ -3,16 +3,16 @@ package main import ( "encoding/binary" "fmt" + "io/ioutil" "net" + "os" + "path" + "strconv" "strings" "unicode" "github.com/subgraph/fw-daemon/nfqueue" - "github.com/subgraph/fw-daemon/proc" - "io/ioutil" - "os" - "path" - "strconv" + "github.com/subgraph/go-procsnitch" ) const ( @@ -91,7 +91,7 @@ const ( FILTER_PROMPT ) -func (rl *RuleList) filter(p *nfqueue.Packet, pinfo *proc.ProcInfo, hostname string) FilterResult { +func (rl *RuleList) filter(p *nfqueue.Packet, pinfo *procsnitch.Info, hostname string) FilterResult { if rl == nil { return FILTER_PROMPT }