From 2fc7525cc773a3a04374af8f10cfefe0f97334a7 Mon Sep 17 00:00:00 2001 From: Stephen Watt Date: Mon, 25 Sep 2017 18:34:16 -0400 Subject: [PATCH] Added new RemovePrompt DBus call to complement RequestPrompt (GUID-based prompt removal). The addition of a rule matching multiple pending connections in fw-prompt now removes all of them. fw-prompter now increments ref# column for identical prompt requests. Fixed/cleaned up/updated TLSGuard code. Added TLSGuard toggle option to fw-prompt GUI (default for SOCKS connections). fw-prompt now displays icon of filtered application. DBus RequestPrompt() now "works" asynchronously. TLSGuard fixed under certain conditions but still very buggy. Fixed some fw-prompt crash conditions with treeview mutex locking. Fixed SOCKS connection panic condition linked to closed channel. Cleanup of unused data structures/values. --- fw-prompt/dbus.go | 35 +--- fw-prompt/fw-prompt.go | 301 ++++++++++++++++++++++++------ sgfw/dns.go | 4 +- sgfw/policy.go | 54 +++++- sgfw/prompt.go | 129 +++++++++++-- sgfw/rules.go | 6 +- sgfw/socks_server_chain.go | 24 ++- sgfw/tlsguard.go | 364 ++++++++++++++++++++++++------------- 8 files changed, 674 insertions(+), 243 deletions(-) diff --git a/fw-prompt/dbus.go b/fw-prompt/dbus.go index 3344a62..1104c2e 100644 --- a/fw-prompt/dbus.go +++ b/fw-prompt/dbus.go @@ -4,7 +4,6 @@ import ( "errors" "github.com/godbus/dbus" "log" - // "github.com/gotk3/gotk3/glib" ) type dbusServer struct { @@ -12,27 +11,6 @@ type dbusServer struct { run bool } -type promptData struct { - Application string - Icon string - Path string - Address string - Port int - IP string - Origin string - Proto string - UID int - GID int - Username string - Groupname string - Pid int - Sandbox string - OptString string - Expanded bool - Expert bool - Action int -} - func newDbusServer() (*dbusServer, error) { conn, err := dbus.SystemBus() @@ -62,10 +40,10 @@ func newDbusServer() (*dbusServer, error) { return ds, nil } -func (ds *dbusServer) RequestPrompt(application, icon, path, address string, port int32, ip, origin, proto string, uid, gid int32, username, groupname string, pid int32, sandbox string, +func (ds *dbusServer) RequestPrompt(guid, application, icon, path, address string, port int32, ip, origin, proto string, uid, gid int32, username, groupname string, pid int32, sandbox string, is_socks bool, optstring string, expanded, expert bool, action int32) (int32, string, *dbus.Error) { - log.Printf("request prompt: app = %s, icon = %s, path = %s, address = %s, is_socks = %v, action = %v\n", application, icon, path, address, is_socks, action) - decision := addRequest(nil, path, proto, int(pid), ip, address, int(port), int(uid), int(gid), origin, is_socks, optstring, sandbox) + log.Printf("request prompt: app = %s, icon = %s, path = %s, address = %s / ip = %s, is_socks = %v, action = %v\n", application, icon, path, address, ip, is_socks, action) + decision := addRequest(nil, guid, path, icon, proto, int(pid), ip, address, int(port), int(uid), int(gid), origin, is_socks, optstring, sandbox) log.Print("Waiting on decision...") decision.Cond.L.Lock() for !decision.Ready { @@ -73,6 +51,11 @@ func (ds *dbusServer) RequestPrompt(application, icon, path, address string, por } log.Print("Decision returned: ", decision.Rule) decision.Cond.L.Unlock() - // glib.IdleAdd(func, data) return int32(decision.Scope), decision.Rule, nil } + +func (ds *dbusServer) RemovePrompt(guid string) *dbus.Error { + log.Printf("++++++++ Cancelling prompt: %s\n", guid) + removeRequest(nil, guid) + return nil +} diff --git a/fw-prompt/fw-prompt.go b/fw-prompt/fw-prompt.go index 36fecf7..adc215d 100644 --- a/fw-prompt/fw-prompt.go +++ b/fw-prompt/fw-prompt.go @@ -34,7 +34,10 @@ type decisionWaiter struct { } type ruleColumns struct { + nrefs int Path string + GUID string + Icon string Proto string Pid int Target string @@ -45,21 +48,25 @@ type ruleColumns struct { Uname string Gname string Origin string + IsSocks bool + ForceTLS bool Scope int } var userPrefs fpPreferences var mainWin *gtk.Window var Notebook *gtk.Notebook -var globalLS *gtk.ListStore +var globalLS *gtk.ListStore = nil var globalTV *gtk.TreeView +var globalPromptLock = &sync.Mutex{} +var globalIcon *gtk.Image var decisionWaiters []*decisionWaiter var editApp, editTarget, editPort, editUser, editGroup *gtk.Entry var comboProto *gtk.ComboBoxText var radioOnce, radioProcess, radioParent, radioSession, radioPermanent *gtk.RadioButton var btnApprove, btnDeny, btnIgnore *gtk.Button -var chkUser, chkGroup *gtk.CheckButton +var chkTLS, chkUser, chkGroup *gtk.CheckButton func dumpDecisions() { fmt.Println("XXX Total of decisions pending: ", len(decisionWaiters)) @@ -306,7 +313,8 @@ func createColumn(title string, id int) *gtk.TreeViewColumn { } func createListStore(general bool) *gtk.ListStore { - colData := []glib.Type{glib.TYPE_INT, glib.TYPE_STRING, glib.TYPE_STRING, glib.TYPE_INT, glib.TYPE_STRING, glib.TYPE_STRING, glib.TYPE_INT, glib.TYPE_INT, glib.TYPE_INT, glib.TYPE_STRING, glib.TYPE_STRING} + colData := []glib.Type{glib.TYPE_INT, glib.TYPE_STRING, glib.TYPE_STRING, glib.TYPE_STRING, glib.TYPE_STRING, glib.TYPE_INT, + glib.TYPE_STRING, glib.TYPE_STRING, glib.TYPE_INT, glib.TYPE_INT, glib.TYPE_INT, glib.TYPE_STRING, glib.TYPE_INT, glib.TYPE_STRING} listStore, err := gtk.ListStoreNew(colData...) if err != nil { @@ -316,7 +324,66 @@ func createListStore(general bool) *gtk.ListStore { return listStore } -func addRequest(listStore *gtk.ListStore, path, proto string, pid int, ipaddr, hostname string, port, uid, gid int, origin string, is_socks bool, optstring string, sandbox string) *decisionWaiter { +func removeRequest(listStore *gtk.ListStore, guid string) { + removed := false + globalPromptLock.Lock() + + /* XXX: This is horrible. Figure out how to do this properly. */ + for ridx := 0; ridx < 2000; ridx++ { + + rule, _, err := getRuleByIdx(ridx) + if err != nil { + break + } else if rule.GUID == guid { + removeSelectedRule(ridx, true) + removed = true + break + } + + } + + globalPromptLock.Unlock() + + if !removed { + log.Printf("Unexpected condition: SGFW requested prompt removal for non-existent GUID %v\n", guid) + } + +} + +func addRequestInc(listStore *gtk.ListStore, guid, path, icon, proto string, pid int, ipaddr, hostname string, port, uid, gid int, origin string, is_socks bool, optstring string, sandbox string) bool { + duplicated := false + + globalPromptLock.Lock() + + for ridx := 0; ridx < 2000; ridx++ { + + /* XXX: This is horrible. Figure out how to do this properly. */ + rule, iter, err := getRuleByIdx(ridx) + if err != nil { + break + // XXX: not compared: optstring/sandbox + } else if (rule.Path == path) && (rule.Proto == proto) && (rule.Pid == pid) && (rule.Target == ipaddr) && (rule.Hostname == hostname) && + (rule.Port == port) && (rule.UID == uid) && (rule.GID == gid) && (rule.Origin == origin) && (rule.IsSocks == is_socks) { + rule.nrefs++ + + err := globalLS.SetValue(iter, 0, rule.nrefs) + if err != nil { + log.Print("Error creating duplicate firewall prompt entry:", err) + break + } + + fmt.Println("YES REALLY DUPLICATE: ", rule.nrefs) + duplicated = true + break + } + + } + + globalPromptLock.Unlock() + return duplicated +} + +func addRequest(listStore *gtk.ListStore, guid, path, icon, proto string, pid int, ipaddr, hostname string, port, uid, gid int, origin string, is_socks bool, optstring string, sandbox string) *decisionWaiter { if listStore == nil { listStore = globalLS waitTimes := []int{1, 2, 5, 10} @@ -342,6 +409,16 @@ func addRequest(listStore *gtk.ListStore, path, proto string, pid int, ipaddr, h log.Fatal("SGFW prompter GUI failed to load for unknown reasons") } + if addRequestInc(listStore, guid, path, icon, proto, pid, ipaddr, hostname, port, uid, gid, origin, is_socks, optstring, sandbox) { + fmt.Println("REQUEST WAS DUPLICATE") + decision := addDecision() + toggleHover() + return decision + } else { + fmt.Println("NOT DUPLICATE") + } + + globalPromptLock.Lock() iter := listStore.Append() if is_socks { @@ -352,24 +429,32 @@ func addRequest(listStore *gtk.ListStore, path, proto string, pid int, ipaddr, h } } - colVals := make([]interface{}, 11) + colVals := make([]interface{}, 14) colVals[0] = 1 - colVals[1] = path - colVals[2] = proto - colVals[3] = pid + colVals[1] = guid + colVals[2] = path + colVals[3] = icon + colVals[4] = proto + colVals[5] = pid if ipaddr == "" { - colVals[4] = "---" + colVals[6] = "---" } else { - colVals[4] = ipaddr + colVals[6] = ipaddr } - colVals[5] = hostname - colVals[6] = port - colVals[7] = uid - colVals[8] = gid - colVals[9] = origin - colVals[10] = optstring + colVals[7] = hostname + colVals[8] = port + colVals[9] = uid + colVals[10] = gid + colVals[11] = origin + colVals[12] = 0 + + if is_socks { + colVals[12] = 1 + } + + colVals[13] = optstring colNums := make([]int, len(colVals)) @@ -378,6 +463,7 @@ func addRequest(listStore *gtk.ListStore, path, proto string, pid int, ipaddr, h } err := listStore.Set(iter, colNums, colVals) + globalPromptLock.Unlock() if err != nil { log.Fatal("Unable to add row:", err) @@ -495,6 +581,8 @@ func toggleHover() { func toggleValidRuleState() { ok := true + globalPromptLock.Lock() + if numSelections() <= 0 { ok = false } @@ -537,6 +625,7 @@ func toggleValidRuleState() { btnApprove.SetSensitive(ok) btnDeny.SetSensitive(ok) btnIgnore.SetSensitive(ok) + globalPromptLock.Unlock() } func createCurrentRule() (ruleColumns, error) { @@ -579,6 +668,8 @@ func createCurrentRule() (ruleColumns, error) { rule.UID, rule.GID = 0, 0 rule.Uname, rule.Gname = "", "" + + rule.ForceTLS = chkTLS.GetActive() /* Pid int Origin string */ @@ -586,6 +677,7 @@ func createCurrentRule() (ruleColumns, error) { } func clearEditor() { + globalIcon.Clear() editApp.SetText("") editTarget.SetText("") editPort.SetText("") @@ -599,6 +691,7 @@ func clearEditor() { radioPermanent.SetActive(false) chkUser.SetActive(false) chkGroup.SetActive(false) + chkTLS.SetActive(false) } func removeSelectedRule(idx int, rmdecision bool) error { @@ -634,78 +727,116 @@ func numSelections() int { return int(rows.Length()) } -func getSelectedRule() (ruleColumns, int, error) { +// Needs to be locked by the caller +func getRuleByIdx(idx int) (ruleColumns, *gtk.TreeIter, error) { rule := ruleColumns{} - sel, err := globalTV.GetSelection() + path, err := gtk.TreePathNewFromString(fmt.Sprintf("%d", idx)) if err != nil { - return rule, -1, err + return rule, nil, err } - rows := sel.GetSelectedRows(globalLS) + iter, err := globalLS.GetIter(path) + if err != nil { + return rule, nil, err + } - if rows.Length() <= 0 { - return rule, -1, errors.New("No selection was made") + rule.nrefs, err = lsGetInt(globalLS, iter, 0) + if err != nil { + return rule, nil, err } - rdata := rows.NthData(0) - lIndex, err := strconv.Atoi(rdata.(*gtk.TreePath).String()) + rule.GUID, err = lsGetStr(globalLS, iter, 1) if err != nil { - return rule, -1, err + return rule, nil, err } - fmt.Println("lindex = ", lIndex) - path, err := gtk.TreePathNewFromString(fmt.Sprintf("%d", lIndex)) + rule.Path, err = lsGetStr(globalLS, iter, 2) if err != nil { - return rule, -1, err + return rule, nil, err } - iter, err := globalLS.GetIter(path) + rule.Icon, err = lsGetStr(globalLS, iter, 3) if err != nil { - return rule, -1, err + return rule, nil, err } - rule.Path, err = lsGetStr(globalLS, iter, 1) + rule.Proto, err = lsGetStr(globalLS, iter, 4) if err != nil { - return rule, -1, err + return rule, nil, err } - rule.Proto, err = lsGetStr(globalLS, iter, 2) + rule.Pid, err = lsGetInt(globalLS, iter, 5) if err != nil { - return rule, -1, err + return rule, nil, err } - rule.Pid, err = lsGetInt(globalLS, iter, 3) + rule.Target, err = lsGetStr(globalLS, iter, 6) if err != nil { - return rule, -1, err + return rule, nil, err } - rule.Target, err = lsGetStr(globalLS, iter, 4) + rule.Hostname, err = lsGetStr(globalLS, iter, 7) if err != nil { - return rule, -1, err + return rule, nil, err } - rule.Hostname, err = lsGetStr(globalLS, iter, 5) + rule.Port, err = lsGetInt(globalLS, iter, 8) if err != nil { - return rule, -1, err + return rule, nil, err } - rule.Port, err = lsGetInt(globalLS, iter, 6) + rule.UID, err = lsGetInt(globalLS, iter, 9) if err != nil { - return rule, -1, err + return rule, nil, err + } + + rule.GID, err = lsGetInt(globalLS, iter, 10) + if err != nil { + return rule, nil, err + } + + rule.Origin, err = lsGetStr(globalLS, iter, 11) + if err != nil { + return rule, nil, err } - rule.UID, err = lsGetInt(globalLS, iter, 7) + rule.IsSocks = false + is_socks, err := lsGetInt(globalLS, iter, 12) + if err != nil { + return rule, nil, err + } + + if is_socks != 0 { + rule.IsSocks = true + } + + return rule, iter, nil +} + +// Needs to be locked by the caller +func getSelectedRule() (ruleColumns, int, error) { + rule := ruleColumns{} + + sel, err := globalTV.GetSelection() if err != nil { return rule, -1, err } - rule.GID, err = lsGetInt(globalLS, iter, 8) + rows := sel.GetSelectedRows(globalLS) + + if rows.Length() <= 0 { + return rule, -1, errors.New("No selection was made") + } + + rdata := rows.NthData(0) + lIndex, err := strconv.Atoi(rdata.(*gtk.TreePath).String()) if err != nil { return rule, -1, err } - rule.Origin, err = lsGetStr(globalLS, iter, 9) + fmt.Println("lindex = ", lIndex) + rule, _, err = getRuleByIdx(lIndex) if err != nil { return rule, -1, err } @@ -811,10 +942,18 @@ func main() { editbox := get_vbox() hbox := get_hbox() lbl := get_label("Application path:") + + globalIcon, err = gtk.ImageNew() + if err != nil { + log.Fatal("Unable to create image:", err) + } + + // globalIcon.SetFromIconName("firefox", gtk.ICON_SIZE_DND) editApp = get_entry("") editApp.Connect("changed", toggleValidRuleState) hbox.PackStart(lbl, false, false, 10) - hbox.PackStart(editApp, true, true, 50) + hbox.PackStart(editApp, true, true, 10) + hbox.PackStart(globalIcon, false, false, 10) editbox.PackStart(hbox, false, false, 5) hbox = get_hbox() @@ -842,7 +981,9 @@ func main() { radioSession = get_radiobutton(radioOnce, "Session", false) radioPermanent = get_radiobutton(radioOnce, "Permanent", false) radioParent.SetSensitive(false) - hbox.PackStart(lbl, false, false, 10) + chkTLS = get_checkbox("Require TLS", false) + hbox.PackStart(chkTLS, false, false, 10) + hbox.PackStart(lbl, false, false, 20) hbox.PackStart(radioOnce, false, false, 5) hbox.PackStart(radioProcess, false, false, 5) hbox.PackStart(radioParent, false, false, 5) @@ -872,16 +1013,31 @@ func main() { box.PackStart(scrollbox, false, true, 5) tv.AppendColumn(createColumn("#", 0)) - tv.AppendColumn(createColumn("Path", 1)) - tv.AppendColumn(createColumn("Protocol", 2)) - tv.AppendColumn(createColumn("PID", 3)) - tv.AppendColumn(createColumn("IP Address", 4)) - tv.AppendColumn(createColumn("Hostname", 5)) - tv.AppendColumn(createColumn("Port", 6)) - tv.AppendColumn(createColumn("UID", 7)) - tv.AppendColumn(createColumn("GID", 8)) - tv.AppendColumn(createColumn("Origin", 9)) - tv.AppendColumn(createColumn("Details", 10)) + + guidcol := createColumn("GUID", 1) + guidcol.SetVisible(false) + tv.AppendColumn(guidcol) + + tv.AppendColumn(createColumn("Path", 2)) + + icol := createColumn("Icon", 3) + icol.SetVisible(false) + tv.AppendColumn(icol) + + tv.AppendColumn(createColumn("Protocol", 4)) + tv.AppendColumn(createColumn("PID", 5)) + tv.AppendColumn(createColumn("IP Address", 6)) + tv.AppendColumn(createColumn("Hostname", 7)) + tv.AppendColumn(createColumn("Port", 8)) + tv.AppendColumn(createColumn("UID", 9)) + tv.AppendColumn(createColumn("GID", 10)) + tv.AppendColumn(createColumn("Origin", 11)) + + scol := createColumn("Is SOCKS", 12) + scol.SetVisible(false) + tv.AppendColumn(scol) + + tv.AppendColumn(createColumn("Details", 13)) listStore := createListStore(true) globalLS = listStore @@ -889,23 +1045,33 @@ func main() { tv.SetModel(listStore) btnApprove.Connect("clicked", func() { + // globalPromptLock.Lock() rule, idx, err := getSelectedRule() if err != nil { + // globalPromptLock.Unlock() promptError("Error occurred processing request: " + err.Error()) return } rule, err = createCurrentRule() if err != nil { + // globalPromptLock.Unlock() promptError("Error occurred constructing new rule: " + err.Error()) return } fmt.Println("rule = ", rule) - rulestr := "ALLOW|" + rule.Proto + ":" + rule.Target + ":" + strconv.Itoa(rule.Port) + rulestr := "ALLOW" + + if rule.ForceTLS { + rulestr += "_TLSONLY" + } + + rulestr += "|" + rule.Proto + ":" + rule.Target + ":" + strconv.Itoa(rule.Port) fmt.Println("RULESTR = ", rulestr) makeDecision(idx, rulestr, int(rule.Scope)) fmt.Println("Decision made.") + // globalPromptLock.Unlock() err = removeSelectedRule(idx, true) if err == nil { clearEditor() @@ -915,14 +1081,17 @@ func main() { }) btnDeny.Connect("clicked", func() { + // globalPromptLock.Lock() rule, idx, err := getSelectedRule() if err != nil { + // globalPromptLock.Unlock() promptError("Error occurred processing request: " + err.Error()) return } rule, err = createCurrentRule() if err != nil { + // globalPromptLock.Unlock() promptError("Error occurred constructing new rule: " + err.Error()) return } @@ -932,6 +1101,7 @@ func main() { fmt.Println("RULESTR = ", rulestr) makeDecision(idx, rulestr, int(rule.Scope)) fmt.Println("Decision made.") + // globalPromptLock.Unlock() err = removeSelectedRule(idx, true) if err == nil { clearEditor() @@ -941,14 +1111,17 @@ func main() { }) btnIgnore.Connect("clicked", func() { + // globalPromptLock.Lock() _, idx, err := getSelectedRule() if err != nil { + // globalPromptLock.Unlock() promptError("Error occurred processing request: " + err.Error()) return } makeDecision(idx, "", 0) fmt.Println("Decision made.") + // globalPromptLock.Unlock() err = removeSelectedRule(idx, true) if err == nil { clearEditor() @@ -959,14 +1132,22 @@ func main() { // tv.SetActivateOnSingleClick(true) tv.Connect("row-activated", func() { + // globalPromptLock.Lock() seldata, _, err := getSelectedRule() if err != nil { + // globalPromptLock.Unlock() promptError("Unexpected error reading selected rule: " + err.Error()) return } editApp.SetText(seldata.Path) + if seldata.Icon != "" { + globalIcon.SetFromIconName(seldata.Icon, gtk.ICON_SIZE_DND) + } else { + globalIcon.Clear() + } + if seldata.Hostname != "" { editTarget.SetText(seldata.Hostname) } else { @@ -981,6 +1162,7 @@ func main() { radioSession.SetActive(false) radioPermanent.SetActive(false) comboProto.SetActiveID(seldata.Proto) + chkTLS.SetActive(seldata.IsSocks) if seldata.Uname != "" { editUser.SetText(seldata.Uname) @@ -1001,6 +1183,7 @@ func main() { chkUser.SetActive(false) chkGroup.SetActive(false) + // globalPromptLock.Unlock() return }) @@ -1011,7 +1194,7 @@ func main() { mainWin.Add(Notebook) if userPrefs.Winheight > 0 && userPrefs.Winwidth > 0 { - // fmt.Printf("height was %d, width was %d\n", userPrefs.Winheight, userPrefs.Winwidth) + // fmt.Printf("height was %d, width was %d\n", userPrefs.Winheight, userPrefs.Winwidth) mainWin.Resize(int(userPrefs.Winwidth), int(userPrefs.Winheight)) } else { mainWin.SetDefaultSize(850, 450) diff --git a/sgfw/dns.go b/sgfw/dns.go index 05df439..b68e1c2 100644 --- a/sgfw/dns.go +++ b/sgfw/dns.go @@ -166,7 +166,7 @@ func (dc *dnsCache) Lookup(ip net.IP, pid int) string { entry, ok := dc.ipMap[pid][ip.String()] if ok { if now.Before(entry.exp) { - // log.Noticef("XXX: LOOKUP on %v / %v = %v, ttl = %v / %v\n", pid, ip.String(), entry.name, entry.ttl, entry.exp) + // log.Noticef("XXX: LOOKUP on %v / %v = %v, ttl = %v / %v\n", pid, ip.String(), entry.name, entry.ttl, entry.exp) return entry.name } else { log.Warningf("Skipping expired per-pid (%d) DNS cache entry: %s -> %s / exp. %v (%ds)\n", @@ -180,7 +180,7 @@ func (dc *dnsCache) Lookup(ip net.IP, pid int) string { if ok { if now.Before(entry.exp) { str = entry.name - // log.Noticef("XXX: LOOKUP on %v / 0 RETURNING %v, ttl = %v / %v\n", ip.String(), str, entry.ttl, entry.exp) + // log.Noticef("XXX: LOOKUP on %v / 0 RETURNING %v, ttl = %v / %v\n", ip.String(), str, entry.ttl, entry.exp) } else { log.Warningf("Skipping expired global DNS cache entry: %s -> %s / exp. %v (%ds)\n", ip.String(), entry.name, entry.exp, entry.ttl) diff --git a/sgfw/policy.go b/sgfw/policy.go index cd82843..7863e61 100644 --- a/sgfw/policy.go +++ b/sgfw/policy.go @@ -52,6 +52,9 @@ type pendingConnection interface { drop() setPrompting(bool) getPrompting() bool + setPrompter(*dbusObjectP) + getPrompter() *dbusObjectP + getGUID() string print() string } @@ -62,6 +65,23 @@ type pendingPkt struct { pinfo *procsnitch.Info optstring string prompting bool + prompter *dbusObjectP + guid string +} + +/* Not a *REAL* GUID */ +func genGUID() string { + frnd, err := os.Open("/dev/urandom") + if err != nil { + log.Fatal("Error reading random data source:", err) + } + + rndb := make([]byte, 16) + frnd.Read(rndb) + frnd.Close() + + guid := fmt.Sprintf("%x-%x-%x-%x", rndb[0:4], rndb[4:8], rndb[8:12], rndb[12:]) + return guid } func getEmptyPInfo() *procsnitch.Info { @@ -165,6 +185,22 @@ func (pp *pendingPkt) drop() { pp.pkt.Accept() } +func (pp *pendingPkt) setPrompter(val *dbusObjectP) { + pp.prompter = val +} + +func (pp *pendingPkt) getPrompter() *dbusObjectP { + return pp.prompter +} + +func (pp *pendingPkt) getGUID() string { + if pp.guid == "" { + pp.guid = genGUID() + } + + return pp.guid +} + func (pp *pendingPkt) getPrompting() bool { return pp.prompting } @@ -265,7 +301,7 @@ func (p *Policy) processPacket(pkt *nfqueue.NFQPacket, pinfo *procsnitch.Info, o case FILTER_ALLOW: pkt.Accept() case FILTER_PROMPT: - p.processPromptResult(&pendingPkt{pol: p, name: name, pkt: pkt, pinfo: pinfo, optstring: optstr, prompting: false}) + p.processPromptResult(&pendingPkt{pol: p, name: name, pkt: pkt, pinfo: pinfo, optstring: optstr, prompter: nil, prompting: false}) default: log.Warningf("Unexpected filter result: %d", result) } @@ -327,6 +363,7 @@ func (p *Policy) processNewRule(r *Rule, scope FilterScope) bool { if scope != APPLY_ONCE { p.rules = append(p.rules, r) } + fmt.Println("----------------------- processNewRule()") p.filterPending(r) if len(p.pendingQueue) == 0 { p.promptInProgress = false @@ -370,8 +407,19 @@ func (p *Policy) filterPending(rule *Rule) { remaining := []pendingConnection{} for _, pc := range p.pendingQueue { if rule.match(pc.src(), pc.dst(), pc.dstPort(), pc.hostname(), pc.proto(), pc.procInfo().UID, pc.procInfo().GID, uidToUser(pc.procInfo().UID), gidToGroup(pc.procInfo().GID)) { + prompter := pc.getPrompter() + + if prompter == nil { + fmt.Println("-------- prompter = NULL") + } else { + fmt.Println("---------- could send prompter") + + call := prompter.Call("com.subgraph.FirewallPrompt.RemovePrompt", 0, pc.getGUID()) + fmt.Println("CAAAAAAAAAAAAAAALL = ", call) + } + log.Infof("Adding rule for: %s", rule.getString(FirewallConfig.LogRedact)) - // log.Noticef("%s > %s", rule.getString(FirewallConfig.LogRedact), pc.print()) + // log.Noticef("%s > %s", rule.getString(FirewallConfig.LogRedact), pc.print()) if rule.rtype == RULE_ACTION_ALLOW { pc.accept() } else if rule.rtype == RULE_ACTION_ALLOW_TLSONLY { @@ -649,7 +697,7 @@ func LookupSandboxProc(srcip net.IP, srcp uint16, dstip net.IP, dstp uint16, pro rlines = append(rlines, strings.Join(ssplit, ":")) } - // log.Warningf("Looking for %s:%d => %s:%d \n %s\n******\n", srcip, srcp, dstip, dstp, data) + // log.Warningf("Looking for %s:%d => %s:%d \n %s\n******\n", srcip, srcp, dstip, dstp, data) if proto == "tcp" { res = procsnitch.LookupTCPSocketProcessAll(srcip, srcp, dstip, dstp, rlines) diff --git a/sgfw/prompt.go b/sgfw/prompt.go index 4c03c6c..1061d5b 100644 --- a/sgfw/prompt.go +++ b/sgfw/prompt.go @@ -18,7 +18,9 @@ var DoMultiPrompt = true const MAX_PROMPTS = 5 var outstandingPrompts = 0 +var outstandingPromptChans [](chan *dbus.Call) var promptLock = &sync.Mutex{} +var promptChanLock = &sync.Mutex{} func newPrompter(conn *dbus.Conn) *prompter { p := new(prompter) @@ -37,6 +39,30 @@ type prompter struct { policyQueue []*Policy } +func saveChannel(ch chan *dbus.Call, add bool, do_close bool) { + promptChanLock.Lock() + + if add { + outstandingPromptChans = append(outstandingPromptChans, ch) + } else { + + for idx, och := range outstandingPromptChans { + if och == ch { + outstandingPromptChans = append(outstandingPromptChans[:idx], outstandingPromptChans[idx+1:]...) + break + } + } + + } + + if !add && do_close { + close(ch) + } + + promptChanLock.Unlock() + return +} + func (p *prompter) prompt(policy *Policy) { p.lock.Lock() defer p.lock.Unlock() @@ -53,11 +79,11 @@ func (p *prompter) prompt(policy *Policy) { func (p *prompter) promptLoop() { p.lock.Lock() for { - // fmt.Println("XXX: promptLoop() outer") + // fmt.Println("XXX: promptLoop() outer") for p.processNextPacket() { - // fmt.Println("XXX: promptLoop() inner") + // fmt.Println("XXX: promptLoop() inner") } - // fmt.Println("promptLoop() wait") + // fmt.Println("promptLoop() wait") p.cond.Wait() } } @@ -79,7 +105,7 @@ func (p *prompter) processNextPacket() bool { empty := true for { pc, empty = p.nextConnection() - // fmt.Println("XXX: processNextPacket() loop; empty = ", empty, " / pc = ", pc) + // fmt.Println("XXX: processNextPacket() loop; empty = ", empty, " / pc = ", pc) if pc == nil && empty { return false } else if pc == nil { @@ -90,7 +116,7 @@ func (p *prompter) processNextPacket() bool { } p.lock.Unlock() defer p.lock.Lock() - // fmt.Println("XXX: Waiting for prompt lock go...") + // fmt.Println("XXX: Waiting for prompt lock go...") for { promptLock.Lock() if outstandingPrompts >= MAX_PROMPTS { @@ -106,9 +132,9 @@ func (p *prompter) processNextPacket() bool { break } - // fmt.Println("XXX: Passed prompt lock!") + // fmt.Println("XXX: Passed prompt lock!") outstandingPrompts++ - // fmt.Println("XXX: Incremented outstanding to ", outstandingPrompts) + // fmt.Println("XXX: Incremented outstanding to ", outstandingPrompts) promptLock.Unlock() // if !pc.getPrompting() { pc.setPrompting(true) @@ -120,15 +146,34 @@ func (p *prompter) processNextPacket() bool { func processReturn(pc pendingConnection) { promptLock.Lock() outstandingPrompts-- - // fmt.Println("XXX: Return decremented outstanding to ", outstandingPrompts) + // fmt.Println("XXX: Return decremented outstanding to ", outstandingPrompts) promptLock.Unlock() pc.setPrompting(false) } +func alertChannel(chidx int, scope int32, rule string) { + defer func() { + if r := recover(); r != nil { + log.Warning("SGFW recovered from panic while delivering out of band rule:", r) + } + }() + + promptData := make([]interface{}, 3) + promptData[0] = scope + promptData[1] = rule + promptData[2] = 666 + + outstandingPromptChans[chidx] <- &dbus.Call{Body: promptData} +} + func (p *prompter) processConnection(pc pendingConnection) { var scope int32 var rule string + if pc.getPrompter() == nil { + pc.setPrompter(&dbusObjectP{p.dbusObj}) + } + if DoMultiPrompt { defer processReturn(pc) } @@ -144,10 +189,14 @@ func (p *prompter) processConnection(pc pendingConnection) { if pc.dst() != nil { dststr = pc.dst().String() } else { - dststr = addr + " (proxy to resolve)" + dststr = addr + " (via proxy resolver)" } - call := p.dbusObj.Call("com.subgraph.FirewallPrompt.RequestPrompt", 0, + callChan := make(chan *dbus.Call, 10) + saveChannel(callChan, true, false) + fmt.Println("# outstanding prompt chans = ", len(outstandingPromptChans)) + p.dbusObj.Go("com.subgraph.FirewallPrompt.RequestPrompt", 0, callChan, + pc.getGUID(), policy.application, policy.icon, policy.path, @@ -167,14 +216,62 @@ func (p *prompter) processConnection(pc pendingConnection) { FirewallConfig.PromptExpanded, FirewallConfig.PromptExpert, int32(FirewallConfig.DefaultActionID)) - err := call.Store(&scope, &rule) - if err != nil { - log.Warningf("Error sending dbus RequestPrompt message: %v", err) - policy.removePending(pc) - pc.drop() - return + + select { + case call := <-callChan: + + if call.Err != nil { + fmt.Println("Error reading DBus channel (accepting packet): ", call.Err) + policy.removePending(pc) + pc.accept() + saveChannel(callChan, false, true) + time.Sleep(1 * time.Second) + return + } + + if len(call.Body) != 2 { + log.Warning("SGFW got back response in unrecognized format, len = ", len(call.Body)) + saveChannel(callChan, false, true) + + if (len(call.Body) == 3) && (call.Body[2] == 666) { + fmt.Printf("+++++++++ AWESOME: %v | %v | %v\n", call.Body[0], call.Body[1], call.Body[2]) + scope = call.Body[0].(int32) + rule = call.Body[1].(string) + } + + return + } + + fmt.Printf("DBUS GOT BACK: %v, %v\n", call.Body[0], call.Body[1]) + scope = call.Body[0].(int32) + rule = call.Body[1].(string) } + saveChannel(callChan, false, true) + + // Try alerting every other channel + promptData := make([]interface{}, 3) + promptData[0] = scope + promptData[1] = rule + promptData[2] = 666 + promptChanLock.Lock() + fmt.Println("# channels to alert: ", len(outstandingPromptChans)) + + for chidx, _ := range outstandingPromptChans { + alertChannel(chidx, scope, rule) + // ch <- &dbus.Call{Body: promptData} + } + + promptChanLock.Unlock() + + /* err := call.Store(&scope, &rule) + if err != nil { + log.Warningf("Error sending dbus RequestPrompt message: %v", err) + policy.removePending(pc) + pc.drop() + return + } */ + // the prompt sends: // ALLOW|dest or DENY|dest // diff --git a/sgfw/rules.go b/sgfw/rules.go index 7a512dd..db237d4 100644 --- a/sgfw/rules.go +++ b/sgfw/rules.go @@ -184,7 +184,7 @@ func (rl *RuleList) filter(pkt *nfqueue.NFQPacket, src, dst net.IP, dstPort uint nfqproto = getNFQProto(pkt) } else { if r.saddr == nil && src == nil && sandboxed == false && (r.port == dstPort || r.port == matchAny) && (r.addr.Equal(anyAddress) || r.hostname == "" || r.hostname == hostname) { - // log.Notice("+ Socks5 MATCH SUCCEEDED") + // log.Notice("+ Socks5 MATCH SUCCEEDED") if r.rtype == RULE_ACTION_DENY { return FILTER_DENY } else if r.rtype == RULE_ACTION_ALLOW { @@ -203,7 +203,7 @@ func (rl *RuleList) filter(pkt *nfqueue.NFQPacket, src, dst net.IP, dstPort uint continue } if r.match(src, dst, dstPort, hostname, nfqproto, pinfo.UID, pinfo.GID, uidToUser(pinfo.UID), gidToGroup(pinfo.GID)) { - // log.Notice("+ MATCH SUCCEEDED") + // log.Notice("+ MATCH SUCCEEDED") dstStr := dst.String() if FirewallConfig.LogRedact { dstStr = STR_REDACTED @@ -214,7 +214,7 @@ func (rl *RuleList) filter(pkt *nfqueue.NFQPacket, src, dst net.IP, dstPort uint srcp, _ := getPacketPorts(pkt) srcStr = fmt.Sprintf("%s:%d", srcip, srcp) } - // log.Noticef("%s > %s %s %s -> %s:%d", + // log.Noticef("%s > %s %s %s -> %s:%d", //r.getString(FirewallConfig.LogRedact), pinfo.ExePath, r.proto, srcStr, dstStr, dstPort) if r.rtype == RULE_ACTION_DENY { //TODO: Optionally redact below log entry diff --git a/sgfw/socks_server_chain.go b/sgfw/socks_server_chain.go index 6836d58..9c35540 100644 --- a/sgfw/socks_server_chain.go +++ b/sgfw/socks_server_chain.go @@ -56,6 +56,8 @@ type pendingSocksConnection struct { pinfo *procsnitch.Info verdict chan int prompting bool + prompter *dbusObjectP + guid string optstr string } @@ -107,8 +109,11 @@ func (sc *pendingSocksConnection) deliverVerdict(v int) { } }() - sc.verdict <- v - close(sc.verdict) + if sc.verdict != nil { + sc.verdict <- v + close(sc.verdict) + sc.verdict = nil + } } func (sc *pendingSocksConnection) accept() { sc.deliverVerdict(socksVerdictAccept) } @@ -119,6 +124,18 @@ func (sc *pendingSocksConnection) acceptTLSOnly() { sc.deliverVerdict(socksVerdi func (sc *pendingSocksConnection) drop() { sc.deliverVerdict(socksVerdictDrop) } +func (sc *pendingSocksConnection) setPrompter(val *dbusObjectP) { sc.prompter = val } + +func (sc *pendingSocksConnection) getPrompter() *dbusObjectP { return sc.prompter } + +func (sc *pendingSocksConnection) getGUID() string { + if sc.guid == "" { + sc.guid = genGUID() + } + + return sc.guid +} + func (sc *pendingSocksConnection) getPrompting() bool { return sc.prompting } func (sc *pendingSocksConnection) setPrompting(val bool) { sc.prompting = val } @@ -364,6 +381,7 @@ func (c *socksChainSession) filterConnect() (bool, bool) { pinfo: pinfo, verdict: make(chan int), prompting: false, + prompter: nil, optstr: optstr, } policy.processPromptResult(pending) @@ -409,7 +427,7 @@ func (c *socksChainSession) forwardTraffic(tls bool) { if c.pinfo.Sandbox != "" { log.Errorf("TLSGuard violation: Dropping traffic from %s (sandbox: %s) to %s: %v", c.pinfo.ExePath, c.pinfo.Sandbox, c.req.Addr.addrStr, err) } else { - log.Errorf("TLSGuard violation: Dropping traffic from %s (unsandboxed) to %s: %v", c.pinfo.ExePath, c.req.Addr.addrStr, err) + log.Errorf("TLSGuard violation: Dropping traffic from %s (un-sandboxed) to %s: %v", c.pinfo.ExePath, c.req.Addr.addrStr, err) } return } else { diff --git a/sgfw/tlsguard.go b/sgfw/tlsguard.go index 0fe2781..37289e5 100644 --- a/sgfw/tlsguard.go +++ b/sgfw/tlsguard.go @@ -3,177 +3,279 @@ package sgfw import ( "crypto/x509" "errors" + "fmt" "io" "net" + "time" ) -func TLSGuard(conn, conn2 net.Conn, fqdn string) error { - // Should this be a requirement? - // if strings.HasSuffix(request.DestAddr.FQDN, "onion") { +const TLSGUARD_READ_TIMEOUT = 5 * time.Second +const TLSGUARD_MIN_TLS_VER_MAJ = 3 +const TLSGUARD_MIN_TLS_VER_MIN = 1 - handshakeByte, err := readNBytes(conn, 1) - if err != nil { - return err - } +const SSL3_RT_CHANGE_CIPHER_SPEC = 20 +const SSL3_RT_ALERT = 21 +const SSL3_RT_HANDSHAKE = 22 +const SSL3_RT_APPLICATION_DATA = 23 - if handshakeByte[0] != 0x16 { - return errors.New("Blocked client from attempting non-TLS connection") - } +const SSL3_MT_SERVER_HELLO = 2 +const SSL3_MT_CERTIFICATE = 11 +const SSL3_MT_CERTIFICATE_REQUEST = 13 +const SSL3_MT_SERVER_DONE = 14 - vers, err := readNBytes(conn, 2) - if err != nil { - return err - } +type connReader struct { + client bool + data []byte + rtype int + err error +} - length, err := readNBytes(conn, 2) - if err != nil { - return err - } +func connectionReader(conn net.Conn, is_client bool, c chan connReader, done chan bool) { + var ret_error error = nil + buffered := []byte{} + mlen := 0 + rtype := 0 + stage := 1 - ffslen := int(int(length[0])<<8 | int(length[1])) + for { + if ret_error != nil { + cr := connReader{client: is_client, data: nil, rtype: 0, err: ret_error} + c <- cr + break + } - ffs, err := readNBytes(conn, ffslen) - if err != nil { - return err - } + select { + case <-done: + fmt.Println("++ DONE: ", is_client) + if len(buffered) > 0 { + //fmt.Println("++ DONE BUT DISPOSING OF BUFFERED DATA") + c <- connReader{client: is_client, data: buffered, rtype: 0, err: nil} + } - // Transmit client hello - conn2.Write(handshakeByte) - conn2.Write(vers) - conn2.Write(length) - conn2.Write(ffs) - - // Read ServerHello - bytesRead := 0 - var s byte // 0x0e is done - var responseBuf []byte = []byte{} - valid := false - sendToClient := false - - for sendToClient == false { - // Handshake byte - serverhandshakeByte, err := readNBytes(conn2, 1) - if err != nil { - return nil - } + c <- connReader{client: is_client, data: nil, rtype: 0, err: nil} + return + default: + if stage == 1 { + header := make([]byte, 5) + conn.SetReadDeadline(time.Now().Add(TLSGUARD_READ_TIMEOUT)) + _, err := io.ReadFull(conn, header) + conn.SetReadDeadline(time.Time{}) + if err != nil { + ret_error = err + continue + } + + if int(header[1]) < TLSGUARD_MIN_TLS_VER_MAJ { + ret_error = errors.New("TLS protocol major version less than expected minimum") + continue + } else if int(header[2]) < TLSGUARD_MIN_TLS_VER_MIN { + ret_error = errors.New("TLS protocol minor version less than expected minimum") + continue + } - responseBuf = append(responseBuf, serverhandshakeByte[0]) - bytesRead += 1 + rtype = int(header[0]) + mlen = int(int(header[3])<<8 | int(header[4])) + fmt.Printf("TLS data chunk header read: type = %#x, maj = %v, min = %v, len = %v\n", rtype, header[1], header[2], mlen) + buffered = header - if serverhandshakeByte[0] != 0x16 { - return errors.New("Expected TLS server handshake byte was not received") - } + stage++ + } else if stage == 2 { + remainder := make([]byte, mlen) + conn.SetReadDeadline(time.Now().Add(TLSGUARD_READ_TIMEOUT)) + _, err := io.ReadFull(conn, remainder) + conn.SetReadDeadline(time.Time{}) + if err != nil { + ret_error = err + continue + } - // Protocol version, 2 bytes - serverProtocolVer, err := readNBytes(conn2, 2) - if err != nil { - return err - } + buffered = append(buffered, remainder...) + fmt.Printf("------- CHUNK READ: client: %v, err = %v, bytes = %v\n", is_client, err, len(buffered)) + cr := connReader{client: is_client, data: buffered, rtype: rtype, err: err} + c <- cr - bytesRead += 2 - responseBuf = append(responseBuf, serverProtocolVer...) + buffered = []byte{} + rtype = 0 + mlen = 0 + stage = 1 + } - // Record length, 2 bytes - serverRecordLen, err := readNBytes(conn2, 2) - if err != nil { - return err } - bytesRead += 2 - responseBuf = append(responseBuf, serverRecordLen...) - serverRecordLenInt := int(int(serverRecordLen[0])<<8 | int(serverRecordLen[1])) + } - // Record type byte - serverMsg, err := readNBytes(conn2, serverRecordLenInt) - if err != nil { - return err - } +} - bytesRead += len(serverMsg) - responseBuf = append(responseBuf, serverMsg...) - s = serverMsg[0] +func TLSGuard(conn, conn2 net.Conn, fqdn string) error { + x509Valid := false + ndone := 0 + // Should this be a requirement? + // if strings.HasSuffix(request.DestAddr.FQDN, "onion") { - // Message len, 3 bytes - serverMessageLen := serverMsg[1:4] - serverMessageLenInt := int(int(serverMessageLen[0])<<16 | int(serverMessageLen[1])<<8 | int(serverMessageLen[2])) + //conn client + //conn2 server - // serverHelloBody, err := readNBytes(conn2, serverMessageLenInt) - serverHelloBody := serverMsg[4 : 4+serverMessageLenInt] + fmt.Println("-------- STARTING HANDSHAKE LOOP") + crChan := make(chan connReader) + dChan := make(chan bool, 10) + go connectionReader(conn, true, crChan, dChan) + go connectionReader(conn2, false, crChan, dChan) - if s == 0x0b { - certChainLen := int(int(serverHelloBody[0])<<16 | int(serverHelloBody[1])<<8 | int(serverHelloBody[2])) - remaining := certChainLen - pos := serverHelloBody[3:certChainLen] +select_loop: + for { + if ndone == 2 { + fmt.Println("DONE channel got both notifications. Terminating loop.") + close(dChan) + close(crChan) + break + } - // var certChain []*x509.Certificate - var verifyOptions x509.VerifyOptions + select { + case cr := <-crChan: + other := conn - if fqdn != "" { - verifyOptions.DNSName = fqdn + if cr.client { + other = conn2 } - pool := x509.NewCertPool() - var c *x509.Certificate + fmt.Printf("++++ SELECT: %v, %v, %v\n", cr.client, cr.err, len(cr.data)) + if cr.err == nil && cr.data == nil { + fmt.Println("DONE channel notification received") + ndone++ + continue + } - for remaining > 0 { - certLen := int(int(pos[0])<<16 | int(pos[1])<<8 | int(pos[2])) - // fmt.Printf("Certs chain len %d, cert 1 len %d:\n", certChainLen, certLen) - cert := pos[3 : 3+certLen] - certs, err := x509.ParseCertificates(cert) - if remaining == certChainLen { - c = certs[0] - } else { - pool.AddCert(certs[0]) + if cr.err == nil { + if cr.rtype == SSL3_RT_CHANGE_CIPHER_SPEC || cr.rtype == SSL3_RT_APPLICATION_DATA || + cr.rtype == SSL3_RT_ALERT { + // fmt.Println("OTHER DATA; PASSING THRU") + if cr.rtype == SSL3_RT_ALERT { + fmt.Println("ALERT = ", cr.data) + } + other.Write(cr.data) + continue + } else if cr.client { + other.Write(cr.data) + continue + } else if cr.rtype != SSL3_RT_HANDSHAKE { + return errors.New(fmt.Sprintf("Expected TLS server handshake byte was not received [%#x vs 0x16]", cr.rtype)) } - // certChain = append(certChain, certs[0]) - if err != nil { - return err + + serverMsg := cr.data[5:] + s := serverMsg[0] + fmt.Printf("s = %#x\n", s) + + if s > 0x22 { + fmt.Println("WTF: ", cr.data) + } + + if s == SSL3_MT_CERTIFICATE { + fmt.Println("HMM") + // Message len, 3 bytes + serverMessageLen := serverMsg[1:4] + serverMessageLenInt := int(int(serverMessageLen[0])<<16 | int(serverMessageLen[1])<<8 | int(serverMessageLen[2])) + // fmt.Printf("chunk len = %v, serverMsgLen = %v, slint = %v\n", len(chunk), len(serverMsg), serverMessageLenInt) + if len(serverMsg) < serverMessageLenInt { + return errors.New(fmt.Sprintf("len(serverMsg) %v < serverMessageLenInt %v!\n", len(serverMsg), serverMessageLenInt)) + } + serverHelloBody := serverMsg[4 : 4+serverMessageLenInt] + certChainLen := int(int(serverHelloBody[0])<<16 | int(serverHelloBody[1])<<8 | int(serverHelloBody[2])) + remaining := certChainLen + pos := serverHelloBody[3:certChainLen] + + // var certChain []*x509.Certificate + var verifyOptions x509.VerifyOptions + + //fqdn = "www.reddit.com" + if fqdn != "" { + verifyOptions.DNSName = fqdn + } + + pool := x509.NewCertPool() + var c *x509.Certificate + + for remaining > 0 { + certLen := int(int(pos[0])<<16 | int(pos[1])<<8 | int(pos[2])) + // fmt.Printf("Certs chain len %d, cert 1 len %d:\n", certChainLen, certLen) + cert := pos[3 : 3+certLen] + certs, err := x509.ParseCertificates(cert) + if remaining == certChainLen { + c = certs[0] + } else { + pool.AddCert(certs[0]) + } + // certChain = append(certChain, certs[0]) + if err != nil { + return err + } + remaining = remaining - certLen - 3 + if remaining > 0 { + pos = pos[3+certLen:] + } + } + + verifyOptions.Intermediates = pool + fmt.Println("ATTEMPTING TO VERIFY: ", fqdn) + _, err := c.Verify(verifyOptions) + fmt.Println("ATTEMPTING TO VERIFY RESULT: ", err) + if err != nil { + return err + } else { + x509Valid = true + } } - remaining = remaining - certLen - 3 - if remaining > 0 { - pos = pos[3+certLen:] + + other.Write(cr.data) + + if x509Valid || (s == SSL3_MT_SERVER_DONE) || (s == SSL3_MT_CERTIFICATE_REQUEST) { + fmt.Println("BREAKING OUT OF LOOP 1") + dChan <- true + fmt.Println("BREAKING OUT OF LOOP 2") + break select_loop } - } - verifyOptions.Intermediates = pool - _, err = c.Verify(verifyOptions) - if err != nil { - return err - } else { - valid = true + // fmt.Printf("Sending chunk of type %d to client.\n", s) + } else if cr.err != nil { + ndone++ + + if cr.client { + fmt.Println("Client read error: ", cr.err) + } else { + fmt.Println("Server read error: ", cr.err) + } + + return cr.err } - // else if s == 0x0d { fmt.Printf("found a client cert request, sending buf to client\n") } - } else if s == 0x0e { - sendToClient = true - } else if s == 0x0d { - sendToClient = true + } + } + + fmt.Println("WAITING; ndone = ", ndone) + for ndone < 2 { + fmt.Println("WAITING; ndone = ", ndone) + select { + case cr := <-crChan: + fmt.Printf("CHAN DATA: %v, %v, %v\n", cr.client, cr.err, len(cr.data)) + if cr.err != nil || cr.data == nil { + ndone++ + } else if cr.client { + conn2.Write(cr.data) + } else if !cr.client { + conn.Write(cr.data) + } - // fmt.Printf("Version bytes: %x %x\n", responseBuf[1], responseBuf[2]) - // fmt.Printf("Len bytes: %x %x\n", responseBuf[3], responseBuf[4]) - // fmt.Printf("Message type: %x\n", responseBuf[5]) - // fmt.Printf("Message len: %x %x %x\n", responseBuf[6], responseBuf[7], responseBuf[8]) - // fmt.Printf("Message body: %v\n", responseBuf[9:]) - conn.Write(responseBuf) - responseBuf = []byte{} + } } - if !valid { + fmt.Println("______ ndone = 2\n") + + // dChan <- true + close(dChan) + + if !x509Valid { return errors.New("Unknown error: TLS connection could not be validated") } return nil -} -func readNBytes(conn net.Conn, numBytes int) ([]byte, error) { - res := make([]byte, 0) - temp := make([]byte, 1) - for i := 0; i < numBytes; i++ { - _, err := io.ReadAtLeast(conn, temp, 1) - if err != nil { - return res, err - } - res = append(res, temp[0]) - } - return res, nil }