From 2d8afe1d6082dcf3090de2fc5845ef51cca175cd Mon Sep 17 00:00:00 2001 From: brl Date: Thu, 3 Dec 2015 15:16:50 -0500 Subject: [PATCH] Initial commit --- .gitignore | 9 + README.md | 4 + dbus.go | 111 +++ dns.go | 68 ++ dnsmsg.go | 931 ++++++++++++++++++ gnome-shell/firewall@subgraph.com/dialog.js | 451 +++++++++ .../firewall@subgraph.com/extension.js | 97 ++ .../firewall@subgraph.com/metadata.json | 1 + .../firewall@subgraph.com/stylesheet.css | 71 ++ icons.go | 86 ++ iptables.go | 48 + main.go | 93 ++ nfqueue/LICENSE | 201 ++++ nfqueue/README.md | 42 + nfqueue/multiqueue.go | 41 + nfqueue/nfqueue.c | 85 ++ nfqueue/nfqueue.go | 180 ++++ nfqueue/nfqueue.h | 27 + nfqueue/packet.go | 145 +++ policy.go | 186 ++++ proc.go | 228 +++++ prompt.go | 170 ++++ rules.go | 269 +++++ 23 files changed, 3544 insertions(+) create mode 100644 .gitignore create mode 100644 README.md create mode 100644 dbus.go create mode 100644 dns.go create mode 100644 dnsmsg.go create mode 100644 gnome-shell/firewall@subgraph.com/dialog.js create mode 100644 gnome-shell/firewall@subgraph.com/extension.js create mode 100644 gnome-shell/firewall@subgraph.com/metadata.json create mode 100644 gnome-shell/firewall@subgraph.com/stylesheet.css create mode 100644 icons.go create mode 100644 iptables.go create mode 100644 main.go create mode 100644 nfqueue/LICENSE create mode 100644 nfqueue/README.md create mode 100644 nfqueue/multiqueue.go create mode 100644 nfqueue/nfqueue.c create mode 100644 nfqueue/nfqueue.go create mode 100644 nfqueue/nfqueue.h create mode 100644 nfqueue/packet.go create mode 100644 policy.go create mode 100644 proc.go create mode 100644 prompt.go create mode 100644 rules.go diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..3e03c47 --- /dev/null +++ b/.gitignore @@ -0,0 +1,9 @@ +*.iml +.idea/*.iml +.idea/*.iml +.idea/*.iml +.idea/*.iml +.idea/*.iml +.idea/*.iml +.idea/*.iml +.idea/ diff --git a/README.md b/README.md new file mode 100644 index 0000000..20462c8 --- /dev/null +++ b/README.md @@ -0,0 +1,4 @@ +# Subgraph Firewall + +A desktop application firewall for Subgraph OS. + diff --git a/dbus.go b/dbus.go new file mode 100644 index 0000000..2a5ae77 --- /dev/null +++ b/dbus.go @@ -0,0 +1,111 @@ +package main + +import ( + "errors" + "fmt" + "os" + "runtime" + "strings" + "syscall" + + "github.com/godbus/dbus" + "github.com/godbus/dbus/introspect" +) + +const introspectXml = ` + + + + + + ` + + introspect.IntrospectDataString + + `` + +const busName = "com.subgraph.Firewall" +const objectPath = "/com/subgraph/Firewall" +const interfaceName = "com.subgraph.Firewall" + +type dbusServer struct { + conn *dbus.Conn + prompter *prompter +} + +func dbusConnect() (*dbus.Conn, error) { + // https://github.com/golang/go/issues/1435 + runtime.LockOSThread() + syscall.Setresuid(-1, 1000, 0) + + conn, err := dbus.SessionBus() + if err != nil { + return nil, err + } + syscall.Setresuid(0, 0, -1) + runtime.UnlockOSThread() + + if os.Geteuid() != 0 || os.Getuid() != 0 { + log.Warning("Not root as expected") + os.Exit(0) + } + return conn, nil +} + +func newDbusServer(conn *dbus.Conn) (*dbusServer, error) { + reply, err := conn.RequestName(busName, dbus.NameFlagDoNotQueue) + if err != nil { + return nil, err + } + if reply != dbus.RequestNameReplyPrimaryOwner { + return nil, errors.New("Bus name is already owned") + } + ds := &dbusServer{} + + if err := conn.Export(ds, objectPath, interfaceName); err != nil { + return nil, err + } + + ps := strings.Split(objectPath, "/") + path := "/" + for _, p := range ps { + if len(path) > 1 { + path += "/" + } + path += p + + if err := conn.Export(ds, dbus.ObjectPath(path), "org.freedesktop.DBus.Introspectable"); err != nil { + return nil, err + } + } + ds.conn = conn + ds.prompter = newPrompter(conn) + return ds, nil +} + +func (ds *dbusServer) Introspect(msg dbus.Message) (string, *dbus.Error) { + path := string(msg.Headers[dbus.FieldPath].Value().(dbus.ObjectPath)) + if path == objectPath { + return introspectXml, nil + } + parts := strings.Split(objectPath, "/") + current := "/" + for i := 0; i < len(parts)-1; i++ { + if len(current) > 1 { + current += "/" + } + current += parts[i] + if path == current { + next := parts[i+1] + return fmt.Sprintf("", next), nil + } + } + return "", nil +} + +func (ds *dbusServer) SetEnabled(flag bool) *dbus.Error { + return nil +} + +func (ds *dbusServer) prompt(p *Policy) { + log.Info("prompting...") + ds.prompter.prompt(p) +} diff --git a/dns.go b/dns.go new file mode 100644 index 0000000..4bb2245 --- /dev/null +++ b/dns.go @@ -0,0 +1,68 @@ +package main + +import ( + "net" + "strings" + "sync" + + "github.com/subgraph/fw-daemon/nfqueue" +) + +type dnsCache struct { + ipMap map[string]string + lock sync.Mutex + done chan struct{} +} + +func NewDnsCache() *dnsCache { + return &dnsCache{ + ipMap: make(map[string]string), + done: make(chan struct{}), + } +} + +func (dc *dnsCache) processDNS(pkt *nfqueue.Packet) { + dns := &dnsMsg{} + if !dns.Unpack(pkt.Payload) { + log.Warning("Failed to Unpack DNS message") + return + } + if !dns.response { + return + } + if len(dns.question) != 1 { + log.Warning("Length of DNS Question section is not 1 as expected: %d", len(dns.question)) + return + } + q := dns.question[0] + if q.Qtype == dnsTypeA { + dc.processRecordA(q.Name, dns.answer) + return + } + log.Info("Unhandled DNS message: %v", dns) + +} + +func (dc *dnsCache) processRecordA(name string, answers []dnsRR) { + dc.lock.Lock() + defer dc.lock.Unlock() + for _, rr := range answers { + switch rec := rr.(type) { + case *dnsRR_A: + ip := net.IPv4(byte(rec.A>>24), byte(rec.A>>16), byte(rec.A>>8), byte(rec.A)).String() + if strings.HasSuffix(name, ".") { + name = name[:len(name)-1] + } + dc.ipMap[ip] = name + log.Info("Adding %s: %s", name, ip) + default: + log.Warning("Unexpected RR type in answer section of A response: %v", rec) + } + } +} + +func (dc *dnsCache) Lookup(ip net.IP) string { + dc.lock.Lock() + defer dc.lock.Unlock() + return dc.ipMap[ip.String()] +} diff --git a/dnsmsg.go b/dnsmsg.go new file mode 100644 index 0000000..d0e5e97 --- /dev/null +++ b/dnsmsg.go @@ -0,0 +1,931 @@ +// Copyright 2009 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// DNS packet assembly. See RFC 1035. +// +// This is intended to support name resolution during Dial. +// It doesn't have to be blazing fast. +// +// Each message structure has a Walk method that is used by +// a generic pack/unpack routine. Thus, if in the future we need +// to define new message structs, no new pack/unpack/printing code +// needs to be written. +// +// The first half of this file defines the DNS message formats. +// The second half implements the conversion to and from wire format. +// A few of the structure elements have string tags to aid the +// generic pack/unpack routines. +// +// TODO(rsc): There are enough names defined in this file that they're all +// prefixed with dns. Perhaps put this in its own package later. + +package main + +import "net" + +// Packet formats + +// Wire constants. +const ( + // valid dnsRR_Header.Rrtype and dnsQuestion.qtype + dnsTypeA = 1 + dnsTypeNS = 2 + dnsTypeMD = 3 + dnsTypeMF = 4 + dnsTypeCNAME = 5 + dnsTypeSOA = 6 + dnsTypeMB = 7 + dnsTypeMG = 8 + dnsTypeMR = 9 + dnsTypeNULL = 10 + dnsTypeWKS = 11 + dnsTypePTR = 12 + dnsTypeHINFO = 13 + dnsTypeMINFO = 14 + dnsTypeMX = 15 + dnsTypeTXT = 16 + dnsTypeAAAA = 28 + dnsTypeSRV = 33 + + // valid dnsQuestion.qtype only + dnsTypeAXFR = 252 + dnsTypeMAILB = 253 + dnsTypeMAILA = 254 + dnsTypeALL = 255 + + // valid dnsQuestion.qclass + dnsClassINET = 1 + dnsClassCSNET = 2 + dnsClassCHAOS = 3 + dnsClassHESIOD = 4 + dnsClassANY = 255 + + // dnsMsg.rcode + dnsRcodeSuccess = 0 + dnsRcodeFormatError = 1 + dnsRcodeServerFailure = 2 + dnsRcodeNameError = 3 + dnsRcodeNotImplemented = 4 + dnsRcodeRefused = 5 +) + +// A dnsStruct describes how to iterate over its fields to emulate +// reflective marshalling. +type dnsStruct interface { + // Walk iterates over fields of a structure and calls f + // with a reference to that field, the name of the field + // and a tag ("", "domain", "ipv4", "ipv6") specifying + // particular encodings. Possible concrete types + // for v are *uint16, *uint32, *string, or []byte, and + // *int, *bool in the case of dnsMsgHdr. + // Whenever f returns false, Walk must stop and return + // false, and otherwise return true. + Walk(f func(v interface{}, name, tag string) (ok bool)) (ok bool) +} + +// The wire format for the DNS packet header. +type dnsHeader struct { + Id uint16 + Bits uint16 + Qdcount, Ancount, Nscount, Arcount uint16 +} + +func (h *dnsHeader) Walk(f func(v interface{}, name, tag string) bool) bool { + return f(&h.Id, "Id", "") && + f(&h.Bits, "Bits", "") && + f(&h.Qdcount, "Qdcount", "") && + f(&h.Ancount, "Ancount", "") && + f(&h.Nscount, "Nscount", "") && + f(&h.Arcount, "Arcount", "") +} + +const ( + // dnsHeader.Bits + _QR = 1 << 15 // query/response (response=1) + _AA = 1 << 10 // authoritative + _TC = 1 << 9 // truncated + _RD = 1 << 8 // recursion desired + _RA = 1 << 7 // recursion available +) + +// DNS queries. +type dnsQuestion struct { + Name string `net:"domain-name"` // `net:"domain-name"` specifies encoding; see packers below + Qtype uint16 + Qclass uint16 +} + +func (q *dnsQuestion) Walk(f func(v interface{}, name, tag string) bool) bool { + return f(&q.Name, "Name", "domain") && + f(&q.Qtype, "Qtype", "") && + f(&q.Qclass, "Qclass", "") +} + +// DNS responses (resource records). +// There are many types of messages, +// but they all share the same header. +type dnsRR_Header struct { + Name string `net:"domain-name"` + Rrtype uint16 + Class uint16 + Ttl uint32 + Rdlength uint16 // length of data after header +} + +func (h *dnsRR_Header) Header() *dnsRR_Header { + return h +} + +func (h *dnsRR_Header) Walk(f func(v interface{}, name, tag string) bool) bool { + return f(&h.Name, "Name", "domain") && + f(&h.Rrtype, "Rrtype", "") && + f(&h.Class, "Class", "") && + f(&h.Ttl, "Ttl", "") && + f(&h.Rdlength, "Rdlength", "") +} + +type dnsRR interface { + dnsStruct + Header() *dnsRR_Header +} + +// Specific DNS RR formats for each query type. + +type dnsRR_CNAME struct { + Hdr dnsRR_Header + Cname string `net:"domain-name"` +} + +func (rr *dnsRR_CNAME) Header() *dnsRR_Header { + return &rr.Hdr +} + +func (rr *dnsRR_CNAME) Walk(f func(v interface{}, name, tag string) bool) bool { + return rr.Hdr.Walk(f) && f(&rr.Cname, "Cname", "domain") +} + +type dnsRR_HINFO struct { + Hdr dnsRR_Header + Cpu string + Os string +} + +func (rr *dnsRR_HINFO) Header() *dnsRR_Header { + return &rr.Hdr +} + +func (rr *dnsRR_HINFO) Walk(f func(v interface{}, name, tag string) bool) bool { + return rr.Hdr.Walk(f) && f(&rr.Cpu, "Cpu", "") && f(&rr.Os, "Os", "") +} + +type dnsRR_MB struct { + Hdr dnsRR_Header + Mb string `net:"domain-name"` +} + +func (rr *dnsRR_MB) Header() *dnsRR_Header { + return &rr.Hdr +} + +func (rr *dnsRR_MB) Walk(f func(v interface{}, name, tag string) bool) bool { + return rr.Hdr.Walk(f) && f(&rr.Mb, "Mb", "domain") +} + +type dnsRR_MG struct { + Hdr dnsRR_Header + Mg string `net:"domain-name"` +} + +func (rr *dnsRR_MG) Header() *dnsRR_Header { + return &rr.Hdr +} + +func (rr *dnsRR_MG) Walk(f func(v interface{}, name, tag string) bool) bool { + return rr.Hdr.Walk(f) && f(&rr.Mg, "Mg", "domain") +} + +type dnsRR_MINFO struct { + Hdr dnsRR_Header + Rmail string `net:"domain-name"` + Email string `net:"domain-name"` +} + +func (rr *dnsRR_MINFO) Header() *dnsRR_Header { + return &rr.Hdr +} + +func (rr *dnsRR_MINFO) Walk(f func(v interface{}, name, tag string) bool) bool { + return rr.Hdr.Walk(f) && f(&rr.Rmail, "Rmail", "domain") && f(&rr.Email, "Email", "domain") +} + +type dnsRR_MR struct { + Hdr dnsRR_Header + Mr string `net:"domain-name"` +} + +func (rr *dnsRR_MR) Header() *dnsRR_Header { + return &rr.Hdr +} + +func (rr *dnsRR_MR) Walk(f func(v interface{}, name, tag string) bool) bool { + return rr.Hdr.Walk(f) && f(&rr.Mr, "Mr", "domain") +} + +type dnsRR_MX struct { + Hdr dnsRR_Header + Pref uint16 + Mx string `net:"domain-name"` +} + +func (rr *dnsRR_MX) Header() *dnsRR_Header { + return &rr.Hdr +} + +func (rr *dnsRR_MX) Walk(f func(v interface{}, name, tag string) bool) bool { + return rr.Hdr.Walk(f) && f(&rr.Pref, "Pref", "") && f(&rr.Mx, "Mx", "domain") +} + +type dnsRR_NS struct { + Hdr dnsRR_Header + Ns string `net:"domain-name"` +} + +func (rr *dnsRR_NS) Header() *dnsRR_Header { + return &rr.Hdr +} + +func (rr *dnsRR_NS) Walk(f func(v interface{}, name, tag string) bool) bool { + return rr.Hdr.Walk(f) && f(&rr.Ns, "Ns", "domain") +} + +type dnsRR_PTR struct { + Hdr dnsRR_Header + Ptr string `net:"domain-name"` +} + +func (rr *dnsRR_PTR) Header() *dnsRR_Header { + return &rr.Hdr +} + +func (rr *dnsRR_PTR) Walk(f func(v interface{}, name, tag string) bool) bool { + return rr.Hdr.Walk(f) && f(&rr.Ptr, "Ptr", "domain") +} + +type dnsRR_SOA struct { + Hdr dnsRR_Header + Ns string `net:"domain-name"` + Mbox string `net:"domain-name"` + Serial uint32 + Refresh uint32 + Retry uint32 + Expire uint32 + Minttl uint32 +} + +func (rr *dnsRR_SOA) Header() *dnsRR_Header { + return &rr.Hdr +} + +func (rr *dnsRR_SOA) Walk(f func(v interface{}, name, tag string) bool) bool { + return rr.Hdr.Walk(f) && + f(&rr.Ns, "Ns", "domain") && + f(&rr.Mbox, "Mbox", "domain") && + f(&rr.Serial, "Serial", "") && + f(&rr.Refresh, "Refresh", "") && + f(&rr.Retry, "Retry", "") && + f(&rr.Expire, "Expire", "") && + f(&rr.Minttl, "Minttl", "") +} + +type dnsRR_TXT struct { + Hdr dnsRR_Header + Txt string // not domain name +} + +func (rr *dnsRR_TXT) Header() *dnsRR_Header { + return &rr.Hdr +} + +func (rr *dnsRR_TXT) Walk(f func(v interface{}, name, tag string) bool) bool { + if !rr.Hdr.Walk(f) { + return false + } + var n uint16 = 0 + for n < rr.Hdr.Rdlength { + var txt string + if !f(&txt, "Txt", "") { + return false + } + // more bytes than rr.Hdr.Rdlength said there woudld be + if rr.Hdr.Rdlength-n < uint16(len(txt))+1 { + return false + } + n += uint16(len(txt)) + 1 + rr.Txt += txt + } + return true +} + +type dnsRR_SRV struct { + Hdr dnsRR_Header + Priority uint16 + Weight uint16 + Port uint16 + Target string `net:"domain-name"` +} + +func (rr *dnsRR_SRV) Header() *dnsRR_Header { + return &rr.Hdr +} + +func (rr *dnsRR_SRV) Walk(f func(v interface{}, name, tag string) bool) bool { + return rr.Hdr.Walk(f) && + f(&rr.Priority, "Priority", "") && + f(&rr.Weight, "Weight", "") && + f(&rr.Port, "Port", "") && + f(&rr.Target, "Target", "domain") +} + +type dnsRR_A struct { + Hdr dnsRR_Header + A uint32 `net:"ipv4"` +} + +func (rr *dnsRR_A) Header() *dnsRR_Header { + return &rr.Hdr +} + +func (rr *dnsRR_A) Walk(f func(v interface{}, name, tag string) bool) bool { + return rr.Hdr.Walk(f) && f(&rr.A, "A", "ipv4") +} + +type dnsRR_AAAA struct { + Hdr dnsRR_Header + AAAA [16]byte `net:"ipv6"` +} + +func (rr *dnsRR_AAAA) Header() *dnsRR_Header { + return &rr.Hdr +} + +func (rr *dnsRR_AAAA) Walk(f func(v interface{}, name, tag string) bool) bool { + return rr.Hdr.Walk(f) && f(rr.AAAA[:], "AAAA", "ipv6") +} + +// Packing and unpacking. +// +// All the packers and unpackers take a (msg []byte, off int) +// and return (off1 int, ok bool). If they return ok==false, they +// also return off1==len(msg), so that the next unpacker will +// also fail. This lets us avoid checks of ok until the end of a +// packing sequence. + +// Map of constructors for each RR wire type. +var rr_mk = map[int]func() dnsRR{ + dnsTypeCNAME: func() dnsRR { return new(dnsRR_CNAME) }, + dnsTypeHINFO: func() dnsRR { return new(dnsRR_HINFO) }, + dnsTypeMB: func() dnsRR { return new(dnsRR_MB) }, + dnsTypeMG: func() dnsRR { return new(dnsRR_MG) }, + dnsTypeMINFO: func() dnsRR { return new(dnsRR_MINFO) }, + dnsTypeMR: func() dnsRR { return new(dnsRR_MR) }, + dnsTypeMX: func() dnsRR { return new(dnsRR_MX) }, + dnsTypeNS: func() dnsRR { return new(dnsRR_NS) }, + dnsTypePTR: func() dnsRR { return new(dnsRR_PTR) }, + dnsTypeSOA: func() dnsRR { return new(dnsRR_SOA) }, + dnsTypeTXT: func() dnsRR { return new(dnsRR_TXT) }, + dnsTypeSRV: func() dnsRR { return new(dnsRR_SRV) }, + dnsTypeA: func() dnsRR { return new(dnsRR_A) }, + dnsTypeAAAA: func() dnsRR { return new(dnsRR_AAAA) }, +} + +// Pack a domain name s into msg[off:]. +// Domain names are a sequence of counted strings +// split at the dots. They end with a zero-length string. +func packDomainName(s string, msg []byte, off int) (off1 int, ok bool) { + // Add trailing dot to canonicalize name. + if n := len(s); n == 0 || s[n-1] != '.' { + s += "." + } + + // Each dot ends a segment of the name. + // We trade each dot byte for a length byte. + // There is also a trailing zero. + // Check that we have all the space we need. + tot := len(s) + 1 + if off+tot > len(msg) { + return len(msg), false + } + + // Emit sequence of counted strings, chopping at dots. + begin := 0 + for i := 0; i < len(s); i++ { + if s[i] == '.' { + if i-begin >= 1<<6 { // top two bits of length must be clear + return len(msg), false + } + msg[off] = byte(i - begin) + off++ + for j := begin; j < i; j++ { + msg[off] = s[j] + off++ + } + begin = i + 1 + } + } + msg[off] = 0 + off++ + return off, true +} + +// Unpack a domain name. +// In addition to the simple sequences of counted strings above, +// domain names are allowed to refer to strings elsewhere in the +// packet, to avoid repeating common suffixes when returning +// many entries in a single domain. The pointers are marked +// by a length byte with the top two bits set. Ignoring those +// two bits, that byte and the next give a 14 bit offset from msg[0] +// where we should pick up the trail. +// Note that if we jump elsewhere in the packet, +// we return off1 == the offset after the first pointer we found, +// which is where the next record will start. +// In theory, the pointers are only allowed to jump backward. +// We let them jump anywhere and stop jumping after a while. +func unpackDomainName(msg []byte, off int) (s string, off1 int, ok bool) { + s = "" + ptr := 0 // number of pointers followed +Loop: + for { + if off >= len(msg) { + return "", len(msg), false + } + c := int(msg[off]) + off++ + switch c & 0xC0 { + case 0x00: + if c == 0x00 { + // end of name + break Loop + } + // literal string + if off+c > len(msg) { + return "", len(msg), false + } + s += string(msg[off:off+c]) + "." + off += c + case 0xC0: + // pointer to somewhere else in msg. + // remember location after first ptr, + // since that's how many bytes we consumed. + // also, don't follow too many pointers -- + // maybe there's a loop. + if off >= len(msg) { + return "", len(msg), false + } + c1 := msg[off] + off++ + if ptr == 0 { + off1 = off + } + if ptr++; ptr > 10 { + return "", len(msg), false + } + off = (c^0xC0)<<8 | int(c1) + default: + // 0x80 and 0x40 are reserved + return "", len(msg), false + } + } + if ptr == 0 { + off1 = off + } + return s, off1, true +} + +// packStruct packs a structure into msg at specified offset off, and +// returns off1 such that msg[off:off1] is the encoded data. +func packStruct(any dnsStruct, msg []byte, off int) (off1 int, ok bool) { + ok = any.Walk(func(field interface{}, name, tag string) bool { + switch fv := field.(type) { + default: + println("net: dns: unknown packing type") + return false + case *uint16: + i := *fv + if off+2 > len(msg) { + return false + } + msg[off] = byte(i >> 8) + msg[off+1] = byte(i) + off += 2 + case *uint32: + i := *fv + msg[off] = byte(i >> 24) + msg[off+1] = byte(i >> 16) + msg[off+2] = byte(i >> 8) + msg[off+3] = byte(i) + off += 4 + case []byte: + n := len(fv) + if off+n > len(msg) { + return false + } + copy(msg[off:off+n], fv) + off += n + case *string: + s := *fv + switch tag { + default: + println("net: dns: unknown string tag", tag) + return false + case "domain": + off, ok = packDomainName(s, msg, off) + if !ok { + return false + } + case "": + // Counted string: 1 byte length. + if len(s) > 255 || off+1+len(s) > len(msg) { + return false + } + msg[off] = byte(len(s)) + off++ + off += copy(msg[off:], s) + } + } + return true + }) + if !ok { + return len(msg), false + } + return off, true +} + +// unpackStruct decodes msg[off:] into the given structure, and +// returns off1 such that msg[off:off1] is the encoded data. +func unpackStruct(any dnsStruct, msg []byte, off int) (off1 int, ok bool) { + ok = any.Walk(func(field interface{}, name, tag string) bool { + switch fv := field.(type) { + default: + println("net: dns: unknown packing type") + return false + case *uint16: + if off+2 > len(msg) { + return false + } + *fv = uint16(msg[off])<<8 | uint16(msg[off+1]) + off += 2 + case *uint32: + if off+4 > len(msg) { + return false + } + *fv = uint32(msg[off])<<24 | uint32(msg[off+1])<<16 | + uint32(msg[off+2])<<8 | uint32(msg[off+3]) + off += 4 + case []byte: + n := len(fv) + if off+n > len(msg) { + return false + } + copy(fv, msg[off:off+n]) + off += n + case *string: + var s string + switch tag { + default: + println("net: dns: unknown string tag", tag) + return false + case "domain": + s, off, ok = unpackDomainName(msg, off) + if !ok { + return false + } + case "": + if off >= len(msg) || off+1+int(msg[off]) > len(msg) { + return false + } + n := int(msg[off]) + off++ + b := make([]byte, n) + for i := 0; i < n; i++ { + b[i] = msg[off+i] + } + off += n + s = string(b) + } + *fv = s + } + return true + }) + if !ok { + return len(msg), false + } + return off, true +} + +// Generic struct printer. Prints fields with tag "ipv4" or "ipv6" +// as IP addresses. +func printStruct(any dnsStruct) string { + s := "{" + i := 0 + any.Walk(func(val interface{}, name, tag string) bool { + i++ + if i > 1 { + s += ", " + } + s += name + "=" + switch tag { + case "ipv4": + i := *val.(*uint32) + s += net.IPv4(byte(i>>24), byte(i>>16), byte(i>>8), byte(i)).String() + case "ipv6": + i := val.([]byte) + s += net.IP(i).String() + default: + var i int64 + switch v := val.(type) { + default: + // can't really happen. + s += "" + return true + case *string: + s += *v + return true + case []byte: + s += string(v) + return true + case *bool: + if *v { + s += "true" + } else { + s += "false" + } + return true + case *int: + i = int64(*v) + case *uint: + i = int64(*v) + case *uint8: + i = int64(*v) + case *uint16: + i = int64(*v) + case *uint32: + i = int64(*v) + case *uint64: + i = int64(*v) + case *uintptr: + i = int64(*v) + } + s += itoa(int(i)) + } + return true + }) + s += "}" + return s +} + +// Convert integer to decimal string. +func itoa(val int) string { + if val < 0 { + return "-" + uitoa(uint(-val)) + } + return uitoa(uint(val)) +} + +// Convert unsigned integer to decimal string. +func uitoa(val uint) string { + if val == 0 { // avoid string allocation + return "0" + } + var buf [20]byte // big enough for 64bit value base 10 + i := len(buf) - 1 + for val >= 10 { + q := val / 10 + buf[i] = byte('0' + val - q*10) + i-- + val = q + } + // val < 10 + buf[i] = byte('0' + val) + return string(buf[i:]) +} + +// Resource record packer. +func packRR(rr dnsRR, msg []byte, off int) (off2 int, ok bool) { + var off1 int + // pack twice, once to find end of header + // and again to find end of packet. + // a bit inefficient but this doesn't need to be fast. + // off1 is end of header + // off2 is end of rr + off1, ok = packStruct(rr.Header(), msg, off) + off2, ok = packStruct(rr, msg, off) + if !ok { + return len(msg), false + } + // pack a third time; redo header with correct data length + rr.Header().Rdlength = uint16(off2 - off1) + packStruct(rr.Header(), msg, off) + return off2, true +} + +// Resource record unpacker. +func unpackRR(msg []byte, off int) (rr dnsRR, off1 int, ok bool) { + // unpack just the header, to find the rr type and length + var h dnsRR_Header + off0 := off + if off, ok = unpackStruct(&h, msg, off); !ok { + return nil, len(msg), false + } + end := off + int(h.Rdlength) + + // make an rr of that type and re-unpack. + // again inefficient but doesn't need to be fast. + mk, known := rr_mk[int(h.Rrtype)] + if !known { + return &h, end, true + } + rr = mk() + off, ok = unpackStruct(rr, msg, off0) + if off != end { + return &h, end, true + } + return rr, off, ok +} + +// Usable representation of a DNS packet. + +// A manually-unpacked version of (id, bits). +// This is in its own struct for easy printing. +type dnsMsgHdr struct { + id uint16 + response bool + opcode int + authoritative bool + truncated bool + recursion_desired bool + recursion_available bool + rcode int +} + +func (h *dnsMsgHdr) Walk(f func(v interface{}, name, tag string) bool) bool { + return f(&h.id, "id", "") && + f(&h.response, "response", "") && + f(&h.opcode, "opcode", "") && + f(&h.authoritative, "authoritative", "") && + f(&h.truncated, "truncated", "") && + f(&h.recursion_desired, "recursion_desired", "") && + f(&h.recursion_available, "recursion_available", "") && + f(&h.rcode, "rcode", "") +} + +type dnsMsg struct { + dnsMsgHdr + question []dnsQuestion + answer []dnsRR + ns []dnsRR + extra []dnsRR +} + +func (dns *dnsMsg) Pack() (msg []byte, ok bool) { + var dh dnsHeader + + // Convert convenient dnsMsg into wire-like dnsHeader. + dh.Id = dns.id + dh.Bits = uint16(dns.opcode)<<11 | uint16(dns.rcode) + if dns.recursion_available { + dh.Bits |= _RA + } + if dns.recursion_desired { + dh.Bits |= _RD + } + if dns.truncated { + dh.Bits |= _TC + } + if dns.authoritative { + dh.Bits |= _AA + } + if dns.response { + dh.Bits |= _QR + } + + // Prepare variable sized arrays. + question := dns.question + answer := dns.answer + ns := dns.ns + extra := dns.extra + + dh.Qdcount = uint16(len(question)) + dh.Ancount = uint16(len(answer)) + dh.Nscount = uint16(len(ns)) + dh.Arcount = uint16(len(extra)) + + // Could work harder to calculate message size, + // but this is far more than we need and not + // big enough to hurt the allocator. + msg = make([]byte, 2000) + + // Pack it in: header and then the pieces. + off := 0 + off, ok = packStruct(&dh, msg, off) + for i := 0; i < len(question); i++ { + off, ok = packStruct(&question[i], msg, off) + } + for i := 0; i < len(answer); i++ { + off, ok = packRR(answer[i], msg, off) + } + for i := 0; i < len(ns); i++ { + off, ok = packRR(ns[i], msg, off) + } + for i := 0; i < len(extra); i++ { + off, ok = packRR(extra[i], msg, off) + } + if !ok { + return nil, false + } + return msg[0:off], true +} + +func (dns *dnsMsg) Unpack(msg []byte) bool { + // Header. + var dh dnsHeader + off := 0 + var ok bool + if off, ok = unpackStruct(&dh, msg, off); !ok { + return false + } + dns.id = dh.Id + dns.response = (dh.Bits & _QR) != 0 + dns.opcode = int(dh.Bits>>11) & 0xF + dns.authoritative = (dh.Bits & _AA) != 0 + dns.truncated = (dh.Bits & _TC) != 0 + dns.recursion_desired = (dh.Bits & _RD) != 0 + dns.recursion_available = (dh.Bits & _RA) != 0 + dns.rcode = int(dh.Bits & 0xF) + + // Arrays. + dns.question = make([]dnsQuestion, dh.Qdcount) + dns.answer = make([]dnsRR, 0, dh.Ancount) + dns.ns = make([]dnsRR, 0, dh.Nscount) + dns.extra = make([]dnsRR, 0, dh.Arcount) + + var rec dnsRR + + for i := 0; i < len(dns.question); i++ { + off, ok = unpackStruct(&dns.question[i], msg, off) + } + for i := 0; i < int(dh.Ancount); i++ { + rec, off, ok = unpackRR(msg, off) + if !ok { + return false + } + dns.answer = append(dns.answer, rec) + } + for i := 0; i < int(dh.Nscount); i++ { + rec, off, ok = unpackRR(msg, off) + if !ok { + return false + } + dns.ns = append(dns.ns, rec) + } + for i := 0; i < int(dh.Arcount); i++ { + rec, off, ok = unpackRR(msg, off) + if !ok { + return false + } + dns.extra = append(dns.extra, rec) + } + // if off != len(msg) { + // println("extra bytes in dns packet", off, "<", len(msg)); + // } + return true +} + +func (dns *dnsMsg) String() string { + s := "DNS: " + printStruct(&dns.dnsMsgHdr) + "\n" + if len(dns.question) > 0 { + s += "-- Questions\n" + for i := 0; i < len(dns.question); i++ { + s += printStruct(&dns.question[i]) + "\n" + } + } + if len(dns.answer) > 0 { + s += "-- Answers\n" + for i := 0; i < len(dns.answer); i++ { + s += printStruct(dns.answer[i]) + "\n" + } + } + if len(dns.ns) > 0 { + s += "-- Name servers\n" + for i := 0; i < len(dns.ns); i++ { + s += printStruct(dns.ns[i]) + "\n" + } + } + if len(dns.extra) > 0 { + s += "-- Extra\n" + for i := 0; i < len(dns.extra); i++ { + s += printStruct(dns.extra[i]) + "\n" + } + } + return s +} diff --git a/gnome-shell/firewall@subgraph.com/dialog.js b/gnome-shell/firewall@subgraph.com/dialog.js new file mode 100644 index 0000000..77857d5 --- /dev/null +++ b/gnome-shell/firewall@subgraph.com/dialog.js @@ -0,0 +1,451 @@ +const Clutter = imports.gi.Clutter; +const GLib = imports.gi.GLib; +const Gtk = imports.gi.Gtk; +const Lang = imports.lang; +const Pango = imports.gi.Pango; +const Signals = imports.signals; +const St = imports.gi.St; + +const ModalDialog = imports.ui.modalDialog; +const Tweener = imports.ui.tweener; + +const RuleScope = { + APPLY_ONCE: 0, + APPLY_SESSION: 1, + APPLY_FOREVER: 2, +}; + +const DetailSection = new Lang.Class({ + Name: 'DetailSection', + + _init: function() { + this.actor = new St.BoxLayout({ style_class: 'fw-details-section' }); + this._left = new St.BoxLayout({ vertical: true, style_class: 'fw-details-left'}); + this._right = new St.BoxLayout({ vertical: true }); + this.actor.add_child(this._left); + this.actor.add_child(this._right); + + this.ipAddr = this._addDetails("IP Address:"); + this.path = this._addDetails("Path:"); + this.pid = this._addDetails("Process ID:"); + this.user = this._addDetails("User:"); + }, + + _addDetails: function(text) { + let title = new St.Label({ style_class: 'fw-detail-title', text: text}); + let msg = new St.Label({ style_class: 'fw-detail-message' }); + this._left.add(title, { expand: true, x_fill: false, x_align: St.Align.END}); + this._right.add(msg); + return msg; + }, + + setDetails: function(ip, path, pid, user) { + this.ipAddr.text = ip; + this.path.text = path; + this.pid.text = pid.toString(); + this.user.text = user; + } +}); + +const OptionListItem = new Lang.Class({ + Name: 'OptionListItem', + + _init: function(text, idx) { + this.actor = new St.BoxLayout({ style_class: 'fw-option-item', reactive: true, can_focus: true }); + this._selectedIcon = new St.Icon({style_class: 'fw-option-item-icon', icon_name: 'object-select-symbolic'}); + this._selectedIcon.opacity = 0; + + this._label = new St.Label({text: text}); + let spacer = new St.Bin(); + this.actor.add_child(this._label); + this.actor.add(spacer, {expand: true}); + this.actor.add_child(this._selectedIcon); + this.idx = idx; + + let action = new Clutter.ClickAction(); + action.connect('clicked', Lang.bind(this, function() { + this.actor.grab_key_focus(); + this.emit('selected'); + })); + this.actor.add_action(action); + + this.actor.connect('key-press-event', Lang.bind(this, this._onKeyPressEvent)); + }, + + setText: function(text) { + this._label.text = text; + }, + + setSelected: function(isSelected) { + this._selectedIcon.opacity = isSelected ? 255 : 0; + }, + + _onKeyPressEvent: function(actor, event) { + let symbol = event.get_key_symbol(); + if (symbol == Clutter.KEY_space || symbol == Clutter.KEY_Return) { + this.emit('selected'); + } + } +}); +Signals.addSignalMethods(OptionListItem.prototype); + +const OptionList = new Lang.Class({ + Name: 'OptionList', + + _init: function() { + this.actor = new St.BoxLayout({vertical: true, style_class: 'fw-option-list'}); + this.buttonGroup = new ButtonGroup("Forever", "Session", "Once"); + this.actor.add_child(this.buttonGroup.actor); + this.items = []; + this._selected; + }, + + setOptionText: function(idx, text) { + if(this.items.length <= idx) { + log("attempt to setOptionText with idx = "+ idx + " when this.items.length = "+ this.items.length) + return; + } + this.items[idx].setText(text); + }, + + addOptions: function(options) { + for(let i = 0; i < options.length; i++) { + this._addOption(options[i], i) + } + if(this.items.length) { + this._optionSelected(this.items[0]) + } + }, + + _addOption: function(text, idx) { + let item = new OptionListItem(text, idx); + item.connect('selected', Lang.bind(this, function() { + this._optionSelected(item); + })); + this.actor.add_child(item.actor); + this.items.push(item); + }, + + _optionSelected: function(item) { + if (item == this._selected) { + return; + } + if(this._selected) { + this._selected.actor.remove_style_pseudo_class('selected'); + this._selected.setSelected(false); + } + item.setSelected(true); + this._selected = item; + this._selected.actor.add_style_pseudo_class('selected'); + }, + + selectedIdx: function() { + return this._selected.idx; + }, + + selectedScope: function() { + switch(this.buttonGroup._checked) { + case 0: + return RuleScope.APPLY_FOREVER; + case 1: + return RuleScope.APPLY_SESSION; + case 2: + return RuleScope.APPLY_ONCE; + default: + log("unexpected scope value "+ this.buttonGroup._selected); + return RuleScope.APPLY_SESSION; + } + } + +}); + +const ButtonGroup = new Lang.Class({ + Name: 'ButtonGroup', + + _init: function() { + this.actor = new St.BoxLayout({ style_class: 'fw-button-group'}); + this._checked = -1; + this._buttons= []; + for(let i = 0; i < arguments.length; i++) { + let idx = i; + this._buttons[i] = new St.Button({ style_class: 'fw-group-button button', + label: arguments[i], + can_focus: true, + x_expand: true }); + this._buttons[i].connect('clicked', Lang.bind(this, function(actor) { + this._setChecked(idx); + })); + this.actor.add_child(this._buttons[i]); + } + this._setChecked(0); + }, + + _setChecked: function(idx) { + + if(idx == this._checked) { + return; + } + this._buttons[idx].add_style_pseudo_class('checked'); + if(this._checked >= 0) { + this._buttons[this._checked].remove_style_pseudo_class('checked'); + } + this._checked = idx; + }, + +}); + +const ExpandingSection = new Lang.Class({ + Name: 'ExpandingSection', + + _init: function(text, content) { + this.actor = new St.BoxLayout({vertical: true}); + this._createHeader(this.actor, text); + this.scroll = new St.ScrollView({hscrollbar_policy: Gtk.PolicyType.NEVER, + vscrollbar_policy: Gtk.PolicyType.NEVER }); + this.actor.add_child(this.scroll); + this.isOpen = false; + }, + + _createHeader: function(parent, text) { + this.header = new St.BoxLayout({ style_class: 'fw-expanding-section-header', reactive: true, track_hover: true, can_focus: true}); + this.label = new St.Label({ style_class: 'fw-expanding-section-label', text: text, y_expand: true, y_align: Clutter.ActorAlign.CENTER }); + this.header.add_child(this.label); + let spacer = new St.Bin({ style_class: 'fw-expanding-section-spacer'}); + this.header.add(spacer, {expand: true}); + + this._triangle = new St.Icon({ style_class: 'popup-menu-arrow', + icon_name: 'pan-end-symbolic', + y_expand: true, + y_align: Clutter.ActorAlign.CENTER}); + this._triangle.pivot_point = new Clutter.Point({ x: 0.5, y: 0.6 }); + + this._triangleBin = new St.Widget({ y_expand: true, y_align: Clutter.ActorAlign.CENTER}); + this._triangleBin.add_child(this._triangle); + + this.header.add_child(this._triangleBin); + this.header.connect('button-press-event', Lang.bind(this, this._onButtonPressEvent)); + this.header.connect('button-release-event', Lang.bind(this, this._onButtonReleaseEvent)); + this.header.connect('key-press-event', Lang.bind(this, this._onKeyPressEvent)); + parent.add_child(this.header); + }, + + _onButtonPressEvent: function (actor, event) { + this.actor.add_style_pseudo_class('active'); + return Clutter.EVENT_PROPAGATE; + }, + + _onButtonReleaseEvent: function (actor, event) { + this.actor.remove_style_pseudo_class('active'); + this.activate(event); + return Clutter.EVENT_STOP; + }, + + _onKeyPressEvent: function(actor, event) { + let symbol = event.get_key_symbol(); + if (symbol == Clutter.KEY_space || symbol == Clutter.KEY_Return) { + this.activate(event); + } + }, + + activate: function(event) { + if(!this.isOpen) { + this.open(); + } else { + this.close(); + } + }, + + set_child: function(child) { + if(this.child) { + this.child.destroy(); + } + this.scroll.add_actor(child); + this.child = child; + let [min, nat] = this.child.get_preferred_width(-1); + this.scroll.width = nat; + this.scroll.show(); + this.scroll.height = 0; + this.child.hide(); + }, + + open: function() { + if(this.isOpen) { + return; + } + if(!this.child) { + return; + } + this.isOpen = true; + this.scroll.show(); + this.child.show(); + let targetAngle = 90; + let [minHeight, naturalHeight] = this.child.get_preferred_height(-1); + this.scroll.height = 0; + this.scroll._arrowRotation = this._triangle.rotation_angle_z; + Tweener.addTween(this.scroll, + { _arrowRotation: targetAngle, + height: naturalHeight, + time: 0.5, + onUpdateScope: this, + onUpdate: function() { + this._triangle.rotation_angle_z = this.scroll._arrowRotation; + }, + onCompleteScope: this, + onComplete: function() { + this.scroll.set_height(-1); + } + }); + }, + + close: function() { + if(!this.isOpen) { + return; + } + this.isOpen = false; + this.scroll._arrowRotation = this._triangle.rotation_angle_z; + Tweener.addTween(this.scroll, + { _arrowRotation: 0, + height: 0, + time: 0.5, + onUpdateScope: this, + onUpdate: function() { + this._triangle.rotation_angle_z = this.scroll._arrowRotation; + }, + onCompleteScope: this, + onComplete: function() { + this.child.hide(); + } + }); + } + +}); + +const PromptDialogHeader = new Lang.Class({ + Name: 'PromptDialogHeader', + + _init: function() { + this.actor = new St.BoxLayout(); + let inner = new St.BoxLayout({ vertical: true }); + this.icon = new St.Icon({style_class: 'fw-prompt-icon'}) + this.title = new St.Label({style_class: 'fw-prompt-title'}) + this.message = new St.Label({style_class: 'fw-prompt-message'}); + this.message.clutter_text.line_wrap = true; + this.message.clutter_text.ellipsize = Pango.EllipsizeMode.NONE; + inner.add_child(this.title); + inner.add_child(this.message); + this.actor.add_child(this.icon); + this.actor.add_child(inner); + }, + + setTitle: function(text) { + if(!text) { + text = "Unknown"; + } + this.title.text = text; + }, + + setMessage: function(text) { + this.message.text = text; + }, + + setIcon: function(name) { + this.icon.icon_name = name; + }, + + setIconDefault: function() { + this.setIcon('security-high-symbolic'); + }, + +}); + +const PromptDialog = new Lang.Class({ + Name: 'PromptDialog', + Extends: ModalDialog.ModalDialog, + + _init: function(invocation) { + this.parent({ styleClass: 'fw-prompt-dialog' }); + this._invocation = invocation; + this.header = new PromptDialogHeader(); + this.contentLayout.add_child(this.header.actor); + + this.details = new ExpandingSection("Details"); + this.contentLayout.add(this.details.actor, {y_fill: false, x_fill: true}); + let box = new St.BoxLayout({ vertical: true }); + this.details.set_child(box); + this.info = new DetailSection(); + box.add_child(this.info.actor); + + this.optionList = new OptionList(); + box.add_child(this.optionList.actor); + this.optionList.addOptions([ + "Any Connection", + "Only PORT", + "Only ADDRESS", + "Only PORT AND ADDRESS"]); + + + this.setButtons([ + { label: "Allow", action: Lang.bind(this, this.onAllow) }, + { label: "Deny", action: Lang.bind(this, this.onDeny) }]); + + }, + + onAllow: function() { + this.close(); + this.sendReturnValue(true); + }, + + onDeny: function() { + this.close(); + this.sendReturnValue(false); + }, + + sendReturnValue: function(allow) { + if(!this._invocation) { + return; + } + let verb = "DENY"; + if(allow) { + verb = "ALLOW"; + } + let rule = verb + "|" + this.ruleTarget(); + let scope = this.optionList.selectedScope() + this._invocation.return_value(GLib.Variant.new('(is)', [scope, rule])); + this._invocation = null; + }, + + ruleTarget: function() { + switch(this.optionList.selectedIdx()) { + case 0: + return "*:*"; + case 1: + return "*:" + this._port; + case 2: + return this._address + ":*"; + case 3: + return this._address + ":" + this._port; + } + + }, + + update: function(application, icon, path, address, port, ip, user, pid) { + this._address = address; + this._port = port; + + let port_str = "TCP Port "+ port; + + this.header.setTitle(application); + this.header.setMessage("Wants to connect to "+ address + " on " + port_str); + + if(icon) { + this.header.setIcon(icon); + } else { + this.header.setIconDefault(); + } + + this.optionList.setOptionText(1, "Only "+ port_str); + this.optionList.setOptionText(2, "Only "+ address); + this.optionList.setOptionText(3, "Only "+ address + " on "+ port_str); + this.info.setDetails(ip, path, pid, user); + }, +}); diff --git a/gnome-shell/firewall@subgraph.com/extension.js b/gnome-shell/firewall@subgraph.com/extension.js new file mode 100644 index 0000000..f1222fc --- /dev/null +++ b/gnome-shell/firewall@subgraph.com/extension.js @@ -0,0 +1,97 @@ +const Lang = imports.lang; +const Gio = imports.gi.Gio; + +const Extension = imports.misc.extensionUtils.getCurrentExtension(); +const Dialog = Extension.imports.dialog; + +function init() { + return new FirewallSupport(); +} + +const FirewallSupport = new Lang.Class({ + Name: 'FirewallSupport', + + _init: function() { + this.handler = null; + }, + + _destroyHandler: function() { + if(this.handler) { + this.handler.destroy(); + this.handler = null; + } + }, + enable: function() { + this._destroyHandler(); + this.handler = new FirewallPromptHandler(); + }, + disable: function() { + this._destroyHandler(); + } +}); + + +// $ busctl --user call com.subgraph.FirewallPrompt /com/subgraph/FirewallPrompt com.subgraph.FirewallPrompt TestPrompt +const FirewallPromptInterface = ' \ + \ + \ + \ + \ + \ + \ + \ + \ + \ + \ + \ + \ + \ + \ + \ + \ +'; + +const FirewallPromptHandler = new Lang.Class({ + Name: 'FirewallPromptHandler', + + _init: function() { + this._dbusImpl = Gio.DBusExportedObject.wrapJSObject(FirewallPromptInterface, this); + this._dbusImpl.export(Gio.DBus.session, '/com/subgraph/FirewallPrompt'); + Gio.bus_own_name_on_connection(Gio.DBus.session, 'com.subgraph.FirewallPrompt', Gio.BusNameOwnerFlags.REPLACE, null, null); + this._dialog = null; + }, + + destroy: function() { + this._closeDialog(); + this._dbusImpl.unexport(); + }, + + _closeDialog: function() { + if (this._dialog) { + this._dialog.close(); + this._dialog = null; + } + }, + + RequestPromptAsync: function(params, invocation) { + let [app, icon, path, address, port, ip, user, pid] = params; + this._closeDialog(); + this._dialog = new Dialog.PromptDialog(invocation); + this._invocation = invocation; + this._dialog.update(app, icon, path, address, port, ip, user, pid); + this._dialog.open(); + + }, + + CloseAsync: function(params, invocation) { + this._closeDialog(); + }, + + TestPrompt: function(params, invocation) { + this._closeDialog(); + this._dialog = new Dialog.PromptDialog(nil); + this._dialog.update("Firefox", "firefox", "/usr/bin/firefox", "242.12.111.18", "443", "linux", "2342"); + this._dialog.open(); + } +}); + diff --git a/gnome-shell/firewall@subgraph.com/metadata.json b/gnome-shell/firewall@subgraph.com/metadata.json new file mode 100644 index 0000000..8908562 --- /dev/null +++ b/gnome-shell/firewall@subgraph.com/metadata.json @@ -0,0 +1 @@ +{"description": "Firewall Extension", "shell-version": ["3.18"], "uuid": "firewall@subgraph.com", "name": "Firewall Extension"} diff --git a/gnome-shell/firewall@subgraph.com/stylesheet.css b/gnome-shell/firewall@subgraph.com/stylesheet.css new file mode 100644 index 0000000..df3b38e --- /dev/null +++ b/gnome-shell/firewall@subgraph.com/stylesheet.css @@ -0,0 +1,71 @@ +.fw-prompt-dialog { + min-width: 450px; + max-width: 500px; +} + +.fw-group-button:checked { + color: white; + border-color: rgba(0,0,0,0.7); + background-color: #222728; + box-shadow: inset 0 0 black; + text-shadow: none; +} + +.fw-button-group { + padding-bottom: 10px; +} + +.fw-option-list { + border: 2px solid rgba(238, 238, 236, 0.5); + padding: 10px; +} + +.fw-option-item{ + padding: 5px; +} + +.fw-option-item:focus { + background-color: #215d9c; +} + +.fw-option-item:selected { + font-weight: bold; +} + +.fw-option-item-icon { + icon-size: 16px; +} + +.fw-prompt-title { + font-size: 14pt; + font-weight: bold; +} + +.fw-prompt-icon { + padding: 10px; +} + +.fw-detail-title { + font-weight: bold; +} + +.fw-expanding-section-header { + border-width: 1px; + border-color: rgba(0, 0, 0, 0); +} + +.fw-expanding-section-header:focus { + border-color: #215d9c; +} + +.fw-expanding-section-label { + font-weight: bold; +} + +.fw-details-section { + padding: 20px; +} + +.fw-details-left { + padding-right: 10px; +} diff --git a/icons.go b/icons.go new file mode 100644 index 0000000..7e3db50 --- /dev/null +++ b/icons.go @@ -0,0 +1,86 @@ +package main + +import ( + "fmt" + "io/ioutil" + "os" + "path" + "strings" +) + +type DesktopEntry struct { + icon string + name string +} + +var entryMap = map[string]*DesktopEntry{} +var initialized = false + +func entryForPath(p string) *DesktopEntry { + if !initialized { + initIcons() + } + entry, ok := entryMap[path.Base(p)] + if ok { + return entry + } + return entryMap[p] +} + +func initIcons() { + if initialized { + return + } + path := "/usr/share/applications" + dir, err := os.Open(path) + if err != nil { + log.Warning("Failed to open %s for reading: %v", path, err) + return + } + names, err := dir.Readdirnames(0) + if err != nil { + log.Warning("Could not read directory %s: %v", path, err) + return + } + for _, n := range names { + if strings.HasSuffix(n, ".desktop") { + loadDesktopFile(fmt.Sprintf("%s/%s", path, n)) + } + } + initialized = true +} + +func loadDesktopFile(path string) { + bs, err := ioutil.ReadFile(path) + if err != nil { + log.Warning("Error reading %s: %v", path, err) + return + } + exec := "" + icon := "" + name := "" + inDE := false + + for _, line := range strings.Split(string(bs), "\n") { + if strings.Contains(line, "[Desktop Entry]") { + inDE = true + } else if len(line) > 0 && line[0] == '[' { + inDE = false + } + if inDE && strings.HasPrefix(line, "Exec=") { + exec = strings.Fields(line[5:])[0] + } + if inDE && strings.HasPrefix(line, "Icon=") { + icon = line[5:] + } + if inDE && strings.HasPrefix(line, "Name=") { + name = line[5:] + } + } + if exec != "" && icon != "" { + entryMap[exec] = &DesktopEntry{ + icon: icon, + name: name, + } + } +} diff --git a/iptables.go b/iptables.go new file mode 100644 index 0000000..1d07438 --- /dev/null +++ b/iptables.go @@ -0,0 +1,48 @@ +package main + +import ( + "fmt" + "os" + "os/exec" + "strings" +) + +const iptablesRule = "-t mangle -%c OUTPUT -m conntrack --ctstate NEW -j NFQUEUE --queue-num 0 --queue-bypass" +const dnsRule = "-%c INPUT --protocol udp -m multiport --source-ports 53 -j NFQUEUE --queue-num 0 --queue-bypass" + +func setupIPTables() { + removeIPTRules(dnsRule, iptablesRule) + addIPTRules(iptablesRule, dnsRule) +} + +func removeIPTRules(rules ...string) { + for _, r := range rules { + iptables('D', r) + } +} + +func addIPTRules(rules ...string) { + for _, r := range rules { + iptables('I', r) + } +} + +func iptables(verb rune, rule string) { + + iptablesPath, err := exec.LookPath("iptables") + if err != nil { + log.Warning("Could not find iptables binary in path") + os.Exit(1) + } + + argLine := fmt.Sprintf(rule, verb) + args := strings.Fields(argLine) + fmt.Println(iptablesPath, argLine) + cmd := exec.Command(iptablesPath, args...) + out, err := cmd.CombinedOutput() + fmt.Fprintf(os.Stderr, string(out)) + _, exitErr := err.(*exec.ExitError) + if err != nil && !exitErr { + log.Warning("Error reading output: %v", err) + } +} diff --git a/main.go b/main.go new file mode 100644 index 0000000..331870c --- /dev/null +++ b/main.go @@ -0,0 +1,93 @@ +package main + +import ( + // _ "net/http/pprof" + "os" + "os/signal" + "runtime" + "time" + + "github.com/op/go-logging" + "github.com/subgraph/fw-daemon/nfqueue" + "sync" +) + +var log = logging.MustGetLogger("sgfw") +var format = logging.MustStringFormatter( + "%{color}%{time:15:04:05} ▶ %{level:.4s} %{id:03x}%{color:reset} %{message}", +) + +func init() { + backend := logging.NewLogBackend(os.Stderr, "", 0) + formatter := logging.NewBackendFormatter(backend, format) + leveler := logging.AddModuleLevel(formatter) + log.SetBackend(leveler) +} + +type Firewall struct { + dbus *dbusServer + dns *dnsCache + + lock sync.Mutex + policyMap map[string]*Policy + policies []*Policy +} + +func (fw *Firewall) runFilter() { + q := nfqueue.NewNFQueue(0) + defer q.Destroy() + + q.DefaultVerdict = nfqueue.DROP + q.Timeout = 5 * time.Minute + packets := q.Process() + + sigs := make(chan os.Signal, 1) + signal.Notify(sigs, os.Interrupt, os.Kill) + + for { + select { + case pkt := <-packets: + fw.filterPacket(pkt) + case <-sigs: + return + } + } +} + +func main() { + runtime.GOMAXPROCS(1) + + if os.Geteuid() != 0 { + log.Error("Must be run as root") + os.Exit(1) + } + + setupIPTables() + + dbus, err := dbusConnect() + if err != nil { + log.Error(err.Error()) + os.Exit(1) + } + ds, err := newDbusServer(dbus) + if err != nil { + log.Error(err.Error()) + os.Exit(1) + } + + fw := &Firewall{ + dbus: ds, + dns: NewDnsCache(), + policyMap: make(map[string]*Policy), + } + + fw.loadRules() + + /* + go func() { + http.ListenAndServe("localhost:6060", nil) + }() + */ + + fw.runFilter() +} diff --git a/nfqueue/LICENSE b/nfqueue/LICENSE new file mode 100644 index 0000000..ad410e1 --- /dev/null +++ b/nfqueue/LICENSE @@ -0,0 +1,201 @@ +Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "{}" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright {yyyy} {name of copyright owner} + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. \ No newline at end of file diff --git a/nfqueue/README.md b/nfqueue/README.md new file mode 100644 index 0000000..895a1fc --- /dev/null +++ b/nfqueue/README.md @@ -0,0 +1,42 @@ +Go-NFQueue +========== +Go Wrapper For Creating IPTables' NFQueue clients in Go + +Usage +------ +Check the `examples/main.go` file + +```bash + cd $GOPATH/github.com/OneOfOne/go-nfqueue/examples + go build -race && sudo ./examples +``` +* Open another terminal : +```bash +sudo iptables -I INPUT 1 -m conntrack --ctstate NEW -j NFQUEUE --queue-num 0 +#or +sudo iptables -I INPUT -i eth0 -m conntrack --ctstate NEW -j NFQUEUE --queue-num 0 +curl --head localhost +ping localhost +sudo iptables -D INPUT -m conntrack --ctstate NEW -j NFQUEUE --queue-num 0 +``` +Then you can `ctrl+c` the program to exit. + +* If you have recent enough iptables/nfqueue you could also use a balanced (multithreaded queue). +* check the example in `examples/mq/multiqueue.go` + +```bash +iptables -I INPUT 1 -m conntrack --ctstate NEW -j NFQUEUE --queue-balance 0:5 --queue-cpu-fanout +``` +Notes +----- + +You must run the executable as root. +This is *WIP*, but all patches are welcome. + +License +------- +go-nfqueue is under the Apache v2 license, check the included license file. +Copyright © [Ahmed W.](http://www.limitlessfx.com/) +See the included `LICENSE` file. + +> Copyright (c) 2014 Ahmed W. \ No newline at end of file diff --git a/nfqueue/multiqueue.go b/nfqueue/multiqueue.go new file mode 100644 index 0000000..08532f8 --- /dev/null +++ b/nfqueue/multiqueue.go @@ -0,0 +1,41 @@ +package nfqueue + +import "sync" + +type multiQueue struct { + qs []*nfQueue +} + +func NewMultiQueue(min, max uint16) (mq *multiQueue) { + mq = &multiQueue{make([]*nfQueue, 0, max-min)} + for i := min; i < max; i++ { + mq.qs = append(mq.qs, NewNFQueue(i)) + } + return mq +} + +func (mq *multiQueue) Process() <-chan *Packet { + var ( + wg sync.WaitGroup + out = make(chan *Packet, len(mq.qs)) + ) + for _, q := range mq.qs { + wg.Add(1) + go func(ch <-chan *Packet) { + for pkt := range ch { + out <- pkt + } + wg.Done() + }(q.Process()) + } + go func() { + wg.Wait() + close(out) + }() + return out +} +func (mq *multiQueue) Destroy() { + for _, q := range mq.qs { + q.Destroy() + } +} diff --git a/nfqueue/nfqueue.c b/nfqueue/nfqueue.c new file mode 100644 index 0000000..5c3b0f2 --- /dev/null +++ b/nfqueue/nfqueue.c @@ -0,0 +1,85 @@ +#include "nfqueue.h" +#include "_cgo_export.h" + + +int nfqueue_cb_new(struct nfq_q_handle *qh, struct nfgenmsg *nfmsg, struct nfq_data *nfa, void *data) { + + struct nfqnl_msg_packet_hdr *ph = nfq_get_msg_packet_hdr(nfa); + + if(ph == NULL) { + return 1; + } + + int id = ntohl(ph->packet_id); + + unsigned char * payload; + unsigned char * saddr, * daddr; + uint16_t sport = 0, dport = 0, checksum = 0; + uint32_t mark = nfq_get_nfmark(nfa); + + int len = nfq_get_payload(nfa, &payload); + + if(len < sizeof(struct iphdr)) { + return 0; + } + + struct iphdr * ip = (struct iphdr *) payload; + + if(ip->version == 4) { + uint32_t ipsz = (ip->ihl << 2); + if(len < ipsz) { + return 0; + } + len -= ipsz; + payload += ipsz; + + saddr = (unsigned char *)&ip->saddr; + daddr = (unsigned char *)&ip->daddr; + + if(ip->protocol == IPPROTO_TCP) { + if(len < sizeof(struct tcphdr)) { + return 0; + } + struct tcphdr *tcp = (struct tcphdr *) payload; + uint32_t tcpsz = (tcp->doff << 2); + if(len < tcpsz) { + return 0; + } + len -= tcpsz; + payload += tcpsz; + + sport = ntohs(tcp->source); + dport = ntohs(tcp->dest); + checksum = ntohs(tcp->check); + } else if(ip->protocol == IPPROTO_UDP) { + if(len < sizeof(struct udphdr)) { + return 0; + } + struct udphdr *u = (struct udphdr *) payload; + len -= sizeof(struct udphdr); + payload += sizeof(struct udphdr); + + sport = ntohs(u->source); + dport = ntohs(u->dest); + checksum = ntohs(u->check); + } + } else { + struct ipv6hdr *ip6 = (struct ipv6hdr*) payload; + saddr = (unsigned char *)&ip6->saddr; + daddr = (unsigned char *)&ip6->daddr; + //ipv6 + } + //pass everything we can and let Go handle it, I'm not a big fan of C + uint32_t verdict = go_nfq_callback(id, ntohs(ph->hw_protocol), ph->hook, &mark, ip->version, ip->protocol, + ip->tos, ip->ttl, saddr, daddr, sport, dport, checksum, len, payload, data); + return nfq_set_verdict2(qh, id, verdict, mark, 0, NULL); +} + +void loop_for_packets(struct nfq_handle *h) { + int fd = nfq_fd(h); + char buf[4096] __attribute__ ((aligned)); + int rv; + while ((rv = recv(fd, buf, sizeof(buf), 0)) && rv >= 0) { + nfq_handle_packet(h, buf, rv); + } +} diff --git a/nfqueue/nfqueue.go b/nfqueue/nfqueue.go new file mode 100644 index 0000000..0795930 --- /dev/null +++ b/nfqueue/nfqueue.go @@ -0,0 +1,180 @@ +package nfqueue + +/* +#cgo LDFLAGS: -lnetfilter_queue +#cgo CFLAGS: -Wall +#include "nfqueue.h" +*/ +import "C" + +import ( + "net" + "os" + "runtime" + "sync" + "syscall" + "time" + "unsafe" +) + +type nfQueue struct { + DefaultVerdict Verdict + Timeout time.Duration + qid uint16 + h *C.struct_nfq_handle + //qh *C.struct_q_handle + qh *C.struct_nfq_q_handle + fd int + lk sync.Mutex + + pktch chan *Packet +} + +func NewNFQueue(qid uint16) (nfq *nfQueue) { + if os.Geteuid() != 0 { + + } + if os.Geteuid() != 0 { + panic("Must be ran by root.") + } + nfq = &nfQueue{DefaultVerdict: ACCEPT, Timeout: time.Microsecond * 5, qid: qid} + return nfq +} + +/* +This returns a channel that will recieve packets, +the user then must call pkt.Accept() or pkt.Drop() +*/ +func (this *nfQueue) Process() <-chan *Packet { + if this.h != nil { + return this.pktch + } + this.init() + + go func() { + runtime.LockOSThread() + C.loop_for_packets(this.h) + }() + + return this.pktch +} + +func (this *nfQueue) init() { + var err error + if this.h, err = C.nfq_open(); err != nil || this.h == nil { + panic(err) + } + + //if this.qh, err = C.nfq_create_queue(this.h, qid, C.get_cb(), unsafe.Pointer(nfq)); err != nil || this.qh == nil { + + this.pktch = make(chan *Packet, 1) + + if C.nfq_unbind_pf(this.h, C.AF_INET) < 0 { + this.Destroy() + panic("nfq_unbind_pf(AF_INET) failed, are you running root?.") + } + if C.nfq_unbind_pf(this.h, C.AF_INET6) < 0 { + this.Destroy() + panic("nfq_unbind_pf(AF_INET6) failed.") + } + + if C.nfq_bind_pf(this.h, C.AF_INET) < 0 { + this.Destroy() + panic("nfq_bind_pf(AF_INET) failed.") + } + + if C.nfq_bind_pf(this.h, C.AF_INET6) < 0 { + this.Destroy() + panic("nfq_bind_pf(AF_INET6) failed.") + } + + if this.qh, err = C.create_queue(this.h, C.uint16_t(this.qid), unsafe.Pointer(this)); err != nil || this.qh == nil { + C.nfq_close(this.h) + panic(err) + } + + this.fd = int(C.nfq_fd(this.h)) + + if C.nfq_set_mode(this.qh, C.NFQNL_COPY_PACKET, 0xffff) < 0 { + this.Destroy() + panic("nfq_set_mode(NFQNL_COPY_PACKET) failed.") + } + if C.nfq_set_queue_maxlen(this.qh, 1024*8) < 0 { + this.Destroy() + panic("nfq_set_queue_maxlen(1024 * 8) failed.") + } +} + +func (this *nfQueue) Destroy() { + this.lk.Lock() + defer this.lk.Unlock() + + if this.fd != 0 && this.Valid() { + syscall.Close(this.fd) + } + if this.qh != nil { + C.nfq_destroy_queue(this.qh) + this.qh = nil + } + if this.h != nil { + C.nfq_close(this.h) + this.h = nil + } + + if this.pktch != nil { + close(this.pktch) + } +} + +func (this *nfQueue) Valid() bool { + return this.h != nil && this.qh != nil +} + +//export go_nfq_callback +func go_nfq_callback(id uint32, hwproto uint16, hook uint8, mark *uint32, + version, protocol, tos, ttl uint8, saddr, daddr unsafe.Pointer, + sport, dport, checksum uint16, payload_len uint32, payload, nfqptr unsafe.Pointer) (v uint32) { + + var ( + nfq = (*nfQueue)(nfqptr) + ipver = IPVersion(version) + ipsz = C.int(ipver.Size()) + ) + bs := C.GoBytes(payload, (C.int)(payload_len)) + + verdict := make(chan uint32, 1) + pkt := Packet{ + QueueId: nfq.qid, + Id: id, + HWProtocol: hwproto, + Hook: hook, + Mark: *mark, + Payload: bs, + IPHeader: &IPHeader{ + Version: ipver, + Protocol: IPProtocol(protocol), + Tos: tos, + TTL: ttl, + Src: net.IP(C.GoBytes(saddr, ipsz)), + Dst: net.IP(C.GoBytes(daddr, ipsz)), + }, + + TCPUDPHeader: &TCPUDPHeader{ + SrcPort: sport, + DstPort: dport, + Checksum: checksum, + }, + + verdict: verdict, + } + nfq.pktch <- &pkt + + select { + case v = <-pkt.verdict: + *mark = pkt.Mark + case <-time.After(nfq.Timeout): + v = uint32(nfq.DefaultVerdict) + } + + return v +} diff --git a/nfqueue/nfqueue.h b/nfqueue/nfqueue.h new file mode 100644 index 0000000..e897bd7 --- /dev/null +++ b/nfqueue/nfqueue.h @@ -0,0 +1,27 @@ +#pragma once +// #define _BSD_SOURCE +// #define __BSD_SOURCE + +// #define __FAVOR_BSD // Just Using _BSD_SOURCE didn't work on my system for some reason +// #define __USE_BSD +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +// extern int nfq_callback(uint8_t version, uint8_t protocol, unsigned char *saddr, unsigned char *daddr, +// uint16_t sport, uint16_t dport, unsigned char * extra, void* data); + +int nfqueue_cb_new(struct nfq_q_handle *qh, struct nfgenmsg *nfmsg, struct nfq_data *nfa, void *data); +void loop_for_packets(struct nfq_handle *h); + +static inline struct nfq_q_handle * create_queue(struct nfq_handle *h, uint16_t num, void *data) { + //we use this because it's more convient to pass the callback in C + return nfq_create_queue(h, num, &nfqueue_cb_new, data); +} diff --git a/nfqueue/packet.go b/nfqueue/packet.go new file mode 100644 index 0000000..aec8f36 --- /dev/null +++ b/nfqueue/packet.go @@ -0,0 +1,145 @@ +package nfqueue + +import ( + "fmt" + "net" + "syscall" +) + +type ( + IPVersion uint8 + IPProtocol uint8 + Verdict uint8 +) + +const ( + IPv4 = IPVersion(4) + IPv6 = IPVersion(6) + + //convience really + IGMP = IPProtocol(syscall.IPPROTO_IGMP) + RAW = IPProtocol(syscall.IPPROTO_RAW) + TCP = IPProtocol(syscall.IPPROTO_TCP) + UDP = IPProtocol(syscall.IPPROTO_UDP) + ICMP = IPProtocol(syscall.IPPROTO_ICMP) + ICMPv6 = IPProtocol(syscall.IPPROTO_ICMPV6) +) + +const ( + DROP Verdict = iota + ACCEPT + STOLEN + QUEUE + REPEAT + STOP +) + +var ( + ErrVerdictSentOrTimedOut error = fmt.Errorf("The verdict was already sent or timed out.") +) + +func (v IPVersion) String() string { + switch v { + case IPv4: + return "IPv4" + case IPv6: + return "IPv6" + } + return fmt.Sprintf("", uint8(v)) +} + +// Returns the byte size of the ip, IPv4 = 4 bytes, IPv6 = 16 +func (v IPVersion) Size() int { + switch v { + case IPv4: + return 4 + case IPv6: + return 16 + } + return 0 +} + +func (p IPProtocol) String() string { + switch p { + case RAW: + return "RAW" + case TCP: + return "TCP" + case UDP: + return "UDP" + case ICMP: + return "ICMP" + case ICMPv6: + return "ICMPv6" + case IGMP: + return "IGMP" + } + return fmt.Sprintf("", uint8(p)) +} + +func (v Verdict) String() string { + switch v { + case DROP: + return "DROP" + case ACCEPT: + return "ACCEPT" + } + return fmt.Sprintf("", uint8(v)) +} + +type IPHeader struct { + Version IPVersion + + Tos, TTL uint8 + Protocol IPProtocol + Src, Dst net.IP +} + +type TCPUDPHeader struct { + SrcPort, DstPort uint16 + Checksum uint16 //not implemented +} + +// TODO handle other protocols + +type Packet struct { + QueueId uint16 + Id uint32 + HWProtocol uint16 + Hook uint8 + Mark uint32 + Payload []byte + *IPHeader + *TCPUDPHeader + + verdict chan uint32 +} + +func (pkt *Packet) String() string { + return fmt.Sprintf("", + pkt.QueueId, pkt.Id, pkt.Protocol, pkt.Src, pkt.SrcPort, pkt.Dst, pkt.DstPort, pkt.Mark, pkt.Checksum, pkt.Tos, pkt.TTL) +} + +func (pkt *Packet) setVerdict(v Verdict) (err error) { + defer func() { + if x := recover(); x != nil { + err = ErrVerdictSentOrTimedOut + } + }() + pkt.verdict <- uint32(v) + close(pkt.verdict) + return err +} + +func (pkt *Packet) Accept() error { + return pkt.setVerdict(ACCEPT) +} + +func (pkt *Packet) Drop() error { + return pkt.setVerdict(DROP) +} + +//HUGE warning, if the iptables rules aren't set correctly this can cause some problems. +// func (pkt *Packet) Repeat() error { +// return this.SetVerdict(REPEAT) +// } diff --git a/policy.go b/policy.go new file mode 100644 index 0000000..5672627 --- /dev/null +++ b/policy.go @@ -0,0 +1,186 @@ +package main + +import ( + "fmt" + "sync" + + "github.com/subgraph/fw-daemon/nfqueue" +) + +type pendingPkt struct { + policy *Policy + hostname string + pkt *nfqueue.Packet + proc *ProcInfo +} + +type Policy struct { + fw *Firewall + path string + application string + icon string + rules RuleList + pendingQueue []*pendingPkt + promptInProgress bool + lock sync.Mutex +} + +func (fw *Firewall) policyForPath(path string) *Policy { + fw.lock.Lock() + defer fw.lock.Unlock() + if _, ok := fw.policyMap[path]; !ok { + p := new(Policy) + p.fw = fw + p.path = path + p.application = path + entry := entryForPath(path) + if entry != nil { + p.application = entry.name + p.icon = entry.icon + } + fw.policyMap[path] = p + fw.policies = append(fw.policies, p) + } + return fw.policyMap[path] +} + +func (p *Policy) processPacket(pkt *nfqueue.Packet, proc *ProcInfo) { + p.lock.Lock() + defer p.lock.Unlock() + name := p.fw.dns.Lookup(pkt.Dst) + log.Info("Lookup(%s): %s", pkt.Dst.String(), name) + result := p.rules.filter(pkt, proc, name) + switch result { + case FILTER_DENY: + pkt.Drop() + case FILTER_ALLOW: + pkt.Accept() + case FILTER_PROMPT: + p.processPromptResult(&pendingPkt{policy: p, hostname: name, pkt: pkt, proc: proc}) + default: + log.Warning("Unexpected filter result: %d", result) + } +} + +func (p *Policy) processPromptResult(pp *pendingPkt) { + p.pendingQueue = append(p.pendingQueue, pp) + if !p.promptInProgress { + p.promptInProgress = true + go p.fw.dbus.prompt(p) + } +} + +func (p *Policy) nextPending() *pendingPkt { + p.lock.Lock() + defer p.lock.Unlock() + if len(p.pendingQueue) == 0 { + return nil + } + return p.pendingQueue[0] +} + +func (p *Policy) removePending(pp *pendingPkt) { + p.lock.Lock() + defer p.lock.Unlock() + + remaining := []*pendingPkt{} + for _, pkt := range p.pendingQueue { + if pkt != pp { + remaining = append(remaining, pkt) + } + } + if len(remaining) != len(p.pendingQueue) { + p.pendingQueue = remaining + } +} + +func (p *Policy) processNewRule(r *Rule, scope int32) bool { + p.lock.Lock() + defer p.lock.Unlock() + + if scope != APPLY_ONCE { + p.rules = append(p.rules, r) + } + + p.filterPending(r) + if len(p.pendingQueue) == 0 { + p.promptInProgress = false + } + + return p.promptInProgress +} + +func (p *Policy) filterPending(rule *Rule) { + remaining := []*pendingPkt{} + for _, pp := range p.pendingQueue { + if rule.match(pp.pkt, pp.hostname) { + log.Info("Also applying %s to %s", rule, printPacket(pp.pkt, pp.hostname)) + if rule.rtype == RULE_ALLOW { + pp.pkt.Accept() + } else { + pp.pkt.Drop() + } + } else { + remaining = append(remaining, pp) + } + } + if len(remaining) != len(p.pendingQueue) { + p.pendingQueue = remaining + } +} + +func (p *Policy) hasPersistentRules() bool { + for _, r := range p.rules { + if !r.sessionOnly { + return true + } + } + return false +} + +func printPacket(pkt *nfqueue.Packet, hostname string) string { + proto := func() string { + switch pkt.Protocol { + case nfqueue.TCP: + return "TCP" + case nfqueue.UDP: + return "UDP" + default: + return "???" + } + }() + name := hostname + if name == "" { + name = pkt.Dst.String() + } + return fmt.Sprintf("(%s %s:%d --> %s:%d)", proto, pkt.Src, pkt.SrcPort, name, pkt.DstPort) +} + +func (fw *Firewall) filterPacket(pkt *nfqueue.Packet) { + if pkt.Protocol == nfqueue.UDP && pkt.SrcPort == 53 { + pkt.Accept() + fw.dns.processDNS(pkt) + return + } + log.Debug("filterPacket %s", printPacket(pkt, fw.dns.Lookup(pkt.Dst))) + if basicAllowPacket(pkt) { + pkt.Accept() + return + } + + proc := findProcessForPacket(pkt) + + if proc == nil { + log.Warning("No process for: %v", pkt) + pkt.Accept() + return + } + policy := fw.policyForPath(proc.exePath) + policy.processPacket(pkt, proc) +} + +func basicAllowPacket(pkt *nfqueue.Packet) bool { + return pkt.Dst.IsLoopback() || + pkt.Dst.IsLinkLocalMulticast() || + pkt.Protocol != nfqueue.TCP +} diff --git a/proc.go b/proc.go new file mode 100644 index 0000000..d492a4d --- /dev/null +++ b/proc.go @@ -0,0 +1,228 @@ +package main + +import ( + "encoding/hex" + "errors" + "fmt" + "io/ioutil" + "net" + "os" + "path/filepath" + "strconv" + "strings" + + "github.com/subgraph/fw-daemon/nfqueue" +) + +type ProcInfo struct { + pid int + uid int + exePath string + cmdLine string +} + +type socketAddr struct { + ip net.IP + port uint16 +} + +type socketStatus struct { + local socketAddr + remote socketAddr + uid int + inode uint64 + pid int +} + +func findProcessForPacket(pkt *nfqueue.Packet) *ProcInfo { + ss := getSocketForPacket(pkt) + if ss == nil { + return nil + } + exePath, err := os.Readlink(fmt.Sprintf("/proc/%d/exe", ss.pid)) + if err != nil { + log.Warning("Error reading exe link for pid %d: %v", ss.pid, err) + return nil + } + bs, err := ioutil.ReadFile(fmt.Sprintf("/proc/%d/cmdline", ss.pid)) + if err != nil { + log.Warning("Error reading cmdline for pid %d: %v", ss.pid, err) + return nil + } + for i, b := range bs { + if b == 0 { + bs[i] = byte(' ') + } + } + + finfo, err := os.Stat(fmt.Sprintf("/proc/%d", ss.pid)) + if err != nil { + log.Warning("Could not stat /proc/%d: %v", ss.pid, err) + return nil + } + finfo.Sys() + return &ProcInfo{ + pid: ss.pid, + uid: ss.uid, + exePath: exePath, + cmdLine: string(bs), + } +} + +func getSocketLinesForPacket(pkt *nfqueue.Packet) []string { + if pkt.Protocol == nfqueue.TCP { + return getSocketLines("tcp") + } else if pkt.Protocol == nfqueue.UDP { + return getSocketLines("udp") + } else { + log.Warning("Cannot lookup socket for protocol %s", pkt.Protocol) + return nil + } +} + +func getSocketLines(proto string) []string { + path := fmt.Sprintf("/proc/net/%s", proto) + data, err := ioutil.ReadFile(path) + if err != nil { + log.Warning("Error reading %s: %v", path, err) + return nil + } + lines := strings.Split(string(data), "\n") + if len(lines) > 0 { + lines = lines[1:] + } + return lines +} + +func (sa *socketAddr) parse(s string) error { + ipPort := strings.Split(s, ":") + if len(ipPort) != 2 { + return fmt.Errorf("badly formatted socket address field: %s", s) + } + ip, err := ParseIp(ipPort[0]) + if err != nil { + return fmt.Errorf("error parsing ip field [%s]: %v", ipPort[0], err) + } + port, err := ParsePort(ipPort[1]) + if err != nil { + return fmt.Errorf("error parsing port field [%s]: %v", ipPort[1], err) + } + sa.ip = ip + sa.port = port + return nil +} + +func (ss *socketStatus) parseLine(line string) error { + fs := strings.Fields(line) + if len(fs) < 10 { + return errors.New("insufficient fields") + } + if err := ss.local.parse(fs[1]); err != nil { + return err + } + if err := ss.remote.parse(fs[2]); err != nil { + return err + } + uid, err := strconv.ParseUint(fs[7], 10, 32) + if err != nil { + return err + } + ss.uid = int(uid) + inode, err := strconv.ParseUint(fs[9], 10, 64) + if err != nil { + return err + } + ss.inode = inode + return nil +} + +func getSocketForPacket(pkt *nfqueue.Packet) *socketStatus { + ss := findSocket(pkt) + if ss == nil { + return nil + } + pid := findPidForInode(ss.inode) + if pid == -1 { + return nil + } + ss.pid = pid + return ss +} + +func findSocket(pkt *nfqueue.Packet) *socketStatus { + var status socketStatus + for _, line := range getSocketLinesForPacket(pkt) { + if err := status.parseLine(line); err != nil { + log.Warning("Unable to parse line [%s]: %v", line, err) + } else { + if status.remote.ip.Equal(pkt.Dst) && status.remote.port == pkt.DstPort { + return &status + } + } + } + return nil +} + +func ParseIp(ip string) (net.IP, error) { + var result net.IP + dst, err := hex.DecodeString(ip) + if err != nil { + return result, fmt.Errorf("Error parsing IP: %s", err) + } + // Reverse byte order -- /proc/net/tcp etc. is little-endian + // TODO: Does this vary by architecture? + for i, j := 0, len(dst)-1; i < j; i, j = i+1, j-1 { + dst[i], dst[j] = dst[j], dst[i] + } + result = net.IP(dst) + return result, nil +} + +func ParsePort(port string) (uint16, error) { + p64, err := strconv.ParseInt(port, 16, 32) + if err != nil { + return 0, fmt.Errorf("Error parsing port: %s", err) + } + return uint16(p64), nil +} + +func findPidForInode(inode uint64) int { + search := fmt.Sprintf("socket:[%d]", inode) + for _, pid := range getAllPids() { + if matchesSocketLink(pid, search) { + return pid + } + } + return -1 +} + +func matchesSocketLink(pid int, search string) bool { + paths, _ := filepath.Glob(fmt.Sprintf("/proc/%d/fd/*", pid)) + for _, p := range paths { + link, err := os.Readlink(p) + if err == nil && link == search { + return true + } + } + return false +} + +func getAllPids() []int { + var pids []int + d, err := os.Open("/proc") + if err != nil { + log.Warning("Error opening /proc: %v", err) + return nil + } + names, err := d.Readdirnames(0) + if err != nil { + log.Warning("Error reading directory names from /proc: %v", err) + return nil + } + for _, n := range names { + if pid, err := strconv.ParseUint(n, 10, 32); err == nil { + pids = append(pids, int(pid)) + } + } + return pids +} diff --git a/prompt.go b/prompt.go new file mode 100644 index 0000000..14df65e --- /dev/null +++ b/prompt.go @@ -0,0 +1,170 @@ +package main + +import ( + "fmt" + "github.com/godbus/dbus" + "os/user" + "strconv" + "sync" +) + +const ( + APPLY_ONCE = iota + APPLY_SESSION + APPLY_FOREVER +) + +func newPrompter(conn *dbus.Conn) *prompter { + p := new(prompter) + p.cond = sync.NewCond(&p.lock) + p.dbusObj = conn.Object("com.subgraph.FirewallPrompt", "/com/subgraph/FirewallPrompt") + p.policyMap = make(map[string]*Policy) + go p.promptLoop() + return p +} + +type prompter struct { + dbusObj dbus.BusObject + lock sync.Mutex + cond *sync.Cond + policyMap map[string]*Policy + policyQueue []*Policy +} + +func (p *prompter) prompt(policy *Policy) { + p.lock.Lock() + defer p.lock.Unlock() + _, ok := p.policyMap[policy.path] + if ok { + return + } + p.policyMap[policy.path] = policy + p.policyQueue = append(p.policyQueue, policy) + p.cond.Signal() +} + +func (p *prompter) promptLoop() { + p.lock.Lock() + for { + for p.processNextPacket() { + } + p.cond.Wait() + } +} + +func (p *prompter) processNextPacket() bool { + pp := p.nextPacket() + if pp == nil { + return false + } + p.lock.Unlock() + defer p.lock.Lock() + p.processPacket(pp) + return true +} + +func printScope(scope int32) string { + switch scope { + case APPLY_FOREVER: + return "APPLY_FOREVER" + case APPLY_SESSION: + return "APPLY_SESSION" + case APPLY_ONCE: + return "APPLY_ONCE" + default: + return fmt.Sprintf("Unknown (%d)", scope) + } +} + +func (p *prompter) processPacket(pp *pendingPkt) { + var scope int32 + var rule string + + addr := pp.hostname + if addr == "" { + addr = pp.pkt.Dst.String() + } + + call := p.dbusObj.Call("com.subgraph.FirewallPrompt.RequestPrompt", 0, + pp.policy.application, + pp.policy.icon, + pp.policy.path, + addr, + int32(pp.pkt.DstPort), + pp.pkt.Dst.String(), + uidToUser(pp.proc.uid), + int32(pp.proc.pid)) + err := call.Store(&scope, &rule) + if err != nil { + log.Warning("Error sending dbus RequestPrompt message: %v", err) + pp.policy.removePending(pp) + pp.pkt.Drop() + return + } + log.Debug("Received prompt response: %s [%s]", printScope(scope), rule) + + r, err := parseRule(rule) + if err != nil { + log.Warning("Error parsing rule string returned from dbus RequestPrompt: %v", err) + pp.policy.removePending(pp) + pp.pkt.Drop() + return + } + if scope == APPLY_SESSION { + r.sessionOnly = true + } + if !pp.policy.processNewRule(r, scope) { + p.lock.Lock() + defer p.lock.Unlock() + p.removePolicy(pp.policy) + } + if scope == APPLY_FOREVER { + pp.policy.fw.saveRules() + } +} + +func (p *prompter) nextPacket() *pendingPkt { + for { + if len(p.policyQueue) == 0 { + return nil + } + policy := p.policyQueue[0] + pp := policy.nextPending() + if pp == nil { + p.removePolicy(policy) + } else { + return pp + } + } +} + +func (p *prompter) removePolicy(policy *Policy) { + newQueue := make([]*Policy, 0, len(p.policyQueue)-1) + for _, pol := range p.policyQueue { + if pol != policy { + newQueue = append(newQueue, pol) + } + } + p.policyQueue = newQueue + delete(p.policyMap, policy.path) +} + +var userMap = make(map[int]string) + +func lookupUser(uid int) string { + u, err := user.LookupId(strconv.Itoa(uid)) + if err != nil { + return fmt.Sprintf("%d", uid) + } + return u.Name +} + +func uidToUser(uid int) string { + uname, ok := userMap[uid] + if ok { + return uname + } + uname = lookupUser(uid) + userMap[uid] = uname + return uname +} diff --git a/rules.go b/rules.go new file mode 100644 index 0000000..513bf35 --- /dev/null +++ b/rules.go @@ -0,0 +1,269 @@ +package main + +import ( + "encoding/binary" + "fmt" + "net" + "strings" + "unicode" + + "github.com/subgraph/fw-daemon/nfqueue" + "io/ioutil" + "os" + "path/filepath" + "strconv" +) + +const ( + RULE_DENY = iota + RULE_ALLOW +) + +const matchAny = 0 +const noAddress = uint32(0xffffffff) + +type Rule struct { + sessionOnly bool + rtype int + hostname string + addr uint32 + port uint16 +} + +func (r *Rule) String() string { + addr := "*" + port := "*" + rtype := "DENY" + + if r.hostname != "" { + addr = r.hostname + } else if r.addr != matchAny && r.addr != noAddress { + bs := make([]byte, 4) + binary.BigEndian.PutUint32(bs, r.addr) + addr = fmt.Sprintf("%d.%d.%d.%d", bs[0], bs[1], bs[2], bs[3]) + } + + if r.port != matchAny { + port = fmt.Sprintf("%d", r.port) + } + + if r.rtype == RULE_ALLOW { + rtype = "ALLOW" + } + + return fmt.Sprintf("%s %s:%s", rtype, addr, port) +} + +type RuleList []*Rule + +func (r *Rule) match(pkt *nfqueue.Packet, name string) bool { + if r.port != matchAny && r.port != pkt.DstPort { + return false + } + if r.addr == matchAny { + return true + } + if r.hostname != "" { + return r.hostname == name + } + return r.addr == binary.BigEndian.Uint32(pkt.Dst) +} + +type FilterResult int + +const ( + FILTER_DENY FilterResult = iota + FILTER_ALLOW + FILTER_PROMPT +) + +func (rl *RuleList) filter(p *nfqueue.Packet, proc *ProcInfo, hostname string) FilterResult { + if rl == nil { + return FILTER_PROMPT + } + result := FILTER_PROMPT + for _, r := range *rl { + if r.match(p, hostname) { + log.Info("%s (%s -> %s:%d)", r, proc.exePath, p.Dst.String(), p.DstPort) + if r.rtype == RULE_DENY { + return FILTER_DENY + } else if r.rtype == RULE_ALLOW { + result = FILTER_ALLOW + } + } + } + return result +} + +func parseError(s string) error { + return fmt.Errorf("unable to parse rule string: %s", s) +} + +func (r *Rule) parse(s string) bool { + r.addr = noAddress + parts := strings.Split(s, "|") + if len(parts) != 2 { + return false + } + return r.parseVerb(parts[0]) && r.parseTarget(parts[1]) +} + +func (r *Rule) parseVerb(v string) bool { + switch v { + case "ALLOW": + r.rtype = RULE_ALLOW + return true + case "DENY": + r.rtype = RULE_DENY + return true + } + return false +} + +func (r *Rule) parseTarget(t string) bool { + addrPort := strings.Split(t, ":") + if len(addrPort) != 2 { + return false + } + return r.parseAddr(addrPort[0]) && r.parsePort(addrPort[1]) +} + +func (r *Rule) parseAddr(a string) bool { + if a == "*" { + r.hostname = "" + r.addr = matchAny + return true + } + if strings.IndexFunc(a, unicode.IsLetter) != -1 { + r.hostname = a + return true + } + ip := net.ParseIP(a) + if ip == nil || len(ip) != 4 { + return false + } + r.addr = binary.BigEndian.Uint32(ip) + return true +} + +func (r *Rule) parsePort(p string) bool { + if p == "*" { + r.port = matchAny + return true + } + var err error + port, err := strconv.ParseUint(p, 10, 16) + if err != nil { + return false + } + r.port = uint16(port) + return true +} + +func parseRule(s string) (*Rule, error) { + r := new(Rule) + if !r.parse(s) { + return nil, parseError(s) + } + return r, nil +} + +const ruleFile = ".sgfw_rules" + +func rulesPath() string { + home := os.Getenv("HOME") + if home != "" { + return filepath.Join(home, ruleFile) + } + // XXX try something else? + return "" +} + +func (fw *Firewall) saveRules() { + fw.lock.Lock() + defer fw.lock.Unlock() + + f, err := os.Create(rulesPath()) + if err != nil { + log.Warning("Failed to open %s for writing: %v", rulesPath(), err) + return + } + defer f.Close() + + for _, p := range fw.policies { + savePolicy(f, p) + } +} + +func savePolicy(f *os.File, p *Policy) { + p.lock.Lock() + defer p.lock.Unlock() + if !p.hasPersistentRules() { + return + } + + if !writeLine(f, "["+p.path+"]") { + return + } + for _, r := range p.rules { + if !r.sessionOnly { + if !writeLine(f, r.String()) { + return + } + } + } +} + +func writeLine(f *os.File, line string) bool { + _, err := f.WriteString(line + "\n") + if err != nil { + log.Warning("Error writing to rule file: %v", err) + return false + } + return true +} + +func (fw *Firewall) loadRules() { + fw.lock.Lock() + defer fw.lock.Unlock() + + bs, err := ioutil.ReadFile(rulesPath()) + if err != nil { + if !os.IsNotExist(err) { + log.Warning("Failed to open %s for reading: %v", rulesPath(), err) + } + return + } + var policy *Policy + for _, line := range strings.Split(string(bs), "\n") { + if strings.HasPrefix(line, "[") && strings.HasSuffix(line, "]") { + policy = fw.processPathLine(line) + } else { + processRuleLine(policy, line) + } + } +} + +func (fw *Firewall) processPathLine(line string) *Policy { + path := line[1 : len(line)-1] + policy := fw.policyForPath(path) + policy.lock.Lock() + defer policy.lock.Unlock() + policy.rules = nil + return policy +} + +func processRuleLine(policy *Policy, line string) { + if policy == nil { + log.Warning("Cannot process rule line without first seeing path line: %s", line) + return + } + rule, err := parseRule(line) + if err != nil { + log.Warning("Error parsing rule (%s): %v", line, err) + return + } + policy.lock.Lock() + defer policy.lock.Unlock() + policy.rules = append(policy.rules, rule) +}