Merged from shw_dev

shw-merge
xSmurf 7 years ago
parent 5f26317c44
commit 7472b4d828

@ -4,33 +4,23 @@ import (
"errors"
"github.com/godbus/dbus"
"log"
// "github.com/gotk3/gotk3/glib"
)
type dbusObject struct {
dbus.BusObject
}
type dbusServer struct {
conn *dbus.Conn
run bool
}
type promptData struct {
Application string
Icon string
Path string
Address string
Port int
IP string
Origin string
Proto string
UID int
GID int
Username string
Groupname string
Pid int
Sandbox string
OptString string
Expanded bool
Expert bool
Action int
func newDbusObjectAdd() (*dbusObject, error) {
conn, err := dbus.SystemBus()
if err != nil {
return nil, err
}
return &dbusObject{conn.Object("com.subgraph.Firewall", "/com/subgraph/Firewall")}, nil
}
func newDbusServer() (*dbusServer, error) {
@ -62,10 +52,10 @@ func newDbusServer() (*dbusServer, error) {
return ds, nil
}
func (ds *dbusServer) RequestPrompt(application, icon, path, address string, port int32, ip, origin, proto string, uid, gid int32, username, groupname string, pid int32, sandbox string,
is_socks bool, optstring string, expanded, expert bool, action int32) (int32, string, *dbus.Error) {
log.Printf("request prompt: app = %s, icon = %s, path = %s, address = %s, is_socks = %v, action = %v\n", application, icon, path, address, is_socks, action)
decision := addRequest(nil, path, proto, int(pid), ip, address, int(port), int(uid), int(gid), origin, is_socks, optstring, sandbox)
/*func (ds *dbusServer) RequestPrompt(guid, application, icon, path, address string, port int32, ip, origin, proto string, uid, gid int32, username, groupname string, pid int32, sandbox string,
is_socks bool, timestamp string, optstring string, expanded, expert bool, action int32) (int32, string, *dbus.Error) {
log.Printf("request prompt: app = %s, icon = %s, path = %s, address = %s / ip = %s, is_socks = %v, action = %v\n", application, icon, path, address, ip, is_socks, action)
decision := addRequest(nil, guid, path, icon, proto, int(pid), ip, address, int(port), int(uid), int(gid), origin, timestamp, is_socks, optstring, sandbox, int(action))
log.Print("Waiting on decision...")
decision.Cond.L.Lock()
for !decision.Ready {
@ -73,6 +63,18 @@ func (ds *dbusServer) RequestPrompt(application, icon, path, address string, por
}
log.Print("Decision returned: ", decision.Rule)
decision.Cond.L.Unlock()
// glib.IdleAdd(func, data)
return int32(decision.Scope), decision.Rule, nil
}*/
func (ds *dbusServer) RequestPromptAsync(guid, application, icon, path, address string, port int32, ip, origin, proto string, uid, gid int32, username, groupname string, pid int32, sandbox string,
is_socks bool, timestamp string, optstring string, expanded, expert bool, action int32) (bool, *dbus.Error) {
log.Printf("ASYNC request prompt: guid = %s, app = %s, icon = %s, path = %s, address = %s / ip = %s, is_socks = %v, action = %v\n", guid, application, icon, path, address, ip, is_socks, action)
addRequestAsync(nil, guid, path, icon, proto, int(pid), ip, address, int(port), int(uid), int(gid), origin, timestamp, is_socks, optstring, sandbox, int(action))
return true, nil
}
func (ds *dbusServer) RemovePrompt(guid string) *dbus.Error {
log.Printf("++++++++ Cancelling prompt: %s\n", guid)
removeRequest(nil, guid)
return nil
}

@ -34,34 +34,44 @@ type decisionWaiter struct {
}
type ruleColumns struct {
Path string
Proto string
Pid int
Target string
Hostname string
Port int
UID int
GID int
Uname string
Gname string
Origin string
Scope int
nrefs int
Path string
GUID string
Icon string
Proto string
Pid int
Target string
Hostname string
Port int
UID int
GID int
Uname string
Gname string
Origin string
Timestamp string
IsSocks bool
ForceTLS bool
Scope int
}
var dbuso *dbusObject
var userPrefs fpPreferences
var mainWin *gtk.Window
var Notebook *gtk.Notebook
var globalLS *gtk.ListStore
var globalLS *gtk.ListStore = nil
var globalTV *gtk.TreeView
var globalPromptLock = &sync.Mutex{}
var globalIcon *gtk.Image
var decisionWaiters []*decisionWaiter
var editApp, editTarget, editPort, editUser, editGroup *gtk.Entry
var comboProto *gtk.ComboBoxText
var radioOnce, radioProcess, radioParent, radioSession, radioPermanent *gtk.RadioButton
var btnApprove, btnDeny, btnIgnore *gtk.Button
var chkUser, chkGroup *gtk.CheckButton
var chkTLS, chkUser, chkGroup *gtk.CheckButton
func dumpDecisions() {
return
fmt.Println("XXX Total of decisions pending: ", len(decisionWaiters))
for i := 0; i < len(decisionWaiters); i++ {
fmt.Printf("XXX %d ready = %v, rule = %v\n", i+1, decisionWaiters[i].Ready, decisionWaiters[i].Rule)
@ -69,6 +79,7 @@ func dumpDecisions() {
}
func addDecision() *decisionWaiter {
return nil
decision := decisionWaiter{Lock: &sync.Mutex{}, Ready: false, Scope: int(sgfw.APPLY_ONCE), Rule: ""}
decision.Cond = sync.NewCond(decision.Lock)
decisionWaiters = append(decisionWaiters, &decision)
@ -306,7 +317,8 @@ func createColumn(title string, id int) *gtk.TreeViewColumn {
}
func createListStore(general bool) *gtk.ListStore {
colData := []glib.Type{glib.TYPE_INT, glib.TYPE_STRING, glib.TYPE_STRING, glib.TYPE_INT, glib.TYPE_STRING, glib.TYPE_STRING, glib.TYPE_INT, glib.TYPE_INT, glib.TYPE_INT, glib.TYPE_STRING, glib.TYPE_STRING}
colData := []glib.Type{glib.TYPE_INT, glib.TYPE_STRING, glib.TYPE_STRING, glib.TYPE_STRING, glib.TYPE_STRING, glib.TYPE_INT, glib.TYPE_STRING,
glib.TYPE_STRING, glib.TYPE_INT, glib.TYPE_INT, glib.TYPE_INT, glib.TYPE_STRING, glib.TYPE_STRING, glib.TYPE_INT, glib.TYPE_STRING, glib.TYPE_INT}
listStore, err := gtk.ListStoreNew(colData...)
if err != nil {
@ -316,24 +328,92 @@ func createListStore(general bool) *gtk.ListStore {
return listStore
}
func addRequest(listStore *gtk.ListStore, path, proto string, pid int, ipaddr, hostname string, port, uid, gid int, origin string, is_socks bool, optstring string, sandbox string) *decisionWaiter {
func removeRequest(listStore *gtk.ListStore, guid string) {
removed := false
globalPromptLock.Lock()
defer globalPromptLock.Unlock()
/* XXX: This is horrible. Figure out how to do this properly. */
for ridx := 0; ridx < globalLS.IterNChildren(nil); ridx++ {
rule, _, err := getRuleByIdx(ridx)
if err != nil {
break
} else if rule.GUID == guid {
removeSelectedRule(ridx, true)
removed = true
break
}
}
if !removed {
log.Printf("Unexpected condition: SGFW requested prompt removal for non-existent GUID %v\n", guid)
}
}
func addRequestInc(listStore *gtk.ListStore, guid, path, icon, proto string, pid int, ipaddr, hostname string, port, uid, gid int,
origin string, is_socks bool, optstring string, sandbox string, action int) bool {
duplicated := false
globalPromptLock.Lock()
defer globalPromptLock.Unlock()
for ridx := 0; ridx < globalLS.IterNChildren(nil); ridx++ {
/* XXX: This is horrible. Figure out how to do this properly. */
rule, iter, err := getRuleByIdx(ridx)
if err != nil {
break
// XXX: not compared: optstring/sandbox
} else if (rule.Path == path) && (rule.Proto == proto) && (rule.Pid == pid) && (rule.Target == ipaddr) && (rule.Hostname == hostname) &&
(rule.Port == port) && (rule.UID == uid) && (rule.GID == gid) && (rule.Origin == origin) && (rule.IsSocks == is_socks) {
rule.nrefs++
err := globalLS.SetValue(iter, 0, rule.nrefs)
if err != nil {
log.Println("Error creating duplicate firewall prompt entry:", err)
break
}
fmt.Println("YES REALLY DUPLICATE: ", rule.nrefs)
duplicated = true
break
}
}
return duplicated
}
func addRequestAsync(listStore *gtk.ListStore, guid, path, icon, proto string, pid int, ipaddr, hostname string, port, uid, gid int,
origin, timestamp string, is_socks bool, optstring string, sandbox string, action int) bool {
addRequest(listStore, guid, path, icon, proto, pid, ipaddr, hostname, port, uid, gid, origin, timestamp, is_socks,
optstring, sandbox, action)
return true
}
func addRequest(listStore *gtk.ListStore, guid, path, icon, proto string, pid int, ipaddr, hostname string, port, uid, gid int,
origin, timestamp string, is_socks bool, optstring string, sandbox string, action int) *decisionWaiter {
if listStore == nil {
listStore = globalLS
waitTimes := []int{1, 2, 5, 10}
if listStore == nil {
log.Print("SGFW prompter was not ready to receive firewall request... waiting")
}
log.Println("SGFW prompter was not ready to receive firewall request... waiting")
for _, wtime := range waitTimes {
time.Sleep(time.Duration(wtime) * time.Second)
listStore = globalLS
for _, wtime := range waitTimes {
time.Sleep(time.Duration(wtime) * time.Second)
listStore = globalLS
if listStore != nil {
break
if listStore != nil {
break
}
log.Println("SGFW prompter is still waiting...")
}
log.Print("SGFW prompter is still waiting...")
}
}
@ -342,6 +422,18 @@ func addRequest(listStore *gtk.ListStore, path, proto string, pid int, ipaddr, h
log.Fatal("SGFW prompter GUI failed to load for unknown reasons")
}
if addRequestInc(listStore, guid, path, icon, proto, pid, ipaddr, hostname, port, uid, gid, origin, is_socks, optstring, sandbox, action) {
fmt.Println("REQUEST WAS DUPLICATE")
decision := addDecision()
globalPromptLock.Lock()
toggleHover()
globalPromptLock.Unlock()
return decision
} else {
fmt.Println("NOT DUPLICATE")
}
globalPromptLock.Lock()
iter := listStore.Append()
if is_socks {
@ -352,24 +444,34 @@ func addRequest(listStore *gtk.ListStore, path, proto string, pid int, ipaddr, h
}
}
colVals := make([]interface{}, 11)
colVals := make([]interface{}, 16)
colVals[0] = 1
colVals[1] = path
colVals[2] = proto
colVals[3] = pid
colVals[1] = guid
colVals[2] = path
colVals[3] = icon
colVals[4] = proto
colVals[5] = pid
if ipaddr == "" {
colVals[4] = "---"
colVals[6] = "---"
} else {
colVals[4] = ipaddr
colVals[6] = ipaddr
}
colVals[7] = hostname
colVals[8] = port
colVals[9] = uid
colVals[10] = gid
colVals[11] = origin
colVals[12] = timestamp
colVals[13] = 0
if is_socks {
colVals[13] = 1
}
colVals[5] = hostname
colVals[6] = port
colVals[7] = uid
colVals[8] = gid
colVals[9] = origin
colVals[10] = optstring
colVals[14] = optstring
colVals[15] = action
colNums := make([]int, len(colVals))
@ -386,6 +488,7 @@ func addRequest(listStore *gtk.ListStore, path, proto string, pid int, ipaddr, h
decision := addDecision()
dumpDecisions()
toggleHover()
globalPromptLock.Unlock()
return decision
}
@ -479,22 +582,42 @@ func lsGetInt(ls *gtk.ListStore, iter *gtk.TreeIter, idx int) (int, error) {
return ival.(int), nil
}
func makeDecision(idx int, rule string, scope int) {
func makeDecision(idx int, rule string, scope int) error {
var dres bool
call := dbuso.Call("AddRuleAsync", 0, uint32(scope), rule, "*")
err := call.Store(&dres)
if err != nil {
log.Println("Error notifying SGFW of asynchronous rule addition:", err)
return err
}
fmt.Println("makeDecision remote result:", dres)
return nil
decisionWaiters[idx].Cond.L.Lock()
decisionWaiters[idx].Rule = rule
decisionWaiters[idx].Scope = scope
decisionWaiters[idx].Ready = true
decisionWaiters[idx].Cond.Signal()
decisionWaiters[idx].Cond.L.Unlock()
return nil
}
/* Do we need to hold the lock while this is called? Stay safe... */
func toggleHover() {
mainWin.SetKeepAbove(len(decisionWaiters) > 0)
nitems := globalLS.IterNChildren(nil)
mainWin.SetKeepAbove(nitems > 0)
}
func toggleValidRuleState() {
ok := true
// XXX: Unfortunately, this can cause deadlock since it's a part of the item removal cascade
// globalPromptLock.Lock()
// defer globalPromptLock.Unlock()
if numSelections() <= 0 {
ok = false
}
@ -536,7 +659,8 @@ func toggleValidRuleState() {
btnApprove.SetSensitive(ok)
btnDeny.SetSensitive(ok)
btnIgnore.SetSensitive(ok)
// btnIgnore.SetSensitive(ok)
btnIgnore.SetSensitive(false)
}
func createCurrentRule() (ruleColumns, error) {
@ -579,6 +703,9 @@ func createCurrentRule() (ruleColumns, error) {
rule.UID, rule.GID = 0, 0
rule.Uname, rule.Gname = "", ""
rule.ForceTLS = chkTLS.GetActive()
/* Pid int
Origin string */
@ -586,6 +713,7 @@ func createCurrentRule() (ruleColumns, error) {
}
func clearEditor() {
globalIcon.Clear()
editApp.SetText("")
editTarget.SetText("")
editPort.SetText("")
@ -599,6 +727,7 @@ func clearEditor() {
radioPermanent.SetActive(false)
chkUser.SetActive(false)
chkGroup.SetActive(false)
chkTLS.SetActive(false)
}
func removeSelectedRule(idx int, rmdecision bool) error {
@ -617,7 +746,7 @@ func removeSelectedRule(idx int, rmdecision bool) error {
globalLS.Remove(iter)
if rmdecision {
decisionWaiters = append(decisionWaiters[:idx], decisionWaiters[idx+1:]...)
// decisionWaiters = append(decisionWaiters[:idx], decisionWaiters[idx+1:]...)
}
toggleHover()
@ -634,78 +763,126 @@ func numSelections() int {
return int(rows.Length())
}
func getSelectedRule() (ruleColumns, int, error) {
// Needs to be locked by the caller
func getRuleByIdx(idx int) (ruleColumns, *gtk.TreeIter, error) {
rule := ruleColumns{}
sel, err := globalTV.GetSelection()
path, err := gtk.TreePathNewFromString(fmt.Sprintf("%d", idx))
if err != nil {
return rule, -1, err
return rule, nil, err
}
rows := sel.GetSelectedRows(globalLS)
iter, err := globalLS.GetIter(path)
if err != nil {
return rule, nil, err
}
if rows.Length() <= 0 {
return rule, -1, errors.New("No selection was made")
rule.nrefs, err = lsGetInt(globalLS, iter, 0)
if err != nil {
return rule, nil, err
}
rdata := rows.NthData(0)
lIndex, err := strconv.Atoi(rdata.(*gtk.TreePath).String())
rule.GUID, err = lsGetStr(globalLS, iter, 1)
if err != nil {
return rule, -1, err
return rule, nil, err
}
fmt.Println("lindex = ", lIndex)
path, err := gtk.TreePathNewFromString(fmt.Sprintf("%d", lIndex))
rule.Path, err = lsGetStr(globalLS, iter, 2)
if err != nil {
return rule, -1, err
return rule, nil, err
}
iter, err := globalLS.GetIter(path)
rule.Icon, err = lsGetStr(globalLS, iter, 3)
if err != nil {
return rule, -1, err
return rule, nil, err
}
rule.Path, err = lsGetStr(globalLS, iter, 1)
rule.Proto, err = lsGetStr(globalLS, iter, 4)
if err != nil {
return rule, -1, err
return rule, nil, err
}
rule.Proto, err = lsGetStr(globalLS, iter, 2)
rule.Pid, err = lsGetInt(globalLS, iter, 5)
if err != nil {
return rule, -1, err
return rule, nil, err
}
rule.Pid, err = lsGetInt(globalLS, iter, 3)
rule.Target, err = lsGetStr(globalLS, iter, 6)
if err != nil {
return rule, -1, err
return rule, nil, err
}
rule.Target, err = lsGetStr(globalLS, iter, 4)
rule.Hostname, err = lsGetStr(globalLS, iter, 7)
if err != nil {
return rule, -1, err
return rule, nil, err
}
rule.Hostname, err = lsGetStr(globalLS, iter, 5)
rule.Port, err = lsGetInt(globalLS, iter, 8)
if err != nil {
return rule, -1, err
return rule, nil, err
}
rule.Port, err = lsGetInt(globalLS, iter, 6)
rule.UID, err = lsGetInt(globalLS, iter, 9)
if err != nil {
return rule, -1, err
return rule, nil, err
}
rule.GID, err = lsGetInt(globalLS, iter, 10)
if err != nil {
return rule, nil, err
}
rule.UID, err = lsGetInt(globalLS, iter, 7)
rule.Origin, err = lsGetStr(globalLS, iter, 11)
if err != nil {
return rule, nil, err
}
rule.Timestamp, err = lsGetStr(globalLS, iter, 12)
if err != nil {
return rule, nil, err
}
rule.IsSocks = false
is_socks, err := lsGetInt(globalLS, iter, 13)
if err != nil {
return rule, nil, err
}
if is_socks != 0 {
rule.IsSocks = true
}
rule.Scope, err = lsGetInt(globalLS, iter, 15)
if err != nil {
return rule, nil, err
}
return rule, iter, nil
}
// Needs to be locked by the caller
func getSelectedRule() (ruleColumns, int, error) {
rule := ruleColumns{}
sel, err := globalTV.GetSelection()
if err != nil {
return rule, -1, err
}
rule.GID, err = lsGetInt(globalLS, iter, 8)
rows := sel.GetSelectedRows(globalLS)
if rows.Length() <= 0 {
return rule, -1, errors.New("No selection was made")
}
rdata := rows.NthData(0)
lIndex, err := strconv.Atoi(rdata.(*gtk.TreePath).String())
if err != nil {
return rule, -1, err
}
rule.Origin, err = lsGetStr(globalLS, iter, 9)
fmt.Println("lindex = ", lIndex)
rule, _, err = getRuleByIdx(lIndex)
if err != nil {
return rule, -1, err
}
@ -713,6 +890,109 @@ func getSelectedRule() (ruleColumns, int, error) {
return rule, lIndex, nil
}
func addPendingPrompts(rules []string) {
for _, rule := range rules {
fields := strings.Split(rule, "|")
if len(fields) != 19 {
log.Printf("Got saved prompt message with strange data: \"%s\"", rule)
continue
}
guid := fields[0]
icon := fields[2]
path := fields[3]
address := fields[4]
port, err := strconv.Atoi(fields[5])
if err != nil {
log.Println("Error converting port in pending prompt message to integer:", err)
continue
}
ip := fields[6]
origin := fields[7]
proto := fields[8]
uid, err := strconv.Atoi(fields[9])
if err != nil {
log.Println("Error converting UID in pending prompt message to integer:", err)
continue
}
gid, err := strconv.Atoi(fields[10])
if err != nil {
log.Println("Error converting GID in pending prompt message to integer:", err)
continue
}
pid, err := strconv.Atoi(fields[13])
if err != nil {
log.Println("Error converting pid in pending prompt message to integer:", err)
continue
}
sandbox := fields[14]
is_socks, err := strconv.ParseBool(fields[15])
if err != nil {
log.Println("Error converting SOCKS flag in pending prompt message to boolean:", err)
continue
}
timestamp := fields[16]
optstring := fields[17]
action, err := strconv.Atoi(fields[18])
if err != nil {
log.Println("Error converting action in pending prompt message to integer:", err)
continue
}
addRequestAsync(nil, guid, path, icon, proto, int(pid), ip, address, int(port), int(uid), int(gid), origin, timestamp, is_socks, optstring, sandbox, action)
}
}
func buttonAction(action string) {
globalPromptLock.Lock()
rule, idx, err := getSelectedRule()
if err != nil {
globalPromptLock.Unlock()
promptError("Error occurred processing request: " + err.Error())
return
}
rule, err = createCurrentRule()
if err != nil {
globalPromptLock.Unlock()
promptError("Error occurred constructing new rule: " + err.Error())
return
}
fmt.Println("rule = ", rule)
rulestr := action
if action == "ALLOW" && rule.ForceTLS {
rulestr += "_TLSONLY"
}
rulestr += "|" + rule.Proto + ":" + rule.Target + ":" + strconv.Itoa(rule.Port)
rulestr += "|" + sgfw.RuleModeString[sgfw.RuleMode(rule.Scope)]
fmt.Println("RULESTR = ", rulestr)
makeDecision(idx, rulestr, int(rule.Scope))
fmt.Println("Decision made.")
err = removeSelectedRule(idx, true)
globalPromptLock.Unlock()
if err == nil {
clearEditor()
} else {
promptError("Error setting new rule: " + err.Error())
}
}
func main() {
decisionWaiters = make([]*decisionWaiter, 0)
_, err := newDbusServer()
@ -721,6 +1001,11 @@ func main() {
return
}
dbuso, err = newDbusObjectAdd()
if err != nil {
log.Fatal("Failed to connect to dbus system bus: %v", err)
}
loadPreferences()
gtk.Init(nil)
@ -811,10 +1096,18 @@ func main() {
editbox := get_vbox()
hbox := get_hbox()
lbl := get_label("Application path:")
globalIcon, err = gtk.ImageNew()
if err != nil {
log.Fatal("Unable to create image:", err)
}
// globalIcon.SetFromIconName("firefox", gtk.ICON_SIZE_DND)
editApp = get_entry("")
editApp.Connect("changed", toggleValidRuleState)
hbox.PackStart(lbl, false, false, 10)
hbox.PackStart(editApp, true, true, 50)
hbox.PackStart(editApp, true, true, 10)
hbox.PackStart(globalIcon, false, false, 10)
editbox.PackStart(hbox, false, false, 5)
hbox = get_hbox()
@ -842,7 +1135,9 @@ func main() {
radioSession = get_radiobutton(radioOnce, "Session", false)
radioPermanent = get_radiobutton(radioOnce, "Permanent", false)
radioParent.SetSensitive(false)
hbox.PackStart(lbl, false, false, 10)
chkTLS = get_checkbox("Require TLS", false)
hbox.PackStart(chkTLS, false, false, 10)
hbox.PackStart(lbl, false, false, 20)
hbox.PackStart(radioOnce, false, false, 5)
hbox.PackStart(radioProcess, false, false, 5)
hbox.PackStart(radioParent, false, false, 5)
@ -872,94 +1167,55 @@ func main() {
box.PackStart(scrollbox, false, true, 5)
tv.AppendColumn(createColumn("#", 0))
tv.AppendColumn(createColumn("Path", 1))
tv.AppendColumn(createColumn("Protocol", 2))
tv.AppendColumn(createColumn("PID", 3))
tv.AppendColumn(createColumn("IP Address", 4))
tv.AppendColumn(createColumn("Hostname", 5))
tv.AppendColumn(createColumn("Port", 6))
tv.AppendColumn(createColumn("UID", 7))
tv.AppendColumn(createColumn("GID", 8))
tv.AppendColumn(createColumn("Origin", 9))
tv.AppendColumn(createColumn("Details", 10))
listStore := createListStore(true)
globalLS = listStore
guidcol := createColumn("GUID", 1)
guidcol.SetVisible(false)
tv.AppendColumn(guidcol)
tv.SetModel(listStore)
tv.AppendColumn(createColumn("Path", 2))
btnApprove.Connect("clicked", func() {
rule, idx, err := getSelectedRule()
if err != nil {
promptError("Error occurred processing request: " + err.Error())
return
}
icol := createColumn("Icon", 3)
icol.SetVisible(false)
tv.AppendColumn(icol)
rule, err = createCurrentRule()
if err != nil {
promptError("Error occurred constructing new rule: " + err.Error())
return
}
tv.AppendColumn(createColumn("Protocol", 4))
tv.AppendColumn(createColumn("PID", 5))
tv.AppendColumn(createColumn("IP Address", 6))
tv.AppendColumn(createColumn("Hostname", 7))
tv.AppendColumn(createColumn("Port", 8))
tv.AppendColumn(createColumn("UID", 9))
tv.AppendColumn(createColumn("GID", 10))
tv.AppendColumn(createColumn("Origin", 11))
tv.AppendColumn(createColumn("Timestamp", 12))
fmt.Println("rule = ", rule)
rulestr := "ALLOW|" + rule.Proto + ":" + rule.Target + ":" + strconv.Itoa(rule.Port)
fmt.Println("RULESTR = ", rulestr)
makeDecision(idx, rulestr, int(rule.Scope))
fmt.Println("Decision made.")
err = removeSelectedRule(idx, true)
if err == nil {
clearEditor()
} else {
promptError("Error setting new rule: " + err.Error())
}
})
scol := createColumn("Is SOCKS", 13)
scol.SetVisible(false)
tv.AppendColumn(scol)
btnDeny.Connect("clicked", func() {
rule, idx, err := getSelectedRule()
if err != nil {
promptError("Error occurred processing request: " + err.Error())
return
}
tv.AppendColumn(createColumn("Details", 14))
rule, err = createCurrentRule()
if err != nil {
promptError("Error occurred constructing new rule: " + err.Error())
return
}
acol := createColumn("Scope", 15)
acol.SetVisible(false)
tv.AppendColumn(acol)
fmt.Println("rule = ", rule)
rulestr := "DENY|" + rule.Proto + ":" + rule.Target + ":" + strconv.Itoa(rule.Port)
fmt.Println("RULESTR = ", rulestr)
makeDecision(idx, rulestr, int(rule.Scope))
fmt.Println("Decision made.")
err = removeSelectedRule(idx, true)
if err == nil {
clearEditor()
} else {
promptError("Error setting new rule: " + err.Error())
}
})
listStore := createListStore(true)
globalLS = listStore
btnIgnore.Connect("clicked", func() {
_, idx, err := getSelectedRule()
if err != nil {
promptError("Error occurred processing request: " + err.Error())
return
}
tv.SetModel(listStore)
makeDecision(idx, "", 0)
fmt.Println("Decision made.")
err = removeSelectedRule(idx, true)
if err == nil {
clearEditor()
} else {
promptError("Error setting new rule: " + err.Error())
}
btnApprove.Connect("clicked", func() {
buttonAction("ALLOW")
})
btnDeny.Connect("clicked", func() {
buttonAction("DENY")
})
// btnIgnore.Connect("clicked", buttonAction)
// tv.SetActivateOnSingleClick(true)
tv.Connect("row-activated", func() {
globalPromptLock.Lock()
seldata, _, err := getSelectedRule()
globalPromptLock.Unlock()
if err != nil {
promptError("Unexpected error reading selected rule: " + err.Error())
return
@ -967,6 +1223,12 @@ func main() {
editApp.SetText(seldata.Path)
if seldata.Icon != "" {
globalIcon.SetFromIconName(seldata.Icon, gtk.ICON_SIZE_DND)
} else {
globalIcon.Clear()
}
if seldata.Hostname != "" {
editTarget.SetText(seldata.Hostname)
} else {
@ -974,13 +1236,14 @@ func main() {
}
editPort.SetText(strconv.Itoa(seldata.Port))
radioOnce.SetActive(true)
radioProcess.SetActive(false)
radioOnce.SetActive(seldata.Scope == int(sgfw.APPLY_ONCE))
radioProcess.SetSensitive(seldata.Pid > 0)
radioParent.SetActive(false)
radioSession.SetActive(false)
radioPermanent.SetActive(false)
radioSession.SetActive(seldata.Scope == int(sgfw.APPLY_SESSION))
radioPermanent.SetActive(seldata.Scope == int(sgfw.APPLY_FOREVER))
comboProto.SetActiveID(seldata.Proto)
chkTLS.SetActive(seldata.IsSocks)
if seldata.Uname != "" {
editUser.SetText(seldata.Uname)
@ -1000,7 +1263,6 @@ func main() {
chkUser.SetActive(false)
chkGroup.SetActive(false)
return
})
@ -1023,5 +1285,16 @@ func main() {
mainWin.ShowAll()
// mainWin.SetKeepAbove(true)
var dres = []string{}
call := dbuso.Call("GetPendingRequests", 0, "*")
err = call.Store(&dres)
if err != nil {
errmsg := "Could not query running SGFW instance (maybe it's not running?): " + err.Error()
promptError(errmsg)
} else {
addPendingPrompts(dres)
}
gtk.Main()
}

@ -199,6 +199,74 @@ func (ds *dbusServer) DeleteRule(id uint32) *dbus.Error {
return nil
}
func (ds *dbusServer) GetPendingRequests(policy string) ([]string, *dbus.Error) {
log.Debug("+++ GetPendingRequests()")
ds.fw.lock.Lock()
defer ds.fw.lock.Unlock()
pending_data := make([]string, 0)
for pname := range ds.fw.policyMap {
policy := ds.fw.policyMap[pname]
pqueue := policy.pendingQueue
for _, pc := range pqueue {
addr := pc.hostname()
if addr == "" {
addr = pc.dst().String()
}
dststr := ""
if pc.dst() != nil {
dststr = pc.dst().String()
} else {
dststr = addr + " (via proxy resolver)"
}
pstr := ""
pstr += pc.getGUID() + "|"
pstr += policy.application + "|"
pstr += policy.icon + "|"
pstr += policy.path + "|"
pstr += addr + "|"
pstr += strconv.FormatUint(uint64(pc.dstPort()), 10) + "|"
pstr += dststr + "|"
pstr += pc.src().String() + "|"
pstr += pc.proto() + "|"
pstr += strconv.FormatInt(int64(pc.procInfo().UID), 10) + "|"
pstr += strconv.FormatInt(int64(pc.procInfo().GID), 10) + "|"
pstr += uidToUser(pc.procInfo().UID) + "|"
pstr += gidToGroup(pc.procInfo().GID) + "|"
pstr += strconv.FormatInt(int64(pc.procInfo().Pid), 10) + "|"
pstr += pc.sandbox() + "|"
pstr += strconv.FormatBool(pc.socks()) + "|"
pstr += pc.getTimestamp() + "|"
pstr += pc.getOptString() + "|"
pstr += strconv.FormatUint(uint64(FirewallConfig.DefaultActionID), 10)
pending_data = append(pending_data, pstr)
}
}
return pending_data, nil
}
func (ds *dbusServer) AddRuleAsync(scope uint32, rule string, policy string) (bool, *dbus.Error) {
log.Warningf("AddRuleAsync %v, %v / %v\n", scope, rule, policy)
ds.fw.lock.Lock()
defer ds.fw.lock.Unlock()
prule := PendingRule{rule: rule, scope: int(scope), policy: policy}
for pname := range ds.fw.policyMap {
log.Debug("+++ Adding prule to policy")
ds.fw.policyMap[pname].rulesPending = append(ds.fw.policyMap[pname].rulesPending, prule)
}
return true, nil
}
func (ds *dbusServer) UpdateRule(rule DbusRule) *dbus.Error {
log.Debugf("UpdateRule %v", rule)
ds.fw.lock.Lock()
@ -268,11 +336,6 @@ func (ds *dbusServer) SetConfig(key string, val dbus.Variant) *dbus.Error {
return nil
}
func (ds *dbusServer) prompt(p *Policy) {
log.Info("prompting...")
ds.prompter.prompt(p)
}
func (ob *dbusObjectP) alertRule(data string) {
ob.Call("com.subgraph.fwprompt.EventNotifier.Alert", 0, data)
}

@ -6,8 +6,6 @@ import (
"strings"
"sync"
// "encoding/binary"
// nfnetlink "github.com/subgraph/go-nfnetlink"
"github.com/google/gopacket/layers"
nfqueue "github.com/subgraph/go-nfnetlink/nfqueue"
@ -15,6 +13,7 @@ import (
"net"
"os"
"syscall"
"time"
"unsafe"
)
@ -52,6 +51,10 @@ type pendingConnection interface {
drop()
setPrompting(bool)
getPrompting() bool
setPrompter(*prompter)
getPrompter() *prompter
getGUID() string
getTimestamp() string
print() string
}
@ -62,6 +65,24 @@ type pendingPkt struct {
pinfo *procsnitch.Info
optstring string
prompting bool
prompter *prompter
guid string
timestamp time.Time
}
/* Not a *REAL* GUID */
func genGUID() string {
frnd, err := os.Open("/dev/urandom")
if err != nil {
log.Fatal("Error reading random data source:", err)
}
rndb := make([]byte, 16)
frnd.Read(rndb)
frnd.Close()
guid := fmt.Sprintf("%x-%x-%x-%x", rndb[0:4], rndb[4:8], rndb[8:12], rndb[12:])
return guid
}
func getEmptyPInfo() *procsnitch.Info {
@ -79,6 +100,10 @@ func (pp *pendingPkt) sandbox() string {
return pp.pinfo.Sandbox
}
func (pc *pendingPkt) getTimestamp() string {
return pc.timestamp.Format("15:04:05.00")
}
func (pp *pendingPkt) socks() bool {
return false
}
@ -165,6 +190,22 @@ func (pp *pendingPkt) drop() {
pp.pkt.Accept()
}
func (pp *pendingPkt) setPrompter(val *prompter) {
pp.prompter = val
}
func (pp *pendingPkt) getPrompter() *prompter {
return pp.prompter
}
func (pp *pendingPkt) getGUID() string {
if pp.guid == "" {
pp.guid = genGUID()
}
return pp.guid
}
func (pp *pendingPkt) getPrompting() bool {
return pp.prompting
}
@ -177,6 +218,12 @@ func (pp *pendingPkt) print() string {
return printPacket(pp.pkt, pp.name, pp.pinfo)
}
type PendingRule struct {
rule string
scope int
policy string
}
type Policy struct {
fw *Firewall
path string
@ -187,6 +234,7 @@ type Policy struct {
pendingQueue []pendingConnection
promptInProgress bool
lock sync.Mutex
rulesPending []PendingRule
}
func (fw *Firewall) PolicyForPath(path string) *Policy {
@ -240,17 +288,25 @@ func (fw *Firewall) policyForPath(path string) *Policy {
return fw.policyMap[path]
}
func (p *Policy) processPacket(pkt *nfqueue.NFQPacket, pinfo *procsnitch.Info, optstr string) {
func (p *Policy) processPacket(pkt *nfqueue.NFQPacket, timestamp time.Time, pinfo *procsnitch.Info, optstr string) {
fmt.Println("policy processPacket()")
/* hbytes, err := pkt.GetHWAddr()
if err != nil {
log.Notice("Failed to get HW address underlying packet: ", err)
} else { log.Notice("got hwaddr: ", hbytes) } */
p.lock.Lock()
defer p.lock.Unlock()
dstb := pkt.Packet.NetworkLayer().NetworkFlow().Dst().Raw()
dstip := net.IP(dstb)
srcip := net.IP(pkt.Packet.NetworkLayer().NetworkFlow().Src().Raw())
/* Can we pass this through quickly? */
/* this probably isn't a performance enhancement. */
/*_, dstp := getPacketPorts(pkt)
fres := p.rules.filter(pkt, srcip, dstip, dstp, dstip.String(), pinfo, optstr)
if fres == FILTER_ALLOW {
fmt.Printf("Packet passed wildcard rules without requiring DNS lookup; accepting: %s:%d\n", dstip, dstp)
pkt.Accept()
return
}*/
name := p.fw.dns.Lookup(dstip, pinfo.Pid)
log.Infof("Lookup(%s): %s", dstip.String(), name)
@ -267,7 +323,7 @@ func (p *Policy) processPacket(pkt *nfqueue.NFQPacket, pinfo *procsnitch.Info, o
case FILTER_ALLOW:
pkt.Accept()
case FILTER_PROMPT:
p.processPromptResult(&pendingPkt{pol: p, name: name, pkt: pkt, pinfo: pinfo, optstring: optstr, prompting: false})
p.processPromptResult(&pendingPkt{pol: p, name: name, pkt: pkt, pinfo: pinfo, optstring: optstr, prompter: nil, timestamp: timestamp, prompting: false})
default:
log.Warningf("Unexpected filter result: %d", result)
}
@ -276,33 +332,27 @@ func (p *Policy) processPacket(pkt *nfqueue.NFQPacket, pinfo *procsnitch.Info, o
func (p *Policy) processPromptResult(pc pendingConnection) {
p.pendingQueue = append(p.pendingQueue, pc)
//fmt.Println("processPromptResult(): p.promptInProgress = ", p.promptInProgress)
if DoMultiPrompt || (!DoMultiPrompt && !p.promptInProgress) {
p.promptInProgress = true
go p.fw.dbus.prompt(p)
}
//if DoMultiPrompt || (!DoMultiPrompt && !p.promptInProgress) {
// if !p.promptInProgress {
p.promptInProgress = true
go p.fw.dbus.prompter.prompt(p)
// }
}
func (p *Policy) nextPending() (pendingConnection, bool) {
p.lock.Lock()
defer p.lock.Unlock()
if !DoMultiPrompt {
if len(p.pendingQueue) == 0 {
return nil, true
}
return p.pendingQueue[0], false
}
if len(p.pendingQueue) == 0 {
return nil, true
}
// for len(p.pendingQueue) != 0 {
for i := 0; i < len(p.pendingQueue); i++ {
// fmt.Printf("XXX: pendingQueue %v of %v: %v\n", i, len(p.pendingQueue), p.pendingQueue[i])
if !p.pendingQueue[i].getPrompting() {
return p.pendingQueue[i], false
}
}
// }
return nil, false
}
@ -329,6 +379,7 @@ func (p *Policy) processNewRule(r *Rule, scope FilterScope) bool {
if scope != APPLY_ONCE {
p.rules = append(p.rules, r)
}
fmt.Println("----------------------- processNewRule()")
p.filterPending(r)
if len(p.pendingQueue) == 0 {
p.promptInProgress = false
@ -372,6 +423,15 @@ func (p *Policy) filterPending(rule *Rule) {
remaining := []pendingConnection{}
for _, pc := range p.pendingQueue {
if rule.match(pc.src(), pc.dst(), pc.dstPort(), pc.hostname(), pc.proto(), pc.procInfo().UID, pc.procInfo().GID, uidToUser(pc.procInfo().UID), gidToGroup(pc.procInfo().GID), pc.procInfo().Sandbox) {
prompter := pc.getPrompter()
if prompter == nil {
fmt.Println("-------- prompter = NULL")
} else {
call := prompter.dbusObj.Call("com.subgraph.FirewallPrompt.RemovePrompt", 0, pc.getGUID())
fmt.Println("CAAAAAAAAAAAAAAALL = ", call)
}
log.Infof("Adding rule for: %s", rule.getString(FirewallConfig.LogRedact))
// log.Noticef("%s > %s", rule.getString(FirewallConfig.LogRedact), pc.print())
if rule.rtype == RULE_ACTION_ALLOW {
@ -442,7 +502,8 @@ func printPacket(pkt *nfqueue.NFQPacket, hostname string, pinfo *procsnitch.Info
return fmt.Sprintf("%s %s %s:%d -> %s:%d", pinfo.ExePath, proto, SrcIp, SrcPort, name, DstPort)
}
func (fw *Firewall) filterPacket(pkt *nfqueue.NFQPacket) {
func (fw *Firewall) filterPacket(pkt *nfqueue.NFQPacket, timestamp time.Time) {
fmt.Println("firewall: filterPacket()")
isudp := pkt.Packet.Layer(layers.LayerTypeUDP) != nil
if basicAllowPacket(pkt) {
@ -520,7 +581,7 @@ func (fw *Firewall) filterPacket(pkt *nfqueue.NFQPacket) {
*/
policy := fw.PolicyForPathAndSandbox(ppath, pinfo.Sandbox)
//log.Notice("XXX: flunked basicallowpacket; policy = ", policy)
policy.processPacket(pkt, pinfo, optstring)
policy.processPacket(pkt, timestamp, pinfo, optstring)
}
func readFileDirect(filename string) ([]byte, error) {

@ -2,24 +2,18 @@ package sgfw
import (
"fmt"
"net"
"os"
"os/user"
"strconv"
"strings"
"sync"
"syscall"
"time"
"github.com/godbus/dbus"
"github.com/subgraph/fw-daemon/proc-coroner"
)
var DoMultiPrompt = true
const MAX_PROMPTS = 5
var outstandingPrompts = 0
var promptLock = &sync.Mutex{}
func newPrompter(conn *dbus.Conn) *prompter {
p := new(prompter)
p.cond = sync.NewCond(&p.lock)
@ -42,6 +36,7 @@ func (p *prompter) prompt(policy *Policy) {
defer p.lock.Unlock()
_, ok := p.policyMap[policy.sandbox+"|"+policy.path]
if ok {
p.cond.Signal()
return
}
p.policyMap[policy.sandbox+"|"+policy.path] = policy
@ -51,35 +46,22 @@ func (p *prompter) prompt(policy *Policy) {
}
func (p *prompter) promptLoop() {
p.lock.Lock()
// p.lock.Lock()
for {
// fmt.Println("XXX: promptLoop() outer")
for p.processNextPacket() {
// fmt.Println("XXX: promptLoop() inner")
}
// fmt.Println("promptLoop() wait")
p.cond.Wait()
p.processNextPacket()
}
}
func (p *prompter) processNextPacket() bool {
//fmt.Println("processNextPacket()")
var pc pendingConnection = nil
if !DoMultiPrompt {
pc, _ = p.nextConnection()
if pc == nil {
return false
}
p.lock.Unlock()
defer p.lock.Lock()
p.processConnection(pc)
return true
}
empty := true
for {
p.lock.Lock()
pc, empty = p.nextConnection()
// fmt.Println("XXX: processNextPacket() loop; empty = ", empty, " / pc = ", pc)
p.lock.Unlock()
//fmt.Println("XXX: processNextPacket() loop; empty = ", empty, " / pc = ", pc)
if pc == nil && empty {
return false
} else if pc == nil {
@ -88,49 +70,150 @@ func (p *prompter) processNextPacket() bool {
break
}
}
p.lock.Unlock()
defer p.lock.Lock()
// fmt.Println("XXX: Waiting for prompt lock go...")
for {
promptLock.Lock()
if outstandingPrompts >= MAX_PROMPTS {
promptLock.Unlock()
continue
}
if pc.getPrompting() {
log.Debugf("Skipping over already prompted connection")
promptLock.Unlock()
continue
}
break
if pc.getPrompting() {
log.Debugf("Skipping over already prompted connection")
return false
}
// fmt.Println("XXX: Passed prompt lock!")
outstandingPrompts++
// fmt.Println("XXX: Incremented outstanding to ", outstandingPrompts)
promptLock.Unlock()
// if !pc.getPrompting() {
pc.setPrompting(true)
fmt.Println("processConnection")
go p.processConnection(pc)
// }
return true
}
func processReturn(pc pendingConnection) {
promptLock.Lock()
outstandingPrompts--
// fmt.Println("XXX: Return decremented outstanding to ", outstandingPrompts)
promptLock.Unlock()
pc.setPrompting(false)
type PC2FDMapping struct {
guid string
inode uint64
fd int
fdpath string
prompter *prompter
}
var PC2FDMap = map[string]PC2FDMapping{}
var PC2FDMapLock = &sync.Mutex{}
var PC2FDMapRunning = false
func monitorPromptFDs(pc pendingConnection) {
guid := pc.getGUID()
pid := pc.procInfo().Pid
inode := pc.procInfo().Inode
fd := pc.procInfo().FD
prompter := pc.getPrompter()
fmt.Printf("ADD TO MONITOR: %v | %v / %v / %v\n", pc.policy().application, guid, pid, fd)
if pid == -1 || fd == -1 || prompter == nil {
log.Warning("Unexpected error condition occurred while adding socket fd to monitor")
return
}
PC2FDMapLock.Lock()
defer PC2FDMapLock.Unlock()
fdpath := fmt.Sprintf("/proc/%d/fd/%d", pid, fd)
PC2FDMap[guid] = PC2FDMapping{guid: guid, inode: inode, fd: fd, fdpath: fdpath, prompter: prompter}
return
}
func monitorPromptFDLoop() {
fmt.Println("++++++++++= monitorPromptFDLoop()")
for true {
delete_guids := []string{}
PC2FDMapLock.Lock()
fmt.Println("++++ nentries = ", len(PC2FDMap))
for guid, fdmon := range PC2FDMap {
fmt.Println("ENTRY:", fdmon)
lsb, err := os.Stat(fdmon.fdpath)
if err != nil {
log.Warningf("Error looking up socket \"%s\": %v\n", fdmon.fdpath, err)
delete_guids = append(delete_guids, guid)
continue
}
sb, ok := lsb.Sys().(*syscall.Stat_t)
if !ok {
log.Warning("Not a syscall.Stat_t")
delete_guids = append(delete_guids, guid)
continue
}
inode := sb.Ino
fmt.Println("+++ INODE = ", inode)
if inode != fdmon.inode {
fmt.Printf("inode mismatch: %v vs %v\n", inode, fdmon.inode)
delete_guids = append(delete_guids, guid)
}
}
fmt.Println("guids to delete: ", delete_guids)
saved_mappings := []PC2FDMapping{}
for _, guid := range delete_guids {
saved_mappings = append(saved_mappings, PC2FDMap[guid])
delete(PC2FDMap, guid)
}
PC2FDMapLock.Unlock()
for _, mapping := range saved_mappings {
call := mapping.prompter.dbusObj.Call("com.subgraph.FirewallPrompt.RemovePrompt", 0, mapping.guid)
fmt.Println("DISPOSING CALL = ", call)
prompter := mapping.prompter
prompter.lock.Lock()
for _, policy := range prompter.policyQueue {
policy.lock.Lock()
pcind := 0
for pcind < len(policy.pendingQueue) {
if policy.pendingQueue[pcind].getGUID() == mapping.guid {
fmt.Println("-------------- found guid to remove")
policy.pendingQueue = append(policy.pendingQueue[:pcind], policy.pendingQueue[pcind+1:]...)
} else {
pcind++
}
}
policy.lock.Unlock()
}
prompter.lock.Unlock()
}
fmt.Println("++++++++++= monitorPromptFDLoop WAIT")
time.Sleep(5 * time.Second)
}
}
func (p *prompter) processConnection(pc pendingConnection) {
var scope int32
var dres bool
var rule string
if DoMultiPrompt {
defer processReturn(pc)
if !PC2FDMapRunning {
PC2FDMapLock.Lock()
if !PC2FDMapRunning {
PC2FDMapRunning = true
PC2FDMapLock.Unlock()
go monitorPromptFDLoop()
} else {
PC2FDMapLock.Unlock()
}
}
if pc.getPrompter() == nil {
pc.setPrompter(p)
}
addr := pc.hostname()
@ -144,10 +227,12 @@ func (p *prompter) processConnection(pc pendingConnection) {
if pc.dst() != nil {
dststr = pc.dst().String()
} else {
dststr = addr + " (proxy to resolve)"
dststr = addr + " (via proxy resolver)"
}
call := p.dbusObj.Call("com.subgraph.FirewallPrompt.RequestPrompt", 0,
monitorPromptFDs(pc)
call := p.dbusObj.Call("com.subgraph.FirewallPrompt.RequestPromptAsync", 0,
pc.getGUID(),
policy.application,
policy.icon,
policy.path,
@ -163,18 +248,58 @@ func (p *prompter) processConnection(pc pendingConnection) {
int32(pc.procInfo().Pid),
pc.sandbox(),
pc.socks(),
pc.getTimestamp(),
pc.getOptString(),
FirewallConfig.PromptExpanded,
FirewallConfig.PromptExpert,
int32(FirewallConfig.DefaultActionID))
err := call.Store(&scope, &rule)
err := call.Store(&dres)
if err != nil {
log.Warningf("Error sending dbus RequestPrompt message: %v", err)
log.Warningf("Error sending dbus async RequestPrompt message: %v", err)
policy.removePending(pc)
pc.drop()
return
}
if !dres {
fmt.Println("Unexpected: fw-prompt async RequestPrompt message returned:", dres)
}
return
/* p.dbusObj.Go("com.subgraph.FirewallPrompt.RequestPrompt", 0, callChan,
pc.getGUID(),
policy.application,
policy.icon,
policy.path,
addr,
int32(pc.dstPort()),
dststr,
pc.src().String(),
pc.proto(),
int32(pc.procInfo().UID),
int32(pc.procInfo().GID),
uidToUser(pc.procInfo().UID),
gidToGroup(pc.procInfo().GID),
int32(pc.procInfo().Pid),
pc.sandbox(),
pc.socks(),
pc.getOptString(),
FirewallConfig.PromptExpanded,
FirewallConfig.PromptExpert,
int32(FirewallConfig.DefaultActionID))
saveChannel(callChan, false, true)
/* err := call.Store(&scope, &rule)
if err != nil {
log.Warningf("Error sending dbus RequestPrompt message: %v", err)
policy.removePending(pc)
pc.drop()
return
} */
// the prompt sends:
// ALLOW|dest or DENY|dest
//
@ -194,17 +319,19 @@ func (p *prompter) processConnection(pc pendingConnection) {
}
tempRule := fmt.Sprintf("%s|%s", toks[0], toks[1])
tempRule += "||-1:-1|" + sandbox + "|"
if pc.src() != nil && !pc.src().Equal(net.ParseIP("127.0.0.1")) && sandbox != "" {
if pc.src() != nil && !pc.src().IsLoopback() && sandbox != "" {
//if !strings.HasSuffix(rule, "SYSTEM") && !strings.HasSuffix(rule, "||") {
//rule += "||"
//}
//ule += "|||" + pc.src().String()
tempRule += "||-1:-1|" + sandbox + "|" + pc.src().String()
// tempRule += "||-1:-1|" + sandbox + "|" + pc.src().String()
tempRule += pc.src().String()
} else {
tempRule += "||-1:-1|" + sandbox + "|"
// tempRule += "||-1:-1|" + sandbox + "|"
}
r, err := policy.parseRule(tempRule, false)
if err != nil {
@ -235,35 +362,109 @@ func (p *prompter) processConnection(pc pendingConnection) {
}
func (p *prompter) nextConnection() (pendingConnection, bool) {
for {
if len(p.policyQueue) == 0 {
return nil, true
}
policy := p.policyQueue[0]
pind := 0
if len(p.policyQueue) == 0 {
return nil, true
}
fmt.Println("policy queue len = ", len(p.policyQueue))
for pind < len(p.policyQueue) {
//fmt.Printf("XXX: pind = %v of %v\n", pind, len(p.policyQueue))
policy := p.policyQueue[pind]
pc, qempty := policy.nextPending()
if pc == nil && qempty {
p.removePolicy(policy)
continue
} else {
pind++
// if pc == nil && !qempty {
if len(policy.rulesPending) > 0 {
fmt.Println("policy rules pending = ", len(policy.rulesPending))
prule := policy.rulesPending[0]
policy.rulesPending = append(policy.rulesPending[:0], policy.rulesPending[1:]...)
toks := strings.Split(prule.rule, "|")
sandbox := ""
if len(toks) > 2 {
sandbox = toks[2]
}
sandbox += ""
tempRule := fmt.Sprintf("%s|%s", toks[0], toks[1])
tempRule += "||-1:-1|" + sandbox + "|"
/*if pc.src() != nil && !pc.src().IsLoopback() && sandbox != "" {
tempRule += "||-1:-1|" + sandbox + "|" + pc.src().String()
} else {
tempRule += "||-1:-1|" + sandbox + "|"
}*/
r, err := policy.parseRule(tempRule, false)
if err != nil {
log.Warningf("Error parsing rule string returned from dbus RequestPrompt: %v", err)
continue
// policy.removePending(pc)
// pc.drop()
// return
} else {
fscope := FilterScope(prule.scope)
if fscope == APPLY_SESSION {
r.mode = RULE_MODE_SESSION
} else if fscope == APPLY_PROCESS {
r.mode = RULE_MODE_PROCESS
/*r.pid = pc.procInfo().Pid
pcoroner.MonitorProcess(r.pid)*/
}
if !policy.processNewRule(r, fscope) {
// p.lock.Lock()
// defer p.lock.Unlock()
// p.removePolicy(pc.policy())
}
if fscope == APPLY_FOREVER {
r.mode = RULE_MODE_PERMANENT
policy.fw.saveRules()
}
log.Warningf("Prompt returning rule: %v", tempRule)
dbusp.alertRule("sgfw prompt added new rule")
}
}
if pc == nil && !qempty {
log.Errorf("FIX ME: I NEED TO SLEEP ON A WAKEABLE CONDITION PROPERLY!!")
// log.Errorf("FIX ME: I NEED TO SLEEP ON A WAKEABLE CONDITION PROPERLY!!")
time.Sleep(time.Millisecond * 300)
continue
}
if pc != nil && pc.getPrompting() {
fmt.Println("SKIPPING PROMPTED")
continue
}
return pc, qempty
}
}
return nil, true
}
func (p *prompter) removePolicy(policy *Policy) {
var newQueue []*Policy = nil
if DoMultiPrompt {
if len(p.policyQueue) == 0 {
log.Debugf("Skipping over zero length policy queue")
newQueue = make([]*Policy, 0, 0)
}
// if DoMultiPrompt {
if len(p.policyQueue) == 0 {
log.Debugf("Skipping over zero length policy queue")
newQueue = make([]*Policy, 0, 0)
}
// }
if !DoMultiPrompt || newQueue == nil {
// if !DoMultiPrompt || newQueue == nil {
if newQueue == nil {
newQueue = make([]*Policy, 0, len(p.policyQueue)-1)
}
for _, pol := range p.policyQueue {
@ -277,11 +478,17 @@ func (p *prompter) removePolicy(policy *Policy) {
var userMap = make(map[int]string)
var groupMap = make(map[int]string)
var userMapLock = &sync.Mutex{}
var groupMapLock = &sync.Mutex{}
func lookupUser(uid int) string {
if uid == -1 {
return "[unknown]"
}
userMapLock.Lock()
defer userMapLock.Unlock()
u, err := user.LookupId(strconv.Itoa(uid))
if err != nil {
return fmt.Sprintf("%d", uid)
@ -293,6 +500,10 @@ func lookupGroup(gid int) string {
if gid == -1 {
return "[unknown]"
}
groupMapLock.Lock()
defer groupMapLock.Unlock()
g, err := user.LookupGroupId(strconv.Itoa(gid))
if err != nil {
return fmt.Sprintf("%d", gid)

@ -479,7 +479,10 @@ func (fw *Firewall) loadRules() {
if err != nil {
if !os.IsNotExist(err) {
log.Warningf("Failed to open %s for reading: %v", p, err)
} else {
log.Warningf("Did not find a rules file at %s: SGFW loaded with no rules\n", p)
}
return
}
var policy *Policy
@ -497,7 +500,13 @@ func (fw *Firewall) loadRules() {
func (fw *Firewall) processPathLine(line string) *Policy {
pathLine := line[1 : len(line)-1]
toks := strings.Split(pathLine, "|")
if len(toks) != 2 {
log.Warning("Error parsing rules directive:", line)
return nil
}
policy := fw.policyForPathAndSandbox(toks[1], toks[0])
policy.lock.Lock()
defer policy.lock.Unlock()

@ -1,16 +1,16 @@
package sgfw
import (
"bufio"
"encoding/json"
"fmt"
"os"
"os/signal"
"regexp"
"strings"
"sync"
"syscall"
// "time"
"bufio"
"encoding/json"
"fmt"
"strings"
"time"
"github.com/op/go-logging"
nfqueue "github.com/subgraph/go-nfnetlink/nfqueue"
@ -110,6 +110,8 @@ func (fw *Firewall) runFilter() {
go func() {
for p := range ps {
timestamp := time.Now()
if fw.isEnabled() {
ipLayer := p.Packet.Layer(layers.LayerTypeIPv4)
if ipLayer == nil {
@ -127,7 +129,7 @@ func (fw *Firewall) runFilter() {
}
fw.filterPacket(p)
fw.filterPacket(p, timestamp)
} else {
p.Accept()
}

@ -56,7 +56,10 @@ type pendingSocksConnection struct {
pinfo *procsnitch.Info
verdict chan int
prompting bool
prompter *prompter
guid string
optstr string
timestamp time.Time
}
func (sc *pendingSocksConnection) sandbox() string {
@ -107,23 +110,58 @@ func (sc *pendingSocksConnection) deliverVerdict(v int) {
}
}()
sc.verdict <- v
close(sc.verdict)
if sc.verdict != nil {
sc.verdict <- v
close(sc.verdict)
sc.verdict = nil
}
}
func (sc *pendingSocksConnection) accept() { sc.deliverVerdict(socksVerdictAccept) }
func (sc *pendingSocksConnection) accept() {
sc.deliverVerdict(socksVerdictAccept)
}
// need to generalize special accept
func (sc *pendingSocksConnection) acceptTLSOnly() { sc.deliverVerdict(socksVerdictAcceptTLSOnly) }
func (sc *pendingSocksConnection) acceptTLSOnly() {
sc.deliverVerdict(socksVerdictAcceptTLSOnly)
}
func (sc *pendingSocksConnection) drop() { sc.deliverVerdict(socksVerdictDrop) }
func (sc *pendingSocksConnection) drop() {
sc.deliverVerdict(socksVerdictDrop)
}
func (sc *pendingSocksConnection) getPrompting() bool { return sc.prompting }
func (sc *pendingSocksConnection) setPrompter(val *prompter) {
sc.prompter = val
}
func (sc *pendingSocksConnection) setPrompting(val bool) { sc.prompting = val }
func (sc *pendingSocksConnection) getPrompter() *prompter {
return sc.prompter
}
func (sc *pendingSocksConnection) getTimestamp() string {
return sc.timestamp.Format("15:04:05.00")
}
func (sc *pendingSocksConnection) print() string { return "socks connection" }
func (sc *pendingSocksConnection) getGUID() string {
if sc.guid == "" {
sc.guid = genGUID()
}
return sc.guid
}
func (sc *pendingSocksConnection) getPrompting() bool {
return sc.prompting
}
func (sc *pendingSocksConnection) setPrompting(val bool) {
sc.prompting = val
}
func (sc *pendingSocksConnection) print() string {
return "socks connection"
}
func NewSocksChain(cfg *socksChainConfig, wg *sync.WaitGroup, fw *Firewall) *socksChain {
chain := socksChain{
@ -146,10 +184,11 @@ func (s *socksChain) start() {
}
s.wg.Add(1)
go s.socksAcceptLoop()
ts := time.Now()
go s.socksAcceptLoop(ts)
}
func (s *socksChain) socksAcceptLoop() error {
func (s *socksChain) socksAcceptLoop(timestamp time.Time) error {
defer s.wg.Done()
defer s.listener.Close()
@ -163,11 +202,11 @@ func (s *socksChain) socksAcceptLoop() error {
continue
}
session := &socksChainSession{cfg: s.cfg, clientConn: conn, procInfo: s.procInfo, server: s}
go session.sessionWorker()
go session.sessionWorker(timestamp)
}
}
func (c *socksChainSession) sessionWorker() {
func (c *socksChainSession) sessionWorker(timestamp time.Time) {
defer c.clientConn.Close()
clientAddr := c.clientConn.RemoteAddr()
@ -197,7 +236,7 @@ func (c *socksChainSession) sessionWorker() {
c.req.ReplyAddr(ReplySucceeded, c.bndAddr)
}
case CommandConnect:
verdict, tls := c.filterConnect()
verdict, tls := c.filterConnect(timestamp)
if !verdict {
c.req.Reply(ReplyConnectionRefused)
@ -278,7 +317,7 @@ func findProxyEndpoint(pdata []string, conn net.Conn) (*procsnitch.Info, string)
return nil, ""
}
func (c *socksChainSession) filterConnect() (bool, bool) {
func (c *socksChainSession) filterConnect(timestamp time.Time) (bool, bool) {
// return filter verdict, tlsguard
allProxies, err := ListProxies()
@ -364,7 +403,9 @@ func (c *socksChainSession) filterConnect() (bool, bool) {
pinfo: pinfo,
verdict: make(chan int),
prompting: false,
prompter: nil,
optstr: optstr,
timestamp: timestamp,
}
policy.processPromptResult(pending)
v := <-pending.verdict
@ -409,7 +450,7 @@ func (c *socksChainSession) forwardTraffic(tls bool) {
if c.pinfo.Sandbox != "" {
log.Errorf("TLSGuard violation: Dropping traffic from %s (sandbox: %s) to %s: %v", c.pinfo.ExePath, c.pinfo.Sandbox, c.req.Addr.addrStr, err)
} else {
log.Errorf("TLSGuard violation: Dropping traffic from %s (unsandboxed) to %s: %v", c.pinfo.ExePath, c.req.Addr.addrStr, err)
log.Errorf("TLSGuard violation: Dropping traffic from %s (un-sandboxed) to %s: %v", c.pinfo.ExePath, c.req.Addr.addrStr, err)
}
return
} else {

@ -2,6 +2,8 @@ package sgfw
import (
"crypto/x509"
"encoding/binary"
"encoding/hex"
"errors"
"fmt"
"io"
@ -9,197 +11,576 @@ import (
"time"
)
const TLSGUARD_READ_TIMEOUT = 2 * time.Second
const TLSGUARD_READ_TIMEOUT = 10 * time.Second
const TLSGUARD_MIN_TLS_VER_MAJ = 3
const TLSGUARD_MIN_TLS_VER_MIN = 1
const TLS_RECORD_HDR_LEN = 5
const SSL3_RT_CHANGE_CIPHER_SPEC = 20
const SSL3_RT_ALERT = 21
const SSL3_RT_HANDSHAKE = 22
const SSL3_RT_APPLICATION_DATA = 23
const SSL3_MT_HELLO_REQUEST = 0
const SSL3_MT_CLIENT_HELLO = 1
const SSL3_MT_SERVER_HELLO = 2
const SSL3_MT_CERTIFICATE = 11
const SSL3_MT_CERTIFICATE_REQUEST = 13
const SSL3_MT_SERVER_DONE = 14
const SSL3_MT_CERTIFICATE_STATUS = 22
const SSL3_AL_WARNING = 1
const SSL3_AL_FATAL = 2
const SSL3_AD_CLOSE_NOTIFY = 0
const SSL3_AD_UNEXPECTED_MESSAGE = 10
const SSL3_AD_BAD_RECORD_MAC = 20
const TLS1_AD_DECRYPTION_FAILED = 21
const TLS1_AD_RECORD_OVERFLOW = 22
const SSL3_AD_DECOMPRESSION_FAILURE = 30
const SSL3_AD_HANDSHAKE_FAILURE = 40
const SSL3_AD_NO_CERTIFICATE = 41
const SSL3_AD_BAD_CERTIFICATE = 42
const SSL3_AD_UNSUPPORTED_CERTIFICATE = 43
const SSL3_AD_CERTIFICATE_REVOKED = 44
const SSL3_AD_CERTIFICATE_EXPIRED = 45
const SSL3_AD_CERTIFICATE_UNKNOWN = 46
const SSL3_AD_ILLEGAL_PARAMETER = 47
const TLS1_AD_UNKNOWN_CA = 48
const TLS1_AD_ACCESS_DENIED = 49
const TLS1_AD_DECODE_ERROR = 50
const TLS1_AD_DECRYPT_ERROR = 51
const TLS1_AD_EXPORT_RESTRICTION = 60
const TLS1_AD_PROTOCOL_VERSION = 70
const TLS1_AD_INSUFFICIENT_SECURITY = 71
const TLS1_AD_INTERNAL_ERROR = 80
const TLS1_AD_INAPPROPRIATE_FALLBACK = 86
const TLS1_AD_USER_CANCELLED = 90
const TLS1_AD_NO_RENEGOTIATION = 100
const TLS1_AD_UNSUPPORTED_EXTENSION = 110
const TLSEXT_TYPE_server_name = 1
const TLSEXT_TYPE_signature_algorithms = 13
const TLSEXT_TYPE_client_certificate_type = 19
const TLSEXT_TYPE_extended_master_secret = 23
const TLSEXT_TYPE_renegotiate = 0xff01
type connReader struct {
client bool
data []byte
rtype int
err error
}
func readTLSChunk(conn net.Conn) ([]byte, int, error) {
conn.SetReadDeadline(time.Now().Add(TLSGUARD_READ_TIMEOUT))
cbytes, err := readNBytes(conn, 5)
conn.SetReadDeadline(time.Time{})
var cipherSuiteMap map[uint16]string = map[uint16]string{
0x0000: "TLS_NULL_WITH_NULL_NULL",
0x000a: "TLS_RSA_WITH_3DES_EDE_CBC_SHA",
0x002f: "TLS_RSA_WITH_AES_128_CBC_SHA",
0x0033: "TLS_DHE_RSA_WITH_AES_128_CBC_SHA",
0x0039: "TLS_DHE_RSA_WITH_AES_256_CBC_SHA",
0x0035: "TLS_RSA_WITH_AES_256_CBC_SHA",
0x0030: "TLS_DH_DSS_WITH_AES_128_CBC_SHA",
0xc009: "TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA",
0xc00a: "TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA",
0xc013: "TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA",
0xc014: "TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA",
0xc02b: "TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256",
0xc02c: "TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384",
0xc02f: "TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256",
0xc030: "TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384",
0xcca9: "TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256",
0xcca8: "TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256",
}
if err != nil {
log.Errorf("TLS data chunk read failure: ", err)
return nil, 0, err
func getCipherSuiteName(value uint) string {
val, ok := cipherSuiteMap[uint16(value)]
if !ok {
return "UNKNOWN"
}
if int(cbytes[1]) < TLSGUARD_MIN_TLS_VER_MAJ {
return nil, 0, errors.New("TLS protocol major version less than expected minimum")
} else if int(cbytes[2]) < TLSGUARD_MIN_TLS_VER_MIN {
return nil, 0, errors.New("TLS protocol minor version less than expected minimum")
}
return val
}
func connectionReader(conn net.Conn, is_client bool, c chan connReader, done chan bool) {
var ret_error error = nil
buffered := []byte{}
mlen := 0
rtype := 0
stage := 1
for {
if ret_error != nil {
cr := connReader{client: is_client, data: nil, rtype: 0, err: ret_error}
c <- cr
break
}
select {
case <-done:
fmt.Println("++ DONE: ", is_client)
if len(buffered) > 0 {
//fmt.Println("++ DONE BUT DISPOSING OF BUFFERED DATA")
c <- connReader{client: is_client, data: buffered, rtype: 0, err: nil}
}
c <- connReader{client: is_client, data: nil, rtype: 0, err: nil}
return
default:
if stage == 1 {
header := make([]byte, TLS_RECORD_HDR_LEN)
conn.SetReadDeadline(time.Now().Add(TLSGUARD_READ_TIMEOUT))
_, err := io.ReadFull(conn, header)
conn.SetReadDeadline(time.Time{})
if err != nil {
ret_error = err
continue
}
if int(header[1]) < TLSGUARD_MIN_TLS_VER_MAJ {
ret_error = errors.New("TLS protocol major version less than expected minimum")
continue
} else if int(header[2]) < TLSGUARD_MIN_TLS_VER_MIN {
ret_error = errors.New("TLS protocol minor version less than expected minimum")
continue
}
rtype = int(header[0])
mlen = int(int(header[3])<<8 | int(header[4]))
fmt.Printf("TLS data chunk header read: type = %#x, maj = %v, min = %v, len = %v\n", rtype, header[1], header[2], mlen)
/* 16384+1024 if compression is not null */
/* or 16384+2048 if ciphertext */
if mlen > 16384 {
ret_error = errors.New(fmt.Sprintf("TLSGuard read TLS plaintext record of excessively large length; dropping (%v bytes)", mlen))
continue
}
cbyte := cbytes[0]
mlen := int(int(cbytes[3])<<8 | int(cbytes[4]))
// fmt.Printf("TLS data chunk header read: type = %#x, maj = %v, min = %v, len = %v\n", cbyte, cbytes[1], cbytes[2], mlen)
buffered = header
stage++
} else if stage == 2 {
remainder := make([]byte, mlen)
conn.SetReadDeadline(time.Now().Add(TLSGUARD_READ_TIMEOUT))
_, err := io.ReadFull(conn, remainder)
conn.SetReadDeadline(time.Time{})
if err != nil {
ret_error = err
continue
}
conn.SetReadDeadline(time.Now().Add(TLSGUARD_READ_TIMEOUT))
cbytes2, err := readNBytes(conn, mlen)
conn.SetReadDeadline(time.Time{})
buffered = append(buffered, remainder...)
fmt.Printf("------- CHUNK READ: client: %v, err = %v, bytes = %v\n", is_client, err, len(buffered))
cr := connReader{client: is_client, data: buffered, rtype: rtype, err: err}
c <- cr
buffered = []byte{}
rtype = 0
mlen = 0
stage = 1
}
}
if err != nil {
return nil, 0, err
}
cbytes = append(cbytes, cbytes2...)
return cbytes, int(cbyte), nil
}
func TLSGuard(conn, conn2 net.Conn, fqdn string) error {
x509Valid := false
ndone := 0
// Should this be a requirement?
// if strings.HasSuffix(request.DestAddr.FQDN, "onion") {
//conn client
//conn2 server
// Read the opening message from the client
chunk, rtype, err := readTLSChunk(conn)
if err != nil {
return err
}
fmt.Println("-------- STARTING HANDSHAKE LOOP")
crChan := make(chan connReader)
dChan := make(chan bool, 10)
go connectionReader(conn, true, crChan, dChan)
go connectionReader(conn2, false, crChan, dChan)
client_expected := SSL3_MT_CLIENT_HELLO
server_expected := SSL3_MT_SERVER_HELLO
select_loop:
for {
if ndone == 2 {
fmt.Println("DONE channel got both notifications. Terminating loop.")
close(dChan)
close(crChan)
break
}
if rtype != SSL3_RT_HANDSHAKE {
return errors.New("Blocked client from attempting non-TLS connection")
}
select {
case cr := <-crChan:
other := conn
// Pass it on through to the server
conn2.Write(chunk)
if cr.client {
other = conn2
}
// Read ServerHello
valid := false
loop := 1
fmt.Printf("++++ SELECT: %v, %v, %v\n", cr.client, cr.err, len(cr.data))
if cr.err == nil && cr.data == nil {
fmt.Println("DONE channel notification received")
ndone++
continue
}
passthru := false
if cr.err == nil {
if cr.rtype == SSL3_RT_CHANGE_CIPHER_SPEC || cr.rtype == SSL3_RT_APPLICATION_DATA ||
cr.rtype == SSL3_RT_ALERT {
/* We expect only a single byte of data */
if cr.rtype == SSL3_RT_CHANGE_CIPHER_SPEC {
fmt.Println("CHANGE CIPHER_SPEC: ", cr.data[TLS_RECORD_HDR_LEN])
if len(cr.data) != 6 {
return errors.New(fmt.Sprintf("TLSGuard dropped connection with strange change cipher spec data length (%v bytes)", len(cr.data)))
}
if cr.data[TLS_RECORD_HDR_LEN] != 1 {
return errors.New(fmt.Sprintf("TLSGuard dropped connection with strange change cipher spec data (%#x bytes)", cr.data[TLS_RECORD_HDR_LEN]))
}
} else if cr.rtype == SSL3_RT_ALERT {
if cr.data[TLS_RECORD_HDR_LEN] == SSL3_AL_WARNING {
fmt.Println("SSL ALERT TYPE: warning")
} else if cr.data[TLS_RECORD_HDR_LEN] == SSL3_AL_FATAL {
fmt.Println("SSL ALERT TYPE: fatal")
} else {
fmt.Println("SSL ALERT TYPE UNKNOWN")
}
alert_desc := int(int(cr.data[6])<<8 | int(cr.data[7]))
fmt.Println("ALERT DESCRIPTION: ", alert_desc)
if cr.data[TLS_RECORD_HDR_LEN] == SSL3_AL_FATAL {
return errors.New(fmt.Sprintf("TLSGuard dropped connection after fatal error alert detected"))
} else if alert_desc == SSL3_AD_CLOSE_NOTIFY {
return errors.New(fmt.Sprintf("TLSGuard dropped connection after close_notify alert detected"))
}
}
// fmt.Println("OTHER DATA; PASSING THRU")
if cr.rtype == SSL3_RT_ALERT {
fmt.Println("ALERT = ", cr.data)
}
other.Write(cr.data)
continue
} else if cr.client {
// other.Write(cr.data)
// continue
} else if cr.rtype != SSL3_RT_HANDSHAKE {
return errors.New(fmt.Sprintf("Expected TLS server handshake byte was not received [%#x vs 0x16]", cr.rtype))
}
for 1 == 1 {
loop++
if cr.rtype < SSL3_RT_CHANGE_CIPHER_SPEC || cr.rtype > SSL3_RT_APPLICATION_DATA {
return errors.New(fmt.Sprintf("TLSGuard dropping connection with unknown content type: %#x", cr.rtype))
}
// fmt.Printf("SSL LOOP %v; trying to read: conn2\n", loop)
chunk, rtype, err = readTLSChunk(conn2)
handshakeMsg := cr.data[TLS_RECORD_HDR_LEN:]
s := uint(handshakeMsg[0])
fmt.Printf("s = %#x\n", s)
// Message len, 3 bytes
if cr.rtype == SSL3_RT_HANDSHAKE {
handshakeMessageLen := handshakeMsg[1:4]
handshakeMessageLenInt := int(int(handshakeMessageLen[0])<<16 | int(handshakeMessageLen[1])<<8 | int(handshakeMessageLen[2]))
fmt.Println("lenint = \n", handshakeMessageLenInt)
}
if err != nil {
log.Debugf("TLSGUARD: OTHER loop %v: trying to read: conn\n", loop)
chunk, rtype, err2 := readTLSChunk(conn)
log.Debugf("TLSGUARD: read: %v, %v, %v\n", err2, rtype, len(chunk))
if cr.client && s != uint(client_expected) {
return errors.New(fmt.Sprintf("Client sent handshake type %#x but expected %#x", s, client_expected))
} else if !cr.client && s != uint(server_expected) {
return errors.New(fmt.Sprintf("Server sent handshake type %#x but expected %#x", s, server_expected))
}
if err2 == nil {
conn2.Write(chunk)
continue
}
if (cr.client && s == SSL3_MT_CLIENT_HELLO) || (!cr.client && s == SSL3_MT_SERVER_HELLO) {
rewrite := false
rewrite_buf := []byte{}
SRC := ""
if s == SSL3_MT_CLIENT_HELLO {
SRC = "CLIENT"
} else {
server_expected = SSL3_MT_CERTIFICATE
SRC = "SERVER"
}
hello_offset := 4
// 2 byte protocol version
fmt.Println(SRC, "HELLO VERSION = ", handshakeMsg[hello_offset:hello_offset+2])
hello_offset += 2
// 4 byte Random/GMT time
gmtbytes := binary.BigEndian.Uint32(handshakeMsg[hello_offset : hello_offset+4])
gmt := time.Unix(int64(gmtbytes), 0)
fmt.Println(SRC, "HELLO GMT = ", gmt)
hello_offset += 4
// 28 bytes Random/random_bytes
hello_offset += 28
// 1 byte (32-bit session ID)
sess_len := uint(handshakeMsg[hello_offset])
fmt.Println(SRC, "HELLO SESSION ID = ", sess_len)
if sess_len != 0 {
fmt.Printf("ALERT: %v attempting to resume session; intercepting request\n", SRC)
rewrite = true
dcopy := make([]byte, len(cr.data))
copy(dcopy, cr.data)
// Copy the bytes before the session ID start
rewrite_buf = dcopy[0 : TLS_RECORD_HDR_LEN+hello_offset+1]
// Set the session ID to 0
rewrite_buf[len(rewrite_buf)-1] = 0
// Write the new TLS record length
binary.BigEndian.PutUint16(rewrite_buf[3:5], uint16(len(dcopy)-(int(sess_len)+TLS_RECORD_HDR_LEN)))
// Write the new ClientHello length
// Starts after the first 6 bytes (record header + type byte)
orig_len := binary.BigEndian.Uint32(handshakeMsg[0:4])
// But it's only 3 bytes so mask out the first one
b1 := orig_len & 0xff000000
orig_len &= 0x00ffffff
orig_len -= uint32(sess_len)
orig_len |= b1
binary.BigEndian.PutUint32(rewrite_buf[TLS_RECORD_HDR_LEN:], orig_len)
rewrite_buf = append(rewrite_buf, dcopy[TLS_RECORD_HDR_LEN+hello_offset+int(sess_len)+1:]...)
}
hello_offset += int(sess_len) + 1
// 2 byte cipher suite array
cs := binary.BigEndian.Uint16(handshakeMsg[hello_offset : hello_offset+2])
noCS := cs
fmt.Printf("cs = %v / %#x\n", noCS, noCS)
if !cr.client {
fmt.Printf("SERVER selected ciphersuite: %#x (%s)\n", cs, getCipherSuiteName(uint(cs)))
hello_offset += 2
} else {
for csind := 0; csind < int(noCS/2); csind++ {
off := hello_offset + 2 + (csind * 2)
cs = binary.BigEndian.Uint16(handshakeMsg[off : off+2])
fmt.Printf("%s HELLO CIPHERSUITE: %d/%d: %#x (%s)\n", SRC, csind+1, noCS/2, cs, getCipherSuiteName(uint(cs)))
}
hello_offset += 2 + int(noCS)
}
clen := uint(handshakeMsg[hello_offset])
hello_offset++
if !cr.client {
fmt.Println("SERVER selected compression method: ", clen)
} else {
fmt.Println(SRC, "HELLO COMPRESSION METHODS LEN = ", clen)
fmt.Println(SRC, "HELLO COMPRESSION METHODS: ", handshakeMsg[hello_offset:hello_offset+int(clen)])
hello_offset += int(clen)
}
var extlen uint16 = 0
if hello_offset == len(handshakeMsg) {
fmt.Println("Message didn't have any extensions present")
} else {
extlen = binary.BigEndian.Uint16(handshakeMsg[hello_offset : hello_offset+2])
fmt.Println(SRC, "HELLO EXTENSIONS LENGTH: ", extlen)
hello_offset += 2
}
var exttype uint16 = 0
if extlen > 2 {
exttype = binary.BigEndian.Uint16(handshakeMsg[hello_offset : hello_offset+2])
fmt.Println(SRC, "HELLO FIRST EXTENSION TYPE: ", exttype)
}
if cr.client {
ext_ctr := 0
for ext_ctr < int(extlen)-2 {
hello_offset += 2
ext_ctr += 2
fmt.Printf("PROGRESS: %v of %v, %v of %v\n", ext_ctr, extlen, hello_offset, len(handshakeMsg))
exttype2 := binary.BigEndian.Uint16(handshakeMsg[hello_offset : hello_offset+2])
fmt.Printf("EXTTYPE = %v, 2 = %v\n", exttype, exttype2)
if exttype2 == TLSEXT_TYPE_server_name {
fmt.Println("CLIENT specified server_name extension:")
}
if exttype != TLSEXT_TYPE_signature_algorithms {
fmt.Println("WTF")
}
hello_offset += 2
ext_ctr += 2
inner_len := binary.BigEndian.Uint16(handshakeMsg[hello_offset : hello_offset+2])
// fmt.Println("INNER LEN = ", inner_len)
hello_offset += int(inner_len)
ext_ctr += int(inner_len)
}
}
if extlen > 0 {
fmt.Printf("ALERT: %v attempting to send extensions; intercepting request\n", SRC)
rewrite = true
tocopy := cr.data
if len(rewrite_buf) > 0 {
tocopy = rewrite_buf
}
dcopy := make([]byte, len(tocopy)-int(extlen))
copy(dcopy, tocopy[0:len(tocopy)-int(extlen)])
rewrite_buf = dcopy
// Write the new TLS record length
binary.BigEndian.PutUint16(rewrite_buf[3:5], uint16(len(dcopy)-(int(sess_len)+TLS_RECORD_HDR_LEN)))
// Write the new ClientHello length
// Starts after the first 6 bytes (record header + type byte)
orig_len := binary.BigEndian.Uint32(rewrite_buf[TLS_RECORD_HDR_LEN:])
// But it's only 3 bytes so mask out the first one
b1 := orig_len & 0xff000000
orig_len &= 0x00ffffff
orig_len -= uint32(extlen)
orig_len |= b1
binary.BigEndian.PutUint32(rewrite_buf[TLS_RECORD_HDR_LEN:], orig_len)
// Write session length 0 at the end
rewrite_buf[len(rewrite_buf)-1] = 0
rewrite_buf[len(rewrite_buf)-2] = 0
}
if rewrite {
fmt.Println("TLSGuard writing back modified handshake data to server")
fmt.Printf("ORIGINAL[%d]: %v\n", len(cr.data), hex.Dump(cr.data))
fmt.Printf("NEW[%d]: %v\n", len(rewrite_buf), hex.Dump(rewrite_buf))
other.Write(rewrite_buf)
} else {
other.Write(cr.data)
}
continue
}
return err
}
if cr.client {
other.Write(cr.data)
continue
}
if rtype == SSL3_RT_CHANGE_CIPHER_SPEC || rtype == SSL3_RT_APPLICATION_DATA ||
rtype == SSL3_RT_ALERT {
// fmt.Println("OTHER DATA; PASSING THRU")
passthru = true
} else if rtype == SSL3_RT_HANDSHAKE {
passthru = false
} else {
return errors.New(fmt.Sprintf("Expected TLS server handshake byte was not received [%#x vs 0x16]", rtype))
}
if !cr.client && server_expected == SSL3_MT_SERVER_HELLO {
server_expected = SSL3_MT_CERTIFICATE
}
if passthru {
// fmt.Println("passthru writing buf again and continuing:")
conn.Write(chunk)
continue
}
if !cr.client && s == SSL3_MT_HELLO_REQUEST {
fmt.Println("Server sent hello request")
continue
}
serverMsg := chunk[5:]
s := serverMsg[0]
log.Debugf("TLSGUARD: s = %#x\n", s)
if s == SSL3_MT_CERTIFICATE {
// Message len, 3 bytes
serverMessageLen := serverMsg[1:4]
serverMessageLenInt := int(int(serverMessageLen[0])<<16 | int(serverMessageLen[1])<<8 | int(serverMessageLen[2]))
// fmt.Printf("chunk len = %v, serverMsgLen = %v, slint = %v\n", len(chunk), len(serverMsg), serverMessageLenInt)
if len(serverMsg) < serverMessageLenInt {
return errors.New(fmt.Sprintf("len(serverMsg) %v < serverMessageLenInt %v!\n", len(serverMsg), serverMessageLenInt))
}
serverHelloBody := serverMsg[4 : 4+serverMessageLenInt]
certChainLen := int(int(serverHelloBody[0])<<16 | int(serverHelloBody[1])<<8 | int(serverHelloBody[2]))
remaining := certChainLen
pos := serverHelloBody[3:certChainLen]
if s > SSL3_MT_CERTIFICATE_STATUS {
fmt.Println("WTF: ", cr.data)
}
// Message len, 3 bytes
handshakeMessageLen := handshakeMsg[1:4]
handshakeMessageLenInt := int(int(handshakeMessageLen[0])<<16 | int(handshakeMessageLen[1])<<8 | int(handshakeMessageLen[2]))
if s == SSL3_MT_CERTIFICATE {
fmt.Println("HMM")
// fmt.Printf("chunk len = %v, handshakeMsgLen = %v, slint = %v\n", len(chunk), len(handshakeMsg), handshakeMessageLenInt)
if len(handshakeMsg) < handshakeMessageLenInt {
return errors.New(fmt.Sprintf("len(handshakeMsg) %v < handshakeMessageLenInt %v!\n", len(handshakeMsg), handshakeMessageLenInt))
}
serverHelloBody := handshakeMsg[4 : 4+handshakeMessageLenInt]
certChainLen := int(int(serverHelloBody[0])<<16 | int(serverHelloBody[1])<<8 | int(serverHelloBody[2]))
remaining := certChainLen
pos := serverHelloBody[3:certChainLen]
// var certChain []*x509.Certificate
var verifyOptions x509.VerifyOptions
//fqdn = "www.reddit.com"
if fqdn != "" {
verifyOptions.DNSName = fqdn
}
pool := x509.NewCertPool()
var c *x509.Certificate
for remaining > 0 {
certLen := int(int(pos[0])<<16 | int(pos[1])<<8 | int(pos[2]))
// fmt.Printf("Certs chain len %d, cert 1 len %d:\n", certChainLen, certLen)
cert := pos[3 : 3+certLen]
certs, err := x509.ParseCertificates(cert)
if remaining == certChainLen {
c = certs[0]
} else {
pool.AddCert(certs[0])
}
// certChain = append(certChain, certs[0])
if err != nil {
return err
}
remaining = remaining - certLen - 3
if remaining > 0 {
pos = pos[3+certLen:]
}
}
verifyOptions.Intermediates = pool
fmt.Println("ATTEMPTING TO VERIFY: ", fqdn)
_, err := c.Verify(verifyOptions)
fmt.Println("ATTEMPTING TO VERIFY RESULT: ", err)
if err != nil {
return err
} else {
x509Valid = true
}
}
// var certChain []*x509.Certificate
var verifyOptions x509.VerifyOptions
other.Write(cr.data)
if fqdn != "" {
verifyOptions.DNSName = fqdn
}
if x509Valid || (s == SSL3_MT_SERVER_DONE) || (s == SSL3_MT_CERTIFICATE_REQUEST) {
fmt.Println("BREAKING OUT OF LOOP 1")
dChan <- true
fmt.Println("BREAKING OUT OF LOOP 2")
break select_loop
}
pool := x509.NewCertPool()
var c *x509.Certificate
// fmt.Printf("Sending chunk of type %d to client.\n", s)
} else if cr.err != nil {
ndone++
for remaining > 0 {
certLen := int(int(pos[0])<<16 | int(pos[1])<<8 | int(pos[2]))
// fmt.Printf("Certs chain len %d, cert 1 len %d:\n", certChainLen, certLen)
cert := pos[3 : 3+certLen]
certs, err := x509.ParseCertificates(cert)
if remaining == certChainLen {
c = certs[0]
if cr.client {
fmt.Println("Client read error: ", cr.err)
} else {
pool.AddCert(certs[0])
fmt.Println("Server read error: ", cr.err)
}
// certChain = append(certChain, certs[0])
if err != nil {
return err
}
remaining = remaining - certLen - 3
if remaining > 0 {
pos = pos[3+certLen:]
}
}
verifyOptions.Intermediates = pool
// fmt.Println("ATTEMPTING TO VERIFY: ", fqdn)
_, err = c.Verify(verifyOptions)
// fmt.Println("ATTEMPTING TO VERIFY RESULT: ", err)
if err != nil {
return err
} else {
valid = true
return cr.err
}
// lse if s == 0x0d { fmt.Printf("found a client cert request, sending buf to client\n") }
} else if s == SSL3_MT_SERVER_DONE {
conn.Write(chunk)
break
} else if s == SSL3_MT_CERTIFICATE_REQUEST {
break
}
// fmt.Printf("Sending chunk of type %d to client.\n", s)
}
conn.Write(chunk)
fmt.Println("WAITING; ndone = ", ndone)
for ndone < 2 {
fmt.Println("WAITING; ndone = ", ndone)
select {
case cr := <-crChan:
fmt.Printf("CHAN DATA: %v, %v, %v\n", cr.client, cr.err, len(cr.data))
if cr.err != nil || cr.data == nil {
ndone++
} else if cr.client {
conn2.Write(cr.data)
} else if !cr.client {
conn.Write(cr.data)
}
}
}
if !valid {
fmt.Println("______ ndone = 2\n")
// dChan <- true
close(dChan)
if !x509Valid {
return errors.New("Unknown error: TLS connection could not be validated")
}
return nil
}
func readNBytes(conn net.Conn, numBytes int) ([]byte, error) {
res := make([]byte, 0)
temp := make([]byte, 1)
for i := 0; i < numBytes; i++ {
_, err := io.ReadAtLeast(conn, temp, 1)
if err != nil {
return res, err
}
res = append(res, temp[0])
}
return res, nil
}

@ -1,8 +1,8 @@
package procsnitch
import (
"encoding/hex"
"encoding/binary"
"encoding/hex"
"errors"
"fmt"
"github.com/op/go-logging"
@ -169,7 +169,7 @@ func ParseIP(ip string) (net.IP, error) {
}
if isLittleEndian > 0 {
for i := 0; i < len(dst) / 4; i++ {
for i := 0; i < len(dst)/4; i++ {
start, end := i*4, (i+1)*4
word := dst[start:end]
lval := binary.LittleEndian.Uint32(word)
@ -177,13 +177,13 @@ func ParseIP(ip string) (net.IP, error) {
}
}
/* if len(dst) == 16 {
dst2 := []byte{dst[3], dst[2], dst[1], dst[0], dst[7], dst[6], dst[5], dst[4], dst[11], dst[10], dst[9], dst[8], dst[15], dst[14], dst[13], dst[12]}
return net.IP(dst2), nil
}
for i, j := 0, len(dst)-1; i < j; i, j = i+1, j-1 {
dst[i], dst[j] = dst[j], dst[i]
} */
/* if len(dst) == 16 {
dst2 := []byte{dst[3], dst[2], dst[1], dst[0], dst[7], dst[6], dst[5], dst[4], dst[11], dst[10], dst[9], dst[8], dst[15], dst[14], dst[13], dst[12]}
return net.IP(dst2), nil
}
for i, j := 0, len(dst)-1; i < j; i, j = i+1, j-1 {
dst[i], dst[j] = dst[j], dst[i]
} */
return net.IP(dst), nil
}
@ -312,6 +312,7 @@ func stripLabel(s string) string {
// stolen from github.com/virtao/GoEndian
const INT_SIZE int = int(unsafe.Sizeof(0))
func setEndian() {
var i int = 0x1
bs := (*[INT_SIZE]byte)(unsafe.Pointer(&i))

@ -23,7 +23,9 @@ type Info struct {
FirstArg string
ParentCmdLine string
ParentExePath string
Sandbox string
Sandbox string
Inode uint64
FD int
}
type pidCache struct {
@ -51,10 +53,12 @@ func loadCache() map[uint64]*Info {
for _, n := range readdir("/proc") {
pid := toPid(n)
if pid != 0 {
pinfo := &Info{Pid: pid}
for _, inode := range inodesFromPid(pid) {
inodes, fds := inodesFromPid(pid)
for iind, inode := range inodes {
pinfo := &Info{Inode: inode, Pid: pid, FD: fds[iind]}
cmap[inode] = pinfo
}
}
}
return cmap
@ -76,8 +80,9 @@ func toPid(name string) int {
return (int)(pid)
}
func inodesFromPid(pid int) []uint64 {
func inodesFromPid(pid int) ([]uint64, []int) {
var inodes []uint64
var fds []int
fdpath := fmt.Sprintf("/proc/%d/fd", pid)
for _, n := range readdir(fdpath) {
if link, err := os.Readlink(path.Join(fdpath, n)); err != nil {
@ -85,12 +90,19 @@ func inodesFromPid(pid int) []uint64 {
log.Warningf("Error reading link %s: %v", n, err)
}
} else {
fd, err := strconv.Atoi(n)
if err != nil {
log.Warningf("Error retrieving fd associated with pid %v: %v", pid, err)
fd = -1
}
if inode := extractSocket(link); inode > 0 {
inodes = append(inodes, inode)
fds = append(fds, fd)
}
}
}
return inodes
return inodes, fds
}
func extractSocket(name string) uint64 {

@ -111,60 +111,65 @@ func findUDPSocketAll(srcAddr net.IP, srcPort uint16, dstAddr net.IP, dstPort ui
if custdata == nil {
if strictness == MATCH_STRICT {
return findSocket(proto, func(ss socketStatus) bool {
fmt.Println("Match strict")
// fmt.Println("Match strict")
return ss.remote.ip.Equal(dstAddr) && ss.local.port == srcPort && ss.local.ip.Equal(srcAddr)
//return ss.local.port == srcPort && ss.local.ip.Equal(srcAddr)
})
} else if strictness == MATCH_LOOSE {
return findSocket(proto, func(ss socketStatus) bool {
/*
fmt.Println("Match loose")
fmt.Printf("sock dst = %v pkt dst = %v\n", ss.remote.ip, dstAddr)
fmt.Printf("sock port = %d pkt port = %d\n", ss.local.port, srcPort)
fmt.Printf("local ip: %v\n source ip: %v\n", ss.local.ip, srcAddr)
/*
fmt.Println("Match loose")
fmt.Printf("sock dst = %v pkt dst = %v\n", ss.remote.ip, dstAddr)
fmt.Printf("sock port = %d pkt port = %d\n", ss.local.port, srcPort)
fmt.Printf("local ip: %v\n source ip: %v\n", ss.local.ip, srcAddr)
*/
if (ss.local.port == srcPort && (ss.local.ip.Equal(net.IPv4(0,0,0,0)) && ss.remote.ip.Equal(net.IPv4(0,0,0,0)))) {
fmt.Printf("Matching for UDP socket bound to *:%d\n",ss.local.port)
if (ss.local.port == srcPort) && addrMatchesAny(ss.local.ip) && addrMatchesAny(ss.remote.ip) {
fmt.Printf("Loose match for UDP socket bound to *:%d\n", ss.local.port)
return true
} else if (ss.remote.ip.Equal(dstAddr) && ss.local.port == srcPort && ss.local.ip.Equal(srcAddr)) {
} else if ss.remote.ip.Equal(dstAddr) && ss.local.port == srcPort && ss.local.ip.Equal(srcAddr) {
return true
}
// Finally, loop through all interfaces if src port matches
// Finally, loop through all interfaces if src port matches
if ss.local.port == srcPort {
ifs, err := net.Interfaces()
if err != nil {
log.Warningf("Error on net.Interfaces(): %v", err)
log.Warning("Error retrieving list of network interfaces for UDP socket lookup:", err)
return false
}
for _, i := range ifs {
addrs, err := i.Addrs()
if err != nil {
log.Warningf("Error on Interface.Addrs(): %v", err)
log.Warning("Error retrieving network interface for UDP socket lookup:", err)
return false
}
for _, addr := range addrs {
var ifip net.IP
switch x := addr.(type) {
case *net.IPNet:
ifip = x.IP
case *net.IPAddr:
ifip = x.IP
case *net.IPNet:
ifip = x.IP
case *net.IPAddr:
ifip = x.IP
}
if ss.local.ip.Equal(ifip) {
fmt.Printf("Matched on UDP socket bound to %v:%d\n",ifip,srcPort)
fmt.Printf("Matched on UDP socket bound to %v:%d\n", ifip, srcPort)
return true
}
}
}
}
return false
//return (ss.remote.ip.Equal(dstAddr) || ss.remote.ip.Equal(net.IPv4(0,0,0,0))) && ss.local.port == srcPort && (ss.local.ip.Equal(srcAddr) || ss.local.ip.Equal(net.IPv4(0,0,0,0)))
/*
return (ss.remote.ip.Equal(dstAddr) || addrMatchesAny(ss.remote.ip)) && ss.local.port == srcPort && ss.local.ip.Equal(srcAddr) ||
(ss.local.ip.Equal(dstAddr) || addrMatchesAny(ss.local.ip)) && ss.remote.port == srcPort && ss.remote.ip.Equal(srcAddr) */
return (ss.remote.ip.Equal(dstAddr) || addrMatchesAny(ss.remote.ip)) && ss.local.port == srcPort && ss.local.ip.Equal(srcAddr) ||
(ss.local.ip.Equal(dstAddr) || addrMatchesAny(ss.local.ip)) && ss.remote.port == srcPort && ss.remote.ip.Equal(srcAddr) */
})
}
return findSocket(proto, func(ss socketStatus) bool {
@ -367,11 +372,11 @@ func getSocketLines(proto string) []string {
}
func addrMatchesAny(addr net.IP) bool {
wildcard := net.IP{0,0,0,0}
wildcard := net.IP{0, 0, 0, 0}
if addr.To4() == nil {
wildcard = net.IP{0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0}
wildcard = net.IP{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}
}
return wildcard.Equal(addr)
return wildcard.Equal(addr)
}

Loading…
Cancel
Save