|
|
|
@ -11,37 +11,86 @@ import (
|
|
|
|
|
"crypto/rand"
|
|
|
|
|
"encoding/hex"
|
|
|
|
|
"fmt"
|
|
|
|
|
"io"
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
const maxFdCount = 3
|
|
|
|
|
|
|
|
|
|
var log = logging.MustGetLogger("oz")
|
|
|
|
|
|
|
|
|
|
type MsgConn struct {
|
|
|
|
|
msgs chan *Message
|
|
|
|
|
addr *net.UnixAddr
|
|
|
|
|
log *logging.Logger
|
|
|
|
|
conn *net.UnixConn
|
|
|
|
|
buf [1024]byte
|
|
|
|
|
oob []byte
|
|
|
|
|
handlers handlerMap
|
|
|
|
|
disp *msgDispatcher
|
|
|
|
|
factory MsgFactory
|
|
|
|
|
isClosed bool
|
|
|
|
|
done chan bool
|
|
|
|
|
idGen <-chan int
|
|
|
|
|
respMan *responseManager
|
|
|
|
|
onClose func()
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
func NewMsgConn(factory MsgFactory, address string) *MsgConn {
|
|
|
|
|
mc := new(MsgConn)
|
|
|
|
|
mc.addr = &net.UnixAddr{address, "unixgram"}
|
|
|
|
|
mc.oob = createOobBuffer()
|
|
|
|
|
mc.msgs = make(chan *Message)
|
|
|
|
|
mc.handlers = make(map[string]reflect.Value)
|
|
|
|
|
mc.factory = factory
|
|
|
|
|
mc.done = make(chan bool)
|
|
|
|
|
mc.idGen = newIdGen(mc.done)
|
|
|
|
|
mc.respMan = newResponseManager()
|
|
|
|
|
return mc
|
|
|
|
|
func RunServer(address string, factory MsgFactory, log *logging.Logger, handlers ...interface{}) error {
|
|
|
|
|
md,err := createDispatcher(log, handlers...)
|
|
|
|
|
if err != nil {
|
|
|
|
|
return err
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
addr := &net.UnixAddr{address, "unix"}
|
|
|
|
|
listener,err := net.ListenUnix("unix", addr)
|
|
|
|
|
if err != nil {
|
|
|
|
|
md.close()
|
|
|
|
|
return err
|
|
|
|
|
}
|
|
|
|
|
done := make(chan bool)
|
|
|
|
|
idGen := newIdGen(done)
|
|
|
|
|
for {
|
|
|
|
|
conn,err := listener.AcceptUnix()
|
|
|
|
|
if err != nil {
|
|
|
|
|
close(md.msgs)
|
|
|
|
|
listener.Close()
|
|
|
|
|
return err
|
|
|
|
|
}
|
|
|
|
|
if err := setPassCred(conn); err != nil {
|
|
|
|
|
return errors.New("Failed to set SO_PASSCRED on accepted socket connection:"+ err.Error())
|
|
|
|
|
}
|
|
|
|
|
mc := &MsgConn{
|
|
|
|
|
conn: conn,
|
|
|
|
|
disp: md,
|
|
|
|
|
oob: createOobBuffer(),
|
|
|
|
|
factory: factory,
|
|
|
|
|
idGen: idGen,
|
|
|
|
|
respMan: newResponseManager(),
|
|
|
|
|
}
|
|
|
|
|
go mc.readLoop()
|
|
|
|
|
}
|
|
|
|
|
return nil
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
func Connect(address string, factory MsgFactory, log *logging.Logger, handlers ...interface{}) (*MsgConn, error) {
|
|
|
|
|
md,err := createDispatcher(log, handlers...)
|
|
|
|
|
if err != nil {
|
|
|
|
|
return nil, err
|
|
|
|
|
}
|
|
|
|
|
conn,err := net.DialUnix("unix", nil, &net.UnixAddr{address, "unix"})
|
|
|
|
|
if err != nil {
|
|
|
|
|
return nil, err
|
|
|
|
|
}
|
|
|
|
|
done := make(chan bool)
|
|
|
|
|
idGen := newIdGen(done)
|
|
|
|
|
mc := &MsgConn{
|
|
|
|
|
conn: conn,
|
|
|
|
|
disp: md,
|
|
|
|
|
oob: createOobBuffer(),
|
|
|
|
|
factory: factory,
|
|
|
|
|
idGen: idGen,
|
|
|
|
|
respMan: newResponseManager(),
|
|
|
|
|
onClose: func() {
|
|
|
|
|
md.close()
|
|
|
|
|
close(done)
|
|
|
|
|
},
|
|
|
|
|
}
|
|
|
|
|
go mc.readLoop()
|
|
|
|
|
return mc, nil
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
func newIdGen(done <-chan bool) <-chan int {
|
|
|
|
@ -62,38 +111,6 @@ func idGenLoop(done <-chan bool, out chan <- int) {
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
func (mc *MsgConn) Listen() error {
|
|
|
|
|
if mc.conn != nil {
|
|
|
|
|
return errors.New("cannot Listen(), already connected")
|
|
|
|
|
}
|
|
|
|
|
conn, err := net.ListenUnixgram("unixgram", mc.addr)
|
|
|
|
|
if err != nil {
|
|
|
|
|
return err
|
|
|
|
|
}
|
|
|
|
|
if err := setPassCred(conn); err != nil {
|
|
|
|
|
return err
|
|
|
|
|
}
|
|
|
|
|
mc.conn = conn
|
|
|
|
|
return nil
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
func (mc *MsgConn) Connect() error {
|
|
|
|
|
if mc.conn != nil {
|
|
|
|
|
return errors.New("cannot Connect(), already connected")
|
|
|
|
|
}
|
|
|
|
|
clientAddr,err := CreateRandomAddress("@oz-")
|
|
|
|
|
if err != nil {
|
|
|
|
|
return err
|
|
|
|
|
}
|
|
|
|
|
conn, err := net.DialUnix("unixgram", &net.UnixAddr{clientAddr, "unixgram"}, nil)
|
|
|
|
|
if err != nil {
|
|
|
|
|
return err
|
|
|
|
|
}
|
|
|
|
|
mc.conn = conn
|
|
|
|
|
go mc.readLoop()
|
|
|
|
|
return nil
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
func CreateRandomAddress(prefix string) (string,error) {
|
|
|
|
|
var bs [16]byte
|
|
|
|
|
n,err := rand.Read(bs[:])
|
|
|
|
@ -106,16 +123,6 @@ func CreateRandomAddress(prefix string) (string,error) {
|
|
|
|
|
return prefix+ hex.EncodeToString(bs[:]),nil
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
func (mc *MsgConn) Run() error {
|
|
|
|
|
go mc.readLoop()
|
|
|
|
|
for m := range mc.msgs {
|
|
|
|
|
if err := mc.handlers.dispatch(m); err != nil {
|
|
|
|
|
return fmt.Errorf("error dispatching message: %v", err)
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
return nil
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
func (mc *MsgConn) readLoop() {
|
|
|
|
|
for {
|
|
|
|
|
if mc.processOneMessage() {
|
|
|
|
@ -124,24 +131,36 @@ func (mc *MsgConn) readLoop() {
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
func (mc *MsgConn) logger() *logging.Logger {
|
|
|
|
|
if mc.log != nil {
|
|
|
|
|
return mc.log
|
|
|
|
|
}
|
|
|
|
|
return defaultLog
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
func (mc *MsgConn) processOneMessage() bool {
|
|
|
|
|
m,err := mc.readMessage()
|
|
|
|
|
if err != nil {
|
|
|
|
|
close(mc.msgs)
|
|
|
|
|
if err == io.EOF {
|
|
|
|
|
mc.Close()
|
|
|
|
|
return true
|
|
|
|
|
}
|
|
|
|
|
if !mc.isClosed {
|
|
|
|
|
log.Warning("error on MsgConn.readMessage(): %v", err)
|
|
|
|
|
mc.logger().Warning("error on MsgConn.readMessage(): %v", err)
|
|
|
|
|
}
|
|
|
|
|
return true
|
|
|
|
|
}
|
|
|
|
|
if !mc.respMan.handle(m) {
|
|
|
|
|
mc.msgs <- m
|
|
|
|
|
mc.disp.dispatch(m)
|
|
|
|
|
}
|
|
|
|
|
return false
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
func (mc *MsgConn) Close() error {
|
|
|
|
|
mc.isClosed = true
|
|
|
|
|
close(mc.done)
|
|
|
|
|
if mc.onClose != nil {
|
|
|
|
|
mc.onClose()
|
|
|
|
|
}
|
|
|
|
|
return mc.conn.Close()
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
@ -151,7 +170,7 @@ func createOobBuffer() []byte {
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
func (mc *MsgConn) readMessage() (*Message, error) {
|
|
|
|
|
n, oobn, _, a, err := mc.conn.ReadMsgUnix(mc.buf[:], mc.oob)
|
|
|
|
|
n, oobn, _, _, err := mc.conn.ReadMsgUnix(mc.buf[:], mc.oob)
|
|
|
|
|
if err != nil {
|
|
|
|
|
return nil, err
|
|
|
|
|
}
|
|
|
|
@ -160,7 +179,6 @@ func (mc *MsgConn) readMessage() (*Message, error) {
|
|
|
|
|
return nil, err
|
|
|
|
|
}
|
|
|
|
|
m.mconn = mc
|
|
|
|
|
m.Peer = a
|
|
|
|
|
|
|
|
|
|
if oobn > 0 {
|
|
|
|
|
err := m.parseControlData(mc.oob[:oobn])
|
|
|
|
@ -198,7 +216,7 @@ func (mc *MsgConn) readMessage() (*Message, error) {
|
|
|
|
|
|
|
|
|
|
func (mc *MsgConn) AddHandlers(args ...interface{}) error {
|
|
|
|
|
for len(args) > 0 {
|
|
|
|
|
if err := mc.handlers.addHandler(args[0]); err != nil {
|
|
|
|
|
if err := mc.disp.hmap.addHandler(args[0]); err != nil {
|
|
|
|
|
return err
|
|
|
|
|
}
|
|
|
|
|
args = args[1:]
|
|
|
|
@ -207,21 +225,21 @@ func (mc *MsgConn) AddHandlers(args ...interface{}) error {
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
func (mc *MsgConn) SendMsg(msg interface{}, fds... int) error {
|
|
|
|
|
return mc.sendMessage(msg, <-mc.idGen, mc.addr, fds...)
|
|
|
|
|
return mc.sendMessage(msg, <-mc.idGen, fds...)
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
func (mc *MsgConn) ExchangeMsg(msg interface{}, fds... int) (ResponseReader, error) {
|
|
|
|
|
id := <-mc.idGen
|
|
|
|
|
rr := mc.respMan.register(id)
|
|
|
|
|
|
|
|
|
|
if err := mc.sendMessage(msg, id, mc.addr, fds...); err != nil {
|
|
|
|
|
if err := mc.sendMessage(msg, id, fds...); err != nil {
|
|
|
|
|
rr.Done()
|
|
|
|
|
return nil, err
|
|
|
|
|
}
|
|
|
|
|
return rr,nil
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
func (mc *MsgConn) sendMessage(msg interface{}, msgID int, dst *net.UnixAddr, fds... int) error {
|
|
|
|
|
func (mc *MsgConn) sendMessage(msg interface{}, msgID int, fds... int) error {
|
|
|
|
|
msgType, err := getMessageType(msg)
|
|
|
|
|
if err != nil {
|
|
|
|
|
return err
|
|
|
|
@ -234,7 +252,7 @@ func (mc *MsgConn) sendMessage(msg interface{}, msgID int, dst *net.UnixAddr, fd
|
|
|
|
|
if err != nil {
|
|
|
|
|
return err
|
|
|
|
|
}
|
|
|
|
|
return mc.sendRaw(raw, dst, fds...)
|
|
|
|
|
return mc.sendRaw(raw, fds...)
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
func getMessageType(msg interface{}) (string, error) {
|
|
|
|
@ -264,25 +282,17 @@ func (mc *MsgConn) newBaseMessage(msgType string, msgID int, body interface{}) (
|
|
|
|
|
return base, nil
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
func (mc *MsgConn) sendRaw(data []byte, dst *net.UnixAddr, fds ...int) error {
|
|
|
|
|
func (mc *MsgConn) sendRaw(data []byte, fds ...int) error {
|
|
|
|
|
if len(fds) > 0 {
|
|
|
|
|
return mc.sendWithFds(data, dst, fds)
|
|
|
|
|
}
|
|
|
|
|
return mc.write(data, dst)
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
func (mc *MsgConn) write(data []byte, dst *net.UnixAddr) error {
|
|
|
|
|
if dst != nil {
|
|
|
|
|
_,err := mc.conn.WriteToUnix(data, dst)
|
|
|
|
|
return err
|
|
|
|
|
return mc.sendWithFds(data, fds)
|
|
|
|
|
}
|
|
|
|
|
_,err := mc.conn.Write(data)
|
|
|
|
|
return err
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
func (mc *MsgConn) sendWithFds(data []byte, dst *net.UnixAddr, fds []int) error {
|
|
|
|
|
func (mc *MsgConn) sendWithFds(data []byte, fds []int) error {
|
|
|
|
|
oob := syscall.UnixRights(fds...)
|
|
|
|
|
_,_,err := mc.conn.WriteMsgUnix(data, oob, dst)
|
|
|
|
|
_,_,err := mc.conn.WriteMsgUnix(data, oob, nil)
|
|
|
|
|
return err
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|