diff --git a/proc-coroner/pcoroner.go b/proc-coroner/pcoroner.go new file mode 100644 index 0000000..a012af5 --- /dev/null +++ b/proc-coroner/pcoroner.go @@ -0,0 +1,154 @@ +package pcoroner + +import ( + "fmt" + "time" + "strings" + "strconv" + "sync" + "os" + "syscall" +) + + +type WatchProcess struct { + Pid int + Inode uint64 + Ppid int + Stime int +} + +type procCB func(int, interface{}) + + +var pmutex = &sync.Mutex{} +var pidMap map[int]WatchProcess = make(map[int]WatchProcess) + + +func MonitorProcess(pid int) bool { + pmutex.Lock() + defer pmutex.Unlock() + + _, ok := pidMap[pid] + + if ok { + return false + } + + watcher := WatchProcess{Pid: pid} + watcher.Inode = 0 + res := checkProcess(&watcher, true) + + if res { + pidMap[pid] = watcher + } + + return res +} + +func UnmonitorProcess(pid int) { + pmutex.Lock() + defer pmutex.Unlock() + delete(pidMap, pid) + return +} + +func MonitorThread(cbfunc procCB, param interface{}) { + for { +/* if len(pidMap) == 0 { + fmt.Println("TICK") + } else { fmt.Println("len = ", len(pidMap)) } */ + pmutex.Lock() + pmutex.Unlock() + + for pkey, pval := range pidMap { +// fmt.Printf("PID %v -> %v\n", pkey, pval) + res := checkProcess(&pval, false) + + if !res { + delete(pidMap, pkey) + + if cbfunc != nil { + cbfunc(pkey, param) + } + continue + } + + } + + time.Sleep(2 * time.Second) + } +} + +func checkProcess(proc *WatchProcess, init bool) bool { + ppath := fmt.Sprintf("/proc/%d/stat", proc.Pid) + f, err := os.Open(ppath) + if err != nil { +// fmt.Printf("Error opening path %s: %s\n", ppath, err) + return false + } + defer f.Close() + + fi, err := f.Stat() + if err != nil { + fmt.Printf("Error calling stat on file %s: %s\n", ppath, err) + return false + } + sb, ok := fi.Sys().(*syscall.Stat_t) + if !ok { + fmt.Println("Unexpected error reading stat information from proc file") + } else if init { + proc.Inode = sb.Ino + } else { + if sb.Ino != proc.Inode { + fmt.Printf("/proc inode mismatch for process %d: %v vs %v\n", proc.Pid, sb.Ino, proc.Inode) + return false + } + } + + var buf [512]byte + nread, err := f.Read(buf[:]) + if err != nil { + fmt.Printf("Error reading stat for process %d: %v", proc.Pid, err) + return true + } else if nread <= 0 { + fmt.Printf("Unexpected error reading stat for process %d", proc.Pid) + return true + } + + bstr := string(buf[:]) +// fmt.Println("sstr = ", bstr) + + fields := strings.Split(bstr, " ") + + if len(fields) < 22 { + fmt.Printf("Unexpected error reading data from /proc stat for process %d", proc.Pid) + return true + } + + ppid, err := strconv.Atoi(fields[3]) + if err != nil { + ppid = -1 + } + + if init { + proc.Ppid = ppid + } else if proc.Ppid != ppid { + fmt.Printf("Cached process ppid did not match value in /proc: %v vs %v\n", proc.Ppid, ppid) + return false + } + + stime, err := strconv.Atoi(fields[21]) + if err != nil { + stime = -1 + } + + if init { + proc.Stime = stime + } else if proc.Stime != stime { + fmt.Printf("Cached process start time did not match value in /proc: %v vs %v\n", proc.Stime, stime) + return false + } + + return true +} diff --git a/sgfw/dns.go b/sgfw/dns.go index 98faa0d..823ecc7 100644 --- a/sgfw/dns.go +++ b/sgfw/dns.go @@ -4,28 +4,55 @@ import ( "net" "strings" "sync" + "time" // "github.com/subgraph/go-nfnetlink" - "github.com/google/gopacket" "github.com/google/gopacket/layers" + nfqueue "github.com/subgraph/go-nfnetlink/nfqueue" + "github.com/subgraph/go-procsnitch" + "github.com/subgraph/fw-daemon/proc-coroner" ) +var monitoring = false +var mlock = sync.Mutex{} + +type dnsEntry struct { + name string + ttl uint32 + exp time.Time +} + type dnsCache struct { - ipMap map[string]string + ipMap map[int]map[string]dnsEntry lock sync.Mutex done chan struct{} } +func newDNSEntry(hostname string, ttl uint32) dnsEntry { + newEntry := dnsEntry{ + name: hostname, + ttl: ttl, + exp: time.Now().Add(time.Second * time.Duration(ttl)), + } + return newEntry +} + func newDNSCache() *dnsCache { - return &dnsCache{ - ipMap: make(map[string]string), + newCache := &dnsCache{ + ipMap: make(map[int]map[string]dnsEntry), done: make(chan struct{}), } + newCache.ipMap[0] = make(map[string]dnsEntry) + return newCache +} + +func isNSTrusted(src net.IP) bool { + return src.IsLoopback() } -func (dc *dnsCache) processDNS(pkt gopacket.Packet) { +func (dc *dnsCache) processDNS(pkt *nfqueue.NFQPacket) { dns := &dnsMsg{} - if !dns.Unpack(pkt.Layer(layers.LayerTypeDNS).LayerContents()) { + if !dns.Unpack(pkt.Packet.Layer(layers.LayerTypeDNS).LayerContents()) { log.Warning("Failed to Unpack DNS message") return } @@ -38,14 +65,48 @@ func (dc *dnsCache) processDNS(pkt gopacket.Packet) { } q := dns.question[0] if q.Qtype == dnsTypeA { - dc.processRecordA(q.Name, dns.answer) + srcip, _ := getPacketIP4Addrs(pkt) + pinfo := getEmptyPInfo() + if !isNSTrusted(srcip) { + pinfo, _ = findProcessForPacket(pkt, true, procsnitch.MATCH_LOOSEST) + + if pinfo == nil { + log.Warningf("Skipping attempted DNS cache entry for process that can't be found: %v -> %v\n", q.Name, dns.answer) + return + } + } +//log.Notice("XXX: PROCESS LOOKUP -> ", pinfo) + dc.processRecordA(q.Name, dns.answer, pinfo.Pid) return } log.Infof("Unhandled DNS message: %v", dns) } -func (dc *dnsCache) processRecordA(name string, answers []dnsRR) { +/*func checker(c *dnsCache) { + for { + log.Error("CACHE CHECKER") + c.lock.Lock() + for k, v := range c.ipMap { + log.Errorf("IN CACHE: %v -> %v\n", k, v) + } + c.lock.Unlock() + time.Sleep(2 * time.Second) + } +} */ + +func procDeathCallback(pid int, param interface{}) { +// log.Warning("XXX: IN CALLBACK for pid: ", pid, " / param = ", param) + + if pid != 0 { + cache := param.(*dnsCache) + cache.lock.Lock() + delete(cache.ipMap, pid) + cache.lock.Unlock() + } +} + +func (dc *dnsCache) processRecordA(name string, answers []dnsRR, pid int) { dc.lock.Lock() defer dc.lock.Unlock() for _, rr := range answers { @@ -55,19 +116,78 @@ func (dc *dnsCache) processRecordA(name string, answers []dnsRR) { if strings.HasSuffix(name, ".") { name = name[:len(name)-1] } - log.Notice("______ Adding to dns map: %s: %s", name, ip) - dc.ipMap[ip] = name + + // Just in case. + if pid < 0 { + pid = 0 + } + log.Noticef("______ Adding to dns map: %s: %s -> pid %d", name, ip, pid) + + _, ok := dc.ipMap[pid] + if !ok { + dc.ipMap[pid] = make(map[string]dnsEntry) + } + dc.ipMap[pid][ip] = newDNSEntry(name, rr.Header().TTL) + + if pid > 0 { + log.Warning("Adding process to be monitored by DNS cache: ", pid) + if !monitoring { + mlock.Lock() + if !monitoring { + monitoring = true +// go checker(dc) + go pcoroner.MonitorThread(procDeathCallback, dc) + } + mlock.Unlock() + } + pcoroner.MonitorProcess(pid) + } if !FirewallConfig.LogRedact { log.Infof("Adding %s: %s", name, ip) } + case *dnsRR_CNAME: + // Not that exotic; just ignore it default: log.Warningf("Unexpected RR type in answer section of A response: %v", rec) } } } -func (dc *dnsCache) Lookup(ip net.IP) string { +func (dc *dnsCache) Lookup(ip net.IP, pid int) string { + now := time.Now() dc.lock.Lock() defer dc.lock.Unlock() - return dc.ipMap[ip.String()] + + // empty procinfo can set this to -1 + if pid < 0 { + pid = 0 + } + + if pid > 0 { + entry, ok := dc.ipMap[pid][ip.String()] + if ok { + if now.Before(entry.exp) { +// log.Noticef("XXX: LOOKUP on %v / %v = %v, ttl = %v / %v\n", pid, ip.String(), entry.name, entry.ttl, entry.exp) + return entry.name + } else { + log.Warningf("Skipping expired per-pid (%d) DNS cache entry: %s -> %s / exp. %v (%ds)\n", + pid, ip.String(), entry.name, entry.exp, entry.ttl) + } + } + } + + str := "" + entry, ok := dc.ipMap[0][ip.String()] + if ok { + if now.Before(entry.exp) { + str = entry.name +// log.Noticef("XXX: LOOKUP on %v / 0 RETURNING %v, ttl = %v / %v\n", ip.String(), str, entry.ttl, entry.exp) + } else { + log.Warningf("Skipping expired global DNS cache entry: %s -> %s / exp. %v (%ds)\n", + ip.String(), entry.name, entry.exp, entry.ttl) + } + } + +//log.Noticef("XXX: LOOKUP on %v / 0 RETURNING %v\n", ip.String(), str) + return str } diff --git a/sgfw/policy.go b/sgfw/policy.go index b54a310..ae76142 100644 --- a/sgfw/policy.go +++ b/sgfw/policy.go @@ -10,7 +10,6 @@ import ( // nfnetlink "github.com/subgraph/go-nfnetlink" nfqueue "github.com/subgraph/go-nfnetlink/nfqueue" -// "github.com/google/gopacket" "github.com/google/gopacket/layers" "github.com/subgraph/go-procsnitch" "net" @@ -203,21 +202,12 @@ func (p *Policy) processPacket(pkt *nfqueue.NFQPacket, pinfo *procsnitch.Info, o dstb := pkt.Packet.NetworkLayer().NetworkFlow().Dst().Raw() dstip := net.IP(dstb) srcip := net.IP(pkt.Packet.NetworkLayer().NetworkFlow().Src().Raw()) -// _, dstp := getPacketPorts(pkt) - name := p.fw.dns.Lookup(dstip) + name := p.fw.dns.Lookup(dstip, pinfo.Pid) if !FirewallConfig.LogRedact { log.Infof("Lookup(%s): %s", dstip.String(), name) } // fwo := matchAgainstOzRules(srcip, dstip, dstp) -if name == "" { -/* log.Notice("XXXXXXXXXXXXx trying better rev lookup:") - net.LookupAddr(dstip.String()) - name = p.fw.dns.Lookup(dstip) - log.Notice("NOW ITS: ", name) */ -} - -//log.Notice("XXX: Attempting to filter packet on rules -> ", fwo, " / rev lookup = ", name) result := p.rules.filterPacket(pkt, pinfo, srcip, name, optstr) switch result { case FILTER_DENY: @@ -384,12 +374,13 @@ func printPacket(pkt *nfqueue.NFQPacket, hostname string, pinfo *procsnitch.Info } func (fw *Firewall) filterPacket(pkt *nfqueue.NFQPacket) { - if pkt.Packet.Layer(layers.LayerTypeUDP) != nil { + isudp := pkt.Packet.Layer(layers.LayerTypeUDP) != nil + if isudp { srcport, _ := getPacketUDPPorts(pkt) if srcport == 53 { + fw.dns.processDNS(pkt) pkt.Accept() - fw.dns.processDNS(pkt.Packet) return } @@ -412,13 +403,18 @@ func (fw *Firewall) filterPacket(pkt *nfqueue.NFQPacket) { ppath := "*" + strictness := procsnitch.MATCH_STRICT + + if isudp { + strictness = procsnitch.MATCH_LOOSE + } - pinfo, optstring := findProcessForPacket(pkt) + pinfo, optstring := findProcessForPacket(pkt, false, strictness) if pinfo == nil { pinfo = getEmptyPInfo() ppath = "[unknown]" optstring = "[Connection could not be mapped]" - log.Warningf("No proc found for %s", printPacket(pkt, fw.dns.Lookup(dstip), nil)) + log.Warningf("No proc found for %s", printPacket(pkt, fw.dns.Lookup(dstip, pinfo.Pid), nil)) // pkt.Accept() // return } else { @@ -433,7 +429,7 @@ func (fw *Firewall) filterPacket(pkt *nfqueue.NFQPacket) { } } } - log.Debugf("filterPacket [%s] %s", ppath, printPacket(pkt, fw.dns.Lookup(dstip), nil)) + log.Debugf("filterPacket [%s] %s", ppath, printPacket(pkt, fw.dns.Lookup(dstip, pinfo.Pid), nil)) if basicAllowPacket(pkt) { pkt.Accept() //log.Notice("XXX: passed basicallowpacket") @@ -481,7 +477,7 @@ func getAllProcNetDataLocal() ([]string, error) { for i := 0; i < len(OzInitPids); i++ { fname := fmt.Sprintf("/proc/%d/net/tcp", OzInitPids[i]) -fmt.Println("XXX: opening: ", fname) +//fmt.Println("XXX: opening: ", fname) bdata, err := readFileDirect(fname) if err != nil { @@ -528,13 +524,18 @@ func getRealRoot(pathname string, pid int) string { return pathname } -func findProcessForPacket(pkt *nfqueue.NFQPacket) (*procsnitch.Info, string) { +func findProcessForPacket(pkt *nfqueue.NFQPacket, reverse bool, strictness int) (*procsnitch.Info, string) { srcip, dstip := getPacketIP4Addrs(pkt) srcp, dstp := getPacketPorts(pkt) proto := "" optstr := "" icode := -1 + if reverse { + dstip, srcip = getPacketIP4Addrs(pkt) + dstp, srcp = getPacketPorts(pkt) + } + if pkt.Packet.Layer(layers.LayerTypeTCP) != nil { proto = "tcp" } else if pkt.Packet.Layer(layers.LayerTypeUDP) != nil { @@ -549,13 +550,15 @@ func findProcessForPacket(pkt *nfqueue.NFQPacket) (*procsnitch.Info, string) { return nil, optstr } +//log.Noticef("XXX proto = %s, from %v : %v -> %v : %v\n", proto, srcip, srcp, dstip, dstp) + var res *procsnitch.Info = nil // Try normal way first, before the more resource intensive/invasive way. if proto == "tcp" { res = procsnitch.LookupTCPSocketProcessAll(srcip, srcp, dstip, dstp, nil) } else if proto == "udp" { - res = procsnitch.LookupUDPSocketProcessAll(srcip, srcp, dstip, dstp, nil, true) + res = procsnitch.LookupUDPSocketProcessAll(srcip, srcp, dstip, dstp, nil, strictness) } else if proto == "icmp" { res = procsnitch.LookupICMPSocketProcessAll(srcip, dstip, icode, nil) } @@ -566,7 +569,7 @@ func findProcessForPacket(pkt *nfqueue.NFQPacket) (*procsnitch.Info, string) { for i := 0; i < len(OzInitPids); i++ { data := "" fname := fmt.Sprintf("/proc/%d/net/%s", OzInitPids[i].Pid, proto) -fmt.Println("XXX: opening: ", fname) +//fmt.Println("XXX: opening: ", fname) bdata, err := readFileDirect(fname) if err != nil { @@ -596,7 +599,7 @@ fmt.Println("XXX: opening: ", fname) if proto == "tcp" { res = procsnitch.LookupTCPSocketProcessAll(srcip, srcp, dstip, dstp, rlines) } else if proto == "udp" { - res = procsnitch.LookupUDPSocketProcessAll(srcip, srcp, dstip, dstp, rlines, true) + res = procsnitch.LookupUDPSocketProcessAll(srcip, srcp, dstip, dstp, rlines, strictness) } else if proto == "icmp" { res = procsnitch.LookupICMPSocketProcessAll(srcip, dstip, icode, rlines) } diff --git a/vendor/github.com/subgraph/go-procsnitch/proc.go b/vendor/github.com/subgraph/go-procsnitch/proc.go index e30bab7..e867e86 100644 --- a/vendor/github.com/subgraph/go-procsnitch/proc.go +++ b/vendor/github.com/subgraph/go-procsnitch/proc.go @@ -76,8 +76,8 @@ func LookupICMPSocketProcessAll(srcAddr net.IP, dstAddr net.IP, code int, custda } // LookupUDPSocketProcessAll searches for a UDP socket a given source port, destination IP, and destination port - AND source destination -func LookupUDPSocketProcessAll(srcAddr net.IP, srcPort uint16, dstAddr net.IP, dstPort uint16, custdata []string, loose bool) *Info { - ss := findUDPSocketAll(srcAddr, srcPort, dstAddr, dstPort, custdata, loose) +func LookupUDPSocketProcessAll(srcAddr net.IP, srcPort uint16, dstAddr net.IP, dstPort uint16, custdata []string, strictness int) *Info { + ss := findUDPSocketAll(srcAddr, srcPort, dstAddr, dstPort, custdata, strictness) if ss == nil { return nil } diff --git a/vendor/github.com/subgraph/go-procsnitch/socket.go b/vendor/github.com/subgraph/go-procsnitch/socket.go index 851c140..9c18ce3 100644 --- a/vendor/github.com/subgraph/go-procsnitch/socket.go +++ b/vendor/github.com/subgraph/go-procsnitch/socket.go @@ -31,6 +31,12 @@ type socketStatus struct { type ConnectionStatus int +const ( + MATCH_STRICT = iota + MATCH_LOOSE + MATCH_LOOSEST +) + const ( ESTABLISHED ConnectionStatus = iota SYN_SENT @@ -90,30 +96,41 @@ func findICMPSocketAll(srcAddr net.IP, dstAddr net.IP, code int, custdata []stri }) } -func findUDPSocketAll(srcAddr net.IP, srcPort uint16, dstAddr net.IP, dstPort uint16, custdata []string, loose bool) *socketStatus { +func findUDPSocketAll(srcAddr net.IP, srcPort uint16, dstAddr net.IP, dstPort uint16, custdata []string, strictness int) *socketStatus { wildcard := net.IP{0,0,0,0} if custdata == nil { - if !loose { + if strictness == MATCH_STRICT { return findSocket("udp", func(ss socketStatus) bool { -// return ss.remote.port == dstPort && ss.remote.ip.Equal(dstAddr) && ss.local.port == srcPort && ss.local.ip.Equal(srcAddr) return ss.remote.ip.Equal(dstAddr) && ss.local.port == srcPort && ss.local.ip.Equal(srcAddr) }) + } else if strictness == MATCH_LOOSE { + return findSocket("udp", func(ss socketStatus) bool { + return (ss.remote.ip.Equal(dstAddr) || ss.remote.ip.Equal(wildcard)) && ss.local.port == srcPort && ss.local.ip.Equal(srcAddr) || + (ss.local.ip.Equal(dstAddr) || ss.local.ip.Equal(wildcard)) && ss.remote.port == srcPort && ss.remote.ip.Equal(srcAddr) + }) } - return findSocket("udp", func(ss socketStatus) bool { - return (ss.remote.ip.Equal(dstAddr) || ss.remote.ip.Equal(wildcard)) && ss.local.port == srcPort && ss.local.ip.Equal(srcAddr) + return (ss.remote.ip.Equal(dstAddr) || ss.remote.ip.Equal(wildcard)) && ss.local.port == srcPort && (ss.local.ip.Equal(srcAddr) || ss.local.ip.Equal(wildcard)) || + (ss.local.ip.Equal(dstAddr) || ss.local.ip.Equal(wildcard)) && ss.remote.port == srcPort && (ss.remote.ip.Equal(srcAddr) || ss.remote.ip.Equal(srcAddr)) }) + } - if !loose { + if strictness == MATCH_STRICT { return findSocketCustom("udp", custdata, func(ss socketStatus) bool { return ss.remote.ip.Equal(dstAddr) && ss.local.port == srcPort && ss.local.ip.Equal(srcAddr) }) + } else if strictness == MATCH_LOOSE { + return findSocketCustom("udp", custdata, func(ss socketStatus) bool { + return (ss.remote.ip.Equal(dstAddr) || ss.remote.ip.Equal(wildcard)) && ss.local.port == srcPort && ss.local.ip.Equal(srcAddr) || + (ss.local.ip.Equal(dstAddr) || ss.local.ip.Equal(wildcard)) && ss.remote.port == srcPort && ss.remote.ip.Equal(srcAddr) + }) } return findSocketCustom("udp", custdata, func(ss socketStatus) bool { - return (ss.remote.ip.Equal(dstAddr) || ss.remote.ip.Equal(wildcard)) && ss.local.port == srcPort && ss.local.ip.Equal(srcAddr) + return (ss.remote.ip.Equal(dstAddr) || ss.remote.ip.Equal(wildcard)) && ss.local.port == srcPort && (ss.local.ip.Equal(srcAddr) || ss.local.ip.Equal(wildcard)) || + (ss.local.ip.Equal(dstAddr) || ss.local.ip.Equal(wildcard)) && ss.remote.port == srcPort && (ss.remote.ip.Equal(srcAddr) || ss.remote.ip.Equal(srcAddr)) }) }