Introduced per-process DNS cache segregation for all A records not returned by local resolver.

Cached DNS name lookups now failover to global cache only populated by local resolver.
Added proc-coroner module for detecting process deaths.
procsnitch updated to handle multiple levels of "strictness" (necessary to lookup processes generating certain UDP data).
shw_dev
shw 8 years ago
parent 51c181a881
commit c3635093fa

@ -0,0 +1,154 @@
package pcoroner
import (
"fmt"
"time"
"strings"
"strconv"
"sync"
"os"
"syscall"
)
type WatchProcess struct {
Pid int
Inode uint64
Ppid int
Stime int
}
type procCB func(int, interface{})
var pmutex = &sync.Mutex{}
var pidMap map[int]WatchProcess = make(map[int]WatchProcess)
func MonitorProcess(pid int) bool {
pmutex.Lock()
defer pmutex.Unlock()
_, ok := pidMap[pid]
if ok {
return false
}
watcher := WatchProcess{Pid: pid}
watcher.Inode = 0
res := checkProcess(&watcher, true)
if res {
pidMap[pid] = watcher
}
return res
}
func UnmonitorProcess(pid int) {
pmutex.Lock()
defer pmutex.Unlock()
delete(pidMap, pid)
return
}
func MonitorThread(cbfunc procCB, param interface{}) {
for {
/* if len(pidMap) == 0 {
fmt.Println("TICK")
} else { fmt.Println("len = ", len(pidMap)) } */
pmutex.Lock()
pmutex.Unlock()
for pkey, pval := range pidMap {
// fmt.Printf("PID %v -> %v\n", pkey, pval)
res := checkProcess(&pval, false)
if !res {
delete(pidMap, pkey)
if cbfunc != nil {
cbfunc(pkey, param)
}
continue
}
}
time.Sleep(2 * time.Second)
}
}
func checkProcess(proc *WatchProcess, init bool) bool {
ppath := fmt.Sprintf("/proc/%d/stat", proc.Pid)
f, err := os.Open(ppath)
if err != nil {
// fmt.Printf("Error opening path %s: %s\n", ppath, err)
return false
}
defer f.Close()
fi, err := f.Stat()
if err != nil {
fmt.Printf("Error calling stat on file %s: %s\n", ppath, err)
return false
}
sb, ok := fi.Sys().(*syscall.Stat_t)
if !ok {
fmt.Println("Unexpected error reading stat information from proc file")
} else if init {
proc.Inode = sb.Ino
} else {
if sb.Ino != proc.Inode {
fmt.Printf("/proc inode mismatch for process %d: %v vs %v\n", proc.Pid, sb.Ino, proc.Inode)
return false
}
}
var buf [512]byte
nread, err := f.Read(buf[:])
if err != nil {
fmt.Printf("Error reading stat for process %d: %v", proc.Pid, err)
return true
} else if nread <= 0 {
fmt.Printf("Unexpected error reading stat for process %d", proc.Pid)
return true
}
bstr := string(buf[:])
// fmt.Println("sstr = ", bstr)
fields := strings.Split(bstr, " ")
if len(fields) < 22 {
fmt.Printf("Unexpected error reading data from /proc stat for process %d", proc.Pid)
return true
}
ppid, err := strconv.Atoi(fields[3])
if err != nil {
ppid = -1
}
if init {
proc.Ppid = ppid
} else if proc.Ppid != ppid {
fmt.Printf("Cached process ppid did not match value in /proc: %v vs %v\n", proc.Ppid, ppid)
return false
}
stime, err := strconv.Atoi(fields[21])
if err != nil {
stime = -1
}
if init {
proc.Stime = stime
} else if proc.Stime != stime {
fmt.Printf("Cached process start time did not match value in /proc: %v vs %v\n", proc.Stime, stime)
return false
}
return true
}

