Moved /proc parsing code into a new package

pull/16/head
Bruce Leidl 9 years ago
parent bb2a0f96a4
commit 30a21eaa57

@ -11,6 +11,7 @@ import (
"sync" "sync"
"syscall" "syscall"
"unsafe" "unsafe"
"github.com/subgraph/fw-daemon/proc"
) )
var log = logging.MustGetLogger("sgfw") var log = logging.MustGetLogger("sgfw")
@ -71,6 +72,7 @@ func (fw *Firewall) runFilter() {
} }
func main() { func main() {
proc.SetLogger(log)
if os.Geteuid() != 0 { if os.Geteuid() != 0 {
log.Error("Must be run as root") log.Error("Must be run as root")

@ -5,13 +5,14 @@ import (
"sync" "sync"
"github.com/subgraph/fw-daemon/nfqueue" "github.com/subgraph/fw-daemon/nfqueue"
"github.com/subgraph/fw-daemon/proc"
) )
type pendingPkt struct { type pendingPkt struct {
policy *Policy policy *Policy
hostname string hostname string
pkt *nfqueue.Packet pkt *nfqueue.Packet
proc *ProcInfo pinfo *proc.ProcInfo
} }
type Policy struct { type Policy struct {
@ -42,12 +43,12 @@ func (fw *Firewall) policyForPath(path string) *Policy {
return fw.policyMap[path] return fw.policyMap[path]
} }
func (p *Policy) processPacket(pkt *nfqueue.Packet, proc *ProcInfo) { func (p *Policy) processPacket(pkt *nfqueue.Packet, pinfo *proc.ProcInfo) {
p.lock.Lock() p.lock.Lock()
defer p.lock.Unlock() defer p.lock.Unlock()
name := p.fw.dns.Lookup(pkt.Dst) name := p.fw.dns.Lookup(pkt.Dst)
log.Info("Lookup(%s): %s", pkt.Dst.String(), name) log.Info("Lookup(%s): %s", pkt.Dst.String(), name)
result := p.rules.filter(pkt, proc, name) result := p.rules.filter(pkt, pinfo, name)
switch result { switch result {
case FILTER_DENY: case FILTER_DENY:
pkt.Mark = 1 pkt.Mark = 1
@ -55,7 +56,7 @@ func (p *Policy) processPacket(pkt *nfqueue.Packet, proc *ProcInfo) {
case FILTER_ALLOW: case FILTER_ALLOW:
pkt.Accept() pkt.Accept()
case FILTER_PROMPT: case FILTER_PROMPT:
p.processPromptResult(&pendingPkt{policy: p, hostname: name, pkt: pkt, proc: proc}) p.processPromptResult(&pendingPkt{policy: p, hostname: name, pkt: pkt, pinfo: pinfo})
default: default:
log.Warning("Unexpected filter result: %d", result) log.Warning("Unexpected filter result: %d", result)
} }
@ -162,21 +163,21 @@ func (fw *Firewall) filterPacket(pkt *nfqueue.Packet) {
fw.dns.processDNS(pkt) fw.dns.processDNS(pkt)
return return
} }
proc := findProcessForPacket(pkt) pinfo := proc.FindProcessForPacket(pkt)
if proc == nil { if pinfo == nil {
log.Warning("No proc found for %s", printPacket(pkt, fw.dns.Lookup(pkt.Dst))) log.Warning("No proc found for %s", printPacket(pkt, fw.dns.Lookup(pkt.Dst)))
pkt.Accept() pkt.Accept()
return return
} }
log.Debug("filterPacket [%s] %s", proc.exePath, printPacket(pkt, fw.dns.Lookup(pkt.Dst))) log.Debug("filterPacket [%s] %s", pinfo.ExePath, printPacket(pkt, fw.dns.Lookup(pkt.Dst)))
if basicAllowPacket(pkt) { if basicAllowPacket(pkt) {
pkt.Accept() pkt.Accept()
return return
} }
fw.lock.Lock() fw.lock.Lock()
policy := fw.policyForPath(proc.exePath) policy := fw.policyForPath(pinfo.ExePath)
fw.lock.Unlock() fw.lock.Unlock()
policy.processPacket(pkt, proc) policy.processPacket(pkt, pinfo)
} }
func basicAllowPacket(pkt *nfqueue.Packet) bool { func basicAllowPacket(pkt *nfqueue.Packet) bool {

@ -1,4 +1,4 @@
package main package proc
import ( import (
"encoding/hex" "encoding/hex"
@ -10,10 +10,15 @@ import (
"path/filepath" "path/filepath"
"strconv" "strconv"
"strings" "strings"
"github.com/subgraph/fw-daemon/Godeps/_workspace/src/github.com/op/go-logging"
"github.com/subgraph/fw-daemon/nfqueue" "github.com/subgraph/fw-daemon/nfqueue"
) )
var log = logging.MustGetLogger("proc")
func SetLogger(logger *logging.Logger) {
log = logger
}
type socketAddr struct { type socketAddr struct {
ip net.IP ip net.IP
port uint16 port uint16
@ -47,7 +52,7 @@ func (ci *ConnectionInfo) String() string {
return fmt.Sprintf("%v %s %s", ci.proc, ci.local, ci.remote) return fmt.Sprintf("%v %s %s", ci.proc, ci.local, ci.remote)
} }
func findProcessForPacket(pkt *nfqueue.Packet) *ProcInfo { func FindProcessForPacket(pkt *nfqueue.Packet) *ProcInfo {
ss := getSocketForPacket(pkt) ss := getSocketForPacket(pkt)
if ss == nil { if ss == nil {
return nil return nil
@ -79,10 +84,10 @@ func findProcessForSocket(ss *socketStatus) *ProcInfo {
} }
finfo.Sys() finfo.Sys()
return &ProcInfo{ return &ProcInfo{
pid: ss.pid, Pid: ss.pid,
uid: ss.uid, Uid: ss.uid,
exePath: exePath, ExePath: exePath,
cmdLine: string(bs), CmdLine: string(bs),
} }
} }
@ -153,6 +158,20 @@ func (ss *socketStatus) parseLine(line string) error {
return nil return nil
} }
func printPacket(pkt *nfqueue.Packet) string {
proto := func() string {
switch pkt.Protocol {
case nfqueue.TCP:
return "TCP"
case nfqueue.UDP:
return "UDP"
default:
return "???"
}
}()
return fmt.Sprintf("(%s %s:%d --> %s:%d)", proto, pkt.Src, pkt.SrcPort, pkt.Dst.String(), pkt.DstPort)
}
func getSocketForPacket(pkt *nfqueue.Packet) *socketStatus { func getSocketForPacket(pkt *nfqueue.Packet) *socketStatus {
ss := findSocket(pkt) ss := findSocket(pkt)
if ss == nil { if ss == nil {
@ -163,7 +182,7 @@ func getSocketForPacket(pkt *nfqueue.Packet) *socketStatus {
ss.pid = pid ss.pid = pid
return ss return ss
} }
log.Info("Unable to find socket link socket:[%d] %s", ss.inode, printPacket(pkt, "")) log.Info("Unable to find socket link socket:[%d] %s", ss.inode, printPacket(pkt))
log.Info("Line was %s", ss.line) log.Info("Line was %s", ss.line)
return nil return nil
} }
@ -181,7 +200,7 @@ func findSocket(pkt *nfqueue.Packet) *socketStatus {
return &status return &status
} }
} }
log.Info("Failed to find socket for packet: %s", printPacket(pkt, "")) log.Info("Failed to find socket for packet: %s", printPacket(pkt))
return nil return nil
} }

@ -1,4 +1,5 @@
package main package proc
import ( import (
"os" "os"
"strconv" "strconv"
@ -10,13 +11,13 @@ import (
type ProcInfo struct { type ProcInfo struct {
pid int Uid int
Pid int
loaded bool loaded bool
exePath string ExePath string
cmdLine string CmdLine string
} }
var cacheMap = make(map[uint64]*ProcInfo) var cacheMap = make(map[uint64]*ProcInfo)
func pidCacheLookup(inode uint64) *ProcInfo { func pidCacheLookup(inode uint64) *ProcInfo {
@ -79,10 +80,10 @@ func extractSocket(name string, pid int) {
func cacheAddPid(inode uint64, pid int) { func cacheAddPid(inode uint64, pid int) {
pi,ok := cacheMap[inode] pi,ok := cacheMap[inode]
if ok && pi.pid == pid { if ok && pi.Pid == pid {
return return
} }
cacheMap[inode] = &ProcInfo{ pid: pid } cacheMap[inode] = &ProcInfo{ Pid: pid }
} }
func readdir(dir string) []string { func readdir(dir string) []string {
@ -105,14 +106,14 @@ func (pi *ProcInfo) loadProcessInfo() bool {
return true return true
} }
exePath, err := os.Readlink(fmt.Sprintf("/proc/%d/exe", pi.pid)) exePath, err := os.Readlink(fmt.Sprintf("/proc/%d/exe", pi.Pid))
if err != nil { if err != nil {
log.Warning("Error reading exe link for pid %d: %v", pi.pid, err) log.Warning("Error reading exe link for pid %d: %v", pi.Pid, err)
return false return false
} }
bs, err := ioutil.ReadFile(fmt.Sprintf("/proc/%d/cmdline", pi.pid)) bs, err := ioutil.ReadFile(fmt.Sprintf("/proc/%d/cmdline", pi.Pid))
if err != nil { if err != nil {
log.Warning("Error reading cmdline for pid %d: %v", pi.pid, err) log.Warning("Error reading cmdline for pid %d: %v", pi.Pid, err)
return false return false
} }
for i, b := range bs { for i, b := range bs {
@ -121,14 +122,14 @@ func (pi *ProcInfo) loadProcessInfo() bool {
} }
} }
finfo, err := os.Stat(fmt.Sprintf("/proc/%d", pi.pid)) finfo, err := os.Stat(fmt.Sprintf("/proc/%d", pi.Pid))
if err != nil { if err != nil {
log.Warning("Could not stat /proc/%d: %v", pi.pid, err) log.Warning("Could not stat /proc/%d: %v", pi.Pid, err)
return false return false
} }
finfo.Sys() finfo.Sys()
pi.exePath = exePath pi.ExePath = exePath
pi.cmdLine = string(bs) pi.CmdLine = string(bs)
pi.loaded = true pi.loaded = true
return true return true
} }

@ -92,8 +92,8 @@ func (p *prompter) processPacket(pp *pendingPkt) {
addr, addr,
int32(pp.pkt.DstPort), int32(pp.pkt.DstPort),
pp.pkt.Dst.String(), pp.pkt.Dst.String(),
uidToUser(pp.proc.uid), uidToUser(pp.pinfo.Uid),
int32(pp.proc.pid)) int32(pp.pinfo.Pid))
err := call.Store(&scope, &rule) err := call.Store(&scope, &rule)
if err != nil { if err != nil {
log.Warning("Error sending dbus RequestPrompt message: %v", err) log.Warning("Error sending dbus RequestPrompt message: %v", err)

@ -12,6 +12,7 @@ import (
"os" "os"
"strconv" "strconv"
"path" "path"
"github.com/subgraph/fw-daemon/proc"
) )
const ( const (
@ -77,14 +78,14 @@ const (
FILTER_PROMPT FILTER_PROMPT
) )
func (rl *RuleList) filter(p *nfqueue.Packet, proc *ProcInfo, hostname string) FilterResult { func (rl *RuleList) filter(p *nfqueue.Packet, pinfo *proc.ProcInfo, hostname string) FilterResult {
if rl == nil { if rl == nil {
return FILTER_PROMPT return FILTER_PROMPT
} }
result := FILTER_PROMPT result := FILTER_PROMPT
for _, r := range *rl { for _, r := range *rl {
if r.match(p, hostname) { if r.match(p, hostname) {
log.Info("%s (%s -> %s:%d)", r, proc.exePath, p.Dst.String(), p.DstPort) log.Info("%s (%s -> %s:%d)", r, pinfo.ExePath, p.Dst.String(), p.DstPort)
if r.rtype == RULE_DENY { if r.rtype == RULE_DENY {
return FILTER_DENY return FILTER_DENY
} else if r.rtype == RULE_ALLOW { } else if r.rtype == RULE_ALLOW {

Loading…
Cancel
Save