Working (but not intensively tested) IPv6 support!

shw_dev
shw 7 years ago
parent 5f5042fed4
commit 8546f6c416

@ -65,7 +65,7 @@ func (dc *dnsCache) processDNS(pkt *nfqueue.NFQPacket) {
} }
q := dns.question[0] q := dns.question[0]
if q.Qtype == dnsTypeA { if q.Qtype == dnsTypeA {
srcip, _ := getPacketIP4Addrs(pkt) srcip, _ := getPacketIPAddrs(pkt)
pinfo := getEmptyPInfo() pinfo := getEmptyPInfo()
if !isNSTrusted(srcip) { if !isNSTrusted(srcip) {
pinfo, _ = findProcessForPacket(pkt, true, procsnitch.MATCH_LOOSEST) pinfo, _ = findProcessForPacket(pkt, true, procsnitch.MATCH_LOOSEST)
@ -145,6 +145,8 @@ func (dc *dnsCache) processRecordA(name string, answers []dnsRR, pid int) {
if !FirewallConfig.LogRedact { if !FirewallConfig.LogRedact {
log.Infof("Adding %s: %s", name, ip) log.Infof("Adding %s: %s", name, ip)
} }
case *dnsRR_AAAA:
log.Warning("AAAA record read from DNS; not supported.")
case *dnsRR_CNAME: case *dnsRR_CNAME:
// Not that exotic; just ignore it // Not that exotic; just ignore it
default: default:

@ -68,7 +68,10 @@ func loadDesktopFile(path string) {
inDE = false inDE = false
} }
if inDE && strings.HasPrefix(line, "Exec=") { if inDE && strings.HasPrefix(line, "Exec=") {
exec = strings.Fields(line[5:])[0] fields := strings.Fields(line[5:])
if len(fields) > 0 {
exec = fields[0]
}
} }
if inDE && strings.HasPrefix(line, "Icon=") { if inDE && strings.HasPrefix(line, "Icon=") {
icon = line[5:] icon = line[5:]

@ -7,7 +7,6 @@ import (
"bufio" "bufio"
"strings" "strings"
"strconv" "strconv"
"encoding/binary"
) )
const ReceiverSocketPath = "/tmp/fwoz.sock" const ReceiverSocketPath = "/tmp/fwoz.sock"
@ -120,8 +119,6 @@ func ReceiverLoop(fw *Firewall, c net.Conn) {
c.Write([]byte(banner)) c.Write([]byte(banner))
for r := 0; r < len(rl); r++ { for r := 0; r < len(rl); r++ {
ip := make([]byte, 4)
binary.BigEndian.PutUint32(ip, rl[r].addr)
hostname := "" hostname := ""
if rl[r].hostname != "" { if rl[r].hostname != "" {
@ -134,7 +131,7 @@ func ReceiverLoop(fw *Firewall, c net.Conn) {
portstr = "*" portstr = "*"
} }
ruledesc := fmt.Sprintf("id %v, %v | %v, src:%v -> %v%v: %v\n", rl[r].id, RuleModeString[rl[r].mode], RuleActionString[rl[r].rtype], rl[r].saddr, net.IP(ip), hostname, portstr) ruledesc := fmt.Sprintf("id %v, %v | %v, src:%v -> %v%v: %v\n", rl[r].id, RuleModeString[rl[r].mode], RuleActionString[rl[r].rtype], rl[r].saddr, rl[r].addr, hostname, portstr)
c.Write([]byte(ruledesc)) c.Write([]byte(ruledesc))
} }

@ -93,21 +93,13 @@ func (pp *pendingPkt) hostname() string {
} }
func (pp *pendingPkt) src() net.IP { func (pp *pendingPkt) src() net.IP {
src, _ := getPacketIP4Addrs(pp.pkt) src, _ := getPacketIPAddrs(pp.pkt)
return src return src
} }
func (pp *pendingPkt) dst() net.IP { func (pp *pendingPkt) dst() net.IP {
_, dst := getPacketIP4Addrs(pp.pkt) _, dst := getPacketIPAddrs(pp.pkt)
return dst return dst
/* dst := pp.pkt.Packet.NetworkLayer().NetworkFlow().Dst()
if dst.EndpointType() != layers.EndpointIPv4 {
return nil
}
return dst.Raw() */
// pp.pkt.NetworkLayer().Layer
} }
func getNFQProto(pkt *nfqueue.NFQPacket) string { func getNFQProto(pkt *nfqueue.NFQPacket) string {
@ -368,7 +360,7 @@ func (p *Policy) hasPersistentRules() bool {
func printPacket(pkt *nfqueue.NFQPacket, hostname string, pinfo *procsnitch.Info) string { func printPacket(pkt *nfqueue.NFQPacket, hostname string, pinfo *procsnitch.Info) string {
proto := "???" proto := "???"
SrcPort, DstPort := uint16(0), uint16(0) SrcPort, DstPort := uint16(0), uint16(0)
SrcIp, DstIp := getPacketIP4Addrs(pkt) SrcIp, DstIp := getPacketIPAddrs(pkt)
code := 0 code := 0
codestr := "" codestr := ""
@ -418,7 +410,7 @@ func (fw *Firewall) filterPacket(pkt *nfqueue.NFQPacket) {
} }
} }
_, dstip := getPacketIP4Addrs(pkt) _, dstip := getPacketIPAddrs(pkt)
/* _, dstp := getPacketPorts(pkt) /* _, dstp := getPacketPorts(pkt)
fwo := matchAgainstOzRules(srcip, dstip, dstp) fwo := matchAgainstOzRules(srcip, dstip, dstp)
log.Notice("XXX: Attempting [2] to filter packet on rules -> ", fwo) log.Notice("XXX: Attempting [2] to filter packet on rules -> ", fwo)
@ -558,14 +550,14 @@ func getRealRoot(pathname string, pid int) string {
} }
func findProcessForPacket(pkt *nfqueue.NFQPacket, reverse bool, strictness int) (*procsnitch.Info, string) { func findProcessForPacket(pkt *nfqueue.NFQPacket, reverse bool, strictness int) (*procsnitch.Info, string) {
srcip, dstip := getPacketIP4Addrs(pkt) srcip, dstip := getPacketIPAddrs(pkt)
srcp, dstp := getPacketPorts(pkt) srcp, dstp := getPacketPorts(pkt)
proto := "" proto := ""
optstr := "" optstr := ""
icode := -1 icode := -1
if reverse { if reverse {
dstip, srcip = getPacketIP4Addrs(pkt) dstip, srcip = getPacketIPAddrs(pkt)
dstp, srcp = getPacketPorts(pkt) dstp, srcp = getPacketPorts(pkt)
} }
@ -576,6 +568,9 @@ func findProcessForPacket(pkt *nfqueue.NFQPacket, reverse bool, strictness int)
} else if pkt.Packet.Layer(layers.LayerTypeICMPv4) != nil { } else if pkt.Packet.Layer(layers.LayerTypeICMPv4) != nil {
proto = "icmp" proto = "icmp"
icode, _ = getpacketICMPCode(pkt) icode, _ = getpacketICMPCode(pkt)
} else if pkt.Packet.Layer(layers.LayerTypeICMPv6) != nil {
proto = "icmp"
icode, _ = getpacketICMPCode(pkt)
} }
if proto == "" { if proto == "" {
@ -656,7 +651,7 @@ func findProcessForPacket(pkt *nfqueue.NFQPacket, reverse bool, strictness int)
} }
func basicAllowPacket(pkt *nfqueue.NFQPacket) bool { func basicAllowPacket(pkt *nfqueue.NFQPacket) bool {
srcip, dstip := getPacketIP4Addrs(pkt) srcip, dstip := getPacketIPAddrs(pkt)
if pkt.Packet.Layer(layers.LayerTypeUDP) != nil { if pkt.Packet.Layer(layers.LayerTypeUDP) != nil {
_, dport := getPacketUDPPorts(pkt) _, dport := getPacketUDPPorts(pkt)
if dport == 53 { if dport == 53 {
@ -674,15 +669,29 @@ func basicAllowPacket(pkt *nfqueue.NFQPacket) bool {
pkt.Packet.Layer(layers.LayerTypeICMPv4) == nil) pkt.Packet.Layer(layers.LayerTypeICMPv4) == nil)
} }
func getPacketIP4Addrs(pkt *nfqueue.NFQPacket) (net.IP, net.IP) { func getPacketIPAddrs(pkt *nfqueue.NFQPacket) (net.IP, net.IP) {
ipv4 := true
ipLayer := pkt.Packet.Layer(layers.LayerTypeIPv4) ipLayer := pkt.Packet.Layer(layers.LayerTypeIPv4)
if ipLayer == nil { if ipLayer == nil {
return net.IP{0,0,0,0}, net.IP{0,0,0,0} ipv4 = false
ipLayer = pkt.Packet.Layer(layers.LayerTypeIPv6)
}
if ipLayer == nil {
if ipv4 {
return net.IP{0,0,0,0}, net.IP{0,0,0,0}
}
return net.IP{}, net.IP{}
}
if !ipv4 {
ip6, _ := ipLayer.(*layers.IPv6)
return ip6.SrcIP, ip6.DstIP
} }
ip, _ := ipLayer.(*layers.IPv4) ip4, _ := ipLayer.(*layers.IPv4)
return ip.SrcIP, ip.DstIP return ip4.SrcIP, ip4.DstIP
} }
func getpacketICMPCode(pkt *nfqueue.NFQPacket) (int, string) { func getpacketICMPCode(pkt *nfqueue.NFQPacket) (int, string) {

@ -9,7 +9,6 @@ import (
"path" "path"
"strconv" "strconv"
"strings" "strings"
"unicode"
"regexp" "regexp"
nfqueue "github.com/subgraph/go-nfnetlink/nfqueue" nfqueue "github.com/subgraph/go-nfnetlink/nfqueue"
@ -18,7 +17,9 @@ import (
) )
const matchAny = 0 const matchAny = 0
const noAddress = uint32(0xffffffff) //const noAddress = uint32(0xffffffff)
var anyAddress net.IP = net.IP{0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,}
var noAddress net.IP = net.IP{0xff,0xff,0xff,0xff,0xff,0xff,0xff,0xff,0xff,0xff,0xff,0xff,0xff,0xff,0xff,0xff}
type Rule struct { type Rule struct {
id uint id uint
@ -28,7 +29,7 @@ type Rule struct {
proto string proto string
hostname string hostname string
network *net.IPNet network *net.IPNet
addr uint32 addr net.IP
saddr net.IP saddr net.IP
port uint16 port uint16
uid int uid int
@ -66,10 +67,11 @@ func (r *Rule) AddrString(redact bool) string {
addr = r.hostname addr = r.hostname
} else if r.network != nil { } else if r.network != nil {
addr = r.network.String() addr = r.network.String()
} else if r.addr != matchAny && r.addr != noAddress { } else if !addrMatchesAny(r.addr) && !addrMatchesNone(r.addr) {
bs := make([]byte, 4) // bs := make([]byte, 4)
binary.BigEndian.PutUint32(bs, r.addr) // binary.BigEndian.PutUint32(bs, r.addr)
addr = fmt.Sprintf("%d.%d.%d.%d", bs[0], bs[1], bs[2], bs[3]) // addr = fmt.Sprintf("%d.%d.%d.%d", bs[0], bs[1], bs[2], bs[3])
addr = r.addr.String()
} }
if r.port != matchAny || r.proto == "icmp" { if r.port != matchAny || r.proto == "icmp" {
@ -99,13 +101,11 @@ func (r *Rule) match(src net.IP, dst net.IP, dstPort uint16, hostname string, pr
return false return false
} }
xip := make(net.IP, 4) log.Notice("comparison: ", hostname, " / ", dst, " : ", dstPort, " -> ", r.addr, " / ", r.hostname, " : ", r.port)
binary.BigEndian.PutUint32(xip, r.addr)
log.Notice("comparison: ", hostname, " / ", dst, " : ", dstPort, " -> ", xip, " / ", r.hostname, " : ", r.port)
if r.port != matchAny && r.port != dstPort { if r.port != matchAny && r.port != dstPort {
return false return false
} }
if r.addr == matchAny { if addrMatchesAny(r.addr) {
return true return true
} }
if r.hostname != "" { if r.hostname != "" {
@ -126,15 +126,15 @@ log.Notice("comparison: ", hostname, " / ", dst, " : ", dstPort, " -> ", xip, "
} }
if proto == "icmp" { if proto == "icmp" {
fmt.Printf("network = %v, src = %v, r.addr = %x, src to4 = %x\n", r.network, src, r.addr, binary.BigEndian.Uint32(src.To4())) fmt.Printf("network = %v, src = %v, r.addr = %x, src to4 = %x\n", r.network, src, r.addr, binary.BigEndian.Uint32(src.To4()))
if (r.network != nil && r.network.Contains(src)) || (r.addr == binary.BigEndian.Uint32(src.To4())) { if (r.network != nil && r.network.Contains(src)) || (r.addr.Equal(src)) {
return true return true
} }
} }
return r.addr == binary.BigEndian.Uint32(dst.To4()) return r.addr.Equal(dst)
} }
func (rl *RuleList) filterPacket(p *nfqueue.NFQPacket, pinfo *procsnitch.Info, srcip net.IP, hostname, optstr string) FilterResult { func (rl *RuleList) filterPacket(p *nfqueue.NFQPacket, pinfo *procsnitch.Info, srcip net.IP, hostname, optstr string) FilterResult {
_, dstip := getPacketIP4Addrs(p) _, dstip := getPacketIPAddrs(p)
_, dstp := getPacketPorts(p) _, dstp := getPacketPorts(p)
return rl.filter(p, srcip, dstip, dstp, hostname, pinfo, optstr) return rl.filter(p, srcip, dstip, dstp, hostname, pinfo, optstr)
} }
@ -162,7 +162,7 @@ log.Notice("+ MATCH SUCCEEDED")
} }
srcStr := STR_UNKNOWN srcStr := STR_UNKNOWN
if pkt != nil { if pkt != nil {
srcip, _ := getPacketIP4Addrs(pkt) srcip, _ := getPacketIPAddrs(pkt)
srcp, _ := getPacketPorts(pkt) srcp, _ := getPacketPorts(pkt)
srcStr = fmt.Sprintf("%s:%d", srcip, srcp) srcStr = fmt.Sprintf("%s:%d", srcip, srcp)
} }
@ -272,30 +272,31 @@ func (r *Rule) parseVerb(v string) bool {
func (r *Rule) parseTarget(t string) bool { func (r *Rule) parseTarget(t string) bool {
addrPort := strings.Split(t, ":") addrPort := strings.Split(t, ":")
if len(addrPort) != 2 && len(addrPort) != 3 { if len(addrPort) < 2 {
return false return false
} }
sind := 0 sind := 0
if len(addrPort) == 3 { lind := len(addrPort)-1
if addrPort[0] != "udp" && addrPort[0] != "icmp" && addrPort[0] != "tcp" { if addrPort[0] == "udp" || addrPort[0] == "icmp" || addrPort[0] == "tcp" {
return false
}
r.proto = addrPort[0] r.proto = addrPort[0]
sind++ sind++
} else { } else {
r.proto = "tcp" r.proto = "tcp"
} }
return r.parseAddr(addrPort[sind]) && r.parsePort(addrPort[sind+1]) newAddr := strings.Join(addrPort[sind:lind], ":")
return r.parseAddr(newAddr) && r.parsePort(addrPort[lind])
// return r.parseAddr(addrPort[sind]) && r.parsePort(addrPort[sind+1])
} }
func (r *Rule) parseAddr(a string) bool { func (r *Rule) parseAddr(a string) bool {
if a == "*" { if a == "*" {
r.hostname = "" r.hostname = ""
r.addr = matchAny r.addr = anyAddress
return true return true
} }
if strings.IndexFunc(a, unicode.IsLetter) != -1 { // if strings.IndexFunc(a, unicode.IsLetter) != -1 {
if net.ParseIP(a) == nil {
r.hostname = a r.hostname = a
return true return true
} }
@ -310,7 +311,7 @@ func (r *Rule) parseAddr(a string) bool {
} else { } else {
r.network = ipnet r.network = ipnet
} }
r.addr = binary.BigEndian.Uint32(ip.To4()) r.addr = ip
return true return true
} }
@ -445,3 +446,23 @@ func processRuleLine(policy *Policy, line string) {
return return
} }
} }
func addrMatchesAny(addr net.IP) bool {
any := anyAddress
if addr.To4() != nil {
any = net.IP{0,0,0,0}
}
return any.Equal(addr)
}
func addrMatchesNone(addr net.IP) bool {
none := noAddress
if addr.To4() != nil {
none = net.IP{0xff,0xff,0xff,0xff}
}
return none.Equal(addr)
}

@ -16,6 +16,8 @@ import (
nfqueue "github.com/subgraph/go-nfnetlink/nfqueue" nfqueue "github.com/subgraph/go-nfnetlink/nfqueue"
// "github.com/subgraph/go-nfnetlink" // "github.com/subgraph/go-nfnetlink"
"github.com/subgraph/go-procsnitch" "github.com/subgraph/go-procsnitch"
"github.com/google/gopacket"
"github.com/google/gopacket/layers"
) )
var dbusp *dbusObjectP = nil var dbusp *dbusObjectP = nil
@ -108,6 +110,22 @@ func (fw *Firewall) runFilter() {
go func() { go func() {
for p := range ps { for p := range ps {
if fw.isEnabled() { if fw.isEnabled() {
ipLayer := p.Packet.Layer(layers.LayerTypeIPv4)
if ipLayer == nil {
continue
}
ip, _ := ipLayer.(*layers.IPv4)
if ip == nil {
continue
}
if ip.Version == 6 {
ip6p := gopacket.NewPacket(ip.LayerContents(), layers.LayerTypeIPv6, gopacket.Default)
p.Packet = ip6p
}
fw.filterPacket(p) fw.filterPacket(p)
} else { } else {
p.Accept() p.Accept()

@ -2,6 +2,7 @@ package procsnitch
import ( import (
"encoding/hex" "encoding/hex"
"encoding/binary"
"errors" "errors"
"fmt" "fmt"
"github.com/op/go-logging" "github.com/op/go-logging"
@ -9,9 +10,11 @@ import (
"net" "net"
"strconv" "strconv"
"strings" "strings"
"unsafe"
) )
var log = logging.MustGetLogger("go-procsockets") var log = logging.MustGetLogger("go-procsockets")
var isLittleEndian = -1
// SetLogger allows setting a custom go-logging instance // SetLogger allows setting a custom go-logging instance
func SetLogger(logger *logging.Logger) { func SetLogger(logger *logging.Logger) {
@ -157,11 +160,32 @@ func ParseIP(ip string) (net.IP, error) {
} }
// Reverse byte order -- /proc/net/tcp etc. is little-endian // Reverse byte order -- /proc/net/tcp etc. is little-endian
// TODO: Does this vary by architecture? // TODO: Does this vary by architecture?
for i, j := 0, len(dst)-1; i < j; i, j = i+1, j-1 { if isLittleEndian == -1 {
dst[i], dst[j] = dst[j], dst[i] setEndian()
} }
result = net.IP(dst)
return result, nil if len(dst) != 4 && len(dst) != 16 {
return result, errors.New("Unsupported address type (not IPv4 or IPv16)")
}
if isLittleEndian > 0 {
for i := 0; i < len(dst) / 4; i++ {
start, end := i*4, (i+1)*4
word := dst[start:end]
lval := binary.LittleEndian.Uint32(word)
binary.BigEndian.PutUint32(dst[start:], lval)
}
}
/* if len(dst) == 16 {
dst2 := []byte{dst[3], dst[2], dst[1], dst[0], dst[7], dst[6], dst[5], dst[4], dst[11], dst[10], dst[9], dst[8], dst[15], dst[14], dst[13], dst[12]}
return net.IP(dst2), nil
}
for i, j := 0, len(dst)-1; i < j; i, j = i+1, j-1 {
dst[i], dst[j] = dst[j], dst[i]
} */
return net.IP(dst), nil
} }
// ParsePort parses a base16 port represented as a string to a uint16 // ParsePort parses a base16 port represented as a string to a uint16
@ -285,3 +309,15 @@ func stripLabel(s string) string {
} }
return s[idx+1:] return s[idx+1:]
} }
// stolen from github.com/virtao/GoEndian
const INT_SIZE int = int(unsafe.Sizeof(0))
func setEndian() {
var i int = 0x1
bs := (*[INT_SIZE]byte)(unsafe.Pointer(&i))
if bs[0] == 0 {
isLittleEndian = 0
} else {
isLittleEndian = 1
}
}

@ -85,52 +85,61 @@ func (ss *socketStatus) String() string {
} }
func findICMPSocketAll(srcAddr net.IP, dstAddr net.IP, code int, custdata []string) *socketStatus { func findICMPSocketAll(srcAddr net.IP, dstAddr net.IP, code int, custdata []string) *socketStatus {
proto := "icmp"
if srcAddr.To4() == nil {
proto += "6"
}
if custdata == nil { if custdata == nil {
return findSocket("icmp", func(ss socketStatus) bool { return findSocket(proto, func(ss socketStatus) bool {
return ss.remote.ip.Equal(dstAddr) && ss.local.ip.Equal(srcAddr) return ss.remote.ip.Equal(dstAddr) && ss.local.ip.Equal(srcAddr)
}) })
} }
return findSocketCustom("icmp", custdata, func(ss socketStatus) bool { return findSocketCustom(proto, custdata, func(ss socketStatus) bool {
return ss.remote.ip.Equal(dstAddr) && ss.local.ip.Equal(srcAddr) return ss.remote.ip.Equal(dstAddr) && ss.local.ip.Equal(srcAddr)
}) })
} }
func findUDPSocketAll(srcAddr net.IP, srcPort uint16, dstAddr net.IP, dstPort uint16, custdata []string, strictness int) *socketStatus { func findUDPSocketAll(srcAddr net.IP, srcPort uint16, dstAddr net.IP, dstPort uint16, custdata []string, strictness int) *socketStatus {
wildcard := net.IP{0,0,0,0} proto := "udp"
if srcAddr.To4() == nil {
proto += "6"
}
if custdata == nil { if custdata == nil {
if strictness == MATCH_STRICT { if strictness == MATCH_STRICT {
return findSocket("udp", func(ss socketStatus) bool { return findSocket(proto, func(ss socketStatus) bool {
return ss.remote.ip.Equal(dstAddr) && ss.local.port == srcPort && ss.local.ip.Equal(srcAddr) return ss.remote.ip.Equal(dstAddr) && ss.local.port == srcPort && ss.local.ip.Equal(srcAddr)
}) })
} else if strictness == MATCH_LOOSE { } else if strictness == MATCH_LOOSE {
return findSocket("udp", func(ss socketStatus) bool { return findSocket(proto, func(ss socketStatus) bool {
return (ss.remote.ip.Equal(dstAddr) || ss.remote.ip.Equal(wildcard)) && ss.local.port == srcPort && ss.local.ip.Equal(srcAddr) || return (ss.remote.ip.Equal(dstAddr) || addrMatchesAny(ss.remote.ip)) && ss.local.port == srcPort && ss.local.ip.Equal(srcAddr) ||
(ss.local.ip.Equal(dstAddr) || ss.local.ip.Equal(wildcard)) && ss.remote.port == srcPort && ss.remote.ip.Equal(srcAddr) (ss.local.ip.Equal(dstAddr) || addrMatchesAny(ss.local.ip)) && ss.remote.port == srcPort && ss.remote.ip.Equal(srcAddr)
}) })
} }
return findSocket("udp", func(ss socketStatus) bool { return findSocket(proto, func(ss socketStatus) bool {
return (ss.remote.ip.Equal(dstAddr) || ss.remote.ip.Equal(wildcard)) && ss.local.port == srcPort && (ss.local.ip.Equal(srcAddr) || ss.local.ip.Equal(wildcard)) || return (ss.remote.ip.Equal(dstAddr) || addrMatchesAny(ss.remote.ip)) && ss.local.port == srcPort && (ss.local.ip.Equal(srcAddr) || addrMatchesAny(ss.local.ip)) ||
(ss.local.ip.Equal(dstAddr) || ss.local.ip.Equal(wildcard)) && ss.remote.port == srcPort && (ss.remote.ip.Equal(srcAddr) || ss.remote.ip.Equal(srcAddr)) (ss.local.ip.Equal(dstAddr) || addrMatchesAny(ss.local.ip)) && ss.remote.port == srcPort && (ss.remote.ip.Equal(srcAddr) || ss.remote.ip.Equal(srcAddr))
}) })
} }
if strictness == MATCH_STRICT { if strictness == MATCH_STRICT {
return findSocketCustom("udp", custdata, func(ss socketStatus) bool { return findSocketCustom(proto, custdata, func(ss socketStatus) bool {
return ss.remote.ip.Equal(dstAddr) && ss.local.port == srcPort && ss.local.ip.Equal(srcAddr) return ss.remote.ip.Equal(dstAddr) && ss.local.port == srcPort && ss.local.ip.Equal(srcAddr)
}) })
} else if strictness == MATCH_LOOSE { } else if strictness == MATCH_LOOSE {
return findSocketCustom("udp", custdata, func(ss socketStatus) bool { return findSocketCustom(proto, custdata, func(ss socketStatus) bool {
return (ss.remote.ip.Equal(dstAddr) || ss.remote.ip.Equal(wildcard)) && ss.local.port == srcPort && ss.local.ip.Equal(srcAddr) || return (ss.remote.ip.Equal(dstAddr) || addrMatchesAny(ss.remote.ip)) && ss.local.port == srcPort && ss.local.ip.Equal(srcAddr) ||
(ss.local.ip.Equal(dstAddr) || ss.local.ip.Equal(wildcard)) && ss.remote.port == srcPort && ss.remote.ip.Equal(srcAddr) (ss.local.ip.Equal(dstAddr) || addrMatchesAny(ss.local.ip)) && ss.remote.port == srcPort && ss.remote.ip.Equal(srcAddr)
}) })
} }
return findSocketCustom("udp", custdata, func(ss socketStatus) bool { return findSocketCustom(proto, custdata, func(ss socketStatus) bool {
return (ss.remote.ip.Equal(dstAddr) || ss.remote.ip.Equal(wildcard)) && ss.local.port == srcPort && (ss.local.ip.Equal(srcAddr) || ss.local.ip.Equal(wildcard)) || return (ss.remote.ip.Equal(dstAddr) || addrMatchesAny(ss.remote.ip)) && ss.local.port == srcPort && (ss.local.ip.Equal(srcAddr) || addrMatchesAny(ss.local.ip)) ||
(ss.local.ip.Equal(dstAddr) || ss.local.ip.Equal(wildcard)) && ss.remote.port == srcPort && (ss.remote.ip.Equal(srcAddr) || ss.remote.ip.Equal(srcAddr)) (ss.local.ip.Equal(dstAddr) || addrMatchesAny(ss.local.ip)) && ss.remote.port == srcPort && (ss.remote.ip.Equal(srcAddr) || ss.remote.ip.Equal(srcAddr))
}) })
} }
@ -141,19 +150,29 @@ func findUDPSocket(srcPort uint16) *socketStatus {
} }
func findTCPSocket(srcPort uint16, dstAddr net.IP, dstPort uint16) *socketStatus { func findTCPSocket(srcPort uint16, dstAddr net.IP, dstPort uint16) *socketStatus {
return findSocket("tcp", func(ss socketStatus) bool { proto := "tcp"
if dstAddr.To4() == nil {
proto += "6"
}
return findSocket(proto, func(ss socketStatus) bool {
return ss.remote.port == dstPort && ss.remote.ip.Equal(dstAddr) && ss.local.port == srcPort return ss.remote.port == dstPort && ss.remote.ip.Equal(dstAddr) && ss.local.port == srcPort
}) })
} }
func findTCPSocketAll(srcAddr net.IP, srcPort uint16, dstAddr net.IP, dstPort uint16, custdata []string) *socketStatus { func findTCPSocketAll(srcAddr net.IP, srcPort uint16, dstAddr net.IP, dstPort uint16, custdata []string) *socketStatus {
proto := "tcp"
if srcAddr.To4() == nil {
proto += "6"
}
if custdata == nil { if custdata == nil {
return findSocket("tcp", func(ss socketStatus) bool { return findSocket(proto, func(ss socketStatus) bool {
return ss.remote.port == dstPort && ss.remote.ip.Equal(dstAddr) && ss.local.port == srcPort && ss.local.ip.Equal(srcAddr) return ss.remote.port == dstPort && ss.remote.ip.Equal(dstAddr) && ss.local.port == srcPort && ss.local.ip.Equal(srcAddr)
}) })
} }
return findSocketCustom("tcp", custdata, func(ss socketStatus) bool { return findSocketCustom(proto, custdata, func(ss socketStatus) bool {
return ss.remote.port == dstPort && ss.remote.ip.Equal(dstAddr) && ss.local.port == srcPort && ss.local.ip.Equal(srcAddr) return ss.remote.port == dstPort && ss.remote.ip.Equal(dstAddr) && ss.local.port == srcPort && ss.local.ip.Equal(srcAddr)
}) })
} }
@ -298,3 +317,13 @@ func getSocketLines(proto string) []string {
} }
return lines return lines
} }
func addrMatchesAny(addr net.IP) bool {
wildcard := net.IP{0,0,0,0}
if addr.To4() == nil {
wildcard = net.IP{0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0}
}
return wildcard.Equal(addr)
}

Loading…
Cancel
Save