Added new RemovePrompt DBus call to complement RequestPrompt (GUID-based prompt removal).

The addition of a rule matching multiple pending connections in fw-prompt now removes all of them.
fw-prompter now increments ref# column for identical prompt requests.
Fixed/cleaned up/updated TLSGuard code.
Added TLSGuard toggle option to fw-prompt GUI (default for SOCKS connections).
fw-prompt now displays icon of filtered application.
DBus RequestPrompt() now "works" asynchronously.
TLSGuard fixed under certain conditions but still very buggy.
Fixed some fw-prompt crash conditions with treeview mutex locking.
Fixed SOCKS connection panic condition linked to closed channel.
Cleanup of unused data structures/values.
shw_dev
Stephen Watt 7 years ago
parent a8f61a2d4e
commit 2fc7525cc7

@ -4,7 +4,6 @@ import (
"errors" "errors"
"github.com/godbus/dbus" "github.com/godbus/dbus"
"log" "log"
// "github.com/gotk3/gotk3/glib"
) )
type dbusServer struct { type dbusServer struct {
@ -12,27 +11,6 @@ type dbusServer struct {
run bool run bool
} }
type promptData struct {
Application string
Icon string
Path string
Address string
Port int
IP string
Origin string
Proto string
UID int
GID int
Username string
Groupname string
Pid int
Sandbox string
OptString string
Expanded bool
Expert bool
Action int
}
func newDbusServer() (*dbusServer, error) { func newDbusServer() (*dbusServer, error) {
conn, err := dbus.SystemBus() conn, err := dbus.SystemBus()
@ -62,10 +40,10 @@ func newDbusServer() (*dbusServer, error) {
return ds, nil return ds, nil
} }
func (ds *dbusServer) RequestPrompt(application, icon, path, address string, port int32, ip, origin, proto string, uid, gid int32, username, groupname string, pid int32, sandbox string, func (ds *dbusServer) RequestPrompt(guid, application, icon, path, address string, port int32, ip, origin, proto string, uid, gid int32, username, groupname string, pid int32, sandbox string,
is_socks bool, optstring string, expanded, expert bool, action int32) (int32, string, *dbus.Error) { is_socks bool, optstring string, expanded, expert bool, action int32) (int32, string, *dbus.Error) {
log.Printf("request prompt: app = %s, icon = %s, path = %s, address = %s, is_socks = %v, action = %v\n", application, icon, path, address, is_socks, action) log.Printf("request prompt: app = %s, icon = %s, path = %s, address = %s / ip = %s, is_socks = %v, action = %v\n", application, icon, path, address, ip, is_socks, action)
decision := addRequest(nil, path, proto, int(pid), ip, address, int(port), int(uid), int(gid), origin, is_socks, optstring, sandbox) decision := addRequest(nil, guid, path, icon, proto, int(pid), ip, address, int(port), int(uid), int(gid), origin, is_socks, optstring, sandbox)
log.Print("Waiting on decision...") log.Print("Waiting on decision...")
decision.Cond.L.Lock() decision.Cond.L.Lock()
for !decision.Ready { for !decision.Ready {
@ -73,6 +51,11 @@ func (ds *dbusServer) RequestPrompt(application, icon, path, address string, por
} }
log.Print("Decision returned: ", decision.Rule) log.Print("Decision returned: ", decision.Rule)
decision.Cond.L.Unlock() decision.Cond.L.Unlock()
// glib.IdleAdd(func, data)
return int32(decision.Scope), decision.Rule, nil return int32(decision.Scope), decision.Rule, nil
} }
func (ds *dbusServer) RemovePrompt(guid string) *dbus.Error {
log.Printf("++++++++ Cancelling prompt: %s\n", guid)
removeRequest(nil, guid)
return nil
}

@ -34,7 +34,10 @@ type decisionWaiter struct {
} }
type ruleColumns struct { type ruleColumns struct {
nrefs int
Path string Path string
GUID string
Icon string
Proto string Proto string
Pid int Pid int
Target string Target string
@ -45,21 +48,25 @@ type ruleColumns struct {
Uname string Uname string
Gname string Gname string
Origin string Origin string
IsSocks bool
ForceTLS bool
Scope int Scope int
} }
var userPrefs fpPreferences var userPrefs fpPreferences
var mainWin *gtk.Window var mainWin *gtk.Window
var Notebook *gtk.Notebook var Notebook *gtk.Notebook
var globalLS *gtk.ListStore var globalLS *gtk.ListStore = nil
var globalTV *gtk.TreeView var globalTV *gtk.TreeView
var globalPromptLock = &sync.Mutex{}
var globalIcon *gtk.Image
var decisionWaiters []*decisionWaiter var decisionWaiters []*decisionWaiter
var editApp, editTarget, editPort, editUser, editGroup *gtk.Entry var editApp, editTarget, editPort, editUser, editGroup *gtk.Entry
var comboProto *gtk.ComboBoxText var comboProto *gtk.ComboBoxText
var radioOnce, radioProcess, radioParent, radioSession, radioPermanent *gtk.RadioButton var radioOnce, radioProcess, radioParent, radioSession, radioPermanent *gtk.RadioButton
var btnApprove, btnDeny, btnIgnore *gtk.Button var btnApprove, btnDeny, btnIgnore *gtk.Button
var chkUser, chkGroup *gtk.CheckButton var chkTLS, chkUser, chkGroup *gtk.CheckButton
func dumpDecisions() { func dumpDecisions() {
fmt.Println("XXX Total of decisions pending: ", len(decisionWaiters)) fmt.Println("XXX Total of decisions pending: ", len(decisionWaiters))
@ -306,7 +313,8 @@ func createColumn(title string, id int) *gtk.TreeViewColumn {
} }
func createListStore(general bool) *gtk.ListStore { func createListStore(general bool) *gtk.ListStore {
colData := []glib.Type{glib.TYPE_INT, glib.TYPE_STRING, glib.TYPE_STRING, glib.TYPE_INT, glib.TYPE_STRING, glib.TYPE_STRING, glib.TYPE_INT, glib.TYPE_INT, glib.TYPE_INT, glib.TYPE_STRING, glib.TYPE_STRING} colData := []glib.Type{glib.TYPE_INT, glib.TYPE_STRING, glib.TYPE_STRING, glib.TYPE_STRING, glib.TYPE_STRING, glib.TYPE_INT,
glib.TYPE_STRING, glib.TYPE_STRING, glib.TYPE_INT, glib.TYPE_INT, glib.TYPE_INT, glib.TYPE_STRING, glib.TYPE_INT, glib.TYPE_STRING}
listStore, err := gtk.ListStoreNew(colData...) listStore, err := gtk.ListStoreNew(colData...)
if err != nil { if err != nil {
@ -316,7 +324,66 @@ func createListStore(general bool) *gtk.ListStore {
return listStore return listStore
} }
func addRequest(listStore *gtk.ListStore, path, proto string, pid int, ipaddr, hostname string, port, uid, gid int, origin string, is_socks bool, optstring string, sandbox string) *decisionWaiter { func removeRequest(listStore *gtk.ListStore, guid string) {
removed := false
globalPromptLock.Lock()
/* XXX: This is horrible. Figure out how to do this properly. */
for ridx := 0; ridx < 2000; ridx++ {
rule, _, err := getRuleByIdx(ridx)
if err != nil {
break
} else if rule.GUID == guid {
removeSelectedRule(ridx, true)
removed = true
break
}
}
globalPromptLock.Unlock()
if !removed {
log.Printf("Unexpected condition: SGFW requested prompt removal for non-existent GUID %v\n", guid)
}
}
func addRequestInc(listStore *gtk.ListStore, guid, path, icon, proto string, pid int, ipaddr, hostname string, port, uid, gid int, origin string, is_socks bool, optstring string, sandbox string) bool {
duplicated := false
globalPromptLock.Lock()
for ridx := 0; ridx < 2000; ridx++ {
/* XXX: This is horrible. Figure out how to do this properly. */
rule, iter, err := getRuleByIdx(ridx)
if err != nil {
break
// XXX: not compared: optstring/sandbox
} else if (rule.Path == path) && (rule.Proto == proto) && (rule.Pid == pid) && (rule.Target == ipaddr) && (rule.Hostname == hostname) &&
(rule.Port == port) && (rule.UID == uid) && (rule.GID == gid) && (rule.Origin == origin) && (rule.IsSocks == is_socks) {
rule.nrefs++
err := globalLS.SetValue(iter, 0, rule.nrefs)
if err != nil {
log.Print("Error creating duplicate firewall prompt entry:", err)
break
}
fmt.Println("YES REALLY DUPLICATE: ", rule.nrefs)
duplicated = true
break
}
}
globalPromptLock.Unlock()
return duplicated
}
func addRequest(listStore *gtk.ListStore, guid, path, icon, proto string, pid int, ipaddr, hostname string, port, uid, gid int, origin string, is_socks bool, optstring string, sandbox string) *decisionWaiter {
if listStore == nil { if listStore == nil {
listStore = globalLS listStore = globalLS
waitTimes := []int{1, 2, 5, 10} waitTimes := []int{1, 2, 5, 10}
@ -342,6 +409,16 @@ func addRequest(listStore *gtk.ListStore, path, proto string, pid int, ipaddr, h
log.Fatal("SGFW prompter GUI failed to load for unknown reasons") log.Fatal("SGFW prompter GUI failed to load for unknown reasons")
} }
if addRequestInc(listStore, guid, path, icon, proto, pid, ipaddr, hostname, port, uid, gid, origin, is_socks, optstring, sandbox) {
fmt.Println("REQUEST WAS DUPLICATE")
decision := addDecision()
toggleHover()
return decision
} else {
fmt.Println("NOT DUPLICATE")
}
globalPromptLock.Lock()
iter := listStore.Append() iter := listStore.Append()
if is_socks { if is_socks {
@ -352,24 +429,32 @@ func addRequest(listStore *gtk.ListStore, path, proto string, pid int, ipaddr, h
} }
} }
colVals := make([]interface{}, 11) colVals := make([]interface{}, 14)
colVals[0] = 1 colVals[0] = 1
colVals[1] = path colVals[1] = guid
colVals[2] = proto colVals[2] = path
colVals[3] = pid colVals[3] = icon
colVals[4] = proto
colVals[5] = pid
if ipaddr == "" { if ipaddr == "" {
colVals[4] = "---" colVals[6] = "---"
} else { } else {
colVals[4] = ipaddr colVals[6] = ipaddr
} }
colVals[5] = hostname colVals[7] = hostname
colVals[6] = port colVals[8] = port
colVals[7] = uid colVals[9] = uid
colVals[8] = gid colVals[10] = gid
colVals[9] = origin colVals[11] = origin
colVals[10] = optstring colVals[12] = 0
if is_socks {
colVals[12] = 1
}
colVals[13] = optstring
colNums := make([]int, len(colVals)) colNums := make([]int, len(colVals))
@ -378,6 +463,7 @@ func addRequest(listStore *gtk.ListStore, path, proto string, pid int, ipaddr, h
} }
err := listStore.Set(iter, colNums, colVals) err := listStore.Set(iter, colNums, colVals)
globalPromptLock.Unlock()
if err != nil { if err != nil {
log.Fatal("Unable to add row:", err) log.Fatal("Unable to add row:", err)
@ -495,6 +581,8 @@ func toggleHover() {
func toggleValidRuleState() { func toggleValidRuleState() {
ok := true ok := true
globalPromptLock.Lock()
if numSelections() <= 0 { if numSelections() <= 0 {
ok = false ok = false
} }
@ -537,6 +625,7 @@ func toggleValidRuleState() {
btnApprove.SetSensitive(ok) btnApprove.SetSensitive(ok)
btnDeny.SetSensitive(ok) btnDeny.SetSensitive(ok)
btnIgnore.SetSensitive(ok) btnIgnore.SetSensitive(ok)
globalPromptLock.Unlock()
} }
func createCurrentRule() (ruleColumns, error) { func createCurrentRule() (ruleColumns, error) {
@ -579,6 +668,8 @@ func createCurrentRule() (ruleColumns, error) {
rule.UID, rule.GID = 0, 0 rule.UID, rule.GID = 0, 0
rule.Uname, rule.Gname = "", "" rule.Uname, rule.Gname = "", ""
rule.ForceTLS = chkTLS.GetActive()
/* Pid int /* Pid int
Origin string */ Origin string */
@ -586,6 +677,7 @@ func createCurrentRule() (ruleColumns, error) {
} }
func clearEditor() { func clearEditor() {
globalIcon.Clear()
editApp.SetText("") editApp.SetText("")
editTarget.SetText("") editTarget.SetText("")
editPort.SetText("") editPort.SetText("")
@ -599,6 +691,7 @@ func clearEditor() {
radioPermanent.SetActive(false) radioPermanent.SetActive(false)
chkUser.SetActive(false) chkUser.SetActive(false)
chkGroup.SetActive(false) chkGroup.SetActive(false)
chkTLS.SetActive(false)
} }
func removeSelectedRule(idx int, rmdecision bool) error { func removeSelectedRule(idx int, rmdecision bool) error {
@ -634,78 +727,116 @@ func numSelections() int {
return int(rows.Length()) return int(rows.Length())
} }
func getSelectedRule() (ruleColumns, int, error) { // Needs to be locked by the caller
func getRuleByIdx(idx int) (ruleColumns, *gtk.TreeIter, error) {
rule := ruleColumns{} rule := ruleColumns{}
sel, err := globalTV.GetSelection() path, err := gtk.TreePathNewFromString(fmt.Sprintf("%d", idx))
if err != nil { if err != nil {
return rule, -1, err return rule, nil, err
} }
rows := sel.GetSelectedRows(globalLS) iter, err := globalLS.GetIter(path)
if err != nil {
return rule, nil, err
}
if rows.Length() <= 0 { rule.nrefs, err = lsGetInt(globalLS, iter, 0)
return rule, -1, errors.New("No selection was made") if err != nil {
return rule, nil, err
} }
rdata := rows.NthData(0) rule.GUID, err = lsGetStr(globalLS, iter, 1)
lIndex, err := strconv.Atoi(rdata.(*gtk.TreePath).String())
if err != nil { if err != nil {
return rule, -1, err return rule, nil, err
} }
fmt.Println("lindex = ", lIndex) rule.Path, err = lsGetStr(globalLS, iter, 2)
path, err := gtk.TreePathNewFromString(fmt.Sprintf("%d", lIndex))
if err != nil { if err != nil {
return rule, -1, err return rule, nil, err
} }
iter, err := globalLS.GetIter(path) rule.Icon, err = lsGetStr(globalLS, iter, 3)
if err != nil { if err != nil {
return rule, -1, err return rule, nil, err
} }
rule.Path, err = lsGetStr(globalLS, iter, 1) rule.Proto, err = lsGetStr(globalLS, iter, 4)
if err != nil { if err != nil {
return rule, -1, err return rule, nil, err
} }
rule.Proto, err = lsGetStr(globalLS, iter, 2) rule.Pid, err = lsGetInt(globalLS, iter, 5)
if err != nil { if err != nil {
return rule, -1, err return rule, nil, err
} }
rule.Pid, err = lsGetInt(globalLS, iter, 3) rule.Target, err = lsGetStr(globalLS, iter, 6)
if err != nil { if err != nil {
return rule, -1, err return rule, nil, err
} }
rule.Target, err = lsGetStr(globalLS, iter, 4) rule.Hostname, err = lsGetStr(globalLS, iter, 7)
if err != nil { if err != nil {
return rule, -1, err return rule, nil, err
} }
rule.Hostname, err = lsGetStr(globalLS, iter, 5) rule.Port, err = lsGetInt(globalLS, iter, 8)
if err != nil { if err != nil {
return rule, -1, err return rule, nil, err
} }
rule.Port, err = lsGetInt(globalLS, iter, 6) rule.UID, err = lsGetInt(globalLS, iter, 9)
if err != nil { if err != nil {
return rule, -1, err return rule, nil, err
}
rule.GID, err = lsGetInt(globalLS, iter, 10)
if err != nil {
return rule, nil, err
}
rule.Origin, err = lsGetStr(globalLS, iter, 11)
if err != nil {
return rule, nil, err
} }
rule.UID, err = lsGetInt(globalLS, iter, 7) rule.IsSocks = false
is_socks, err := lsGetInt(globalLS, iter, 12)
if err != nil {
return rule, nil, err
}
if is_socks != 0 {
rule.IsSocks = true
}
return rule, iter, nil
}
// Needs to be locked by the caller
func getSelectedRule() (ruleColumns, int, error) {
rule := ruleColumns{}
sel, err := globalTV.GetSelection()
if err != nil { if err != nil {
return rule, -1, err return rule, -1, err
} }
rule.GID, err = lsGetInt(globalLS, iter, 8) rows := sel.GetSelectedRows(globalLS)
if rows.Length() <= 0 {
return rule, -1, errors.New("No selection was made")
}
rdata := rows.NthData(0)
lIndex, err := strconv.Atoi(rdata.(*gtk.TreePath).String())
if err != nil { if err != nil {
return rule, -1, err return rule, -1, err
} }
rule.Origin, err = lsGetStr(globalLS, iter, 9) fmt.Println("lindex = ", lIndex)
rule, _, err = getRuleByIdx(lIndex)
if err != nil { if err != nil {
return rule, -1, err return rule, -1, err
} }
@ -811,10 +942,18 @@ func main() {
editbox := get_vbox() editbox := get_vbox()
hbox := get_hbox() hbox := get_hbox()
lbl := get_label("Application path:") lbl := get_label("Application path:")
globalIcon, err = gtk.ImageNew()
if err != nil {
log.Fatal("Unable to create image:", err)
}
// globalIcon.SetFromIconName("firefox", gtk.ICON_SIZE_DND)
editApp = get_entry("") editApp = get_entry("")
editApp.Connect("changed", toggleValidRuleState) editApp.Connect("changed", toggleValidRuleState)
hbox.PackStart(lbl, false, false, 10) hbox.PackStart(lbl, false, false, 10)
hbox.PackStart(editApp, true, true, 50) hbox.PackStart(editApp, true, true, 10)
hbox.PackStart(globalIcon, false, false, 10)
editbox.PackStart(hbox, false, false, 5) editbox.PackStart(hbox, false, false, 5)
hbox = get_hbox() hbox = get_hbox()
@ -842,7 +981,9 @@ func main() {
radioSession = get_radiobutton(radioOnce, "Session", false) radioSession = get_radiobutton(radioOnce, "Session", false)
radioPermanent = get_radiobutton(radioOnce, "Permanent", false) radioPermanent = get_radiobutton(radioOnce, "Permanent", false)
radioParent.SetSensitive(false) radioParent.SetSensitive(false)
hbox.PackStart(lbl, false, false, 10) chkTLS = get_checkbox("Require TLS", false)
hbox.PackStart(chkTLS, false, false, 10)
hbox.PackStart(lbl, false, false, 20)
hbox.PackStart(radioOnce, false, false, 5) hbox.PackStart(radioOnce, false, false, 5)
hbox.PackStart(radioProcess, false, false, 5) hbox.PackStart(radioProcess, false, false, 5)
hbox.PackStart(radioParent, false, false, 5) hbox.PackStart(radioParent, false, false, 5)
@ -872,16 +1013,31 @@ func main() {
box.PackStart(scrollbox, false, true, 5) box.PackStart(scrollbox, false, true, 5)
tv.AppendColumn(createColumn("#", 0)) tv.AppendColumn(createColumn("#", 0))
tv.AppendColumn(createColumn("Path", 1))
tv.AppendColumn(createColumn("Protocol", 2)) guidcol := createColumn("GUID", 1)
tv.AppendColumn(createColumn("PID", 3)) guidcol.SetVisible(false)
tv.AppendColumn(createColumn("IP Address", 4)) tv.AppendColumn(guidcol)
tv.AppendColumn(createColumn("Hostname", 5))
tv.AppendColumn(createColumn("Port", 6)) tv.AppendColumn(createColumn("Path", 2))
tv.AppendColumn(createColumn("UID", 7))
tv.AppendColumn(createColumn("GID", 8)) icol := createColumn("Icon", 3)
tv.AppendColumn(createColumn("Origin", 9)) icol.SetVisible(false)
tv.AppendColumn(createColumn("Details", 10)) tv.AppendColumn(icol)
tv.AppendColumn(createColumn("Protocol", 4))
tv.AppendColumn(createColumn("PID", 5))
tv.AppendColumn(createColumn("IP Address", 6))
tv.AppendColumn(createColumn("Hostname", 7))
tv.AppendColumn(createColumn("Port", 8))
tv.AppendColumn(createColumn("UID", 9))
tv.AppendColumn(createColumn("GID", 10))
tv.AppendColumn(createColumn("Origin", 11))
scol := createColumn("Is SOCKS", 12)
scol.SetVisible(false)
tv.AppendColumn(scol)
tv.AppendColumn(createColumn("Details", 13))
listStore := createListStore(true) listStore := createListStore(true)
globalLS = listStore globalLS = listStore
@ -889,23 +1045,33 @@ func main() {
tv.SetModel(listStore) tv.SetModel(listStore)
btnApprove.Connect("clicked", func() { btnApprove.Connect("clicked", func() {
// globalPromptLock.Lock()
rule, idx, err := getSelectedRule() rule, idx, err := getSelectedRule()
if err != nil { if err != nil {
// globalPromptLock.Unlock()
promptError("Error occurred processing request: " + err.Error()) promptError("Error occurred processing request: " + err.Error())
return return
} }
rule, err = createCurrentRule() rule, err = createCurrentRule()
if err != nil { if err != nil {
// globalPromptLock.Unlock()
promptError("Error occurred constructing new rule: " + err.Error()) promptError("Error occurred constructing new rule: " + err.Error())
return return
} }
fmt.Println("rule = ", rule) fmt.Println("rule = ", rule)
rulestr := "ALLOW|" + rule.Proto + ":" + rule.Target + ":" + strconv.Itoa(rule.Port) rulestr := "ALLOW"
if rule.ForceTLS {
rulestr += "_TLSONLY"
}
rulestr += "|" + rule.Proto + ":" + rule.Target + ":" + strconv.Itoa(rule.Port)
fmt.Println("RULESTR = ", rulestr) fmt.Println("RULESTR = ", rulestr)
makeDecision(idx, rulestr, int(rule.Scope)) makeDecision(idx, rulestr, int(rule.Scope))
fmt.Println("Decision made.") fmt.Println("Decision made.")
// globalPromptLock.Unlock()
err = removeSelectedRule(idx, true) err = removeSelectedRule(idx, true)
if err == nil { if err == nil {
clearEditor() clearEditor()
@ -915,14 +1081,17 @@ func main() {
}) })
btnDeny.Connect("clicked", func() { btnDeny.Connect("clicked", func() {
// globalPromptLock.Lock()
rule, idx, err := getSelectedRule() rule, idx, err := getSelectedRule()
if err != nil { if err != nil {
// globalPromptLock.Unlock()
promptError("Error occurred processing request: " + err.Error()) promptError("Error occurred processing request: " + err.Error())
return return
} }
rule, err = createCurrentRule() rule, err = createCurrentRule()
if err != nil { if err != nil {
// globalPromptLock.Unlock()
promptError("Error occurred constructing new rule: " + err.Error()) promptError("Error occurred constructing new rule: " + err.Error())
return return
} }
@ -932,6 +1101,7 @@ func main() {
fmt.Println("RULESTR = ", rulestr) fmt.Println("RULESTR = ", rulestr)
makeDecision(idx, rulestr, int(rule.Scope)) makeDecision(idx, rulestr, int(rule.Scope))
fmt.Println("Decision made.") fmt.Println("Decision made.")
// globalPromptLock.Unlock()
err = removeSelectedRule(idx, true) err = removeSelectedRule(idx, true)
if err == nil { if err == nil {
clearEditor() clearEditor()
@ -941,14 +1111,17 @@ func main() {
}) })
btnIgnore.Connect("clicked", func() { btnIgnore.Connect("clicked", func() {
// globalPromptLock.Lock()
_, idx, err := getSelectedRule() _, idx, err := getSelectedRule()
if err != nil { if err != nil {
// globalPromptLock.Unlock()
promptError("Error occurred processing request: " + err.Error()) promptError("Error occurred processing request: " + err.Error())
return return
} }
makeDecision(idx, "", 0) makeDecision(idx, "", 0)
fmt.Println("Decision made.") fmt.Println("Decision made.")
// globalPromptLock.Unlock()
err = removeSelectedRule(idx, true) err = removeSelectedRule(idx, true)
if err == nil { if err == nil {
clearEditor() clearEditor()
@ -959,14 +1132,22 @@ func main() {
// tv.SetActivateOnSingleClick(true) // tv.SetActivateOnSingleClick(true)
tv.Connect("row-activated", func() { tv.Connect("row-activated", func() {
// globalPromptLock.Lock()
seldata, _, err := getSelectedRule() seldata, _, err := getSelectedRule()
if err != nil { if err != nil {
// globalPromptLock.Unlock()
promptError("Unexpected error reading selected rule: " + err.Error()) promptError("Unexpected error reading selected rule: " + err.Error())
return return
} }
editApp.SetText(seldata.Path) editApp.SetText(seldata.Path)
if seldata.Icon != "" {
globalIcon.SetFromIconName(seldata.Icon, gtk.ICON_SIZE_DND)
} else {
globalIcon.Clear()
}
if seldata.Hostname != "" { if seldata.Hostname != "" {
editTarget.SetText(seldata.Hostname) editTarget.SetText(seldata.Hostname)
} else { } else {
@ -981,6 +1162,7 @@ func main() {
radioSession.SetActive(false) radioSession.SetActive(false)
radioPermanent.SetActive(false) radioPermanent.SetActive(false)
comboProto.SetActiveID(seldata.Proto) comboProto.SetActiveID(seldata.Proto)
chkTLS.SetActive(seldata.IsSocks)
if seldata.Uname != "" { if seldata.Uname != "" {
editUser.SetText(seldata.Uname) editUser.SetText(seldata.Uname)
@ -1001,6 +1183,7 @@ func main() {
chkUser.SetActive(false) chkUser.SetActive(false)
chkGroup.SetActive(false) chkGroup.SetActive(false)
// globalPromptLock.Unlock()
return return
}) })
@ -1011,7 +1194,7 @@ func main() {
mainWin.Add(Notebook) mainWin.Add(Notebook)
if userPrefs.Winheight > 0 && userPrefs.Winwidth > 0 { if userPrefs.Winheight > 0 && userPrefs.Winwidth > 0 {
// fmt.Printf("height was %d, width was %d\n", userPrefs.Winheight, userPrefs.Winwidth) // fmt.Printf("height was %d, width was %d\n", userPrefs.Winheight, userPrefs.Winwidth)
mainWin.Resize(int(userPrefs.Winwidth), int(userPrefs.Winheight)) mainWin.Resize(int(userPrefs.Winwidth), int(userPrefs.Winheight))
} else { } else {
mainWin.SetDefaultSize(850, 450) mainWin.SetDefaultSize(850, 450)

@ -166,7 +166,7 @@ func (dc *dnsCache) Lookup(ip net.IP, pid int) string {
entry, ok := dc.ipMap[pid][ip.String()] entry, ok := dc.ipMap[pid][ip.String()]
if ok { if ok {
if now.Before(entry.exp) { if now.Before(entry.exp) {
// log.Noticef("XXX: LOOKUP on %v / %v = %v, ttl = %v / %v\n", pid, ip.String(), entry.name, entry.ttl, entry.exp) // log.Noticef("XXX: LOOKUP on %v / %v = %v, ttl = %v / %v\n", pid, ip.String(), entry.name, entry.ttl, entry.exp)
return entry.name return entry.name
} else { } else {
log.Warningf("Skipping expired per-pid (%d) DNS cache entry: %s -> %s / exp. %v (%ds)\n", log.Warningf("Skipping expired per-pid (%d) DNS cache entry: %s -> %s / exp. %v (%ds)\n",
@ -180,7 +180,7 @@ func (dc *dnsCache) Lookup(ip net.IP, pid int) string {
if ok { if ok {
if now.Before(entry.exp) { if now.Before(entry.exp) {
str = entry.name str = entry.name
// log.Noticef("XXX: LOOKUP on %v / 0 RETURNING %v, ttl = %v / %v\n", ip.String(), str, entry.ttl, entry.exp) // log.Noticef("XXX: LOOKUP on %v / 0 RETURNING %v, ttl = %v / %v\n", ip.String(), str, entry.ttl, entry.exp)
} else { } else {
log.Warningf("Skipping expired global DNS cache entry: %s -> %s / exp. %v (%ds)\n", log.Warningf("Skipping expired global DNS cache entry: %s -> %s / exp. %v (%ds)\n",
ip.String(), entry.name, entry.exp, entry.ttl) ip.String(), entry.name, entry.exp, entry.ttl)

@ -52,6 +52,9 @@ type pendingConnection interface {
drop() drop()
setPrompting(bool) setPrompting(bool)
getPrompting() bool getPrompting() bool
setPrompter(*dbusObjectP)
getPrompter() *dbusObjectP
getGUID() string
print() string print() string
} }
@ -62,6 +65,23 @@ type pendingPkt struct {
pinfo *procsnitch.Info pinfo *procsnitch.Info
optstring string optstring string
prompting bool prompting bool
prompter *dbusObjectP
guid string
}
/* Not a *REAL* GUID */
func genGUID() string {
frnd, err := os.Open("/dev/urandom")
if err != nil {
log.Fatal("Error reading random data source:", err)
}
rndb := make([]byte, 16)
frnd.Read(rndb)
frnd.Close()
guid := fmt.Sprintf("%x-%x-%x-%x", rndb[0:4], rndb[4:8], rndb[8:12], rndb[12:])
return guid
} }
func getEmptyPInfo() *procsnitch.Info { func getEmptyPInfo() *procsnitch.Info {
@ -165,6 +185,22 @@ func (pp *pendingPkt) drop() {
pp.pkt.Accept() pp.pkt.Accept()
} }
func (pp *pendingPkt) setPrompter(val *dbusObjectP) {
pp.prompter = val
}
func (pp *pendingPkt) getPrompter() *dbusObjectP {
return pp.prompter
}
func (pp *pendingPkt) getGUID() string {
if pp.guid == "" {
pp.guid = genGUID()
}
return pp.guid
}
func (pp *pendingPkt) getPrompting() bool { func (pp *pendingPkt) getPrompting() bool {
return pp.prompting return pp.prompting
} }
@ -265,7 +301,7 @@ func (p *Policy) processPacket(pkt *nfqueue.NFQPacket, pinfo *procsnitch.Info, o
case FILTER_ALLOW: case FILTER_ALLOW:
pkt.Accept() pkt.Accept()
case FILTER_PROMPT: case FILTER_PROMPT:
p.processPromptResult(&pendingPkt{pol: p, name: name, pkt: pkt, pinfo: pinfo, optstring: optstr, prompting: false}) p.processPromptResult(&pendingPkt{pol: p, name: name, pkt: pkt, pinfo: pinfo, optstring: optstr, prompter: nil, prompting: false})
default: default:
log.Warningf("Unexpected filter result: %d", result) log.Warningf("Unexpected filter result: %d", result)
} }
@ -327,6 +363,7 @@ func (p *Policy) processNewRule(r *Rule, scope FilterScope) bool {
if scope != APPLY_ONCE { if scope != APPLY_ONCE {
p.rules = append(p.rules, r) p.rules = append(p.rules, r)
} }
fmt.Println("----------------------- processNewRule()")
p.filterPending(r) p.filterPending(r)
if len(p.pendingQueue) == 0 { if len(p.pendingQueue) == 0 {
p.promptInProgress = false p.promptInProgress = false
@ -370,8 +407,19 @@ func (p *Policy) filterPending(rule *Rule) {
remaining := []pendingConnection{} remaining := []pendingConnection{}
for _, pc := range p.pendingQueue { for _, pc := range p.pendingQueue {
if rule.match(pc.src(), pc.dst(), pc.dstPort(), pc.hostname(), pc.proto(), pc.procInfo().UID, pc.procInfo().GID, uidToUser(pc.procInfo().UID), gidToGroup(pc.procInfo().GID)) { if rule.match(pc.src(), pc.dst(), pc.dstPort(), pc.hostname(), pc.proto(), pc.procInfo().UID, pc.procInfo().GID, uidToUser(pc.procInfo().UID), gidToGroup(pc.procInfo().GID)) {
prompter := pc.getPrompter()
if prompter == nil {
fmt.Println("-------- prompter = NULL")
} else {
fmt.Println("---------- could send prompter")
call := prompter.Call("com.subgraph.FirewallPrompt.RemovePrompt", 0, pc.getGUID())
fmt.Println("CAAAAAAAAAAAAAAALL = ", call)
}
log.Infof("Adding rule for: %s", rule.getString(FirewallConfig.LogRedact)) log.Infof("Adding rule for: %s", rule.getString(FirewallConfig.LogRedact))
// log.Noticef("%s > %s", rule.getString(FirewallConfig.LogRedact), pc.print()) // log.Noticef("%s > %s", rule.getString(FirewallConfig.LogRedact), pc.print())
if rule.rtype == RULE_ACTION_ALLOW { if rule.rtype == RULE_ACTION_ALLOW {
pc.accept() pc.accept()
} else if rule.rtype == RULE_ACTION_ALLOW_TLSONLY { } else if rule.rtype == RULE_ACTION_ALLOW_TLSONLY {
@ -649,7 +697,7 @@ func LookupSandboxProc(srcip net.IP, srcp uint16, dstip net.IP, dstp uint16, pro
rlines = append(rlines, strings.Join(ssplit, ":")) rlines = append(rlines, strings.Join(ssplit, ":"))
} }
// log.Warningf("Looking for %s:%d => %s:%d \n %s\n******\n", srcip, srcp, dstip, dstp, data) // log.Warningf("Looking for %s:%d => %s:%d \n %s\n******\n", srcip, srcp, dstip, dstp, data)
if proto == "tcp" { if proto == "tcp" {
res = procsnitch.LookupTCPSocketProcessAll(srcip, srcp, dstip, dstp, rlines) res = procsnitch.LookupTCPSocketProcessAll(srcip, srcp, dstip, dstp, rlines)

@ -18,7 +18,9 @@ var DoMultiPrompt = true
const MAX_PROMPTS = 5 const MAX_PROMPTS = 5
var outstandingPrompts = 0 var outstandingPrompts = 0
var outstandingPromptChans [](chan *dbus.Call)
var promptLock = &sync.Mutex{} var promptLock = &sync.Mutex{}
var promptChanLock = &sync.Mutex{}
func newPrompter(conn *dbus.Conn) *prompter { func newPrompter(conn *dbus.Conn) *prompter {
p := new(prompter) p := new(prompter)
@ -37,6 +39,30 @@ type prompter struct {
policyQueue []*Policy policyQueue []*Policy
} }
func saveChannel(ch chan *dbus.Call, add bool, do_close bool) {
promptChanLock.Lock()
if add {
outstandingPromptChans = append(outstandingPromptChans, ch)
} else {
for idx, och := range outstandingPromptChans {
if och == ch {
outstandingPromptChans = append(outstandingPromptChans[:idx], outstandingPromptChans[idx+1:]...)
break
}
}
}
if !add && do_close {
close(ch)
}
promptChanLock.Unlock()
return
}
func (p *prompter) prompt(policy *Policy) { func (p *prompter) prompt(policy *Policy) {
p.lock.Lock() p.lock.Lock()
defer p.lock.Unlock() defer p.lock.Unlock()
@ -53,11 +79,11 @@ func (p *prompter) prompt(policy *Policy) {
func (p *prompter) promptLoop() { func (p *prompter) promptLoop() {
p.lock.Lock() p.lock.Lock()
for { for {
// fmt.Println("XXX: promptLoop() outer") // fmt.Println("XXX: promptLoop() outer")
for p.processNextPacket() { for p.processNextPacket() {
// fmt.Println("XXX: promptLoop() inner") // fmt.Println("XXX: promptLoop() inner")
} }
// fmt.Println("promptLoop() wait") // fmt.Println("promptLoop() wait")
p.cond.Wait() p.cond.Wait()
} }
} }
@ -79,7 +105,7 @@ func (p *prompter) processNextPacket() bool {
empty := true empty := true
for { for {
pc, empty = p.nextConnection() pc, empty = p.nextConnection()
// fmt.Println("XXX: processNextPacket() loop; empty = ", empty, " / pc = ", pc) // fmt.Println("XXX: processNextPacket() loop; empty = ", empty, " / pc = ", pc)
if pc == nil && empty { if pc == nil && empty {
return false return false
} else if pc == nil { } else if pc == nil {
@ -90,7 +116,7 @@ func (p *prompter) processNextPacket() bool {
} }
p.lock.Unlock() p.lock.Unlock()
defer p.lock.Lock() defer p.lock.Lock()
// fmt.Println("XXX: Waiting for prompt lock go...") // fmt.Println("XXX: Waiting for prompt lock go...")
for { for {
promptLock.Lock() promptLock.Lock()
if outstandingPrompts >= MAX_PROMPTS { if outstandingPrompts >= MAX_PROMPTS {
@ -106,9 +132,9 @@ func (p *prompter) processNextPacket() bool {
break break
} }
// fmt.Println("XXX: Passed prompt lock!") // fmt.Println("XXX: Passed prompt lock!")
outstandingPrompts++ outstandingPrompts++
// fmt.Println("XXX: Incremented outstanding to ", outstandingPrompts) // fmt.Println("XXX: Incremented outstanding to ", outstandingPrompts)
promptLock.Unlock() promptLock.Unlock()
// if !pc.getPrompting() { // if !pc.getPrompting() {
pc.setPrompting(true) pc.setPrompting(true)
@ -120,15 +146,34 @@ func (p *prompter) processNextPacket() bool {
func processReturn(pc pendingConnection) { func processReturn(pc pendingConnection) {
promptLock.Lock() promptLock.Lock()
outstandingPrompts-- outstandingPrompts--
// fmt.Println("XXX: Return decremented outstanding to ", outstandingPrompts) // fmt.Println("XXX: Return decremented outstanding to ", outstandingPrompts)
promptLock.Unlock() promptLock.Unlock()
pc.setPrompting(false) pc.setPrompting(false)
} }
func alertChannel(chidx int, scope int32, rule string) {
defer func() {
if r := recover(); r != nil {
log.Warning("SGFW recovered from panic while delivering out of band rule:", r)
}
}()
promptData := make([]interface{}, 3)
promptData[0] = scope
promptData[1] = rule
promptData[2] = 666
outstandingPromptChans[chidx] <- &dbus.Call{Body: promptData}
}
func (p *prompter) processConnection(pc pendingConnection) { func (p *prompter) processConnection(pc pendingConnection) {
var scope int32 var scope int32
var rule string var rule string
if pc.getPrompter() == nil {
pc.setPrompter(&dbusObjectP{p.dbusObj})
}
if DoMultiPrompt { if DoMultiPrompt {
defer processReturn(pc) defer processReturn(pc)
} }
@ -144,10 +189,14 @@ func (p *prompter) processConnection(pc pendingConnection) {
if pc.dst() != nil { if pc.dst() != nil {
dststr = pc.dst().String() dststr = pc.dst().String()
} else { } else {
dststr = addr + " (proxy to resolve)" dststr = addr + " (via proxy resolver)"
} }
call := p.dbusObj.Call("com.subgraph.FirewallPrompt.RequestPrompt", 0, callChan := make(chan *dbus.Call, 10)
saveChannel(callChan, true, false)
fmt.Println("# outstanding prompt chans = ", len(outstandingPromptChans))
p.dbusObj.Go("com.subgraph.FirewallPrompt.RequestPrompt", 0, callChan,
pc.getGUID(),
policy.application, policy.application,
policy.icon, policy.icon,
policy.path, policy.path,
@ -167,14 +216,62 @@ func (p *prompter) processConnection(pc pendingConnection) {
FirewallConfig.PromptExpanded, FirewallConfig.PromptExpanded,
FirewallConfig.PromptExpert, FirewallConfig.PromptExpert,
int32(FirewallConfig.DefaultActionID)) int32(FirewallConfig.DefaultActionID))
err := call.Store(&scope, &rule)
if err != nil { select {
log.Warningf("Error sending dbus RequestPrompt message: %v", err) case call := <-callChan:
policy.removePending(pc)
pc.drop() if call.Err != nil {
return fmt.Println("Error reading DBus channel (accepting packet): ", call.Err)
policy.removePending(pc)
pc.accept()
saveChannel(callChan, false, true)
time.Sleep(1 * time.Second)
return
}
if len(call.Body) != 2 {
log.Warning("SGFW got back response in unrecognized format, len = ", len(call.Body))
saveChannel(callChan, false, true)
if (len(call.Body) == 3) && (call.Body[2] == 666) {
fmt.Printf("+++++++++ AWESOME: %v | %v | %v\n", call.Body[0], call.Body[1], call.Body[2])
scope = call.Body[0].(int32)
rule = call.Body[1].(string)
}
return
}
fmt.Printf("DBUS GOT BACK: %v, %v\n", call.Body[0], call.Body[1])
scope = call.Body[0].(int32)
rule = call.Body[1].(string)
} }
saveChannel(callChan, false, true)
// Try alerting every other channel
promptData := make([]interface{}, 3)
promptData[0] = scope
promptData[1] = rule
promptData[2] = 666
promptChanLock.Lock()
fmt.Println("# channels to alert: ", len(outstandingPromptChans))
for chidx, _ := range outstandingPromptChans {
alertChannel(chidx, scope, rule)
// ch <- &dbus.Call{Body: promptData}
}
promptChanLock.Unlock()
/* err := call.Store(&scope, &rule)
if err != nil {
log.Warningf("Error sending dbus RequestPrompt message: %v", err)
policy.removePending(pc)
pc.drop()
return
} */
// the prompt sends: // the prompt sends:
// ALLOW|dest or DENY|dest // ALLOW|dest or DENY|dest
// //

@ -184,7 +184,7 @@ func (rl *RuleList) filter(pkt *nfqueue.NFQPacket, src, dst net.IP, dstPort uint
nfqproto = getNFQProto(pkt) nfqproto = getNFQProto(pkt)
} else { } else {
if r.saddr == nil && src == nil && sandboxed == false && (r.port == dstPort || r.port == matchAny) && (r.addr.Equal(anyAddress) || r.hostname == "" || r.hostname == hostname) { if r.saddr == nil && src == nil && sandboxed == false && (r.port == dstPort || r.port == matchAny) && (r.addr.Equal(anyAddress) || r.hostname == "" || r.hostname == hostname) {
// log.Notice("+ Socks5 MATCH SUCCEEDED") // log.Notice("+ Socks5 MATCH SUCCEEDED")
if r.rtype == RULE_ACTION_DENY { if r.rtype == RULE_ACTION_DENY {
return FILTER_DENY return FILTER_DENY
} else if r.rtype == RULE_ACTION_ALLOW { } else if r.rtype == RULE_ACTION_ALLOW {
@ -203,7 +203,7 @@ func (rl *RuleList) filter(pkt *nfqueue.NFQPacket, src, dst net.IP, dstPort uint
continue continue
} }
if r.match(src, dst, dstPort, hostname, nfqproto, pinfo.UID, pinfo.GID, uidToUser(pinfo.UID), gidToGroup(pinfo.GID)) { if r.match(src, dst, dstPort, hostname, nfqproto, pinfo.UID, pinfo.GID, uidToUser(pinfo.UID), gidToGroup(pinfo.GID)) {
// log.Notice("+ MATCH SUCCEEDED") // log.Notice("+ MATCH SUCCEEDED")
dstStr := dst.String() dstStr := dst.String()
if FirewallConfig.LogRedact { if FirewallConfig.LogRedact {
dstStr = STR_REDACTED dstStr = STR_REDACTED
@ -214,7 +214,7 @@ func (rl *RuleList) filter(pkt *nfqueue.NFQPacket, src, dst net.IP, dstPort uint
srcp, _ := getPacketPorts(pkt) srcp, _ := getPacketPorts(pkt)
srcStr = fmt.Sprintf("%s:%d", srcip, srcp) srcStr = fmt.Sprintf("%s:%d", srcip, srcp)
} }
// log.Noticef("%s > %s %s %s -> %s:%d", // log.Noticef("%s > %s %s %s -> %s:%d",
//r.getString(FirewallConfig.LogRedact), pinfo.ExePath, r.proto, srcStr, dstStr, dstPort) //r.getString(FirewallConfig.LogRedact), pinfo.ExePath, r.proto, srcStr, dstStr, dstPort)
if r.rtype == RULE_ACTION_DENY { if r.rtype == RULE_ACTION_DENY {
//TODO: Optionally redact below log entry //TODO: Optionally redact below log entry

@ -56,6 +56,8 @@ type pendingSocksConnection struct {
pinfo *procsnitch.Info pinfo *procsnitch.Info
verdict chan int verdict chan int
prompting bool prompting bool
prompter *dbusObjectP
guid string
optstr string optstr string
} }
@ -107,8 +109,11 @@ func (sc *pendingSocksConnection) deliverVerdict(v int) {
} }
}() }()
sc.verdict <- v if sc.verdict != nil {
close(sc.verdict) sc.verdict <- v
close(sc.verdict)
sc.verdict = nil
}
} }
func (sc *pendingSocksConnection) accept() { sc.deliverVerdict(socksVerdictAccept) } func (sc *pendingSocksConnection) accept() { sc.deliverVerdict(socksVerdictAccept) }
@ -119,6 +124,18 @@ func (sc *pendingSocksConnection) acceptTLSOnly() { sc.deliverVerdict(socksVerdi
func (sc *pendingSocksConnection) drop() { sc.deliverVerdict(socksVerdictDrop) } func (sc *pendingSocksConnection) drop() { sc.deliverVerdict(socksVerdictDrop) }
func (sc *pendingSocksConnection) setPrompter(val *dbusObjectP) { sc.prompter = val }
func (sc *pendingSocksConnection) getPrompter() *dbusObjectP { return sc.prompter }
func (sc *pendingSocksConnection) getGUID() string {
if sc.guid == "" {
sc.guid = genGUID()
}
return sc.guid
}
func (sc *pendingSocksConnection) getPrompting() bool { return sc.prompting } func (sc *pendingSocksConnection) getPrompting() bool { return sc.prompting }
func (sc *pendingSocksConnection) setPrompting(val bool) { sc.prompting = val } func (sc *pendingSocksConnection) setPrompting(val bool) { sc.prompting = val }
@ -364,6 +381,7 @@ func (c *socksChainSession) filterConnect() (bool, bool) {
pinfo: pinfo, pinfo: pinfo,
verdict: make(chan int), verdict: make(chan int),
prompting: false, prompting: false,
prompter: nil,
optstr: optstr, optstr: optstr,
} }
policy.processPromptResult(pending) policy.processPromptResult(pending)
@ -409,7 +427,7 @@ func (c *socksChainSession) forwardTraffic(tls bool) {
if c.pinfo.Sandbox != "" { if c.pinfo.Sandbox != "" {
log.Errorf("TLSGuard violation: Dropping traffic from %s (sandbox: %s) to %s: %v", c.pinfo.ExePath, c.pinfo.Sandbox, c.req.Addr.addrStr, err) log.Errorf("TLSGuard violation: Dropping traffic from %s (sandbox: %s) to %s: %v", c.pinfo.ExePath, c.pinfo.Sandbox, c.req.Addr.addrStr, err)
} else { } else {
log.Errorf("TLSGuard violation: Dropping traffic from %s (unsandboxed) to %s: %v", c.pinfo.ExePath, c.req.Addr.addrStr, err) log.Errorf("TLSGuard violation: Dropping traffic from %s (un-sandboxed) to %s: %v", c.pinfo.ExePath, c.req.Addr.addrStr, err)
} }
return return
} else { } else {

@ -3,177 +3,279 @@ package sgfw
import ( import (
"crypto/x509" "crypto/x509"
"errors" "errors"
"fmt"
"io" "io"
"net" "net"
"time"
) )
func TLSGuard(conn, conn2 net.Conn, fqdn string) error { const TLSGUARD_READ_TIMEOUT = 5 * time.Second
// Should this be a requirement? const TLSGUARD_MIN_TLS_VER_MAJ = 3
// if strings.HasSuffix(request.DestAddr.FQDN, "onion") { const TLSGUARD_MIN_TLS_VER_MIN = 1
handshakeByte, err := readNBytes(conn, 1) const SSL3_RT_CHANGE_CIPHER_SPEC = 20
if err != nil { const SSL3_RT_ALERT = 21
return err const SSL3_RT_HANDSHAKE = 22
} const SSL3_RT_APPLICATION_DATA = 23
if handshakeByte[0] != 0x16 { const SSL3_MT_SERVER_HELLO = 2
return errors.New("Blocked client from attempting non-TLS connection") const SSL3_MT_CERTIFICATE = 11
} const SSL3_MT_CERTIFICATE_REQUEST = 13
const SSL3_MT_SERVER_DONE = 14
vers, err := readNBytes(conn, 2) type connReader struct {
if err != nil { client bool
return err data []byte
} rtype int
err error
}
length, err := readNBytes(conn, 2) func connectionReader(conn net.Conn, is_client bool, c chan connReader, done chan bool) {
if err != nil { var ret_error error = nil
return err buffered := []byte{}
} mlen := 0
rtype := 0
stage := 1
ffslen := int(int(length[0])<<8 | int(length[1])) for {
if ret_error != nil {
cr := connReader{client: is_client, data: nil, rtype: 0, err: ret_error}
c <- cr
break
}
ffs, err := readNBytes(conn, ffslen) select {
if err != nil { case <-done:
return err fmt.Println("++ DONE: ", is_client)
} if len(buffered) > 0 {
//fmt.Println("++ DONE BUT DISPOSING OF BUFFERED DATA")
c <- connReader{client: is_client, data: buffered, rtype: 0, err: nil}
}
// Transmit client hello c <- connReader{client: is_client, data: nil, rtype: 0, err: nil}
conn2.Write(handshakeByte) return
conn2.Write(vers) default:
conn2.Write(length) if stage == 1 {
conn2.Write(ffs) header := make([]byte, 5)
conn.SetReadDeadline(time.Now().Add(TLSGUARD_READ_TIMEOUT))
// Read ServerHello _, err := io.ReadFull(conn, header)
bytesRead := 0 conn.SetReadDeadline(time.Time{})
var s byte // 0x0e is done if err != nil {
var responseBuf []byte = []byte{} ret_error = err
valid := false continue
sendToClient := false }
for sendToClient == false { if int(header[1]) < TLSGUARD_MIN_TLS_VER_MAJ {
// Handshake byte ret_error = errors.New("TLS protocol major version less than expected minimum")
serverhandshakeByte, err := readNBytes(conn2, 1) continue
if err != nil { } else if int(header[2]) < TLSGUARD_MIN_TLS_VER_MIN {
return nil ret_error = errors.New("TLS protocol minor version less than expected minimum")
} continue
}
responseBuf = append(responseBuf, serverhandshakeByte[0]) rtype = int(header[0])
bytesRead += 1 mlen = int(int(header[3])<<8 | int(header[4]))
fmt.Printf("TLS data chunk header read: type = %#x, maj = %v, min = %v, len = %v\n", rtype, header[1], header[2], mlen)
buffered = header
if serverhandshakeByte[0] != 0x16 { stage++
return errors.New("Expected TLS server handshake byte was not received") } else if stage == 2 {
} remainder := make([]byte, mlen)
conn.SetReadDeadline(time.Now().Add(TLSGUARD_READ_TIMEOUT))
_, err := io.ReadFull(conn, remainder)
conn.SetReadDeadline(time.Time{})
if err != nil {
ret_error = err
continue
}
// Protocol version, 2 bytes buffered = append(buffered, remainder...)
serverProtocolVer, err := readNBytes(conn2, 2) fmt.Printf("------- CHUNK READ: client: %v, err = %v, bytes = %v\n", is_client, err, len(buffered))
if err != nil { cr := connReader{client: is_client, data: buffered, rtype: rtype, err: err}
return err c <- cr
}
bytesRead += 2 buffered = []byte{}
responseBuf = append(responseBuf, serverProtocolVer...) rtype = 0
mlen = 0
stage = 1
}
// Record length, 2 bytes
serverRecordLen, err := readNBytes(conn2, 2)
if err != nil {
return err
} }
bytesRead += 2 }
responseBuf = append(responseBuf, serverRecordLen...)
serverRecordLenInt := int(int(serverRecordLen[0])<<8 | int(serverRecordLen[1]))
// Record type byte }
serverMsg, err := readNBytes(conn2, serverRecordLenInt)
if err != nil {
return err
}
bytesRead += len(serverMsg) func TLSGuard(conn, conn2 net.Conn, fqdn string) error {
responseBuf = append(responseBuf, serverMsg...) x509Valid := false
s = serverMsg[0] ndone := 0
// Should this be a requirement?
// if strings.HasSuffix(request.DestAddr.FQDN, "onion") {
// Message len, 3 bytes //conn client
serverMessageLen := serverMsg[1:4] //conn2 server
serverMessageLenInt := int(int(serverMessageLen[0])<<16 | int(serverMessageLen[1])<<8 | int(serverMessageLen[2]))
// serverHelloBody, err := readNBytes(conn2, serverMessageLenInt) fmt.Println("-------- STARTING HANDSHAKE LOOP")
serverHelloBody := serverMsg[4 : 4+serverMessageLenInt] crChan := make(chan connReader)
dChan := make(chan bool, 10)
go connectionReader(conn, true, crChan, dChan)
go connectionReader(conn2, false, crChan, dChan)
if s == 0x0b { select_loop:
certChainLen := int(int(serverHelloBody[0])<<16 | int(serverHelloBody[1])<<8 | int(serverHelloBody[2])) for {
remaining := certChainLen if ndone == 2 {
pos := serverHelloBody[3:certChainLen] fmt.Println("DONE channel got both notifications. Terminating loop.")
close(dChan)
close(crChan)
break
}
// var certChain []*x509.Certificate select {
var verifyOptions x509.VerifyOptions case cr := <-crChan:
other := conn
if fqdn != "" { if cr.client {
verifyOptions.DNSName = fqdn other = conn2
} }
pool := x509.NewCertPool() fmt.Printf("++++ SELECT: %v, %v, %v\n", cr.client, cr.err, len(cr.data))
var c *x509.Certificate if cr.err == nil && cr.data == nil {
fmt.Println("DONE channel notification received")
ndone++
continue
}
for remaining > 0 { if cr.err == nil {
certLen := int(int(pos[0])<<16 | int(pos[1])<<8 | int(pos[2])) if cr.rtype == SSL3_RT_CHANGE_CIPHER_SPEC || cr.rtype == SSL3_RT_APPLICATION_DATA ||
// fmt.Printf("Certs chain len %d, cert 1 len %d:\n", certChainLen, certLen) cr.rtype == SSL3_RT_ALERT {
cert := pos[3 : 3+certLen] // fmt.Println("OTHER DATA; PASSING THRU")
certs, err := x509.ParseCertificates(cert) if cr.rtype == SSL3_RT_ALERT {
if remaining == certChainLen { fmt.Println("ALERT = ", cr.data)
c = certs[0] }
} else { other.Write(cr.data)
pool.AddCert(certs[0]) continue
} else if cr.client {
other.Write(cr.data)
continue
} else if cr.rtype != SSL3_RT_HANDSHAKE {
return errors.New(fmt.Sprintf("Expected TLS server handshake byte was not received [%#x vs 0x16]", cr.rtype))
} }
// certChain = append(certChain, certs[0])
if err != nil { serverMsg := cr.data[5:]
return err s := serverMsg[0]
fmt.Printf("s = %#x\n", s)
if s > 0x22 {
fmt.Println("WTF: ", cr.data)
}
if s == SSL3_MT_CERTIFICATE {
fmt.Println("HMM")
// Message len, 3 bytes
serverMessageLen := serverMsg[1:4]
serverMessageLenInt := int(int(serverMessageLen[0])<<16 | int(serverMessageLen[1])<<8 | int(serverMessageLen[2]))
// fmt.Printf("chunk len = %v, serverMsgLen = %v, slint = %v\n", len(chunk), len(serverMsg), serverMessageLenInt)
if len(serverMsg) < serverMessageLenInt {
return errors.New(fmt.Sprintf("len(serverMsg) %v < serverMessageLenInt %v!\n", len(serverMsg), serverMessageLenInt))
}
serverHelloBody := serverMsg[4 : 4+serverMessageLenInt]
certChainLen := int(int(serverHelloBody[0])<<16 | int(serverHelloBody[1])<<8 | int(serverHelloBody[2]))
remaining := certChainLen
pos := serverHelloBody[3:certChainLen]
// var certChain []*x509.Certificate
var verifyOptions x509.VerifyOptions
//fqdn = "www.reddit.com"
if fqdn != "" {
verifyOptions.DNSName = fqdn
}
pool := x509.NewCertPool()
var c *x509.Certificate
for remaining > 0 {
certLen := int(int(pos[0])<<16 | int(pos[1])<<8 | int(pos[2]))
// fmt.Printf("Certs chain len %d, cert 1 len %d:\n", certChainLen, certLen)
cert := pos[3 : 3+certLen]
certs, err := x509.ParseCertificates(cert)
if remaining == certChainLen {
c = certs[0]
} else {
pool.AddCert(certs[0])
}
// certChain = append(certChain, certs[0])
if err != nil {
return err
}
remaining = remaining - certLen - 3
if remaining > 0 {
pos = pos[3+certLen:]
}
}
verifyOptions.Intermediates = pool
fmt.Println("ATTEMPTING TO VERIFY: ", fqdn)
_, err := c.Verify(verifyOptions)
fmt.Println("ATTEMPTING TO VERIFY RESULT: ", err)
if err != nil {
return err
} else {
x509Valid = true
}
} }
remaining = remaining - certLen - 3
if remaining > 0 { other.Write(cr.data)
pos = pos[3+certLen:]
if x509Valid || (s == SSL3_MT_SERVER_DONE) || (s == SSL3_MT_CERTIFICATE_REQUEST) {
fmt.Println("BREAKING OUT OF LOOP 1")
dChan <- true
fmt.Println("BREAKING OUT OF LOOP 2")
break select_loop
} }
}
verifyOptions.Intermediates = pool
_, err = c.Verify(verifyOptions) // fmt.Printf("Sending chunk of type %d to client.\n", s)
if err != nil { } else if cr.err != nil {
return err ndone++
} else {
valid = true if cr.client {
fmt.Println("Client read error: ", cr.err)
} else {
fmt.Println("Server read error: ", cr.err)
}
return cr.err
} }
// else if s == 0x0d { fmt.Printf("found a client cert request, sending buf to client\n") }
} else if s == 0x0e {
sendToClient = true
} else if s == 0x0d {
sendToClient = true
} }
}
fmt.Println("WAITING; ndone = ", ndone)
for ndone < 2 {
fmt.Println("WAITING; ndone = ", ndone)
select {
case cr := <-crChan:
fmt.Printf("CHAN DATA: %v, %v, %v\n", cr.client, cr.err, len(cr.data))
if cr.err != nil || cr.data == nil {
ndone++
} else if cr.client {
conn2.Write(cr.data)
} else if !cr.client {
conn.Write(cr.data)
}
// fmt.Printf("Version bytes: %x %x\n", responseBuf[1], responseBuf[2]) }
// fmt.Printf("Len bytes: %x %x\n", responseBuf[3], responseBuf[4])
// fmt.Printf("Message type: %x\n", responseBuf[5])
// fmt.Printf("Message len: %x %x %x\n", responseBuf[6], responseBuf[7], responseBuf[8])
// fmt.Printf("Message body: %v\n", responseBuf[9:])
conn.Write(responseBuf)
responseBuf = []byte{}
} }
if !valid { fmt.Println("______ ndone = 2\n")
// dChan <- true
close(dChan)
if !x509Valid {
return errors.New("Unknown error: TLS connection could not be validated") return errors.New("Unknown error: TLS connection could not be validated")
} }
return nil return nil
}
func readNBytes(conn net.Conn, numBytes int) ([]byte, error) {
res := make([]byte, 0)
temp := make([]byte, 1)
for i := 0; i < numBytes; i++ {
_, err := io.ReadAtLeast(conn, temp, 1)
if err != nil {
return res, err
}
res = append(res, temp[0])
}
return res, nil
} }

Loading…
Cancel
Save