diff --git a/fw-prompt/fw-prompt.go b/fw-prompt/fw-prompt.go index 124a4fa..38f73d0 100644 --- a/fw-prompt/fw-prompt.go +++ b/fw-prompt/fw-prompt.go @@ -4,6 +4,7 @@ import ( "encoding/json" "errors" "fmt" + "github.com/gotk3/gotk3/gdk" "github.com/gotk3/gotk3/glib" "github.com/gotk3/gotk3/gtk" "io/ioutil" @@ -25,14 +26,6 @@ type fpPreferences struct { Winleft uint } -type decisionWaiter struct { - Cond *sync.Cond - Lock sync.Locker - Ready bool - Scope int - Rule string -} - type ruleColumns struct { nrefs int Path string @@ -54,36 +47,63 @@ type ruleColumns struct { Scope int } +const ( + COL_NO_NREFS = iota + COL_NO_ICON_PIXBUF + COL_NO_GUID + COL_NO_PATH + COL_NO_ICON + COL_NO_PROTO + COL_NO_PID + COL_NO_DSTIP + COL_NO_HOSTNAME + COL_NO_PORT + COL_NO_UID + COL_NO_GID + COL_NO_ORIGIN + COL_NO_TIMESTAMP + COL_NO_IS_SOCKS + COL_NO_OPTSTRING + COL_NO_ACTION + COL_NO_LAST +) + var dbuso *dbusObject var userPrefs fpPreferences var mainWin *gtk.Window var Notebook *gtk.Notebook -var globalLS *gtk.ListStore = nil +var globalTS *gtk.TreeStore = nil var globalTV *gtk.TreeView var globalPromptLock = &sync.Mutex{} +var recentLock = &sync.Mutex{} var globalIcon *gtk.Image -var decisionWaiters []*decisionWaiter var editApp, editTarget, editPort, editUser, editGroup *gtk.Entry var comboProto *gtk.ComboBoxText var radioOnce, radioProcess, radioParent, radioSession, radioPermanent *gtk.RadioButton var btnApprove, btnDeny, btnIgnore *gtk.Button var chkTLS, chkUser, chkGroup *gtk.CheckButton +var recentlyRemoved = []string{} -func dumpDecisions() { - return - fmt.Println("XXX Total of decisions pending: ", len(decisionWaiters)) - for i := 0; i < len(decisionWaiters); i++ { - fmt.Printf("XXX %d ready = %v, rule = %v\n", i+1, decisionWaiters[i].Ready, decisionWaiters[i].Rule) +func wasRecentlyRemoved(guid string) bool { + recentLock.Lock() + defer recentLock.Unlock() + + for gind, g := range recentlyRemoved { + if g == guid { + recentlyRemoved = append(recentlyRemoved[:gind], recentlyRemoved[gind+1:]...) + return true + } } + + return false } -func addDecision() *decisionWaiter { - return nil - decision := decisionWaiter{Lock: &sync.Mutex{}, Ready: false, Scope: int(sgfw.APPLY_ONCE), Rule: ""} - decision.Cond = sync.NewCond(decision.Lock) - decisionWaiters = append(decisionWaiters, &decision) - return &decision +func addRecentlyRemoved(guid string) { + recentLock.Lock() + defer recentLock.Unlock() + fmt.Println("RECENTLY REMOVED: ", guid) + recentlyRemoved = append(recentlyRemoved, guid) } func promptInfo(msg string) { @@ -298,15 +318,27 @@ func get_label(text string) *gtk.Label { return label } -func createColumn(title string, id int) *gtk.TreeViewColumn { - cellRenderer, err := gtk.CellRendererTextNew() +func createColumnImg(title string, id int) *gtk.TreeViewColumn { + cellRenderer, err := gtk.CellRendererPixbufNew() + if err != nil { + log.Fatal("Unable to create image cell renderer:", err) + } + column, err := gtk.TreeViewColumnNewWithAttribute(title, cellRenderer, "pixbuf", id) + if err != nil { + log.Fatal("Unable to create cell column:", err) + } + + return column +} + +func createColumnText(title string, id int) *gtk.TreeViewColumn { + cellRenderer, err := gtk.CellRendererTextNew() if err != nil { log.Fatal("Unable to create text cell renderer:", err) } column, err := gtk.TreeViewColumnNewWithAttribute(title, cellRenderer, "text", id) - if err != nil { log.Fatal("Unable to create cell column:", err) } @@ -316,33 +348,57 @@ func createColumn(title string, id int) *gtk.TreeViewColumn { return column } -func createListStore(general bool) *gtk.ListStore { - colData := []glib.Type{glib.TYPE_INT, glib.TYPE_STRING, glib.TYPE_STRING, glib.TYPE_STRING, glib.TYPE_STRING, glib.TYPE_INT, glib.TYPE_STRING, +func createTreeStore(general bool) *gtk.TreeStore { + colData := []glib.Type{glib.TYPE_INT, glib.TYPE_OBJECT, glib.TYPE_STRING, glib.TYPE_STRING, glib.TYPE_STRING, glib.TYPE_STRING, glib.TYPE_INT, glib.TYPE_STRING, glib.TYPE_STRING, glib.TYPE_INT, glib.TYPE_INT, glib.TYPE_INT, glib.TYPE_STRING, glib.TYPE_STRING, glib.TYPE_INT, glib.TYPE_STRING, glib.TYPE_INT} - listStore, err := gtk.ListStoreNew(colData...) + treeStore, err := gtk.TreeStoreNew(colData...) if err != nil { log.Fatal("Unable to create list store:", err) } - return listStore + return treeStore } -func removeRequest(listStore *gtk.ListStore, guid string) { +func removeRequest(treeStore *gtk.TreeStore, guid string) { + if wasRecentlyRemoved(guid) { + fmt.Printf("Entry for %s was recently removed; deleting from cache\n", guid) + return + } + removed := false + + if globalTS == nil { + return + } + globalPromptLock.Lock() defer globalPromptLock.Unlock() - /* XXX: This is horrible. Figure out how to do this properly. */ - for ridx := 0; ridx < globalLS.IterNChildren(nil); ridx++ { - - rule, _, err := getRuleByIdx(ridx) +remove_outer: + for ridx := 0; ridx < globalTS.IterNChildren(nil); ridx++ { + nchildren := 0 + this_iter, err := globalTS.GetIterFromString(fmt.Sprintf("%d", ridx)) if err != nil { - break - } else if rule.GUID == guid { - removeSelectedRule(ridx, true) - removed = true - break + log.Println("Strange condition; couldn't get iter of known tree index:", err) + } else { + nchildren = globalTS.IterNChildren(this_iter) + } + + for cidx := 0; cidx < nchildren-1; cidx++ { + sidx := cidx + if cidx == nchildren { + cidx = -1 + } + + rule, _, err := getRuleByIdx(ridx, sidx) + if err != nil { + break remove_outer + } else if rule.GUID == guid { + removeSelectedRule(ridx, sidx) + removed = true + break + } } } @@ -353,17 +409,104 @@ func removeRequest(listStore *gtk.ListStore, guid string) { } -func addRequestInc(listStore *gtk.ListStore, guid, path, icon, proto string, pid int, ipaddr, hostname string, port, uid, gid int, - origin string, is_socks bool, optstring string, sandbox string, action int) bool { +// Needs to be locked by caller +func storeNewEntry(ts *gtk.TreeStore, iter *gtk.TreeIter, guid, path, icon, proto string, pid int, ipaddr, hostname string, port, uid, gid int, origin, + timestamp string, is_socks bool, optstring, sandbox string, action int) { + var colVals = [COL_NO_LAST]interface{}{} + + if is_socks { + if (optstring != "") && (strings.Index(optstring, "SOCKS") == -1) { + optstring = "SOCKS5 / " + optstring + } else if optstring == "" { + optstring = "SOCKS5" + } + } + + colVals[COL_NO_NREFS] = 1 + colVals[COL_NO_ICON_PIXBUF] = nil + colVals[COL_NO_GUID] = guid + colVals[COL_NO_PATH] = path + colVals[COL_NO_ICON] = icon + colVals[COL_NO_PROTO] = proto + colVals[COL_NO_PID] = pid + + if ipaddr == "" { + colVals[COL_NO_DSTIP] = "---" + } else { + colVals[COL_NO_DSTIP] = ipaddr + } + + colVals[COL_NO_HOSTNAME] = hostname + colVals[COL_NO_PORT] = port + colVals[COL_NO_UID] = uid + colVals[COL_NO_GID] = gid + colVals[COL_NO_ORIGIN] = origin + colVals[COL_NO_TIMESTAMP] = timestamp + colVals[COL_NO_IS_SOCKS] = 0 + + if is_socks { + colVals[COL_NO_IS_SOCKS] = 1 + } + + colVals[COL_NO_OPTSTRING] = optstring + colVals[COL_NO_ACTION] = action + + itheme, err := gtk.IconThemeGetDefault() + if err != nil { + log.Fatal("Could not load default icon theme:", err) + } + + make_blank := false + if icon != "" { + pb, err := itheme.LoadIcon(icon, 24, gtk.ICON_LOOKUP_GENERIC_FALLBACK) + if err != nil { + log.Println("Could not load icon:", err) + make_blank = true + } else { + colVals[COL_NO_ICON_PIXBUF] = pb + } + } else { + make_blank = true + } + + if make_blank { + pb, err := gdk.PixbufNew(gdk.COLORSPACE_RGB, true, 8, 24, 24) + if err != nil { + log.Println("Error creating blank icon:", err) + } else { + colVals[COL_NO_ICON_PIXBUF] = pb + + img, err := gtk.ImageNewFromPixbuf(pb) + if err != nil { + log.Println("Error creating image from pixbuf:", err) + } else { + img.Clear() + pb = img.GetPixbuf() + colVals[COL_NO_ICON_PIXBUF] = pb + } + } + + } + + for n := 0; n < len(colVals); n++ { + err := ts.SetValue(iter, n, colVals[n]) + if err != nil { + log.Fatal("Unable to add row:", err) + } + } + + return +} + +func addRequestInc(treeStore *gtk.TreeStore, guid, path, icon, proto string, pid int, ipaddr, hostname string, port, uid, gid int, + origin, timestamp string, is_socks bool, optstring string, sandbox string, action int) bool { duplicated := false globalPromptLock.Lock() defer globalPromptLock.Unlock() - for ridx := 0; ridx < globalLS.IterNChildren(nil); ridx++ { - - /* XXX: This is horrible. Figure out how to do this properly. */ - rule, iter, err := getRuleByIdx(ridx) + for ridx := 0; ridx < globalTS.IterNChildren(nil); ridx++ { + rule, iter, err := getRuleByIdx(ridx, -1) if err != nil { break // XXX: not compared: optstring/sandbox @@ -371,14 +514,15 @@ func addRequestInc(listStore *gtk.ListStore, guid, path, icon, proto string, pid (rule.Port == port) && (rule.UID == uid) && (rule.GID == gid) && (rule.Origin == origin) && (rule.IsSocks == is_socks) { rule.nrefs++ - err := globalLS.SetValue(iter, 0, rule.nrefs) + err := globalTS.SetValue(iter, 0, rule.nrefs) if err != nil { log.Println("Error creating duplicate firewall prompt entry:", err) break } - fmt.Println("YES REALLY DUPLICATE: ", rule.nrefs) duplicated = true + subiter := globalTS.Append(iter) + storeNewEntry(globalTS, subiter, guid, path, icon, proto, pid, ipaddr, hostname, port, uid, gid, origin, timestamp, is_socks, optstring, sandbox, action) break } @@ -387,27 +531,27 @@ func addRequestInc(listStore *gtk.ListStore, guid, path, icon, proto string, pid return duplicated } -func addRequestAsync(listStore *gtk.ListStore, guid, path, icon, proto string, pid int, ipaddr, hostname string, port, uid, gid int, +func addRequestAsync(treeStore *gtk.TreeStore, guid, path, icon, proto string, pid int, ipaddr, hostname string, port, uid, gid int, origin, timestamp string, is_socks bool, optstring string, sandbox string, action int) bool { - addRequest(listStore, guid, path, icon, proto, pid, ipaddr, hostname, port, uid, gid, origin, timestamp, is_socks, + addRequest(treeStore, guid, path, icon, proto, pid, ipaddr, hostname, port, uid, gid, origin, timestamp, is_socks, optstring, sandbox, action) return true } -func addRequest(listStore *gtk.ListStore, guid, path, icon, proto string, pid int, ipaddr, hostname string, port, uid, gid int, - origin, timestamp string, is_socks bool, optstring string, sandbox string, action int) *decisionWaiter { - if listStore == nil { - listStore = globalLS +func addRequest(treeStore *gtk.TreeStore, guid, path, icon, proto string, pid int, ipaddr, hostname string, port, uid, gid int, + origin, timestamp string, is_socks bool, optstring string, sandbox string, action int) bool { + if treeStore == nil { + treeStore = globalTS waitTimes := []int{1, 2, 5, 10} - if listStore == nil { + if treeStore == nil { log.Println("SGFW prompter was not ready to receive firewall request... waiting") for _, wtime := range waitTimes { time.Sleep(time.Duration(wtime) * time.Second) - listStore = globalLS + treeStore = globalTS - if listStore != nil { + if treeStore != nil { break } @@ -418,78 +562,26 @@ func addRequest(listStore *gtk.ListStore, guid, path, icon, proto string, pid in } - if listStore == nil { + if treeStore == nil { log.Fatal("SGFW prompter GUI failed to load for unknown reasons") } - if addRequestInc(listStore, guid, path, icon, proto, pid, ipaddr, hostname, port, uid, gid, origin, is_socks, optstring, sandbox, action) { - fmt.Println("REQUEST WAS DUPLICATE") - decision := addDecision() + if addRequestInc(treeStore, guid, path, icon, proto, pid, ipaddr, hostname, port, uid, gid, origin, timestamp, is_socks, optstring, sandbox, action) { + fmt.Println("Request was duplicate: ", guid) globalPromptLock.Lock() toggleHover() globalPromptLock.Unlock() - return decision - } else { - fmt.Println("NOT DUPLICATE") + return true } globalPromptLock.Lock() - iter := listStore.Append() - - if is_socks { - if (optstring != "") && (strings.Index(optstring, "SOCKS") == -1) { - optstring = "SOCKS5 / " + optstring - } else if optstring == "" { - optstring = "SOCKS5" - } - } - - colVals := make([]interface{}, 16) - colVals[0] = 1 - colVals[1] = guid - colVals[2] = path - colVals[3] = icon - colVals[4] = proto - colVals[5] = pid - - if ipaddr == "" { - colVals[6] = "---" - } else { - colVals[6] = ipaddr - } - - colVals[7] = hostname - colVals[8] = port - colVals[9] = uid - colVals[10] = gid - colVals[11] = origin - colVals[12] = timestamp - colVals[13] = 0 - - if is_socks { - colVals[13] = 1 - } - - colVals[14] = optstring - colVals[15] = action - - colNums := make([]int, len(colVals)) - - for n := 0; n < len(colVals); n++ { - colNums[n] = n - } - - err := listStore.Set(iter, colNums, colVals) + defer globalPromptLock.Unlock() - if err != nil { - log.Fatal("Unable to add row:", err) - } + iter := treeStore.Append(nil) + storeNewEntry(treeStore, iter, guid, path, icon, proto, pid, ipaddr, hostname, port, uid, gid, origin, timestamp, is_socks, optstring, sandbox, action) - decision := addDecision() - dumpDecisions() toggleHover() - globalPromptLock.Unlock() - return decision + return true } func setup_settings() { @@ -554,8 +646,8 @@ func setup_settings() { Notebook.AppendPage(scrollbox, hLabel) } -func lsGetStr(ls *gtk.ListStore, iter *gtk.TreeIter, idx int) (string, error) { - val, err := globalLS.GetValue(iter, idx) +func lsGetStr(ls *gtk.TreeStore, iter *gtk.TreeIter, idx int) (string, error) { + val, err := globalTS.GetValue(iter, idx) if err != nil { return "", err } @@ -568,8 +660,8 @@ func lsGetStr(ls *gtk.ListStore, iter *gtk.TreeIter, idx int) (string, error) { return sval, nil } -func lsGetInt(ls *gtk.ListStore, iter *gtk.TreeIter, idx int) (int, error) { - val, err := globalLS.GetValue(iter, idx) +func lsGetInt(ls *gtk.TreeStore, iter *gtk.TreeIter, idx int) (int, error) { + val, err := globalTS.GetValue(iter, idx) if err != nil { return 0, err } @@ -582,9 +674,9 @@ func lsGetInt(ls *gtk.ListStore, iter *gtk.TreeIter, idx int) (int, error) { return ival.(int), nil } -func makeDecision(idx int, rule string, scope int) error { +func makeDecision(rule string, scope int, guid string) error { var dres bool - call := dbuso.Call("AddRuleAsync", 0, uint32(scope), rule, "*") + call := dbuso.Call("AddRuleAsync", 0, uint32(scope), rule, "*", guid) err := call.Store(&dres) if err != nil { @@ -593,20 +685,12 @@ func makeDecision(idx int, rule string, scope int) error { } fmt.Println("makeDecision remote result:", dres) - - return nil - decisionWaiters[idx].Cond.L.Lock() - decisionWaiters[idx].Rule = rule - decisionWaiters[idx].Scope = scope - decisionWaiters[idx].Ready = true - decisionWaiters[idx].Cond.Signal() - decisionWaiters[idx].Cond.L.Unlock() return nil } /* Do we need to hold the lock while this is called? Stay safe... */ func toggleHover() { - nitems := globalLS.IterNChildren(nil) + nitems := globalTS.IterNChildren(nil) mainWin.SetKeepAbove(nitems > 0) } @@ -730,120 +814,187 @@ func clearEditor() { chkTLS.SetActive(false) } -func removeSelectedRule(idx int, rmdecision bool) error { - fmt.Println("XXX: attempting to remove idx = ", idx) +func removeSelectedRule(idx, subidx int) error { + fmt.Printf("XXX: attempting to remove idx = %v, %v\n", idx, subidx) + ppathstr := fmt.Sprintf("%d", idx) + pathstr := ppathstr - path, err := gtk.TreePathNewFromString(fmt.Sprintf("%d", idx)) - if err != nil { - return err + if subidx > -1 { + pathstr = fmt.Sprintf("%d:%d", idx, subidx) } - iter, err := globalLS.GetIter(path) + iter, err := globalTS.GetIterFromString(pathstr) if err != nil { return err } - globalLS.Remove(iter) + nchildren := globalTS.IterNChildren(iter) + + if nchildren >= 1 { + firstpath := fmt.Sprintf("%d:0", idx) + citer, err := globalTS.GetIterFromString(firstpath) + if err != nil { + return err + } + + gnrefs, err := globalTS.GetValue(iter, COL_NO_NREFS) + if err != nil { + return err + } + + vnrefs, err := gnrefs.GoValue() + if err != nil { + return err + } + + nrefs := vnrefs.(int) - 1 + + for n := 0; n < COL_NO_LAST; n++ { + val, err := globalTS.GetValue(citer, n) + if err != nil { + return err + } + + if n == COL_NO_NREFS { + err = globalTS.SetValue(iter, n, nrefs) + } else { + err = globalTS.SetValue(iter, n, val) + } + + if err != nil { + return err + } + } + + globalTS.Remove(citer) + return nil + } + + globalTS.Remove(iter) + + if subidx > -1 { + ppath, err := gtk.TreePathNewFromString(ppathstr) + if err != nil { + return err + } + + piter, err := globalTS.GetIter(ppath) + if err != nil { + return err + } + + nrefs, err := lsGetInt(globalTS, piter, COL_NO_NREFS) + if err != nil { + return err + } - if rmdecision { - // decisionWaiters = append(decisionWaiters[:idx], decisionWaiters[idx+1:]...) + err = globalTS.SetValue(piter, COL_NO_NREFS, nrefs-1) + if err != nil { + return err + } } toggleHover() return nil } +// Needs to be locked by the caller func numSelections() int { sel, err := globalTV.GetSelection() if err != nil { return -1 } - rows := sel.GetSelectedRows(globalLS) + rows := sel.GetSelectedRows(globalTS) return int(rows.Length()) } // Needs to be locked by the caller -func getRuleByIdx(idx int) (ruleColumns, *gtk.TreeIter, error) { +func getRuleByIdx(idx, subidx int) (ruleColumns, *gtk.TreeIter, error) { rule := ruleColumns{} + tpath := fmt.Sprintf("%d", idx) + + if subidx != -1 { + tpath = fmt.Sprintf("%d:%d", idx, subidx) + } - path, err := gtk.TreePathNewFromString(fmt.Sprintf("%d", idx)) + path, err := gtk.TreePathNewFromString(tpath) if err != nil { return rule, nil, err } - iter, err := globalLS.GetIter(path) + iter, err := globalTS.GetIter(path) if err != nil { return rule, nil, err } - rule.nrefs, err = lsGetInt(globalLS, iter, 0) + rule.nrefs, err = lsGetInt(globalTS, iter, COL_NO_NREFS) if err != nil { return rule, nil, err } - rule.GUID, err = lsGetStr(globalLS, iter, 1) + rule.GUID, err = lsGetStr(globalTS, iter, COL_NO_GUID) if err != nil { return rule, nil, err } - rule.Path, err = lsGetStr(globalLS, iter, 2) + rule.Path, err = lsGetStr(globalTS, iter, COL_NO_PATH) if err != nil { return rule, nil, err } - rule.Icon, err = lsGetStr(globalLS, iter, 3) + rule.Icon, err = lsGetStr(globalTS, iter, COL_NO_ICON) if err != nil { return rule, nil, err } - rule.Proto, err = lsGetStr(globalLS, iter, 4) + rule.Proto, err = lsGetStr(globalTS, iter, COL_NO_PROTO) if err != nil { return rule, nil, err } - rule.Pid, err = lsGetInt(globalLS, iter, 5) + rule.Pid, err = lsGetInt(globalTS, iter, COL_NO_PID) if err != nil { return rule, nil, err } - rule.Target, err = lsGetStr(globalLS, iter, 6) + rule.Target, err = lsGetStr(globalTS, iter, COL_NO_DSTIP) if err != nil { return rule, nil, err } - rule.Hostname, err = lsGetStr(globalLS, iter, 7) + rule.Hostname, err = lsGetStr(globalTS, iter, COL_NO_HOSTNAME) if err != nil { return rule, nil, err } - rule.Port, err = lsGetInt(globalLS, iter, 8) + rule.Port, err = lsGetInt(globalTS, iter, COL_NO_PORT) if err != nil { return rule, nil, err } - rule.UID, err = lsGetInt(globalLS, iter, 9) + rule.UID, err = lsGetInt(globalTS, iter, COL_NO_UID) if err != nil { return rule, nil, err } - rule.GID, err = lsGetInt(globalLS, iter, 10) + rule.GID, err = lsGetInt(globalTS, iter, COL_NO_GID) if err != nil { return rule, nil, err } - rule.Origin, err = lsGetStr(globalLS, iter, 11) + rule.Origin, err = lsGetStr(globalTS, iter, COL_NO_ORIGIN) if err != nil { return rule, nil, err } - rule.Timestamp, err = lsGetStr(globalLS, iter, 12) + rule.Timestamp, err = lsGetStr(globalTS, iter, COL_NO_TIMESTAMP) if err != nil { return rule, nil, err } rule.IsSocks = false - is_socks, err := lsGetInt(globalLS, iter, 13) + is_socks, err := lsGetInt(globalTS, iter, COL_NO_IS_SOCKS) if err != nil { return rule, nil, err } @@ -852,7 +1003,7 @@ func getRuleByIdx(idx int) (ruleColumns, *gtk.TreeIter, error) { rule.IsSocks = true } - rule.Scope, err = lsGetInt(globalLS, iter, 15) + rule.Scope, err = lsGetInt(globalTS, iter, COL_NO_ACTION) if err != nil { return rule, nil, err } @@ -861,116 +1012,78 @@ func getRuleByIdx(idx int) (ruleColumns, *gtk.TreeIter, error) { } // Needs to be locked by the caller -func getSelectedRule() (ruleColumns, int, error) { +func getSelectedRule() (ruleColumns, int, int, error) { rule := ruleColumns{} sel, err := globalTV.GetSelection() if err != nil { - return rule, -1, err + return rule, -1, -1, err } - rows := sel.GetSelectedRows(globalLS) + rows := sel.GetSelectedRows(globalTS) if rows.Length() <= 0 { - return rule, -1, errors.New("No selection was made") + return rule, -1, -1, errors.New("no selection was made") } rdata := rows.NthData(0) - lIndex, err := strconv.Atoi(rdata.(*gtk.TreePath).String()) - if err != nil { - return rule, -1, err - } - - fmt.Println("lindex = ", lIndex) - rule, _, err = getRuleByIdx(lIndex) - if err != nil { - return rule, -1, err - } - - return rule, lIndex, nil -} - -func addPendingPrompts(rules []string) { - - for _, rule := range rules { - fields := strings.Split(rule, "|") - - if len(fields) != 19 { - log.Printf("Got saved prompt message with strange data: \"%s\"", rule) - continue - } + tpath := rdata.(*gtk.TreePath).String() - guid := fields[0] - icon := fields[2] - path := fields[3] - address := fields[4] + subidx := -1 + ptoks := strings.Split(tpath, ":") - port, err := strconv.Atoi(fields[5]) + if len(ptoks) > 2 { + return rule, -1, -1, errors.New("internal error parsing selected item tree path") + } else if len(ptoks) == 2 { + subidx, err = strconv.Atoi(ptoks[1]) if err != nil { - log.Println("Error converting port in pending prompt message to integer:", err) - continue + return rule, -1, -1, err } + tpath = ptoks[0] + } - ip := fields[6] - origin := fields[7] - proto := fields[8] - - uid, err := strconv.Atoi(fields[9]) - if err != nil { - log.Println("Error converting UID in pending prompt message to integer:", err) - continue - } - - gid, err := strconv.Atoi(fields[10]) - if err != nil { - log.Println("Error converting GID in pending prompt message to integer:", err) - continue - } - - pid, err := strconv.Atoi(fields[13]) - if err != nil { - log.Println("Error converting pid in pending prompt message to integer:", err) - continue - } - - sandbox := fields[14] - - is_socks, err := strconv.ParseBool(fields[15]) - if err != nil { - log.Println("Error converting SOCKS flag in pending prompt message to boolean:", err) - continue - } - - timestamp := fields[16] - optstring := fields[17] - - action, err := strconv.Atoi(fields[18]) - if err != nil { - log.Println("Error converting action in pending prompt message to integer:", err) - continue - } + lIndex, err := strconv.Atoi(tpath) + if err != nil { + return rule, -1, -1, err + } - addRequestAsync(nil, guid, path, icon, proto, int(pid), ip, address, int(port), int(uid), int(gid), origin, timestamp, is_socks, optstring, sandbox, action) + // fmt.Printf("lindex = %d : %d\n", lIndex, subidx) + rule, _, err = getRuleByIdx(lIndex, subidx) + if err != nil { + return rule, -1, -1, err } + return rule, lIndex, subidx, nil } func buttonAction(action string) { globalPromptLock.Lock() - rule, idx, err := getSelectedRule() + rule, idx, subidx, err := getSelectedRule() if err != nil { globalPromptLock.Unlock() promptError("Error occurred processing request: " + err.Error()) return } - rule, err = createCurrentRule() + urule, err := createCurrentRule() if err != nil { globalPromptLock.Unlock() promptError("Error occurred constructing new rule: " + err.Error()) return } + // Overlay the rules + rule.Scope = urule.Scope + //rule.Path = urule.Path + rule.Port = urule.Port + rule.Target = urule.Target + rule.Proto = urule.Proto + // rule.UID = urule.UID + // rule.GID = urule.GID + // rule.Uname = urule.Uname + // rule.Gname = urule.Gname + rule.ForceTLS = urule.ForceTLS + fmt.Println("rule = ", rule) rulestr := action @@ -981,9 +1094,9 @@ func buttonAction(action string) { rulestr += "|" + rule.Proto + ":" + rule.Target + ":" + strconv.Itoa(rule.Port) rulestr += "|" + sgfw.RuleModeString[sgfw.RuleMode(rule.Scope)] fmt.Println("RULESTR = ", rulestr) - makeDecision(idx, rulestr, int(rule.Scope)) - fmt.Println("Decision made.") - err = removeSelectedRule(idx, true) + makeDecision(rulestr, int(rule.Scope), rule.GUID) + err = removeSelectedRule(idx, subidx) + addRecentlyRemoved(rule.GUID) globalPromptLock.Unlock() if err == nil { clearEditor() @@ -994,7 +1107,6 @@ func buttonAction(action string) { } func main() { - decisionWaiters = make([]*decisionWaiter, 0) _, err := newDbusServer() if err != nil { log.Fatal("Error:", err) @@ -1104,6 +1216,7 @@ func main() { // globalIcon.SetFromIconName("firefox", gtk.ICON_SIZE_DND) editApp = get_entry("") + editApp.SetEditable(false) editApp.Connect("changed", toggleValidRuleState) hbox.PackStart(lbl, false, false, 10) hbox.PackStart(editApp, true, true, 10) @@ -1166,42 +1279,43 @@ func main() { // box.PackStart(tv, false, true, 5) box.PackStart(scrollbox, false, true, 5) - tv.AppendColumn(createColumn("#", 0)) + tv.AppendColumn(createColumnText("#", COL_NO_NREFS)) + tv.AppendColumn(createColumnImg("", COL_NO_ICON_PIXBUF)) - guidcol := createColumn("GUID", 1) + guidcol := createColumnText("GUID", COL_NO_GUID) guidcol.SetVisible(false) tv.AppendColumn(guidcol) - tv.AppendColumn(createColumn("Path", 2)) + tv.AppendColumn(createColumnText("Path", COL_NO_PATH)) - icol := createColumn("Icon", 3) + icol := createColumnText("Icon", COL_NO_ICON) icol.SetVisible(false) tv.AppendColumn(icol) - tv.AppendColumn(createColumn("Protocol", 4)) - tv.AppendColumn(createColumn("PID", 5)) - tv.AppendColumn(createColumn("IP Address", 6)) - tv.AppendColumn(createColumn("Hostname", 7)) - tv.AppendColumn(createColumn("Port", 8)) - tv.AppendColumn(createColumn("UID", 9)) - tv.AppendColumn(createColumn("GID", 10)) - tv.AppendColumn(createColumn("Origin", 11)) - tv.AppendColumn(createColumn("Timestamp", 12)) - - scol := createColumn("Is SOCKS", 13) + tv.AppendColumn(createColumnText("Protocol", COL_NO_PROTO)) + tv.AppendColumn(createColumnText("PID", COL_NO_PID)) + tv.AppendColumn(createColumnText("IP Address", COL_NO_DSTIP)) + tv.AppendColumn(createColumnText("Hostname", COL_NO_HOSTNAME)) + tv.AppendColumn(createColumnText("Port", COL_NO_PORT)) + tv.AppendColumn(createColumnText("UID", COL_NO_UID)) + tv.AppendColumn(createColumnText("GID", COL_NO_GID)) + tv.AppendColumn(createColumnText("Origin", COL_NO_ORIGIN)) + tv.AppendColumn(createColumnText("Timestamp", COL_NO_TIMESTAMP)) + + scol := createColumnText("Is SOCKS", COL_NO_IS_SOCKS) scol.SetVisible(false) tv.AppendColumn(scol) - tv.AppendColumn(createColumn("Details", 14)) + tv.AppendColumn(createColumnText("Details", COL_NO_OPTSTRING)) - acol := createColumn("Scope", 15) + acol := createColumnText("Scope", COL_NO_ACTION) acol.SetVisible(false) tv.AppendColumn(acol) - listStore := createListStore(true) - globalLS = listStore + treeStore := createTreeStore(true) + globalTS = treeStore - tv.SetModel(listStore) + tv.SetModel(treeStore) btnApprove.Connect("clicked", func() { buttonAction("ALLOW") @@ -1214,7 +1328,7 @@ func main() { // tv.SetActivateOnSingleClick(true) tv.Connect("row-activated", func() { globalPromptLock.Lock() - seldata, _, err := getSelectedRule() + seldata, _, _, err := getSelectedRule() globalPromptLock.Unlock() if err != nil { promptError("Unexpected error reading selected rule: " + err.Error()) @@ -1237,7 +1351,7 @@ func main() { editPort.SetText(strconv.Itoa(seldata.Port)) radioOnce.SetActive(seldata.Scope == int(sgfw.APPLY_ONCE)) - radioProcess.SetSensitive(seldata.Pid > 0) + radioProcess.SetSensitive(seldata.Scope == int(sgfw.APPLY_PROCESS)) radioParent.SetActive(false) radioSession.SetActive(seldata.Scope == int(sgfw.APPLY_SESSION)) radioPermanent.SetActive(seldata.Scope == int(sgfw.APPLY_FOREVER)) @@ -1286,14 +1400,14 @@ func main() { mainWin.ShowAll() // mainWin.SetKeepAbove(true) - var dres = []string{} + var dres bool call := dbuso.Call("GetPendingRequests", 0, "*") err = call.Store(&dres) if err != nil { errmsg := "Could not query running SGFW instance (maybe it's not running?): " + err.Error() promptError(errmsg) - } else { - addPendingPrompts(dres) + } else if !dres { + promptError("Call to sgfw did not succeed; fw-prompt may have loaded without retrieving all pending connections") } gtk.Main() diff --git a/sgfw/config.go b/sgfw/config.go index 81b0564..2b1047f 100644 --- a/sgfw/config.go +++ b/sgfw/config.go @@ -58,7 +58,7 @@ func readConfig() { PromptExpanded: false, PromptExpert: false, DefaultAction: "SESSION", - DefaultActionID: 1, + DefaultActionID: 0, } if len(buf) > 0 { diff --git a/sgfw/const.go b/sgfw/const.go index f0f8bd1..74808f9 100644 --- a/sgfw/const.go +++ b/sgfw/const.go @@ -41,6 +41,7 @@ const ( RULE_MODE_PROCESS RULE_MODE_PERMANENT RULE_MODE_SYSTEM + RULE_MODE_ONCE ) // RuleModeString is used to get a rule mode string from its id @@ -49,6 +50,7 @@ var RuleModeString = map[RuleMode]string{ RULE_MODE_PROCESS: "PROCESS", RULE_MODE_PERMANENT: "PERMANENT", RULE_MODE_SYSTEM: "SYSTEM", + RULE_MODE_ONCE: "ONCE", } // RuleModeValue converts a mode string to its id @@ -57,16 +59,18 @@ var RuleModeValue = map[string]RuleMode{ RuleModeString[RULE_MODE_PROCESS]: RULE_MODE_PROCESS, RuleModeString[RULE_MODE_PERMANENT]: RULE_MODE_PERMANENT, RuleModeString[RULE_MODE_SYSTEM]: RULE_MODE_SYSTEM, + RuleModeString[RULE_MODE_ONCE]: RULE_MODE_ONCE, } //FilterScope contains a filter's time scope type FilterScope uint16 const ( - APPLY_ONCE FilterScope = iota - APPLY_SESSION + APPLY_SESSION FilterScope = iota APPLY_PROCESS APPLY_FOREVER + APPLY_SYSTEM + APPLY_ONCE ) // FilterScopeString converts a filter scope ID to its string @@ -143,9 +147,3 @@ type DbusRule struct { Mode uint16 Sandbox string } - -/*const ( - OZ_FWRULE_WHITELIST = iota - OZ_FWRULE_BLACKLIST - OZ_FWRULE_NONE -) */ diff --git a/sgfw/dbus.go b/sgfw/dbus.go index 95d42e2..fed6100 100644 --- a/sgfw/dbus.go +++ b/sgfw/dbus.go @@ -199,18 +199,19 @@ func (ds *dbusServer) DeleteRule(id uint32) *dbus.Error { return nil } -func (ds *dbusServer) GetPendingRequests(policy string) ([]string, *dbus.Error) { +func (ds *dbusServer) GetPendingRequests(policy string) (bool, *dbus.Error) { + succeeded := true + log.Debug("+++ GetPendingRequests()") ds.fw.lock.Lock() defer ds.fw.lock.Unlock() - pending_data := make([]string, 0) - for pname := range ds.fw.policyMap { policy := ds.fw.policyMap[pname] pqueue := policy.pendingQueue for _, pc := range pqueue { + var dres bool addr := pc.hostname() if addr == "" { addr = pc.dst().String() @@ -224,40 +225,48 @@ func (ds *dbusServer) GetPendingRequests(policy string) ([]string, *dbus.Error) dststr = addr + " (via proxy resolver)" } - pstr := "" - pstr += pc.getGUID() + "|" - pstr += policy.application + "|" - pstr += policy.icon + "|" - pstr += policy.path + "|" - pstr += addr + "|" - pstr += strconv.FormatUint(uint64(pc.dstPort()), 10) + "|" - pstr += dststr + "|" - pstr += pc.src().String() + "|" - pstr += pc.proto() + "|" - pstr += strconv.FormatInt(int64(pc.procInfo().UID), 10) + "|" - pstr += strconv.FormatInt(int64(pc.procInfo().GID), 10) + "|" - pstr += uidToUser(pc.procInfo().UID) + "|" - pstr += gidToGroup(pc.procInfo().GID) + "|" - pstr += strconv.FormatInt(int64(pc.procInfo().Pid), 10) + "|" - pstr += pc.sandbox() + "|" - pstr += strconv.FormatBool(pc.socks()) + "|" - pstr += pc.getTimestamp() + "|" - pstr += pc.getOptString() + "|" - pstr += strconv.FormatUint(uint64(FirewallConfig.DefaultActionID), 10) - pending_data = append(pending_data, pstr) + call := ds.prompter.dbusObj.Call("com.subgraph.FirewallPrompt.RequestPromptAsync", 0, + pc.getGUID(), + policy.application, + policy.icon, + policy.path, + addr, + int32(pc.dstPort()), + dststr, + pc.src().String(), + pc.proto(), + int32(pc.procInfo().UID), + int32(pc.procInfo().GID), + uidToUser(pc.procInfo().UID), + gidToGroup(pc.procInfo().GID), + int32(pc.procInfo().Pid), + pc.sandbox(), + pc.socks(), + pc.getTimestamp(), + pc.getOptString(), + FirewallConfig.PromptExpanded, + FirewallConfig.PromptExpert, + int32(FirewallConfig.DefaultActionID)) + + err := call.Store(&dres) + if err != nil { + log.Warningf("Error sending DBus async pending RequestPrompt message: %v", err) + succeeded = false + } + } } - return pending_data, nil + return succeeded, nil } -func (ds *dbusServer) AddRuleAsync(scope uint32, rule string, policy string) (bool, *dbus.Error) { - log.Warningf("AddRuleAsync %v, %v / %v\n", scope, rule, policy) +func (ds *dbusServer) AddRuleAsync(scope uint32, rule, policy, guid string) (bool, *dbus.Error) { + log.Warningf("AddRuleAsync %v, %v / %v / %v\n", scope, rule, policy, guid) ds.fw.lock.Lock() defer ds.fw.lock.Unlock() - prule := PendingRule{rule: rule, scope: int(scope), policy: policy} + prule := PendingRule{rule: rule, scope: int(scope), policy: policy, guid: guid} for pname := range ds.fw.policyMap { log.Debug("+++ Adding prule to policy") diff --git a/sgfw/icons.go b/sgfw/icons.go index 4beed49..778c0da 100644 --- a/sgfw/icons.go +++ b/sgfw/icons.go @@ -85,5 +85,14 @@ func loadDesktopFile(path string) { icon: icon, name: name, } + + lname := exec + for i := 0; i < 5; i++ { + lname, err = os.Readlink(lname) + if err == nil { + entryMap[lname] = entryMap[exec] + } + } + } } diff --git a/sgfw/ipc.go b/sgfw/ipc.go index cabe8f8..06f599c 100644 --- a/sgfw/ipc.go +++ b/sgfw/ipc.go @@ -8,6 +8,7 @@ import ( "os" "strconv" "strings" + "sync" "github.com/subgraph/oz/ipc" ) @@ -21,9 +22,14 @@ type OzInitProc struct { } var OzInitPids []OzInitProc = []OzInitProc{} +var OzInitPidsLock = sync.Mutex{} + func addInitPid(pid int, name string, sboxid int) { fmt.Println("::::::::::: init pid added: ", pid, " -> ", name) + OzInitPidsLock.Lock() + defer OzInitPidsLock.Unlock() + for i := 0; i < len(OzInitPids); i++ { if OzInitPids[i].Pid == pid { return @@ -36,6 +42,9 @@ func addInitPid(pid int, name string, sboxid int) { func removeInitPid(pid int) { fmt.Println("::::::::::: removing PID: ", pid) + OzInitPidsLock.Lock() + defer OzInitPidsLock.Unlock() + for i := 0; i < len(OzInitPids); i++ { if OzInitPids[i].Pid == pid { OzInitPids = append(OzInitPids[:i], OzInitPids[i+1:]...) @@ -139,19 +148,6 @@ func ReceiverLoop(fw *Firewall, c net.Conn) { c.Write([]byte(ruledesc)) } - /* for i := 0; i < len(sandboxRules); i++ { - rulestr := "" - - if sandboxRules[i].Whitelist { - rulestr += "whitelist" - } else { - rulestr += "blacklist" - } - - rulestr += " " + sandboxRules[i].SrcIf.String() + " -> " + sandboxRules[i].DstIP.String() + " : " + strconv.Itoa(int(sandboxRules[i].DstPort)) + "\n" - c.Write([]byte(rulestr)) - } */ - return } else { tokens := strings.Split(data, " ") @@ -337,12 +333,7 @@ const OzSocketName = "@oz-control" var bSockName = OzSocketName -var messageFactory = ipc.NewMsgFactory( - new(ListProxiesMsg), - new(ListProxiesResp), -) - -func clientConnect() (*ipc.MsgConn, error) { +func init() { bSockName = os.Getenv("SOCKET_NAME") if bSockName != "" { @@ -356,7 +347,14 @@ func clientConnect() (*ipc.MsgConn, error) { } else { bSockName = OzSocketName } +} +var messageFactory = ipc.NewMsgFactory( + new(ListProxiesMsg), + new(ListProxiesResp), +) + +func clientConnect() (*ipc.MsgConn, error) { return ipc.Connect(bSockName, messageFactory, nil) } diff --git a/sgfw/policy.go b/sgfw/policy.go index 936ae73..7d642f0 100644 --- a/sgfw/policy.go +++ b/sgfw/policy.go @@ -23,17 +23,6 @@ var _interpreters = []string{ "bash", } -/*type sandboxRule struct { - SrcIf net.IP - DstIP net.IP - DstPort uint16 - Whitelist bool -} - -var sandboxRules = []sandboxRule { -// { net.IP{172,16,1,42}, net.IP{140,211,166,134}, 21, false }, -} */ - type pendingConnection interface { policy() *Policy procInfo() *procsnitch.Info @@ -222,6 +211,7 @@ type PendingRule struct { rule string scope int policy string + guid string } type Policy struct { @@ -313,7 +303,6 @@ func (p *Policy) processPacket(pkt *nfqueue.NFQPacket, timestamp time.Time, pinf if !FirewallConfig.LogRedact { log.Infof("Lookup(%s): %s", dstip.String(), name) } - // fwo := matchAgainstOzRules(srcip, dstip, dstp) result := p.rules.filterPacket(pkt, pinfo, srcip, name, optstr) switch result { @@ -331,12 +320,8 @@ func (p *Policy) processPacket(pkt *nfqueue.NFQPacket, timestamp time.Time, pinf func (p *Policy) processPromptResult(pc pendingConnection) { p.pendingQueue = append(p.pendingQueue, pc) - //fmt.Println("processPromptResult(): p.promptInProgress = ", p.promptInProgress) - //if DoMultiPrompt || (!DoMultiPrompt && !p.promptInProgress) { - // if !p.promptInProgress { p.promptInProgress = true go p.fw.dbus.prompter.prompt(p) - // } } func (p *Policy) nextPending() (pendingConnection, bool) { @@ -372,6 +357,15 @@ func (p *Policy) removePending(pc pendingConnection) { } } +func (p *Policy) processNewRuleOnce(r *Rule, guid string) bool { + p.lock.Lock() + defer p.lock.Unlock() + + fmt.Println("----------------------- processNewRule() ONCE") + p.filterPendingOne(r, guid) + return true +} + func (p *Policy) processNewRule(r *Rule, scope FilterScope) bool { p.lock.Lock() defer p.lock.Unlock() @@ -419,6 +413,47 @@ func (p *Policy) removeRule(r *Rule) { p.rules = newRules } +func (p *Policy) filterPendingOne(rule *Rule, guid string) { + remaining := []pendingConnection{} + + for _, pc := range p.pendingQueue { + if guid != "" && guid != pc.getGUID() { + continue + } + + if rule.match(pc.src(), pc.dst(), pc.dstPort(), pc.hostname(), pc.proto(), pc.procInfo().UID, pc.procInfo().GID, uidToUser(pc.procInfo().UID), gidToGroup(pc.procInfo().GID), pc.procInfo().Sandbox) { + prompter := pc.getPrompter() + + if prompter == nil { + fmt.Println("-------- prompter = NULL") + } else { + call := prompter.dbusObj.Call("com.subgraph.FirewallPrompt.RemovePrompt", 0, pc.getGUID()) + fmt.Println("CAAAAAAAAAAAAAAALL = ", call) + } + + log.Infof("Adding rule for: %s", rule.getString(FirewallConfig.LogRedact)) + // log.Noticef("%s > %s", rule.getString(FirewallConfig.LogRedact), pc.print()) + if rule.rtype == RULE_ACTION_ALLOW { + pc.accept() + } else if rule.rtype == RULE_ACTION_ALLOW_TLSONLY { + pc.acceptTLSOnly() + } else { + srcs := pc.src().String() + ":" + strconv.Itoa(int(pc.srcPort())) + log.Warningf("DENIED outgoing connection attempt by %s from %s %s -> %s:%d (user prompt) %v", + pc.procInfo().ExePath, pc.proto(), srcs, pc.dst(), pc.dstPort, rule.rtype) + pc.drop() + } + + // XXX: If matching a GUID, we can break out immediately + } else { + remaining = append(remaining, pc) + } + } + if len(remaining) != len(p.pendingQueue) { + p.pendingQueue = remaining + } +} + func (p *Policy) filterPending(rule *Rule) { remaining := []pendingConnection{} for _, pc := range p.pendingQueue { @@ -532,18 +567,6 @@ func (fw *Firewall) filterPacket(pkt *nfqueue.NFQPacket, timestamp time.Time) { */ _, dstip := getPacketIPAddrs(pkt) /* _, dstp := getPacketPorts(pkt) - fwo := eatchAgainstOzRules(srcip, dstip, dstp) - log.Notice("XXX: Attempting [2] to filter packet on rules -> ", fwo) - - if fwo == OZ_FWRULE_WHITELIST { - log.Noticef("Automatically passed through whitelisted sandbox traffic from %s to %s:%d\n", srcip, dstip, dstp) - pkt.Accept() - return - } else if fwo == OZ_FWRULE_BLACKLIST { - log.Noticef("Automatically blocking blacklisted sandbox traffic from %s to %s:%d\n", srcip, dstip, dstp) - pkt.SetMark(1) - pkt.Accept() - return } */ ppath := "*" @@ -633,6 +656,7 @@ func readFileDirect(filename string) ([]byte, error) { func getAllProcNetDataLocal() ([]string, error) { data := "" + OzInitPidsLock.Lock() for i := 0; i < len(OzInitPids); i++ { fname := fmt.Sprintf("/proc/%d/net/tcp", OzInitPids[i]) @@ -647,6 +671,8 @@ func getAllProcNetDataLocal() ([]string, error) { } + OzInitPidsLock.Unlock() + lines := strings.Split(data, "\n") rlines := make([]string, 0) ctr := 1 @@ -692,6 +718,7 @@ func LookupSandboxProc(srcip net.IP, srcp uint16, dstip net.IP, dstp uint16, pro var res *procsnitch.Info = nil var optstr string removePids := make([]int, 0) + OzInitPidsLock.Lock() for i := 0; i < len(OzInitPids); i++ { data := "" @@ -746,6 +773,8 @@ func LookupSandboxProc(srcip net.IP, srcp uint16, dstip net.IP, dstp uint16, pro } + OzInitPidsLock.Unlock() + for _, p := range removePids { removeInitPid(p) } @@ -797,6 +826,7 @@ func findProcessForPacket(pkt *nfqueue.NFQPacket, reverse bool, strictness int) if res == nil { removePids := make([]int, 0) + OzInitPidsLock.Lock() for i := 0; i < len(OzInitPids); i++ { data := "" @@ -845,6 +875,8 @@ func findProcessForPacket(pkt *nfqueue.NFQPacket, reverse bool, strictness int) } + OzInitPidsLock.Unlock() + for _, p := range removePids { removeInitPid(p) } @@ -942,21 +974,3 @@ func getPacketPorts(pkt *nfqueue.NFQPacket) (uint16, uint16) { return s, d } - -/*func matchAgainstOzRules(srci, dsti net.IP, dstp uint16) int { - - for i := 0; i < len(sandboxRules); i++ { - - log.Notice("XXX: Attempting to match: ", srci, " / ", dsti, " / ", dstp, " | ", sandboxRules[i]) - - if sandboxRules[i].SrcIf.Equal(srci) && sandboxRules[i].DstIP.Equal(dsti) && sandboxRules[i].DstPort == dstp { - if sandboxRules[i].Whitelist { - return OZ_FWRULE_WHITELIST - } - return OZ_FWRULE_BLACKLIST - } - - } - - return OZ_FWRULE_NONE -} */ diff --git a/sgfw/prompt.go b/sgfw/prompt.go index bd7ef07..d3efd09 100644 --- a/sgfw/prompt.go +++ b/sgfw/prompt.go @@ -61,6 +61,9 @@ func (p *prompter) processNextPacket() bool { p.lock.Lock() pc, empty = p.nextConnection() p.lock.Unlock() + if pc != nil { + fmt.Println("GOT NON NIL") + } //fmt.Println("XXX: processNextPacket() loop; empty = ", empty, " / pc = ", pc) if pc == nil && empty { return false @@ -142,7 +145,7 @@ func monitorPromptFDLoop() { } inode := sb.Ino - fmt.Println("+++ INODE = ", inode) + // fmt.Println("+++ INODE = ", inode) if inode != fdmon.inode { fmt.Printf("inode mismatch: %v vs %v\n", inode, fdmon.inode) @@ -268,38 +271,6 @@ func (p *prompter) processConnection(pc pendingConnection) { return - /* p.dbusObj.Go("com.subgraph.FirewallPrompt.RequestPrompt", 0, callChan, - pc.getGUID(), - policy.application, - policy.icon, - policy.path, - addr, - int32(pc.dstPort()), - dststr, - pc.src().String(), - pc.proto(), - int32(pc.procInfo().UID), - int32(pc.procInfo().GID), - uidToUser(pc.procInfo().UID), - gidToGroup(pc.procInfo().GID), - int32(pc.procInfo().Pid), - pc.sandbox(), - pc.socks(), - pc.getOptString(), - FirewallConfig.PromptExpanded, - FirewallConfig.PromptExpert, - int32(FirewallConfig.DefaultActionID)) - - saveChannel(callChan, false, true) - - /* err := call.Store(&scope, &rule) - if err != nil { - log.Warningf("Error sending dbus RequestPrompt message: %v", err) - policy.removePending(pc) - pc.drop() - return - } */ - // the prompt sends: // ALLOW|dest or DENY|dest // @@ -370,6 +341,7 @@ func (p *prompter) nextConnection() (pendingConnection, bool) { fmt.Println("policy queue len = ", len(p.policyQueue)) for pind < len(p.policyQueue) { + fmt.Printf("policy loop %d of %d\n", pind, len(p.policyQueue)) //fmt.Printf("XXX: pind = %v of %v\n", pind, len(p.policyQueue)) policy := p.policyQueue[pind] pc, qempty := policy.nextPending() @@ -379,21 +351,54 @@ func (p *prompter) nextConnection() (pendingConnection, bool) { continue } else { pind++ - // if pc == nil && !qempty { - if len(policy.rulesPending) > 0 { - fmt.Println("policy rules pending = ", len(policy.rulesPending)) + pendingOnce := make([]PendingRule, 0) + pendingOther := make([]PendingRule, 0) + + for _, r := range policy.rulesPending { + if r.scope == int(APPLY_ONCE) { + pendingOnce = append(pendingOnce, r) + } else { + pendingOther = append(pendingOther, r) + } + } + fmt.Printf("# pending once = %d, other = %d, pc = %p / policy = %p\n", len(pendingOnce), len(pendingOther), pc, policy) + policy.rulesPending = pendingOther + + // One time filters are all applied right here, at once. + for _, pr := range pendingOnce { + toks := strings.Split(pr.rule, "|") + sandbox := "" + + if len(toks) > 2 { + sandbox = toks[2] + } + tempRule := fmt.Sprintf("%s|%s", toks[0], toks[1]) + tempRule += "||-1:-1|" + sandbox + "|" + + r, err := policy.parseRule(tempRule, false) + if err != nil { + log.Warningf("Error parsing rule string returned from dbus RequestPrompt: %v", err) + continue + } + + r.mode = RuleMode(pr.scope) + fmt.Println("+++++++ processing one time rule: ", pr.rule) + policy.processNewRuleOnce(r, pr.guid) + } + + // if pc == nil && !qempty { + if len(policy.rulesPending) > 0 { + fmt.Println("non/once policy rules pending = ", len(policy.rulesPending)) prule := policy.rulesPending[0] policy.rulesPending = append(policy.rulesPending[:0], policy.rulesPending[1:]...) - toks := strings.Split(prule.rule, "|") sandbox := "" if len(toks) > 2 { sandbox = toks[2] } - sandbox += "" tempRule := fmt.Sprintf("%s|%s", toks[0], toks[1]) tempRule += "||-1:-1|" + sandbox + "|" diff --git a/sgfw/rules.go b/sgfw/rules.go index 610515d..a3e1623 100644 --- a/sgfw/rules.go +++ b/sgfw/rules.go @@ -50,6 +50,7 @@ func (r *Rule) getString(redact bool) string { if r.rtype == RULE_ACTION_ALLOW || r.rtype == RULE_ACTION_ALLOW_TLSONLY { rtype = RuleActionString[r.rtype] } + rmode := "|" + RuleModeString[r.mode] protostr := "" @@ -247,7 +248,7 @@ func (r *Rule) parse(s string) bool { r.saddr = nil parts := strings.Split(s, "|") if len(parts) < 4 || len(parts) > 6 { - log.Notice("invalid number ", len(parts), " of rule parts in line ", s) + log.Notice("Error: invalid number ", len(parts), " of rule parts in line ", s) return false } if parts[2] == "SYSTEM" { @@ -275,7 +276,7 @@ func (r *Rule) parse(s string) bool { r.saddr = net.ParseIP(parts[5]) if r.saddr == nil { - log.Notice("invalid source IP ", parts[5], " in line ", s) + log.Notice("Error: invalid source IP ", parts[5], " in line ", s) return false } diff --git a/sgfw/tlsguard.go b/sgfw/tlsguard.go index ff7db76..9370a06 100644 --- a/sgfw/tlsguard.go +++ b/sgfw/tlsguard.go @@ -3,7 +3,7 @@ package sgfw import ( "crypto/x509" "encoding/binary" - "encoding/hex" + // "encoding/hex" "errors" "fmt" "io" @@ -11,7 +11,7 @@ import ( "time" ) -const TLSGUARD_READ_TIMEOUT = 10 * time.Second +const TLSGUARD_READ_TIMEOUT = 8 * time.Second const TLSGUARD_MIN_TLS_VER_MAJ = 3 const TLSGUARD_MIN_TLS_VER_MIN = 1 @@ -59,12 +59,66 @@ const TLS1_AD_USER_CANCELLED = 90 const TLS1_AD_NO_RENEGOTIATION = 100 const TLS1_AD_UNSUPPORTED_EXTENSION = 110 -const TLSEXT_TYPE_server_name = 1 +const TLSEXT_TYPE_server_name = 0 +const TLSEXT_TYPE_max_fragment_length = 1 +const TLSEXT_TYPE_client_certificate_url = 2 +const TLSEXT_TYPE_trusted_ca_keys = 3 +const TLSEXT_TYPE_truncated_hmac = 4 +const TLSEXT_TYPE_status_request = 5 +const TLSEXT_TYPE_user_mapping = 6 +const TLSEXT_TYPE_client_authz = 7 +const TLSEXT_TYPE_server_authz = 8 +const TLSEXT_TYPE_cert_type = 9 +const TLSEXT_TYPE_supported_groups = 10 +const TLSEXT_TYPE_ec_point_formats = 11 +const TLSEXT_TYPE_srp = 12 const TLSEXT_TYPE_signature_algorithms = 13 +const TLSEXT_TYPE_use_srtp = 14 +const TLSEXT_TYPE_heartbeat = 15 +const TLSEXT_TYPE_application_layer_protocol_negotiation = 16 +const TLSEXT_TYPE_status_request_v2 = 17 +const TLSEXT_TYPE_signed_certificate_timestamp = 18 const TLSEXT_TYPE_client_certificate_type = 19 +const TLSEXT_TYPE_server_certificate_type = 20 +const TLSEXT_TYPE_padding = 21 +const TLSEXT_TYPE_encrypt_then_mac = 22 const TLSEXT_TYPE_extended_master_secret = 23 +const TLSEXT_TYPE_token_binding = 24 +const TLSEXT_TYPE_cached_info = 25 +const TLSEXT_TYPE_SessionTicket = 35 const TLSEXT_TYPE_renegotiate = 0xff01 +var tlsExtensionMap map[uint16]string = map[uint16]string{ + TLSEXT_TYPE_server_name: "TLSEXT_TYPE_server_name", + TLSEXT_TYPE_max_fragment_length: "TLSEXT_TYPE_max_fragment_length", + TLSEXT_TYPE_client_certificate_url: "TLSEXT_TYPE_client_certificate_url", + TLSEXT_TYPE_trusted_ca_keys: "TLSEXT_TYPE_trusted_ca_keys", + TLSEXT_TYPE_truncated_hmac: "TLSEXT_TYPE_truncated_hmac", + TLSEXT_TYPE_status_request: "TLSEXT_TYPE_status_request", + TLSEXT_TYPE_user_mapping: "TLSEXT_TYPE_user_mapping", + TLSEXT_TYPE_client_authz: "TLSEXT_TYPE_client_authz", + TLSEXT_TYPE_server_authz: "TLSEXT_TYPE_server_authz", + TLSEXT_TYPE_cert_type: "TLSEXT_TYPE_cert_type", + TLSEXT_TYPE_supported_groups: "TLSEXT_TYPE_supported_groups", + TLSEXT_TYPE_ec_point_formats: "TLSEXT_TYPE_ec_point_formats", + TLSEXT_TYPE_srp: "TLSEXT_TYPE_srp", + TLSEXT_TYPE_signature_algorithms: "TLSEXT_TYPE_signature_algorithms", + TLSEXT_TYPE_use_srtp: "TLSEXT_TYPE_use_srtp", + TLSEXT_TYPE_heartbeat: "TLSEXT_TYPE_heartbeat", + TLSEXT_TYPE_application_layer_protocol_negotiation: "TLSEXT_TYPE_application_layer_protocol_negotiation", + TLSEXT_TYPE_status_request_v2: "TLSEXT_TYPE_status_request_v2", + TLSEXT_TYPE_signed_certificate_timestamp: "TLSEXT_TYPE_signed_certificate_timestamp", + TLSEXT_TYPE_client_certificate_type: "TLSEXT_TYPE_client_certificate_type", + TLSEXT_TYPE_server_certificate_type: "TLSEXT_TYPE_server_certificate_type", + TLSEXT_TYPE_padding: "TLSEXT_TYPE_padding", + TLSEXT_TYPE_encrypt_then_mac: "TLSEXT_TYPE_encrypt_then_mac", + TLSEXT_TYPE_extended_master_secret: "TLSEXT_TYPE_extended_master_secret", + TLSEXT_TYPE_token_binding: "TLSEXT_TYPE_token_binding", + TLSEXT_TYPE_cached_info: "TLSEXT_TYPE_cached_info", + TLSEXT_TYPE_SessionTicket: "TLSEXT_TYPE_SessionTicket", + TLSEXT_TYPE_renegotiate: "TLSEXT_TYPE_renegotiate", +} + type connReader struct { client bool data []byte @@ -80,18 +134,54 @@ var cipherSuiteMap map[uint16]string = map[uint16]string{ 0x0039: "TLS_DHE_RSA_WITH_AES_256_CBC_SHA", 0x0035: "TLS_RSA_WITH_AES_256_CBC_SHA", 0x0030: "TLS_DH_DSS_WITH_AES_128_CBC_SHA", + 0x0067: "TLS_DHE_RSA_WITH_AES_128_CBC_SHA256", + 0x006b: "TLS_DHE_RSA_WITH_AES_256_CBC_SHA256", + 0x009e: "TLS_DHE_RSA_WITH_AES_128_GCM_SHA256", + 0x009f: "TLS_DHE_RSA_WITH_AES_256_GCM_SHA384", + 0x00c4: "TLS_DHE_RSA_WITH_CAMELLIA_256_CBC_SHA256", 0xc009: "TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA", 0xc00a: "TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA", 0xc013: "TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA", 0xc014: "TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA", + 0xc023: "TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA256", + 0xc024: "TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA384", + 0xc027: "TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA256", + 0xc028: "TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA384", 0xc02b: "TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256", 0xc02c: "TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384", 0xc02f: "TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256", 0xc030: "TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384", + 0xc076: "TLS_ECDHE_RSA_WITH_CAMELLIA_128_CBC_SHA256", + 0xc077: "TLS_ECDHE_RSA_WITH_CAMELLIA_256_CBC_SHA384", + 0xcc13: "TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256", + 0xcc14: "TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256", + 0xcc15: "TLS_DHE_RSA_WITH_CHACHA20_POLY1305_SHA256", 0xcca9: "TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256", 0xcca8: "TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256", } +var whitelistedCiphers = []string{ + "SSL_DHE_RSA_WITH_3DES_EDE_CBC_SHA", + "TLS_DHE_RSA_WITH_AES_128_CBC_SHA", + "TLS_DHE_RSA_WITH_AES_128_GCM_SHA256", + "TLS_DHE_RSA_WITH_AES_256_GCM_SHA384", + "TLS_DHE_RSA_WITH_CHACHA20_POLY1305_SHA256", + "TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256", + "TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384", + "TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256", + "TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA384", + "TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256", + "TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384", + "TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256", + "TLS_RSA_WITH_AES_128_CBC_SHA", + "SSL_RSA_WITH_3DES_EDE_CBC_SHA", +} + +var blacklistedCiphers = []string{ + "TLS_NULL_WITH_NULL_NULL", + "TLS_RSA_WITH_AES_128_CBC_SHA", +} + func getCipherSuiteName(value uint) string { val, ok := cipherSuiteMap[uint16(value)] if !ok { @@ -101,6 +191,79 @@ func getCipherSuiteName(value uint) string { return val } +func isBadCipher(cname string) bool { + for _, cipher := range blacklistedCiphers { + if cipher == cname { + return true + } + } + + return false +} + +func gettlsExtensionName(value uint) string { + // 26-34: Unassigned + // 36-65280: Unassigned + // 65282-65535: Unassigned + + if (value >= 26 && value <= 34) || (value >= 36 && value <= 65280) || (value >= 65282 && value <= 65535) { + return fmt.Sprintf("Unassigned TLS Extension %#x", value) + } + + val, ok := tlsExtensionMap[uint16(value)] + if !ok { + return "UNKNOWN" + } + + return val +} + +func stripTLSData(record []byte, start_ind, end_ind int, len_ind int, len_size int) []byte { + var size uint = 0 + + if len_size < 1 || len_size > 2 { + return nil + } else if start_ind >= end_ind { + return nil + } else if len_ind >= start_ind { + return nil + } + + rcopy := make([]byte, len(record)) + copy(rcopy, record) + + if len_size == 1 { + size = uint(rcopy[len_ind]) + } else if len_size == 2 { + size = uint(binary.BigEndian.Uint16(rcopy[len_ind : len_ind+len_size])) + } + + size -= uint(end_ind - start_ind) + + // Put back the length size + if len_size == 1 { + rcopy[len_ind] = byte(size) + } else if len_size == 2 { + binary.BigEndian.PutUint16(rcopy[len_ind:len_ind+len_size], uint16(size)) + } + + // Patch the record size + rsize := binary.BigEndian.Uint16(rcopy[3:5]) + rsize -= uint16(end_ind - start_ind) + binary.BigEndian.PutUint16(rcopy[3:5], rsize) + + // And finally the 3 byte hello record + hsize := binary.BigEndian.Uint32(rcopy[5:9]) + saved_b := hsize & 0xff000000 + hsize &= 0x00ffffff + hsize -= uint32(end_ind - start_ind) + hsize |= saved_b + binary.BigEndian.PutUint32(rcopy[5:9], hsize) + + result := append(rcopy[:start_ind], rcopy[end_ind:]...) + return result +} + func connectionReader(conn net.Conn, is_client bool, c chan connReader, done chan bool) { var ret_error error = nil buffered := []byte{} @@ -142,6 +305,9 @@ func connectionReader(conn net.Conn, is_client bool, c chan connReader, done cha } else if int(header[2]) < TLSGUARD_MIN_TLS_VER_MIN { ret_error = errors.New("TLS protocol minor version less than expected minimum") continue + } else if int(header[1]) > 3 { + ret_error = errors.New("TLS protocol major version was larger than expected; maybe not TLS handshake?") + continue } rtype = int(header[0]) @@ -184,6 +350,16 @@ func connectionReader(conn net.Conn, is_client bool, c chan connReader, done cha } +func isExpected(val uint, possibilities []uint) bool { + for _, pval := range possibilities { + if val == pval { + return true + } + } + + return false +} + func TLSGuard(conn, conn2 net.Conn, fqdn string) error { x509Valid := false ndone := 0 @@ -196,17 +372,24 @@ func TLSGuard(conn, conn2 net.Conn, fqdn string) error { fmt.Println("-------- STARTING HANDSHAKE LOOP") crChan := make(chan connReader) dChan := make(chan bool, 10) + dChan2 := make(chan bool, 10) go connectionReader(conn, true, crChan, dChan) - go connectionReader(conn2, false, crChan, dChan) + go connectionReader(conn2, false, crChan, dChan2) + + client_expected := []uint{SSL3_MT_CLIENT_HELLO} + server_expected := []uint{SSL3_MT_SERVER_HELLO} - client_expected := SSL3_MT_CLIENT_HELLO - server_expected := SSL3_MT_SERVER_HELLO + client_sess := false + server_sess := false + client_change_cipher := false + server_change_cipher := false select_loop: for { if ndone == 2 { fmt.Println("DONE channel got both notifications. Terminating loop.") close(dChan) + close(dChan2) close(crChan) break } @@ -239,6 +422,12 @@ select_loop: if cr.data[TLS_RECORD_HDR_LEN] != 1 { return errors.New(fmt.Sprintf("TLSGuard dropped connection with strange change cipher spec data (%#x bytes)", cr.data[TLS_RECORD_HDR_LEN])) } + + if cr.client { + client_change_cipher = true + } else { + server_change_cipher = true + } } else if cr.rtype == SSL3_RT_ALERT { if cr.data[TLS_RECORD_HDR_LEN] == SSL3_AL_WARNING { fmt.Println("SSL ALERT TYPE: warning") @@ -248,7 +437,7 @@ select_loop: fmt.Println("SSL ALERT TYPE UNKNOWN") } - alert_desc := int(int(cr.data[6])<<8 | int(cr.data[7])) + alert_desc := int(int(cr.data[5])<<8 | int(cr.data[6])) fmt.Println("ALERT DESCRIPTION: ", alert_desc) if cr.data[TLS_RECORD_HDR_LEN] == SSL3_AL_FATAL { @@ -258,49 +447,46 @@ select_loop: } } - - // fmt.Println("OTHER DATA; PASSING THRU") - if cr.rtype == SSL3_RT_ALERT { - fmt.Println("ALERT = ", cr.data) - } other.Write(cr.data) continue - } else if cr.client { - // other.Write(cr.data) - // continue } else if cr.rtype != SSL3_RT_HANDSHAKE { return errors.New(fmt.Sprintf("Expected TLS server handshake byte was not received [%#x vs 0x16]", cr.rtype)) } - if cr.rtype < SSL3_RT_CHANGE_CIPHER_SPEC || cr.rtype > SSL3_RT_APPLICATION_DATA { - return errors.New(fmt.Sprintf("TLSGuard dropping connection with unknown content type: %#x", cr.rtype)) - } - handshakeMsg := cr.data[TLS_RECORD_HDR_LEN:] s := uint(handshakeMsg[0]) - fmt.Printf("s = %#x\n", s) - // Message len, 3 bytes - if cr.rtype == SSL3_RT_HANDSHAKE { - handshakeMessageLen := handshakeMsg[1:4] - handshakeMessageLenInt := int(int(handshakeMessageLen[0])<<16 | int(handshakeMessageLen[1])<<8 | int(handshakeMessageLen[2])) - fmt.Println("lenint = \n", handshakeMessageLenInt) + handshakeMessageLen := handshakeMsg[1:4] + handshakeMessageLenInt := int(int(handshakeMessageLen[0])<<16 | int(handshakeMessageLen[1])<<8 | int(handshakeMessageLen[2])) + fmt.Printf("s = %#x, lenint = %v, total = %d\n", s, handshakeMessageLenInt, len(cr.data)) + + if (client_sess || server_sess) && (client_change_cipher || server_change_cipher) { + + if handshakeMessageLenInt > len(cr.data)+9 { + log.Notice("TLSGuard saw what looks like a resumed encrypted session... passing connection through") + other.Write(cr.data) + dChan <- true + dChan2 <- true + x509Valid = true + break select_loop + } + } - if cr.client && s != uint(client_expected) { + if cr.client && !isExpected(s, client_expected) { return errors.New(fmt.Sprintf("Client sent handshake type %#x but expected %#x", s, client_expected)) - } else if !cr.client && s != uint(server_expected) { + } else if !cr.client && !isExpected(s, server_expected) { return errors.New(fmt.Sprintf("Server sent handshake type %#x but expected %#x", s, server_expected)) } if (cr.client && s == SSL3_MT_CLIENT_HELLO) || (!cr.client && s == SSL3_MT_SERVER_HELLO) { - rewrite := false - rewrite_buf := []byte{} + // rewrite := false + // rewrite_buf := []byte{} SRC := "" if s == SSL3_MT_CLIENT_HELLO { SRC = "CLIENT" } else { - server_expected = SSL3_MT_CERTIFICATE + server_expected = []uint{SSL3_MT_CERTIFICATE, SSL3_MT_HELLO_REQUEST} SRC = "SERVER" } @@ -319,139 +505,79 @@ select_loop: sess_len := uint(handshakeMsg[hello_offset]) fmt.Println(SRC, "HELLO SESSION ID = ", sess_len) - if sess_len != 0 { - fmt.Printf("ALERT: %v attempting to resume session; intercepting request\n", SRC) - rewrite = true - dcopy := make([]byte, len(cr.data)) - copy(dcopy, cr.data) - // Copy the bytes before the session ID start - rewrite_buf = dcopy[0 : TLS_RECORD_HDR_LEN+hello_offset+1] - // Set the session ID to 0 - rewrite_buf[len(rewrite_buf)-1] = 0 - // Write the new TLS record length - binary.BigEndian.PutUint16(rewrite_buf[3:5], uint16(len(dcopy)-(int(sess_len)+TLS_RECORD_HDR_LEN))) - // Write the new ClientHello length - // Starts after the first 6 bytes (record header + type byte) - orig_len := binary.BigEndian.Uint32(handshakeMsg[0:4]) - // But it's only 3 bytes so mask out the first one - b1 := orig_len & 0xff000000 - orig_len &= 0x00ffffff - orig_len -= uint32(sess_len) - orig_len |= b1 - binary.BigEndian.PutUint32(rewrite_buf[TLS_RECORD_HDR_LEN:], orig_len) - rewrite_buf = append(rewrite_buf, dcopy[TLS_RECORD_HDR_LEN+hello_offset+int(sess_len)+1:]...) - } - - hello_offset += int(sess_len) + 1 - // 2 byte cipher suite array - cs := binary.BigEndian.Uint16(handshakeMsg[hello_offset : hello_offset+2]) - noCS := cs - fmt.Printf("cs = %v / %#x\n", noCS, noCS) - - if !cr.client { - fmt.Printf("SERVER selected ciphersuite: %#x (%s)\n", cs, getCipherSuiteName(uint(cs))) - hello_offset += 2 + if cr.client && sess_len > 0 { + client_sess = true } else { - - for csind := 0; csind < int(noCS/2); csind++ { - off := hello_offset + 2 + (csind * 2) - cs = binary.BigEndian.Uint16(handshakeMsg[off : off+2]) - fmt.Printf("%s HELLO CIPHERSUITE: %d/%d: %#x (%s)\n", SRC, csind+1, noCS/2, cs, getCipherSuiteName(uint(cs))) - } - - hello_offset += 2 + int(noCS) + server_sess = true } - clen := uint(handshakeMsg[hello_offset]) - hello_offset++ + /* + hello_offset += int(sess_len) + 1 + // 2 byte cipher suite array + cs := binary.BigEndian.Uint16(handshakeMsg[hello_offset : hello_offset+2]) + noCS := cs + fmt.Printf("cs = %v / %#x\n", noCS, noCS) - if !cr.client { - fmt.Println("SERVER selected compression method: ", clen) - } else { - fmt.Println(SRC, "HELLO COMPRESSION METHODS LEN = ", clen) - fmt.Println(SRC, "HELLO COMPRESSION METHODS: ", handshakeMsg[hello_offset:hello_offset+int(clen)]) - hello_offset += int(clen) - } - - var extlen uint16 = 0 + saved_ciphersuite_size_off := hello_offset - if hello_offset == len(handshakeMsg) { - fmt.Println("Message didn't have any extensions present") - } else { - extlen = binary.BigEndian.Uint16(handshakeMsg[hello_offset : hello_offset+2]) - fmt.Println(SRC, "HELLO EXTENSIONS LENGTH: ", extlen) - hello_offset += 2 - } + if !cr.client { + fmt.Printf("SERVER selected ciphersuite: %#x (%s)\n", cs, getCipherSuiteName(uint(cs))) + hello_offset += 2 + } else { - var exttype uint16 = 0 - if extlen > 2 { - exttype = binary.BigEndian.Uint16(handshakeMsg[hello_offset : hello_offset+2]) - fmt.Println(SRC, "HELLO FIRST EXTENSION TYPE: ", exttype) - } + for csind := 0; csind < int(noCS/2); csind++ { + off := hello_offset + 2 + (csind * 2) + cs = binary.BigEndian.Uint16(handshakeMsg[off : off+2]) + cname := getCipherSuiteName(uint(cs)) + fmt.Printf("%s HELLO CIPHERSUITE: %d/%d: %#x (%s)\n", SRC, csind+1, noCS/2, cs, cname) - if cr.client { - ext_ctr := 0 + if isBadCipher(cname) { + fmt.Println("BAD CIPHER: ", cname) + } - for ext_ctr < int(extlen)-2 { - hello_offset += 2 - ext_ctr += 2 - fmt.Printf("PROGRESS: %v of %v, %v of %v\n", ext_ctr, extlen, hello_offset, len(handshakeMsg)) - exttype2 := binary.BigEndian.Uint16(handshakeMsg[hello_offset : hello_offset+2]) - fmt.Printf("EXTTYPE = %v, 2 = %v\n", exttype, exttype2) - if exttype2 == TLSEXT_TYPE_server_name { - fmt.Println("CLIENT specified server_name extension:") - } - if exttype != TLSEXT_TYPE_signature_algorithms { - fmt.Println("WTF") } - hello_offset += 2 - ext_ctr += 2 - inner_len := binary.BigEndian.Uint16(handshakeMsg[hello_offset : hello_offset+2]) - // fmt.Println("INNER LEN = ", inner_len) - hello_offset += int(inner_len) - ext_ctr += int(inner_len) + hello_offset += 2 + int(noCS) } - } + clen := uint(handshakeMsg[hello_offset]) + hello_offset++ - if extlen > 0 { - fmt.Printf("ALERT: %v attempting to send extensions; intercepting request\n", SRC) - rewrite = true - tocopy := cr.data + if !cr.client { + fmt.Println("SERVER selected compression method: ", clen) + } else { + fmt.Println(SRC, "HELLO COMPRESSION METHODS LEN = ", clen) + fmt.Println(SRC, "HELLO COMPRESSION METHODS: ", handshakeMsg[hello_offset:hello_offset+int(clen)]) + hello_offset += int(clen) + } + + var extlen uint16 = 0 - if len(rewrite_buf) > 0 { - tocopy = rewrite_buf + if hello_offset == len(handshakeMsg) { + fmt.Println("Message didn't have any extensions present") + } else { + extlen = binary.BigEndian.Uint16(handshakeMsg[hello_offset : hello_offset+2]) + fmt.Println(SRC, "HELLO EXTENSIONS LENGTH: ", extlen) + hello_offset += 2 } - dcopy := make([]byte, len(tocopy)-int(extlen)) - copy(dcopy, tocopy[0:len(tocopy)-int(extlen)]) - rewrite_buf = dcopy - // Write the new TLS record length - binary.BigEndian.PutUint16(rewrite_buf[3:5], uint16(len(dcopy)-(int(sess_len)+TLS_RECORD_HDR_LEN))) - // Write the new ClientHello length - // Starts after the first 6 bytes (record header + type byte) - orig_len := binary.BigEndian.Uint32(rewrite_buf[TLS_RECORD_HDR_LEN:]) - // But it's only 3 bytes so mask out the first one - b1 := orig_len & 0xff000000 - orig_len &= 0x00ffffff - orig_len -= uint32(extlen) - orig_len |= b1 - binary.BigEndian.PutUint32(rewrite_buf[TLS_RECORD_HDR_LEN:], orig_len) - // Write session length 0 at the end - rewrite_buf[len(rewrite_buf)-1] = 0 - rewrite_buf[len(rewrite_buf)-2] = 0 - } + if cr.client { + ext_ctr := 0 + + for ext_ctr < int(extlen)-2 { + exttype := binary.BigEndian.Uint16(handshakeMsg[hello_offset : hello_offset+2]) + hello_offset += 2 + ext_ctr += 2 + // fmt.Printf("PROGRESS: %v of %v, %v of %v\n", ext_ctr, extlen, hello_offset, len(handshakeMsg)) + fmt.Printf("EXTTYPE = %#x (%s)\n", exttype, gettlsExtensionName(uint(exttype))) + inner_len := binary.BigEndian.Uint16(handshakeMsg[hello_offset : hello_offset+2]) + hello_offset += int(inner_len) + 2 + ext_ctr += int(inner_len) + 2 + } - if rewrite { - fmt.Println("TLSGuard writing back modified handshake data to server") - fmt.Printf("ORIGINAL[%d]: %v\n", len(cr.data), hex.Dump(cr.data)) - fmt.Printf("NEW[%d]: %v\n", len(rewrite_buf), hex.Dump(rewrite_buf)) - other.Write(rewrite_buf) - } else { - other.Write(cr.data) - } + }*/ + other.Write(cr.data) continue } @@ -460,25 +586,19 @@ select_loop: continue } - if !cr.client && server_expected == SSL3_MT_SERVER_HELLO { - server_expected = SSL3_MT_CERTIFICATE + if !cr.client && isExpected(SSL3_MT_SERVER_HELLO, server_expected) { + server_expected = []uint{SSL3_MT_CERTIFICATE} } if !cr.client && s == SSL3_MT_HELLO_REQUEST { fmt.Println("Server sent hello request") - continue } if s > SSL3_MT_CERTIFICATE_STATUS { fmt.Println("WTF: ", cr.data) } - // Message len, 3 bytes - handshakeMessageLen := handshakeMsg[1:4] - handshakeMessageLenInt := int(int(handshakeMessageLen[0])<<16 | int(handshakeMessageLen[1])<<8 | int(handshakeMessageLen[2])) - if s == SSL3_MT_CERTIFICATE { - fmt.Println("HMM") // fmt.Printf("chunk len = %v, handshakeMsgLen = %v, slint = %v\n", len(chunk), len(handshakeMsg), handshakeMessageLenInt) if len(handshakeMsg) < handshakeMessageLenInt { return errors.New(fmt.Sprintf("len(handshakeMsg) %v < handshakeMessageLenInt %v!\n", len(handshakeMsg), handshakeMessageLenInt)) @@ -535,6 +655,7 @@ select_loop: if x509Valid || (s == SSL3_MT_SERVER_DONE) || (s == SSL3_MT_CERTIFICATE_REQUEST) { fmt.Println("BREAKING OUT OF LOOP 1") dChan <- true + dChan2 <- true fmt.Println("BREAKING OUT OF LOOP 2") break select_loop } @@ -576,6 +697,7 @@ select_loop: // dChan <- true close(dChan) + close(dChan2) if !x509Valid { return errors.New("Unknown error: TLS connection could not be validated")