From 7472b4d828fd3285adfffffffca6c5186ecc1f90 Mon Sep 17 00:00:00 2001 From: xSmurf Date: Thu, 28 Sep 2017 20:38:51 +0000 Subject: [PATCH] Merged from shw_dev --- fw-prompt/dbus.go | 52 +- fw-prompt/fw-prompt.go | 571 +++++++++++---- sgfw/dbus.go | 73 +- sgfw/policy.go | 105 ++- sgfw/prompt.go | 365 +++++++--- sgfw/rules.go | 9 + sgfw/sgfw.go | 14 +- sgfw/socks_server_chain.go | 71 +- sgfw/tlsguard.go | 651 ++++++++++++++---- .../github.com/subgraph/go-procsnitch/proc.go | 19 +- .../subgraph/go-procsnitch/proc_pid.go | 22 +- .../subgraph/go-procsnitch/socket.go | 51 +- 12 files changed, 1532 insertions(+), 471 deletions(-) diff --git a/fw-prompt/dbus.go b/fw-prompt/dbus.go index 3344a62..8e03ade 100644 --- a/fw-prompt/dbus.go +++ b/fw-prompt/dbus.go @@ -4,33 +4,23 @@ import ( "errors" "github.com/godbus/dbus" "log" - // "github.com/gotk3/gotk3/glib" ) +type dbusObject struct { + dbus.BusObject +} + type dbusServer struct { conn *dbus.Conn run bool } -type promptData struct { - Application string - Icon string - Path string - Address string - Port int - IP string - Origin string - Proto string - UID int - GID int - Username string - Groupname string - Pid int - Sandbox string - OptString string - Expanded bool - Expert bool - Action int +func newDbusObjectAdd() (*dbusObject, error) { + conn, err := dbus.SystemBus() + if err != nil { + return nil, err + } + return &dbusObject{conn.Object("com.subgraph.Firewall", "/com/subgraph/Firewall")}, nil } func newDbusServer() (*dbusServer, error) { @@ -62,10 +52,10 @@ func newDbusServer() (*dbusServer, error) { return ds, nil } -func (ds *dbusServer) RequestPrompt(application, icon, path, address string, port int32, ip, origin, proto string, uid, gid int32, username, groupname string, pid int32, sandbox string, - is_socks bool, optstring string, expanded, expert bool, action int32) (int32, string, *dbus.Error) { - log.Printf("request prompt: app = %s, icon = %s, path = %s, address = %s, is_socks = %v, action = %v\n", application, icon, path, address, is_socks, action) - decision := addRequest(nil, path, proto, int(pid), ip, address, int(port), int(uid), int(gid), origin, is_socks, optstring, sandbox) +/*func (ds *dbusServer) RequestPrompt(guid, application, icon, path, address string, port int32, ip, origin, proto string, uid, gid int32, username, groupname string, pid int32, sandbox string, + is_socks bool, timestamp string, optstring string, expanded, expert bool, action int32) (int32, string, *dbus.Error) { + log.Printf("request prompt: app = %s, icon = %s, path = %s, address = %s / ip = %s, is_socks = %v, action = %v\n", application, icon, path, address, ip, is_socks, action) + decision := addRequest(nil, guid, path, icon, proto, int(pid), ip, address, int(port), int(uid), int(gid), origin, timestamp, is_socks, optstring, sandbox, int(action)) log.Print("Waiting on decision...") decision.Cond.L.Lock() for !decision.Ready { @@ -73,6 +63,18 @@ func (ds *dbusServer) RequestPrompt(application, icon, path, address string, por } log.Print("Decision returned: ", decision.Rule) decision.Cond.L.Unlock() - // glib.IdleAdd(func, data) return int32(decision.Scope), decision.Rule, nil +}*/ + +func (ds *dbusServer) RequestPromptAsync(guid, application, icon, path, address string, port int32, ip, origin, proto string, uid, gid int32, username, groupname string, pid int32, sandbox string, + is_socks bool, timestamp string, optstring string, expanded, expert bool, action int32) (bool, *dbus.Error) { + log.Printf("ASYNC request prompt: guid = %s, app = %s, icon = %s, path = %s, address = %s / ip = %s, is_socks = %v, action = %v\n", guid, application, icon, path, address, ip, is_socks, action) + addRequestAsync(nil, guid, path, icon, proto, int(pid), ip, address, int(port), int(uid), int(gid), origin, timestamp, is_socks, optstring, sandbox, int(action)) + return true, nil +} + +func (ds *dbusServer) RemovePrompt(guid string) *dbus.Error { + log.Printf("++++++++ Cancelling prompt: %s\n", guid) + removeRequest(nil, guid) + return nil } diff --git a/fw-prompt/fw-prompt.go b/fw-prompt/fw-prompt.go index 36fecf7..124a4fa 100644 --- a/fw-prompt/fw-prompt.go +++ b/fw-prompt/fw-prompt.go @@ -34,34 +34,44 @@ type decisionWaiter struct { } type ruleColumns struct { - Path string - Proto string - Pid int - Target string - Hostname string - Port int - UID int - GID int - Uname string - Gname string - Origin string - Scope int + nrefs int + Path string + GUID string + Icon string + Proto string + Pid int + Target string + Hostname string + Port int + UID int + GID int + Uname string + Gname string + Origin string + Timestamp string + IsSocks bool + ForceTLS bool + Scope int } +var dbuso *dbusObject var userPrefs fpPreferences var mainWin *gtk.Window var Notebook *gtk.Notebook -var globalLS *gtk.ListStore +var globalLS *gtk.ListStore = nil var globalTV *gtk.TreeView +var globalPromptLock = &sync.Mutex{} +var globalIcon *gtk.Image var decisionWaiters []*decisionWaiter var editApp, editTarget, editPort, editUser, editGroup *gtk.Entry var comboProto *gtk.ComboBoxText var radioOnce, radioProcess, radioParent, radioSession, radioPermanent *gtk.RadioButton var btnApprove, btnDeny, btnIgnore *gtk.Button -var chkUser, chkGroup *gtk.CheckButton +var chkTLS, chkUser, chkGroup *gtk.CheckButton func dumpDecisions() { + return fmt.Println("XXX Total of decisions pending: ", len(decisionWaiters)) for i := 0; i < len(decisionWaiters); i++ { fmt.Printf("XXX %d ready = %v, rule = %v\n", i+1, decisionWaiters[i].Ready, decisionWaiters[i].Rule) @@ -69,6 +79,7 @@ func dumpDecisions() { } func addDecision() *decisionWaiter { + return nil decision := decisionWaiter{Lock: &sync.Mutex{}, Ready: false, Scope: int(sgfw.APPLY_ONCE), Rule: ""} decision.Cond = sync.NewCond(decision.Lock) decisionWaiters = append(decisionWaiters, &decision) @@ -306,7 +317,8 @@ func createColumn(title string, id int) *gtk.TreeViewColumn { } func createListStore(general bool) *gtk.ListStore { - colData := []glib.Type{glib.TYPE_INT, glib.TYPE_STRING, glib.TYPE_STRING, glib.TYPE_INT, glib.TYPE_STRING, glib.TYPE_STRING, glib.TYPE_INT, glib.TYPE_INT, glib.TYPE_INT, glib.TYPE_STRING, glib.TYPE_STRING} + colData := []glib.Type{glib.TYPE_INT, glib.TYPE_STRING, glib.TYPE_STRING, glib.TYPE_STRING, glib.TYPE_STRING, glib.TYPE_INT, glib.TYPE_STRING, + glib.TYPE_STRING, glib.TYPE_INT, glib.TYPE_INT, glib.TYPE_INT, glib.TYPE_STRING, glib.TYPE_STRING, glib.TYPE_INT, glib.TYPE_STRING, glib.TYPE_INT} listStore, err := gtk.ListStoreNew(colData...) if err != nil { @@ -316,24 +328,92 @@ func createListStore(general bool) *gtk.ListStore { return listStore } -func addRequest(listStore *gtk.ListStore, path, proto string, pid int, ipaddr, hostname string, port, uid, gid int, origin string, is_socks bool, optstring string, sandbox string) *decisionWaiter { +func removeRequest(listStore *gtk.ListStore, guid string) { + removed := false + globalPromptLock.Lock() + defer globalPromptLock.Unlock() + + /* XXX: This is horrible. Figure out how to do this properly. */ + for ridx := 0; ridx < globalLS.IterNChildren(nil); ridx++ { + + rule, _, err := getRuleByIdx(ridx) + if err != nil { + break + } else if rule.GUID == guid { + removeSelectedRule(ridx, true) + removed = true + break + } + + } + + if !removed { + log.Printf("Unexpected condition: SGFW requested prompt removal for non-existent GUID %v\n", guid) + } + +} + +func addRequestInc(listStore *gtk.ListStore, guid, path, icon, proto string, pid int, ipaddr, hostname string, port, uid, gid int, + origin string, is_socks bool, optstring string, sandbox string, action int) bool { + duplicated := false + + globalPromptLock.Lock() + defer globalPromptLock.Unlock() + + for ridx := 0; ridx < globalLS.IterNChildren(nil); ridx++ { + + /* XXX: This is horrible. Figure out how to do this properly. */ + rule, iter, err := getRuleByIdx(ridx) + if err != nil { + break + // XXX: not compared: optstring/sandbox + } else if (rule.Path == path) && (rule.Proto == proto) && (rule.Pid == pid) && (rule.Target == ipaddr) && (rule.Hostname == hostname) && + (rule.Port == port) && (rule.UID == uid) && (rule.GID == gid) && (rule.Origin == origin) && (rule.IsSocks == is_socks) { + rule.nrefs++ + + err := globalLS.SetValue(iter, 0, rule.nrefs) + if err != nil { + log.Println("Error creating duplicate firewall prompt entry:", err) + break + } + + fmt.Println("YES REALLY DUPLICATE: ", rule.nrefs) + duplicated = true + break + } + + } + + return duplicated +} + +func addRequestAsync(listStore *gtk.ListStore, guid, path, icon, proto string, pid int, ipaddr, hostname string, port, uid, gid int, + origin, timestamp string, is_socks bool, optstring string, sandbox string, action int) bool { + addRequest(listStore, guid, path, icon, proto, pid, ipaddr, hostname, port, uid, gid, origin, timestamp, is_socks, + optstring, sandbox, action) + return true +} + +func addRequest(listStore *gtk.ListStore, guid, path, icon, proto string, pid int, ipaddr, hostname string, port, uid, gid int, + origin, timestamp string, is_socks bool, optstring string, sandbox string, action int) *decisionWaiter { if listStore == nil { listStore = globalLS waitTimes := []int{1, 2, 5, 10} if listStore == nil { - log.Print("SGFW prompter was not ready to receive firewall request... waiting") - } + log.Println("SGFW prompter was not ready to receive firewall request... waiting") - for _, wtime := range waitTimes { - time.Sleep(time.Duration(wtime) * time.Second) - listStore = globalLS + for _, wtime := range waitTimes { + time.Sleep(time.Duration(wtime) * time.Second) + listStore = globalLS - if listStore != nil { - break + if listStore != nil { + break + } + + log.Println("SGFW prompter is still waiting...") } - log.Print("SGFW prompter is still waiting...") } } @@ -342,6 +422,18 @@ func addRequest(listStore *gtk.ListStore, path, proto string, pid int, ipaddr, h log.Fatal("SGFW prompter GUI failed to load for unknown reasons") } + if addRequestInc(listStore, guid, path, icon, proto, pid, ipaddr, hostname, port, uid, gid, origin, is_socks, optstring, sandbox, action) { + fmt.Println("REQUEST WAS DUPLICATE") + decision := addDecision() + globalPromptLock.Lock() + toggleHover() + globalPromptLock.Unlock() + return decision + } else { + fmt.Println("NOT DUPLICATE") + } + + globalPromptLock.Lock() iter := listStore.Append() if is_socks { @@ -352,24 +444,34 @@ func addRequest(listStore *gtk.ListStore, path, proto string, pid int, ipaddr, h } } - colVals := make([]interface{}, 11) + colVals := make([]interface{}, 16) colVals[0] = 1 - colVals[1] = path - colVals[2] = proto - colVals[3] = pid + colVals[1] = guid + colVals[2] = path + colVals[3] = icon + colVals[4] = proto + colVals[5] = pid if ipaddr == "" { - colVals[4] = "---" + colVals[6] = "---" } else { - colVals[4] = ipaddr + colVals[6] = ipaddr + } + + colVals[7] = hostname + colVals[8] = port + colVals[9] = uid + colVals[10] = gid + colVals[11] = origin + colVals[12] = timestamp + colVals[13] = 0 + + if is_socks { + colVals[13] = 1 } - colVals[5] = hostname - colVals[6] = port - colVals[7] = uid - colVals[8] = gid - colVals[9] = origin - colVals[10] = optstring + colVals[14] = optstring + colVals[15] = action colNums := make([]int, len(colVals)) @@ -386,6 +488,7 @@ func addRequest(listStore *gtk.ListStore, path, proto string, pid int, ipaddr, h decision := addDecision() dumpDecisions() toggleHover() + globalPromptLock.Unlock() return decision } @@ -479,22 +582,42 @@ func lsGetInt(ls *gtk.ListStore, iter *gtk.TreeIter, idx int) (int, error) { return ival.(int), nil } -func makeDecision(idx int, rule string, scope int) { +func makeDecision(idx int, rule string, scope int) error { + var dres bool + call := dbuso.Call("AddRuleAsync", 0, uint32(scope), rule, "*") + + err := call.Store(&dres) + if err != nil { + log.Println("Error notifying SGFW of asynchronous rule addition:", err) + return err + } + + fmt.Println("makeDecision remote result:", dres) + + return nil decisionWaiters[idx].Cond.L.Lock() decisionWaiters[idx].Rule = rule decisionWaiters[idx].Scope = scope decisionWaiters[idx].Ready = true decisionWaiters[idx].Cond.Signal() decisionWaiters[idx].Cond.L.Unlock() + return nil } +/* Do we need to hold the lock while this is called? Stay safe... */ func toggleHover() { - mainWin.SetKeepAbove(len(decisionWaiters) > 0) + nitems := globalLS.IterNChildren(nil) + + mainWin.SetKeepAbove(nitems > 0) } func toggleValidRuleState() { ok := true + // XXX: Unfortunately, this can cause deadlock since it's a part of the item removal cascade + // globalPromptLock.Lock() + // defer globalPromptLock.Unlock() + if numSelections() <= 0 { ok = false } @@ -536,7 +659,8 @@ func toggleValidRuleState() { btnApprove.SetSensitive(ok) btnDeny.SetSensitive(ok) - btnIgnore.SetSensitive(ok) + // btnIgnore.SetSensitive(ok) + btnIgnore.SetSensitive(false) } func createCurrentRule() (ruleColumns, error) { @@ -579,6 +703,9 @@ func createCurrentRule() (ruleColumns, error) { rule.UID, rule.GID = 0, 0 rule.Uname, rule.Gname = "", "" + + rule.ForceTLS = chkTLS.GetActive() + /* Pid int Origin string */ @@ -586,6 +713,7 @@ func createCurrentRule() (ruleColumns, error) { } func clearEditor() { + globalIcon.Clear() editApp.SetText("") editTarget.SetText("") editPort.SetText("") @@ -599,6 +727,7 @@ func clearEditor() { radioPermanent.SetActive(false) chkUser.SetActive(false) chkGroup.SetActive(false) + chkTLS.SetActive(false) } func removeSelectedRule(idx int, rmdecision bool) error { @@ -617,7 +746,7 @@ func removeSelectedRule(idx int, rmdecision bool) error { globalLS.Remove(iter) if rmdecision { - decisionWaiters = append(decisionWaiters[:idx], decisionWaiters[idx+1:]...) + // decisionWaiters = append(decisionWaiters[:idx], decisionWaiters[idx+1:]...) } toggleHover() @@ -634,78 +763,126 @@ func numSelections() int { return int(rows.Length()) } -func getSelectedRule() (ruleColumns, int, error) { +// Needs to be locked by the caller +func getRuleByIdx(idx int) (ruleColumns, *gtk.TreeIter, error) { rule := ruleColumns{} - sel, err := globalTV.GetSelection() + path, err := gtk.TreePathNewFromString(fmt.Sprintf("%d", idx)) if err != nil { - return rule, -1, err + return rule, nil, err } - rows := sel.GetSelectedRows(globalLS) + iter, err := globalLS.GetIter(path) + if err != nil { + return rule, nil, err + } - if rows.Length() <= 0 { - return rule, -1, errors.New("No selection was made") + rule.nrefs, err = lsGetInt(globalLS, iter, 0) + if err != nil { + return rule, nil, err } - rdata := rows.NthData(0) - lIndex, err := strconv.Atoi(rdata.(*gtk.TreePath).String()) + rule.GUID, err = lsGetStr(globalLS, iter, 1) if err != nil { - return rule, -1, err + return rule, nil, err } - fmt.Println("lindex = ", lIndex) - path, err := gtk.TreePathNewFromString(fmt.Sprintf("%d", lIndex)) + rule.Path, err = lsGetStr(globalLS, iter, 2) if err != nil { - return rule, -1, err + return rule, nil, err } - iter, err := globalLS.GetIter(path) + rule.Icon, err = lsGetStr(globalLS, iter, 3) if err != nil { - return rule, -1, err + return rule, nil, err } - rule.Path, err = lsGetStr(globalLS, iter, 1) + rule.Proto, err = lsGetStr(globalLS, iter, 4) if err != nil { - return rule, -1, err + return rule, nil, err } - rule.Proto, err = lsGetStr(globalLS, iter, 2) + rule.Pid, err = lsGetInt(globalLS, iter, 5) if err != nil { - return rule, -1, err + return rule, nil, err } - rule.Pid, err = lsGetInt(globalLS, iter, 3) + rule.Target, err = lsGetStr(globalLS, iter, 6) if err != nil { - return rule, -1, err + return rule, nil, err } - rule.Target, err = lsGetStr(globalLS, iter, 4) + rule.Hostname, err = lsGetStr(globalLS, iter, 7) if err != nil { - return rule, -1, err + return rule, nil, err } - rule.Hostname, err = lsGetStr(globalLS, iter, 5) + rule.Port, err = lsGetInt(globalLS, iter, 8) if err != nil { - return rule, -1, err + return rule, nil, err } - rule.Port, err = lsGetInt(globalLS, iter, 6) + rule.UID, err = lsGetInt(globalLS, iter, 9) if err != nil { - return rule, -1, err + return rule, nil, err + } + + rule.GID, err = lsGetInt(globalLS, iter, 10) + if err != nil { + return rule, nil, err } - rule.UID, err = lsGetInt(globalLS, iter, 7) + rule.Origin, err = lsGetStr(globalLS, iter, 11) + if err != nil { + return rule, nil, err + } + + rule.Timestamp, err = lsGetStr(globalLS, iter, 12) + if err != nil { + return rule, nil, err + } + + rule.IsSocks = false + is_socks, err := lsGetInt(globalLS, iter, 13) + if err != nil { + return rule, nil, err + } + + if is_socks != 0 { + rule.IsSocks = true + } + + rule.Scope, err = lsGetInt(globalLS, iter, 15) + if err != nil { + return rule, nil, err + } + + return rule, iter, nil +} + +// Needs to be locked by the caller +func getSelectedRule() (ruleColumns, int, error) { + rule := ruleColumns{} + + sel, err := globalTV.GetSelection() if err != nil { return rule, -1, err } - rule.GID, err = lsGetInt(globalLS, iter, 8) + rows := sel.GetSelectedRows(globalLS) + + if rows.Length() <= 0 { + return rule, -1, errors.New("No selection was made") + } + + rdata := rows.NthData(0) + lIndex, err := strconv.Atoi(rdata.(*gtk.TreePath).String()) if err != nil { return rule, -1, err } - rule.Origin, err = lsGetStr(globalLS, iter, 9) + fmt.Println("lindex = ", lIndex) + rule, _, err = getRuleByIdx(lIndex) if err != nil { return rule, -1, err } @@ -713,6 +890,109 @@ func getSelectedRule() (ruleColumns, int, error) { return rule, lIndex, nil } +func addPendingPrompts(rules []string) { + + for _, rule := range rules { + fields := strings.Split(rule, "|") + + if len(fields) != 19 { + log.Printf("Got saved prompt message with strange data: \"%s\"", rule) + continue + } + + guid := fields[0] + icon := fields[2] + path := fields[3] + address := fields[4] + + port, err := strconv.Atoi(fields[5]) + if err != nil { + log.Println("Error converting port in pending prompt message to integer:", err) + continue + } + + ip := fields[6] + origin := fields[7] + proto := fields[8] + + uid, err := strconv.Atoi(fields[9]) + if err != nil { + log.Println("Error converting UID in pending prompt message to integer:", err) + continue + } + + gid, err := strconv.Atoi(fields[10]) + if err != nil { + log.Println("Error converting GID in pending prompt message to integer:", err) + continue + } + + pid, err := strconv.Atoi(fields[13]) + if err != nil { + log.Println("Error converting pid in pending prompt message to integer:", err) + continue + } + + sandbox := fields[14] + + is_socks, err := strconv.ParseBool(fields[15]) + if err != nil { + log.Println("Error converting SOCKS flag in pending prompt message to boolean:", err) + continue + } + + timestamp := fields[16] + optstring := fields[17] + + action, err := strconv.Atoi(fields[18]) + if err != nil { + log.Println("Error converting action in pending prompt message to integer:", err) + continue + } + + addRequestAsync(nil, guid, path, icon, proto, int(pid), ip, address, int(port), int(uid), int(gid), origin, timestamp, is_socks, optstring, sandbox, action) + } + +} + +func buttonAction(action string) { + globalPromptLock.Lock() + rule, idx, err := getSelectedRule() + if err != nil { + globalPromptLock.Unlock() + promptError("Error occurred processing request: " + err.Error()) + return + } + + rule, err = createCurrentRule() + if err != nil { + globalPromptLock.Unlock() + promptError("Error occurred constructing new rule: " + err.Error()) + return + } + + fmt.Println("rule = ", rule) + rulestr := action + + if action == "ALLOW" && rule.ForceTLS { + rulestr += "_TLSONLY" + } + + rulestr += "|" + rule.Proto + ":" + rule.Target + ":" + strconv.Itoa(rule.Port) + rulestr += "|" + sgfw.RuleModeString[sgfw.RuleMode(rule.Scope)] + fmt.Println("RULESTR = ", rulestr) + makeDecision(idx, rulestr, int(rule.Scope)) + fmt.Println("Decision made.") + err = removeSelectedRule(idx, true) + globalPromptLock.Unlock() + if err == nil { + clearEditor() + } else { + promptError("Error setting new rule: " + err.Error()) + } + +} + func main() { decisionWaiters = make([]*decisionWaiter, 0) _, err := newDbusServer() @@ -721,6 +1001,11 @@ func main() { return } + dbuso, err = newDbusObjectAdd() + if err != nil { + log.Fatal("Failed to connect to dbus system bus: %v", err) + } + loadPreferences() gtk.Init(nil) @@ -811,10 +1096,18 @@ func main() { editbox := get_vbox() hbox := get_hbox() lbl := get_label("Application path:") + + globalIcon, err = gtk.ImageNew() + if err != nil { + log.Fatal("Unable to create image:", err) + } + + // globalIcon.SetFromIconName("firefox", gtk.ICON_SIZE_DND) editApp = get_entry("") editApp.Connect("changed", toggleValidRuleState) hbox.PackStart(lbl, false, false, 10) - hbox.PackStart(editApp, true, true, 50) + hbox.PackStart(editApp, true, true, 10) + hbox.PackStart(globalIcon, false, false, 10) editbox.PackStart(hbox, false, false, 5) hbox = get_hbox() @@ -842,7 +1135,9 @@ func main() { radioSession = get_radiobutton(radioOnce, "Session", false) radioPermanent = get_radiobutton(radioOnce, "Permanent", false) radioParent.SetSensitive(false) - hbox.PackStart(lbl, false, false, 10) + chkTLS = get_checkbox("Require TLS", false) + hbox.PackStart(chkTLS, false, false, 10) + hbox.PackStart(lbl, false, false, 20) hbox.PackStart(radioOnce, false, false, 5) hbox.PackStart(radioProcess, false, false, 5) hbox.PackStart(radioParent, false, false, 5) @@ -872,94 +1167,55 @@ func main() { box.PackStart(scrollbox, false, true, 5) tv.AppendColumn(createColumn("#", 0)) - tv.AppendColumn(createColumn("Path", 1)) - tv.AppendColumn(createColumn("Protocol", 2)) - tv.AppendColumn(createColumn("PID", 3)) - tv.AppendColumn(createColumn("IP Address", 4)) - tv.AppendColumn(createColumn("Hostname", 5)) - tv.AppendColumn(createColumn("Port", 6)) - tv.AppendColumn(createColumn("UID", 7)) - tv.AppendColumn(createColumn("GID", 8)) - tv.AppendColumn(createColumn("Origin", 9)) - tv.AppendColumn(createColumn("Details", 10)) - listStore := createListStore(true) - globalLS = listStore + guidcol := createColumn("GUID", 1) + guidcol.SetVisible(false) + tv.AppendColumn(guidcol) - tv.SetModel(listStore) + tv.AppendColumn(createColumn("Path", 2)) - btnApprove.Connect("clicked", func() { - rule, idx, err := getSelectedRule() - if err != nil { - promptError("Error occurred processing request: " + err.Error()) - return - } + icol := createColumn("Icon", 3) + icol.SetVisible(false) + tv.AppendColumn(icol) - rule, err = createCurrentRule() - if err != nil { - promptError("Error occurred constructing new rule: " + err.Error()) - return - } + tv.AppendColumn(createColumn("Protocol", 4)) + tv.AppendColumn(createColumn("PID", 5)) + tv.AppendColumn(createColumn("IP Address", 6)) + tv.AppendColumn(createColumn("Hostname", 7)) + tv.AppendColumn(createColumn("Port", 8)) + tv.AppendColumn(createColumn("UID", 9)) + tv.AppendColumn(createColumn("GID", 10)) + tv.AppendColumn(createColumn("Origin", 11)) + tv.AppendColumn(createColumn("Timestamp", 12)) - fmt.Println("rule = ", rule) - rulestr := "ALLOW|" + rule.Proto + ":" + rule.Target + ":" + strconv.Itoa(rule.Port) - fmt.Println("RULESTR = ", rulestr) - makeDecision(idx, rulestr, int(rule.Scope)) - fmt.Println("Decision made.") - err = removeSelectedRule(idx, true) - if err == nil { - clearEditor() - } else { - promptError("Error setting new rule: " + err.Error()) - } - }) + scol := createColumn("Is SOCKS", 13) + scol.SetVisible(false) + tv.AppendColumn(scol) - btnDeny.Connect("clicked", func() { - rule, idx, err := getSelectedRule() - if err != nil { - promptError("Error occurred processing request: " + err.Error()) - return - } + tv.AppendColumn(createColumn("Details", 14)) - rule, err = createCurrentRule() - if err != nil { - promptError("Error occurred constructing new rule: " + err.Error()) - return - } + acol := createColumn("Scope", 15) + acol.SetVisible(false) + tv.AppendColumn(acol) - fmt.Println("rule = ", rule) - rulestr := "DENY|" + rule.Proto + ":" + rule.Target + ":" + strconv.Itoa(rule.Port) - fmt.Println("RULESTR = ", rulestr) - makeDecision(idx, rulestr, int(rule.Scope)) - fmt.Println("Decision made.") - err = removeSelectedRule(idx, true) - if err == nil { - clearEditor() - } else { - promptError("Error setting new rule: " + err.Error()) - } - }) + listStore := createListStore(true) + globalLS = listStore - btnIgnore.Connect("clicked", func() { - _, idx, err := getSelectedRule() - if err != nil { - promptError("Error occurred processing request: " + err.Error()) - return - } + tv.SetModel(listStore) - makeDecision(idx, "", 0) - fmt.Println("Decision made.") - err = removeSelectedRule(idx, true) - if err == nil { - clearEditor() - } else { - promptError("Error setting new rule: " + err.Error()) - } + btnApprove.Connect("clicked", func() { + buttonAction("ALLOW") }) + btnDeny.Connect("clicked", func() { + buttonAction("DENY") + }) + // btnIgnore.Connect("clicked", buttonAction) // tv.SetActivateOnSingleClick(true) tv.Connect("row-activated", func() { + globalPromptLock.Lock() seldata, _, err := getSelectedRule() + globalPromptLock.Unlock() if err != nil { promptError("Unexpected error reading selected rule: " + err.Error()) return @@ -967,6 +1223,12 @@ func main() { editApp.SetText(seldata.Path) + if seldata.Icon != "" { + globalIcon.SetFromIconName(seldata.Icon, gtk.ICON_SIZE_DND) + } else { + globalIcon.Clear() + } + if seldata.Hostname != "" { editTarget.SetText(seldata.Hostname) } else { @@ -974,13 +1236,14 @@ func main() { } editPort.SetText(strconv.Itoa(seldata.Port)) - radioOnce.SetActive(true) - radioProcess.SetActive(false) + radioOnce.SetActive(seldata.Scope == int(sgfw.APPLY_ONCE)) radioProcess.SetSensitive(seldata.Pid > 0) radioParent.SetActive(false) - radioSession.SetActive(false) - radioPermanent.SetActive(false) + radioSession.SetActive(seldata.Scope == int(sgfw.APPLY_SESSION)) + radioPermanent.SetActive(seldata.Scope == int(sgfw.APPLY_FOREVER)) + comboProto.SetActiveID(seldata.Proto) + chkTLS.SetActive(seldata.IsSocks) if seldata.Uname != "" { editUser.SetText(seldata.Uname) @@ -1000,7 +1263,6 @@ func main() { chkUser.SetActive(false) chkGroup.SetActive(false) - return }) @@ -1023,5 +1285,16 @@ func main() { mainWin.ShowAll() // mainWin.SetKeepAbove(true) + + var dres = []string{} + call := dbuso.Call("GetPendingRequests", 0, "*") + err = call.Store(&dres) + if err != nil { + errmsg := "Could not query running SGFW instance (maybe it's not running?): " + err.Error() + promptError(errmsg) + } else { + addPendingPrompts(dres) + } + gtk.Main() } diff --git a/sgfw/dbus.go b/sgfw/dbus.go index 14547ce..95d42e2 100644 --- a/sgfw/dbus.go +++ b/sgfw/dbus.go @@ -199,6 +199,74 @@ func (ds *dbusServer) DeleteRule(id uint32) *dbus.Error { return nil } +func (ds *dbusServer) GetPendingRequests(policy string) ([]string, *dbus.Error) { + log.Debug("+++ GetPendingRequests()") + ds.fw.lock.Lock() + defer ds.fw.lock.Unlock() + + pending_data := make([]string, 0) + + for pname := range ds.fw.policyMap { + policy := ds.fw.policyMap[pname] + pqueue := policy.pendingQueue + + for _, pc := range pqueue { + addr := pc.hostname() + if addr == "" { + addr = pc.dst().String() + } + + dststr := "" + + if pc.dst() != nil { + dststr = pc.dst().String() + } else { + dststr = addr + " (via proxy resolver)" + } + + pstr := "" + pstr += pc.getGUID() + "|" + pstr += policy.application + "|" + pstr += policy.icon + "|" + pstr += policy.path + "|" + pstr += addr + "|" + pstr += strconv.FormatUint(uint64(pc.dstPort()), 10) + "|" + pstr += dststr + "|" + pstr += pc.src().String() + "|" + pstr += pc.proto() + "|" + pstr += strconv.FormatInt(int64(pc.procInfo().UID), 10) + "|" + pstr += strconv.FormatInt(int64(pc.procInfo().GID), 10) + "|" + pstr += uidToUser(pc.procInfo().UID) + "|" + pstr += gidToGroup(pc.procInfo().GID) + "|" + pstr += strconv.FormatInt(int64(pc.procInfo().Pid), 10) + "|" + pstr += pc.sandbox() + "|" + pstr += strconv.FormatBool(pc.socks()) + "|" + pstr += pc.getTimestamp() + "|" + pstr += pc.getOptString() + "|" + pstr += strconv.FormatUint(uint64(FirewallConfig.DefaultActionID), 10) + pending_data = append(pending_data, pstr) + } + + } + + return pending_data, nil +} + +func (ds *dbusServer) AddRuleAsync(scope uint32, rule string, policy string) (bool, *dbus.Error) { + log.Warningf("AddRuleAsync %v, %v / %v\n", scope, rule, policy) + ds.fw.lock.Lock() + defer ds.fw.lock.Unlock() + + prule := PendingRule{rule: rule, scope: int(scope), policy: policy} + + for pname := range ds.fw.policyMap { + log.Debug("+++ Adding prule to policy") + ds.fw.policyMap[pname].rulesPending = append(ds.fw.policyMap[pname].rulesPending, prule) + } + + return true, nil +} + func (ds *dbusServer) UpdateRule(rule DbusRule) *dbus.Error { log.Debugf("UpdateRule %v", rule) ds.fw.lock.Lock() @@ -268,11 +336,6 @@ func (ds *dbusServer) SetConfig(key string, val dbus.Variant) *dbus.Error { return nil } -func (ds *dbusServer) prompt(p *Policy) { - log.Info("prompting...") - ds.prompter.prompt(p) -} - func (ob *dbusObjectP) alertRule(data string) { ob.Call("com.subgraph.fwprompt.EventNotifier.Alert", 0, data) } diff --git a/sgfw/policy.go b/sgfw/policy.go index b1174fa..936ae73 100644 --- a/sgfw/policy.go +++ b/sgfw/policy.go @@ -6,8 +6,6 @@ import ( "strings" "sync" - // "encoding/binary" - // nfnetlink "github.com/subgraph/go-nfnetlink" "github.com/google/gopacket/layers" nfqueue "github.com/subgraph/go-nfnetlink/nfqueue" @@ -15,6 +13,7 @@ import ( "net" "os" "syscall" + "time" "unsafe" ) @@ -52,6 +51,10 @@ type pendingConnection interface { drop() setPrompting(bool) getPrompting() bool + setPrompter(*prompter) + getPrompter() *prompter + getGUID() string + getTimestamp() string print() string } @@ -62,6 +65,24 @@ type pendingPkt struct { pinfo *procsnitch.Info optstring string prompting bool + prompter *prompter + guid string + timestamp time.Time +} + +/* Not a *REAL* GUID */ +func genGUID() string { + frnd, err := os.Open("/dev/urandom") + if err != nil { + log.Fatal("Error reading random data source:", err) + } + + rndb := make([]byte, 16) + frnd.Read(rndb) + frnd.Close() + + guid := fmt.Sprintf("%x-%x-%x-%x", rndb[0:4], rndb[4:8], rndb[8:12], rndb[12:]) + return guid } func getEmptyPInfo() *procsnitch.Info { @@ -79,6 +100,10 @@ func (pp *pendingPkt) sandbox() string { return pp.pinfo.Sandbox } +func (pc *pendingPkt) getTimestamp() string { + return pc.timestamp.Format("15:04:05.00") +} + func (pp *pendingPkt) socks() bool { return false } @@ -165,6 +190,22 @@ func (pp *pendingPkt) drop() { pp.pkt.Accept() } +func (pp *pendingPkt) setPrompter(val *prompter) { + pp.prompter = val +} + +func (pp *pendingPkt) getPrompter() *prompter { + return pp.prompter +} + +func (pp *pendingPkt) getGUID() string { + if pp.guid == "" { + pp.guid = genGUID() + } + + return pp.guid +} + func (pp *pendingPkt) getPrompting() bool { return pp.prompting } @@ -177,6 +218,12 @@ func (pp *pendingPkt) print() string { return printPacket(pp.pkt, pp.name, pp.pinfo) } +type PendingRule struct { + rule string + scope int + policy string +} + type Policy struct { fw *Firewall path string @@ -187,6 +234,7 @@ type Policy struct { pendingQueue []pendingConnection promptInProgress bool lock sync.Mutex + rulesPending []PendingRule } func (fw *Firewall) PolicyForPath(path string) *Policy { @@ -240,17 +288,25 @@ func (fw *Firewall) policyForPath(path string) *Policy { return fw.policyMap[path] } -func (p *Policy) processPacket(pkt *nfqueue.NFQPacket, pinfo *procsnitch.Info, optstr string) { +func (p *Policy) processPacket(pkt *nfqueue.NFQPacket, timestamp time.Time, pinfo *procsnitch.Info, optstr string) { + fmt.Println("policy processPacket()") - /* hbytes, err := pkt.GetHWAddr() - if err != nil { - log.Notice("Failed to get HW address underlying packet: ", err) - } else { log.Notice("got hwaddr: ", hbytes) } */ p.lock.Lock() defer p.lock.Unlock() dstb := pkt.Packet.NetworkLayer().NetworkFlow().Dst().Raw() dstip := net.IP(dstb) srcip := net.IP(pkt.Packet.NetworkLayer().NetworkFlow().Src().Raw()) + + /* Can we pass this through quickly? */ + /* this probably isn't a performance enhancement. */ + /*_, dstp := getPacketPorts(pkt) + fres := p.rules.filter(pkt, srcip, dstip, dstp, dstip.String(), pinfo, optstr) + if fres == FILTER_ALLOW { + fmt.Printf("Packet passed wildcard rules without requiring DNS lookup; accepting: %s:%d\n", dstip, dstp) + pkt.Accept() + return + }*/ + name := p.fw.dns.Lookup(dstip, pinfo.Pid) log.Infof("Lookup(%s): %s", dstip.String(), name) @@ -267,7 +323,7 @@ func (p *Policy) processPacket(pkt *nfqueue.NFQPacket, pinfo *procsnitch.Info, o case FILTER_ALLOW: pkt.Accept() case FILTER_PROMPT: - p.processPromptResult(&pendingPkt{pol: p, name: name, pkt: pkt, pinfo: pinfo, optstring: optstr, prompting: false}) + p.processPromptResult(&pendingPkt{pol: p, name: name, pkt: pkt, pinfo: pinfo, optstring: optstr, prompter: nil, timestamp: timestamp, prompting: false}) default: log.Warningf("Unexpected filter result: %d", result) } @@ -276,33 +332,27 @@ func (p *Policy) processPacket(pkt *nfqueue.NFQPacket, pinfo *procsnitch.Info, o func (p *Policy) processPromptResult(pc pendingConnection) { p.pendingQueue = append(p.pendingQueue, pc) //fmt.Println("processPromptResult(): p.promptInProgress = ", p.promptInProgress) - if DoMultiPrompt || (!DoMultiPrompt && !p.promptInProgress) { - p.promptInProgress = true - go p.fw.dbus.prompt(p) - } + //if DoMultiPrompt || (!DoMultiPrompt && !p.promptInProgress) { + // if !p.promptInProgress { + p.promptInProgress = true + go p.fw.dbus.prompter.prompt(p) + // } } func (p *Policy) nextPending() (pendingConnection, bool) { p.lock.Lock() defer p.lock.Unlock() - if !DoMultiPrompt { - if len(p.pendingQueue) == 0 { - return nil, true - } - return p.pendingQueue[0], false - } if len(p.pendingQueue) == 0 { return nil, true } - // for len(p.pendingQueue) != 0 { for i := 0; i < len(p.pendingQueue); i++ { + // fmt.Printf("XXX: pendingQueue %v of %v: %v\n", i, len(p.pendingQueue), p.pendingQueue[i]) if !p.pendingQueue[i].getPrompting() { return p.pendingQueue[i], false } } - // } return nil, false } @@ -329,6 +379,7 @@ func (p *Policy) processNewRule(r *Rule, scope FilterScope) bool { if scope != APPLY_ONCE { p.rules = append(p.rules, r) } + fmt.Println("----------------------- processNewRule()") p.filterPending(r) if len(p.pendingQueue) == 0 { p.promptInProgress = false @@ -372,6 +423,15 @@ func (p *Policy) filterPending(rule *Rule) { remaining := []pendingConnection{} for _, pc := range p.pendingQueue { if rule.match(pc.src(), pc.dst(), pc.dstPort(), pc.hostname(), pc.proto(), pc.procInfo().UID, pc.procInfo().GID, uidToUser(pc.procInfo().UID), gidToGroup(pc.procInfo().GID), pc.procInfo().Sandbox) { + prompter := pc.getPrompter() + + if prompter == nil { + fmt.Println("-------- prompter = NULL") + } else { + call := prompter.dbusObj.Call("com.subgraph.FirewallPrompt.RemovePrompt", 0, pc.getGUID()) + fmt.Println("CAAAAAAAAAAAAAAALL = ", call) + } + log.Infof("Adding rule for: %s", rule.getString(FirewallConfig.LogRedact)) // log.Noticef("%s > %s", rule.getString(FirewallConfig.LogRedact), pc.print()) if rule.rtype == RULE_ACTION_ALLOW { @@ -442,7 +502,8 @@ func printPacket(pkt *nfqueue.NFQPacket, hostname string, pinfo *procsnitch.Info return fmt.Sprintf("%s %s %s:%d -> %s:%d", pinfo.ExePath, proto, SrcIp, SrcPort, name, DstPort) } -func (fw *Firewall) filterPacket(pkt *nfqueue.NFQPacket) { +func (fw *Firewall) filterPacket(pkt *nfqueue.NFQPacket, timestamp time.Time) { + fmt.Println("firewall: filterPacket()") isudp := pkt.Packet.Layer(layers.LayerTypeUDP) != nil if basicAllowPacket(pkt) { @@ -520,7 +581,7 @@ func (fw *Firewall) filterPacket(pkt *nfqueue.NFQPacket) { */ policy := fw.PolicyForPathAndSandbox(ppath, pinfo.Sandbox) //log.Notice("XXX: flunked basicallowpacket; policy = ", policy) - policy.processPacket(pkt, pinfo, optstring) + policy.processPacket(pkt, timestamp, pinfo, optstring) } func readFileDirect(filename string) ([]byte, error) { diff --git a/sgfw/prompt.go b/sgfw/prompt.go index 619daf2..bd7ef07 100644 --- a/sgfw/prompt.go +++ b/sgfw/prompt.go @@ -2,24 +2,18 @@ package sgfw import ( "fmt" - "net" + "os" "os/user" "strconv" "strings" "sync" + "syscall" "time" "github.com/godbus/dbus" "github.com/subgraph/fw-daemon/proc-coroner" ) -var DoMultiPrompt = true - -const MAX_PROMPTS = 5 - -var outstandingPrompts = 0 -var promptLock = &sync.Mutex{} - func newPrompter(conn *dbus.Conn) *prompter { p := new(prompter) p.cond = sync.NewCond(&p.lock) @@ -42,6 +36,7 @@ func (p *prompter) prompt(policy *Policy) { defer p.lock.Unlock() _, ok := p.policyMap[policy.sandbox+"|"+policy.path] if ok { + p.cond.Signal() return } p.policyMap[policy.sandbox+"|"+policy.path] = policy @@ -51,35 +46,22 @@ func (p *prompter) prompt(policy *Policy) { } func (p *prompter) promptLoop() { - p.lock.Lock() + // p.lock.Lock() for { - // fmt.Println("XXX: promptLoop() outer") - for p.processNextPacket() { - // fmt.Println("XXX: promptLoop() inner") - } - // fmt.Println("promptLoop() wait") - p.cond.Wait() + p.processNextPacket() } } func (p *prompter) processNextPacket() bool { + //fmt.Println("processNextPacket()") var pc pendingConnection = nil - - if !DoMultiPrompt { - pc, _ = p.nextConnection() - if pc == nil { - return false - } - p.lock.Unlock() - defer p.lock.Lock() - p.processConnection(pc) - return true - } - empty := true + for { + p.lock.Lock() pc, empty = p.nextConnection() - // fmt.Println("XXX: processNextPacket() loop; empty = ", empty, " / pc = ", pc) + p.lock.Unlock() + //fmt.Println("XXX: processNextPacket() loop; empty = ", empty, " / pc = ", pc) if pc == nil && empty { return false } else if pc == nil { @@ -88,49 +70,150 @@ func (p *prompter) processNextPacket() bool { break } } - p.lock.Unlock() - defer p.lock.Lock() - // fmt.Println("XXX: Waiting for prompt lock go...") - for { - promptLock.Lock() - if outstandingPrompts >= MAX_PROMPTS { - promptLock.Unlock() - continue - } - if pc.getPrompting() { - log.Debugf("Skipping over already prompted connection") - promptLock.Unlock() - continue - } - - break + if pc.getPrompting() { + log.Debugf("Skipping over already prompted connection") + return false } - // fmt.Println("XXX: Passed prompt lock!") - outstandingPrompts++ - // fmt.Println("XXX: Incremented outstanding to ", outstandingPrompts) - promptLock.Unlock() - // if !pc.getPrompting() { + pc.setPrompting(true) + fmt.Println("processConnection") go p.processConnection(pc) - // } return true } -func processReturn(pc pendingConnection) { - promptLock.Lock() - outstandingPrompts-- - // fmt.Println("XXX: Return decremented outstanding to ", outstandingPrompts) - promptLock.Unlock() - pc.setPrompting(false) +type PC2FDMapping struct { + guid string + inode uint64 + fd int + fdpath string + prompter *prompter +} + +var PC2FDMap = map[string]PC2FDMapping{} +var PC2FDMapLock = &sync.Mutex{} +var PC2FDMapRunning = false + +func monitorPromptFDs(pc pendingConnection) { + guid := pc.getGUID() + pid := pc.procInfo().Pid + inode := pc.procInfo().Inode + fd := pc.procInfo().FD + prompter := pc.getPrompter() + + fmt.Printf("ADD TO MONITOR: %v | %v / %v / %v\n", pc.policy().application, guid, pid, fd) + + if pid == -1 || fd == -1 || prompter == nil { + log.Warning("Unexpected error condition occurred while adding socket fd to monitor") + return + } + + PC2FDMapLock.Lock() + defer PC2FDMapLock.Unlock() + + fdpath := fmt.Sprintf("/proc/%d/fd/%d", pid, fd) + PC2FDMap[guid] = PC2FDMapping{guid: guid, inode: inode, fd: fd, fdpath: fdpath, prompter: prompter} + return +} + +func monitorPromptFDLoop() { + fmt.Println("++++++++++= monitorPromptFDLoop()") + + for true { + delete_guids := []string{} + PC2FDMapLock.Lock() + fmt.Println("++++ nentries = ", len(PC2FDMap)) + + for guid, fdmon := range PC2FDMap { + fmt.Println("ENTRY:", fdmon) + + lsb, err := os.Stat(fdmon.fdpath) + if err != nil { + log.Warningf("Error looking up socket \"%s\": %v\n", fdmon.fdpath, err) + delete_guids = append(delete_guids, guid) + continue + } + + sb, ok := lsb.Sys().(*syscall.Stat_t) + if !ok { + log.Warning("Not a syscall.Stat_t") + delete_guids = append(delete_guids, guid) + continue + } + + inode := sb.Ino + fmt.Println("+++ INODE = ", inode) + + if inode != fdmon.inode { + fmt.Printf("inode mismatch: %v vs %v\n", inode, fdmon.inode) + delete_guids = append(delete_guids, guid) + } + + } + + fmt.Println("guids to delete: ", delete_guids) + saved_mappings := []PC2FDMapping{} + for _, guid := range delete_guids { + saved_mappings = append(saved_mappings, PC2FDMap[guid]) + delete(PC2FDMap, guid) + } + + PC2FDMapLock.Unlock() + + for _, mapping := range saved_mappings { + call := mapping.prompter.dbusObj.Call("com.subgraph.FirewallPrompt.RemovePrompt", 0, mapping.guid) + fmt.Println("DISPOSING CALL = ", call) + prompter := mapping.prompter + + prompter.lock.Lock() + + for _, policy := range prompter.policyQueue { + policy.lock.Lock() + pcind := 0 + + for pcind < len(policy.pendingQueue) { + + if policy.pendingQueue[pcind].getGUID() == mapping.guid { + fmt.Println("-------------- found guid to remove") + policy.pendingQueue = append(policy.pendingQueue[:pcind], policy.pendingQueue[pcind+1:]...) + } else { + pcind++ + } + + } + + policy.lock.Unlock() + } + + prompter.lock.Unlock() + } + + fmt.Println("++++++++++= monitorPromptFDLoop WAIT") + time.Sleep(5 * time.Second) + } + } func (p *prompter) processConnection(pc pendingConnection) { var scope int32 + var dres bool var rule string - if DoMultiPrompt { - defer processReturn(pc) + if !PC2FDMapRunning { + PC2FDMapLock.Lock() + + if !PC2FDMapRunning { + PC2FDMapRunning = true + PC2FDMapLock.Unlock() + go monitorPromptFDLoop() + } else { + PC2FDMapLock.Unlock() + } + + } + + if pc.getPrompter() == nil { + pc.setPrompter(p) } addr := pc.hostname() @@ -144,10 +227,12 @@ func (p *prompter) processConnection(pc pendingConnection) { if pc.dst() != nil { dststr = pc.dst().String() } else { - dststr = addr + " (proxy to resolve)" + dststr = addr + " (via proxy resolver)" } - call := p.dbusObj.Call("com.subgraph.FirewallPrompt.RequestPrompt", 0, + monitorPromptFDs(pc) + call := p.dbusObj.Call("com.subgraph.FirewallPrompt.RequestPromptAsync", 0, + pc.getGUID(), policy.application, policy.icon, policy.path, @@ -163,18 +248,58 @@ func (p *prompter) processConnection(pc pendingConnection) { int32(pc.procInfo().Pid), pc.sandbox(), pc.socks(), + pc.getTimestamp(), pc.getOptString(), FirewallConfig.PromptExpanded, FirewallConfig.PromptExpert, int32(FirewallConfig.DefaultActionID)) - err := call.Store(&scope, &rule) + + err := call.Store(&dres) if err != nil { - log.Warningf("Error sending dbus RequestPrompt message: %v", err) + log.Warningf("Error sending dbus async RequestPrompt message: %v", err) policy.removePending(pc) pc.drop() return } + if !dres { + fmt.Println("Unexpected: fw-prompt async RequestPrompt message returned:", dres) + } + + return + + /* p.dbusObj.Go("com.subgraph.FirewallPrompt.RequestPrompt", 0, callChan, + pc.getGUID(), + policy.application, + policy.icon, + policy.path, + addr, + int32(pc.dstPort()), + dststr, + pc.src().String(), + pc.proto(), + int32(pc.procInfo().UID), + int32(pc.procInfo().GID), + uidToUser(pc.procInfo().UID), + gidToGroup(pc.procInfo().GID), + int32(pc.procInfo().Pid), + pc.sandbox(), + pc.socks(), + pc.getOptString(), + FirewallConfig.PromptExpanded, + FirewallConfig.PromptExpert, + int32(FirewallConfig.DefaultActionID)) + + saveChannel(callChan, false, true) + + /* err := call.Store(&scope, &rule) + if err != nil { + log.Warningf("Error sending dbus RequestPrompt message: %v", err) + policy.removePending(pc) + pc.drop() + return + } */ + // the prompt sends: // ALLOW|dest or DENY|dest // @@ -194,17 +319,19 @@ func (p *prompter) processConnection(pc pendingConnection) { } tempRule := fmt.Sprintf("%s|%s", toks[0], toks[1]) + tempRule += "||-1:-1|" + sandbox + "|" - if pc.src() != nil && !pc.src().Equal(net.ParseIP("127.0.0.1")) && sandbox != "" { + if pc.src() != nil && !pc.src().IsLoopback() && sandbox != "" { //if !strings.HasSuffix(rule, "SYSTEM") && !strings.HasSuffix(rule, "||") { //rule += "||" //} //ule += "|||" + pc.src().String() - tempRule += "||-1:-1|" + sandbox + "|" + pc.src().String() + // tempRule += "||-1:-1|" + sandbox + "|" + pc.src().String() + tempRule += pc.src().String() } else { - tempRule += "||-1:-1|" + sandbox + "|" + // tempRule += "||-1:-1|" + sandbox + "|" } r, err := policy.parseRule(tempRule, false) if err != nil { @@ -235,35 +362,109 @@ func (p *prompter) processConnection(pc pendingConnection) { } func (p *prompter) nextConnection() (pendingConnection, bool) { - for { - if len(p.policyQueue) == 0 { - return nil, true - } - policy := p.policyQueue[0] + pind := 0 + + if len(p.policyQueue) == 0 { + return nil, true + } + fmt.Println("policy queue len = ", len(p.policyQueue)) + + for pind < len(p.policyQueue) { + //fmt.Printf("XXX: pind = %v of %v\n", pind, len(p.policyQueue)) + policy := p.policyQueue[pind] pc, qempty := policy.nextPending() + if pc == nil && qempty { p.removePolicy(policy) + continue } else { + pind++ + // if pc == nil && !qempty { + + if len(policy.rulesPending) > 0 { + fmt.Println("policy rules pending = ", len(policy.rulesPending)) + + prule := policy.rulesPending[0] + policy.rulesPending = append(policy.rulesPending[:0], policy.rulesPending[1:]...) + + toks := strings.Split(prule.rule, "|") + sandbox := "" + + if len(toks) > 2 { + sandbox = toks[2] + } + sandbox += "" + + tempRule := fmt.Sprintf("%s|%s", toks[0], toks[1]) + tempRule += "||-1:-1|" + sandbox + "|" + + /*if pc.src() != nil && !pc.src().IsLoopback() && sandbox != "" { + tempRule += "||-1:-1|" + sandbox + "|" + pc.src().String() + } else { + tempRule += "||-1:-1|" + sandbox + "|" + }*/ + + r, err := policy.parseRule(tempRule, false) + if err != nil { + log.Warningf("Error parsing rule string returned from dbus RequestPrompt: %v", err) + continue + // policy.removePending(pc) + // pc.drop() + // return + } else { + fscope := FilterScope(prule.scope) + if fscope == APPLY_SESSION { + r.mode = RULE_MODE_SESSION + } else if fscope == APPLY_PROCESS { + r.mode = RULE_MODE_PROCESS + /*r.pid = pc.procInfo().Pid + pcoroner.MonitorProcess(r.pid)*/ + } + if !policy.processNewRule(r, fscope) { + // p.lock.Lock() + // defer p.lock.Unlock() + // p.removePolicy(pc.policy()) + } + if fscope == APPLY_FOREVER { + r.mode = RULE_MODE_PERMANENT + policy.fw.saveRules() + } + log.Warningf("Prompt returning rule: %v", tempRule) + dbusp.alertRule("sgfw prompt added new rule") + } + + } + if pc == nil && !qempty { - log.Errorf("FIX ME: I NEED TO SLEEP ON A WAKEABLE CONDITION PROPERLY!!") + // log.Errorf("FIX ME: I NEED TO SLEEP ON A WAKEABLE CONDITION PROPERLY!!") time.Sleep(time.Millisecond * 300) + continue + } + + if pc != nil && pc.getPrompting() { + fmt.Println("SKIPPING PROMPTED") + continue } + return pc, qempty } } + + return nil, true } func (p *prompter) removePolicy(policy *Policy) { var newQueue []*Policy = nil - if DoMultiPrompt { - if len(p.policyQueue) == 0 { - log.Debugf("Skipping over zero length policy queue") - newQueue = make([]*Policy, 0, 0) - } + // if DoMultiPrompt { + if len(p.policyQueue) == 0 { + log.Debugf("Skipping over zero length policy queue") + newQueue = make([]*Policy, 0, 0) } + // } - if !DoMultiPrompt || newQueue == nil { + // if !DoMultiPrompt || newQueue == nil { + if newQueue == nil { newQueue = make([]*Policy, 0, len(p.policyQueue)-1) } for _, pol := range p.policyQueue { @@ -277,11 +478,17 @@ func (p *prompter) removePolicy(policy *Policy) { var userMap = make(map[int]string) var groupMap = make(map[int]string) +var userMapLock = &sync.Mutex{} +var groupMapLock = &sync.Mutex{} func lookupUser(uid int) string { if uid == -1 { return "[unknown]" } + + userMapLock.Lock() + defer userMapLock.Unlock() + u, err := user.LookupId(strconv.Itoa(uid)) if err != nil { return fmt.Sprintf("%d", uid) @@ -293,6 +500,10 @@ func lookupGroup(gid int) string { if gid == -1 { return "[unknown]" } + + groupMapLock.Lock() + defer groupMapLock.Unlock() + g, err := user.LookupGroupId(strconv.Itoa(gid)) if err != nil { return fmt.Sprintf("%d", gid) diff --git a/sgfw/rules.go b/sgfw/rules.go index adeb3f7..f5795f8 100644 --- a/sgfw/rules.go +++ b/sgfw/rules.go @@ -479,7 +479,10 @@ func (fw *Firewall) loadRules() { if err != nil { if !os.IsNotExist(err) { log.Warningf("Failed to open %s for reading: %v", p, err) + } else { + log.Warningf("Did not find a rules file at %s: SGFW loaded with no rules\n", p) } + return } var policy *Policy @@ -497,7 +500,13 @@ func (fw *Firewall) loadRules() { func (fw *Firewall) processPathLine(line string) *Policy { pathLine := line[1 : len(line)-1] + toks := strings.Split(pathLine, "|") + if len(toks) != 2 { + log.Warning("Error parsing rules directive:", line) + return nil + } + policy := fw.policyForPathAndSandbox(toks[1], toks[0]) policy.lock.Lock() defer policy.lock.Unlock() diff --git a/sgfw/sgfw.go b/sgfw/sgfw.go index 8aa49ac..fe6a428 100644 --- a/sgfw/sgfw.go +++ b/sgfw/sgfw.go @@ -1,16 +1,16 @@ package sgfw import ( + "bufio" + "encoding/json" + "fmt" "os" "os/signal" "regexp" + "strings" "sync" "syscall" - // "time" - "bufio" - "encoding/json" - "fmt" - "strings" + "time" "github.com/op/go-logging" nfqueue "github.com/subgraph/go-nfnetlink/nfqueue" @@ -110,6 +110,8 @@ func (fw *Firewall) runFilter() { go func() { for p := range ps { + timestamp := time.Now() + if fw.isEnabled() { ipLayer := p.Packet.Layer(layers.LayerTypeIPv4) if ipLayer == nil { @@ -127,7 +129,7 @@ func (fw *Firewall) runFilter() { } - fw.filterPacket(p) + fw.filterPacket(p, timestamp) } else { p.Accept() } diff --git a/sgfw/socks_server_chain.go b/sgfw/socks_server_chain.go index c183953..ad597dc 100644 --- a/sgfw/socks_server_chain.go +++ b/sgfw/socks_server_chain.go @@ -56,7 +56,10 @@ type pendingSocksConnection struct { pinfo *procsnitch.Info verdict chan int prompting bool + prompter *prompter + guid string optstr string + timestamp time.Time } func (sc *pendingSocksConnection) sandbox() string { @@ -107,23 +110,58 @@ func (sc *pendingSocksConnection) deliverVerdict(v int) { } }() - sc.verdict <- v - close(sc.verdict) + if sc.verdict != nil { + sc.verdict <- v + close(sc.verdict) + sc.verdict = nil + } } -func (sc *pendingSocksConnection) accept() { sc.deliverVerdict(socksVerdictAccept) } +func (sc *pendingSocksConnection) accept() { + sc.deliverVerdict(socksVerdictAccept) +} // need to generalize special accept -func (sc *pendingSocksConnection) acceptTLSOnly() { sc.deliverVerdict(socksVerdictAcceptTLSOnly) } +func (sc *pendingSocksConnection) acceptTLSOnly() { + sc.deliverVerdict(socksVerdictAcceptTLSOnly) +} -func (sc *pendingSocksConnection) drop() { sc.deliverVerdict(socksVerdictDrop) } +func (sc *pendingSocksConnection) drop() { + sc.deliverVerdict(socksVerdictDrop) +} -func (sc *pendingSocksConnection) getPrompting() bool { return sc.prompting } +func (sc *pendingSocksConnection) setPrompter(val *prompter) { + sc.prompter = val +} -func (sc *pendingSocksConnection) setPrompting(val bool) { sc.prompting = val } +func (sc *pendingSocksConnection) getPrompter() *prompter { + return sc.prompter +} + +func (sc *pendingSocksConnection) getTimestamp() string { + return sc.timestamp.Format("15:04:05.00") +} -func (sc *pendingSocksConnection) print() string { return "socks connection" } +func (sc *pendingSocksConnection) getGUID() string { + if sc.guid == "" { + sc.guid = genGUID() + } + + return sc.guid +} + +func (sc *pendingSocksConnection) getPrompting() bool { + return sc.prompting +} + +func (sc *pendingSocksConnection) setPrompting(val bool) { + sc.prompting = val +} + +func (sc *pendingSocksConnection) print() string { + return "socks connection" +} func NewSocksChain(cfg *socksChainConfig, wg *sync.WaitGroup, fw *Firewall) *socksChain { chain := socksChain{ @@ -146,10 +184,11 @@ func (s *socksChain) start() { } s.wg.Add(1) - go s.socksAcceptLoop() + ts := time.Now() + go s.socksAcceptLoop(ts) } -func (s *socksChain) socksAcceptLoop() error { +func (s *socksChain) socksAcceptLoop(timestamp time.Time) error { defer s.wg.Done() defer s.listener.Close() @@ -163,11 +202,11 @@ func (s *socksChain) socksAcceptLoop() error { continue } session := &socksChainSession{cfg: s.cfg, clientConn: conn, procInfo: s.procInfo, server: s} - go session.sessionWorker() + go session.sessionWorker(timestamp) } } -func (c *socksChainSession) sessionWorker() { +func (c *socksChainSession) sessionWorker(timestamp time.Time) { defer c.clientConn.Close() clientAddr := c.clientConn.RemoteAddr() @@ -197,7 +236,7 @@ func (c *socksChainSession) sessionWorker() { c.req.ReplyAddr(ReplySucceeded, c.bndAddr) } case CommandConnect: - verdict, tls := c.filterConnect() + verdict, tls := c.filterConnect(timestamp) if !verdict { c.req.Reply(ReplyConnectionRefused) @@ -278,7 +317,7 @@ func findProxyEndpoint(pdata []string, conn net.Conn) (*procsnitch.Info, string) return nil, "" } -func (c *socksChainSession) filterConnect() (bool, bool) { +func (c *socksChainSession) filterConnect(timestamp time.Time) (bool, bool) { // return filter verdict, tlsguard allProxies, err := ListProxies() @@ -364,7 +403,9 @@ func (c *socksChainSession) filterConnect() (bool, bool) { pinfo: pinfo, verdict: make(chan int), prompting: false, + prompter: nil, optstr: optstr, + timestamp: timestamp, } policy.processPromptResult(pending) v := <-pending.verdict @@ -409,7 +450,7 @@ func (c *socksChainSession) forwardTraffic(tls bool) { if c.pinfo.Sandbox != "" { log.Errorf("TLSGuard violation: Dropping traffic from %s (sandbox: %s) to %s: %v", c.pinfo.ExePath, c.pinfo.Sandbox, c.req.Addr.addrStr, err) } else { - log.Errorf("TLSGuard violation: Dropping traffic from %s (unsandboxed) to %s: %v", c.pinfo.ExePath, c.req.Addr.addrStr, err) + log.Errorf("TLSGuard violation: Dropping traffic from %s (un-sandboxed) to %s: %v", c.pinfo.ExePath, c.req.Addr.addrStr, err) } return } else { diff --git a/sgfw/tlsguard.go b/sgfw/tlsguard.go index d6b4946..ff7db76 100644 --- a/sgfw/tlsguard.go +++ b/sgfw/tlsguard.go @@ -2,6 +2,8 @@ package sgfw import ( "crypto/x509" + "encoding/binary" + "encoding/hex" "errors" "fmt" "io" @@ -9,197 +11,576 @@ import ( "time" ) -const TLSGUARD_READ_TIMEOUT = 2 * time.Second +const TLSGUARD_READ_TIMEOUT = 10 * time.Second const TLSGUARD_MIN_TLS_VER_MAJ = 3 const TLSGUARD_MIN_TLS_VER_MIN = 1 +const TLS_RECORD_HDR_LEN = 5 + const SSL3_RT_CHANGE_CIPHER_SPEC = 20 const SSL3_RT_ALERT = 21 const SSL3_RT_HANDSHAKE = 22 const SSL3_RT_APPLICATION_DATA = 23 +const SSL3_MT_HELLO_REQUEST = 0 +const SSL3_MT_CLIENT_HELLO = 1 const SSL3_MT_SERVER_HELLO = 2 const SSL3_MT_CERTIFICATE = 11 const SSL3_MT_CERTIFICATE_REQUEST = 13 const SSL3_MT_SERVER_DONE = 14 +const SSL3_MT_CERTIFICATE_STATUS = 22 + +const SSL3_AL_WARNING = 1 +const SSL3_AL_FATAL = 2 +const SSL3_AD_CLOSE_NOTIFY = 0 +const SSL3_AD_UNEXPECTED_MESSAGE = 10 +const SSL3_AD_BAD_RECORD_MAC = 20 +const TLS1_AD_DECRYPTION_FAILED = 21 +const TLS1_AD_RECORD_OVERFLOW = 22 +const SSL3_AD_DECOMPRESSION_FAILURE = 30 +const SSL3_AD_HANDSHAKE_FAILURE = 40 +const SSL3_AD_NO_CERTIFICATE = 41 +const SSL3_AD_BAD_CERTIFICATE = 42 +const SSL3_AD_UNSUPPORTED_CERTIFICATE = 43 +const SSL3_AD_CERTIFICATE_REVOKED = 44 +const SSL3_AD_CERTIFICATE_EXPIRED = 45 +const SSL3_AD_CERTIFICATE_UNKNOWN = 46 +const SSL3_AD_ILLEGAL_PARAMETER = 47 +const TLS1_AD_UNKNOWN_CA = 48 +const TLS1_AD_ACCESS_DENIED = 49 +const TLS1_AD_DECODE_ERROR = 50 +const TLS1_AD_DECRYPT_ERROR = 51 +const TLS1_AD_EXPORT_RESTRICTION = 60 +const TLS1_AD_PROTOCOL_VERSION = 70 +const TLS1_AD_INSUFFICIENT_SECURITY = 71 +const TLS1_AD_INTERNAL_ERROR = 80 +const TLS1_AD_INAPPROPRIATE_FALLBACK = 86 +const TLS1_AD_USER_CANCELLED = 90 +const TLS1_AD_NO_RENEGOTIATION = 100 +const TLS1_AD_UNSUPPORTED_EXTENSION = 110 + +const TLSEXT_TYPE_server_name = 1 +const TLSEXT_TYPE_signature_algorithms = 13 +const TLSEXT_TYPE_client_certificate_type = 19 +const TLSEXT_TYPE_extended_master_secret = 23 +const TLSEXT_TYPE_renegotiate = 0xff01 + +type connReader struct { + client bool + data []byte + rtype int + err error +} -func readTLSChunk(conn net.Conn) ([]byte, int, error) { - conn.SetReadDeadline(time.Now().Add(TLSGUARD_READ_TIMEOUT)) - cbytes, err := readNBytes(conn, 5) - conn.SetReadDeadline(time.Time{}) +var cipherSuiteMap map[uint16]string = map[uint16]string{ + 0x0000: "TLS_NULL_WITH_NULL_NULL", + 0x000a: "TLS_RSA_WITH_3DES_EDE_CBC_SHA", + 0x002f: "TLS_RSA_WITH_AES_128_CBC_SHA", + 0x0033: "TLS_DHE_RSA_WITH_AES_128_CBC_SHA", + 0x0039: "TLS_DHE_RSA_WITH_AES_256_CBC_SHA", + 0x0035: "TLS_RSA_WITH_AES_256_CBC_SHA", + 0x0030: "TLS_DH_DSS_WITH_AES_128_CBC_SHA", + 0xc009: "TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA", + 0xc00a: "TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA", + 0xc013: "TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA", + 0xc014: "TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA", + 0xc02b: "TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256", + 0xc02c: "TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384", + 0xc02f: "TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256", + 0xc030: "TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384", + 0xcca9: "TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256", + 0xcca8: "TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256", +} - if err != nil { - log.Errorf("TLS data chunk read failure: ", err) - return nil, 0, err +func getCipherSuiteName(value uint) string { + val, ok := cipherSuiteMap[uint16(value)] + if !ok { + return "UNKNOWN" } - if int(cbytes[1]) < TLSGUARD_MIN_TLS_VER_MAJ { - return nil, 0, errors.New("TLS protocol major version less than expected minimum") - } else if int(cbytes[2]) < TLSGUARD_MIN_TLS_VER_MIN { - return nil, 0, errors.New("TLS protocol minor version less than expected minimum") - } + return val +} + +func connectionReader(conn net.Conn, is_client bool, c chan connReader, done chan bool) { + var ret_error error = nil + buffered := []byte{} + mlen := 0 + rtype := 0 + stage := 1 + + for { + if ret_error != nil { + cr := connReader{client: is_client, data: nil, rtype: 0, err: ret_error} + c <- cr + break + } + + select { + case <-done: + fmt.Println("++ DONE: ", is_client) + if len(buffered) > 0 { + //fmt.Println("++ DONE BUT DISPOSING OF BUFFERED DATA") + c <- connReader{client: is_client, data: buffered, rtype: 0, err: nil} + } + + c <- connReader{client: is_client, data: nil, rtype: 0, err: nil} + return + default: + if stage == 1 { + header := make([]byte, TLS_RECORD_HDR_LEN) + conn.SetReadDeadline(time.Now().Add(TLSGUARD_READ_TIMEOUT)) + _, err := io.ReadFull(conn, header) + conn.SetReadDeadline(time.Time{}) + if err != nil { + ret_error = err + continue + } + + if int(header[1]) < TLSGUARD_MIN_TLS_VER_MAJ { + ret_error = errors.New("TLS protocol major version less than expected minimum") + continue + } else if int(header[2]) < TLSGUARD_MIN_TLS_VER_MIN { + ret_error = errors.New("TLS protocol minor version less than expected minimum") + continue + } + + rtype = int(header[0]) + mlen = int(int(header[3])<<8 | int(header[4])) + fmt.Printf("TLS data chunk header read: type = %#x, maj = %v, min = %v, len = %v\n", rtype, header[1], header[2], mlen) + + /* 16384+1024 if compression is not null */ + /* or 16384+2048 if ciphertext */ + if mlen > 16384 { + ret_error = errors.New(fmt.Sprintf("TLSGuard read TLS plaintext record of excessively large length; dropping (%v bytes)", mlen)) + continue + } - cbyte := cbytes[0] - mlen := int(int(cbytes[3])<<8 | int(cbytes[4])) - // fmt.Printf("TLS data chunk header read: type = %#x, maj = %v, min = %v, len = %v\n", cbyte, cbytes[1], cbytes[2], mlen) + buffered = header + stage++ + } else if stage == 2 { + remainder := make([]byte, mlen) + conn.SetReadDeadline(time.Now().Add(TLSGUARD_READ_TIMEOUT)) + _, err := io.ReadFull(conn, remainder) + conn.SetReadDeadline(time.Time{}) + if err != nil { + ret_error = err + continue + } - conn.SetReadDeadline(time.Now().Add(TLSGUARD_READ_TIMEOUT)) - cbytes2, err := readNBytes(conn, mlen) - conn.SetReadDeadline(time.Time{}) + buffered = append(buffered, remainder...) + fmt.Printf("------- CHUNK READ: client: %v, err = %v, bytes = %v\n", is_client, err, len(buffered)) + cr := connReader{client: is_client, data: buffered, rtype: rtype, err: err} + c <- cr + + buffered = []byte{} + rtype = 0 + mlen = 0 + stage = 1 + } + + } - if err != nil { - return nil, 0, err } - cbytes = append(cbytes, cbytes2...) - return cbytes, int(cbyte), nil } func TLSGuard(conn, conn2 net.Conn, fqdn string) error { + x509Valid := false + ndone := 0 // Should this be a requirement? // if strings.HasSuffix(request.DestAddr.FQDN, "onion") { //conn client //conn2 server - // Read the opening message from the client - chunk, rtype, err := readTLSChunk(conn) - if err != nil { - return err - } + fmt.Println("-------- STARTING HANDSHAKE LOOP") + crChan := make(chan connReader) + dChan := make(chan bool, 10) + go connectionReader(conn, true, crChan, dChan) + go connectionReader(conn2, false, crChan, dChan) + + client_expected := SSL3_MT_CLIENT_HELLO + server_expected := SSL3_MT_SERVER_HELLO + +select_loop: + for { + if ndone == 2 { + fmt.Println("DONE channel got both notifications. Terminating loop.") + close(dChan) + close(crChan) + break + } - if rtype != SSL3_RT_HANDSHAKE { - return errors.New("Blocked client from attempting non-TLS connection") - } + select { + case cr := <-crChan: + other := conn - // Pass it on through to the server - conn2.Write(chunk) + if cr.client { + other = conn2 + } - // Read ServerHello - valid := false - loop := 1 + fmt.Printf("++++ SELECT: %v, %v, %v\n", cr.client, cr.err, len(cr.data)) + if cr.err == nil && cr.data == nil { + fmt.Println("DONE channel notification received") + ndone++ + continue + } - passthru := false + if cr.err == nil { + if cr.rtype == SSL3_RT_CHANGE_CIPHER_SPEC || cr.rtype == SSL3_RT_APPLICATION_DATA || + cr.rtype == SSL3_RT_ALERT { + + /* We expect only a single byte of data */ + if cr.rtype == SSL3_RT_CHANGE_CIPHER_SPEC { + fmt.Println("CHANGE CIPHER_SPEC: ", cr.data[TLS_RECORD_HDR_LEN]) + if len(cr.data) != 6 { + return errors.New(fmt.Sprintf("TLSGuard dropped connection with strange change cipher spec data length (%v bytes)", len(cr.data))) + } + if cr.data[TLS_RECORD_HDR_LEN] != 1 { + return errors.New(fmt.Sprintf("TLSGuard dropped connection with strange change cipher spec data (%#x bytes)", cr.data[TLS_RECORD_HDR_LEN])) + } + } else if cr.rtype == SSL3_RT_ALERT { + if cr.data[TLS_RECORD_HDR_LEN] == SSL3_AL_WARNING { + fmt.Println("SSL ALERT TYPE: warning") + } else if cr.data[TLS_RECORD_HDR_LEN] == SSL3_AL_FATAL { + fmt.Println("SSL ALERT TYPE: fatal") + } else { + fmt.Println("SSL ALERT TYPE UNKNOWN") + } + + alert_desc := int(int(cr.data[6])<<8 | int(cr.data[7])) + fmt.Println("ALERT DESCRIPTION: ", alert_desc) + + if cr.data[TLS_RECORD_HDR_LEN] == SSL3_AL_FATAL { + return errors.New(fmt.Sprintf("TLSGuard dropped connection after fatal error alert detected")) + } else if alert_desc == SSL3_AD_CLOSE_NOTIFY { + return errors.New(fmt.Sprintf("TLSGuard dropped connection after close_notify alert detected")) + } + + } + + // fmt.Println("OTHER DATA; PASSING THRU") + if cr.rtype == SSL3_RT_ALERT { + fmt.Println("ALERT = ", cr.data) + } + other.Write(cr.data) + continue + } else if cr.client { + // other.Write(cr.data) + // continue + } else if cr.rtype != SSL3_RT_HANDSHAKE { + return errors.New(fmt.Sprintf("Expected TLS server handshake byte was not received [%#x vs 0x16]", cr.rtype)) + } - for 1 == 1 { - loop++ + if cr.rtype < SSL3_RT_CHANGE_CIPHER_SPEC || cr.rtype > SSL3_RT_APPLICATION_DATA { + return errors.New(fmt.Sprintf("TLSGuard dropping connection with unknown content type: %#x", cr.rtype)) + } - // fmt.Printf("SSL LOOP %v; trying to read: conn2\n", loop) - chunk, rtype, err = readTLSChunk(conn2) + handshakeMsg := cr.data[TLS_RECORD_HDR_LEN:] + s := uint(handshakeMsg[0]) + fmt.Printf("s = %#x\n", s) + // Message len, 3 bytes + if cr.rtype == SSL3_RT_HANDSHAKE { + handshakeMessageLen := handshakeMsg[1:4] + handshakeMessageLenInt := int(int(handshakeMessageLen[0])<<16 | int(handshakeMessageLen[1])<<8 | int(handshakeMessageLen[2])) + fmt.Println("lenint = \n", handshakeMessageLenInt) + } - if err != nil { - log.Debugf("TLSGUARD: OTHER loop %v: trying to read: conn\n", loop) - chunk, rtype, err2 := readTLSChunk(conn) - log.Debugf("TLSGUARD: read: %v, %v, %v\n", err2, rtype, len(chunk)) + if cr.client && s != uint(client_expected) { + return errors.New(fmt.Sprintf("Client sent handshake type %#x but expected %#x", s, client_expected)) + } else if !cr.client && s != uint(server_expected) { + return errors.New(fmt.Sprintf("Server sent handshake type %#x but expected %#x", s, server_expected)) + } - if err2 == nil { - conn2.Write(chunk) - continue - } + if (cr.client && s == SSL3_MT_CLIENT_HELLO) || (!cr.client && s == SSL3_MT_SERVER_HELLO) { + rewrite := false + rewrite_buf := []byte{} + SRC := "" + + if s == SSL3_MT_CLIENT_HELLO { + SRC = "CLIENT" + } else { + server_expected = SSL3_MT_CERTIFICATE + SRC = "SERVER" + } + + hello_offset := 4 + // 2 byte protocol version + fmt.Println(SRC, "HELLO VERSION = ", handshakeMsg[hello_offset:hello_offset+2]) + hello_offset += 2 + // 4 byte Random/GMT time + gmtbytes := binary.BigEndian.Uint32(handshakeMsg[hello_offset : hello_offset+4]) + gmt := time.Unix(int64(gmtbytes), 0) + fmt.Println(SRC, "HELLO GMT = ", gmt) + hello_offset += 4 + // 28 bytes Random/random_bytes + hello_offset += 28 + // 1 byte (32-bit session ID) + sess_len := uint(handshakeMsg[hello_offset]) + fmt.Println(SRC, "HELLO SESSION ID = ", sess_len) + + if sess_len != 0 { + fmt.Printf("ALERT: %v attempting to resume session; intercepting request\n", SRC) + rewrite = true + dcopy := make([]byte, len(cr.data)) + copy(dcopy, cr.data) + // Copy the bytes before the session ID start + rewrite_buf = dcopy[0 : TLS_RECORD_HDR_LEN+hello_offset+1] + // Set the session ID to 0 + rewrite_buf[len(rewrite_buf)-1] = 0 + // Write the new TLS record length + binary.BigEndian.PutUint16(rewrite_buf[3:5], uint16(len(dcopy)-(int(sess_len)+TLS_RECORD_HDR_LEN))) + // Write the new ClientHello length + // Starts after the first 6 bytes (record header + type byte) + orig_len := binary.BigEndian.Uint32(handshakeMsg[0:4]) + // But it's only 3 bytes so mask out the first one + b1 := orig_len & 0xff000000 + orig_len &= 0x00ffffff + orig_len -= uint32(sess_len) + orig_len |= b1 + binary.BigEndian.PutUint32(rewrite_buf[TLS_RECORD_HDR_LEN:], orig_len) + rewrite_buf = append(rewrite_buf, dcopy[TLS_RECORD_HDR_LEN+hello_offset+int(sess_len)+1:]...) + } + + hello_offset += int(sess_len) + 1 + // 2 byte cipher suite array + cs := binary.BigEndian.Uint16(handshakeMsg[hello_offset : hello_offset+2]) + noCS := cs + fmt.Printf("cs = %v / %#x\n", noCS, noCS) + + if !cr.client { + fmt.Printf("SERVER selected ciphersuite: %#x (%s)\n", cs, getCipherSuiteName(uint(cs))) + hello_offset += 2 + } else { + + for csind := 0; csind < int(noCS/2); csind++ { + off := hello_offset + 2 + (csind * 2) + cs = binary.BigEndian.Uint16(handshakeMsg[off : off+2]) + fmt.Printf("%s HELLO CIPHERSUITE: %d/%d: %#x (%s)\n", SRC, csind+1, noCS/2, cs, getCipherSuiteName(uint(cs))) + } + + hello_offset += 2 + int(noCS) + } + + clen := uint(handshakeMsg[hello_offset]) + hello_offset++ + + if !cr.client { + fmt.Println("SERVER selected compression method: ", clen) + } else { + fmt.Println(SRC, "HELLO COMPRESSION METHODS LEN = ", clen) + fmt.Println(SRC, "HELLO COMPRESSION METHODS: ", handshakeMsg[hello_offset:hello_offset+int(clen)]) + hello_offset += int(clen) + } + + var extlen uint16 = 0 + + if hello_offset == len(handshakeMsg) { + fmt.Println("Message didn't have any extensions present") + } else { + extlen = binary.BigEndian.Uint16(handshakeMsg[hello_offset : hello_offset+2]) + fmt.Println(SRC, "HELLO EXTENSIONS LENGTH: ", extlen) + hello_offset += 2 + } + + var exttype uint16 = 0 + if extlen > 2 { + exttype = binary.BigEndian.Uint16(handshakeMsg[hello_offset : hello_offset+2]) + fmt.Println(SRC, "HELLO FIRST EXTENSION TYPE: ", exttype) + } + + if cr.client { + ext_ctr := 0 + + for ext_ctr < int(extlen)-2 { + hello_offset += 2 + ext_ctr += 2 + fmt.Printf("PROGRESS: %v of %v, %v of %v\n", ext_ctr, extlen, hello_offset, len(handshakeMsg)) + exttype2 := binary.BigEndian.Uint16(handshakeMsg[hello_offset : hello_offset+2]) + fmt.Printf("EXTTYPE = %v, 2 = %v\n", exttype, exttype2) + if exttype2 == TLSEXT_TYPE_server_name { + fmt.Println("CLIENT specified server_name extension:") + } + if exttype != TLSEXT_TYPE_signature_algorithms { + fmt.Println("WTF") + } + + hello_offset += 2 + ext_ctr += 2 + inner_len := binary.BigEndian.Uint16(handshakeMsg[hello_offset : hello_offset+2]) + // fmt.Println("INNER LEN = ", inner_len) + hello_offset += int(inner_len) + ext_ctr += int(inner_len) + } + + } + + if extlen > 0 { + fmt.Printf("ALERT: %v attempting to send extensions; intercepting request\n", SRC) + rewrite = true + tocopy := cr.data + + if len(rewrite_buf) > 0 { + tocopy = rewrite_buf + } + + dcopy := make([]byte, len(tocopy)-int(extlen)) + copy(dcopy, tocopy[0:len(tocopy)-int(extlen)]) + rewrite_buf = dcopy + // Write the new TLS record length + binary.BigEndian.PutUint16(rewrite_buf[3:5], uint16(len(dcopy)-(int(sess_len)+TLS_RECORD_HDR_LEN))) + // Write the new ClientHello length + // Starts after the first 6 bytes (record header + type byte) + orig_len := binary.BigEndian.Uint32(rewrite_buf[TLS_RECORD_HDR_LEN:]) + // But it's only 3 bytes so mask out the first one + b1 := orig_len & 0xff000000 + orig_len &= 0x00ffffff + orig_len -= uint32(extlen) + orig_len |= b1 + binary.BigEndian.PutUint32(rewrite_buf[TLS_RECORD_HDR_LEN:], orig_len) + // Write session length 0 at the end + rewrite_buf[len(rewrite_buf)-1] = 0 + rewrite_buf[len(rewrite_buf)-2] = 0 + } + + if rewrite { + fmt.Println("TLSGuard writing back modified handshake data to server") + fmt.Printf("ORIGINAL[%d]: %v\n", len(cr.data), hex.Dump(cr.data)) + fmt.Printf("NEW[%d]: %v\n", len(rewrite_buf), hex.Dump(rewrite_buf)) + other.Write(rewrite_buf) + } else { + other.Write(cr.data) + } + + continue + } - return err - } + if cr.client { + other.Write(cr.data) + continue + } - if rtype == SSL3_RT_CHANGE_CIPHER_SPEC || rtype == SSL3_RT_APPLICATION_DATA || - rtype == SSL3_RT_ALERT { - // fmt.Println("OTHER DATA; PASSING THRU") - passthru = true - } else if rtype == SSL3_RT_HANDSHAKE { - passthru = false - } else { - return errors.New(fmt.Sprintf("Expected TLS server handshake byte was not received [%#x vs 0x16]", rtype)) - } + if !cr.client && server_expected == SSL3_MT_SERVER_HELLO { + server_expected = SSL3_MT_CERTIFICATE + } - if passthru { - // fmt.Println("passthru writing buf again and continuing:") - conn.Write(chunk) - continue - } + if !cr.client && s == SSL3_MT_HELLO_REQUEST { + fmt.Println("Server sent hello request") + continue + } - serverMsg := chunk[5:] - s := serverMsg[0] - log.Debugf("TLSGUARD: s = %#x\n", s) - - if s == SSL3_MT_CERTIFICATE { - // Message len, 3 bytes - serverMessageLen := serverMsg[1:4] - serverMessageLenInt := int(int(serverMessageLen[0])<<16 | int(serverMessageLen[1])<<8 | int(serverMessageLen[2])) - // fmt.Printf("chunk len = %v, serverMsgLen = %v, slint = %v\n", len(chunk), len(serverMsg), serverMessageLenInt) - if len(serverMsg) < serverMessageLenInt { - return errors.New(fmt.Sprintf("len(serverMsg) %v < serverMessageLenInt %v!\n", len(serverMsg), serverMessageLenInt)) - } - serverHelloBody := serverMsg[4 : 4+serverMessageLenInt] - certChainLen := int(int(serverHelloBody[0])<<16 | int(serverHelloBody[1])<<8 | int(serverHelloBody[2])) - remaining := certChainLen - pos := serverHelloBody[3:certChainLen] + if s > SSL3_MT_CERTIFICATE_STATUS { + fmt.Println("WTF: ", cr.data) + } + + // Message len, 3 bytes + handshakeMessageLen := handshakeMsg[1:4] + handshakeMessageLenInt := int(int(handshakeMessageLen[0])<<16 | int(handshakeMessageLen[1])<<8 | int(handshakeMessageLen[2])) + + if s == SSL3_MT_CERTIFICATE { + fmt.Println("HMM") + // fmt.Printf("chunk len = %v, handshakeMsgLen = %v, slint = %v\n", len(chunk), len(handshakeMsg), handshakeMessageLenInt) + if len(handshakeMsg) < handshakeMessageLenInt { + return errors.New(fmt.Sprintf("len(handshakeMsg) %v < handshakeMessageLenInt %v!\n", len(handshakeMsg), handshakeMessageLenInt)) + } + serverHelloBody := handshakeMsg[4 : 4+handshakeMessageLenInt] + certChainLen := int(int(serverHelloBody[0])<<16 | int(serverHelloBody[1])<<8 | int(serverHelloBody[2])) + remaining := certChainLen + pos := serverHelloBody[3:certChainLen] + + // var certChain []*x509.Certificate + var verifyOptions x509.VerifyOptions + + //fqdn = "www.reddit.com" + if fqdn != "" { + verifyOptions.DNSName = fqdn + } + + pool := x509.NewCertPool() + var c *x509.Certificate + + for remaining > 0 { + certLen := int(int(pos[0])<<16 | int(pos[1])<<8 | int(pos[2])) + // fmt.Printf("Certs chain len %d, cert 1 len %d:\n", certChainLen, certLen) + cert := pos[3 : 3+certLen] + certs, err := x509.ParseCertificates(cert) + if remaining == certChainLen { + c = certs[0] + } else { + pool.AddCert(certs[0]) + } + // certChain = append(certChain, certs[0]) + if err != nil { + return err + } + remaining = remaining - certLen - 3 + if remaining > 0 { + pos = pos[3+certLen:] + } + } + + verifyOptions.Intermediates = pool + fmt.Println("ATTEMPTING TO VERIFY: ", fqdn) + _, err := c.Verify(verifyOptions) + fmt.Println("ATTEMPTING TO VERIFY RESULT: ", err) + if err != nil { + return err + } else { + x509Valid = true + } + } - // var certChain []*x509.Certificate - var verifyOptions x509.VerifyOptions + other.Write(cr.data) - if fqdn != "" { - verifyOptions.DNSName = fqdn - } + if x509Valid || (s == SSL3_MT_SERVER_DONE) || (s == SSL3_MT_CERTIFICATE_REQUEST) { + fmt.Println("BREAKING OUT OF LOOP 1") + dChan <- true + fmt.Println("BREAKING OUT OF LOOP 2") + break select_loop + } - pool := x509.NewCertPool() - var c *x509.Certificate + // fmt.Printf("Sending chunk of type %d to client.\n", s) + } else if cr.err != nil { + ndone++ - for remaining > 0 { - certLen := int(int(pos[0])<<16 | int(pos[1])<<8 | int(pos[2])) - // fmt.Printf("Certs chain len %d, cert 1 len %d:\n", certChainLen, certLen) - cert := pos[3 : 3+certLen] - certs, err := x509.ParseCertificates(cert) - if remaining == certChainLen { - c = certs[0] + if cr.client { + fmt.Println("Client read error: ", cr.err) } else { - pool.AddCert(certs[0]) + fmt.Println("Server read error: ", cr.err) } - // certChain = append(certChain, certs[0]) - if err != nil { - return err - } - remaining = remaining - certLen - 3 - if remaining > 0 { - pos = pos[3+certLen:] - } - } - verifyOptions.Intermediates = pool - - // fmt.Println("ATTEMPTING TO VERIFY: ", fqdn) - _, err = c.Verify(verifyOptions) - // fmt.Println("ATTEMPTING TO VERIFY RESULT: ", err) - if err != nil { - return err - } else { - valid = true + + return cr.err } - // lse if s == 0x0d { fmt.Printf("found a client cert request, sending buf to client\n") } - } else if s == SSL3_MT_SERVER_DONE { - conn.Write(chunk) - break - } else if s == SSL3_MT_CERTIFICATE_REQUEST { - break + } - // fmt.Printf("Sending chunk of type %d to client.\n", s) + } - conn.Write(chunk) + fmt.Println("WAITING; ndone = ", ndone) + for ndone < 2 { + fmt.Println("WAITING; ndone = ", ndone) + select { + case cr := <-crChan: + fmt.Printf("CHAN DATA: %v, %v, %v\n", cr.client, cr.err, len(cr.data)) + if cr.err != nil || cr.data == nil { + ndone++ + } else if cr.client { + conn2.Write(cr.data) + } else if !cr.client { + conn.Write(cr.data) + } + + } } - if !valid { + fmt.Println("______ ndone = 2\n") + + // dChan <- true + close(dChan) + + if !x509Valid { return errors.New("Unknown error: TLS connection could not be validated") } return nil -} -func readNBytes(conn net.Conn, numBytes int) ([]byte, error) { - res := make([]byte, 0) - temp := make([]byte, 1) - for i := 0; i < numBytes; i++ { - _, err := io.ReadAtLeast(conn, temp, 1) - if err != nil { - return res, err - } - res = append(res, temp[0]) - } - return res, nil } diff --git a/vendor/github.com/subgraph/go-procsnitch/proc.go b/vendor/github.com/subgraph/go-procsnitch/proc.go index 823850e..fe23232 100644 --- a/vendor/github.com/subgraph/go-procsnitch/proc.go +++ b/vendor/github.com/subgraph/go-procsnitch/proc.go @@ -1,8 +1,8 @@ package procsnitch import ( - "encoding/hex" "encoding/binary" + "encoding/hex" "errors" "fmt" "github.com/op/go-logging" @@ -169,7 +169,7 @@ func ParseIP(ip string) (net.IP, error) { } if isLittleEndian > 0 { - for i := 0; i < len(dst) / 4; i++ { + for i := 0; i < len(dst)/4; i++ { start, end := i*4, (i+1)*4 word := dst[start:end] lval := binary.LittleEndian.Uint32(word) @@ -177,13 +177,13 @@ func ParseIP(ip string) (net.IP, error) { } } -/* if len(dst) == 16 { - dst2 := []byte{dst[3], dst[2], dst[1], dst[0], dst[7], dst[6], dst[5], dst[4], dst[11], dst[10], dst[9], dst[8], dst[15], dst[14], dst[13], dst[12]} - return net.IP(dst2), nil - } - for i, j := 0, len(dst)-1; i < j; i, j = i+1, j-1 { - dst[i], dst[j] = dst[j], dst[i] - } */ + /* if len(dst) == 16 { + dst2 := []byte{dst[3], dst[2], dst[1], dst[0], dst[7], dst[6], dst[5], dst[4], dst[11], dst[10], dst[9], dst[8], dst[15], dst[14], dst[13], dst[12]} + return net.IP(dst2), nil + } + for i, j := 0, len(dst)-1; i < j; i, j = i+1, j-1 { + dst[i], dst[j] = dst[j], dst[i] + } */ return net.IP(dst), nil } @@ -312,6 +312,7 @@ func stripLabel(s string) string { // stolen from github.com/virtao/GoEndian const INT_SIZE int = int(unsafe.Sizeof(0)) + func setEndian() { var i int = 0x1 bs := (*[INT_SIZE]byte)(unsafe.Pointer(&i)) diff --git a/vendor/github.com/subgraph/go-procsnitch/proc_pid.go b/vendor/github.com/subgraph/go-procsnitch/proc_pid.go index a34c9b1..549ea37 100644 --- a/vendor/github.com/subgraph/go-procsnitch/proc_pid.go +++ b/vendor/github.com/subgraph/go-procsnitch/proc_pid.go @@ -23,7 +23,9 @@ type Info struct { FirstArg string ParentCmdLine string ParentExePath string - Sandbox string + Sandbox string + Inode uint64 + FD int } type pidCache struct { @@ -51,10 +53,12 @@ func loadCache() map[uint64]*Info { for _, n := range readdir("/proc") { pid := toPid(n) if pid != 0 { - pinfo := &Info{Pid: pid} - for _, inode := range inodesFromPid(pid) { + inodes, fds := inodesFromPid(pid) + for iind, inode := range inodes { + pinfo := &Info{Inode: inode, Pid: pid, FD: fds[iind]} cmap[inode] = pinfo } + } } return cmap @@ -76,8 +80,9 @@ func toPid(name string) int { return (int)(pid) } -func inodesFromPid(pid int) []uint64 { +func inodesFromPid(pid int) ([]uint64, []int) { var inodes []uint64 + var fds []int fdpath := fmt.Sprintf("/proc/%d/fd", pid) for _, n := range readdir(fdpath) { if link, err := os.Readlink(path.Join(fdpath, n)); err != nil { @@ -85,12 +90,19 @@ func inodesFromPid(pid int) []uint64 { log.Warningf("Error reading link %s: %v", n, err) } } else { + fd, err := strconv.Atoi(n) + if err != nil { + log.Warningf("Error retrieving fd associated with pid %v: %v", pid, err) + fd = -1 + } + if inode := extractSocket(link); inode > 0 { inodes = append(inodes, inode) + fds = append(fds, fd) } } } - return inodes + return inodes, fds } func extractSocket(name string) uint64 { diff --git a/vendor/github.com/subgraph/go-procsnitch/socket.go b/vendor/github.com/subgraph/go-procsnitch/socket.go index f137ed0..a5d838f 100644 --- a/vendor/github.com/subgraph/go-procsnitch/socket.go +++ b/vendor/github.com/subgraph/go-procsnitch/socket.go @@ -111,60 +111,65 @@ func findUDPSocketAll(srcAddr net.IP, srcPort uint16, dstAddr net.IP, dstPort ui if custdata == nil { if strictness == MATCH_STRICT { return findSocket(proto, func(ss socketStatus) bool { - fmt.Println("Match strict") + // fmt.Println("Match strict") return ss.remote.ip.Equal(dstAddr) && ss.local.port == srcPort && ss.local.ip.Equal(srcAddr) //return ss.local.port == srcPort && ss.local.ip.Equal(srcAddr) }) } else if strictness == MATCH_LOOSE { return findSocket(proto, func(ss socketStatus) bool { - /* - fmt.Println("Match loose") - fmt.Printf("sock dst = %v pkt dst = %v\n", ss.remote.ip, dstAddr) - fmt.Printf("sock port = %d pkt port = %d\n", ss.local.port, srcPort) - fmt.Printf("local ip: %v\n source ip: %v\n", ss.local.ip, srcAddr) + /* + fmt.Println("Match loose") + fmt.Printf("sock dst = %v pkt dst = %v\n", ss.remote.ip, dstAddr) + fmt.Printf("sock port = %d pkt port = %d\n", ss.local.port, srcPort) + fmt.Printf("local ip: %v\n source ip: %v\n", ss.local.ip, srcAddr) */ - if (ss.local.port == srcPort && (ss.local.ip.Equal(net.IPv4(0,0,0,0)) && ss.remote.ip.Equal(net.IPv4(0,0,0,0)))) { - fmt.Printf("Matching for UDP socket bound to *:%d\n",ss.local.port) + if (ss.local.port == srcPort) && addrMatchesAny(ss.local.ip) && addrMatchesAny(ss.remote.ip) { + fmt.Printf("Loose match for UDP socket bound to *:%d\n", ss.local.port) return true - } else if (ss.remote.ip.Equal(dstAddr) && ss.local.port == srcPort && ss.local.ip.Equal(srcAddr)) { + } else if ss.remote.ip.Equal(dstAddr) && ss.local.port == srcPort && ss.local.ip.Equal(srcAddr) { return true } - // Finally, loop through all interfaces if src port matches - + // Finally, loop through all interfaces if src port matches if ss.local.port == srcPort { ifs, err := net.Interfaces() if err != nil { - log.Warningf("Error on net.Interfaces(): %v", err) + log.Warning("Error retrieving list of network interfaces for UDP socket lookup:", err) return false } + for _, i := range ifs { + addrs, err := i.Addrs() if err != nil { - log.Warningf("Error on Interface.Addrs(): %v", err) + log.Warning("Error retrieving network interface for UDP socket lookup:", err) return false } + for _, addr := range addrs { var ifip net.IP switch x := addr.(type) { - case *net.IPNet: - ifip = x.IP - case *net.IPAddr: - ifip = x.IP + case *net.IPNet: + ifip = x.IP + case *net.IPAddr: + ifip = x.IP } + if ss.local.ip.Equal(ifip) { - fmt.Printf("Matched on UDP socket bound to %v:%d\n",ifip,srcPort) + fmt.Printf("Matched on UDP socket bound to %v:%d\n", ifip, srcPort) return true } + } } } + return false //return (ss.remote.ip.Equal(dstAddr) || ss.remote.ip.Equal(net.IPv4(0,0,0,0))) && ss.local.port == srcPort && (ss.local.ip.Equal(srcAddr) || ss.local.ip.Equal(net.IPv4(0,0,0,0))) /* - return (ss.remote.ip.Equal(dstAddr) || addrMatchesAny(ss.remote.ip)) && ss.local.port == srcPort && ss.local.ip.Equal(srcAddr) || - (ss.local.ip.Equal(dstAddr) || addrMatchesAny(ss.local.ip)) && ss.remote.port == srcPort && ss.remote.ip.Equal(srcAddr) */ + return (ss.remote.ip.Equal(dstAddr) || addrMatchesAny(ss.remote.ip)) && ss.local.port == srcPort && ss.local.ip.Equal(srcAddr) || + (ss.local.ip.Equal(dstAddr) || addrMatchesAny(ss.local.ip)) && ss.remote.port == srcPort && ss.remote.ip.Equal(srcAddr) */ }) } return findSocket(proto, func(ss socketStatus) bool { @@ -367,11 +372,11 @@ func getSocketLines(proto string) []string { } func addrMatchesAny(addr net.IP) bool { - wildcard := net.IP{0,0,0,0} + wildcard := net.IP{0, 0, 0, 0} if addr.To4() == nil { - wildcard = net.IP{0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0} + wildcard = net.IP{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0} } - return wildcard.Equal(addr) + return wildcard.Equal(addr) }