diff --git a/vendor/github.com/subgraph/go-procsnitch/proc_test.go b/vendor/github.com/subgraph/go-procsnitch/proc_test.go new file mode 100644 index 0000000..a23dcf8 --- /dev/null +++ b/vendor/github.com/subgraph/go-procsnitch/proc_test.go @@ -0,0 +1,78 @@ +package procsnitch + +import ( + "fmt" + "net" + "os" + "sync" + "testing" + "time" +) + +type TestListener struct { + network string + address string + waitGroup *sync.WaitGroup +} + +func NewTestListener(network, address string, wg *sync.WaitGroup) *TestListener { + l := TestListener{ + network: network, + address: address, + waitGroup: wg, + } + return &l +} + +func (l *TestListener) AcceptLoop() { + l.waitGroup.Add(1) + listener, err := net.Listen(l.network, l.address) + if err != nil { + panic(err) + } + defer listener.Close() + + l.waitGroup.Done() + + for { + conn, err := listener.Accept() + if err != nil { + panic(err) + } + + go l.SessionWorker(conn) + } +} + +func (l *TestListener) SessionWorker(conn net.Conn) { + for { + time.Sleep(time.Second * 60) + } +} + +func TestLookupUNIXSocketProcess(t *testing.T) { + // listen for a connection + var wg sync.WaitGroup + network := "unix" + address := "./testing_socket" + l := NewTestListener(network, address, &wg) + go l.AcceptLoop() + wg.Wait() + + // XXX fix me + time.Sleep(time.Second * 1) + + // dial a connection + conn, err := net.Dial(network, address) + if err != nil { + panic(err) + } + defer os.Remove(address) + conn.Write([]byte("hello")) + procInfo := LookupUNIXSocketProcess(address) + if procInfo == nil { + t.Error("failured to acquire proc info for unix domain socket") + t.Fail() + } + fmt.Println("Acquired proc info for UNIX domain socket!", procInfo) +} diff --git a/vendor/github.com/subgraph/go-procsnitch/socket.go b/vendor/github.com/subgraph/go-procsnitch/socket.go index df667e0..448eba2 100644 --- a/vendor/github.com/subgraph/go-procsnitch/socket.go +++ b/vendor/github.com/subgraph/go-procsnitch/socket.go @@ -111,15 +111,12 @@ func findUDPSocketAll(srcAddr net.IP, srcPort uint16, dstAddr net.IP, dstPort ui if custdata == nil { if strictness == MATCH_STRICT { return findSocket(proto, func(ss socketStatus) bool { - //return ss.remote.ip.Equal(dstAddr) && ss.local.port == srcPort && ss.local.ip.Equal(srcAddr) - return ss.local.port == srcPort && ss.local.ip.Equal(srcAddr) + return ss.remote.ip.Equal(dstAddr) && ss.local.port == srcPort && ss.local.ip.Equal(srcAddr) }) } else if strictness == MATCH_LOOSE { return findSocket(proto, func(ss socketStatus) bool { - return ss.local.port == srcPort && (ss.local.ip.Equal(srcAddr) || ss.local.ip.Equal(net.IPv4(0,0,0,0))) - /* return (ss.remote.ip.Equal(dstAddr) || addrMatchesAny(ss.remote.ip)) && ss.local.port == srcPort && ss.local.ip.Equal(srcAddr) || - (ss.local.ip.Equal(dstAddr) || addrMatchesAny(ss.local.ip)) && ss.remote.port == srcPort && ss.remote.ip.Equal(srcAddr) */ + (ss.local.ip.Equal(dstAddr) || addrMatchesAny(ss.local.ip)) && ss.remote.port == srcPort && ss.remote.ip.Equal(srcAddr) }) } return findSocket(proto, func(ss socketStatus) bool {