Refactor to support more than one type of 'connection'

Created and interface pendingConnection which represents either a
pending packet or socks connection.
socks-filter
Bruce Leidl 9 years ago
parent bb71d8309d
commit 68218b4e83

@ -6,26 +6,77 @@ import (
"github.com/subgraph/fw-daemon/nfqueue" "github.com/subgraph/fw-daemon/nfqueue"
"github.com/subgraph/go-procsnitch" "github.com/subgraph/go-procsnitch"
"net"
) )
type pendingConnection interface {
policy() *Policy
procInfo() *procsnitch.Info
hostname() string
dst() net.IP
dstPort() uint16
accept()
drop()
print() string
}
type pendingPkt struct { type pendingPkt struct {
policy *Policy pol *Policy
hostname string name string
pkt *nfqueue.Packet pkt *nfqueue.Packet
pinfo *procsnitch.Info pinfo *procsnitch.Info
} }
func (pp *pendingPkt) policy() *Policy {
return pp.pol
}
func (pp *pendingPkt) procInfo() *procsnitch.Info {
return pp.pinfo
}
func (pp *pendingPkt) hostname() string {
return pp.name
}
func (pp *pendingPkt) dst() net.IP {
return pp.pkt.Dst
}
func (pp *pendingPkt) dstPort() uint16 {
return pp.pkt.DstPort
}
func (pp *pendingPkt) accept() {
pp.pkt.Accept()
}
func (pp *pendingPkt) drop() {
pp.pkt.Mark = 1
pp.pkt.Accept()
}
func (pp *pendingPkt) print() string {
return printPacket(pp.pkt, pp.name)
}
type Policy struct { type Policy struct {
fw *Firewall fw *Firewall
path string path string
application string application string
icon string icon string
rules RuleList rules RuleList
pendingQueue []*pendingPkt pendingQueue []pendingConnection
promptInProgress bool promptInProgress bool
lock sync.Mutex lock sync.Mutex
} }
func (fw *Firewall) PolicyForPath(path string) *Policy {
fw.lock.Lock()
defer fw.lock.Unlock()
return fw.policyForPath(path)
}
func (fw *Firewall) policyForPath(path string) *Policy { func (fw *Firewall) policyForPath(path string) *Policy {
if _, ok := fw.policyMap[path]; !ok { if _, ok := fw.policyMap[path]; !ok {
p := new(Policy) p := new(Policy)
@ -50,7 +101,7 @@ func (p *Policy) processPacket(pkt *nfqueue.Packet, pinfo *procsnitch.Info) {
if !logRedact { if !logRedact {
log.Info("Lookup(%s): %s", pkt.Dst.String(), name) log.Info("Lookup(%s): %s", pkt.Dst.String(), name)
} }
result := p.rules.filter(pkt, pinfo, name) result := p.rules.filterPacket(pkt, pinfo, name)
switch result { switch result {
case FILTER_DENY: case FILTER_DENY:
pkt.Mark = 1 pkt.Mark = 1
@ -58,21 +109,21 @@ func (p *Policy) processPacket(pkt *nfqueue.Packet, pinfo *procsnitch.Info) {
case FILTER_ALLOW: case FILTER_ALLOW:
pkt.Accept() pkt.Accept()
case FILTER_PROMPT: case FILTER_PROMPT:
p.processPromptResult(&pendingPkt{policy: p, hostname: name, pkt: pkt, pinfo: pinfo}) p.processPromptResult(&pendingPkt{pol: p, name: name, pkt: pkt, pinfo: pinfo})
default: default:
log.Warning("Unexpected filter result: %d", result) log.Warning("Unexpected filter result: %d", result)
} }
} }
func (p *Policy) processPromptResult(pp *pendingPkt) { func (p *Policy) processPromptResult(pc pendingConnection) {
p.pendingQueue = append(p.pendingQueue, pp) p.pendingQueue = append(p.pendingQueue, pc)
if !p.promptInProgress { if !p.promptInProgress {
p.promptInProgress = true p.promptInProgress = true
go p.fw.dbus.prompt(p) go p.fw.dbus.prompt(p)
} }
} }
func (p *Policy) nextPending() *pendingPkt { func (p *Policy) nextPending() pendingConnection {
p.lock.Lock() p.lock.Lock()
defer p.lock.Unlock() defer p.lock.Unlock()
if len(p.pendingQueue) == 0 { if len(p.pendingQueue) == 0 {
@ -81,14 +132,14 @@ func (p *Policy) nextPending() *pendingPkt {
return p.pendingQueue[0] return p.pendingQueue[0]
} }
func (p *Policy) removePending(pp *pendingPkt) { func (p *Policy) removePending(pc pendingConnection) {
p.lock.Lock() p.lock.Lock()
defer p.lock.Unlock() defer p.lock.Unlock()
remaining := []*pendingPkt{} remaining := []pendingConnection{}
for _, pkt := range p.pendingQueue { for _, c := range p.pendingQueue {
if pkt != pp { if c != pc {
remaining = append(remaining, pkt) remaining = append(remaining, c)
} }
} }
if len(remaining) != len(p.pendingQueue) { if len(remaining) != len(p.pendingQueue) {
@ -141,18 +192,17 @@ func (p *Policy) removeRule(r *Rule) {
} }
func (p *Policy) filterPending(rule *Rule) { func (p *Policy) filterPending(rule *Rule) {
remaining := []*pendingPkt{} remaining := []pendingConnection{}
for _, pp := range p.pendingQueue { for _, pc := range p.pendingQueue {
if rule.match(pp.pkt, pp.hostname) { if rule.match(pc.dst(), pc.dstPort(), pc.hostname()) {
log.Info("Also applying %s to %s", rule.getString(logRedact), printPacket(pp.pkt, pp.hostname)) log.Info("Also applying %s to %s", rule.getString(logRedact), pc.print())
if rule.rtype == RULE_ALLOW { if rule.rtype == RULE_ALLOW {
pp.pkt.Accept() pc.accept()
} else { } else {
pp.pkt.Mark = 1 pc.drop()
pp.pkt.Accept()
} }
} else { } else {
remaining = append(remaining, pp) remaining = append(remaining, pc)
} }
} }
if len(remaining) != len(p.pendingQueue) { if len(remaining) != len(p.pendingQueue) {
@ -208,9 +258,7 @@ func (fw *Firewall) filterPacket(pkt *nfqueue.Packet) {
pkt.Accept() pkt.Accept()
return return
} }
fw.lock.Lock() policy := fw.PolicyForPath(pinfo.ExePath)
policy := fw.policyForPath(pinfo.ExePath)
fw.lock.Unlock()
policy.processPacket(pkt, pinfo) policy.processPacket(pkt, pinfo)
} }

@ -53,13 +53,13 @@ func (p *prompter) promptLoop() {
} }
func (p *prompter) processNextPacket() bool { func (p *prompter) processNextPacket() bool {
pp := p.nextPacket() pc := p.nextConnection()
if pp == nil { if pc == nil {
return false return false
} }
p.lock.Unlock() p.lock.Unlock()
defer p.lock.Lock() defer p.lock.Lock()
p.processPacket(pp) p.processConnection(pc)
return true return true
} }
@ -76,65 +76,64 @@ func printScope(scope int32) string {
} }
} }
func (p *prompter) processPacket(pp *pendingPkt) { func (p *prompter) processConnection(pc pendingConnection) {
var scope int32 var scope int32
var rule string var rule string
addr := pp.hostname addr := pc.hostname()
if addr == "" { if addr == "" {
addr = pp.pkt.Dst.String() addr = pc.dst().String()
} }
policy := pc.policy()
call := p.dbusObj.Call("com.subgraph.FirewallPrompt.RequestPrompt", 0, call := p.dbusObj.Call("com.subgraph.FirewallPrompt.RequestPrompt", 0,
pp.policy.application, policy.application,
pp.policy.icon, policy.icon,
pp.policy.path, policy.path,
addr, addr,
int32(pp.pkt.DstPort), int32(pc.dstPort()),
pp.pkt.Dst.String(), pc.dst().String(),
uidToUser(pp.pinfo.UID), uidToUser(pc.procInfo().UID),
int32(pp.pinfo.Pid)) int32(pc.procInfo().Pid))
err := call.Store(&scope, &rule) err := call.Store(&scope, &rule)
if err != nil { if err != nil {
log.Warning("Error sending dbus RequestPrompt message: %v", err) log.Warning("Error sending dbus RequestPrompt message: %v", err)
pp.policy.removePending(pp) policy.removePending(pc)
pp.pkt.Mark = 1 pc.drop()
pp.pkt.Accept()
return return
} }
r, err := pp.policy.parseRule(rule, false) r, err := policy.parseRule(rule, false)
if err != nil { if err != nil {
log.Warning("Error parsing rule string returned from dbus RequestPrompt: %v", err) log.Warning("Error parsing rule string returned from dbus RequestPrompt: %v", err)
pp.policy.removePending(pp) policy.removePending(pc)
pp.pkt.Mark = 1 pc.drop()
pp.pkt.Accept()
return return
} }
if scope == APPLY_SESSION { if scope == APPLY_SESSION {
r.sessionOnly = true r.sessionOnly = true
} }
if !pp.policy.processNewRule(r, scope) { if !policy.processNewRule(r, scope) {
p.lock.Lock() p.lock.Lock()
defer p.lock.Unlock() defer p.lock.Unlock()
p.removePolicy(pp.policy) p.removePolicy(pc.policy())
} }
if scope == APPLY_FOREVER { if scope == APPLY_FOREVER {
pp.policy.fw.saveRules() policy.fw.saveRules()
} }
} }
func (p *prompter) nextPacket() *pendingPkt { func (p *prompter) nextConnection() pendingConnection {
for { for {
if len(p.policyQueue) == 0 { if len(p.policyQueue) == 0 {
return nil return nil
} }
policy := p.policyQueue[0] policy := p.policyQueue[0]
pp := policy.nextPending() pc := policy.nextPending()
if pp == nil { if pc == nil {
p.removePolicy(policy) p.removePolicy(policy)
} else { } else {
return pp return pc
} }
} }
} }

Loading…
Cancel
Save