Various code cleanups (still buggy/WIP).

Fixed lock/race condition in fw-prompt; consolidated redundant rule action code.
Started fuller TLS implementation in TLSGuard; probably broke a lot of stuff in the process.
Removal/reorganization of old/stale/unused code.
shw_dev
Stephen Watt 7 years ago
parent 0d13c7bb9c
commit 0bda150abc

20
TODO

@ -1,6 +1,26 @@
fw-daemon:
pc.socks() an getOptString() return overlapping information
remove all stale references to SANDBOX: rules/policyForPathAndSandbox()
fw-prompt: fw-prompt:
scope returned by new rules is bad (always set to process)
prompter should have a timestamp field
Iteration through fw-prompt choices can't brute force by index #
This function needs to be updated because it no longer works: This function needs to be updated because it no longer works:
func toggleHover() { mainWin.SetKeepAbove(len(decisionWaiters) > 0) } func toggleHover() { mainWin.SetKeepAbove(len(decisionWaiters) > 0) }
Each duplicate prompt needs to be expandable into individual items
gnome-shell:
Start using new async DBus methods
new go-procsnitch vendor package changes should be pushed into main project new go-procsnitch vendor package changes should be pushed into main project

@ -734,7 +734,8 @@ func toggleValidRuleState() {
btnApprove.SetSensitive(ok) btnApprove.SetSensitive(ok)
btnDeny.SetSensitive(ok) btnDeny.SetSensitive(ok)
btnIgnore.SetSensitive(ok) // btnIgnore.SetSensitive(ok)
btnIgnore.SetSensitive(false)
} }
func createCurrentRule() (ruleColumns, error) { func createCurrentRule() (ruleColumns, error) {
@ -1023,6 +1024,44 @@ func addPendingPrompts(rules []string) {
} }
func buttonAction(action string) {
globalPromptLock.Lock()
rule, idx, err := getSelectedRule()
if err != nil {
globalPromptLock.Unlock()
promptError("Error occurred processing request: " + err.Error())
return
}
rule, err = createCurrentRule()
if err != nil {
globalPromptLock.Unlock()
promptError("Error occurred constructing new rule: " + err.Error())
return
}
fmt.Println("rule = ", rule)
rulestr := action
if action == "ALLOW" && rule.ForceTLS {
rulestr += "_TLSONLY"
}
rulestr += "|" + rule.Proto + ":" + rule.Target + ":" + strconv.Itoa(rule.Port)
rulestr += "|" + sgfw.RuleModeString[sgfw.RuleMode(rule.Scope)]
fmt.Println("RULESTR = ", rulestr)
makeDecision(idx, rulestr, int(rule.Scope))
fmt.Println("Decision made.")
err = removeSelectedRule(idx, true)
globalPromptLock.Unlock()
if err == nil {
clearEditor()
} else {
promptError("Error setting new rule: " + err.Error())
}
}
func main() { func main() {
decisionWaiters = make([]*decisionWaiter, 0) decisionWaiters = make([]*decisionWaiter, 0)
_, err := newDbusServer() _, err := newDbusServer()
@ -1233,90 +1272,12 @@ func main() {
tv.SetModel(listStore) tv.SetModel(listStore)
btnApprove.Connect("clicked", func() { btnApprove.Connect("clicked", func() {
globalPromptLock.Lock() buttonAction("ALLOW")
rule, idx, err := getSelectedRule()
if err != nil {
globalPromptLock.Unlock()
promptError("Error occurred processing request: " + err.Error())
return
}
rule, err = createCurrentRule()
if err != nil {
globalPromptLock.Unlock()
promptError("Error occurred constructing new rule: " + err.Error())
return
}
fmt.Println("rule = ", rule)
rulestr := "ALLOW"
if rule.ForceTLS {
rulestr += "_TLSONLY"
}
rulestr += "|" + rule.Proto + ":" + rule.Target + ":" + strconv.Itoa(rule.Port)
fmt.Println("RULESTR = ", rulestr)
makeDecision(idx, rulestr, int(rule.Scope))
fmt.Println("Decision made.")
globalPromptLock.Unlock()
err = removeSelectedRule(idx, true)
if err == nil {
clearEditor()
} else {
promptError("Error setting new rule: " + err.Error())
}
}) })
btnDeny.Connect("clicked", func() { btnDeny.Connect("clicked", func() {
globalPromptLock.Lock() buttonAction("DENY")
rule, idx, err := getSelectedRule()
if err != nil {
globalPromptLock.Unlock()
promptError("Error occurred processing request: " + err.Error())
return
}
rule, err = createCurrentRule()
if err != nil {
globalPromptLock.Unlock()
promptError("Error occurred constructing new rule: " + err.Error())
return
}
fmt.Println("rule = ", rule)
rulestr := "DENY|" + rule.Proto + ":" + rule.Target + ":" + strconv.Itoa(rule.Port)
fmt.Println("RULESTR = ", rulestr)
makeDecision(idx, rulestr, int(rule.Scope))
fmt.Println("Decision made.")
globalPromptLock.Unlock()
err = removeSelectedRule(idx, true)
if err == nil {
clearEditor()
} else {
promptError("Error setting new rule: " + err.Error())
}
})
btnIgnore.Connect("clicked", func() {
globalPromptLock.Lock()
_, idx, err := getSelectedRule()
if err != nil {
globalPromptLock.Unlock()
promptError("Error occurred processing request: " + err.Error())
return
}
makeDecision(idx, "", 0)
fmt.Println("Decision made.")
globalPromptLock.Unlock()
err = removeSelectedRule(idx, true)
if err == nil {
clearEditor()
} else {
promptError("Error setting new rule: " + err.Error())
}
}) })
// btnIgnore.Connect("clicked", buttonAction)
// tv.SetActivateOnSingleClick(true) // tv.SetActivateOnSingleClick(true)
tv.Connect("row-activated", func() { tv.Connect("row-activated", func() {

@ -252,7 +252,7 @@ func (ds *dbusServer) GetPendingRequests(policy string) ([]string, *dbus.Error)
} }
func (ds *dbusServer) AddRuleAsync(scope uint32, rule string, policy string) (bool, *dbus.Error) { func (ds *dbusServer) AddRuleAsync(scope uint32, rule string, policy string) (bool, *dbus.Error) {
log.Debugf("AddRuleAsync %v, %v / %v\n", scope, rule, policy) log.Warningf("AddRuleAsync %v, %v / %v\n", scope, rule, policy)
ds.fw.lock.Lock() ds.fw.lock.Lock()
defer ds.fw.lock.Unlock() defer ds.fw.lock.Unlock()
@ -335,11 +335,6 @@ func (ds *dbusServer) SetConfig(key string, val dbus.Variant) *dbus.Error {
return nil return nil
} }
/*func (ds *dbusServer) prompt(p *Policy) {
log.Info("prompting...")
ds.prompter.prompt(p)
} */
func (ob *dbusObjectP) alertRule(data string) { func (ob *dbusObjectP) alertRule(data string) {
ob.Call("com.subgraph.fwprompt.EventNotifier.Alert", 0, data) ob.Call("com.subgraph.fwprompt.EventNotifier.Alert", 0, data)
} }

@ -283,6 +283,7 @@ func (fw *Firewall) policyForPath(path string) *Policy {
func (p *Policy) processPacket(pkt *nfqueue.NFQPacket, pinfo *procsnitch.Info, optstr string) { func (p *Policy) processPacket(pkt *nfqueue.NFQPacket, pinfo *procsnitch.Info, optstr string) {
fmt.Println("policy processPacket()")
/* hbytes, err := pkt.GetHWAddr() /* hbytes, err := pkt.GetHWAddr()
if err != nil { if err != nil {
log.Notice("Failed to get HW address underlying packet: ", err) log.Notice("Failed to get HW address underlying packet: ", err)
@ -292,6 +293,17 @@ func (p *Policy) processPacket(pkt *nfqueue.NFQPacket, pinfo *procsnitch.Info, o
dstb := pkt.Packet.NetworkLayer().NetworkFlow().Dst().Raw() dstb := pkt.Packet.NetworkLayer().NetworkFlow().Dst().Raw()
dstip := net.IP(dstb) dstip := net.IP(dstb)
srcip := net.IP(pkt.Packet.NetworkLayer().NetworkFlow().Src().Raw()) srcip := net.IP(pkt.Packet.NetworkLayer().NetworkFlow().Src().Raw())
/* Can we pass this through quickly? */
/* this probably isn't a performance enhancement. */
/*_, dstp := getPacketPorts(pkt)
fres := p.rules.filter(pkt, srcip, dstip, dstp, dstip.String(), pinfo, optstr)
if fres == FILTER_ALLOW {
fmt.Printf("Packet passed wildcard rules without requiring DNS lookup; accepting: %s:%d\n", dstip, dstp)
pkt.Accept()
return
}*/
name := p.fw.dns.Lookup(dstip, pinfo.Pid) name := p.fw.dns.Lookup(dstip, pinfo.Pid)
log.Infof("Lookup(%s): %s", dstip.String(), name) log.Infof("Lookup(%s): %s", dstip.String(), name)
@ -333,7 +345,7 @@ func (p *Policy) nextPending() (pendingConnection, bool) {
} }
for i := 0; i < len(p.pendingQueue); i++ { for i := 0; i < len(p.pendingQueue); i++ {
fmt.Printf("pendingQueue %v of %v: %v\n", i, len(p.pendingQueue), p.pendingQueue[i]) // fmt.Printf("XXX: pendingQueue %v of %v: %v\n", i, len(p.pendingQueue), p.pendingQueue[i])
if !p.pendingQueue[i].getPrompting() { if !p.pendingQueue[i].getPrompting() {
return p.pendingQueue[i], false return p.pendingQueue[i], false
} }
@ -488,6 +500,7 @@ func printPacket(pkt *nfqueue.NFQPacket, hostname string, pinfo *procsnitch.Info
} }
func (fw *Firewall) filterPacket(pkt *nfqueue.NFQPacket) { func (fw *Firewall) filterPacket(pkt *nfqueue.NFQPacket) {
fmt.Println("firewall: filterPacket()")
isudp := pkt.Packet.Layer(layers.LayerTypeUDP) != nil isudp := pkt.Packet.Layer(layers.LayerTypeUDP) != nil
if basicAllowPacket(pkt) { if basicAllowPacket(pkt) {

@ -2,7 +2,6 @@ package sgfw
import ( import (
"fmt" "fmt"
"net"
"os" "os"
"os/user" "os/user"
"strconv" "strconv"
@ -49,36 +48,20 @@ func (p *prompter) prompt(policy *Policy) {
func (p *prompter) promptLoop() { func (p *prompter) promptLoop() {
// p.lock.Lock() // p.lock.Lock()
for { for {
// fmt.Println("XXX: promptLoop() outer") p.processNextPacket()
p.lock.Lock()
for p.processNextPacket() {
// fmt.Println("XXX: promptLoop() inner")
}
p.lock.Unlock()
// fmt.Println("promptLoop() wait")
// p.cond.Wait()
} }
} }
func (p *prompter) processNextPacket() bool { func (p *prompter) processNextPacket() bool {
//fmt.Println("processNextPacket()")
var pc pendingConnection = nil var pc pendingConnection = nil
/* if 1 == 2 {
// if !DoMultiPrompt {
pc, _ = p.nextConnection()
if pc == nil {
return false
}
p.lock.Unlock()
defer p.lock.Lock()
p.processConnection(pc)
return true
} */
empty := true empty := true
for { for {
p.lock.Lock()
pc, empty = p.nextConnection() pc, empty = p.nextConnection()
fmt.Println("XXX: processNextPacket() loop; empty = ", empty, " / pc = ", pc) p.lock.Unlock()
//fmt.Println("XXX: processNextPacket() loop; empty = ", empty, " / pc = ", pc)
if pc == nil && empty { if pc == nil && empty {
return false return false
} else if pc == nil { } else if pc == nil {
@ -87,14 +70,14 @@ func (p *prompter) processNextPacket() bool {
break break
} }
} }
p.lock.Unlock()
defer p.lock.Lock()
// fmt.Println("XXX: Waiting for prompt lock go...")
if pc.getPrompting() { if pc.getPrompting() {
log.Debugf("Skipping over already prompted connection") log.Debugf("Skipping over already prompted connection")
return false
} }
pc.setPrompting(true) pc.setPrompting(true)
fmt.Println("processConnection")
go p.processConnection(pc) go p.processConnection(pc)
return true return true
} }
@ -223,6 +206,8 @@ func (p *prompter) processConnection(pc pendingConnection) {
PC2FDMapRunning = true PC2FDMapRunning = true
PC2FDMapLock.Unlock() PC2FDMapLock.Unlock()
go monitorPromptFDLoop() go monitorPromptFDLoop()
} else {
PC2FDMapLock.Unlock()
} }
} }
@ -245,11 +230,6 @@ func (p *prompter) processConnection(pc pendingConnection) {
dststr = addr + " (via proxy resolver)" dststr = addr + " (via proxy resolver)"
} }
// callChan := make(chan *dbus.Call, 10)
// saveChannel(callChan, true, false)
// fmt.Println("# outstanding prompt chans = ", len(outstandingPromptChans))
// fmt.Println("ABOUT TO CALL ASYNC PROMPT")
monitorPromptFDs(pc) monitorPromptFDs(pc)
call := p.dbusObj.Call("com.subgraph.FirewallPrompt.RequestPromptAsync", 0, call := p.dbusObj.Call("com.subgraph.FirewallPrompt.RequestPromptAsync", 0,
pc.getGUID(), pc.getGUID(),
@ -309,53 +289,8 @@ func (p *prompter) processConnection(pc pendingConnection) {
FirewallConfig.PromptExpert, FirewallConfig.PromptExpert,
int32(FirewallConfig.DefaultActionID)) int32(FirewallConfig.DefaultActionID))
select {
case call := <-callChan:
if call.Err != nil {
fmt.Println("Error reading DBus channel (accepting packet): ", call.Err)
policy.removePending(pc)
pc.accept()
saveChannel(callChan, false, true)
time.Sleep(1 * time.Second)
return
}
if len(call.Body) != 2 {
log.Warning("SGFW got back response in unrecognized format, len = ", len(call.Body))
saveChannel(callChan, false, true)
if (len(call.Body) == 3) && (call.Body[2] == 666) {
fmt.Printf("+++++++++ AWESOME: %v | %v | %v\n", call.Body[0], call.Body[1], call.Body[2])
scope = call.Body[0].(int32)
rule = call.Body[1].(string)
}
return
}
fmt.Printf("DBUS GOT BACK: %v, %v\n", call.Body[0], call.Body[1])
scope = call.Body[0].(int32)
rule = call.Body[1].(string)
}
saveChannel(callChan, false, true) saveChannel(callChan, false, true)
// Try alerting every other channel
promptData := make([]interface{}, 3)
promptData[0] = scope
promptData[1] = rule
promptData[2] = 666
promptChanLock.Lock()
fmt.Println("# channels to alert: ", len(outstandingPromptChans))
for chidx, _ := range outstandingPromptChans {
alertChannel(chidx, scope, rule)
// ch <- &dbus.Call{Body: promptData}
}
promptChanLock.Unlock() */
/* err := call.Store(&scope, &rule) /* err := call.Store(&scope, &rule)
if err != nil { if err != nil {
log.Warningf("Error sending dbus RequestPrompt message: %v", err) log.Warningf("Error sending dbus RequestPrompt message: %v", err)
@ -383,17 +318,19 @@ func (p *prompter) processConnection(pc pendingConnection) {
} }
tempRule := fmt.Sprintf("%s|%s", toks[0], toks[1]) tempRule := fmt.Sprintf("%s|%s", toks[0], toks[1])
tempRule += "||-1:-1|" + sandbox + "|"
if pc.src() != nil && !pc.src().Equal(net.ParseIP("127.0.0.1")) && sandbox != "" { if pc.src() != nil && !pc.src().IsLoopback() && sandbox != "" {
//if !strings.HasSuffix(rule, "SYSTEM") && !strings.HasSuffix(rule, "||") { //if !strings.HasSuffix(rule, "SYSTEM") && !strings.HasSuffix(rule, "||") {
//rule += "||" //rule += "||"
//} //}
//ule += "|||" + pc.src().String() //ule += "|||" + pc.src().String()
tempRule += "||-1:-1|" + sandbox + "|" + pc.src().String() // tempRule += "||-1:-1|" + sandbox + "|" + pc.src().String()
tempRule += pc.src().String()
} else { } else {
tempRule += "||-1:-1|" + sandbox + "|" // tempRule += "||-1:-1|" + sandbox + "|"
} }
r, err := policy.parseRule(tempRule, false) r, err := policy.parseRule(tempRule, false)
if err != nil { if err != nil {
@ -432,7 +369,7 @@ func (p *prompter) nextConnection() (pendingConnection, bool) {
fmt.Println("policy queue len = ", len(p.policyQueue)) fmt.Println("policy queue len = ", len(p.policyQueue))
for pind < len(p.policyQueue) { for pind < len(p.policyQueue) {
fmt.Printf("pind = %v of %v\n", pind, len(p.policyQueue)) //fmt.Printf("XXX: pind = %v of %v\n", pind, len(p.policyQueue))
policy := p.policyQueue[pind] policy := p.policyQueue[pind]
pc, qempty := policy.nextPending() pc, qempty := policy.nextPending()
@ -455,18 +392,21 @@ func (p *prompter) nextConnection() (pendingConnection, bool) {
if len(toks) > 2 { if len(toks) > 2 {
sandbox = toks[2] sandbox = toks[2]
} }
sandbox += ""
tempRule := fmt.Sprintf("%s|%s", toks[0], toks[1]) tempRule := fmt.Sprintf("%s|%s", toks[0], toks[1])
tempRule += "||-1:-1|" + sandbox + "|"
/* if pc.src() != nil && !pc.src().Equal(net.ParseIP("127.0.0.1")) && sandbox != "" { /*if pc.src() != nil && !pc.src().IsLoopback() && sandbox != "" {
tempRule += "||-1:-1|" + sandbox + "|" + pc.src().String() tempRule += "||-1:-1|" + sandbox + "|" + pc.src().String()
} else {*/ } else {
tempRule += "||-1:-1|" + sandbox + "|" tempRule += "||-1:-1|" + sandbox + "|"
// } }*/
r, err := policy.parseRule(tempRule, false) r, err := policy.parseRule(tempRule, false)
if err != nil { if err != nil {
log.Warningf("Error parsing rule string returned from dbus RequestPrompt: %v", err) log.Warningf("Error parsing rule string returned from dbus RequestPrompt: %v", err)
continue
// policy.removePending(pc) // policy.removePending(pc)
// pc.drop() // pc.drop()
// return // return
@ -476,8 +416,8 @@ func (p *prompter) nextConnection() (pendingConnection, bool) {
r.mode = RULE_MODE_SESSION r.mode = RULE_MODE_SESSION
} else if fscope == APPLY_PROCESS { } else if fscope == APPLY_PROCESS {
r.mode = RULE_MODE_PROCESS r.mode = RULE_MODE_PROCESS
// r.pid = pc.procInfo().Pid /*r.pid = pc.procInfo().Pid
// pcoroner.MonitorProcess(r.pid) pcoroner.MonitorProcess(r.pid)*/
} }
if !policy.processNewRule(r, fscope) { if !policy.processNewRule(r, fscope) {
// p.lock.Lock() // p.lock.Lock()

@ -47,13 +47,11 @@ func (r *Rule) String() string {
func (r *Rule) getString(redact bool) string { func (r *Rule) getString(redact bool) string {
rtype := RuleActionString[RULE_ACTION_DENY] rtype := RuleActionString[RULE_ACTION_DENY]
if r.rtype == RULE_ACTION_ALLOW { if r.rtype == RULE_ACTION_ALLOW || r.rtype == RULE_ACTION_ALLOW_TLSONLY {
rtype = RuleActionString[RULE_ACTION_ALLOW] rtype = RuleActionString[r.rtype]
} else if r.rtype == RULE_ACTION_ALLOW_TLSONLY {
rtype = RuleActionString[RULE_ACTION_ALLOW_TLSONLY]
} }
rmode := "|" + RuleModeString[r.mode]
rmode := "|" + RuleModeString[r.mode]
protostr := "" protostr := ""
if r.proto != "tcp" { if r.proto != "tcp" {
@ -258,17 +256,17 @@ func (r *Rule) parse(s string) bool {
} else if parts[2] == "PERMANENT" { } else if parts[2] == "PERMANENT" {
r.mode = RULE_MODE_PERMANENT r.mode = RULE_MODE_PERMANENT
} else if parts[2] != "" { } else if parts[2] != "" {
log.Notice("invalid rule mode ", parts[2], " in line ", s) log.Warning("Error: invalid rule mode ", parts[2], " in line ", s)
return false return false
} }
if !r.parsePrivs(parts[3]) { if !r.parsePrivs(parts[3]) {
log.Notice("invalid privs ", parts[3], " in line ", s) log.Warning("Error: invalid privs ", parts[3], " in line ", s)
return false return false
} }
if !r.parseSandbox(parts[4]) { if !r.parseSandbox(parts[4]) {
log.Notice("invalid sandbox ", parts[4], "in line ", s) log.Warning("invalid sandbox ", parts[4], "in line ", s)
return false return false
} }

@ -18,10 +18,42 @@ const SSL3_RT_ALERT = 21
const SSL3_RT_HANDSHAKE = 22 const SSL3_RT_HANDSHAKE = 22
const SSL3_RT_APPLICATION_DATA = 23 const SSL3_RT_APPLICATION_DATA = 23
const SSL3_MT_HELLO_REQUEST = 0
const SSL3_MT_CLIENT_HELLO = 1
const SSL3_MT_SERVER_HELLO = 2 const SSL3_MT_SERVER_HELLO = 2
const SSL3_MT_CERTIFICATE = 11 const SSL3_MT_CERTIFICATE = 11
const SSL3_MT_CERTIFICATE_REQUEST = 13 const SSL3_MT_CERTIFICATE_REQUEST = 13
const SSL3_MT_SERVER_DONE = 14 const SSL3_MT_SERVER_DONE = 14
const SSL3_MT_CERTIFICATE_STATUS = 22
const SSL3_AL_WARNING = 1
const SSL3_AL_FATAL = 2
const SSL3_AD_CLOSE_NOTIFY = 0
const SSL3_AD_UNEXPECTED_MESSAGE = 10
const SSL3_AD_BAD_RECORD_MAC = 20
const TLS1_AD_DECRYPTION_FAILED = 21
const TLS1_AD_RECORD_OVERFLOW = 22
const SSL3_AD_DECOMPRESSION_FAILURE = 30
const SSL3_AD_HANDSHAKE_FAILURE = 40
const SSL3_AD_NO_CERTIFICATE = 41
const SSL3_AD_BAD_CERTIFICATE = 42
const SSL3_AD_UNSUPPORTED_CERTIFICATE = 43
const SSL3_AD_CERTIFICATE_REVOKED = 44
const SSL3_AD_CERTIFICATE_EXPIRED = 45
const SSL3_AD_CERTIFICATE_UNKNOWN = 46
const SSL3_AD_ILLEGAL_PARAMETER = 47
const TLS1_AD_UNKNOWN_CA = 48
const TLS1_AD_ACCESS_DENIED = 49
const TLS1_AD_DECODE_ERROR = 50
const TLS1_AD_DECRYPT_ERROR = 51
const TLS1_AD_EXPORT_RESTRICTION = 60
const TLS1_AD_PROTOCOL_VERSION = 70
const TLS1_AD_INSUFFICIENT_SECURITY = 71
const TLS1_AD_INTERNAL_ERROR = 80
const TLS1_AD_INAPPROPRIATE_FALLBACK = 86
const TLS1_AD_USER_CANCELLED = 90
const TLS1_AD_NO_RENEGOTIATION = 100
const TLS1_AD_UNSUPPORTED_EXTENSION = 110
type connReader struct { type connReader struct {
client bool client bool
@ -76,8 +108,15 @@ func connectionReader(conn net.Conn, is_client bool, c chan connReader, done cha
rtype = int(header[0]) rtype = int(header[0])
mlen = int(int(header[3])<<8 | int(header[4])) mlen = int(int(header[3])<<8 | int(header[4]))
fmt.Printf("TLS data chunk header read: type = %#x, maj = %v, min = %v, len = %v\n", rtype, header[1], header[2], mlen) fmt.Printf("TLS data chunk header read: type = %#x, maj = %v, min = %v, len = %v\n", rtype, header[1], header[2], mlen)
buffered = header
/* 16384+1024 if compression is not null */
/* or 16384+2048 if ciphertext */
if mlen > 16384 {
ret_error = errors.New(fmt.Sprintf("TLSGuard read TLS plaintext record of excessively large length; dropping (%v bytes)", mlen))
continue
}
buffered = header
stage++ stage++
} else if stage == 2 { } else if stage == 2 {
remainder := make([]byte, mlen) remainder := make([]byte, mlen)
@ -121,6 +160,9 @@ func TLSGuard(conn, conn2 net.Conn, fqdn string) error {
go connectionReader(conn, true, crChan, dChan) go connectionReader(conn, true, crChan, dChan)
go connectionReader(conn2, false, crChan, dChan) go connectionReader(conn2, false, crChan, dChan)
client_expected := SSL3_MT_CLIENT_HELLO
server_expected := SSL3_MT_SERVER_HELLO
select_loop: select_loop:
for { for {
if ndone == 2 { if ndone == 2 {
@ -148,6 +190,35 @@ select_loop:
if cr.err == nil { if cr.err == nil {
if cr.rtype == SSL3_RT_CHANGE_CIPHER_SPEC || cr.rtype == SSL3_RT_APPLICATION_DATA || if cr.rtype == SSL3_RT_CHANGE_CIPHER_SPEC || cr.rtype == SSL3_RT_APPLICATION_DATA ||
cr.rtype == SSL3_RT_ALERT { cr.rtype == SSL3_RT_ALERT {
/* We expect only a single byte of data */
if cr.rtype == SSL3_RT_CHANGE_CIPHER_SPEC {
if len(cr.data) != 6 {
return errors.New(fmt.Sprintf("TLSGuard dropped connection with strange change cipher spec data length (%v bytes)", len(cr.data)))
}
if cr.data[5] != 1 {
return errors.New(fmt.Sprintf("TLSGuard dropped connection with strange change cipher spec data (%#x bytes)", cr.data[5]))
}
} else if cr.rtype == SSL3_RT_ALERT {
if cr.data[5] == SSL3_AL_WARNING {
fmt.Println("SSL ALERT TYPE: warning")
} else if cr.data[5] == SSL3_AL_FATAL {
fmt.Println("SSL ALERT TYPE: fatal")
} else {
fmt.Println("SSL ALERT TYPE UNKNOWN")
}
alert_desc := int(int(cr.data[6])<<8 | int(cr.data[7]))
fmt.Println("ALERT DESCRIPTION: ", alert_desc)
if cr.data[5] == SSL3_AL_FATAL {
return errors.New(fmt.Sprintf("TLSGuard dropped connection after fatal error alert detected"))
} else if alert_desc == SSL3_AD_CLOSE_NOTIFY {
return errors.New(fmt.Sprintf("TLSGuard dropped connection after close_notify alert detected"))
}
}
// fmt.Println("OTHER DATA; PASSING THRU") // fmt.Println("OTHER DATA; PASSING THRU")
if cr.rtype == SSL3_RT_ALERT { if cr.rtype == SSL3_RT_ALERT {
fmt.Println("ALERT = ", cr.data) fmt.Println("ALERT = ", cr.data)
@ -161,19 +232,35 @@ select_loop:
return errors.New(fmt.Sprintf("Expected TLS server handshake byte was not received [%#x vs 0x16]", cr.rtype)) return errors.New(fmt.Sprintf("Expected TLS server handshake byte was not received [%#x vs 0x16]", cr.rtype))
} }
if cr.rtype < SSL3_RT_CHANGE_CIPHER_SPEC || cr.rtype > SSL3_RT_APPLICATION_DATA {
return errors.New(fmt.Sprintf("TLSGuard dropping connection with unknown content type: %#x", cr.rtype))
}
serverMsg := cr.data[5:] serverMsg := cr.data[5:]
s := serverMsg[0] s := uint(serverMsg[0])
fmt.Printf("s = %#x\n", s) fmt.Printf("s = %#x\n", s)
if s > 0x22 { if cr.client && s != uint(client_expected) {
return errors.New(fmt.Sprintf("Client sent handshake type %#x but expected %#x", s, client_expected))
} else if !cr.client && s != uint(server_expected) {
return errors.New(fmt.Sprintf("Server sent handshake type %#x but expected %#x", s, server_expected))
}
if !cr.client && s == SSL3_MT_HELLO_REQUEST {
fmt.Println("Server sent hello request")
continue
}
if s > SSL3_MT_CERTIFICATE_STATUS {
fmt.Println("WTF: ", cr.data) fmt.Println("WTF: ", cr.data)
} }
if s == SSL3_MT_CERTIFICATE {
fmt.Println("HMM")
// Message len, 3 bytes // Message len, 3 bytes
serverMessageLen := serverMsg[1:4] serverMessageLen := serverMsg[1:4]
serverMessageLenInt := int(int(serverMessageLen[0])<<16 | int(serverMessageLen[1])<<8 | int(serverMessageLen[2])) serverMessageLenInt := int(int(serverMessageLen[0])<<16 | int(serverMessageLen[1])<<8 | int(serverMessageLen[2]))
if s == SSL3_MT_CERTIFICATE {
fmt.Println("HMM")
// fmt.Printf("chunk len = %v, serverMsgLen = %v, slint = %v\n", len(chunk), len(serverMsg), serverMessageLenInt) // fmt.Printf("chunk len = %v, serverMsgLen = %v, slint = %v\n", len(chunk), len(serverMsg), serverMessageLenInt)
if len(serverMsg) < serverMessageLenInt { if len(serverMsg) < serverMessageLenInt {
return errors.New(fmt.Sprintf("len(serverMsg) %v < serverMessageLenInt %v!\n", len(serverMsg), serverMessageLenInt)) return errors.New(fmt.Sprintf("len(serverMsg) %v < serverMessageLenInt %v!\n", len(serverMsg), serverMessageLenInt))

@ -111,7 +111,7 @@ func findUDPSocketAll(srcAddr net.IP, srcPort uint16, dstAddr net.IP, dstPort ui
if custdata == nil { if custdata == nil {
if strictness == MATCH_STRICT { if strictness == MATCH_STRICT {
return findSocket(proto, func(ss socketStatus) bool { return findSocket(proto, func(ss socketStatus) bool {
fmt.Println("Match strict") // fmt.Println("Match strict")
return ss.remote.ip.Equal(dstAddr) && ss.local.port == srcPort && ss.local.ip.Equal(srcAddr) return ss.remote.ip.Equal(dstAddr) && ss.local.port == srcPort && ss.local.ip.Equal(srcAddr)
//return ss.local.port == srcPort && ss.local.ip.Equal(srcAddr) //return ss.local.port == srcPort && ss.local.ip.Equal(srcAddr)
}) })
@ -124,27 +124,29 @@ func findUDPSocketAll(srcAddr net.IP, srcPort uint16, dstAddr net.IP, dstPort ui
fmt.Printf("local ip: %v\n source ip: %v\n", ss.local.ip, srcAddr) fmt.Printf("local ip: %v\n source ip: %v\n", ss.local.ip, srcAddr)
*/ */
if ss.local.port == srcPort && (ss.local.ip.Equal(net.IPv4(0, 0, 0, 0)) && ss.remote.ip.Equal(net.IPv4(0, 0, 0, 0))) { if (ss.local.port == srcPort) && addrMatchesAny(ss.local.ip) && addrMatchesAny(ss.remote.ip) {
fmt.Printf("Matching for UDP socket bound to *:%d\n", ss.local.port) fmt.Printf("Loose match for UDP socket bound to *:%d\n", ss.local.port)
return true return true
} else if ss.remote.ip.Equal(dstAddr) && ss.local.port == srcPort && ss.local.ip.Equal(srcAddr) { } else if ss.remote.ip.Equal(dstAddr) && ss.local.port == srcPort && ss.local.ip.Equal(srcAddr) {
return true return true
} }
// Finally, loop through all interfaces if src port matches // Finally, loop through all interfaces if src port matches
if ss.local.port == srcPort { if ss.local.port == srcPort {
ifs, err := net.Interfaces() ifs, err := net.Interfaces()
if err != nil { if err != nil {
log.Warningf("Error on net.Interfaces(): %v", err) log.Warning("Error retrieving list of network interfaces for UDP socket lookup:", err)
return false return false
} }
for _, i := range ifs { for _, i := range ifs {
addrs, err := i.Addrs() addrs, err := i.Addrs()
if err != nil { if err != nil {
log.Warningf("Error on Interface.Addrs(): %v", err) log.Warning("Error retrieving network interface for UDP socket lookup:", err)
return false return false
} }
for _, addr := range addrs { for _, addr := range addrs {
var ifip net.IP var ifip net.IP
switch x := addr.(type) { switch x := addr.(type) {
@ -153,13 +155,16 @@ func findUDPSocketAll(srcAddr net.IP, srcPort uint16, dstAddr net.IP, dstPort ui
case *net.IPAddr: case *net.IPAddr:
ifip = x.IP ifip = x.IP
} }
if ss.local.ip.Equal(ifip) { if ss.local.ip.Equal(ifip) {
fmt.Printf("Matched on UDP socket bound to %v:%d\n", ifip, srcPort) fmt.Printf("Matched on UDP socket bound to %v:%d\n", ifip, srcPort)
return true return true
} }
} }
} }
} }
return false return false
//return (ss.remote.ip.Equal(dstAddr) || ss.remote.ip.Equal(net.IPv4(0,0,0,0))) && ss.local.port == srcPort && (ss.local.ip.Equal(srcAddr) || ss.local.ip.Equal(net.IPv4(0,0,0,0))) //return (ss.remote.ip.Equal(dstAddr) || ss.remote.ip.Equal(net.IPv4(0,0,0,0))) && ss.local.port == srcPort && (ss.local.ip.Equal(srcAddr) || ss.local.ip.Equal(net.IPv4(0,0,0,0)))
/* /*

Loading…
Cancel
Save