@ -4,28 +4,55 @@ import (
"net"
"strings"
"sync"
"time"
// "github.com/subgraph/go-nfnetlink"
"github.com/google/gopacket"
"github.com/google/gopacket/layers"
nfqueue "github.com/subgraph/go-nfnetlink/nfqueue"
"github.com/subgraph/go-procsnitch"
"github.com/subgraph/fw-daemon/proc-coroner"
)
var monitoring = false
var mlock = sync.Mutex{}
type dnsEntry struct {
name string
ttl uint32
exp time.Time
}
type dnsCache struct {
ipMap map[string]string
ipMap map[int]map[string]dnsEntry
lock sync.Mutex
done chan struct{}
}
func newDNSEntry(hostname string, ttl uint32) dnsEntry {
newEntry := dnsEntry{
name: hostname,
ttl: ttl,
exp: time.Now().Add(time.Second * time.Duration(ttl)),
}
return newEntry
}
func newDNSCache() *dnsCache {
return &dnsCache{
ipMap: make(map[string]string),
newCache := &dnsCache{
ipMap: make(map[int]map[string]dnsEntry),
done: make(chan struct{}),
}
newCache.ipMap[0] = make(map[string]dnsEntry)
return newCache
}
func isNSTrusted(src net.IP) bool {
return src.IsLoopback()
}
func (dc *dnsCache) processDNS(pkt gopacket.Packet) {
func (dc *dnsCache) processDNS(pkt *nfqueue.NFQPacket) {
dns := &dnsMsg{}
if !dns.Unpack(pkt.Layer(layers.LayerTypeDNS).LayerContents()) {
if !dns.Unpack(pkt.Packet.Layer(layers.LayerTypeDNS).LayerContents()) {
log.Warning("Failed to Unpack DNS message")
return
}
@ -38,14 +65,48 @@ func (dc *dnsCache) processDNS(pkt gopacket.Packet) {
}
q := dns.question[0]
if q.Qtype == dnsTypeA {
dc.processRecordA(q.Name, dns.answer)
srcip, _ := getPacketIP4Addrs(pkt)
pinfo := getEmptyPInfo()
if !isNSTrusted(srcip) {
pinfo, _ = findProcessForPacket(pkt, true, procsnitch.MATCH_LOOSEST)
if pinfo == nil {
log.Warningf("Skipping attempted DNS cache entry for process that can't be found: %v -> %v\n", q.Name, dns.answer)
return
}
}
//log.Notice("XXX: PROCESS LOOKUP -> ", pinfo)
dc.processRecordA(q.Name, dns.answer, pinfo.Pid)
return
}
log.Infof("Unhandled DNS message: %v", dns)
}
func (dc *dnsCache) processRecordA(name string, answers []dnsRR) {
/*func checker(c *dnsCache) {
for {
log.Error("CACHE CHECKER")
c.lock.Lock()
for k, v := range c.ipMap {
log.Errorf("IN CACHE: %v -> %v\n", k, v)
}
c.lock.Unlock()
time.Sleep(2 * time.Second)
}
} */
func procDeathCallback(pid int, param interface{}) {
// log.Warning("XXX: IN CALLBACK for pid: ", pid, " / param = ", param)
if pid != 0 {
cache := param.(*dnsCache)
cache.lock.Lock()
delete(cache.ipMap, pid)
cache.lock.Unlock()
}
}
func (dc *dnsCache) processRecordA(name string, answers []dnsRR, pid int) {
dc.lock.Lock()
defer dc.lock.Unlock()
for _, rr := range answers {
@ -55,19 +116,78 @@ func (dc *dnsCache) processRecordA(name string, answers []dnsRR) {
if strings.HasSuffix(name, ".") {
name = name[:len(name)-1]
}
log.Notice("______ Adding to dns map: %s: %s", name, ip)
dc.ipMap[ip] = name
// Just in case.
if pid < 0 {
pid = 0
}
log.Noticef("______ Adding to dns map: %s: %s -> pid %d", name, ip, pid)
_, ok := dc.ipMap[pid]
if !ok {
dc.ipMap[pid] = make(map[string]dnsEntry)
}
dc.ipMap[pid][ip] = newDNSEntry(name, rr.Header().TTL)
if pid > 0 {
log.Warning("Adding process to be monitored by DNS cache: ", pid)
if !monitoring {
mlock.Lock()
if !monitoring {
monitoring = true
// go checker(dc)
go pcoroner.MonitorThread(procDeathCallback, dc)
}
mlock.Unlock()
}
pcoroner.MonitorProcess(pid)
}
if !FirewallConfig.LogRedact {
log.Infof("Adding %s: %s", name, ip)
}
case *dnsRR_CNAME:
// Not that exotic; just ignore it
default:
log.Warningf("Unexpected RR type in answer section of A response: %v", rec)
}
}
}
func (dc *dnsCache) Lookup(ip net.IP) string {
func (dc *dnsCache) Lookup(ip net.IP, pid int) string {
now := time.Now()
dc.lock.Lock()
defer dc.lock.Unlock()
return dc.ipMap[ip.String()]
// empty procinfo can set this to -1
if pid < 0 {
pid = 0
}
if pid > 0 {
entry, ok := dc.ipMap[pid][ip.String()]
if ok {
if now.Before(entry.exp) {
// log.Noticef("XXX: LOOKUP on %v / %v = %v, ttl = %v / %v\n", pid, ip.String(), entry.name, entry.ttl, entry.exp)
return entry.name
} else {
log.Warningf("Skipping expired per-pid (%d) DNS cache entry: %s -> %s / exp. %v (%ds)\n",
pid, ip.String(), entry.name, entry.exp, entry.ttl)
}
}
}
str := ""
entry, ok := dc.ipMap[0][ip.String()]
if ok {
if now.Before(entry.exp) {
str = entry.name
// log.Noticef("XXX: LOOKUP on %v / 0 RETURNING %v, ttl = %v / %v\n", ip.String(), str, entry.ttl, entry.exp)
} else {
log.Warningf("Skipping expired global DNS cache entry: %s -> %s / exp. %v (%ds)\n",
ip.String(), entry.name, entry.exp, entry.ttl)
}
}
//log.Noticef("XXX: LOOKUP on %v / 0 RETURNING %v\n", ip.String(), str)
return str
}

@ -10,7 +10,6 @@ import (
// nfnetlink "github.com/subgraph/go-nfnetlink"
nfqueue "github.com/subgraph/go-nfnetlink/nfqueue"
// "github.com/google/gopacket"
"github.com/google/gopacket/layers"
"github.com/subgraph/go-procsnitch"
"net"
@ -203,21 +202,12 @@ func (p *Policy) processPacket(pkt *nfqueue.NFQPacket, pinfo *procsnitch.Info, o
dstb := pkt.Packet.NetworkLayer().NetworkFlow().Dst().Raw()
dstip := net.IP(dstb)
srcip := net.IP(pkt.Packet.NetworkLayer().NetworkFlow().Src().Raw())
// _, dstp := getPacketPorts(pkt)
name := p.fw.dns.Lookup(dstip)
name := p.fw.dns.Lookup(dstip, pinfo.Pid)
if !FirewallConfig.LogRedact {
log.Infof("Lookup(%s): %s", dstip.String(), name)
}
// fwo := matchAgainstOzRules(srcip, dstip, dstp)
if name == "" {
/* log.Notice("XXXXXXXXXXXXx trying better rev lookup:")
net.LookupAddr(dstip.String())
name = p.fw.dns.Lookup(dstip)
log.Notice("NOW ITS: ", name) */
}
//log.Notice("XXX: Attempting to filter packet on rules -> ", fwo, " / rev lookup = ", name)
result := p.rules.filterPacket(pkt, pinfo, srcip, name, optstr)
switch result {
case FILTER_DENY:
@ -384,12 +374,13 @@ func printPacket(pkt *nfqueue.NFQPacket, hostname string, pinfo *procsnitch.Info
}
func (fw *Firewall) filterPacket(pkt *nfqueue.NFQPacket) {
if pkt.Packet.Layer(layers.LayerTypeUDP) != nil {
isudp := pkt.Packet.Layer(layers.LayerTypeUDP) != nil
if isudp {
srcport, _ := getPacketUDPPorts(pkt)
if srcport == 53 {
fw.dns.processDNS(pkt)
pkt.Accept()
fw.dns.processDNS(pkt.Packet)
return
}
@ -412,13 +403,18 @@ func (fw *Firewall) filterPacket(pkt *nfqueue.NFQPacket) {
ppath := "*"
strictness := procsnitch.MATCH_STRICT
if isudp {
strictness = procsnitch.MATCH_LOOSE
}
pinfo, optstring := findProcessForPacket(pkt)
pinfo, optstring := findProcessForPacket(pkt, false, strictness)
if pinfo == nil {
pinfo = getEmptyPInfo()
ppath = "[unknown]"
optstring = "[Connection could not be mapped]"
log.Warningf("No proc found for %s", printPacket(pkt, fw.dns.Lookup(dstip), nil))
log.Warningf("No proc found for %s", printPacket(pkt, fw.dns.Lookup(dstip, pinfo.Pid), nil))
// pkt.Accept()
// return
} else {
@ -433,7 +429,7 @@ func (fw *Firewall) filterPacket(pkt *nfqueue.NFQPacket) {
}
}
}
log.Debugf("filterPacket [%s] %s", ppath, printPacket(pkt, fw.dns.Lookup(dstip), nil))
log.Debugf("filterPacket [%s] %s", ppath, printPacket(pkt, fw.dns.Lookup(dstip, pinfo.Pid), nil))
if basicAllowPacket(pkt) {
pkt.Accept()
//log.Notice("XXX: passed basicallowpacket")
@ -481,7 +477,7 @@ func getAllProcNetDataLocal() ([]string, error) {
for i := 0; i < len(OzInitPids); i++ {
fname := fmt.Sprintf("/proc/%d/net/tcp", OzInitPids[i])
fmt.Println("XXX: opening: ", fname)
//fmt.Println("XXX: opening: ", fname)
bdata, err := readFileDirect(fname)
if err != nil {
@ -528,13 +524,18 @@ func getRealRoot(pathname string, pid int) string {
return pathname
}
func findProcessForPacket(pkt *nfqueue.NFQPacket) (*procsnitch.Info, string) {
func findProcessForPacket(pkt *nfqueue.NFQPacket, reverse bool, strictness int) (*procsnitch.Info, string) {
srcip, dstip := getPacketIP4Addrs(pkt)
srcp, dstp := getPacketPorts(pkt)
proto := ""
optstr := ""
icode := -1
if reverse {
dstip, srcip = getPacketIP4Addrs(pkt)
dstp, srcp = getPacketPorts(pkt)
}
if pkt.Packet.Layer(layers.LayerTypeTCP) != nil {
proto = "tcp"
} else if pkt.Packet.Layer(layers.LayerTypeUDP) != nil {
@ -549,13 +550,15 @@ func findProcessForPacket(pkt *nfqueue.NFQPacket) (*procsnitch.Info, string) {
return nil, optstr
}
//log.Noticef("XXX proto = %s, from %v : %v -> %v : %v\n", proto, srcip, srcp, dstip, dstp)
var res *procsnitch.Info = nil
// Try normal way first, before the more resource intensive/invasive way.
if proto == "tcp" {
res = procsnitch.LookupTCPSocketProcessAll(srcip, srcp, dstip, dstp, nil)
} else if proto == "udp" {
res = procsnitch.LookupUDPSocketProcessAll(srcip, srcp, dstip, dstp, nil, true)
res = procsnitch.LookupUDPSocketProcessAll(srcip, srcp, dstip, dstp, nil, strictness)
} else if proto == "icmp" {
res = procsnitch.LookupICMPSocketProcessAll(srcip, dstip, icode, nil)
}
@ -566,7 +569,7 @@ func findProcessForPacket(pkt *nfqueue.NFQPacket) (*procsnitch.Info, string) {
for i := 0; i < len(OzInitPids); i++ {
data := ""
fname := fmt.Sprintf("/proc/%d/net/%s", OzInitPids[i].Pid, proto)
fmt.Println("XXX: opening: ", fname)
//fmt.Println("XXX: opening: ", fname)
bdata, err := readFileDirect(fname)
if err != nil {
@ -596,7 +599,7 @@ fmt.Println("XXX: opening: ", fname)
if proto == "tcp" {
res = procsnitch.LookupTCPSocketProcessAll(srcip, srcp, dstip, dstp, rlines)
} else if proto == "udp" {
res = procsnitch.LookupUDPSocketProcessAll(srcip, srcp, dstip, dstp, rlines, true)
res = procsnitch.LookupUDPSocketProcessAll(srcip, srcp, dstip, dstp, rlines, strictness)
} else if proto == "icmp" {
res = procsnitch.LookupICMPSocketProcessAll(srcip, dstip, icode, rlines)
}

@ -76,8 +76,8 @@ func LookupICMPSocketProcessAll(srcAddr net.IP, dstAddr net.IP, code int, custda
}
// LookupUDPSocketProcessAll searches for a UDP socket a given source port, destination IP, and destination port - AND source destination
func LookupUDPSocketProcessAll(srcAddr net.IP, srcPort uint16, dstAddr net.IP, dstPort uint16, custdata []string, loose bool) *Info {
ss := findUDPSocketAll(srcAddr, srcPort, dstAddr, dstPort, custdata, loose)
func LookupUDPSocketProcessAll(srcAddr net.IP, srcPort uint16, dstAddr net.IP, dstPort uint16, custdata []string, strictness int) *Info {
ss := findUDPSocketAll(srcAddr, srcPort, dstAddr, dstPort, custdata, strictness)
if ss == nil {
return nil
}

@ -31,6 +31,12 @@ type socketStatus struct {
type ConnectionStatus int
const (
MATCH_STRICT = iota
MATCH_LOOSE
MATCH_LOOSEST
)
const (
ESTABLISHED ConnectionStatus = iota
SYN_SENT
@ -90,30 +96,41 @@ func findICMPSocketAll(srcAddr net.IP, dstAddr net.IP, code int, custdata []stri
})
}
func findUDPSocketAll(srcAddr net.IP, srcPort uint16, dstAddr net.IP, dstPort uint16, custdata []string, loose bool) *socketStatus {
func findUDPSocketAll(srcAddr net.IP, srcPort uint16, dstAddr net.IP, dstPort uint16, custdata []string, strictness int) *socketStatus {
wildcard := net.IP{0,0,0,0}
if custdata == nil {
if !loose {
if strictness == MATCH_STRICT {
return findSocket("udp", func(ss socketStatus) bool {
// return ss.remote.port == dstPort && ss.remote.ip.Equal(dstAddr) && ss.local.port == srcPort && ss.local.ip.Equal(srcAddr)
return ss.remote.ip.Equal(dstAddr) && ss.local.port == srcPort && ss.local.ip.Equal(srcAddr)
})
} else if strictness == MATCH_LOOSE {
return findSocket("udp", func(ss socketStatus) bool {
return (ss.remote.ip.Equal(dstAddr) || ss.remote.ip.Equal(wildcard)) && ss.local.port == srcPort && ss.local.ip.Equal(srcAddr) ||
(ss.local.ip.Equal(dstAddr) || ss.local.ip.Equal(wildcard)) && ss.remote.port == srcPort && ss.remote.ip.Equal(srcAddr)
})
}
return findSocket("udp", func(ss socketStatus) bool {
return (ss.remote.ip.Equal(dstAddr) || ss.remote.ip.Equal(wildcard)) && ss.local.port == srcPort && ss.local.ip.Equal(srcAddr)
return (ss.remote.ip.Equal(dstAddr) || ss.remote.ip.Equal(wildcard)) && ss.local.port == srcPort && (ss.local.ip.Equal(srcAddr) || ss.local.ip.Equal(wildcard)) ||
(ss.local.ip.Equal(dstAddr) || ss.local.ip.Equal(wildcard)) && ss.remote.port == srcPort && (ss.remote.ip.Equal(srcAddr) || ss.remote.ip.Equal(srcAddr))
})
}
if !loose {
if strictness == MATCH_STRICT {
return findSocketCustom("udp", custdata, func(ss socketStatus) bool {
return ss.remote.ip.Equal(dstAddr) && ss.local.port == srcPort && ss.local.ip.Equal(srcAddr)
})
} else if strictness == MATCH_LOOSE {
return findSocketCustom("udp", custdata, func(ss socketStatus) bool {
return (ss.remote.ip.Equal(dstAddr) || ss.remote.ip.Equal(wildcard)) && ss.local.port == srcPort && ss.local.ip.Equal(srcAddr) ||
(ss.local.ip.Equal(dstAddr) || ss.local.ip.Equal(wildcard)) && ss.remote.port == srcPort && ss.remote.ip.Equal(srcAddr)
})
}
return findSocketCustom("udp", custdata, func(ss socketStatus) bool {
return (ss.remote.ip.Equal(dstAddr) || ss.remote.ip.Equal(wildcard)) && ss.local.port == srcPort && ss.local.ip.Equal(srcAddr)
return (ss.remote.ip.Equal(dstAddr) || ss.remote.ip.Equal(wildcard)) && ss.local.port == srcPort && (ss.local.ip.Equal(srcAddr) || ss.local.ip.Equal(wildcard)) ||
(ss.local.ip.Equal(dstAddr) || ss.local.ip.Equal(wildcard)) && ss.remote.port == srcPort && (ss.remote.ip.Equal(srcAddr) || ss.remote.ip.Equal(srcAddr))
})
}

Loading…
Cancel
Save