You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
fw-daemon/sgfw/dns.go

227 lines
5.6 KiB

package sgfw
import (
"encoding/binary"
"fmt"
"net"
"strings"
"sync"
"time"
// "github.com/subgraph/go-nfnetlink"
"github.com/google/gopacket/layers"
"github.com/subgraph/fw-daemon/proc-coroner"
nfqueue "github.com/subgraph/go-nfnetlink/nfqueue"
"github.com/subgraph/go-procsnitch"
)
type dnsEntry struct {
name string
ttl uint32
exp time.Time
}
type dnsCache struct {
ipMap map[int]map[string]dnsEntry
lock sync.Mutex
done chan struct{}
}
func newDNSEntry(hostname string, ttl uint32) dnsEntry {
newEntry := dnsEntry{
name: hostname,
ttl: ttl,
exp: time.Now().Add(time.Second * time.Duration(ttl)),
}
return newEntry
}
func newDNSCache() *dnsCache {
newCache := &dnsCache{
ipMap: make(map[int]map[string]dnsEntry),
done: make(chan struct{}),
}
newCache.ipMap[0] = make(map[string]dnsEntry)
return newCache
}
func isNSTrusted(src net.IP) bool {
return src.IsLoopback()
}
func (dc *dnsCache) processDNS(pkt *nfqueue.NFQPacket) {
dns := &dnsMsg{}
if !dns.Unpack(pkt.Packet.Layer(layers.LayerTypeDNS).LayerContents()) {
log.Warning("Failed to Unpack DNS message")
return
}
if !dns.response {
return
}
if len(dns.question) != 1 {
log.Warningf("Length of DNS Question section is not 1 as expected: %d", len(dns.question))
return
}
q := dns.question[0]
if q.Qtype == dnsTypeA || q.Qtype == dnsTypeAAAA {
srcip, _ := getPacketIPAddrs(pkt)
pinfo := getEmptyPInfo()
if !isNSTrusted(srcip) {
pinfo, _ = findProcessForPacket(pkt, true, procsnitch.MATCH_LOOSEST)
if pinfo == nil {
findProcessForPacket(pkt, false, procsnitch.MATCH_LOOSEST)
}
if pinfo == nil {
if !FirewallConfig.LogRedact {
log.Warningf("Skipping attempted DNS cache entry for process that can't be found: %v -> %v\n", q.Name, dns.answer)
} else {
dbLogger.logRedacted("default", fmt.Sprintf("Skipping attempted DNS cache entry for process that can't be found: %v -> %v\n", q.Name, dns.answer))
}
return
} else {
log.Warningf("%v", pinfo)
}
}
//log.Notice("XXX: PROCESS LOOKUP -> ", pinfo)
dc.processRecordAddress(q.Name, dns.answer, pinfo.Pid)
return
}
if !FirewallConfig.LogRedact {
log.Infof("Unhandled DNS message: %v", dns)
} else {
log.Infof("Unhandled DNS message: %s", STR_REDACTED)
dbLogger.logRedacted("default", fmt.Sprintf("Unhandled DNS message: %v", dns))
}
}
/*func checker(c *dnsCache) {
for {
log.Error("CACHE CHECKER")
c.lock.Lock()
for k, v := range c.ipMap {
log.Errorf("IN CACHE: %v -> %v\n", k, v)
}
c.lock.Unlock()
time.Sleep(2 * time.Second)
}
} */
func procDeathCallbackDNS(pid int, param interface{}) {
if pid != 0 {
cache := param.(*dnsCache)
cache.lock.Lock()
delete(cache.ipMap, pid)
cache.lock.Unlock()
}
}
func (dc *dnsCache) processRecordAddress(name string, answers []dnsRR, pid int) {
dc.lock.Lock()
defer dc.lock.Unlock()
for _, rr := range answers {
var aBytes []byte = nil
switch rec := rr.(type) {
case *dnsRR_A:
var ipA [4]byte
aBytes = ipA[:]
binary.BigEndian.PutUint32(aBytes, rec.A)
case *dnsRR_AAAA:
aBytes = rec.AAAA[:]
case *dnsRR_CNAME:
// Not that exotic; just ignore it
default:
if !FirewallConfig.LogRedact {
log.Warningf("Unexpected RR type in answer section of A response: %v", rec)
} else {
log.Warningf("Unexpected RR type in answer section of A response: [redacted]")
dbLogger.logRedacted("default", fmt.Sprintf("Unexpected RR type in answer section of A response: %v", rec))
}
}
if aBytes == nil {
continue
}
ip := net.IP(aBytes).String()
if strings.HasSuffix(name, ".") {
name = name[:len(name)-1]
}
// Just in case.
if pid < 0 {
pid = 0
}
// log.Noticef("______ Adding to dns map: %s: %s -> pid %d", name, ip, pid)
_, ok := dc.ipMap[pid]
if !ok {
dc.ipMap[pid] = make(map[string]dnsEntry)
}
dc.ipMap[pid][ip] = newDNSEntry(name, rr.Header().TTL)
if pid > 0 {
log.Warning("Adding process to be monitored by DNS cache: ", pid)
pcoroner.MonitorProcess(pid)
}
if !FirewallConfig.LogRedact {
log.Infof("Adding %s: %s", name, ip)
} else {
dbLogger.logRedacted("default", fmt.Sprintf("Adding %s: %s", name, ip))
}
}
}
func (dc *dnsCache) Lookup(ip net.IP, pid int) string {
now := time.Now()
dc.lock.Lock()
defer dc.lock.Unlock()
// empty procinfo can set this to -1
if pid < 0 {
pid = 0
}
if pid > 0 {
entry, ok := dc.ipMap[pid][ip.String()]
if ok {
if now.Before(entry.exp) {
// log.Noticef("XXX: LOOKUP on %v / %v = %v, ttl = %v / %v\n", pid, ip.String(), entry.name, entry.ttl, entry.exp)
return entry.name
} else {
if !FirewallConfig.LogRedact {
log.Warningf("Skipping expired per-pid (%d) DNS cache entry: %s -> %s / exp. %v (%ds)\n",
pid, ip.String(), entry.name, entry.exp, entry.ttl)
} else {
dbLogger.logRedacted("default", fmt.Sprintf("Skipping expired per-pid (%d) DNS cache entry: %s -> %s / exp. %v (%ds)\n",
pid, ip.String(), entry.name, entry.exp, entry.ttl))
}
}
}
}
str := ""
entry, ok := dc.ipMap[0][ip.String()]
if ok {
if now.Before(entry.exp) {
str = entry.name
// log.Noticef("XXX: LOOKUP on %v / 0 RETURNING %v, ttl = %v / %v\n", ip.String(), str, entry.ttl, entry.exp)
} else {
if !FirewallConfig.LogRedact {
log.Warningf("Skipping expired global DNS cache entry: %s -> %s / exp. %v (%ds)\n",
ip.String(), entry.name, entry.exp, entry.ttl)
} else {
dbLogger.logRedacted("default", fmt.Sprintf("Skipping expired global DNS cache entry: %s -> %s / exp. %v (%ds)\n",
ip.String(), entry.name, entry.exp, entry.ttl))
}
}
}
//log.Noticef("XXX: LOOKUP on %v / 0 RETURNING %v\n", ip.String(), str)
return str
}