Merge back...

shw-merge
xSmurf 7 years ago
parent e5dd1cb538
commit b6ff6c4857

@ -4,6 +4,7 @@ import (
"encoding/json" "encoding/json"
"errors" "errors"
"fmt" "fmt"
"github.com/gotk3/gotk3/gdk"
"github.com/gotk3/gotk3/glib" "github.com/gotk3/gotk3/glib"
"github.com/gotk3/gotk3/gtk" "github.com/gotk3/gotk3/gtk"
"io/ioutil" "io/ioutil"
@ -25,14 +26,6 @@ type fpPreferences struct {
Winleft uint Winleft uint
} }
type decisionWaiter struct {
Cond *sync.Cond
Lock sync.Locker
Ready bool
Scope int
Rule string
}
type ruleColumns struct { type ruleColumns struct {
nrefs int nrefs int
Path string Path string
@ -54,36 +47,63 @@ type ruleColumns struct {
Scope int Scope int
} }
const (
COL_NO_NREFS = iota
COL_NO_ICON_PIXBUF
COL_NO_GUID
COL_NO_PATH
COL_NO_ICON
COL_NO_PROTO
COL_NO_PID
COL_NO_DSTIP
COL_NO_HOSTNAME
COL_NO_PORT
COL_NO_UID
COL_NO_GID
COL_NO_ORIGIN
COL_NO_TIMESTAMP
COL_NO_IS_SOCKS
COL_NO_OPTSTRING
COL_NO_ACTION
COL_NO_LAST
)
var dbuso *dbusObject var dbuso *dbusObject
var userPrefs fpPreferences var userPrefs fpPreferences
var mainWin *gtk.Window var mainWin *gtk.Window
var Notebook *gtk.Notebook var Notebook *gtk.Notebook
var globalLS *gtk.ListStore = nil var globalTS *gtk.TreeStore = nil
var globalTV *gtk.TreeView var globalTV *gtk.TreeView
var globalPromptLock = &sync.Mutex{} var globalPromptLock = &sync.Mutex{}
var recentLock = &sync.Mutex{}
var globalIcon *gtk.Image var globalIcon *gtk.Image
var decisionWaiters []*decisionWaiter
var editApp, editTarget, editPort, editUser, editGroup *gtk.Entry var editApp, editTarget, editPort, editUser, editGroup *gtk.Entry
var comboProto *gtk.ComboBoxText var comboProto *gtk.ComboBoxText
var radioOnce, radioProcess, radioParent, radioSession, radioPermanent *gtk.RadioButton var radioOnce, radioProcess, radioParent, radioSession, radioPermanent *gtk.RadioButton
var btnApprove, btnDeny, btnIgnore *gtk.Button var btnApprove, btnDeny, btnIgnore *gtk.Button
var chkTLS, chkUser, chkGroup *gtk.CheckButton var chkTLS, chkUser, chkGroup *gtk.CheckButton
var recentlyRemoved = []string{}
func dumpDecisions() { func wasRecentlyRemoved(guid string) bool {
return recentLock.Lock()
fmt.Println("XXX Total of decisions pending: ", len(decisionWaiters)) defer recentLock.Unlock()
for i := 0; i < len(decisionWaiters); i++ {
fmt.Printf("XXX %d ready = %v, rule = %v\n", i+1, decisionWaiters[i].Ready, decisionWaiters[i].Rule) for gind, g := range recentlyRemoved {
if g == guid {
recentlyRemoved = append(recentlyRemoved[:gind], recentlyRemoved[gind+1:]...)
return true
}
} }
return false
} }
func addDecision() *decisionWaiter { func addRecentlyRemoved(guid string) {
return nil recentLock.Lock()
decision := decisionWaiter{Lock: &sync.Mutex{}, Ready: false, Scope: int(sgfw.APPLY_ONCE), Rule: ""} defer recentLock.Unlock()
decision.Cond = sync.NewCond(decision.Lock) fmt.Println("RECENTLY REMOVED: ", guid)
decisionWaiters = append(decisionWaiters, &decision) recentlyRemoved = append(recentlyRemoved, guid)
return &decision
} }
func promptInfo(msg string) { func promptInfo(msg string) {
@ -298,15 +318,27 @@ func get_label(text string) *gtk.Label {
return label return label
} }
func createColumn(title string, id int) *gtk.TreeViewColumn { func createColumnImg(title string, id int) *gtk.TreeViewColumn {
cellRenderer, err := gtk.CellRendererTextNew() cellRenderer, err := gtk.CellRendererPixbufNew()
if err != nil {
log.Fatal("Unable to create image cell renderer:", err)
}
column, err := gtk.TreeViewColumnNewWithAttribute(title, cellRenderer, "pixbuf", id)
if err != nil {
log.Fatal("Unable to create cell column:", err)
}
return column
}
func createColumnText(title string, id int) *gtk.TreeViewColumn {
cellRenderer, err := gtk.CellRendererTextNew()
if err != nil { if err != nil {
log.Fatal("Unable to create text cell renderer:", err) log.Fatal("Unable to create text cell renderer:", err)
} }
column, err := gtk.TreeViewColumnNewWithAttribute(title, cellRenderer, "text", id) column, err := gtk.TreeViewColumnNewWithAttribute(title, cellRenderer, "text", id)
if err != nil { if err != nil {
log.Fatal("Unable to create cell column:", err) log.Fatal("Unable to create cell column:", err)
} }
@ -316,34 +348,58 @@ func createColumn(title string, id int) *gtk.TreeViewColumn {
return column return column
} }
func createListStore(general bool) *gtk.ListStore { func createTreeStore(general bool) *gtk.TreeStore {
colData := []glib.Type{glib.TYPE_INT, glib.TYPE_STRING, glib.TYPE_STRING, glib.TYPE_STRING, glib.TYPE_STRING, glib.TYPE_INT, glib.TYPE_STRING, colData := []glib.Type{glib.TYPE_INT, glib.TYPE_OBJECT, glib.TYPE_STRING, glib.TYPE_STRING, glib.TYPE_STRING, glib.TYPE_STRING, glib.TYPE_INT, glib.TYPE_STRING,
glib.TYPE_STRING, glib.TYPE_INT, glib.TYPE_INT, glib.TYPE_INT, glib.TYPE_STRING, glib.TYPE_STRING, glib.TYPE_INT, glib.TYPE_STRING, glib.TYPE_INT} glib.TYPE_STRING, glib.TYPE_INT, glib.TYPE_INT, glib.TYPE_INT, glib.TYPE_STRING, glib.TYPE_STRING, glib.TYPE_INT, glib.TYPE_STRING, glib.TYPE_INT}
listStore, err := gtk.ListStoreNew(colData...)
treeStore, err := gtk.TreeStoreNew(colData...)
if err != nil { if err != nil {
log.Fatal("Unable to create list store:", err) log.Fatal("Unable to create list store:", err)
} }
return listStore return treeStore
} }
func removeRequest(listStore *gtk.ListStore, guid string) { func removeRequest(treeStore *gtk.TreeStore, guid string) {
if wasRecentlyRemoved(guid) {
fmt.Printf("Entry for %s was recently removed; deleting from cache\n", guid)
return
}
removed := false removed := false
if globalTS == nil {
return
}
globalPromptLock.Lock() globalPromptLock.Lock()
defer globalPromptLock.Unlock() defer globalPromptLock.Unlock()
/* XXX: This is horrible. Figure out how to do this properly. */ remove_outer:
for ridx := 0; ridx < globalLS.IterNChildren(nil); ridx++ { for ridx := 0; ridx < globalTS.IterNChildren(nil); ridx++ {
nchildren := 0
this_iter, err := globalTS.GetIterFromString(fmt.Sprintf("%d", ridx))
if err != nil {
log.Println("Strange condition; couldn't get iter of known tree index:", err)
} else {
nchildren = globalTS.IterNChildren(this_iter)
}
rule, _, err := getRuleByIdx(ridx) for cidx := 0; cidx < nchildren-1; cidx++ {
sidx := cidx
if cidx == nchildren {
cidx = -1
}
rule, _, err := getRuleByIdx(ridx, sidx)
if err != nil { if err != nil {
break break remove_outer
} else if rule.GUID == guid { } else if rule.GUID == guid {
removeSelectedRule(ridx, true) removeSelectedRule(ridx, sidx)
removed = true removed = true
break break
} }
}
} }
@ -353,17 +409,104 @@ func removeRequest(listStore *gtk.ListStore, guid string) {
} }
func addRequestInc(listStore *gtk.ListStore, guid, path, icon, proto string, pid int, ipaddr, hostname string, port, uid, gid int, // Needs to be locked by caller
origin string, is_socks bool, optstring string, sandbox string, action int) bool { func storeNewEntry(ts *gtk.TreeStore, iter *gtk.TreeIter, guid, path, icon, proto string, pid int, ipaddr, hostname string, port, uid, gid int, origin,
timestamp string, is_socks bool, optstring, sandbox string, action int) {
var colVals = [COL_NO_LAST]interface{}{}
if is_socks {
if (optstring != "") && (strings.Index(optstring, "SOCKS") == -1) {
optstring = "SOCKS5 / " + optstring
} else if optstring == "" {
optstring = "SOCKS5"
}
}
colVals[COL_NO_NREFS] = 1
colVals[COL_NO_ICON_PIXBUF] = nil
colVals[COL_NO_GUID] = guid
colVals[COL_NO_PATH] = path
colVals[COL_NO_ICON] = icon
colVals[COL_NO_PROTO] = proto
colVals[COL_NO_PID] = pid
if ipaddr == "" {
colVals[COL_NO_DSTIP] = "---"
} else {
colVals[COL_NO_DSTIP] = ipaddr
}
colVals[COL_NO_HOSTNAME] = hostname
colVals[COL_NO_PORT] = port
colVals[COL_NO_UID] = uid
colVals[COL_NO_GID] = gid
colVals[COL_NO_ORIGIN] = origin
colVals[COL_NO_TIMESTAMP] = timestamp
colVals[COL_NO_IS_SOCKS] = 0
if is_socks {
colVals[COL_NO_IS_SOCKS] = 1
}
colVals[COL_NO_OPTSTRING] = optstring
colVals[COL_NO_ACTION] = action
itheme, err := gtk.IconThemeGetDefault()
if err != nil {
log.Fatal("Could not load default icon theme:", err)
}
make_blank := false
if icon != "" {
pb, err := itheme.LoadIcon(icon, 24, gtk.ICON_LOOKUP_GENERIC_FALLBACK)
if err != nil {
log.Println("Could not load icon:", err)
make_blank = true
} else {
colVals[COL_NO_ICON_PIXBUF] = pb
}
} else {
make_blank = true
}
if make_blank {
pb, err := gdk.PixbufNew(gdk.COLORSPACE_RGB, true, 8, 24, 24)
if err != nil {
log.Println("Error creating blank icon:", err)
} else {
colVals[COL_NO_ICON_PIXBUF] = pb
img, err := gtk.ImageNewFromPixbuf(pb)
if err != nil {
log.Println("Error creating image from pixbuf:", err)
} else {
img.Clear()
pb = img.GetPixbuf()
colVals[COL_NO_ICON_PIXBUF] = pb
}
}
}
for n := 0; n < len(colVals); n++ {
err := ts.SetValue(iter, n, colVals[n])
if err != nil {
log.Fatal("Unable to add row:", err)
}
}
return
}
func addRequestInc(treeStore *gtk.TreeStore, guid, path, icon, proto string, pid int, ipaddr, hostname string, port, uid, gid int,
origin, timestamp string, is_socks bool, optstring string, sandbox string, action int) bool {
duplicated := false duplicated := false
globalPromptLock.Lock() globalPromptLock.Lock()
defer globalPromptLock.Unlock() defer globalPromptLock.Unlock()
for ridx := 0; ridx < globalLS.IterNChildren(nil); ridx++ { for ridx := 0; ridx < globalTS.IterNChildren(nil); ridx++ {
rule, iter, err := getRuleByIdx(ridx, -1)
/* XXX: This is horrible. Figure out how to do this properly. */
rule, iter, err := getRuleByIdx(ridx)
if err != nil { if err != nil {
break break
// XXX: not compared: optstring/sandbox // XXX: not compared: optstring/sandbox
@ -371,14 +514,15 @@ func addRequestInc(listStore *gtk.ListStore, guid, path, icon, proto string, pid
(rule.Port == port) && (rule.UID == uid) && (rule.GID == gid) && (rule.Origin == origin) && (rule.IsSocks == is_socks) { (rule.Port == port) && (rule.UID == uid) && (rule.GID == gid) && (rule.Origin == origin) && (rule.IsSocks == is_socks) {
rule.nrefs++ rule.nrefs++
err := globalLS.SetValue(iter, 0, rule.nrefs) err := globalTS.SetValue(iter, 0, rule.nrefs)
if err != nil { if err != nil {
log.Println("Error creating duplicate firewall prompt entry:", err) log.Println("Error creating duplicate firewall prompt entry:", err)
break break
} }
fmt.Println("YES REALLY DUPLICATE: ", rule.nrefs)
duplicated = true duplicated = true
subiter := globalTS.Append(iter)
storeNewEntry(globalTS, subiter, guid, path, icon, proto, pid, ipaddr, hostname, port, uid, gid, origin, timestamp, is_socks, optstring, sandbox, action)
break break
} }
@ -387,27 +531,27 @@ func addRequestInc(listStore *gtk.ListStore, guid, path, icon, proto string, pid
return duplicated return duplicated
} }
func addRequestAsync(listStore *gtk.ListStore, guid, path, icon, proto string, pid int, ipaddr, hostname string, port, uid, gid int, func addRequestAsync(treeStore *gtk.TreeStore, guid, path, icon, proto string, pid int, ipaddr, hostname string, port, uid, gid int,
origin, timestamp string, is_socks bool, optstring string, sandbox string, action int) bool { origin, timestamp string, is_socks bool, optstring string, sandbox string, action int) bool {
addRequest(listStore, guid, path, icon, proto, pid, ipaddr, hostname, port, uid, gid, origin, timestamp, is_socks, addRequest(treeStore, guid, path, icon, proto, pid, ipaddr, hostname, port, uid, gid, origin, timestamp, is_socks,
optstring, sandbox, action) optstring, sandbox, action)
return true return true
} }
func addRequest(listStore *gtk.ListStore, guid, path, icon, proto string, pid int, ipaddr, hostname string, port, uid, gid int, func addRequest(treeStore *gtk.TreeStore, guid, path, icon, proto string, pid int, ipaddr, hostname string, port, uid, gid int,
origin, timestamp string, is_socks bool, optstring string, sandbox string, action int) *decisionWaiter { origin, timestamp string, is_socks bool, optstring string, sandbox string, action int) bool {
if listStore == nil { if treeStore == nil {
listStore = globalLS treeStore = globalTS
waitTimes := []int{1, 2, 5, 10} waitTimes := []int{1, 2, 5, 10}
if listStore == nil { if treeStore == nil {
log.Println("SGFW prompter was not ready to receive firewall request... waiting") log.Println("SGFW prompter was not ready to receive firewall request... waiting")
for _, wtime := range waitTimes { for _, wtime := range waitTimes {
time.Sleep(time.Duration(wtime) * time.Second) time.Sleep(time.Duration(wtime) * time.Second)
listStore = globalLS treeStore = globalTS
if listStore != nil { if treeStore != nil {
break break
} }
@ -418,78 +562,26 @@ func addRequest(listStore *gtk.ListStore, guid, path, icon, proto string, pid in
} }
if listStore == nil { if treeStore == nil {
log.Fatal("SGFW prompter GUI failed to load for unknown reasons") log.Fatal("SGFW prompter GUI failed to load for unknown reasons")
} }
if addRequestInc(listStore, guid, path, icon, proto, pid, ipaddr, hostname, port, uid, gid, origin, is_socks, optstring, sandbox, action) { if addRequestInc(treeStore, guid, path, icon, proto, pid, ipaddr, hostname, port, uid, gid, origin, timestamp, is_socks, optstring, sandbox, action) {
fmt.Println("REQUEST WAS DUPLICATE") fmt.Println("Request was duplicate: ", guid)
decision := addDecision()
globalPromptLock.Lock() globalPromptLock.Lock()
toggleHover() toggleHover()
globalPromptLock.Unlock() globalPromptLock.Unlock()
return decision return true
} else {
fmt.Println("NOT DUPLICATE")
} }
globalPromptLock.Lock() globalPromptLock.Lock()
iter := listStore.Append() defer globalPromptLock.Unlock()
if is_socks {
if (optstring != "") && (strings.Index(optstring, "SOCKS") == -1) {
optstring = "SOCKS5 / " + optstring
} else if optstring == "" {
optstring = "SOCKS5"
}
}
colVals := make([]interface{}, 16)
colVals[0] = 1
colVals[1] = guid
colVals[2] = path
colVals[3] = icon
colVals[4] = proto
colVals[5] = pid
if ipaddr == "" {
colVals[6] = "---"
} else {
colVals[6] = ipaddr
}
colVals[7] = hostname
colVals[8] = port
colVals[9] = uid
colVals[10] = gid
colVals[11] = origin
colVals[12] = timestamp
colVals[13] = 0
if is_socks {
colVals[13] = 1
}
colVals[14] = optstring
colVals[15] = action
colNums := make([]int, len(colVals))
for n := 0; n < len(colVals); n++ {
colNums[n] = n
}
err := listStore.Set(iter, colNums, colVals)
if err != nil { iter := treeStore.Append(nil)
log.Fatal("Unable to add row:", err) storeNewEntry(treeStore, iter, guid, path, icon, proto, pid, ipaddr, hostname, port, uid, gid, origin, timestamp, is_socks, optstring, sandbox, action)
}
decision := addDecision()
dumpDecisions()
toggleHover() toggleHover()
globalPromptLock.Unlock() return true
return decision
} }
func setup_settings() { func setup_settings() {
@ -554,8 +646,8 @@ func setup_settings() {
Notebook.AppendPage(scrollbox, hLabel) Notebook.AppendPage(scrollbox, hLabel)
} }
func lsGetStr(ls *gtk.ListStore, iter *gtk.TreeIter, idx int) (string, error) { func lsGetStr(ls *gtk.TreeStore, iter *gtk.TreeIter, idx int) (string, error) {
val, err := globalLS.GetValue(iter, idx) val, err := globalTS.GetValue(iter, idx)
if err != nil { if err != nil {
return "", err return "", err
} }
@ -568,8 +660,8 @@ func lsGetStr(ls *gtk.ListStore, iter *gtk.TreeIter, idx int) (string, error) {
return sval, nil return sval, nil
} }
func lsGetInt(ls *gtk.ListStore, iter *gtk.TreeIter, idx int) (int, error) { func lsGetInt(ls *gtk.TreeStore, iter *gtk.TreeIter, idx int) (int, error) {
val, err := globalLS.GetValue(iter, idx) val, err := globalTS.GetValue(iter, idx)
if err != nil { if err != nil {
return 0, err return 0, err
} }
@ -582,9 +674,9 @@ func lsGetInt(ls *gtk.ListStore, iter *gtk.TreeIter, idx int) (int, error) {
return ival.(int), nil return ival.(int), nil
} }
func makeDecision(idx int, rule string, scope int) error { func makeDecision(rule string, scope int, guid string) error {
var dres bool var dres bool
call := dbuso.Call("AddRuleAsync", 0, uint32(scope), rule, "*") call := dbuso.Call("AddRuleAsync", 0, uint32(scope), rule, "*", guid)
err := call.Store(&dres) err := call.Store(&dres)
if err != nil { if err != nil {
@ -593,20 +685,12 @@ func makeDecision(idx int, rule string, scope int) error {
} }
fmt.Println("makeDecision remote result:", dres) fmt.Println("makeDecision remote result:", dres)
return nil
decisionWaiters[idx].Cond.L.Lock()
decisionWaiters[idx].Rule = rule
decisionWaiters[idx].Scope = scope
decisionWaiters[idx].Ready = true
decisionWaiters[idx].Cond.Signal()
decisionWaiters[idx].Cond.L.Unlock()
return nil return nil
} }
/* Do we need to hold the lock while this is called? Stay safe... */ /* Do we need to hold the lock while this is called? Stay safe... */
func toggleHover() { func toggleHover() {
nitems := globalLS.IterNChildren(nil) nitems := globalTS.IterNChildren(nil)
mainWin.SetKeepAbove(nitems > 0) mainWin.SetKeepAbove(nitems > 0)
} }
@ -730,120 +814,187 @@ func clearEditor() {
chkTLS.SetActive(false) chkTLS.SetActive(false)
} }
func removeSelectedRule(idx int, rmdecision bool) error { func removeSelectedRule(idx, subidx int) error {
fmt.Println("XXX: attempting to remove idx = ", idx) fmt.Printf("XXX: attempting to remove idx = %v, %v\n", idx, subidx)
ppathstr := fmt.Sprintf("%d", idx)
pathstr := ppathstr
if subidx > -1 {
pathstr = fmt.Sprintf("%d:%d", idx, subidx)
}
iter, err := globalTS.GetIterFromString(pathstr)
if err != nil {
return err
}
nchildren := globalTS.IterNChildren(iter)
if nchildren >= 1 {
firstpath := fmt.Sprintf("%d:0", idx)
citer, err := globalTS.GetIterFromString(firstpath)
if err != nil {
return err
}
gnrefs, err := globalTS.GetValue(iter, COL_NO_NREFS)
if err != nil {
return err
}
vnrefs, err := gnrefs.GoValue()
if err != nil {
return err
}
nrefs := vnrefs.(int) - 1
for n := 0; n < COL_NO_LAST; n++ {
val, err := globalTS.GetValue(citer, n)
if err != nil {
return err
}
if n == COL_NO_NREFS {
err = globalTS.SetValue(iter, n, nrefs)
} else {
err = globalTS.SetValue(iter, n, val)
}
if err != nil {
return err
}
}
globalTS.Remove(citer)
return nil
}
globalTS.Remove(iter)
path, err := gtk.TreePathNewFromString(fmt.Sprintf("%d", idx)) if subidx > -1 {
ppath, err := gtk.TreePathNewFromString(ppathstr)
if err != nil { if err != nil {
return err return err
} }
iter, err := globalLS.GetIter(path) piter, err := globalTS.GetIter(ppath)
if err != nil { if err != nil {
return err return err
} }
globalLS.Remove(iter) nrefs, err := lsGetInt(globalTS, piter, COL_NO_NREFS)
if err != nil {
return err
}
if rmdecision { err = globalTS.SetValue(piter, COL_NO_NREFS, nrefs-1)
// decisionWaiters = append(decisionWaiters[:idx], decisionWaiters[idx+1:]...) if err != nil {
return err
}
} }
toggleHover() toggleHover()
return nil return nil
} }
// Needs to be locked by the caller
func numSelections() int { func numSelections() int {
sel, err := globalTV.GetSelection() sel, err := globalTV.GetSelection()
if err != nil { if err != nil {
return -1 return -1
} }
rows := sel.GetSelectedRows(globalLS) rows := sel.GetSelectedRows(globalTS)
return int(rows.Length()) return int(rows.Length())
} }
// Needs to be locked by the caller // Needs to be locked by the caller
func getRuleByIdx(idx int) (ruleColumns, *gtk.TreeIter, error) { func getRuleByIdx(idx, subidx int) (ruleColumns, *gtk.TreeIter, error) {
rule := ruleColumns{} rule := ruleColumns{}
tpath := fmt.Sprintf("%d", idx)
if subidx != -1 {
tpath = fmt.Sprintf("%d:%d", idx, subidx)
}
path, err := gtk.TreePathNewFromString(fmt.Sprintf("%d", idx)) path, err := gtk.TreePathNewFromString(tpath)
if err != nil { if err != nil {
return rule, nil, err return rule, nil, err
} }
iter, err := globalLS.GetIter(path) iter, err := globalTS.GetIter(path)
if err != nil { if err != nil {
return rule, nil, err return rule, nil, err
} }
rule.nrefs, err = lsGetInt(globalLS, iter, 0) rule.nrefs, err = lsGetInt(globalTS, iter, COL_NO_NREFS)
if err != nil { if err != nil {
return rule, nil, err return rule, nil, err
} }
rule.GUID, err = lsGetStr(globalLS, iter, 1) rule.GUID, err = lsGetStr(globalTS, iter, COL_NO_GUID)
if err != nil { if err != nil {
return rule, nil, err return rule, nil, err
} }
rule.Path, err = lsGetStr(globalLS, iter, 2) rule.Path, err = lsGetStr(globalTS, iter, COL_NO_PATH)
if err != nil { if err != nil {
return rule, nil, err return rule, nil, err
} }
rule.Icon, err = lsGetStr(globalLS, iter, 3) rule.Icon, err = lsGetStr(globalTS, iter, COL_NO_ICON)
if err != nil { if err != nil {
return rule, nil, err return rule, nil, err
} }
rule.Proto, err = lsGetStr(globalLS, iter, 4) rule.Proto, err = lsGetStr(globalTS, iter, COL_NO_PROTO)
if err != nil { if err != nil {
return rule, nil, err return rule, nil, err
} }
rule.Pid, err = lsGetInt(globalLS, iter, 5) rule.Pid, err = lsGetInt(globalTS, iter, COL_NO_PID)
if err != nil { if err != nil {
return rule, nil, err return rule, nil, err
} }
rule.Target, err = lsGetStr(globalLS, iter, 6) rule.Target, err = lsGetStr(globalTS, iter, COL_NO_DSTIP)
if err != nil { if err != nil {
return rule, nil, err return rule, nil, err
} }
rule.Hostname, err = lsGetStr(globalLS, iter, 7) rule.Hostname, err = lsGetStr(globalTS, iter, COL_NO_HOSTNAME)
if err != nil { if err != nil {
return rule, nil, err return rule, nil, err
} }
rule.Port, err = lsGetInt(globalLS, iter, 8) rule.Port, err = lsGetInt(globalTS, iter, COL_NO_PORT)
if err != nil { if err != nil {
return rule, nil, err return rule, nil, err
} }
rule.UID, err = lsGetInt(globalLS, iter, 9) rule.UID, err = lsGetInt(globalTS, iter, COL_NO_UID)
if err != nil { if err != nil {
return rule, nil, err return rule, nil, err
} }
rule.GID, err = lsGetInt(globalLS, iter, 10) rule.GID, err = lsGetInt(globalTS, iter, COL_NO_GID)
if err != nil { if err != nil {
return rule, nil, err return rule, nil, err
} }
rule.Origin, err = lsGetStr(globalLS, iter, 11) rule.Origin, err = lsGetStr(globalTS, iter, COL_NO_ORIGIN)
if err != nil { if err != nil {
return rule, nil, err return rule, nil, err
} }
rule.Timestamp, err = lsGetStr(globalLS, iter, 12) rule.Timestamp, err = lsGetStr(globalTS, iter, COL_NO_TIMESTAMP)
if err != nil { if err != nil {
return rule, nil, err return rule, nil, err
} }
rule.IsSocks = false rule.IsSocks = false
is_socks, err := lsGetInt(globalLS, iter, 13) is_socks, err := lsGetInt(globalTS, iter, COL_NO_IS_SOCKS)
if err != nil { if err != nil {
return rule, nil, err return rule, nil, err
} }
@ -852,7 +1003,7 @@ func getRuleByIdx(idx int) (ruleColumns, *gtk.TreeIter, error) {
rule.IsSocks = true rule.IsSocks = true
} }
rule.Scope, err = lsGetInt(globalLS, iter, 15) rule.Scope, err = lsGetInt(globalTS, iter, COL_NO_ACTION)
if err != nil { if err != nil {
return rule, nil, err return rule, nil, err
} }
@ -861,116 +1012,78 @@ func getRuleByIdx(idx int) (ruleColumns, *gtk.TreeIter, error) {
} }
// Needs to be locked by the caller // Needs to be locked by the caller
func getSelectedRule() (ruleColumns, int, error) { func getSelectedRule() (ruleColumns, int, int, error) {
rule := ruleColumns{} rule := ruleColumns{}
sel, err := globalTV.GetSelection() sel, err := globalTV.GetSelection()
if err != nil { if err != nil {
return rule, -1, err return rule, -1, -1, err
} }
rows := sel.GetSelectedRows(globalLS) rows := sel.GetSelectedRows(globalTS)
if rows.Length() <= 0 { if rows.Length() <= 0 {
return rule, -1, errors.New("No selection was made") return rule, -1, -1, errors.New("no selection was made")
} }
rdata := rows.NthData(0) rdata := rows.NthData(0)
lIndex, err := strconv.Atoi(rdata.(*gtk.TreePath).String()) tpath := rdata.(*gtk.TreePath).String()
if err != nil {
return rule, -1, err
}
fmt.Println("lindex = ", lIndex)
rule, _, err = getRuleByIdx(lIndex)
if err != nil {
return rule, -1, err
}
return rule, lIndex, nil
}
func addPendingPrompts(rules []string) {
for _, rule := range rules {
fields := strings.Split(rule, "|")
if len(fields) != 19 {
log.Printf("Got saved prompt message with strange data: \"%s\"", rule)
continue
}
guid := fields[0]
icon := fields[2]
path := fields[3]
address := fields[4]
port, err := strconv.Atoi(fields[5])
if err != nil {
log.Println("Error converting port in pending prompt message to integer:", err)
continue
}
ip := fields[6] subidx := -1
origin := fields[7] ptoks := strings.Split(tpath, ":")
proto := fields[8]
uid, err := strconv.Atoi(fields[9]) if len(ptoks) > 2 {
return rule, -1, -1, errors.New("internal error parsing selected item tree path")
} else if len(ptoks) == 2 {
subidx, err = strconv.Atoi(ptoks[1])
if err != nil { if err != nil {
log.Println("Error converting UID in pending prompt message to integer:", err) return rule, -1, -1, err
continue
} }
tpath = ptoks[0]
gid, err := strconv.Atoi(fields[10])
if err != nil {
log.Println("Error converting GID in pending prompt message to integer:", err)
continue
}
pid, err := strconv.Atoi(fields[13])
if err != nil {
log.Println("Error converting pid in pending prompt message to integer:", err)
continue
} }
sandbox := fields[14] lIndex, err := strconv.Atoi(tpath)
is_socks, err := strconv.ParseBool(fields[15])
if err != nil { if err != nil {
log.Println("Error converting SOCKS flag in pending prompt message to boolean:", err) return rule, -1, -1, err
continue
} }
timestamp := fields[16] // fmt.Printf("lindex = %d : %d\n", lIndex, subidx)
optstring := fields[17] rule, _, err = getRuleByIdx(lIndex, subidx)
action, err := strconv.Atoi(fields[18])
if err != nil { if err != nil {
log.Println("Error converting action in pending prompt message to integer:", err) return rule, -1, -1, err
continue
}
addRequestAsync(nil, guid, path, icon, proto, int(pid), ip, address, int(port), int(uid), int(gid), origin, timestamp, is_socks, optstring, sandbox, action)
} }
return rule, lIndex, subidx, nil
} }
func buttonAction(action string) { func buttonAction(action string) {
globalPromptLock.Lock() globalPromptLock.Lock()
rule, idx, err := getSelectedRule() rule, idx, subidx, err := getSelectedRule()
if err != nil { if err != nil {
globalPromptLock.Unlock() globalPromptLock.Unlock()
promptError("Error occurred processing request: " + err.Error()) promptError("Error occurred processing request: " + err.Error())
return return
} }
rule, err = createCurrentRule() urule, err := createCurrentRule()
if err != nil { if err != nil {
globalPromptLock.Unlock() globalPromptLock.Unlock()
promptError("Error occurred constructing new rule: " + err.Error()) promptError("Error occurred constructing new rule: " + err.Error())
return return
} }
// Overlay the rules
rule.Scope = urule.Scope
//rule.Path = urule.Path
rule.Port = urule.Port
rule.Target = urule.Target
rule.Proto = urule.Proto
// rule.UID = urule.UID
// rule.GID = urule.GID
// rule.Uname = urule.Uname
// rule.Gname = urule.Gname
rule.ForceTLS = urule.ForceTLS
fmt.Println("rule = ", rule) fmt.Println("rule = ", rule)
rulestr := action rulestr := action
@ -981,9 +1094,9 @@ func buttonAction(action string) {
rulestr += "|" + rule.Proto + ":" + rule.Target + ":" + strconv.Itoa(rule.Port) rulestr += "|" + rule.Proto + ":" + rule.Target + ":" + strconv.Itoa(rule.Port)
rulestr += "|" + sgfw.RuleModeString[sgfw.RuleMode(rule.Scope)] rulestr += "|" + sgfw.RuleModeString[sgfw.RuleMode(rule.Scope)]
fmt.Println("RULESTR = ", rulestr) fmt.Println("RULESTR = ", rulestr)
makeDecision(idx, rulestr, int(rule.Scope)) makeDecision(rulestr, int(rule.Scope), rule.GUID)
fmt.Println("Decision made.") err = removeSelectedRule(idx, subidx)
err = removeSelectedRule(idx, true) addRecentlyRemoved(rule.GUID)
globalPromptLock.Unlock() globalPromptLock.Unlock()
if err == nil { if err == nil {
clearEditor() clearEditor()
@ -994,7 +1107,6 @@ func buttonAction(action string) {
} }
func main() { func main() {
decisionWaiters = make([]*decisionWaiter, 0)
_, err := newDbusServer() _, err := newDbusServer()
if err != nil { if err != nil {
log.Fatal("Error:", err) log.Fatal("Error:", err)
@ -1104,6 +1216,7 @@ func main() {
// globalIcon.SetFromIconName("firefox", gtk.ICON_SIZE_DND) // globalIcon.SetFromIconName("firefox", gtk.ICON_SIZE_DND)
editApp = get_entry("") editApp = get_entry("")
editApp.SetEditable(false)
editApp.Connect("changed", toggleValidRuleState) editApp.Connect("changed", toggleValidRuleState)
hbox.PackStart(lbl, false, false, 10) hbox.PackStart(lbl, false, false, 10)
hbox.PackStart(editApp, true, true, 10) hbox.PackStart(editApp, true, true, 10)
@ -1166,42 +1279,43 @@ func main() {
// box.PackStart(tv, false, true, 5) // box.PackStart(tv, false, true, 5)
box.PackStart(scrollbox, false, true, 5) box.PackStart(scrollbox, false, true, 5)
tv.AppendColumn(createColumn("#", 0)) tv.AppendColumn(createColumnText("#", COL_NO_NREFS))
tv.AppendColumn(createColumnImg("", COL_NO_ICON_PIXBUF))
guidcol := createColumn("GUID", 1) guidcol := createColumnText("GUID", COL_NO_GUID)
guidcol.SetVisible(false) guidcol.SetVisible(false)
tv.AppendColumn(guidcol) tv.AppendColumn(guidcol)
tv.AppendColumn(createColumn("Path", 2)) tv.AppendColumn(createColumnText("Path", COL_NO_PATH))
icol := createColumn("Icon", 3) icol := createColumnText("Icon", COL_NO_ICON)
icol.SetVisible(false) icol.SetVisible(false)
tv.AppendColumn(icol) tv.AppendColumn(icol)
tv.AppendColumn(createColumn("Protocol", 4)) tv.AppendColumn(createColumnText("Protocol", COL_NO_PROTO))
tv.AppendColumn(createColumn("PID", 5)) tv.AppendColumn(createColumnText("PID", COL_NO_PID))
tv.AppendColumn(createColumn("IP Address", 6)) tv.AppendColumn(createColumnText("IP Address", COL_NO_DSTIP))
tv.AppendColumn(createColumn("Hostname", 7)) tv.AppendColumn(createColumnText("Hostname", COL_NO_HOSTNAME))
tv.AppendColumn(createColumn("Port", 8)) tv.AppendColumn(createColumnText("Port", COL_NO_PORT))
tv.AppendColumn(createColumn("UID", 9)) tv.AppendColumn(createColumnText("UID", COL_NO_UID))
tv.AppendColumn(createColumn("GID", 10)) tv.AppendColumn(createColumnText("GID", COL_NO_GID))
tv.AppendColumn(createColumn("Origin", 11)) tv.AppendColumn(createColumnText("Origin", COL_NO_ORIGIN))
tv.AppendColumn(createColumn("Timestamp", 12)) tv.AppendColumn(createColumnText("Timestamp", COL_NO_TIMESTAMP))
scol := createColumn("Is SOCKS", 13) scol := createColumnText("Is SOCKS", COL_NO_IS_SOCKS)
scol.SetVisible(false) scol.SetVisible(false)
tv.AppendColumn(scol) tv.AppendColumn(scol)
tv.AppendColumn(createColumn("Details", 14)) tv.AppendColumn(createColumnText("Details", COL_NO_OPTSTRING))
acol := createColumn("Scope", 15) acol := createColumnText("Scope", COL_NO_ACTION)
acol.SetVisible(false) acol.SetVisible(false)
tv.AppendColumn(acol) tv.AppendColumn(acol)
listStore := createListStore(true) treeStore := createTreeStore(true)
globalLS = listStore globalTS = treeStore
tv.SetModel(listStore) tv.SetModel(treeStore)
btnApprove.Connect("clicked", func() { btnApprove.Connect("clicked", func() {
buttonAction("ALLOW") buttonAction("ALLOW")
@ -1214,7 +1328,7 @@ func main() {
// tv.SetActivateOnSingleClick(true) // tv.SetActivateOnSingleClick(true)
tv.Connect("row-activated", func() { tv.Connect("row-activated", func() {
globalPromptLock.Lock() globalPromptLock.Lock()
seldata, _, err := getSelectedRule() seldata, _, _, err := getSelectedRule()
globalPromptLock.Unlock() globalPromptLock.Unlock()
if err != nil { if err != nil {
promptError("Unexpected error reading selected rule: " + err.Error()) promptError("Unexpected error reading selected rule: " + err.Error())
@ -1237,7 +1351,7 @@ func main() {
editPort.SetText(strconv.Itoa(seldata.Port)) editPort.SetText(strconv.Itoa(seldata.Port))
radioOnce.SetActive(seldata.Scope == int(sgfw.APPLY_ONCE)) radioOnce.SetActive(seldata.Scope == int(sgfw.APPLY_ONCE))
radioProcess.SetSensitive(seldata.Pid > 0) radioProcess.SetSensitive(seldata.Scope == int(sgfw.APPLY_PROCESS))
radioParent.SetActive(false) radioParent.SetActive(false)
radioSession.SetActive(seldata.Scope == int(sgfw.APPLY_SESSION)) radioSession.SetActive(seldata.Scope == int(sgfw.APPLY_SESSION))
radioPermanent.SetActive(seldata.Scope == int(sgfw.APPLY_FOREVER)) radioPermanent.SetActive(seldata.Scope == int(sgfw.APPLY_FOREVER))
@ -1286,14 +1400,14 @@ func main() {
mainWin.ShowAll() mainWin.ShowAll()
// mainWin.SetKeepAbove(true) // mainWin.SetKeepAbove(true)
var dres = []string{} var dres bool
call := dbuso.Call("GetPendingRequests", 0, "*") call := dbuso.Call("GetPendingRequests", 0, "*")
err = call.Store(&dres) err = call.Store(&dres)
if err != nil { if err != nil {
errmsg := "Could not query running SGFW instance (maybe it's not running?): " + err.Error() errmsg := "Could not query running SGFW instance (maybe it's not running?): " + err.Error()
promptError(errmsg) promptError(errmsg)
} else { } else if !dres {
addPendingPrompts(dres) promptError("Call to sgfw did not succeed; fw-prompt may have loaded without retrieving all pending connections")
} }
gtk.Main() gtk.Main()

@ -58,7 +58,7 @@ func readConfig() {
PromptExpanded: false, PromptExpanded: false,
PromptExpert: false, PromptExpert: false,
DefaultAction: "SESSION", DefaultAction: "SESSION",
DefaultActionID: 1, DefaultActionID: 0,
} }
if len(buf) > 0 { if len(buf) > 0 {

@ -41,6 +41,7 @@ const (
RULE_MODE_PROCESS RULE_MODE_PROCESS
RULE_MODE_PERMANENT RULE_MODE_PERMANENT
RULE_MODE_SYSTEM RULE_MODE_SYSTEM
RULE_MODE_ONCE
) )
// RuleModeString is used to get a rule mode string from its id // RuleModeString is used to get a rule mode string from its id
@ -49,6 +50,7 @@ var RuleModeString = map[RuleMode]string{
RULE_MODE_PROCESS: "PROCESS", RULE_MODE_PROCESS: "PROCESS",
RULE_MODE_PERMANENT: "PERMANENT", RULE_MODE_PERMANENT: "PERMANENT",
RULE_MODE_SYSTEM: "SYSTEM", RULE_MODE_SYSTEM: "SYSTEM",
RULE_MODE_ONCE: "ONCE",
} }
// RuleModeValue converts a mode string to its id // RuleModeValue converts a mode string to its id
@ -57,16 +59,18 @@ var RuleModeValue = map[string]RuleMode{
RuleModeString[RULE_MODE_PROCESS]: RULE_MODE_PROCESS, RuleModeString[RULE_MODE_PROCESS]: RULE_MODE_PROCESS,
RuleModeString[RULE_MODE_PERMANENT]: RULE_MODE_PERMANENT, RuleModeString[RULE_MODE_PERMANENT]: RULE_MODE_PERMANENT,
RuleModeString[RULE_MODE_SYSTEM]: RULE_MODE_SYSTEM, RuleModeString[RULE_MODE_SYSTEM]: RULE_MODE_SYSTEM,
RuleModeString[RULE_MODE_ONCE]: RULE_MODE_ONCE,
} }
//FilterScope contains a filter's time scope //FilterScope contains a filter's time scope
type FilterScope uint16 type FilterScope uint16
const ( const (
APPLY_ONCE FilterScope = iota APPLY_SESSION FilterScope = iota
APPLY_SESSION
APPLY_PROCESS APPLY_PROCESS
APPLY_FOREVER APPLY_FOREVER
APPLY_SYSTEM
APPLY_ONCE
) )
// FilterScopeString converts a filter scope ID to its string // FilterScopeString converts a filter scope ID to its string
@ -143,9 +147,3 @@ type DbusRule struct {
Mode uint16 Mode uint16
Sandbox string Sandbox string
} }
/*const (
OZ_FWRULE_WHITELIST = iota
OZ_FWRULE_BLACKLIST
OZ_FWRULE_NONE
) */

@ -199,18 +199,19 @@ func (ds *dbusServer) DeleteRule(id uint32) *dbus.Error {
return nil return nil
} }
func (ds *dbusServer) GetPendingRequests(policy string) ([]string, *dbus.Error) { func (ds *dbusServer) GetPendingRequests(policy string) (bool, *dbus.Error) {
succeeded := true
log.Debug("+++ GetPendingRequests()") log.Debug("+++ GetPendingRequests()")
ds.fw.lock.Lock() ds.fw.lock.Lock()
defer ds.fw.lock.Unlock() defer ds.fw.lock.Unlock()
pending_data := make([]string, 0)
for pname := range ds.fw.policyMap { for pname := range ds.fw.policyMap {
policy := ds.fw.policyMap[pname] policy := ds.fw.policyMap[pname]
pqueue := policy.pendingQueue pqueue := policy.pendingQueue
for _, pc := range pqueue { for _, pc := range pqueue {
var dres bool
addr := pc.hostname() addr := pc.hostname()
if addr == "" { if addr == "" {
addr = pc.dst().String() addr = pc.dst().String()
@ -224,40 +225,48 @@ func (ds *dbusServer) GetPendingRequests(policy string) ([]string, *dbus.Error)
dststr = addr + " (via proxy resolver)" dststr = addr + " (via proxy resolver)"
} }
pstr := "" call := ds.prompter.dbusObj.Call("com.subgraph.FirewallPrompt.RequestPromptAsync", 0,
pstr += pc.getGUID() + "|" pc.getGUID(),
pstr += policy.application + "|" policy.application,
pstr += policy.icon + "|" policy.icon,
pstr += policy.path + "|" policy.path,
pstr += addr + "|" addr,
pstr += strconv.FormatUint(uint64(pc.dstPort()), 10) + "|" int32(pc.dstPort()),
pstr += dststr + "|" dststr,
pstr += pc.src().String() + "|" pc.src().String(),
pstr += pc.proto() + "|" pc.proto(),
pstr += strconv.FormatInt(int64(pc.procInfo().UID), 10) + "|" int32(pc.procInfo().UID),
pstr += strconv.FormatInt(int64(pc.procInfo().GID), 10) + "|" int32(pc.procInfo().GID),
pstr += uidToUser(pc.procInfo().UID) + "|" uidToUser(pc.procInfo().UID),
pstr += gidToGroup(pc.procInfo().GID) + "|" gidToGroup(pc.procInfo().GID),
pstr += strconv.FormatInt(int64(pc.procInfo().Pid), 10) + "|" int32(pc.procInfo().Pid),
pstr += pc.sandbox() + "|" pc.sandbox(),
pstr += strconv.FormatBool(pc.socks()) + "|" pc.socks(),
pstr += pc.getTimestamp() + "|" pc.getTimestamp(),
pstr += pc.getOptString() + "|" pc.getOptString(),
pstr += strconv.FormatUint(uint64(FirewallConfig.DefaultActionID), 10) FirewallConfig.PromptExpanded,
pending_data = append(pending_data, pstr) FirewallConfig.PromptExpert,
} int32(FirewallConfig.DefaultActionID))
} err := call.Store(&dres)
if err != nil {
return pending_data, nil log.Warningf("Error sending DBus async pending RequestPrompt message: %v", err)
succeeded = false
}
}
}
return succeeded, nil
} }
func (ds *dbusServer) AddRuleAsync(scope uint32, rule string, policy string) (bool, *dbus.Error) { func (ds *dbusServer) AddRuleAsync(scope uint32, rule, policy, guid string) (bool, *dbus.Error) {
log.Warningf("AddRuleAsync %v, %v / %v\n", scope, rule, policy) log.Warningf("AddRuleAsync %v, %v / %v / %v\n", scope, rule, policy, guid)
ds.fw.lock.Lock() ds.fw.lock.Lock()
defer ds.fw.lock.Unlock() defer ds.fw.lock.Unlock()
prule := PendingRule{rule: rule, scope: int(scope), policy: policy} prule := PendingRule{rule: rule, scope: int(scope), policy: policy, guid: guid}
for pname := range ds.fw.policyMap { for pname := range ds.fw.policyMap {
log.Debug("+++ Adding prule to policy") log.Debug("+++ Adding prule to policy")

@ -85,5 +85,14 @@ func loadDesktopFile(path string) {
icon: icon, icon: icon,
name: name, name: name,
} }
lname := exec
for i := 0; i < 5; i++ {
lname, err = os.Readlink(lname)
if err == nil {
entryMap[lname] = entryMap[exec]
}
}
} }
} }

@ -8,6 +8,7 @@ import (
"os" "os"
"strconv" "strconv"
"strings" "strings"
"sync"
"github.com/subgraph/oz/ipc" "github.com/subgraph/oz/ipc"
) )
@ -21,9 +22,14 @@ type OzInitProc struct {
} }
var OzInitPids []OzInitProc = []OzInitProc{} var OzInitPids []OzInitProc = []OzInitProc{}
var OzInitPidsLock = sync.Mutex{}
func addInitPid(pid int, name string, sboxid int) { func addInitPid(pid int, name string, sboxid int) {
fmt.Println("::::::::::: init pid added: ", pid, " -> ", name) fmt.Println("::::::::::: init pid added: ", pid, " -> ", name)
OzInitPidsLock.Lock()
defer OzInitPidsLock.Unlock()
for i := 0; i < len(OzInitPids); i++ { for i := 0; i < len(OzInitPids); i++ {
if OzInitPids[i].Pid == pid { if OzInitPids[i].Pid == pid {
return return
@ -36,6 +42,9 @@ func addInitPid(pid int, name string, sboxid int) {
func removeInitPid(pid int) { func removeInitPid(pid int) {
fmt.Println("::::::::::: removing PID: ", pid) fmt.Println("::::::::::: removing PID: ", pid)
OzInitPidsLock.Lock()
defer OzInitPidsLock.Unlock()
for i := 0; i < len(OzInitPids); i++ { for i := 0; i < len(OzInitPids); i++ {
if OzInitPids[i].Pid == pid { if OzInitPids[i].Pid == pid {
OzInitPids = append(OzInitPids[:i], OzInitPids[i+1:]...) OzInitPids = append(OzInitPids[:i], OzInitPids[i+1:]...)
@ -139,19 +148,6 @@ func ReceiverLoop(fw *Firewall, c net.Conn) {
c.Write([]byte(ruledesc)) c.Write([]byte(ruledesc))
} }
/* for i := 0; i < len(sandboxRules); i++ {
rulestr := ""
if sandboxRules[i].Whitelist {
rulestr += "whitelist"
} else {
rulestr += "blacklist"
}
rulestr += " " + sandboxRules[i].SrcIf.String() + " -> " + sandboxRules[i].DstIP.String() + " : " + strconv.Itoa(int(sandboxRules[i].DstPort)) + "\n"
c.Write([]byte(rulestr))
} */
return return
} else { } else {
tokens := strings.Split(data, " ") tokens := strings.Split(data, " ")
@ -337,12 +333,7 @@ const OzSocketName = "@oz-control"
var bSockName = OzSocketName var bSockName = OzSocketName
var messageFactory = ipc.NewMsgFactory( func init() {
new(ListProxiesMsg),
new(ListProxiesResp),
)
func clientConnect() (*ipc.MsgConn, error) {
bSockName = os.Getenv("SOCKET_NAME") bSockName = os.Getenv("SOCKET_NAME")
if bSockName != "" { if bSockName != "" {
@ -356,7 +347,14 @@ func clientConnect() (*ipc.MsgConn, error) {
} else { } else {
bSockName = OzSocketName bSockName = OzSocketName
} }
}
var messageFactory = ipc.NewMsgFactory(
new(ListProxiesMsg),
new(ListProxiesResp),
)
func clientConnect() (*ipc.MsgConn, error) {
return ipc.Connect(bSockName, messageFactory, nil) return ipc.Connect(bSockName, messageFactory, nil)
} }

@ -23,17 +23,6 @@ var _interpreters = []string{
"bash", "bash",
} }
/*type sandboxRule struct {
SrcIf net.IP
DstIP net.IP
DstPort uint16
Whitelist bool
}
var sandboxRules = []sandboxRule {
// { net.IP{172,16,1,42}, net.IP{140,211,166,134}, 21, false },
} */
type pendingConnection interface { type pendingConnection interface {
policy() *Policy policy() *Policy
procInfo() *procsnitch.Info procInfo() *procsnitch.Info
@ -222,6 +211,7 @@ type PendingRule struct {
rule string rule string
scope int scope int
policy string policy string
guid string
} }
type Policy struct { type Policy struct {
@ -313,7 +303,6 @@ func (p *Policy) processPacket(pkt *nfqueue.NFQPacket, timestamp time.Time, pinf
if !FirewallConfig.LogRedact { if !FirewallConfig.LogRedact {
log.Infof("Lookup(%s): %s", dstip.String(), name) log.Infof("Lookup(%s): %s", dstip.String(), name)
} }
// fwo := matchAgainstOzRules(srcip, dstip, dstp)
result := p.rules.filterPacket(pkt, pinfo, srcip, name, optstr) result := p.rules.filterPacket(pkt, pinfo, srcip, name, optstr)
switch result { switch result {
@ -331,12 +320,8 @@ func (p *Policy) processPacket(pkt *nfqueue.NFQPacket, timestamp time.Time, pinf
func (p *Policy) processPromptResult(pc pendingConnection) { func (p *Policy) processPromptResult(pc pendingConnection) {
p.pendingQueue = append(p.pendingQueue, pc) p.pendingQueue = append(p.pendingQueue, pc)
//fmt.Println("processPromptResult(): p.promptInProgress = ", p.promptInProgress)
//if DoMultiPrompt || (!DoMultiPrompt && !p.promptInProgress) {
// if !p.promptInProgress {
p.promptInProgress = true p.promptInProgress = true
go p.fw.dbus.prompter.prompt(p) go p.fw.dbus.prompter.prompt(p)
// }
} }
func (p *Policy) nextPending() (pendingConnection, bool) { func (p *Policy) nextPending() (pendingConnection, bool) {
@ -372,6 +357,15 @@ func (p *Policy) removePending(pc pendingConnection) {
} }
} }
func (p *Policy) processNewRuleOnce(r *Rule, guid string) bool {
p.lock.Lock()
defer p.lock.Unlock()
fmt.Println("----------------------- processNewRule() ONCE")
p.filterPendingOne(r, guid)
return true
}
func (p *Policy) processNewRule(r *Rule, scope FilterScope) bool { func (p *Policy) processNewRule(r *Rule, scope FilterScope) bool {
p.lock.Lock() p.lock.Lock()
defer p.lock.Unlock() defer p.lock.Unlock()
@ -419,6 +413,47 @@ func (p *Policy) removeRule(r *Rule) {
p.rules = newRules p.rules = newRules
} }
func (p *Policy) filterPendingOne(rule *Rule, guid string) {
remaining := []pendingConnection{}
for _, pc := range p.pendingQueue {
if guid != "" && guid != pc.getGUID() {
continue
}
if rule.match(pc.src(), pc.dst(), pc.dstPort(), pc.hostname(), pc.proto(), pc.procInfo().UID, pc.procInfo().GID, uidToUser(pc.procInfo().UID), gidToGroup(pc.procInfo().GID), pc.procInfo().Sandbox) {
prompter := pc.getPrompter()
if prompter == nil {
fmt.Println("-------- prompter = NULL")
} else {
call := prompter.dbusObj.Call("com.subgraph.FirewallPrompt.RemovePrompt", 0, pc.getGUID())
fmt.Println("CAAAAAAAAAAAAAAALL = ", call)
}
log.Infof("Adding rule for: %s", rule.getString(FirewallConfig.LogRedact))
// log.Noticef("%s > %s", rule.getString(FirewallConfig.LogRedact), pc.print())
if rule.rtype == RULE_ACTION_ALLOW {
pc.accept()
} else if rule.rtype == RULE_ACTION_ALLOW_TLSONLY {
pc.acceptTLSOnly()
} else {
srcs := pc.src().String() + ":" + strconv.Itoa(int(pc.srcPort()))
log.Warningf("DENIED outgoing connection attempt by %s from %s %s -> %s:%d (user prompt) %v",
pc.procInfo().ExePath, pc.proto(), srcs, pc.dst(), pc.dstPort, rule.rtype)
pc.drop()
}
// XXX: If matching a GUID, we can break out immediately
} else {
remaining = append(remaining, pc)
}
}
if len(remaining) != len(p.pendingQueue) {
p.pendingQueue = remaining
}
}
func (p *Policy) filterPending(rule *Rule) { func (p *Policy) filterPending(rule *Rule) {
remaining := []pendingConnection{} remaining := []pendingConnection{}
for _, pc := range p.pendingQueue { for _, pc := range p.pendingQueue {
@ -532,18 +567,6 @@ func (fw *Firewall) filterPacket(pkt *nfqueue.NFQPacket, timestamp time.Time) {
*/ */
_, dstip := getPacketIPAddrs(pkt) _, dstip := getPacketIPAddrs(pkt)
/* _, dstp := getPacketPorts(pkt) /* _, dstp := getPacketPorts(pkt)
fwo := eatchAgainstOzRules(srcip, dstip, dstp)
log.Notice("XXX: Attempting [2] to filter packet on rules -> ", fwo)
if fwo == OZ_FWRULE_WHITELIST {
log.Noticef("Automatically passed through whitelisted sandbox traffic from %s to %s:%d\n", srcip, dstip, dstp)
pkt.Accept()
return
} else if fwo == OZ_FWRULE_BLACKLIST {
log.Noticef("Automatically blocking blacklisted sandbox traffic from %s to %s:%d\n", srcip, dstip, dstp)
pkt.SetMark(1)
pkt.Accept()
return
} */ } */
ppath := "*" ppath := "*"
@ -633,6 +656,7 @@ func readFileDirect(filename string) ([]byte, error) {
func getAllProcNetDataLocal() ([]string, error) { func getAllProcNetDataLocal() ([]string, error) {
data := "" data := ""
OzInitPidsLock.Lock()
for i := 0; i < len(OzInitPids); i++ { for i := 0; i < len(OzInitPids); i++ {
fname := fmt.Sprintf("/proc/%d/net/tcp", OzInitPids[i]) fname := fmt.Sprintf("/proc/%d/net/tcp", OzInitPids[i])
@ -647,6 +671,8 @@ func getAllProcNetDataLocal() ([]string, error) {
} }
OzInitPidsLock.Unlock()
lines := strings.Split(data, "\n") lines := strings.Split(data, "\n")
rlines := make([]string, 0) rlines := make([]string, 0)
ctr := 1 ctr := 1
@ -692,6 +718,7 @@ func LookupSandboxProc(srcip net.IP, srcp uint16, dstip net.IP, dstp uint16, pro
var res *procsnitch.Info = nil var res *procsnitch.Info = nil
var optstr string var optstr string
removePids := make([]int, 0) removePids := make([]int, 0)
OzInitPidsLock.Lock()
for i := 0; i < len(OzInitPids); i++ { for i := 0; i < len(OzInitPids); i++ {
data := "" data := ""
@ -746,6 +773,8 @@ func LookupSandboxProc(srcip net.IP, srcp uint16, dstip net.IP, dstp uint16, pro
} }
OzInitPidsLock.Unlock()
for _, p := range removePids { for _, p := range removePids {
removeInitPid(p) removeInitPid(p)
} }
@ -797,6 +826,7 @@ func findProcessForPacket(pkt *nfqueue.NFQPacket, reverse bool, strictness int)
if res == nil { if res == nil {
removePids := make([]int, 0) removePids := make([]int, 0)
OzInitPidsLock.Lock()
for i := 0; i < len(OzInitPids); i++ { for i := 0; i < len(OzInitPids); i++ {
data := "" data := ""
@ -845,6 +875,8 @@ func findProcessForPacket(pkt *nfqueue.NFQPacket, reverse bool, strictness int)
} }
OzInitPidsLock.Unlock()
for _, p := range removePids { for _, p := range removePids {
removeInitPid(p) removeInitPid(p)
} }
@ -942,21 +974,3 @@ func getPacketPorts(pkt *nfqueue.NFQPacket) (uint16, uint16) {
return s, d return s, d
} }
/*func matchAgainstOzRules(srci, dsti net.IP, dstp uint16) int {
for i := 0; i < len(sandboxRules); i++ {
log.Notice("XXX: Attempting to match: ", srci, " / ", dsti, " / ", dstp, " | ", sandboxRules[i])
if sandboxRules[i].SrcIf.Equal(srci) && sandboxRules[i].DstIP.Equal(dsti) && sandboxRules[i].DstPort == dstp {
if sandboxRules[i].Whitelist {
return OZ_FWRULE_WHITELIST
}
return OZ_FWRULE_BLACKLIST
}
}
return OZ_FWRULE_NONE
} */

@ -61,6 +61,9 @@ func (p *prompter) processNextPacket() bool {
p.lock.Lock() p.lock.Lock()
pc, empty = p.nextConnection() pc, empty = p.nextConnection()
p.lock.Unlock() p.lock.Unlock()
if pc != nil {
fmt.Println("GOT NON NIL")
}
//fmt.Println("XXX: processNextPacket() loop; empty = ", empty, " / pc = ", pc) //fmt.Println("XXX: processNextPacket() loop; empty = ", empty, " / pc = ", pc)
if pc == nil && empty { if pc == nil && empty {
return false return false
@ -142,7 +145,7 @@ func monitorPromptFDLoop() {
} }
inode := sb.Ino inode := sb.Ino
fmt.Println("+++ INODE = ", inode) // fmt.Println("+++ INODE = ", inode)
if inode != fdmon.inode { if inode != fdmon.inode {
fmt.Printf("inode mismatch: %v vs %v\n", inode, fdmon.inode) fmt.Printf("inode mismatch: %v vs %v\n", inode, fdmon.inode)
@ -268,38 +271,6 @@ func (p *prompter) processConnection(pc pendingConnection) {
return return
/* p.dbusObj.Go("com.subgraph.FirewallPrompt.RequestPrompt", 0, callChan,
pc.getGUID(),
policy.application,
policy.icon,
policy.path,
addr,
int32(pc.dstPort()),
dststr,
pc.src().String(),
pc.proto(),
int32(pc.procInfo().UID),
int32(pc.procInfo().GID),
uidToUser(pc.procInfo().UID),
gidToGroup(pc.procInfo().GID),
int32(pc.procInfo().Pid),
pc.sandbox(),
pc.socks(),
pc.getOptString(),
FirewallConfig.PromptExpanded,
FirewallConfig.PromptExpert,
int32(FirewallConfig.DefaultActionID))
saveChannel(callChan, false, true)
/* err := call.Store(&scope, &rule)
if err != nil {
log.Warningf("Error sending dbus RequestPrompt message: %v", err)
policy.removePending(pc)
pc.drop()
return
} */
// the prompt sends: // the prompt sends:
// ALLOW|dest or DENY|dest // ALLOW|dest or DENY|dest
// //
@ -370,6 +341,7 @@ func (p *prompter) nextConnection() (pendingConnection, bool) {
fmt.Println("policy queue len = ", len(p.policyQueue)) fmt.Println("policy queue len = ", len(p.policyQueue))
for pind < len(p.policyQueue) { for pind < len(p.policyQueue) {
fmt.Printf("policy loop %d of %d\n", pind, len(p.policyQueue))
//fmt.Printf("XXX: pind = %v of %v\n", pind, len(p.policyQueue)) //fmt.Printf("XXX: pind = %v of %v\n", pind, len(p.policyQueue))
policy := p.policyQueue[pind] policy := p.policyQueue[pind]
pc, qempty := policy.nextPending() pc, qempty := policy.nextPending()
@ -379,21 +351,54 @@ func (p *prompter) nextConnection() (pendingConnection, bool) {
continue continue
} else { } else {
pind++ pind++
// if pc == nil && !qempty {
if len(policy.rulesPending) > 0 { pendingOnce := make([]PendingRule, 0)
fmt.Println("policy rules pending = ", len(policy.rulesPending)) pendingOther := make([]PendingRule, 0)
for _, r := range policy.rulesPending {
if r.scope == int(APPLY_ONCE) {
pendingOnce = append(pendingOnce, r)
} else {
pendingOther = append(pendingOther, r)
}
}
fmt.Printf("# pending once = %d, other = %d, pc = %p / policy = %p\n", len(pendingOnce), len(pendingOther), pc, policy)
policy.rulesPending = pendingOther
// One time filters are all applied right here, at once.
for _, pr := range pendingOnce {
toks := strings.Split(pr.rule, "|")
sandbox := ""
if len(toks) > 2 {
sandbox = toks[2]
}
tempRule := fmt.Sprintf("%s|%s", toks[0], toks[1])
tempRule += "||-1:-1|" + sandbox + "|"
r, err := policy.parseRule(tempRule, false)
if err != nil {
log.Warningf("Error parsing rule string returned from dbus RequestPrompt: %v", err)
continue
}
r.mode = RuleMode(pr.scope)
fmt.Println("+++++++ processing one time rule: ", pr.rule)
policy.processNewRuleOnce(r, pr.guid)
}
// if pc == nil && !qempty {
if len(policy.rulesPending) > 0 {
fmt.Println("non/once policy rules pending = ", len(policy.rulesPending))
prule := policy.rulesPending[0] prule := policy.rulesPending[0]
policy.rulesPending = append(policy.rulesPending[:0], policy.rulesPending[1:]...) policy.rulesPending = append(policy.rulesPending[:0], policy.rulesPending[1:]...)
toks := strings.Split(prule.rule, "|") toks := strings.Split(prule.rule, "|")
sandbox := "" sandbox := ""
if len(toks) > 2 { if len(toks) > 2 {
sandbox = toks[2] sandbox = toks[2]
} }
sandbox += ""
tempRule := fmt.Sprintf("%s|%s", toks[0], toks[1]) tempRule := fmt.Sprintf("%s|%s", toks[0], toks[1])
tempRule += "||-1:-1|" + sandbox + "|" tempRule += "||-1:-1|" + sandbox + "|"

@ -50,6 +50,7 @@ func (r *Rule) getString(redact bool) string {
if r.rtype == RULE_ACTION_ALLOW || r.rtype == RULE_ACTION_ALLOW_TLSONLY { if r.rtype == RULE_ACTION_ALLOW || r.rtype == RULE_ACTION_ALLOW_TLSONLY {
rtype = RuleActionString[r.rtype] rtype = RuleActionString[r.rtype]
} }
rmode := "|" + RuleModeString[r.mode] rmode := "|" + RuleModeString[r.mode]
protostr := "" protostr := ""
@ -247,7 +248,7 @@ func (r *Rule) parse(s string) bool {
r.saddr = nil r.saddr = nil
parts := strings.Split(s, "|") parts := strings.Split(s, "|")
if len(parts) < 4 || len(parts) > 6 { if len(parts) < 4 || len(parts) > 6 {
log.Notice("invalid number ", len(parts), " of rule parts in line ", s) log.Notice("Error: invalid number ", len(parts), " of rule parts in line ", s)
return false return false
} }
if parts[2] == "SYSTEM" { if parts[2] == "SYSTEM" {
@ -275,7 +276,7 @@ func (r *Rule) parse(s string) bool {
r.saddr = net.ParseIP(parts[5]) r.saddr = net.ParseIP(parts[5])
if r.saddr == nil { if r.saddr == nil {
log.Notice("invalid source IP ", parts[5], " in line ", s) log.Notice("Error: invalid source IP ", parts[5], " in line ", s)
return false return false
} }

@ -3,7 +3,7 @@ package sgfw
import ( import (
"crypto/x509" "crypto/x509"
"encoding/binary" "encoding/binary"
"encoding/hex" // "encoding/hex"
"errors" "errors"
"fmt" "fmt"
"io" "io"
@ -11,7 +11,7 @@ import (
"time" "time"
) )
const TLSGUARD_READ_TIMEOUT = 10 * time.Second const TLSGUARD_READ_TIMEOUT = 8 * time.Second
const TLSGUARD_MIN_TLS_VER_MAJ = 3 const TLSGUARD_MIN_TLS_VER_MAJ = 3
const TLSGUARD_MIN_TLS_VER_MIN = 1 const TLSGUARD_MIN_TLS_VER_MIN = 1
@ -59,12 +59,66 @@ const TLS1_AD_USER_CANCELLED = 90
const TLS1_AD_NO_RENEGOTIATION = 100 const TLS1_AD_NO_RENEGOTIATION = 100
const TLS1_AD_UNSUPPORTED_EXTENSION = 110 const TLS1_AD_UNSUPPORTED_EXTENSION = 110
const TLSEXT_TYPE_server_name = 1 const TLSEXT_TYPE_server_name = 0
const TLSEXT_TYPE_max_fragment_length = 1
const TLSEXT_TYPE_client_certificate_url = 2
const TLSEXT_TYPE_trusted_ca_keys = 3
const TLSEXT_TYPE_truncated_hmac = 4
const TLSEXT_TYPE_status_request = 5
const TLSEXT_TYPE_user_mapping = 6
const TLSEXT_TYPE_client_authz = 7
const TLSEXT_TYPE_server_authz = 8
const TLSEXT_TYPE_cert_type = 9
const TLSEXT_TYPE_supported_groups = 10
const TLSEXT_TYPE_ec_point_formats = 11
const TLSEXT_TYPE_srp = 12
const TLSEXT_TYPE_signature_algorithms = 13 const TLSEXT_TYPE_signature_algorithms = 13
const TLSEXT_TYPE_use_srtp = 14
const TLSEXT_TYPE_heartbeat = 15
const TLSEXT_TYPE_application_layer_protocol_negotiation = 16
const TLSEXT_TYPE_status_request_v2 = 17
const TLSEXT_TYPE_signed_certificate_timestamp = 18
const TLSEXT_TYPE_client_certificate_type = 19 const TLSEXT_TYPE_client_certificate_type = 19
const TLSEXT_TYPE_server_certificate_type = 20
const TLSEXT_TYPE_padding = 21
const TLSEXT_TYPE_encrypt_then_mac = 22
const TLSEXT_TYPE_extended_master_secret = 23 const TLSEXT_TYPE_extended_master_secret = 23
const TLSEXT_TYPE_token_binding = 24
const TLSEXT_TYPE_cached_info = 25
const TLSEXT_TYPE_SessionTicket = 35
const TLSEXT_TYPE_renegotiate = 0xff01 const TLSEXT_TYPE_renegotiate = 0xff01
var tlsExtensionMap map[uint16]string = map[uint16]string{
TLSEXT_TYPE_server_name: "TLSEXT_TYPE_server_name",
TLSEXT_TYPE_max_fragment_length: "TLSEXT_TYPE_max_fragment_length",
TLSEXT_TYPE_client_certificate_url: "TLSEXT_TYPE_client_certificate_url",
TLSEXT_TYPE_trusted_ca_keys: "TLSEXT_TYPE_trusted_ca_keys",
TLSEXT_TYPE_truncated_hmac: "TLSEXT_TYPE_truncated_hmac",
TLSEXT_TYPE_status_request: "TLSEXT_TYPE_status_request",
TLSEXT_TYPE_user_mapping: "TLSEXT_TYPE_user_mapping",
TLSEXT_TYPE_client_authz: "TLSEXT_TYPE_client_authz",
TLSEXT_TYPE_server_authz: "TLSEXT_TYPE_server_authz",
TLSEXT_TYPE_cert_type: "TLSEXT_TYPE_cert_type",
TLSEXT_TYPE_supported_groups: "TLSEXT_TYPE_supported_groups",
TLSEXT_TYPE_ec_point_formats: "TLSEXT_TYPE_ec_point_formats",
TLSEXT_TYPE_srp: "TLSEXT_TYPE_srp",
TLSEXT_TYPE_signature_algorithms: "TLSEXT_TYPE_signature_algorithms",
TLSEXT_TYPE_use_srtp: "TLSEXT_TYPE_use_srtp",
TLSEXT_TYPE_heartbeat: "TLSEXT_TYPE_heartbeat",
TLSEXT_TYPE_application_layer_protocol_negotiation: "TLSEXT_TYPE_application_layer_protocol_negotiation",
TLSEXT_TYPE_status_request_v2: "TLSEXT_TYPE_status_request_v2",
TLSEXT_TYPE_signed_certificate_timestamp: "TLSEXT_TYPE_signed_certificate_timestamp",
TLSEXT_TYPE_client_certificate_type: "TLSEXT_TYPE_client_certificate_type",
TLSEXT_TYPE_server_certificate_type: "TLSEXT_TYPE_server_certificate_type",
TLSEXT_TYPE_padding: "TLSEXT_TYPE_padding",
TLSEXT_TYPE_encrypt_then_mac: "TLSEXT_TYPE_encrypt_then_mac",
TLSEXT_TYPE_extended_master_secret: "TLSEXT_TYPE_extended_master_secret",
TLSEXT_TYPE_token_binding: "TLSEXT_TYPE_token_binding",
TLSEXT_TYPE_cached_info: "TLSEXT_TYPE_cached_info",
TLSEXT_TYPE_SessionTicket: "TLSEXT_TYPE_SessionTicket",
TLSEXT_TYPE_renegotiate: "TLSEXT_TYPE_renegotiate",
}
type connReader struct { type connReader struct {
client bool client bool
data []byte data []byte
@ -80,18 +134,54 @@ var cipherSuiteMap map[uint16]string = map[uint16]string{
0x0039: "TLS_DHE_RSA_WITH_AES_256_CBC_SHA", 0x0039: "TLS_DHE_RSA_WITH_AES_256_CBC_SHA",
0x0035: "TLS_RSA_WITH_AES_256_CBC_SHA", 0x0035: "TLS_RSA_WITH_AES_256_CBC_SHA",
0x0030: "TLS_DH_DSS_WITH_AES_128_CBC_SHA", 0x0030: "TLS_DH_DSS_WITH_AES_128_CBC_SHA",
0x0067: "TLS_DHE_RSA_WITH_AES_128_CBC_SHA256",
0x006b: "TLS_DHE_RSA_WITH_AES_256_CBC_SHA256",
0x009e: "TLS_DHE_RSA_WITH_AES_128_GCM_SHA256",
0x009f: "TLS_DHE_RSA_WITH_AES_256_GCM_SHA384",
0x00c4: "TLS_DHE_RSA_WITH_CAMELLIA_256_CBC_SHA256",
0xc009: "TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA", 0xc009: "TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA",
0xc00a: "TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA", 0xc00a: "TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA",
0xc013: "TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA", 0xc013: "TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA",
0xc014: "TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA", 0xc014: "TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA",
0xc023: "TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA256",
0xc024: "TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA384",
0xc027: "TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA256",
0xc028: "TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA384",
0xc02b: "TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256", 0xc02b: "TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256",
0xc02c: "TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384", 0xc02c: "TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384",
0xc02f: "TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256", 0xc02f: "TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256",
0xc030: "TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384", 0xc030: "TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384",
0xc076: "TLS_ECDHE_RSA_WITH_CAMELLIA_128_CBC_SHA256",
0xc077: "TLS_ECDHE_RSA_WITH_CAMELLIA_256_CBC_SHA384",
0xcc13: "TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256",
0xcc14: "TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256",
0xcc15: "TLS_DHE_RSA_WITH_CHACHA20_POLY1305_SHA256",
0xcca9: "TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256", 0xcca9: "TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256",
0xcca8: "TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256", 0xcca8: "TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256",
} }
var whitelistedCiphers = []string{
"SSL_DHE_RSA_WITH_3DES_EDE_CBC_SHA",
"TLS_DHE_RSA_WITH_AES_128_CBC_SHA",
"TLS_DHE_RSA_WITH_AES_128_GCM_SHA256",
"TLS_DHE_RSA_WITH_AES_256_GCM_SHA384",
"TLS_DHE_RSA_WITH_CHACHA20_POLY1305_SHA256",
"TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256",
"TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384",
"TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256",
"TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA384",
"TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256",
"TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384",
"TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256",
"TLS_RSA_WITH_AES_128_CBC_SHA",
"SSL_RSA_WITH_3DES_EDE_CBC_SHA",
}
var blacklistedCiphers = []string{
"TLS_NULL_WITH_NULL_NULL",
"TLS_RSA_WITH_AES_128_CBC_SHA",
}
func getCipherSuiteName(value uint) string { func getCipherSuiteName(value uint) string {
val, ok := cipherSuiteMap[uint16(value)] val, ok := cipherSuiteMap[uint16(value)]
if !ok { if !ok {
@ -101,6 +191,79 @@ func getCipherSuiteName(value uint) string {
return val return val
} }
func isBadCipher(cname string) bool {
for _, cipher := range blacklistedCiphers {
if cipher == cname {
return true
}
}
return false
}
func gettlsExtensionName(value uint) string {
// 26-34: Unassigned
// 36-65280: Unassigned
// 65282-65535: Unassigned
if (value >= 26 && value <= 34) || (value >= 36 && value <= 65280) || (value >= 65282 && value <= 65535) {
return fmt.Sprintf("Unassigned TLS Extension %#x", value)
}
val, ok := tlsExtensionMap[uint16(value)]
if !ok {
return "UNKNOWN"
}
return val
}
func stripTLSData(record []byte, start_ind, end_ind int, len_ind int, len_size int) []byte {
var size uint = 0
if len_size < 1 || len_size > 2 {
return nil
} else if start_ind >= end_ind {
return nil
} else if len_ind >= start_ind {
return nil
}
rcopy := make([]byte, len(record))
copy(rcopy, record)
if len_size == 1 {
size = uint(rcopy[len_ind])
} else if len_size == 2 {
size = uint(binary.BigEndian.Uint16(rcopy[len_ind : len_ind+len_size]))
}
size -= uint(end_ind - start_ind)
// Put back the length size
if len_size == 1 {
rcopy[len_ind] = byte(size)
} else if len_size == 2 {
binary.BigEndian.PutUint16(rcopy[len_ind:len_ind+len_size], uint16(size))
}
// Patch the record size
rsize := binary.BigEndian.Uint16(rcopy[3:5])
rsize -= uint16(end_ind - start_ind)
binary.BigEndian.PutUint16(rcopy[3:5], rsize)
// And finally the 3 byte hello record
hsize := binary.BigEndian.Uint32(rcopy[5:9])
saved_b := hsize & 0xff000000
hsize &= 0x00ffffff
hsize -= uint32(end_ind - start_ind)
hsize |= saved_b
binary.BigEndian.PutUint32(rcopy[5:9], hsize)
result := append(rcopy[:start_ind], rcopy[end_ind:]...)
return result
}
func connectionReader(conn net.Conn, is_client bool, c chan connReader, done chan bool) { func connectionReader(conn net.Conn, is_client bool, c chan connReader, done chan bool) {
var ret_error error = nil var ret_error error = nil
buffered := []byte{} buffered := []byte{}
@ -142,6 +305,9 @@ func connectionReader(conn net.Conn, is_client bool, c chan connReader, done cha
} else if int(header[2]) < TLSGUARD_MIN_TLS_VER_MIN { } else if int(header[2]) < TLSGUARD_MIN_TLS_VER_MIN {
ret_error = errors.New("TLS protocol minor version less than expected minimum") ret_error = errors.New("TLS protocol minor version less than expected minimum")
continue continue
} else if int(header[1]) > 3 {
ret_error = errors.New("TLS protocol major version was larger than expected; maybe not TLS handshake?")
continue
} }
rtype = int(header[0]) rtype = int(header[0])
@ -184,6 +350,16 @@ func connectionReader(conn net.Conn, is_client bool, c chan connReader, done cha
} }
func isExpected(val uint, possibilities []uint) bool {
for _, pval := range possibilities {
if val == pval {
return true
}
}
return false
}
func TLSGuard(conn, conn2 net.Conn, fqdn string) error { func TLSGuard(conn, conn2 net.Conn, fqdn string) error {
x509Valid := false x509Valid := false
ndone := 0 ndone := 0
@ -196,17 +372,24 @@ func TLSGuard(conn, conn2 net.Conn, fqdn string) error {
fmt.Println("-------- STARTING HANDSHAKE LOOP") fmt.Println("-------- STARTING HANDSHAKE LOOP")
crChan := make(chan connReader) crChan := make(chan connReader)
dChan := make(chan bool, 10) dChan := make(chan bool, 10)
dChan2 := make(chan bool, 10)
go connectionReader(conn, true, crChan, dChan) go connectionReader(conn, true, crChan, dChan)
go connectionReader(conn2, false, crChan, dChan) go connectionReader(conn2, false, crChan, dChan2)
client_expected := []uint{SSL3_MT_CLIENT_HELLO}
server_expected := []uint{SSL3_MT_SERVER_HELLO}
client_expected := SSL3_MT_CLIENT_HELLO client_sess := false
server_expected := SSL3_MT_SERVER_HELLO server_sess := false
client_change_cipher := false
server_change_cipher := false
select_loop: select_loop:
for { for {
if ndone == 2 { if ndone == 2 {
fmt.Println("DONE channel got both notifications. Terminating loop.") fmt.Println("DONE channel got both notifications. Terminating loop.")
close(dChan) close(dChan)
close(dChan2)
close(crChan) close(crChan)
break break
} }
@ -239,6 +422,12 @@ select_loop:
if cr.data[TLS_RECORD_HDR_LEN] != 1 { if cr.data[TLS_RECORD_HDR_LEN] != 1 {
return errors.New(fmt.Sprintf("TLSGuard dropped connection with strange change cipher spec data (%#x bytes)", cr.data[TLS_RECORD_HDR_LEN])) return errors.New(fmt.Sprintf("TLSGuard dropped connection with strange change cipher spec data (%#x bytes)", cr.data[TLS_RECORD_HDR_LEN]))
} }
if cr.client {
client_change_cipher = true
} else {
server_change_cipher = true
}
} else if cr.rtype == SSL3_RT_ALERT { } else if cr.rtype == SSL3_RT_ALERT {
if cr.data[TLS_RECORD_HDR_LEN] == SSL3_AL_WARNING { if cr.data[TLS_RECORD_HDR_LEN] == SSL3_AL_WARNING {
fmt.Println("SSL ALERT TYPE: warning") fmt.Println("SSL ALERT TYPE: warning")
@ -248,7 +437,7 @@ select_loop:
fmt.Println("SSL ALERT TYPE UNKNOWN") fmt.Println("SSL ALERT TYPE UNKNOWN")
} }
alert_desc := int(int(cr.data[6])<<8 | int(cr.data[7])) alert_desc := int(int(cr.data[5])<<8 | int(cr.data[6]))
fmt.Println("ALERT DESCRIPTION: ", alert_desc) fmt.Println("ALERT DESCRIPTION: ", alert_desc)
if cr.data[TLS_RECORD_HDR_LEN] == SSL3_AL_FATAL { if cr.data[TLS_RECORD_HDR_LEN] == SSL3_AL_FATAL {
@ -258,49 +447,46 @@ select_loop:
} }
} }
// fmt.Println("OTHER DATA; PASSING THRU")
if cr.rtype == SSL3_RT_ALERT {
fmt.Println("ALERT = ", cr.data)
}
other.Write(cr.data) other.Write(cr.data)
continue continue
} else if cr.client {
// other.Write(cr.data)
// continue
} else if cr.rtype != SSL3_RT_HANDSHAKE { } else if cr.rtype != SSL3_RT_HANDSHAKE {
return errors.New(fmt.Sprintf("Expected TLS server handshake byte was not received [%#x vs 0x16]", cr.rtype)) return errors.New(fmt.Sprintf("Expected TLS server handshake byte was not received [%#x vs 0x16]", cr.rtype))
} }
if cr.rtype < SSL3_RT_CHANGE_CIPHER_SPEC || cr.rtype > SSL3_RT_APPLICATION_DATA {
return errors.New(fmt.Sprintf("TLSGuard dropping connection with unknown content type: %#x", cr.rtype))
}
handshakeMsg := cr.data[TLS_RECORD_HDR_LEN:] handshakeMsg := cr.data[TLS_RECORD_HDR_LEN:]
s := uint(handshakeMsg[0]) s := uint(handshakeMsg[0])
fmt.Printf("s = %#x\n", s)
// Message len, 3 bytes
if cr.rtype == SSL3_RT_HANDSHAKE {
handshakeMessageLen := handshakeMsg[1:4] handshakeMessageLen := handshakeMsg[1:4]
handshakeMessageLenInt := int(int(handshakeMessageLen[0])<<16 | int(handshakeMessageLen[1])<<8 | int(handshakeMessageLen[2])) handshakeMessageLenInt := int(int(handshakeMessageLen[0])<<16 | int(handshakeMessageLen[1])<<8 | int(handshakeMessageLen[2]))
fmt.Println("lenint = \n", handshakeMessageLenInt) fmt.Printf("s = %#x, lenint = %v, total = %d\n", s, handshakeMessageLenInt, len(cr.data))
if (client_sess || server_sess) && (client_change_cipher || server_change_cipher) {
if handshakeMessageLenInt > len(cr.data)+9 {
log.Notice("TLSGuard saw what looks like a resumed encrypted session... passing connection through")
other.Write(cr.data)
dChan <- true
dChan2 <- true
x509Valid = true
break select_loop
}
} }
if cr.client && s != uint(client_expected) { if cr.client && !isExpected(s, client_expected) {
return errors.New(fmt.Sprintf("Client sent handshake type %#x but expected %#x", s, client_expected)) return errors.New(fmt.Sprintf("Client sent handshake type %#x but expected %#x", s, client_expected))
} else if !cr.client && s != uint(server_expected) { } else if !cr.client && !isExpected(s, server_expected) {
return errors.New(fmt.Sprintf("Server sent handshake type %#x but expected %#x", s, server_expected)) return errors.New(fmt.Sprintf("Server sent handshake type %#x but expected %#x", s, server_expected))
} }
if (cr.client && s == SSL3_MT_CLIENT_HELLO) || (!cr.client && s == SSL3_MT_SERVER_HELLO) { if (cr.client && s == SSL3_MT_CLIENT_HELLO) || (!cr.client && s == SSL3_MT_SERVER_HELLO) {
rewrite := false // rewrite := false
rewrite_buf := []byte{} // rewrite_buf := []byte{}
SRC := "" SRC := ""
if s == SSL3_MT_CLIENT_HELLO { if s == SSL3_MT_CLIENT_HELLO {
SRC = "CLIENT" SRC = "CLIENT"
} else { } else {
server_expected = SSL3_MT_CERTIFICATE server_expected = []uint{SSL3_MT_CERTIFICATE, SSL3_MT_HELLO_REQUEST}
SRC = "SERVER" SRC = "SERVER"
} }
@ -319,35 +505,21 @@ select_loop:
sess_len := uint(handshakeMsg[hello_offset]) sess_len := uint(handshakeMsg[hello_offset])
fmt.Println(SRC, "HELLO SESSION ID = ", sess_len) fmt.Println(SRC, "HELLO SESSION ID = ", sess_len)
if sess_len != 0 { if cr.client && sess_len > 0 {
fmt.Printf("ALERT: %v attempting to resume session; intercepting request\n", SRC) client_sess = true
rewrite = true } else {
dcopy := make([]byte, len(cr.data)) server_sess = true
copy(dcopy, cr.data)
// Copy the bytes before the session ID start
rewrite_buf = dcopy[0 : TLS_RECORD_HDR_LEN+hello_offset+1]
// Set the session ID to 0
rewrite_buf[len(rewrite_buf)-1] = 0
// Write the new TLS record length
binary.BigEndian.PutUint16(rewrite_buf[3:5], uint16(len(dcopy)-(int(sess_len)+TLS_RECORD_HDR_LEN)))
// Write the new ClientHello length
// Starts after the first 6 bytes (record header + type byte)
orig_len := binary.BigEndian.Uint32(handshakeMsg[0:4])
// But it's only 3 bytes so mask out the first one
b1 := orig_len & 0xff000000
orig_len &= 0x00ffffff
orig_len -= uint32(sess_len)
orig_len |= b1
binary.BigEndian.PutUint32(rewrite_buf[TLS_RECORD_HDR_LEN:], orig_len)
rewrite_buf = append(rewrite_buf, dcopy[TLS_RECORD_HDR_LEN+hello_offset+int(sess_len)+1:]...)
} }
/*
hello_offset += int(sess_len) + 1 hello_offset += int(sess_len) + 1
// 2 byte cipher suite array // 2 byte cipher suite array
cs := binary.BigEndian.Uint16(handshakeMsg[hello_offset : hello_offset+2]) cs := binary.BigEndian.Uint16(handshakeMsg[hello_offset : hello_offset+2])
noCS := cs noCS := cs
fmt.Printf("cs = %v / %#x\n", noCS, noCS) fmt.Printf("cs = %v / %#x\n", noCS, noCS)
saved_ciphersuite_size_off := hello_offset
if !cr.client { if !cr.client {
fmt.Printf("SERVER selected ciphersuite: %#x (%s)\n", cs, getCipherSuiteName(uint(cs))) fmt.Printf("SERVER selected ciphersuite: %#x (%s)\n", cs, getCipherSuiteName(uint(cs)))
hello_offset += 2 hello_offset += 2
@ -356,7 +528,13 @@ select_loop:
for csind := 0; csind < int(noCS/2); csind++ { for csind := 0; csind < int(noCS/2); csind++ {
off := hello_offset + 2 + (csind * 2) off := hello_offset + 2 + (csind * 2)
cs = binary.BigEndian.Uint16(handshakeMsg[off : off+2]) cs = binary.BigEndian.Uint16(handshakeMsg[off : off+2])
fmt.Printf("%s HELLO CIPHERSUITE: %d/%d: %#x (%s)\n", SRC, csind+1, noCS/2, cs, getCipherSuiteName(uint(cs))) cname := getCipherSuiteName(uint(cs))
fmt.Printf("%s HELLO CIPHERSUITE: %d/%d: %#x (%s)\n", SRC, csind+1, noCS/2, cs, cname)
if isBadCipher(cname) {
fmt.Println("BAD CIPHER: ", cname)
}
} }
hello_offset += 2 + int(noCS) hello_offset += 2 + int(noCS)
@ -383,75 +561,23 @@ select_loop:
hello_offset += 2 hello_offset += 2
} }
var exttype uint16 = 0
if extlen > 2 {
exttype = binary.BigEndian.Uint16(handshakeMsg[hello_offset : hello_offset+2])
fmt.Println(SRC, "HELLO FIRST EXTENSION TYPE: ", exttype)
}
if cr.client { if cr.client {
ext_ctr := 0 ext_ctr := 0
for ext_ctr < int(extlen)-2 { for ext_ctr < int(extlen)-2 {
exttype := binary.BigEndian.Uint16(handshakeMsg[hello_offset : hello_offset+2])
hello_offset += 2 hello_offset += 2
ext_ctr += 2 ext_ctr += 2
fmt.Printf("PROGRESS: %v of %v, %v of %v\n", ext_ctr, extlen, hello_offset, len(handshakeMsg)) // fmt.Printf("PROGRESS: %v of %v, %v of %v\n", ext_ctr, extlen, hello_offset, len(handshakeMsg))
exttype2 := binary.BigEndian.Uint16(handshakeMsg[hello_offset : hello_offset+2]) fmt.Printf("EXTTYPE = %#x (%s)\n", exttype, gettlsExtensionName(uint(exttype)))
fmt.Printf("EXTTYPE = %v, 2 = %v\n", exttype, exttype2)
if exttype2 == TLSEXT_TYPE_server_name {
fmt.Println("CLIENT specified server_name extension:")
}
if exttype != TLSEXT_TYPE_signature_algorithms {
fmt.Println("WTF")
}
hello_offset += 2
ext_ctr += 2
inner_len := binary.BigEndian.Uint16(handshakeMsg[hello_offset : hello_offset+2]) inner_len := binary.BigEndian.Uint16(handshakeMsg[hello_offset : hello_offset+2])
// fmt.Println("INNER LEN = ", inner_len) hello_offset += int(inner_len) + 2
hello_offset += int(inner_len) ext_ctr += int(inner_len) + 2
ext_ctr += int(inner_len)
}
}
if extlen > 0 {
fmt.Printf("ALERT: %v attempting to send extensions; intercepting request\n", SRC)
rewrite = true
tocopy := cr.data
if len(rewrite_buf) > 0 {
tocopy = rewrite_buf
}
dcopy := make([]byte, len(tocopy)-int(extlen))
copy(dcopy, tocopy[0:len(tocopy)-int(extlen)])
rewrite_buf = dcopy
// Write the new TLS record length
binary.BigEndian.PutUint16(rewrite_buf[3:5], uint16(len(dcopy)-(int(sess_len)+TLS_RECORD_HDR_LEN)))
// Write the new ClientHello length
// Starts after the first 6 bytes (record header + type byte)
orig_len := binary.BigEndian.Uint32(rewrite_buf[TLS_RECORD_HDR_LEN:])
// But it's only 3 bytes so mask out the first one
b1 := orig_len & 0xff000000
orig_len &= 0x00ffffff
orig_len -= uint32(extlen)
orig_len |= b1
binary.BigEndian.PutUint32(rewrite_buf[TLS_RECORD_HDR_LEN:], orig_len)
// Write session length 0 at the end
rewrite_buf[len(rewrite_buf)-1] = 0
rewrite_buf[len(rewrite_buf)-2] = 0
}
if rewrite {
fmt.Println("TLSGuard writing back modified handshake data to server")
fmt.Printf("ORIGINAL[%d]: %v\n", len(cr.data), hex.Dump(cr.data))
fmt.Printf("NEW[%d]: %v\n", len(rewrite_buf), hex.Dump(rewrite_buf))
other.Write(rewrite_buf)
} else {
other.Write(cr.data)
} }
}*/
other.Write(cr.data)
continue continue
} }
@ -460,25 +586,19 @@ select_loop:
continue continue
} }
if !cr.client && server_expected == SSL3_MT_SERVER_HELLO { if !cr.client && isExpected(SSL3_MT_SERVER_HELLO, server_expected) {
server_expected = SSL3_MT_CERTIFICATE server_expected = []uint{SSL3_MT_CERTIFICATE}
} }
if !cr.client && s == SSL3_MT_HELLO_REQUEST { if !cr.client && s == SSL3_MT_HELLO_REQUEST {
fmt.Println("Server sent hello request") fmt.Println("Server sent hello request")
continue
} }
if s > SSL3_MT_CERTIFICATE_STATUS { if s > SSL3_MT_CERTIFICATE_STATUS {
fmt.Println("WTF: ", cr.data) fmt.Println("WTF: ", cr.data)
} }
// Message len, 3 bytes
handshakeMessageLen := handshakeMsg[1:4]
handshakeMessageLenInt := int(int(handshakeMessageLen[0])<<16 | int(handshakeMessageLen[1])<<8 | int(handshakeMessageLen[2]))
if s == SSL3_MT_CERTIFICATE { if s == SSL3_MT_CERTIFICATE {
fmt.Println("HMM")
// fmt.Printf("chunk len = %v, handshakeMsgLen = %v, slint = %v\n", len(chunk), len(handshakeMsg), handshakeMessageLenInt) // fmt.Printf("chunk len = %v, handshakeMsgLen = %v, slint = %v\n", len(chunk), len(handshakeMsg), handshakeMessageLenInt)
if len(handshakeMsg) < handshakeMessageLenInt { if len(handshakeMsg) < handshakeMessageLenInt {
return errors.New(fmt.Sprintf("len(handshakeMsg) %v < handshakeMessageLenInt %v!\n", len(handshakeMsg), handshakeMessageLenInt)) return errors.New(fmt.Sprintf("len(handshakeMsg) %v < handshakeMessageLenInt %v!\n", len(handshakeMsg), handshakeMessageLenInt))
@ -535,6 +655,7 @@ select_loop:
if x509Valid || (s == SSL3_MT_SERVER_DONE) || (s == SSL3_MT_CERTIFICATE_REQUEST) { if x509Valid || (s == SSL3_MT_SERVER_DONE) || (s == SSL3_MT_CERTIFICATE_REQUEST) {
fmt.Println("BREAKING OUT OF LOOP 1") fmt.Println("BREAKING OUT OF LOOP 1")
dChan <- true dChan <- true
dChan2 <- true
fmt.Println("BREAKING OUT OF LOOP 2") fmt.Println("BREAKING OUT OF LOOP 2")
break select_loop break select_loop
} }
@ -576,6 +697,7 @@ select_loop:
// dChan <- true // dChan <- true
close(dChan) close(dChan)
close(dChan2)
if !x509Valid { if !x509Valid {
return errors.New("Unknown error: TLS connection could not be validated") return errors.New("Unknown error: TLS connection could not be validated")

Loading…
Cancel
Save