diff --git a/policy.go b/policy.go index f76f3b5..6ad7898 100644 --- a/policy.go +++ b/policy.go @@ -6,13 +6,57 @@ import ( "github.com/subgraph/fw-daemon/nfqueue" "github.com/subgraph/go-procsnitch" + "net" ) +type pendingConnection interface { + policy() *Policy + procInfo() *procsnitch.Info + hostname() string + dst() net.IP + dstPort() uint16 + accept() + drop() + print() string +} + type pendingPkt struct { - policy *Policy - hostname string - pkt *nfqueue.Packet - pinfo *procsnitch.Info + pol *Policy + name string + pkt *nfqueue.Packet + pinfo *procsnitch.Info +} + +func (pp *pendingPkt) policy() *Policy { + return pp.pol +} + +func (pp *pendingPkt) procInfo() *procsnitch.Info { + return pp.pinfo +} + +func (pp *pendingPkt) hostname() string { + return pp.name +} +func (pp *pendingPkt) dst() net.IP { + return pp.pkt.Dst +} + +func (pp *pendingPkt) dstPort() uint16 { + return pp.pkt.DstPort +} + +func (pp *pendingPkt) accept() { + pp.pkt.Accept() +} + +func (pp *pendingPkt) drop() { + pp.pkt.Mark = 1 + pp.pkt.Accept() +} + +func (pp *pendingPkt) print() string { + return printPacket(pp.pkt, pp.name) } type Policy struct { @@ -21,11 +65,18 @@ type Policy struct { application string icon string rules RuleList - pendingQueue []*pendingPkt + pendingQueue []pendingConnection promptInProgress bool lock sync.Mutex } +func (fw *Firewall) PolicyForPath(path string) *Policy { + fw.lock.Lock() + defer fw.lock.Unlock() + + return fw.policyForPath(path) +} + func (fw *Firewall) policyForPath(path string) *Policy { if _, ok := fw.policyMap[path]; !ok { p := new(Policy) @@ -50,7 +101,7 @@ func (p *Policy) processPacket(pkt *nfqueue.Packet, pinfo *procsnitch.Info) { if !logRedact { log.Info("Lookup(%s): %s", pkt.Dst.String(), name) } - result := p.rules.filter(pkt, pinfo, name) + result := p.rules.filterPacket(pkt, pinfo, name) switch result { case FILTER_DENY: pkt.Mark = 1 @@ -58,21 +109,21 @@ func (p *Policy) processPacket(pkt *nfqueue.Packet, pinfo *procsnitch.Info) { case FILTER_ALLOW: pkt.Accept() case FILTER_PROMPT: - p.processPromptResult(&pendingPkt{policy: p, hostname: name, pkt: pkt, pinfo: pinfo}) + p.processPromptResult(&pendingPkt{pol: p, name: name, pkt: pkt, pinfo: pinfo}) default: log.Warning("Unexpected filter result: %d", result) } } -func (p *Policy) processPromptResult(pp *pendingPkt) { - p.pendingQueue = append(p.pendingQueue, pp) +func (p *Policy) processPromptResult(pc pendingConnection) { + p.pendingQueue = append(p.pendingQueue, pc) if !p.promptInProgress { p.promptInProgress = true go p.fw.dbus.prompt(p) } } -func (p *Policy) nextPending() *pendingPkt { +func (p *Policy) nextPending() pendingConnection { p.lock.Lock() defer p.lock.Unlock() if len(p.pendingQueue) == 0 { @@ -81,14 +132,14 @@ func (p *Policy) nextPending() *pendingPkt { return p.pendingQueue[0] } -func (p *Policy) removePending(pp *pendingPkt) { +func (p *Policy) removePending(pc pendingConnection) { p.lock.Lock() defer p.lock.Unlock() - remaining := []*pendingPkt{} - for _, pkt := range p.pendingQueue { - if pkt != pp { - remaining = append(remaining, pkt) + remaining := []pendingConnection{} + for _, c := range p.pendingQueue { + if c != pc { + remaining = append(remaining, c) } } if len(remaining) != len(p.pendingQueue) { @@ -141,18 +192,17 @@ func (p *Policy) removeRule(r *Rule) { } func (p *Policy) filterPending(rule *Rule) { - remaining := []*pendingPkt{} - for _, pp := range p.pendingQueue { - if rule.match(pp.pkt, pp.hostname) { - log.Info("Also applying %s to %s", rule.getString(logRedact), printPacket(pp.pkt, pp.hostname)) + remaining := []pendingConnection{} + for _, pc := range p.pendingQueue { + if rule.match(pc.dst(), pc.dstPort(), pc.hostname()) { + log.Info("Also applying %s to %s", rule.getString(logRedact), pc.print()) if rule.rtype == RULE_ALLOW { - pp.pkt.Accept() + pc.accept() } else { - pp.pkt.Mark = 1 - pp.pkt.Accept() + pc.drop() } } else { - remaining = append(remaining, pp) + remaining = append(remaining, pc) } } if len(remaining) != len(p.pendingQueue) { @@ -208,9 +258,7 @@ func (fw *Firewall) filterPacket(pkt *nfqueue.Packet) { pkt.Accept() return } - fw.lock.Lock() - policy := fw.policyForPath(pinfo.ExePath) - fw.lock.Unlock() + policy := fw.PolicyForPath(pinfo.ExePath) policy.processPacket(pkt, pinfo) } diff --git a/prompt.go b/prompt.go index 6d52d9c..f13573b 100644 --- a/prompt.go +++ b/prompt.go @@ -53,13 +53,13 @@ func (p *prompter) promptLoop() { } func (p *prompter) processNextPacket() bool { - pp := p.nextPacket() - if pp == nil { + pc := p.nextConnection() + if pc == nil { return false } p.lock.Unlock() defer p.lock.Lock() - p.processPacket(pp) + p.processConnection(pc) return true } @@ -76,65 +76,64 @@ func printScope(scope int32) string { } } -func (p *prompter) processPacket(pp *pendingPkt) { +func (p *prompter) processConnection(pc pendingConnection) { var scope int32 var rule string - addr := pp.hostname + addr := pc.hostname() if addr == "" { - addr = pp.pkt.Dst.String() + addr = pc.dst().String() } + policy := pc.policy() call := p.dbusObj.Call("com.subgraph.FirewallPrompt.RequestPrompt", 0, - pp.policy.application, - pp.policy.icon, - pp.policy.path, + policy.application, + policy.icon, + policy.path, addr, - int32(pp.pkt.DstPort), - pp.pkt.Dst.String(), - uidToUser(pp.pinfo.UID), - int32(pp.pinfo.Pid)) + int32(pc.dstPort()), + pc.dst().String(), + uidToUser(pc.procInfo().UID), + int32(pc.procInfo().Pid)) err := call.Store(&scope, &rule) if err != nil { log.Warning("Error sending dbus RequestPrompt message: %v", err) - pp.policy.removePending(pp) - pp.pkt.Mark = 1 - pp.pkt.Accept() + policy.removePending(pc) + pc.drop() return } - r, err := pp.policy.parseRule(rule, false) + r, err := policy.parseRule(rule, false) if err != nil { log.Warning("Error parsing rule string returned from dbus RequestPrompt: %v", err) - pp.policy.removePending(pp) - pp.pkt.Mark = 1 - pp.pkt.Accept() + policy.removePending(pc) + pc.drop() return } if scope == APPLY_SESSION { r.sessionOnly = true } - if !pp.policy.processNewRule(r, scope) { + if !policy.processNewRule(r, scope) { p.lock.Lock() defer p.lock.Unlock() - p.removePolicy(pp.policy) + p.removePolicy(pc.policy()) } if scope == APPLY_FOREVER { - pp.policy.fw.saveRules() + policy.fw.saveRules() } } -func (p *prompter) nextPacket() *pendingPkt { +func (p *prompter) nextConnection() pendingConnection { for { if len(p.policyQueue) == 0 { return nil } policy := p.policyQueue[0] - pp := policy.nextPending() - if pp == nil { + pc := policy.nextPending() + if pc == nil { p.removePolicy(policy) } else { - return pp + return pc } } }