diff --git a/ipc/handlers.go b/ipc/handlers.go index e317300..27a9443 100644 --- a/ipc/handlers.go +++ b/ipc/handlers.go @@ -3,10 +3,57 @@ import ( "reflect" "errors" "fmt" + "github.com/op/go-logging" ) type handlerMap map[string]reflect.Value +var defaultLog = logging.MustGetLogger("ipc") + +type msgDispatcher struct { + log *logging.Logger + msgs chan *Message + hmap handlerMap +} + +func createDispatcher(log *logging.Logger, handlers...interface{}) (*msgDispatcher, error) { + md := &msgDispatcher{ + log: log, + msgs: make(chan *Message), + hmap: make(map[string]reflect.Value), + } + for _,h := range handlers { + if err := md.hmap.addHandler(h); err != nil { + return nil, err + } + } + go md.runDispatcher() + return md, nil +} + +func (md *msgDispatcher) close() { + close(md.msgs) +} + +func (md *msgDispatcher) dispatch(m *Message) { + md.msgs <- m +} + +func (md *msgDispatcher) logger() *logging.Logger { + if md.log != nil { + return md.log + } + return defaultLog +} + +func (md *msgDispatcher) runDispatcher() { + for m := range md.msgs { + if err := md.hmap.dispatch(m); err != nil { + md.logger().Warning("error dispatching message: %v", err) + } + } +} + func (handlers handlerMap) dispatch(m *Message) error { h,ok := handlers[m.Type] if !ok { diff --git a/ipc/ipc.go b/ipc/ipc.go index 863ca32..d904b38 100644 --- a/ipc/ipc.go +++ b/ipc/ipc.go @@ -11,37 +11,86 @@ import ( "crypto/rand" "encoding/hex" "fmt" + "io" ) const maxFdCount = 3 -var log = logging.MustGetLogger("oz") - type MsgConn struct { - msgs chan *Message - addr *net.UnixAddr + log *logging.Logger conn *net.UnixConn buf [1024]byte oob []byte - handlers handlerMap + disp *msgDispatcher factory MsgFactory isClosed bool - done chan bool idGen <-chan int respMan *responseManager + onClose func() } -func NewMsgConn(factory MsgFactory, address string) *MsgConn { - mc := new(MsgConn) - mc.addr = &net.UnixAddr{address, "unixgram"} - mc.oob = createOobBuffer() - mc.msgs = make(chan *Message) - mc.handlers = make(map[string]reflect.Value) - mc.factory = factory - mc.done = make(chan bool) - mc.idGen = newIdGen(mc.done) - mc.respMan = newResponseManager() - return mc +func RunServer(address string, factory MsgFactory, log *logging.Logger, handlers ...interface{}) error { + md,err := createDispatcher(log, handlers...) + if err != nil { + return err + } + + addr := &net.UnixAddr{address, "unix"} + listener,err := net.ListenUnix("unix", addr) + if err != nil { + md.close() + return err + } + done := make(chan bool) + idGen := newIdGen(done) + for { + conn,err := listener.AcceptUnix() + if err != nil { + close(md.msgs) + listener.Close() + return err + } + if err := setPassCred(conn); err != nil { + return errors.New("Failed to set SO_PASSCRED on accepted socket connection:"+ err.Error()) + } + mc := &MsgConn{ + conn: conn, + disp: md, + oob: createOobBuffer(), + factory: factory, + idGen: idGen, + respMan: newResponseManager(), + } + go mc.readLoop() + } + return nil +} + +func Connect(address string, factory MsgFactory, log *logging.Logger, handlers ...interface{}) (*MsgConn, error) { + md,err := createDispatcher(log, handlers...) + if err != nil { + return nil, err + } + conn,err := net.DialUnix("unix", nil, &net.UnixAddr{address, "unix"}) + if err != nil { + return nil, err + } + done := make(chan bool) + idGen := newIdGen(done) + mc := &MsgConn{ + conn: conn, + disp: md, + oob: createOobBuffer(), + factory: factory, + idGen: idGen, + respMan: newResponseManager(), + onClose: func() { + md.close() + close(done) + }, + } + go mc.readLoop() + return mc, nil } func newIdGen(done <-chan bool) <-chan int { @@ -62,38 +111,6 @@ func idGenLoop(done <-chan bool, out chan <- int) { } } -func (mc *MsgConn) Listen() error { - if mc.conn != nil { - return errors.New("cannot Listen(), already connected") - } - conn, err := net.ListenUnixgram("unixgram", mc.addr) - if err != nil { - return err - } - if err := setPassCred(conn); err != nil { - return err - } - mc.conn = conn - return nil -} - -func (mc *MsgConn) Connect() error { - if mc.conn != nil { - return errors.New("cannot Connect(), already connected") - } - clientAddr,err := CreateRandomAddress("@oz-") - if err != nil { - return err - } - conn, err := net.DialUnix("unixgram", &net.UnixAddr{clientAddr, "unixgram"}, nil) - if err != nil { - return err - } - mc.conn = conn - go mc.readLoop() - return nil -} - func CreateRandomAddress(prefix string) (string,error) { var bs [16]byte n,err := rand.Read(bs[:]) @@ -106,16 +123,6 @@ func CreateRandomAddress(prefix string) (string,error) { return prefix+ hex.EncodeToString(bs[:]),nil } -func (mc *MsgConn) Run() error { - go mc.readLoop() - for m := range mc.msgs { - if err := mc.handlers.dispatch(m); err != nil { - return fmt.Errorf("error dispatching message: %v", err) - } - } - return nil -} - func (mc *MsgConn) readLoop() { for { if mc.processOneMessage() { @@ -124,24 +131,36 @@ func (mc *MsgConn) readLoop() { } } +func (mc *MsgConn) logger() *logging.Logger { + if mc.log != nil { + return mc.log + } + return defaultLog +} + func (mc *MsgConn) processOneMessage() bool { m,err := mc.readMessage() if err != nil { - close(mc.msgs) + if err == io.EOF { + mc.Close() + return true + } if !mc.isClosed { - log.Warning("error on MsgConn.readMessage(): %v", err) + mc.logger().Warning("error on MsgConn.readMessage(): %v", err) } return true } if !mc.respMan.handle(m) { - mc.msgs <- m + mc.disp.dispatch(m) } return false } func (mc *MsgConn) Close() error { mc.isClosed = true - close(mc.done) + if mc.onClose != nil { + mc.onClose() + } return mc.conn.Close() } @@ -151,7 +170,7 @@ func createOobBuffer() []byte { } func (mc *MsgConn) readMessage() (*Message, error) { - n, oobn, _, a, err := mc.conn.ReadMsgUnix(mc.buf[:], mc.oob) + n, oobn, _, _, err := mc.conn.ReadMsgUnix(mc.buf[:], mc.oob) if err != nil { return nil, err } @@ -160,7 +179,6 @@ func (mc *MsgConn) readMessage() (*Message, error) { return nil, err } m.mconn = mc - m.Peer = a if oobn > 0 { err := m.parseControlData(mc.oob[:oobn]) @@ -198,7 +216,7 @@ func (mc *MsgConn) readMessage() (*Message, error) { func (mc *MsgConn) AddHandlers(args ...interface{}) error { for len(args) > 0 { - if err := mc.handlers.addHandler(args[0]); err != nil { + if err := mc.disp.hmap.addHandler(args[0]); err != nil { return err } args = args[1:] @@ -207,21 +225,21 @@ func (mc *MsgConn) AddHandlers(args ...interface{}) error { } func (mc *MsgConn) SendMsg(msg interface{}, fds... int) error { - return mc.sendMessage(msg, <-mc.idGen, mc.addr, fds...) + return mc.sendMessage(msg, <-mc.idGen, fds...) } func (mc *MsgConn) ExchangeMsg(msg interface{}, fds... int) (ResponseReader, error) { id := <-mc.idGen rr := mc.respMan.register(id) - if err := mc.sendMessage(msg, id, mc.addr, fds...); err != nil { + if err := mc.sendMessage(msg, id, fds...); err != nil { rr.Done() return nil, err } return rr,nil } -func (mc *MsgConn) sendMessage(msg interface{}, msgID int, dst *net.UnixAddr, fds... int) error { +func (mc *MsgConn) sendMessage(msg interface{}, msgID int, fds... int) error { msgType, err := getMessageType(msg) if err != nil { return err @@ -234,7 +252,7 @@ func (mc *MsgConn) sendMessage(msg interface{}, msgID int, dst *net.UnixAddr, fd if err != nil { return err } - return mc.sendRaw(raw, dst, fds...) + return mc.sendRaw(raw, fds...) } func getMessageType(msg interface{}) (string, error) { @@ -264,25 +282,17 @@ func (mc *MsgConn) newBaseMessage(msgType string, msgID int, body interface{}) ( return base, nil } -func (mc *MsgConn) sendRaw(data []byte, dst *net.UnixAddr, fds ...int) error { +func (mc *MsgConn) sendRaw(data []byte, fds ...int) error { if len(fds) > 0 { - return mc.sendWithFds(data, dst, fds) - } - return mc.write(data, dst) -} - -func (mc *MsgConn) write(data []byte, dst *net.UnixAddr) error { - if dst != nil { - _,err := mc.conn.WriteToUnix(data, dst) - return err + return mc.sendWithFds(data, fds) } _,err := mc.conn.Write(data) return err } -func (mc *MsgConn) sendWithFds(data []byte, dst *net.UnixAddr, fds []int) error { +func (mc *MsgConn) sendWithFds(data []byte, fds []int) error { oob := syscall.UnixRights(fds...) - _,_,err := mc.conn.WriteMsgUnix(data, oob, dst) + _,_,err := mc.conn.WriteMsgUnix(data, oob, nil) return err } diff --git a/ipc/message.go b/ipc/message.go index 6320bbc..e72f0d8 100644 --- a/ipc/message.go +++ b/ipc/message.go @@ -2,7 +2,6 @@ package ipc import ( "encoding/json" - "net" "syscall" "fmt" "reflect" @@ -15,7 +14,7 @@ func NewMsgFactory(msgTypes ...interface{}) MsgFactory { mf := (MsgFactory)(make(map[string]func() interface{})) for _, mt := range msgTypes { if err := mf.register(mt); err != nil { - log.Fatalf("failed adding (%T) in NewMsgFactory: %v", mt, err) + defaultLog.Fatalf("failed adding (%T) in NewMsgFactory: %v", mt, err) return nil } } @@ -56,7 +55,6 @@ type Message struct { Type string MsgID int Body interface{} - Peer *net.UnixAddr Ucred *syscall.Ucred Fds []int mconn *MsgConn @@ -120,5 +118,5 @@ func (m *Message) parseControlData(data []byte) error { } func (m *Message) Respond(msg interface{}, fds... int) error { - return m.mconn.sendMessage(msg, m.MsgID, m.Peer, fds...) + return m.mconn.sendMessage(msg, m.MsgID, fds...) } diff --git a/oz-daemon/client.go b/oz-daemon/client.go index e2cbef2..8dd8202 100644 --- a/oz-daemon/client.go +++ b/oz-daemon/client.go @@ -7,11 +7,7 @@ import ( ) func clientConnect() (*ipc.MsgConn, error) { - c := ipc.NewMsgConn(messageFactory, SocketName) - if err := c.Connect(); err != nil { - return nil, err - } - return c, nil + return ipc.Connect(SocketName, messageFactory, nil) } func clientSend(msg interface{}) (*ipc.Message, error) { diff --git a/oz-daemon/daemon.go b/oz-daemon/daemon.go index 1b65ab4..d0b6111 100644 --- a/oz-daemon/daemon.go +++ b/oz-daemon/daemon.go @@ -25,6 +25,7 @@ func Main() { d := initialize() err := runServer( + d.log, d.handlePing, d.handleListProfiles, d.handleLaunch, @@ -65,15 +66,12 @@ func (d *daemonState) handleChildExit(pid int, wstatus syscall.WaitStatus) { d.Notice("No sandbox found with oz-init pid = %d", pid) } -func runServer(args ...interface{}) error { - serv := ipc.NewMsgConn(messageFactory, SocketName) - if err := serv.AddHandlers(args...); err != nil { - return err - } - if err := serv.Listen(); err != nil { +func runServer(log *logging.Logger, args ...interface{}) error { + err := ipc.RunServer(SocketName, messageFactory, log, args...) + if err != nil { return err } - return serv.Run() + return nil } func (d * daemonState) handlePing(msg *PingMsg, m *ipc.Message) error { diff --git a/oz-init/client.go b/oz-init/client.go index 90d6541..8671205 100644 --- a/oz-init/client.go +++ b/oz-init/client.go @@ -6,11 +6,7 @@ import ( ) func clientConnect(addr string) (*ipc.MsgConn, error) { - c := ipc.NewMsgConn(messageFactory, addr) - if err := c.Connect(); err != nil { - return nil, err - } - return c, nil + return ipc.Connect(addr, messageFactory, nil) } func clientSend(addr string, msg interface{}) (*ipc.Message, error) { diff --git a/oz-init/init.go b/oz-init/init.go index e0aa228..46a1be7 100644 --- a/oz-init/init.go +++ b/oz-init/init.go @@ -130,13 +130,13 @@ func (st *initState) runInit() { oz.ReapChildProcs(st.log, st.handleChildExit) - serv := ipc.NewMsgConn(messageFactory, st.address) - serv.AddHandlers( + err := ipc.RunServer(st.address, messageFactory, st.log, handlePing, st.handleRunShell, ) - serv.Listen() - serv.Run() + if err != nil { + st.log.Warning("RunServer returned err: %v", err) + } st.log.Info("oz-init exiting...") }