Merged from shw_dev

shw-merge
xSmurf 7 years ago
parent 5f26317c44
commit 7472b4d828

@ -4,33 +4,23 @@ import (
"errors" "errors"
"github.com/godbus/dbus" "github.com/godbus/dbus"
"log" "log"
// "github.com/gotk3/gotk3/glib"
) )
type dbusObject struct {
dbus.BusObject
}
type dbusServer struct { type dbusServer struct {
conn *dbus.Conn conn *dbus.Conn
run bool run bool
} }
type promptData struct { func newDbusObjectAdd() (*dbusObject, error) {
Application string conn, err := dbus.SystemBus()
Icon string if err != nil {
Path string return nil, err
Address string }
Port int return &dbusObject{conn.Object("com.subgraph.Firewall", "/com/subgraph/Firewall")}, nil
IP string
Origin string
Proto string
UID int
GID int
Username string
Groupname string
Pid int
Sandbox string
OptString string
Expanded bool
Expert bool
Action int
} }
func newDbusServer() (*dbusServer, error) { func newDbusServer() (*dbusServer, error) {
@ -62,10 +52,10 @@ func newDbusServer() (*dbusServer, error) {
return ds, nil return ds, nil
} }
func (ds *dbusServer) RequestPrompt(application, icon, path, address string, port int32, ip, origin, proto string, uid, gid int32, username, groupname string, pid int32, sandbox string, /*func (ds *dbusServer) RequestPrompt(guid, application, icon, path, address string, port int32, ip, origin, proto string, uid, gid int32, username, groupname string, pid int32, sandbox string,
is_socks bool, optstring string, expanded, expert bool, action int32) (int32, string, *dbus.Error) { is_socks bool, timestamp string, optstring string, expanded, expert bool, action int32) (int32, string, *dbus.Error) {
log.Printf("request prompt: app = %s, icon = %s, path = %s, address = %s, is_socks = %v, action = %v\n", application, icon, path, address, is_socks, action) log.Printf("request prompt: app = %s, icon = %s, path = %s, address = %s / ip = %s, is_socks = %v, action = %v\n", application, icon, path, address, ip, is_socks, action)
decision := addRequest(nil, path, proto, int(pid), ip, address, int(port), int(uid), int(gid), origin, is_socks, optstring, sandbox) decision := addRequest(nil, guid, path, icon, proto, int(pid), ip, address, int(port), int(uid), int(gid), origin, timestamp, is_socks, optstring, sandbox, int(action))
log.Print("Waiting on decision...") log.Print("Waiting on decision...")
decision.Cond.L.Lock() decision.Cond.L.Lock()
for !decision.Ready { for !decision.Ready {
@ -73,6 +63,18 @@ func (ds *dbusServer) RequestPrompt(application, icon, path, address string, por
} }
log.Print("Decision returned: ", decision.Rule) log.Print("Decision returned: ", decision.Rule)
decision.Cond.L.Unlock() decision.Cond.L.Unlock()
// glib.IdleAdd(func, data)
return int32(decision.Scope), decision.Rule, nil return int32(decision.Scope), decision.Rule, nil
}*/
func (ds *dbusServer) RequestPromptAsync(guid, application, icon, path, address string, port int32, ip, origin, proto string, uid, gid int32, username, groupname string, pid int32, sandbox string,
is_socks bool, timestamp string, optstring string, expanded, expert bool, action int32) (bool, *dbus.Error) {
log.Printf("ASYNC request prompt: guid = %s, app = %s, icon = %s, path = %s, address = %s / ip = %s, is_socks = %v, action = %v\n", guid, application, icon, path, address, ip, is_socks, action)
addRequestAsync(nil, guid, path, icon, proto, int(pid), ip, address, int(port), int(uid), int(gid), origin, timestamp, is_socks, optstring, sandbox, int(action))
return true, nil
}
func (ds *dbusServer) RemovePrompt(guid string) *dbus.Error {
log.Printf("++++++++ Cancelling prompt: %s\n", guid)
removeRequest(nil, guid)
return nil
} }

@ -34,7 +34,10 @@ type decisionWaiter struct {
} }
type ruleColumns struct { type ruleColumns struct {
nrefs int
Path string Path string
GUID string
Icon string
Proto string Proto string
Pid int Pid int
Target string Target string
@ -45,23 +48,30 @@ type ruleColumns struct {
Uname string Uname string
Gname string Gname string
Origin string Origin string
Timestamp string
IsSocks bool
ForceTLS bool
Scope int Scope int
} }
var dbuso *dbusObject
var userPrefs fpPreferences var userPrefs fpPreferences
var mainWin *gtk.Window var mainWin *gtk.Window
var Notebook *gtk.Notebook var Notebook *gtk.Notebook
var globalLS *gtk.ListStore var globalLS *gtk.ListStore = nil
var globalTV *gtk.TreeView var globalTV *gtk.TreeView
var globalPromptLock = &sync.Mutex{}
var globalIcon *gtk.Image
var decisionWaiters []*decisionWaiter var decisionWaiters []*decisionWaiter
var editApp, editTarget, editPort, editUser, editGroup *gtk.Entry var editApp, editTarget, editPort, editUser, editGroup *gtk.Entry
var comboProto *gtk.ComboBoxText var comboProto *gtk.ComboBoxText
var radioOnce, radioProcess, radioParent, radioSession, radioPermanent *gtk.RadioButton var radioOnce, radioProcess, radioParent, radioSession, radioPermanent *gtk.RadioButton
var btnApprove, btnDeny, btnIgnore *gtk.Button var btnApprove, btnDeny, btnIgnore *gtk.Button
var chkUser, chkGroup *gtk.CheckButton var chkTLS, chkUser, chkGroup *gtk.CheckButton
func dumpDecisions() { func dumpDecisions() {
return
fmt.Println("XXX Total of decisions pending: ", len(decisionWaiters)) fmt.Println("XXX Total of decisions pending: ", len(decisionWaiters))
for i := 0; i < len(decisionWaiters); i++ { for i := 0; i < len(decisionWaiters); i++ {
fmt.Printf("XXX %d ready = %v, rule = %v\n", i+1, decisionWaiters[i].Ready, decisionWaiters[i].Rule) fmt.Printf("XXX %d ready = %v, rule = %v\n", i+1, decisionWaiters[i].Ready, decisionWaiters[i].Rule)
@ -69,6 +79,7 @@ func dumpDecisions() {
} }
func addDecision() *decisionWaiter { func addDecision() *decisionWaiter {
return nil
decision := decisionWaiter{Lock: &sync.Mutex{}, Ready: false, Scope: int(sgfw.APPLY_ONCE), Rule: ""} decision := decisionWaiter{Lock: &sync.Mutex{}, Ready: false, Scope: int(sgfw.APPLY_ONCE), Rule: ""}
decision.Cond = sync.NewCond(decision.Lock) decision.Cond = sync.NewCond(decision.Lock)
decisionWaiters = append(decisionWaiters, &decision) decisionWaiters = append(decisionWaiters, &decision)
@ -306,7 +317,8 @@ func createColumn(title string, id int) *gtk.TreeViewColumn {
} }
func createListStore(general bool) *gtk.ListStore { func createListStore(general bool) *gtk.ListStore {
colData := []glib.Type{glib.TYPE_INT, glib.TYPE_STRING, glib.TYPE_STRING, glib.TYPE_INT, glib.TYPE_STRING, glib.TYPE_STRING, glib.TYPE_INT, glib.TYPE_INT, glib.TYPE_INT, glib.TYPE_STRING, glib.TYPE_STRING} colData := []glib.Type{glib.TYPE_INT, glib.TYPE_STRING, glib.TYPE_STRING, glib.TYPE_STRING, glib.TYPE_STRING, glib.TYPE_INT, glib.TYPE_STRING,
glib.TYPE_STRING, glib.TYPE_INT, glib.TYPE_INT, glib.TYPE_INT, glib.TYPE_STRING, glib.TYPE_STRING, glib.TYPE_INT, glib.TYPE_STRING, glib.TYPE_INT}
listStore, err := gtk.ListStoreNew(colData...) listStore, err := gtk.ListStoreNew(colData...)
if err != nil { if err != nil {
@ -316,14 +328,80 @@ func createListStore(general bool) *gtk.ListStore {
return listStore return listStore
} }
func addRequest(listStore *gtk.ListStore, path, proto string, pid int, ipaddr, hostname string, port, uid, gid int, origin string, is_socks bool, optstring string, sandbox string) *decisionWaiter { func removeRequest(listStore *gtk.ListStore, guid string) {
removed := false
globalPromptLock.Lock()
defer globalPromptLock.Unlock()
/* XXX: This is horrible. Figure out how to do this properly. */
for ridx := 0; ridx < globalLS.IterNChildren(nil); ridx++ {
rule, _, err := getRuleByIdx(ridx)
if err != nil {
break
} else if rule.GUID == guid {
removeSelectedRule(ridx, true)
removed = true
break
}
}
if !removed {
log.Printf("Unexpected condition: SGFW requested prompt removal for non-existent GUID %v\n", guid)
}
}
func addRequestInc(listStore *gtk.ListStore, guid, path, icon, proto string, pid int, ipaddr, hostname string, port, uid, gid int,
origin string, is_socks bool, optstring string, sandbox string, action int) bool {
duplicated := false
globalPromptLock.Lock()
defer globalPromptLock.Unlock()
for ridx := 0; ridx < globalLS.IterNChildren(nil); ridx++ {
/* XXX: This is horrible. Figure out how to do this properly. */
rule, iter, err := getRuleByIdx(ridx)
if err != nil {
break
// XXX: not compared: optstring/sandbox
} else if (rule.Path == path) && (rule.Proto == proto) && (rule.Pid == pid) && (rule.Target == ipaddr) && (rule.Hostname == hostname) &&
(rule.Port == port) && (rule.UID == uid) && (rule.GID == gid) && (rule.Origin == origin) && (rule.IsSocks == is_socks) {
rule.nrefs++
err := globalLS.SetValue(iter, 0, rule.nrefs)
if err != nil {
log.Println("Error creating duplicate firewall prompt entry:", err)
break
}
fmt.Println("YES REALLY DUPLICATE: ", rule.nrefs)
duplicated = true
break
}
}
return duplicated
}
func addRequestAsync(listStore *gtk.ListStore, guid, path, icon, proto string, pid int, ipaddr, hostname string, port, uid, gid int,
origin, timestamp string, is_socks bool, optstring string, sandbox string, action int) bool {
addRequest(listStore, guid, path, icon, proto, pid, ipaddr, hostname, port, uid, gid, origin, timestamp, is_socks,
optstring, sandbox, action)
return true
}
func addRequest(listStore *gtk.ListStore, guid, path, icon, proto string, pid int, ipaddr, hostname string, port, uid, gid int,
origin, timestamp string, is_socks bool, optstring string, sandbox string, action int) *decisionWaiter {
if listStore == nil { if listStore == nil {
listStore = globalLS listStore = globalLS
waitTimes := []int{1, 2, 5, 10} waitTimes := []int{1, 2, 5, 10}
if listStore == nil { if listStore == nil {
log.Print("SGFW prompter was not ready to receive firewall request... waiting") log.Println("SGFW prompter was not ready to receive firewall request... waiting")
}
for _, wtime := range waitTimes { for _, wtime := range waitTimes {
time.Sleep(time.Duration(wtime) * time.Second) time.Sleep(time.Duration(wtime) * time.Second)
@ -333,7 +411,9 @@ func addRequest(listStore *gtk.ListStore, path, proto string, pid int, ipaddr, h
break break
} }
log.Print("SGFW prompter is still waiting...") log.Println("SGFW prompter is still waiting...")
}
} }
} }
@ -342,6 +422,18 @@ func addRequest(listStore *gtk.ListStore, path, proto string, pid int, ipaddr, h
log.Fatal("SGFW prompter GUI failed to load for unknown reasons") log.Fatal("SGFW prompter GUI failed to load for unknown reasons")
} }
if addRequestInc(listStore, guid, path, icon, proto, pid, ipaddr, hostname, port, uid, gid, origin, is_socks, optstring, sandbox, action) {
fmt.Println("REQUEST WAS DUPLICATE")
decision := addDecision()
globalPromptLock.Lock()
toggleHover()
globalPromptLock.Unlock()
return decision
} else {
fmt.Println("NOT DUPLICATE")
}
globalPromptLock.Lock()
iter := listStore.Append() iter := listStore.Append()
if is_socks { if is_socks {
@ -352,24 +444,34 @@ func addRequest(listStore *gtk.ListStore, path, proto string, pid int, ipaddr, h
} }
} }
colVals := make([]interface{}, 11) colVals := make([]interface{}, 16)
colVals[0] = 1 colVals[0] = 1
colVals[1] = path colVals[1] = guid
colVals[2] = proto colVals[2] = path
colVals[3] = pid colVals[3] = icon
colVals[4] = proto
colVals[5] = pid
if ipaddr == "" { if ipaddr == "" {
colVals[4] = "---" colVals[6] = "---"
} else { } else {
colVals[4] = ipaddr colVals[6] = ipaddr
}
colVals[7] = hostname
colVals[8] = port
colVals[9] = uid
colVals[10] = gid
colVals[11] = origin
colVals[12] = timestamp
colVals[13] = 0
if is_socks {
colVals[13] = 1
} }
colVals[5] = hostname colVals[14] = optstring
colVals[6] = port colVals[15] = action
colVals[7] = uid
colVals[8] = gid
colVals[9] = origin
colVals[10] = optstring
colNums := make([]int, len(colVals)) colNums := make([]int, len(colVals))
@ -386,6 +488,7 @@ func addRequest(listStore *gtk.ListStore, path, proto string, pid int, ipaddr, h
decision := addDecision() decision := addDecision()
dumpDecisions() dumpDecisions()
toggleHover() toggleHover()
globalPromptLock.Unlock()
return decision return decision
} }
@ -479,22 +582,42 @@ func lsGetInt(ls *gtk.ListStore, iter *gtk.TreeIter, idx int) (int, error) {
return ival.(int), nil return ival.(int), nil
} }
func makeDecision(idx int, rule string, scope int) { func makeDecision(idx int, rule string, scope int) error {
var dres bool
call := dbuso.Call("AddRuleAsync", 0, uint32(scope), rule, "*")
err := call.Store(&dres)
if err != nil {
log.Println("Error notifying SGFW of asynchronous rule addition:", err)
return err
}
fmt.Println("makeDecision remote result:", dres)
return nil
decisionWaiters[idx].Cond.L.Lock() decisionWaiters[idx].Cond.L.Lock()
decisionWaiters[idx].Rule = rule decisionWaiters[idx].Rule = rule
decisionWaiters[idx].Scope = scope decisionWaiters[idx].Scope = scope
decisionWaiters[idx].Ready = true decisionWaiters[idx].Ready = true
decisionWaiters[idx].Cond.Signal() decisionWaiters[idx].Cond.Signal()
decisionWaiters[idx].Cond.L.Unlock() decisionWaiters[idx].Cond.L.Unlock()
return nil
} }
/* Do we need to hold the lock while this is called? Stay safe... */
func toggleHover() { func toggleHover() {
mainWin.SetKeepAbove(len(decisionWaiters) > 0) nitems := globalLS.IterNChildren(nil)
mainWin.SetKeepAbove(nitems > 0)
} }
func toggleValidRuleState() { func toggleValidRuleState() {
ok := true ok := true
// XXX: Unfortunately, this can cause deadlock since it's a part of the item removal cascade
// globalPromptLock.Lock()
// defer globalPromptLock.Unlock()
if numSelections() <= 0 { if numSelections() <= 0 {
ok = false ok = false
} }
@ -536,7 +659,8 @@ func toggleValidRuleState() {
btnApprove.SetSensitive(ok) btnApprove.SetSensitive(ok)
btnDeny.SetSensitive(ok) btnDeny.SetSensitive(ok)
btnIgnore.SetSensitive(ok) // btnIgnore.SetSensitive(ok)
btnIgnore.SetSensitive(false)
} }
func createCurrentRule() (ruleColumns, error) { func createCurrentRule() (ruleColumns, error) {
@ -579,6 +703,9 @@ func createCurrentRule() (ruleColumns, error) {
rule.UID, rule.GID = 0, 0 rule.UID, rule.GID = 0, 0
rule.Uname, rule.Gname = "", "" rule.Uname, rule.Gname = "", ""
rule.ForceTLS = chkTLS.GetActive()
/* Pid int /* Pid int
Origin string */ Origin string */
@ -586,6 +713,7 @@ func createCurrentRule() (ruleColumns, error) {
} }
func clearEditor() { func clearEditor() {
globalIcon.Clear()
editApp.SetText("") editApp.SetText("")
editTarget.SetText("") editTarget.SetText("")
editPort.SetText("") editPort.SetText("")
@ -599,6 +727,7 @@ func clearEditor() {
radioPermanent.SetActive(false) radioPermanent.SetActive(false)
chkUser.SetActive(false) chkUser.SetActive(false)
chkGroup.SetActive(false) chkGroup.SetActive(false)
chkTLS.SetActive(false)
} }
func removeSelectedRule(idx int, rmdecision bool) error { func removeSelectedRule(idx int, rmdecision bool) error {
@ -617,7 +746,7 @@ func removeSelectedRule(idx int, rmdecision bool) error {
globalLS.Remove(iter) globalLS.Remove(iter)
if rmdecision { if rmdecision {
decisionWaiters = append(decisionWaiters[:idx], decisionWaiters[idx+1:]...) // decisionWaiters = append(decisionWaiters[:idx], decisionWaiters[idx+1:]...)
} }
toggleHover() toggleHover()
@ -634,6 +763,104 @@ func numSelections() int {
return int(rows.Length()) return int(rows.Length())
} }
// Needs to be locked by the caller
func getRuleByIdx(idx int) (ruleColumns, *gtk.TreeIter, error) {
rule := ruleColumns{}
path, err := gtk.TreePathNewFromString(fmt.Sprintf("%d", idx))
if err != nil {
return rule, nil, err
}
iter, err := globalLS.GetIter(path)
if err != nil {
return rule, nil, err
}
rule.nrefs, err = lsGetInt(globalLS, iter, 0)
if err != nil {
return rule, nil, err
}
rule.GUID, err = lsGetStr(globalLS, iter, 1)
if err != nil {
return rule, nil, err
}
rule.Path, err = lsGetStr(globalLS, iter, 2)
if err != nil {
return rule, nil, err
}
rule.Icon, err = lsGetStr(globalLS, iter, 3)
if err != nil {
return rule, nil, err
}
rule.Proto, err = lsGetStr(globalLS, iter, 4)
if err != nil {
return rule, nil, err
}
rule.Pid, err = lsGetInt(globalLS, iter, 5)
if err != nil {
return rule, nil, err
}
rule.Target, err = lsGetStr(globalLS, iter, 6)
if err != nil {
return rule, nil, err
}
rule.Hostname, err = lsGetStr(globalLS, iter, 7)
if err != nil {
return rule, nil, err
}
rule.Port, err = lsGetInt(globalLS, iter, 8)
if err != nil {
return rule, nil, err
}
rule.UID, err = lsGetInt(globalLS, iter, 9)
if err != nil {
return rule, nil, err
}
rule.GID, err = lsGetInt(globalLS, iter, 10)
if err != nil {
return rule, nil, err
}
rule.Origin, err = lsGetStr(globalLS, iter, 11)
if err != nil {
return rule, nil, err
}
rule.Timestamp, err = lsGetStr(globalLS, iter, 12)
if err != nil {
return rule, nil, err
}
rule.IsSocks = false
is_socks, err := lsGetInt(globalLS, iter, 13)
if err != nil {
return rule, nil, err
}
if is_socks != 0 {
rule.IsSocks = true
}
rule.Scope, err = lsGetInt(globalLS, iter, 15)
if err != nil {
return rule, nil, err
}
return rule, iter, nil
}
// Needs to be locked by the caller
func getSelectedRule() (ruleColumns, int, error) { func getSelectedRule() (ruleColumns, int, error) {
rule := ruleColumns{} rule := ruleColumns{}
@ -655,62 +882,115 @@ func getSelectedRule() (ruleColumns, int, error) {
} }
fmt.Println("lindex = ", lIndex) fmt.Println("lindex = ", lIndex)
path, err := gtk.TreePathNewFromString(fmt.Sprintf("%d", lIndex)) rule, _, err = getRuleByIdx(lIndex)
if err != nil { if err != nil {
return rule, -1, err return rule, -1, err
} }
iter, err := globalLS.GetIter(path) return rule, lIndex, nil
if err != nil {
return rule, -1, err
} }
rule.Path, err = lsGetStr(globalLS, iter, 1) func addPendingPrompts(rules []string) {
if err != nil {
return rule, -1, err for _, rule := range rules {
fields := strings.Split(rule, "|")
if len(fields) != 19 {
log.Printf("Got saved prompt message with strange data: \"%s\"", rule)
continue
} }
rule.Proto, err = lsGetStr(globalLS, iter, 2) guid := fields[0]
icon := fields[2]
path := fields[3]
address := fields[4]
port, err := strconv.Atoi(fields[5])
if err != nil { if err != nil {
return rule, -1, err log.Println("Error converting port in pending prompt message to integer:", err)
continue
} }
rule.Pid, err = lsGetInt(globalLS, iter, 3) ip := fields[6]
origin := fields[7]
proto := fields[8]
uid, err := strconv.Atoi(fields[9])
if err != nil { if err != nil {
return rule, -1, err log.Println("Error converting UID in pending prompt message to integer:", err)
continue
} }
rule.Target, err = lsGetStr(globalLS, iter, 4) gid, err := strconv.Atoi(fields[10])
if err != nil { if err != nil {
return rule, -1, err log.Println("Error converting GID in pending prompt message to integer:", err)
continue
} }
rule.Hostname, err = lsGetStr(globalLS, iter, 5) pid, err := strconv.Atoi(fields[13])
if err != nil { if err != nil {
return rule, -1, err log.Println("Error converting pid in pending prompt message to integer:", err)
continue
} }
rule.Port, err = lsGetInt(globalLS, iter, 6) sandbox := fields[14]
is_socks, err := strconv.ParseBool(fields[15])
if err != nil { if err != nil {
return rule, -1, err log.Println("Error converting SOCKS flag in pending prompt message to boolean:", err)
continue
} }
rule.UID, err = lsGetInt(globalLS, iter, 7) timestamp := fields[16]
optstring := fields[17]
action, err := strconv.Atoi(fields[18])
if err != nil { if err != nil {
return rule, -1, err log.Println("Error converting action in pending prompt message to integer:", err)
continue
}
addRequestAsync(nil, guid, path, icon, proto, int(pid), ip, address, int(port), int(uid), int(gid), origin, timestamp, is_socks, optstring, sandbox, action)
}
} }
rule.GID, err = lsGetInt(globalLS, iter, 8) func buttonAction(action string) {
globalPromptLock.Lock()
rule, idx, err := getSelectedRule()
if err != nil { if err != nil {
return rule, -1, err globalPromptLock.Unlock()
promptError("Error occurred processing request: " + err.Error())
return
} }
rule.Origin, err = lsGetStr(globalLS, iter, 9) rule, err = createCurrentRule()
if err != nil { if err != nil {
return rule, -1, err globalPromptLock.Unlock()
promptError("Error occurred constructing new rule: " + err.Error())
return
}
fmt.Println("rule = ", rule)
rulestr := action
if action == "ALLOW" && rule.ForceTLS {
rulestr += "_TLSONLY"
}
rulestr += "|" + rule.Proto + ":" + rule.Target + ":" + strconv.Itoa(rule.Port)
rulestr += "|" + sgfw.RuleModeString[sgfw.RuleMode(rule.Scope)]
fmt.Println("RULESTR = ", rulestr)
makeDecision(idx, rulestr, int(rule.Scope))
fmt.Println("Decision made.")
err = removeSelectedRule(idx, true)
globalPromptLock.Unlock()
if err == nil {
clearEditor()
} else {
promptError("Error setting new rule: " + err.Error())
} }
return rule, lIndex, nil
} }
func main() { func main() {
@ -721,6 +1001,11 @@ func main() {
return return
} }
dbuso, err = newDbusObjectAdd()
if err != nil {
log.Fatal("Failed to connect to dbus system bus: %v", err)
}
loadPreferences() loadPreferences()
gtk.Init(nil) gtk.Init(nil)
@ -811,10 +1096,18 @@ func main() {
editbox := get_vbox() editbox := get_vbox()
hbox := get_hbox() hbox := get_hbox()
lbl := get_label("Application path:") lbl := get_label("Application path:")
globalIcon, err = gtk.ImageNew()
if err != nil {
log.Fatal("Unable to create image:", err)
}
// globalIcon.SetFromIconName("firefox", gtk.ICON_SIZE_DND)
editApp = get_entry("") editApp = get_entry("")
editApp.Connect("changed", toggleValidRuleState) editApp.Connect("changed", toggleValidRuleState)
hbox.PackStart(lbl, false, false, 10) hbox.PackStart(lbl, false, false, 10)
hbox.PackStart(editApp, true, true, 50) hbox.PackStart(editApp, true, true, 10)
hbox.PackStart(globalIcon, false, false, 10)
editbox.PackStart(hbox, false, false, 5) editbox.PackStart(hbox, false, false, 5)
hbox = get_hbox() hbox = get_hbox()
@ -842,7 +1135,9 @@ func main() {
radioSession = get_radiobutton(radioOnce, "Session", false) radioSession = get_radiobutton(radioOnce, "Session", false)
radioPermanent = get_radiobutton(radioOnce, "Permanent", false) radioPermanent = get_radiobutton(radioOnce, "Permanent", false)
radioParent.SetSensitive(false) radioParent.SetSensitive(false)
hbox.PackStart(lbl, false, false, 10) chkTLS = get_checkbox("Require TLS", false)
hbox.PackStart(chkTLS, false, false, 10)
hbox.PackStart(lbl, false, false, 20)
hbox.PackStart(radioOnce, false, false, 5) hbox.PackStart(radioOnce, false, false, 5)
hbox.PackStart(radioProcess, false, false, 5) hbox.PackStart(radioProcess, false, false, 5)
hbox.PackStart(radioParent, false, false, 5) hbox.PackStart(radioParent, false, false, 5)
@ -872,94 +1167,55 @@ func main() {
box.PackStart(scrollbox, false, true, 5) box.PackStart(scrollbox, false, true, 5)
tv.AppendColumn(createColumn("#", 0)) tv.AppendColumn(createColumn("#", 0))
tv.AppendColumn(createColumn("Path", 1))
tv.AppendColumn(createColumn("Protocol", 2))
tv.AppendColumn(createColumn("PID", 3))
tv.AppendColumn(createColumn("IP Address", 4))
tv.AppendColumn(createColumn("Hostname", 5))
tv.AppendColumn(createColumn("Port", 6))
tv.AppendColumn(createColumn("UID", 7))
tv.AppendColumn(createColumn("GID", 8))
tv.AppendColumn(createColumn("Origin", 9))
tv.AppendColumn(createColumn("Details", 10))
listStore := createListStore(true) guidcol := createColumn("GUID", 1)
globalLS = listStore guidcol.SetVisible(false)
tv.AppendColumn(guidcol)
tv.SetModel(listStore) tv.AppendColumn(createColumn("Path", 2))
btnApprove.Connect("clicked", func() { icol := createColumn("Icon", 3)
rule, idx, err := getSelectedRule() icol.SetVisible(false)
if err != nil { tv.AppendColumn(icol)
promptError("Error occurred processing request: " + err.Error())
return
}
rule, err = createCurrentRule() tv.AppendColumn(createColumn("Protocol", 4))
if err != nil { tv.AppendColumn(createColumn("PID", 5))
promptError("Error occurred constructing new rule: " + err.Error()) tv.AppendColumn(createColumn("IP Address", 6))
return tv.AppendColumn(createColumn("Hostname", 7))
} tv.AppendColumn(createColumn("Port", 8))
tv.AppendColumn(createColumn("UID", 9))
tv.AppendColumn(createColumn("GID", 10))
tv.AppendColumn(createColumn("Origin", 11))
tv.AppendColumn(createColumn("Timestamp", 12))
fmt.Println("rule = ", rule) scol := createColumn("Is SOCKS", 13)
rulestr := "ALLOW|" + rule.Proto + ":" + rule.Target + ":" + strconv.Itoa(rule.Port) scol.SetVisible(false)
fmt.Println("RULESTR = ", rulestr) tv.AppendColumn(scol)
makeDecision(idx, rulestr, int(rule.Scope))
fmt.Println("Decision made.")
err = removeSelectedRule(idx, true)
if err == nil {
clearEditor()
} else {
promptError("Error setting new rule: " + err.Error())
}
})
btnDeny.Connect("clicked", func() { tv.AppendColumn(createColumn("Details", 14))
rule, idx, err := getSelectedRule()
if err != nil {
promptError("Error occurred processing request: " + err.Error())
return
}
rule, err = createCurrentRule() acol := createColumn("Scope", 15)
if err != nil { acol.SetVisible(false)
promptError("Error occurred constructing new rule: " + err.Error()) tv.AppendColumn(acol)
return
}
fmt.Println("rule = ", rule) listStore := createListStore(true)
rulestr := "DENY|" + rule.Proto + ":" + rule.Target + ":" + strconv.Itoa(rule.Port) globalLS = listStore
fmt.Println("RULESTR = ", rulestr)
makeDecision(idx, rulestr, int(rule.Scope))
fmt.Println("Decision made.")
err = removeSelectedRule(idx, true)
if err == nil {
clearEditor()
} else {
promptError("Error setting new rule: " + err.Error())
}
})
btnIgnore.Connect("clicked", func() { tv.SetModel(listStore)
_, idx, err := getSelectedRule()
if err != nil {
promptError("Error occurred processing request: " + err.Error())
return
}
makeDecision(idx, "", 0) btnApprove.Connect("clicked", func() {
fmt.Println("Decision made.") buttonAction("ALLOW")
err = removeSelectedRule(idx, true)
if err == nil {
clearEditor()
} else {
promptError("Error setting new rule: " + err.Error())
}
}) })
btnDeny.Connect("clicked", func() {
buttonAction("DENY")
})
// btnIgnore.Connect("clicked", buttonAction)
// tv.SetActivateOnSingleClick(true) // tv.SetActivateOnSingleClick(true)
tv.Connect("row-activated", func() { tv.Connect("row-activated", func() {
globalPromptLock.Lock()
seldata, _, err := getSelectedRule() seldata, _, err := getSelectedRule()
globalPromptLock.Unlock()
if err != nil { if err != nil {
promptError("Unexpected error reading selected rule: " + err.Error()) promptError("Unexpected error reading selected rule: " + err.Error())
return return
@ -967,6 +1223,12 @@ func main() {
editApp.SetText(seldata.Path) editApp.SetText(seldata.Path)
if seldata.Icon != "" {
globalIcon.SetFromIconName(seldata.Icon, gtk.ICON_SIZE_DND)
} else {
globalIcon.Clear()
}
if seldata.Hostname != "" { if seldata.Hostname != "" {
editTarget.SetText(seldata.Hostname) editTarget.SetText(seldata.Hostname)
} else { } else {
@ -974,13 +1236,14 @@ func main() {
} }
editPort.SetText(strconv.Itoa(seldata.Port)) editPort.SetText(strconv.Itoa(seldata.Port))
radioOnce.SetActive(true) radioOnce.SetActive(seldata.Scope == int(sgfw.APPLY_ONCE))
radioProcess.SetActive(false)
radioProcess.SetSensitive(seldata.Pid > 0) radioProcess.SetSensitive(seldata.Pid > 0)
radioParent.SetActive(false) radioParent.SetActive(false)
radioSession.SetActive(false) radioSession.SetActive(seldata.Scope == int(sgfw.APPLY_SESSION))
radioPermanent.SetActive(false) radioPermanent.SetActive(seldata.Scope == int(sgfw.APPLY_FOREVER))
comboProto.SetActiveID(seldata.Proto) comboProto.SetActiveID(seldata.Proto)
chkTLS.SetActive(seldata.IsSocks)
if seldata.Uname != "" { if seldata.Uname != "" {
editUser.SetText(seldata.Uname) editUser.SetText(seldata.Uname)
@ -1000,7 +1263,6 @@ func main() {
chkUser.SetActive(false) chkUser.SetActive(false)
chkGroup.SetActive(false) chkGroup.SetActive(false)
return return
}) })
@ -1023,5 +1285,16 @@ func main() {
mainWin.ShowAll() mainWin.ShowAll()
// mainWin.SetKeepAbove(true) // mainWin.SetKeepAbove(true)
var dres = []string{}
call := dbuso.Call("GetPendingRequests", 0, "*")
err = call.Store(&dres)
if err != nil {
errmsg := "Could not query running SGFW instance (maybe it's not running?): " + err.Error()
promptError(errmsg)
} else {
addPendingPrompts(dres)
}
gtk.Main() gtk.Main()
} }

@ -199,6 +199,74 @@ func (ds *dbusServer) DeleteRule(id uint32) *dbus.Error {
return nil return nil
} }
func (ds *dbusServer) GetPendingRequests(policy string) ([]string, *dbus.Error) {
log.Debug("+++ GetPendingRequests()")
ds.fw.lock.Lock()
defer ds.fw.lock.Unlock()
pending_data := make([]string, 0)
for pname := range ds.fw.policyMap {
policy := ds.fw.policyMap[pname]
pqueue := policy.pendingQueue
for _, pc := range pqueue {
addr := pc.hostname()
if addr == "" {
addr = pc.dst().String()
}
dststr := ""
if pc.dst() != nil {
dststr = pc.dst().String()
} else {
dststr = addr + " (via proxy resolver)"
}
pstr := ""
pstr += pc.getGUID() + "|"
pstr += policy.application + "|"
pstr += policy.icon + "|"
pstr += policy.path + "|"
pstr += addr + "|"
pstr += strconv.FormatUint(uint64(pc.dstPort()), 10) + "|"
pstr += dststr + "|"
pstr += pc.src().String() + "|"
pstr += pc.proto() + "|"
pstr += strconv.FormatInt(int64(pc.procInfo().UID), 10) + "|"
pstr += strconv.FormatInt(int64(pc.procInfo().GID), 10) + "|"
pstr += uidToUser(pc.procInfo().UID) + "|"
pstr += gidToGroup(pc.procInfo().GID) + "|"
pstr += strconv.FormatInt(int64(pc.procInfo().Pid), 10) + "|"
pstr += pc.sandbox() + "|"
pstr += strconv.FormatBool(pc.socks()) + "|"
pstr += pc.getTimestamp() + "|"
pstr += pc.getOptString() + "|"
pstr += strconv.FormatUint(uint64(FirewallConfig.DefaultActionID), 10)
pending_data = append(pending_data, pstr)
}
}
return pending_data, nil
}
func (ds *dbusServer) AddRuleAsync(scope uint32, rule string, policy string) (bool, *dbus.Error) {
log.Warningf("AddRuleAsync %v, %v / %v\n", scope, rule, policy)
ds.fw.lock.Lock()
defer ds.fw.lock.Unlock()
prule := PendingRule{rule: rule, scope: int(scope), policy: policy}
for pname := range ds.fw.policyMap {
log.Debug("+++ Adding prule to policy")
ds.fw.policyMap[pname].rulesPending = append(ds.fw.policyMap[pname].rulesPending, prule)
}
return true, nil
}
func (ds *dbusServer) UpdateRule(rule DbusRule) *dbus.Error { func (ds *dbusServer) UpdateRule(rule DbusRule) *dbus.Error {
log.Debugf("UpdateRule %v", rule) log.Debugf("UpdateRule %v", rule)
ds.fw.lock.Lock() ds.fw.lock.Lock()
@ -268,11 +336,6 @@ func (ds *dbusServer) SetConfig(key string, val dbus.Variant) *dbus.Error {
return nil return nil
} }
func (ds *dbusServer) prompt(p *Policy) {
log.Info("prompting...")
ds.prompter.prompt(p)
}
func (ob *dbusObjectP) alertRule(data string) { func (ob *dbusObjectP) alertRule(data string) {
ob.Call("com.subgraph.fwprompt.EventNotifier.Alert", 0, data) ob.Call("com.subgraph.fwprompt.EventNotifier.Alert", 0, data)
} }

@ -6,8 +6,6 @@ import (
"strings" "strings"
"sync" "sync"
// "encoding/binary"
// nfnetlink "github.com/subgraph/go-nfnetlink" // nfnetlink "github.com/subgraph/go-nfnetlink"
"github.com/google/gopacket/layers" "github.com/google/gopacket/layers"
nfqueue "github.com/subgraph/go-nfnetlink/nfqueue" nfqueue "github.com/subgraph/go-nfnetlink/nfqueue"
@ -15,6 +13,7 @@ import (
"net" "net"
"os" "os"
"syscall" "syscall"
"time"
"unsafe" "unsafe"
) )
@ -52,6 +51,10 @@ type pendingConnection interface {
drop() drop()
setPrompting(bool) setPrompting(bool)
getPrompting() bool getPrompting() bool
setPrompter(*prompter)
getPrompter() *prompter
getGUID() string
getTimestamp() string
print() string print() string
} }
@ -62,6 +65,24 @@ type pendingPkt struct {
pinfo *procsnitch.Info pinfo *procsnitch.Info
optstring string optstring string
prompting bool prompting bool
prompter *prompter
guid string
timestamp time.Time
}
/* Not a *REAL* GUID */
func genGUID() string {
frnd, err := os.Open("/dev/urandom")
if err != nil {
log.Fatal("Error reading random data source:", err)
}
rndb := make([]byte, 16)
frnd.Read(rndb)
frnd.Close()
guid := fmt.Sprintf("%x-%x-%x-%x", rndb[0:4], rndb[4:8], rndb[8:12], rndb[12:])
return guid
} }
func getEmptyPInfo() *procsnitch.Info { func getEmptyPInfo() *procsnitch.Info {
@ -79,6 +100,10 @@ func (pp *pendingPkt) sandbox() string {
return pp.pinfo.Sandbox return pp.pinfo.Sandbox
} }
func (pc *pendingPkt) getTimestamp() string {
return pc.timestamp.Format("15:04:05.00")
}
func (pp *pendingPkt) socks() bool { func (pp *pendingPkt) socks() bool {
return false return false
} }
@ -165,6 +190,22 @@ func (pp *pendingPkt) drop() {
pp.pkt.Accept() pp.pkt.Accept()
} }
func (pp *pendingPkt) setPrompter(val *prompter) {
pp.prompter = val
}
func (pp *pendingPkt) getPrompter() *prompter {
return pp.prompter
}
func (pp *pendingPkt) getGUID() string {
if pp.guid == "" {
pp.guid = genGUID()
}
return pp.guid
}
func (pp *pendingPkt) getPrompting() bool { func (pp *pendingPkt) getPrompting() bool {
return pp.prompting return pp.prompting
} }
@ -177,6 +218,12 @@ func (pp *pendingPkt) print() string {
return printPacket(pp.pkt, pp.name, pp.pinfo) return printPacket(pp.pkt, pp.name, pp.pinfo)
} }
type PendingRule struct {
rule string
scope int
policy string
}
type Policy struct { type Policy struct {
fw *Firewall fw *Firewall
path string path string
@ -187,6 +234,7 @@ type Policy struct {
pendingQueue []pendingConnection pendingQueue []pendingConnection
promptInProgress bool promptInProgress bool
lock sync.Mutex lock sync.Mutex
rulesPending []PendingRule
} }
func (fw *Firewall) PolicyForPath(path string) *Policy { func (fw *Firewall) PolicyForPath(path string) *Policy {
@ -240,17 +288,25 @@ func (fw *Firewall) policyForPath(path string) *Policy {
return fw.policyMap[path] return fw.policyMap[path]
} }
func (p *Policy) processPacket(pkt *nfqueue.NFQPacket, pinfo *procsnitch.Info, optstr string) { func (p *Policy) processPacket(pkt *nfqueue.NFQPacket, timestamp time.Time, pinfo *procsnitch.Info, optstr string) {
fmt.Println("policy processPacket()")
/* hbytes, err := pkt.GetHWAddr()
if err != nil {
log.Notice("Failed to get HW address underlying packet: ", err)
} else { log.Notice("got hwaddr: ", hbytes) } */
p.lock.Lock() p.lock.Lock()
defer p.lock.Unlock() defer p.lock.Unlock()
dstb := pkt.Packet.NetworkLayer().NetworkFlow().Dst().Raw() dstb := pkt.Packet.NetworkLayer().NetworkFlow().Dst().Raw()
dstip := net.IP(dstb) dstip := net.IP(dstb)
srcip := net.IP(pkt.Packet.NetworkLayer().NetworkFlow().Src().Raw()) srcip := net.IP(pkt.Packet.NetworkLayer().NetworkFlow().Src().Raw())
/* Can we pass this through quickly? */
/* this probably isn't a performance enhancement. */
/*_, dstp := getPacketPorts(pkt)
fres := p.rules.filter(pkt, srcip, dstip, dstp, dstip.String(), pinfo, optstr)
if fres == FILTER_ALLOW {
fmt.Printf("Packet passed wildcard rules without requiring DNS lookup; accepting: %s:%d\n", dstip, dstp)
pkt.Accept()
return
}*/
name := p.fw.dns.Lookup(dstip, pinfo.Pid) name := p.fw.dns.Lookup(dstip, pinfo.Pid)
log.Infof("Lookup(%s): %s", dstip.String(), name) log.Infof("Lookup(%s): %s", dstip.String(), name)
@ -267,7 +323,7 @@ func (p *Policy) processPacket(pkt *nfqueue.NFQPacket, pinfo *procsnitch.Info, o
case FILTER_ALLOW: case FILTER_ALLOW:
pkt.Accept() pkt.Accept()
case FILTER_PROMPT: case FILTER_PROMPT:
p.processPromptResult(&pendingPkt{pol: p, name: name, pkt: pkt, pinfo: pinfo, optstring: optstr, prompting: false}) p.processPromptResult(&pendingPkt{pol: p, name: name, pkt: pkt, pinfo: pinfo, optstring: optstr, prompter: nil, timestamp: timestamp, prompting: false})
default: default:
log.Warningf("Unexpected filter result: %d", result) log.Warningf("Unexpected filter result: %d", result)
} }
@ -276,33 +332,27 @@ func (p *Policy) processPacket(pkt *nfqueue.NFQPacket, pinfo *procsnitch.Info, o
func (p *Policy) processPromptResult(pc pendingConnection) { func (p *Policy) processPromptResult(pc pendingConnection) {
p.pendingQueue = append(p.pendingQueue, pc) p.pendingQueue = append(p.pendingQueue, pc)
//fmt.Println("processPromptResult(): p.promptInProgress = ", p.promptInProgress) //fmt.Println("processPromptResult(): p.promptInProgress = ", p.promptInProgress)
if DoMultiPrompt || (!DoMultiPrompt && !p.promptInProgress) { //if DoMultiPrompt || (!DoMultiPrompt && !p.promptInProgress) {
// if !p.promptInProgress {
p.promptInProgress = true p.promptInProgress = true
go p.fw.dbus.prompt(p) go p.fw.dbus.prompter.prompt(p)
} // }
} }
func (p *Policy) nextPending() (pendingConnection, bool) { func (p *Policy) nextPending() (pendingConnection, bool) {
p.lock.Lock() p.lock.Lock()
defer p.lock.Unlock() defer p.lock.Unlock()
if !DoMultiPrompt {
if len(p.pendingQueue) == 0 {
return nil, true
}
return p.pendingQueue[0], false
}
if len(p.pendingQueue) == 0 { if len(p.pendingQueue) == 0 {
return nil, true return nil, true
} }
// for len(p.pendingQueue) != 0 {
for i := 0; i < len(p.pendingQueue); i++ { for i := 0; i < len(p.pendingQueue); i++ {
// fmt.Printf("XXX: pendingQueue %v of %v: %v\n", i, len(p.pendingQueue), p.pendingQueue[i])
if !p.pendingQueue[i].getPrompting() { if !p.pendingQueue[i].getPrompting() {
return p.pendingQueue[i], false return p.pendingQueue[i], false
} }
} }
// }
return nil, false return nil, false
} }
@ -329,6 +379,7 @@ func (p *Policy) processNewRule(r *Rule, scope FilterScope) bool {
if scope != APPLY_ONCE { if scope != APPLY_ONCE {
p.rules = append(p.rules, r) p.rules = append(p.rules, r)
} }
fmt.Println("----------------------- processNewRule()")
p.filterPending(r) p.filterPending(r)
if len(p.pendingQueue) == 0 { if len(p.pendingQueue) == 0 {
p.promptInProgress = false p.promptInProgress = false
@ -372,6 +423,15 @@ func (p *Policy) filterPending(rule *Rule) {
remaining := []pendingConnection{} remaining := []pendingConnection{}
for _, pc := range p.pendingQueue { for _, pc := range p.pendingQueue {
if rule.match(pc.src(), pc.dst(), pc.dstPort(), pc.hostname(), pc.proto(), pc.procInfo().UID, pc.procInfo().GID, uidToUser(pc.procInfo().UID), gidToGroup(pc.procInfo().GID), pc.procInfo().Sandbox) { if rule.match(pc.src(), pc.dst(), pc.dstPort(), pc.hostname(), pc.proto(), pc.procInfo().UID, pc.procInfo().GID, uidToUser(pc.procInfo().UID), gidToGroup(pc.procInfo().GID), pc.procInfo().Sandbox) {
prompter := pc.getPrompter()
if prompter == nil {
fmt.Println("-------- prompter = NULL")
} else {
call := prompter.dbusObj.Call("com.subgraph.FirewallPrompt.RemovePrompt", 0, pc.getGUID())
fmt.Println("CAAAAAAAAAAAAAAALL = ", call)
}
log.Infof("Adding rule for: %s", rule.getString(FirewallConfig.LogRedact)) log.Infof("Adding rule for: %s", rule.getString(FirewallConfig.LogRedact))
// log.Noticef("%s > %s", rule.getString(FirewallConfig.LogRedact), pc.print()) // log.Noticef("%s > %s", rule.getString(FirewallConfig.LogRedact), pc.print())
if rule.rtype == RULE_ACTION_ALLOW { if rule.rtype == RULE_ACTION_ALLOW {
@ -442,7 +502,8 @@ func printPacket(pkt *nfqueue.NFQPacket, hostname string, pinfo *procsnitch.Info
return fmt.Sprintf("%s %s %s:%d -> %s:%d", pinfo.ExePath, proto, SrcIp, SrcPort, name, DstPort) return fmt.Sprintf("%s %s %s:%d -> %s:%d", pinfo.ExePath, proto, SrcIp, SrcPort, name, DstPort)
} }
func (fw *Firewall) filterPacket(pkt *nfqueue.NFQPacket) { func (fw *Firewall) filterPacket(pkt *nfqueue.NFQPacket, timestamp time.Time) {
fmt.Println("firewall: filterPacket()")
isudp := pkt.Packet.Layer(layers.LayerTypeUDP) != nil isudp := pkt.Packet.Layer(layers.LayerTypeUDP) != nil
if basicAllowPacket(pkt) { if basicAllowPacket(pkt) {
@ -520,7 +581,7 @@ func (fw *Firewall) filterPacket(pkt *nfqueue.NFQPacket) {
*/ */
policy := fw.PolicyForPathAndSandbox(ppath, pinfo.Sandbox) policy := fw.PolicyForPathAndSandbox(ppath, pinfo.Sandbox)
//log.Notice("XXX: flunked basicallowpacket; policy = ", policy) //log.Notice("XXX: flunked basicallowpacket; policy = ", policy)
policy.processPacket(pkt, pinfo, optstring) policy.processPacket(pkt, timestamp, pinfo, optstring)
} }
func readFileDirect(filename string) ([]byte, error) { func readFileDirect(filename string) ([]byte, error) {

@ -2,24 +2,18 @@ package sgfw
import ( import (
"fmt" "fmt"
"net" "os"
"os/user" "os/user"
"strconv" "strconv"
"strings" "strings"
"sync" "sync"
"syscall"
"time" "time"
"github.com/godbus/dbus" "github.com/godbus/dbus"
"github.com/subgraph/fw-daemon/proc-coroner" "github.com/subgraph/fw-daemon/proc-coroner"
) )
var DoMultiPrompt = true
const MAX_PROMPTS = 5
var outstandingPrompts = 0
var promptLock = &sync.Mutex{}
func newPrompter(conn *dbus.Conn) *prompter { func newPrompter(conn *dbus.Conn) *prompter {
p := new(prompter) p := new(prompter)
p.cond = sync.NewCond(&p.lock) p.cond = sync.NewCond(&p.lock)
@ -42,6 +36,7 @@ func (p *prompter) prompt(policy *Policy) {
defer p.lock.Unlock() defer p.lock.Unlock()
_, ok := p.policyMap[policy.sandbox+"|"+policy.path] _, ok := p.policyMap[policy.sandbox+"|"+policy.path]
if ok { if ok {
p.cond.Signal()
return return
} }
p.policyMap[policy.sandbox+"|"+policy.path] = policy p.policyMap[policy.sandbox+"|"+policy.path] = policy
@ -51,34 +46,21 @@ func (p *prompter) prompt(policy *Policy) {
} }
func (p *prompter) promptLoop() { func (p *prompter) promptLoop() {
p.lock.Lock() // p.lock.Lock()
for { for {
// fmt.Println("XXX: promptLoop() outer") p.processNextPacket()
for p.processNextPacket() {
// fmt.Println("XXX: promptLoop() inner")
}
// fmt.Println("promptLoop() wait")
p.cond.Wait()
} }
} }
func (p *prompter) processNextPacket() bool { func (p *prompter) processNextPacket() bool {
//fmt.Println("processNextPacket()")
var pc pendingConnection = nil var pc pendingConnection = nil
if !DoMultiPrompt {
pc, _ = p.nextConnection()
if pc == nil {
return false
}
p.lock.Unlock()
defer p.lock.Lock()
p.processConnection(pc)
return true
}
empty := true empty := true
for { for {
p.lock.Lock()
pc, empty = p.nextConnection() pc, empty = p.nextConnection()
p.lock.Unlock()
//fmt.Println("XXX: processNextPacket() loop; empty = ", empty, " / pc = ", pc) //fmt.Println("XXX: processNextPacket() loop; empty = ", empty, " / pc = ", pc)
if pc == nil && empty { if pc == nil && empty {
return false return false
@ -88,49 +70,150 @@ func (p *prompter) processNextPacket() bool {
break break
} }
} }
p.lock.Unlock()
defer p.lock.Lock()
// fmt.Println("XXX: Waiting for prompt lock go...")
for {
promptLock.Lock()
if outstandingPrompts >= MAX_PROMPTS {
promptLock.Unlock()
continue
}
if pc.getPrompting() { if pc.getPrompting() {
log.Debugf("Skipping over already prompted connection") log.Debugf("Skipping over already prompted connection")
promptLock.Unlock() return false
continue
} }
break
}
// fmt.Println("XXX: Passed prompt lock!")
outstandingPrompts++
// fmt.Println("XXX: Incremented outstanding to ", outstandingPrompts)
promptLock.Unlock()
// if !pc.getPrompting() {
pc.setPrompting(true) pc.setPrompting(true)
fmt.Println("processConnection")
go p.processConnection(pc) go p.processConnection(pc)
// }
return true return true
} }
func processReturn(pc pendingConnection) { type PC2FDMapping struct {
promptLock.Lock() guid string
outstandingPrompts-- inode uint64
// fmt.Println("XXX: Return decremented outstanding to ", outstandingPrompts) fd int
promptLock.Unlock() fdpath string
pc.setPrompting(false) prompter *prompter
}
var PC2FDMap = map[string]PC2FDMapping{}
var PC2FDMapLock = &sync.Mutex{}
var PC2FDMapRunning = false
func monitorPromptFDs(pc pendingConnection) {
guid := pc.getGUID()
pid := pc.procInfo().Pid
inode := pc.procInfo().Inode
fd := pc.procInfo().FD
prompter := pc.getPrompter()
fmt.Printf("ADD TO MONITOR: %v | %v / %v / %v\n", pc.policy().application, guid, pid, fd)
if pid == -1 || fd == -1 || prompter == nil {
log.Warning("Unexpected error condition occurred while adding socket fd to monitor")
return
}
PC2FDMapLock.Lock()
defer PC2FDMapLock.Unlock()
fdpath := fmt.Sprintf("/proc/%d/fd/%d", pid, fd)
PC2FDMap[guid] = PC2FDMapping{guid: guid, inode: inode, fd: fd, fdpath: fdpath, prompter: prompter}
return
}
func monitorPromptFDLoop() {
fmt.Println("++++++++++= monitorPromptFDLoop()")
for true {
delete_guids := []string{}
PC2FDMapLock.Lock()
fmt.Println("++++ nentries = ", len(PC2FDMap))
for guid, fdmon := range PC2FDMap {
fmt.Println("ENTRY:", fdmon)
lsb, err := os.Stat(fdmon.fdpath)
if err != nil {
log.Warningf("Error looking up socket \"%s\": %v\n", fdmon.fdpath, err)
delete_guids = append(delete_guids, guid)
continue
}
sb, ok := lsb.Sys().(*syscall.Stat_t)
if !ok {
log.Warning("Not a syscall.Stat_t")
delete_guids = append(delete_guids, guid)
continue
}
inode := sb.Ino
fmt.Println("+++ INODE = ", inode)
if inode != fdmon.inode {
fmt.Printf("inode mismatch: %v vs %v\n", inode, fdmon.inode)
delete_guids = append(delete_guids, guid)
}
}
fmt.Println("guids to delete: ", delete_guids)
saved_mappings := []PC2FDMapping{}
for _, guid := range delete_guids {
saved_mappings = append(saved_mappings, PC2FDMap[guid])
delete(PC2FDMap, guid)
}
PC2FDMapLock.Unlock()
for _, mapping := range saved_mappings {
call := mapping.prompter.dbusObj.Call("com.subgraph.FirewallPrompt.RemovePrompt", 0, mapping.guid)
fmt.Println("DISPOSING CALL = ", call)
prompter := mapping.prompter
prompter.lock.Lock()
for _, policy := range prompter.policyQueue {
policy.lock.Lock()
pcind := 0
for pcind < len(policy.pendingQueue) {
if policy.pendingQueue[pcind].getGUID() == mapping.guid {
fmt.Println("-------------- found guid to remove")
policy.pendingQueue = append(policy.pendingQueue[:pcind], policy.pendingQueue[pcind+1:]...)
} else {
pcind++
}
}
policy.lock.Unlock()
}
prompter.lock.Unlock()
}
fmt.Println("++++++++++= monitorPromptFDLoop WAIT")
time.Sleep(5 * time.Second)
}
} }
func (p *prompter) processConnection(pc pendingConnection) { func (p *prompter) processConnection(pc pendingConnection) {
var scope int32 var scope int32
var dres bool
var rule string var rule string
if DoMultiPrompt { if !PC2FDMapRunning {
defer processReturn(pc) PC2FDMapLock.Lock()
if !PC2FDMapRunning {
PC2FDMapRunning = true
PC2FDMapLock.Unlock()
go monitorPromptFDLoop()
} else {
PC2FDMapLock.Unlock()
}
}
if pc.getPrompter() == nil {
pc.setPrompter(p)
} }
addr := pc.hostname() addr := pc.hostname()
@ -144,10 +227,12 @@ func (p *prompter) processConnection(pc pendingConnection) {
if pc.dst() != nil { if pc.dst() != nil {
dststr = pc.dst().String() dststr = pc.dst().String()
} else { } else {
dststr = addr + " (proxy to resolve)" dststr = addr + " (via proxy resolver)"
} }
call := p.dbusObj.Call("com.subgraph.FirewallPrompt.RequestPrompt", 0, monitorPromptFDs(pc)
call := p.dbusObj.Call("com.subgraph.FirewallPrompt.RequestPromptAsync", 0,
pc.getGUID(),
policy.application, policy.application,
policy.icon, policy.icon,
policy.path, policy.path,
@ -163,18 +248,58 @@ func (p *prompter) processConnection(pc pendingConnection) {
int32(pc.procInfo().Pid), int32(pc.procInfo().Pid),
pc.sandbox(), pc.sandbox(),
pc.socks(), pc.socks(),
pc.getTimestamp(),
pc.getOptString(), pc.getOptString(),
FirewallConfig.PromptExpanded, FirewallConfig.PromptExpanded,
FirewallConfig.PromptExpert, FirewallConfig.PromptExpert,
int32(FirewallConfig.DefaultActionID)) int32(FirewallConfig.DefaultActionID))
err := call.Store(&scope, &rule)
err := call.Store(&dres)
if err != nil { if err != nil {
log.Warningf("Error sending dbus RequestPrompt message: %v", err) log.Warningf("Error sending dbus async RequestPrompt message: %v", err)
policy.removePending(pc) policy.removePending(pc)
pc.drop() pc.drop()
return return
} }
if !dres {
fmt.Println("Unexpected: fw-prompt async RequestPrompt message returned:", dres)
}
return
/* p.dbusObj.Go("com.subgraph.FirewallPrompt.RequestPrompt", 0, callChan,
pc.getGUID(),
policy.application,
policy.icon,
policy.path,
addr,
int32(pc.dstPort()),
dststr,
pc.src().String(),
pc.proto(),
int32(pc.procInfo().UID),
int32(pc.procInfo().GID),
uidToUser(pc.procInfo().UID),
gidToGroup(pc.procInfo().GID),
int32(pc.procInfo().Pid),
pc.sandbox(),
pc.socks(),
pc.getOptString(),
FirewallConfig.PromptExpanded,
FirewallConfig.PromptExpert,
int32(FirewallConfig.DefaultActionID))
saveChannel(callChan, false, true)
/* err := call.Store(&scope, &rule)
if err != nil {
log.Warningf("Error sending dbus RequestPrompt message: %v", err)
policy.removePending(pc)
pc.drop()
return
} */
// the prompt sends: // the prompt sends:
// ALLOW|dest or DENY|dest // ALLOW|dest or DENY|dest
// //
@ -194,17 +319,19 @@ func (p *prompter) processConnection(pc pendingConnection) {
} }
tempRule := fmt.Sprintf("%s|%s", toks[0], toks[1]) tempRule := fmt.Sprintf("%s|%s", toks[0], toks[1])
tempRule += "||-1:-1|" + sandbox + "|"
if pc.src() != nil && !pc.src().Equal(net.ParseIP("127.0.0.1")) && sandbox != "" { if pc.src() != nil && !pc.src().IsLoopback() && sandbox != "" {
//if !strings.HasSuffix(rule, "SYSTEM") && !strings.HasSuffix(rule, "||") { //if !strings.HasSuffix(rule, "SYSTEM") && !strings.HasSuffix(rule, "||") {
//rule += "||" //rule += "||"
//} //}
//ule += "|||" + pc.src().String() //ule += "|||" + pc.src().String()
tempRule += "||-1:-1|" + sandbox + "|" + pc.src().String() // tempRule += "||-1:-1|" + sandbox + "|" + pc.src().String()
tempRule += pc.src().String()
} else { } else {
tempRule += "||-1:-1|" + sandbox + "|" // tempRule += "||-1:-1|" + sandbox + "|"
} }
r, err := policy.parseRule(tempRule, false) r, err := policy.parseRule(tempRule, false)
if err != nil { if err != nil {
@ -235,35 +362,109 @@ func (p *prompter) processConnection(pc pendingConnection) {
} }
func (p *prompter) nextConnection() (pendingConnection, bool) { func (p *prompter) nextConnection() (pendingConnection, bool) {
for { pind := 0
if len(p.policyQueue) == 0 { if len(p.policyQueue) == 0 {
return nil, true return nil, true
} }
policy := p.policyQueue[0] fmt.Println("policy queue len = ", len(p.policyQueue))
for pind < len(p.policyQueue) {
//fmt.Printf("XXX: pind = %v of %v\n", pind, len(p.policyQueue))
policy := p.policyQueue[pind]
pc, qempty := policy.nextPending() pc, qempty := policy.nextPending()
if pc == nil && qempty { if pc == nil && qempty {
p.removePolicy(policy) p.removePolicy(policy)
continue
} else {
pind++
// if pc == nil && !qempty {
if len(policy.rulesPending) > 0 {
fmt.Println("policy rules pending = ", len(policy.rulesPending))
prule := policy.rulesPending[0]
policy.rulesPending = append(policy.rulesPending[:0], policy.rulesPending[1:]...)
toks := strings.Split(prule.rule, "|")
sandbox := ""
if len(toks) > 2 {
sandbox = toks[2]
}
sandbox += ""
tempRule := fmt.Sprintf("%s|%s", toks[0], toks[1])
tempRule += "||-1:-1|" + sandbox + "|"
/*if pc.src() != nil && !pc.src().IsLoopback() && sandbox != "" {
tempRule += "||-1:-1|" + sandbox + "|" + pc.src().String()
} else { } else {
tempRule += "||-1:-1|" + sandbox + "|"
}*/
r, err := policy.parseRule(tempRule, false)
if err != nil {
log.Warningf("Error parsing rule string returned from dbus RequestPrompt: %v", err)
continue
// policy.removePending(pc)
// pc.drop()
// return
} else {
fscope := FilterScope(prule.scope)
if fscope == APPLY_SESSION {
r.mode = RULE_MODE_SESSION
} else if fscope == APPLY_PROCESS {
r.mode = RULE_MODE_PROCESS
/*r.pid = pc.procInfo().Pid
pcoroner.MonitorProcess(r.pid)*/
}
if !policy.processNewRule(r, fscope) {
// p.lock.Lock()
// defer p.lock.Unlock()
// p.removePolicy(pc.policy())
}
if fscope == APPLY_FOREVER {
r.mode = RULE_MODE_PERMANENT
policy.fw.saveRules()
}
log.Warningf("Prompt returning rule: %v", tempRule)
dbusp.alertRule("sgfw prompt added new rule")
}
}
if pc == nil && !qempty { if pc == nil && !qempty {
log.Errorf("FIX ME: I NEED TO SLEEP ON A WAKEABLE CONDITION PROPERLY!!") // log.Errorf("FIX ME: I NEED TO SLEEP ON A WAKEABLE CONDITION PROPERLY!!")
time.Sleep(time.Millisecond * 300) time.Sleep(time.Millisecond * 300)
continue
} }
if pc != nil && pc.getPrompting() {
fmt.Println("SKIPPING PROMPTED")
continue
}
return pc, qempty return pc, qempty
} }
} }
return nil, true
} }
func (p *prompter) removePolicy(policy *Policy) { func (p *prompter) removePolicy(policy *Policy) {
var newQueue []*Policy = nil var newQueue []*Policy = nil
if DoMultiPrompt { // if DoMultiPrompt {
if len(p.policyQueue) == 0 { if len(p.policyQueue) == 0 {
log.Debugf("Skipping over zero length policy queue") log.Debugf("Skipping over zero length policy queue")
newQueue = make([]*Policy, 0, 0) newQueue = make([]*Policy, 0, 0)
} }
} // }
if !DoMultiPrompt || newQueue == nil { // if !DoMultiPrompt || newQueue == nil {
if newQueue == nil {
newQueue = make([]*Policy, 0, len(p.policyQueue)-1) newQueue = make([]*Policy, 0, len(p.policyQueue)-1)
} }
for _, pol := range p.policyQueue { for _, pol := range p.policyQueue {
@ -277,11 +478,17 @@ func (p *prompter) removePolicy(policy *Policy) {
var userMap = make(map[int]string) var userMap = make(map[int]string)
var groupMap = make(map[int]string) var groupMap = make(map[int]string)
var userMapLock = &sync.Mutex{}
var groupMapLock = &sync.Mutex{}
func lookupUser(uid int) string { func lookupUser(uid int) string {
if uid == -1 { if uid == -1 {
return "[unknown]" return "[unknown]"
} }
userMapLock.Lock()
defer userMapLock.Unlock()
u, err := user.LookupId(strconv.Itoa(uid)) u, err := user.LookupId(strconv.Itoa(uid))
if err != nil { if err != nil {
return fmt.Sprintf("%d", uid) return fmt.Sprintf("%d", uid)
@ -293,6 +500,10 @@ func lookupGroup(gid int) string {
if gid == -1 { if gid == -1 {
return "[unknown]" return "[unknown]"
} }
groupMapLock.Lock()
defer groupMapLock.Unlock()
g, err := user.LookupGroupId(strconv.Itoa(gid)) g, err := user.LookupGroupId(strconv.Itoa(gid))
if err != nil { if err != nil {
return fmt.Sprintf("%d", gid) return fmt.Sprintf("%d", gid)

@ -479,7 +479,10 @@ func (fw *Firewall) loadRules() {
if err != nil { if err != nil {
if !os.IsNotExist(err) { if !os.IsNotExist(err) {
log.Warningf("Failed to open %s for reading: %v", p, err) log.Warningf("Failed to open %s for reading: %v", p, err)
} else {
log.Warningf("Did not find a rules file at %s: SGFW loaded with no rules\n", p)
} }
return return
} }
var policy *Policy var policy *Policy
@ -497,7 +500,13 @@ func (fw *Firewall) loadRules() {
func (fw *Firewall) processPathLine(line string) *Policy { func (fw *Firewall) processPathLine(line string) *Policy {
pathLine := line[1 : len(line)-1] pathLine := line[1 : len(line)-1]
toks := strings.Split(pathLine, "|") toks := strings.Split(pathLine, "|")
if len(toks) != 2 {
log.Warning("Error parsing rules directive:", line)
return nil
}
policy := fw.policyForPathAndSandbox(toks[1], toks[0]) policy := fw.policyForPathAndSandbox(toks[1], toks[0])
policy.lock.Lock() policy.lock.Lock()
defer policy.lock.Unlock() defer policy.lock.Unlock()

@ -1,16 +1,16 @@
package sgfw package sgfw
import ( import (
"bufio"
"encoding/json"
"fmt"
"os" "os"
"os/signal" "os/signal"
"regexp" "regexp"
"strings"
"sync" "sync"
"syscall" "syscall"
// "time" "time"
"bufio"
"encoding/json"
"fmt"
"strings"
"github.com/op/go-logging" "github.com/op/go-logging"
nfqueue "github.com/subgraph/go-nfnetlink/nfqueue" nfqueue "github.com/subgraph/go-nfnetlink/nfqueue"
@ -110,6 +110,8 @@ func (fw *Firewall) runFilter() {
go func() { go func() {
for p := range ps { for p := range ps {
timestamp := time.Now()
if fw.isEnabled() { if fw.isEnabled() {
ipLayer := p.Packet.Layer(layers.LayerTypeIPv4) ipLayer := p.Packet.Layer(layers.LayerTypeIPv4)
if ipLayer == nil { if ipLayer == nil {
@ -127,7 +129,7 @@ func (fw *Firewall) runFilter() {
} }
fw.filterPacket(p) fw.filterPacket(p, timestamp)
} else { } else {
p.Accept() p.Accept()
} }

@ -56,7 +56,10 @@ type pendingSocksConnection struct {
pinfo *procsnitch.Info pinfo *procsnitch.Info
verdict chan int verdict chan int
prompting bool prompting bool
prompter *prompter
guid string
optstr string optstr string
timestamp time.Time
} }
func (sc *pendingSocksConnection) sandbox() string { func (sc *pendingSocksConnection) sandbox() string {
@ -107,23 +110,58 @@ func (sc *pendingSocksConnection) deliverVerdict(v int) {
} }
}() }()
if sc.verdict != nil {
sc.verdict <- v sc.verdict <- v
close(sc.verdict) close(sc.verdict)
sc.verdict = nil
}
} }
func (sc *pendingSocksConnection) accept() { sc.deliverVerdict(socksVerdictAccept) } func (sc *pendingSocksConnection) accept() {
sc.deliverVerdict(socksVerdictAccept)
}
// need to generalize special accept // need to generalize special accept
func (sc *pendingSocksConnection) acceptTLSOnly() { sc.deliverVerdict(socksVerdictAcceptTLSOnly) } func (sc *pendingSocksConnection) acceptTLSOnly() {
sc.deliverVerdict(socksVerdictAcceptTLSOnly)
}
func (sc *pendingSocksConnection) drop() { sc.deliverVerdict(socksVerdictDrop) } func (sc *pendingSocksConnection) drop() {
sc.deliverVerdict(socksVerdictDrop)
}
func (sc *pendingSocksConnection) getPrompting() bool { return sc.prompting } func (sc *pendingSocksConnection) setPrompter(val *prompter) {
sc.prompter = val
}
func (sc *pendingSocksConnection) setPrompting(val bool) { sc.prompting = val } func (sc *pendingSocksConnection) getPrompter() *prompter {
return sc.prompter
}
func (sc *pendingSocksConnection) getTimestamp() string {
return sc.timestamp.Format("15:04:05.00")
}
func (sc *pendingSocksConnection) print() string { return "socks connection" } func (sc *pendingSocksConnection) getGUID() string {
if sc.guid == "" {
sc.guid = genGUID()
}
return sc.guid
}
func (sc *pendingSocksConnection) getPrompting() bool {
return sc.prompting
}
func (sc *pendingSocksConnection) setPrompting(val bool) {
sc.prompting = val
}
func (sc *pendingSocksConnection) print() string {
return "socks connection"
}
func NewSocksChain(cfg *socksChainConfig, wg *sync.WaitGroup, fw *Firewall) *socksChain { func NewSocksChain(cfg *socksChainConfig, wg *sync.WaitGroup, fw *Firewall) *socksChain {
chain := socksChain{ chain := socksChain{
@ -146,10 +184,11 @@ func (s *socksChain) start() {
} }
s.wg.Add(1) s.wg.Add(1)
go s.socksAcceptLoop() ts := time.Now()
go s.socksAcceptLoop(ts)
} }
func (s *socksChain) socksAcceptLoop() error { func (s *socksChain) socksAcceptLoop(timestamp time.Time) error {
defer s.wg.Done() defer s.wg.Done()
defer s.listener.Close() defer s.listener.Close()
@ -163,11 +202,11 @@ func (s *socksChain) socksAcceptLoop() error {
continue continue
} }
session := &socksChainSession{cfg: s.cfg, clientConn: conn, procInfo: s.procInfo, server: s} session := &socksChainSession{cfg: s.cfg, clientConn: conn, procInfo: s.procInfo, server: s}
go session.sessionWorker() go session.sessionWorker(timestamp)
} }
} }
func (c *socksChainSession) sessionWorker() { func (c *socksChainSession) sessionWorker(timestamp time.Time) {
defer c.clientConn.Close() defer c.clientConn.Close()
clientAddr := c.clientConn.RemoteAddr() clientAddr := c.clientConn.RemoteAddr()
@ -197,7 +236,7 @@ func (c *socksChainSession) sessionWorker() {
c.req.ReplyAddr(ReplySucceeded, c.bndAddr) c.req.ReplyAddr(ReplySucceeded, c.bndAddr)
} }
case CommandConnect: case CommandConnect:
verdict, tls := c.filterConnect() verdict, tls := c.filterConnect(timestamp)
if !verdict { if !verdict {
c.req.Reply(ReplyConnectionRefused) c.req.Reply(ReplyConnectionRefused)
@ -278,7 +317,7 @@ func findProxyEndpoint(pdata []string, conn net.Conn) (*procsnitch.Info, string)
return nil, "" return nil, ""
} }
func (c *socksChainSession) filterConnect() (bool, bool) { func (c *socksChainSession) filterConnect(timestamp time.Time) (bool, bool) {
// return filter verdict, tlsguard // return filter verdict, tlsguard
allProxies, err := ListProxies() allProxies, err := ListProxies()
@ -364,7 +403,9 @@ func (c *socksChainSession) filterConnect() (bool, bool) {
pinfo: pinfo, pinfo: pinfo,
verdict: make(chan int), verdict: make(chan int),
prompting: false, prompting: false,
prompter: nil,
optstr: optstr, optstr: optstr,
timestamp: timestamp,
} }
policy.processPromptResult(pending) policy.processPromptResult(pending)
v := <-pending.verdict v := <-pending.verdict
@ -409,7 +450,7 @@ func (c *socksChainSession) forwardTraffic(tls bool) {
if c.pinfo.Sandbox != "" { if c.pinfo.Sandbox != "" {
log.Errorf("TLSGuard violation: Dropping traffic from %s (sandbox: %s) to %s: %v", c.pinfo.ExePath, c.pinfo.Sandbox, c.req.Addr.addrStr, err) log.Errorf("TLSGuard violation: Dropping traffic from %s (sandbox: %s) to %s: %v", c.pinfo.ExePath, c.pinfo.Sandbox, c.req.Addr.addrStr, err)
} else { } else {
log.Errorf("TLSGuard violation: Dropping traffic from %s (unsandboxed) to %s: %v", c.pinfo.ExePath, c.req.Addr.addrStr, err) log.Errorf("TLSGuard violation: Dropping traffic from %s (un-sandboxed) to %s: %v", c.pinfo.ExePath, c.req.Addr.addrStr, err)
} }
return return
} else { } else {

@ -2,6 +2,8 @@ package sgfw
import ( import (
"crypto/x509" "crypto/x509"
"encoding/binary"
"encoding/hex"
"errors" "errors"
"fmt" "fmt"
"io" "io"
@ -9,126 +11,479 @@ import (
"time" "time"
) )
const TLSGUARD_READ_TIMEOUT = 2 * time.Second const TLSGUARD_READ_TIMEOUT = 10 * time.Second
const TLSGUARD_MIN_TLS_VER_MAJ = 3 const TLSGUARD_MIN_TLS_VER_MAJ = 3
const TLSGUARD_MIN_TLS_VER_MIN = 1 const TLSGUARD_MIN_TLS_VER_MIN = 1
const TLS_RECORD_HDR_LEN = 5
const SSL3_RT_CHANGE_CIPHER_SPEC = 20 const SSL3_RT_CHANGE_CIPHER_SPEC = 20
const SSL3_RT_ALERT = 21 const SSL3_RT_ALERT = 21
const SSL3_RT_HANDSHAKE = 22 const SSL3_RT_HANDSHAKE = 22
const SSL3_RT_APPLICATION_DATA = 23 const SSL3_RT_APPLICATION_DATA = 23
const SSL3_MT_HELLO_REQUEST = 0
const SSL3_MT_CLIENT_HELLO = 1
const SSL3_MT_SERVER_HELLO = 2 const SSL3_MT_SERVER_HELLO = 2
const SSL3_MT_CERTIFICATE = 11 const SSL3_MT_CERTIFICATE = 11
const SSL3_MT_CERTIFICATE_REQUEST = 13 const SSL3_MT_CERTIFICATE_REQUEST = 13
const SSL3_MT_SERVER_DONE = 14 const SSL3_MT_SERVER_DONE = 14
const SSL3_MT_CERTIFICATE_STATUS = 22
const SSL3_AL_WARNING = 1
const SSL3_AL_FATAL = 2
const SSL3_AD_CLOSE_NOTIFY = 0
const SSL3_AD_UNEXPECTED_MESSAGE = 10
const SSL3_AD_BAD_RECORD_MAC = 20
const TLS1_AD_DECRYPTION_FAILED = 21
const TLS1_AD_RECORD_OVERFLOW = 22
const SSL3_AD_DECOMPRESSION_FAILURE = 30
const SSL3_AD_HANDSHAKE_FAILURE = 40
const SSL3_AD_NO_CERTIFICATE = 41
const SSL3_AD_BAD_CERTIFICATE = 42
const SSL3_AD_UNSUPPORTED_CERTIFICATE = 43
const SSL3_AD_CERTIFICATE_REVOKED = 44
const SSL3_AD_CERTIFICATE_EXPIRED = 45
const SSL3_AD_CERTIFICATE_UNKNOWN = 46
const SSL3_AD_ILLEGAL_PARAMETER = 47
const TLS1_AD_UNKNOWN_CA = 48
const TLS1_AD_ACCESS_DENIED = 49
const TLS1_AD_DECODE_ERROR = 50
const TLS1_AD_DECRYPT_ERROR = 51
const TLS1_AD_EXPORT_RESTRICTION = 60
const TLS1_AD_PROTOCOL_VERSION = 70
const TLS1_AD_INSUFFICIENT_SECURITY = 71
const TLS1_AD_INTERNAL_ERROR = 80
const TLS1_AD_INAPPROPRIATE_FALLBACK = 86
const TLS1_AD_USER_CANCELLED = 90
const TLS1_AD_NO_RENEGOTIATION = 100
const TLS1_AD_UNSUPPORTED_EXTENSION = 110
const TLSEXT_TYPE_server_name = 1
const TLSEXT_TYPE_signature_algorithms = 13
const TLSEXT_TYPE_client_certificate_type = 19
const TLSEXT_TYPE_extended_master_secret = 23
const TLSEXT_TYPE_renegotiate = 0xff01
type connReader struct {
client bool
data []byte
rtype int
err error
}
var cipherSuiteMap map[uint16]string = map[uint16]string{
0x0000: "TLS_NULL_WITH_NULL_NULL",
0x000a: "TLS_RSA_WITH_3DES_EDE_CBC_SHA",
0x002f: "TLS_RSA_WITH_AES_128_CBC_SHA",
0x0033: "TLS_DHE_RSA_WITH_AES_128_CBC_SHA",
0x0039: "TLS_DHE_RSA_WITH_AES_256_CBC_SHA",
0x0035: "TLS_RSA_WITH_AES_256_CBC_SHA",
0x0030: "TLS_DH_DSS_WITH_AES_128_CBC_SHA",
0xc009: "TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA",
0xc00a: "TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA",
0xc013: "TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA",
0xc014: "TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA",
0xc02b: "TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256",
0xc02c: "TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384",
0xc02f: "TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256",
0xc030: "TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384",
0xcca9: "TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256",
0xcca8: "TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256",
}
func getCipherSuiteName(value uint) string {
val, ok := cipherSuiteMap[uint16(value)]
if !ok {
return "UNKNOWN"
}
return val
}
func readTLSChunk(conn net.Conn) ([]byte, int, error) { func connectionReader(conn net.Conn, is_client bool, c chan connReader, done chan bool) {
var ret_error error = nil
buffered := []byte{}
mlen := 0
rtype := 0
stage := 1
for {
if ret_error != nil {
cr := connReader{client: is_client, data: nil, rtype: 0, err: ret_error}
c <- cr
break
}
select {
case <-done:
fmt.Println("++ DONE: ", is_client)
if len(buffered) > 0 {
//fmt.Println("++ DONE BUT DISPOSING OF BUFFERED DATA")
c <- connReader{client: is_client, data: buffered, rtype: 0, err: nil}
}
c <- connReader{client: is_client, data: nil, rtype: 0, err: nil}
return
default:
if stage == 1 {
header := make([]byte, TLS_RECORD_HDR_LEN)
conn.SetReadDeadline(time.Now().Add(TLSGUARD_READ_TIMEOUT)) conn.SetReadDeadline(time.Now().Add(TLSGUARD_READ_TIMEOUT))
cbytes, err := readNBytes(conn, 5) _, err := io.ReadFull(conn, header)
conn.SetReadDeadline(time.Time{}) conn.SetReadDeadline(time.Time{})
if err != nil { if err != nil {
log.Errorf("TLS data chunk read failure: ", err) ret_error = err
return nil, 0, err continue
} }
if int(cbytes[1]) < TLSGUARD_MIN_TLS_VER_MAJ { if int(header[1]) < TLSGUARD_MIN_TLS_VER_MAJ {
return nil, 0, errors.New("TLS protocol major version less than expected minimum") ret_error = errors.New("TLS protocol major version less than expected minimum")
} else if int(cbytes[2]) < TLSGUARD_MIN_TLS_VER_MIN { continue
return nil, 0, errors.New("TLS protocol minor version less than expected minimum") } else if int(header[2]) < TLSGUARD_MIN_TLS_VER_MIN {
ret_error = errors.New("TLS protocol minor version less than expected minimum")
continue
} }
cbyte := cbytes[0] rtype = int(header[0])
mlen := int(int(cbytes[3])<<8 | int(cbytes[4])) mlen = int(int(header[3])<<8 | int(header[4]))
// fmt.Printf("TLS data chunk header read: type = %#x, maj = %v, min = %v, len = %v\n", cbyte, cbytes[1], cbytes[2], mlen) fmt.Printf("TLS data chunk header read: type = %#x, maj = %v, min = %v, len = %v\n", rtype, header[1], header[2], mlen)
/* 16384+1024 if compression is not null */
/* or 16384+2048 if ciphertext */
if mlen > 16384 {
ret_error = errors.New(fmt.Sprintf("TLSGuard read TLS plaintext record of excessively large length; dropping (%v bytes)", mlen))
continue
}
buffered = header
stage++
} else if stage == 2 {
remainder := make([]byte, mlen)
conn.SetReadDeadline(time.Now().Add(TLSGUARD_READ_TIMEOUT)) conn.SetReadDeadline(time.Now().Add(TLSGUARD_READ_TIMEOUT))
cbytes2, err := readNBytes(conn, mlen) _, err := io.ReadFull(conn, remainder)
conn.SetReadDeadline(time.Time{}) conn.SetReadDeadline(time.Time{})
if err != nil { if err != nil {
return nil, 0, err ret_error = err
continue
}
buffered = append(buffered, remainder...)
fmt.Printf("------- CHUNK READ: client: %v, err = %v, bytes = %v\n", is_client, err, len(buffered))
cr := connReader{client: is_client, data: buffered, rtype: rtype, err: err}
c <- cr
buffered = []byte{}
rtype = 0
mlen = 0
stage = 1
}
}
} }
cbytes = append(cbytes, cbytes2...)
return cbytes, int(cbyte), nil
} }
func TLSGuard(conn, conn2 net.Conn, fqdn string) error { func TLSGuard(conn, conn2 net.Conn, fqdn string) error {
x509Valid := false
ndone := 0
// Should this be a requirement? // Should this be a requirement?
// if strings.HasSuffix(request.DestAddr.FQDN, "onion") { // if strings.HasSuffix(request.DestAddr.FQDN, "onion") {
//conn client //conn client
//conn2 server //conn2 server
// Read the opening message from the client fmt.Println("-------- STARTING HANDSHAKE LOOP")
chunk, rtype, err := readTLSChunk(conn) crChan := make(chan connReader)
if err != nil { dChan := make(chan bool, 10)
return err go connectionReader(conn, true, crChan, dChan)
go connectionReader(conn2, false, crChan, dChan)
client_expected := SSL3_MT_CLIENT_HELLO
server_expected := SSL3_MT_SERVER_HELLO
select_loop:
for {
if ndone == 2 {
fmt.Println("DONE channel got both notifications. Terminating loop.")
close(dChan)
close(crChan)
break
} }
if rtype != SSL3_RT_HANDSHAKE { select {
return errors.New("Blocked client from attempting non-TLS connection") case cr := <-crChan:
other := conn
if cr.client {
other = conn2
} }
// Pass it on through to the server fmt.Printf("++++ SELECT: %v, %v, %v\n", cr.client, cr.err, len(cr.data))
conn2.Write(chunk) if cr.err == nil && cr.data == nil {
fmt.Println("DONE channel notification received")
ndone++
continue
}
// Read ServerHello if cr.err == nil {
valid := false if cr.rtype == SSL3_RT_CHANGE_CIPHER_SPEC || cr.rtype == SSL3_RT_APPLICATION_DATA ||
loop := 1 cr.rtype == SSL3_RT_ALERT {
passthru := false /* We expect only a single byte of data */
if cr.rtype == SSL3_RT_CHANGE_CIPHER_SPEC {
fmt.Println("CHANGE CIPHER_SPEC: ", cr.data[TLS_RECORD_HDR_LEN])
if len(cr.data) != 6 {
return errors.New(fmt.Sprintf("TLSGuard dropped connection with strange change cipher spec data length (%v bytes)", len(cr.data)))
}
if cr.data[TLS_RECORD_HDR_LEN] != 1 {
return errors.New(fmt.Sprintf("TLSGuard dropped connection with strange change cipher spec data (%#x bytes)", cr.data[TLS_RECORD_HDR_LEN]))
}
} else if cr.rtype == SSL3_RT_ALERT {
if cr.data[TLS_RECORD_HDR_LEN] == SSL3_AL_WARNING {
fmt.Println("SSL ALERT TYPE: warning")
} else if cr.data[TLS_RECORD_HDR_LEN] == SSL3_AL_FATAL {
fmt.Println("SSL ALERT TYPE: fatal")
} else {
fmt.Println("SSL ALERT TYPE UNKNOWN")
}
for 1 == 1 { alert_desc := int(int(cr.data[6])<<8 | int(cr.data[7]))
loop++ fmt.Println("ALERT DESCRIPTION: ", alert_desc)
// fmt.Printf("SSL LOOP %v; trying to read: conn2\n", loop) if cr.data[TLS_RECORD_HDR_LEN] == SSL3_AL_FATAL {
chunk, rtype, err = readTLSChunk(conn2) return errors.New(fmt.Sprintf("TLSGuard dropped connection after fatal error alert detected"))
} else if alert_desc == SSL3_AD_CLOSE_NOTIFY {
return errors.New(fmt.Sprintf("TLSGuard dropped connection after close_notify alert detected"))
}
if err != nil { }
log.Debugf("TLSGUARD: OTHER loop %v: trying to read: conn\n", loop)
chunk, rtype, err2 := readTLSChunk(conn)
log.Debugf("TLSGUARD: read: %v, %v, %v\n", err2, rtype, len(chunk))
if err2 == nil { // fmt.Println("OTHER DATA; PASSING THRU")
conn2.Write(chunk) if cr.rtype == SSL3_RT_ALERT {
fmt.Println("ALERT = ", cr.data)
}
other.Write(cr.data)
continue continue
} else if cr.client {
// other.Write(cr.data)
// continue
} else if cr.rtype != SSL3_RT_HANDSHAKE {
return errors.New(fmt.Sprintf("Expected TLS server handshake byte was not received [%#x vs 0x16]", cr.rtype))
} }
return err if cr.rtype < SSL3_RT_CHANGE_CIPHER_SPEC || cr.rtype > SSL3_RT_APPLICATION_DATA {
return errors.New(fmt.Sprintf("TLSGuard dropping connection with unknown content type: %#x", cr.rtype))
} }
if rtype == SSL3_RT_CHANGE_CIPHER_SPEC || rtype == SSL3_RT_APPLICATION_DATA || handshakeMsg := cr.data[TLS_RECORD_HDR_LEN:]
rtype == SSL3_RT_ALERT { s := uint(handshakeMsg[0])
// fmt.Println("OTHER DATA; PASSING THRU") fmt.Printf("s = %#x\n", s)
passthru = true // Message len, 3 bytes
} else if rtype == SSL3_RT_HANDSHAKE { if cr.rtype == SSL3_RT_HANDSHAKE {
passthru = false handshakeMessageLen := handshakeMsg[1:4]
handshakeMessageLenInt := int(int(handshakeMessageLen[0])<<16 | int(handshakeMessageLen[1])<<8 | int(handshakeMessageLen[2]))
fmt.Println("lenint = \n", handshakeMessageLenInt)
}
if cr.client && s != uint(client_expected) {
return errors.New(fmt.Sprintf("Client sent handshake type %#x but expected %#x", s, client_expected))
} else if !cr.client && s != uint(server_expected) {
return errors.New(fmt.Sprintf("Server sent handshake type %#x but expected %#x", s, server_expected))
}
if (cr.client && s == SSL3_MT_CLIENT_HELLO) || (!cr.client && s == SSL3_MT_SERVER_HELLO) {
rewrite := false
rewrite_buf := []byte{}
SRC := ""
if s == SSL3_MT_CLIENT_HELLO {
SRC = "CLIENT"
} else { } else {
return errors.New(fmt.Sprintf("Expected TLS server handshake byte was not received [%#x vs 0x16]", rtype)) server_expected = SSL3_MT_CERTIFICATE
SRC = "SERVER"
}
hello_offset := 4
// 2 byte protocol version
fmt.Println(SRC, "HELLO VERSION = ", handshakeMsg[hello_offset:hello_offset+2])
hello_offset += 2
// 4 byte Random/GMT time
gmtbytes := binary.BigEndian.Uint32(handshakeMsg[hello_offset : hello_offset+4])
gmt := time.Unix(int64(gmtbytes), 0)
fmt.Println(SRC, "HELLO GMT = ", gmt)
hello_offset += 4
// 28 bytes Random/random_bytes
hello_offset += 28
// 1 byte (32-bit session ID)
sess_len := uint(handshakeMsg[hello_offset])
fmt.Println(SRC, "HELLO SESSION ID = ", sess_len)
if sess_len != 0 {
fmt.Printf("ALERT: %v attempting to resume session; intercepting request\n", SRC)
rewrite = true
dcopy := make([]byte, len(cr.data))
copy(dcopy, cr.data)
// Copy the bytes before the session ID start
rewrite_buf = dcopy[0 : TLS_RECORD_HDR_LEN+hello_offset+1]
// Set the session ID to 0
rewrite_buf[len(rewrite_buf)-1] = 0
// Write the new TLS record length
binary.BigEndian.PutUint16(rewrite_buf[3:5], uint16(len(dcopy)-(int(sess_len)+TLS_RECORD_HDR_LEN)))
// Write the new ClientHello length
// Starts after the first 6 bytes (record header + type byte)
orig_len := binary.BigEndian.Uint32(handshakeMsg[0:4])
// But it's only 3 bytes so mask out the first one
b1 := orig_len & 0xff000000
orig_len &= 0x00ffffff
orig_len -= uint32(sess_len)
orig_len |= b1
binary.BigEndian.PutUint32(rewrite_buf[TLS_RECORD_HDR_LEN:], orig_len)
rewrite_buf = append(rewrite_buf, dcopy[TLS_RECORD_HDR_LEN+hello_offset+int(sess_len)+1:]...)
}
hello_offset += int(sess_len) + 1
// 2 byte cipher suite array
cs := binary.BigEndian.Uint16(handshakeMsg[hello_offset : hello_offset+2])
noCS := cs
fmt.Printf("cs = %v / %#x\n", noCS, noCS)
if !cr.client {
fmt.Printf("SERVER selected ciphersuite: %#x (%s)\n", cs, getCipherSuiteName(uint(cs)))
hello_offset += 2
} else {
for csind := 0; csind < int(noCS/2); csind++ {
off := hello_offset + 2 + (csind * 2)
cs = binary.BigEndian.Uint16(handshakeMsg[off : off+2])
fmt.Printf("%s HELLO CIPHERSUITE: %d/%d: %#x (%s)\n", SRC, csind+1, noCS/2, cs, getCipherSuiteName(uint(cs)))
}
hello_offset += 2 + int(noCS)
}
clen := uint(handshakeMsg[hello_offset])
hello_offset++
if !cr.client {
fmt.Println("SERVER selected compression method: ", clen)
} else {
fmt.Println(SRC, "HELLO COMPRESSION METHODS LEN = ", clen)
fmt.Println(SRC, "HELLO COMPRESSION METHODS: ", handshakeMsg[hello_offset:hello_offset+int(clen)])
hello_offset += int(clen)
}
var extlen uint16 = 0
if hello_offset == len(handshakeMsg) {
fmt.Println("Message didn't have any extensions present")
} else {
extlen = binary.BigEndian.Uint16(handshakeMsg[hello_offset : hello_offset+2])
fmt.Println(SRC, "HELLO EXTENSIONS LENGTH: ", extlen)
hello_offset += 2
}
var exttype uint16 = 0
if extlen > 2 {
exttype = binary.BigEndian.Uint16(handshakeMsg[hello_offset : hello_offset+2])
fmt.Println(SRC, "HELLO FIRST EXTENSION TYPE: ", exttype)
}
if cr.client {
ext_ctr := 0
for ext_ctr < int(extlen)-2 {
hello_offset += 2
ext_ctr += 2
fmt.Printf("PROGRESS: %v of %v, %v of %v\n", ext_ctr, extlen, hello_offset, len(handshakeMsg))
exttype2 := binary.BigEndian.Uint16(handshakeMsg[hello_offset : hello_offset+2])
fmt.Printf("EXTTYPE = %v, 2 = %v\n", exttype, exttype2)
if exttype2 == TLSEXT_TYPE_server_name {
fmt.Println("CLIENT specified server_name extension:")
}
if exttype != TLSEXT_TYPE_signature_algorithms {
fmt.Println("WTF")
}
hello_offset += 2
ext_ctr += 2
inner_len := binary.BigEndian.Uint16(handshakeMsg[hello_offset : hello_offset+2])
// fmt.Println("INNER LEN = ", inner_len)
hello_offset += int(inner_len)
ext_ctr += int(inner_len)
}
}
if extlen > 0 {
fmt.Printf("ALERT: %v attempting to send extensions; intercepting request\n", SRC)
rewrite = true
tocopy := cr.data
if len(rewrite_buf) > 0 {
tocopy = rewrite_buf
}
dcopy := make([]byte, len(tocopy)-int(extlen))
copy(dcopy, tocopy[0:len(tocopy)-int(extlen)])
rewrite_buf = dcopy
// Write the new TLS record length
binary.BigEndian.PutUint16(rewrite_buf[3:5], uint16(len(dcopy)-(int(sess_len)+TLS_RECORD_HDR_LEN)))
// Write the new ClientHello length
// Starts after the first 6 bytes (record header + type byte)
orig_len := binary.BigEndian.Uint32(rewrite_buf[TLS_RECORD_HDR_LEN:])
// But it's only 3 bytes so mask out the first one
b1 := orig_len & 0xff000000
orig_len &= 0x00ffffff
orig_len -= uint32(extlen)
orig_len |= b1
binary.BigEndian.PutUint32(rewrite_buf[TLS_RECORD_HDR_LEN:], orig_len)
// Write session length 0 at the end
rewrite_buf[len(rewrite_buf)-1] = 0
rewrite_buf[len(rewrite_buf)-2] = 0
}
if rewrite {
fmt.Println("TLSGuard writing back modified handshake data to server")
fmt.Printf("ORIGINAL[%d]: %v\n", len(cr.data), hex.Dump(cr.data))
fmt.Printf("NEW[%d]: %v\n", len(rewrite_buf), hex.Dump(rewrite_buf))
other.Write(rewrite_buf)
} else {
other.Write(cr.data)
} }
if passthru {
// fmt.Println("passthru writing buf again and continuing:")
conn.Write(chunk)
continue continue
} }
serverMsg := chunk[5:] if cr.client {
s := serverMsg[0] other.Write(cr.data)
log.Debugf("TLSGUARD: s = %#x\n", s) continue
}
if !cr.client && server_expected == SSL3_MT_SERVER_HELLO {
server_expected = SSL3_MT_CERTIFICATE
}
if !cr.client && s == SSL3_MT_HELLO_REQUEST {
fmt.Println("Server sent hello request")
continue
}
if s > SSL3_MT_CERTIFICATE_STATUS {
fmt.Println("WTF: ", cr.data)
}
if s == SSL3_MT_CERTIFICATE {
// Message len, 3 bytes // Message len, 3 bytes
serverMessageLen := serverMsg[1:4] handshakeMessageLen := handshakeMsg[1:4]
serverMessageLenInt := int(int(serverMessageLen[0])<<16 | int(serverMessageLen[1])<<8 | int(serverMessageLen[2])) handshakeMessageLenInt := int(int(handshakeMessageLen[0])<<16 | int(handshakeMessageLen[1])<<8 | int(handshakeMessageLen[2]))
// fmt.Printf("chunk len = %v, serverMsgLen = %v, slint = %v\n", len(chunk), len(serverMsg), serverMessageLenInt)
if len(serverMsg) < serverMessageLenInt { if s == SSL3_MT_CERTIFICATE {
return errors.New(fmt.Sprintf("len(serverMsg) %v < serverMessageLenInt %v!\n", len(serverMsg), serverMessageLenInt)) fmt.Println("HMM")
// fmt.Printf("chunk len = %v, handshakeMsgLen = %v, slint = %v\n", len(chunk), len(handshakeMsg), handshakeMessageLenInt)
if len(handshakeMsg) < handshakeMessageLenInt {
return errors.New(fmt.Sprintf("len(handshakeMsg) %v < handshakeMessageLenInt %v!\n", len(handshakeMsg), handshakeMessageLenInt))
} }
serverHelloBody := serverMsg[4 : 4+serverMessageLenInt] serverHelloBody := handshakeMsg[4 : 4+handshakeMessageLenInt]
certChainLen := int(int(serverHelloBody[0])<<16 | int(serverHelloBody[1])<<8 | int(serverHelloBody[2])) certChainLen := int(int(serverHelloBody[0])<<16 | int(serverHelloBody[1])<<8 | int(serverHelloBody[2]))
remaining := certChainLen remaining := certChainLen
pos := serverHelloBody[3:certChainLen] pos := serverHelloBody[3:certChainLen]
@ -136,6 +491,7 @@ func TLSGuard(conn, conn2 net.Conn, fqdn string) error {
// var certChain []*x509.Certificate // var certChain []*x509.Certificate
var verifyOptions x509.VerifyOptions var verifyOptions x509.VerifyOptions
//fqdn = "www.reddit.com"
if fqdn != "" { if fqdn != "" {
verifyOptions.DNSName = fqdn verifyOptions.DNSName = fqdn
} }
@ -162,44 +518,69 @@ func TLSGuard(conn, conn2 net.Conn, fqdn string) error {
pos = pos[3+certLen:] pos = pos[3+certLen:]
} }
} }
verifyOptions.Intermediates = pool
// fmt.Println("ATTEMPTING TO VERIFY: ", fqdn) verifyOptions.Intermediates = pool
_, err = c.Verify(verifyOptions) fmt.Println("ATTEMPTING TO VERIFY: ", fqdn)
// fmt.Println("ATTEMPTING TO VERIFY RESULT: ", err) _, err := c.Verify(verifyOptions)
fmt.Println("ATTEMPTING TO VERIFY RESULT: ", err)
if err != nil { if err != nil {
return err return err
} else { } else {
valid = true x509Valid = true
} }
// lse if s == 0x0d { fmt.Printf("found a client cert request, sending buf to client\n") }
} else if s == SSL3_MT_SERVER_DONE {
conn.Write(chunk)
break
} else if s == SSL3_MT_CERTIFICATE_REQUEST {
break
} }
other.Write(cr.data)
if x509Valid || (s == SSL3_MT_SERVER_DONE) || (s == SSL3_MT_CERTIFICATE_REQUEST) {
fmt.Println("BREAKING OUT OF LOOP 1")
dChan <- true
fmt.Println("BREAKING OUT OF LOOP 2")
break select_loop
}
// fmt.Printf("Sending chunk of type %d to client.\n", s) // fmt.Printf("Sending chunk of type %d to client.\n", s)
} else if cr.err != nil {
ndone++
conn.Write(chunk) if cr.client {
fmt.Println("Client read error: ", cr.err)
} else {
fmt.Println("Server read error: ", cr.err)
} }
if !valid { return cr.err
return errors.New("Unknown error: TLS connection could not be validated")
} }
return nil }
}
fmt.Println("WAITING; ndone = ", ndone)
for ndone < 2 {
fmt.Println("WAITING; ndone = ", ndone)
select {
case cr := <-crChan:
fmt.Printf("CHAN DATA: %v, %v, %v\n", cr.client, cr.err, len(cr.data))
if cr.err != nil || cr.data == nil {
ndone++
} else if cr.client {
conn2.Write(cr.data)
} else if !cr.client {
conn.Write(cr.data)
} }
func readNBytes(conn net.Conn, numBytes int) ([]byte, error) {
res := make([]byte, 0)
temp := make([]byte, 1)
for i := 0; i < numBytes; i++ {
_, err := io.ReadAtLeast(conn, temp, 1)
if err != nil {
return res, err
} }
res = append(res, temp[0])
} }
return res, nil
fmt.Println("______ ndone = 2\n")
// dChan <- true
close(dChan)
if !x509Valid {
return errors.New("Unknown error: TLS connection could not be validated")
}
return nil
} }

@ -1,8 +1,8 @@
package procsnitch package procsnitch
import ( import (
"encoding/hex"
"encoding/binary" "encoding/binary"
"encoding/hex"
"errors" "errors"
"fmt" "fmt"
"github.com/op/go-logging" "github.com/op/go-logging"
@ -312,6 +312,7 @@ func stripLabel(s string) string {
// stolen from github.com/virtao/GoEndian // stolen from github.com/virtao/GoEndian
const INT_SIZE int = int(unsafe.Sizeof(0)) const INT_SIZE int = int(unsafe.Sizeof(0))
func setEndian() { func setEndian() {
var i int = 0x1 var i int = 0x1
bs := (*[INT_SIZE]byte)(unsafe.Pointer(&i)) bs := (*[INT_SIZE]byte)(unsafe.Pointer(&i))

@ -24,6 +24,8 @@ type Info struct {
ParentCmdLine string ParentCmdLine string
ParentExePath string ParentExePath string
Sandbox string Sandbox string
Inode uint64
FD int
} }
type pidCache struct { type pidCache struct {
@ -51,10 +53,12 @@ func loadCache() map[uint64]*Info {
for _, n := range readdir("/proc") { for _, n := range readdir("/proc") {
pid := toPid(n) pid := toPid(n)
if pid != 0 { if pid != 0 {
pinfo := &Info{Pid: pid} inodes, fds := inodesFromPid(pid)
for _, inode := range inodesFromPid(pid) { for iind, inode := range inodes {
pinfo := &Info{Inode: inode, Pid: pid, FD: fds[iind]}
cmap[inode] = pinfo cmap[inode] = pinfo
} }
} }
} }
return cmap return cmap
@ -76,8 +80,9 @@ func toPid(name string) int {
return (int)(pid) return (int)(pid)
} }
func inodesFromPid(pid int) []uint64 { func inodesFromPid(pid int) ([]uint64, []int) {
var inodes []uint64 var inodes []uint64
var fds []int
fdpath := fmt.Sprintf("/proc/%d/fd", pid) fdpath := fmt.Sprintf("/proc/%d/fd", pid)
for _, n := range readdir(fdpath) { for _, n := range readdir(fdpath) {
if link, err := os.Readlink(path.Join(fdpath, n)); err != nil { if link, err := os.Readlink(path.Join(fdpath, n)); err != nil {
@ -85,12 +90,19 @@ func inodesFromPid(pid int) []uint64 {
log.Warningf("Error reading link %s: %v", n, err) log.Warningf("Error reading link %s: %v", n, err)
} }
} else { } else {
fd, err := strconv.Atoi(n)
if err != nil {
log.Warningf("Error retrieving fd associated with pid %v: %v", pid, err)
fd = -1
}
if inode := extractSocket(link); inode > 0 { if inode := extractSocket(link); inode > 0 {
inodes = append(inodes, inode) inodes = append(inodes, inode)
fds = append(fds, fd)
} }
} }
} }
return inodes return inodes, fds
} }
func extractSocket(name string) uint64 { func extractSocket(name string) uint64 {

@ -111,7 +111,7 @@ func findUDPSocketAll(srcAddr net.IP, srcPort uint16, dstAddr net.IP, dstPort ui
if custdata == nil { if custdata == nil {
if strictness == MATCH_STRICT { if strictness == MATCH_STRICT {
return findSocket(proto, func(ss socketStatus) bool { return findSocket(proto, func(ss socketStatus) bool {
fmt.Println("Match strict") // fmt.Println("Match strict")
return ss.remote.ip.Equal(dstAddr) && ss.local.port == srcPort && ss.local.ip.Equal(srcAddr) return ss.remote.ip.Equal(dstAddr) && ss.local.port == srcPort && ss.local.ip.Equal(srcAddr)
//return ss.local.port == srcPort && ss.local.ip.Equal(srcAddr) //return ss.local.port == srcPort && ss.local.ip.Equal(srcAddr)
}) })
@ -124,27 +124,29 @@ func findUDPSocketAll(srcAddr net.IP, srcPort uint16, dstAddr net.IP, dstPort ui
fmt.Printf("local ip: %v\n source ip: %v\n", ss.local.ip, srcAddr) fmt.Printf("local ip: %v\n source ip: %v\n", ss.local.ip, srcAddr)
*/ */
if (ss.local.port == srcPort && (ss.local.ip.Equal(net.IPv4(0,0,0,0)) && ss.remote.ip.Equal(net.IPv4(0,0,0,0)))) { if (ss.local.port == srcPort) && addrMatchesAny(ss.local.ip) && addrMatchesAny(ss.remote.ip) {
fmt.Printf("Matching for UDP socket bound to *:%d\n",ss.local.port) fmt.Printf("Loose match for UDP socket bound to *:%d\n", ss.local.port)
return true return true
} else if (ss.remote.ip.Equal(dstAddr) && ss.local.port == srcPort && ss.local.ip.Equal(srcAddr)) { } else if ss.remote.ip.Equal(dstAddr) && ss.local.port == srcPort && ss.local.ip.Equal(srcAddr) {
return true return true
} }
// Finally, loop through all interfaces if src port matches // Finally, loop through all interfaces if src port matches
if ss.local.port == srcPort { if ss.local.port == srcPort {
ifs, err := net.Interfaces() ifs, err := net.Interfaces()
if err != nil { if err != nil {
log.Warningf("Error on net.Interfaces(): %v", err) log.Warning("Error retrieving list of network interfaces for UDP socket lookup:", err)
return false return false
} }
for _, i := range ifs { for _, i := range ifs {
addrs, err := i.Addrs() addrs, err := i.Addrs()
if err != nil { if err != nil {
log.Warningf("Error on Interface.Addrs(): %v", err) log.Warning("Error retrieving network interface for UDP socket lookup:", err)
return false return false
} }
for _, addr := range addrs { for _, addr := range addrs {
var ifip net.IP var ifip net.IP
switch x := addr.(type) { switch x := addr.(type) {
@ -153,13 +155,16 @@ func findUDPSocketAll(srcAddr net.IP, srcPort uint16, dstAddr net.IP, dstPort ui
case *net.IPAddr: case *net.IPAddr:
ifip = x.IP ifip = x.IP
} }
if ss.local.ip.Equal(ifip) { if ss.local.ip.Equal(ifip) {
fmt.Printf("Matched on UDP socket bound to %v:%d\n", ifip, srcPort) fmt.Printf("Matched on UDP socket bound to %v:%d\n", ifip, srcPort)
return true return true
} }
} }
} }
} }
return false return false
//return (ss.remote.ip.Equal(dstAddr) || ss.remote.ip.Equal(net.IPv4(0,0,0,0))) && ss.local.port == srcPort && (ss.local.ip.Equal(srcAddr) || ss.local.ip.Equal(net.IPv4(0,0,0,0))) //return (ss.remote.ip.Equal(dstAddr) || ss.remote.ip.Equal(net.IPv4(0,0,0,0))) && ss.local.port == srcPort && (ss.local.ip.Equal(srcAddr) || ss.local.ip.Equal(net.IPv4(0,0,0,0)))
/* /*

Loading…
Cancel
Save