refactor IPC to use SOCK_STREAM instead of SOCK_DGRAM

networking
brl 9 years ago
parent deb2751d85
commit 120dbd97a2

@ -3,10 +3,57 @@ import (
"reflect" "reflect"
"errors" "errors"
"fmt" "fmt"
"github.com/op/go-logging"
) )
type handlerMap map[string]reflect.Value type handlerMap map[string]reflect.Value
var defaultLog = logging.MustGetLogger("ipc")
type msgDispatcher struct {
log *logging.Logger
msgs chan *Message
hmap handlerMap
}
func createDispatcher(log *logging.Logger, handlers...interface{}) (*msgDispatcher, error) {
md := &msgDispatcher{
log: log,
msgs: make(chan *Message),
hmap: make(map[string]reflect.Value),
}
for _,h := range handlers {
if err := md.hmap.addHandler(h); err != nil {
return nil, err
}
}
go md.runDispatcher()
return md, nil
}
func (md *msgDispatcher) close() {
close(md.msgs)
}
func (md *msgDispatcher) dispatch(m *Message) {
md.msgs <- m
}
func (md *msgDispatcher) logger() *logging.Logger {
if md.log != nil {
return md.log
}
return defaultLog
}
func (md *msgDispatcher) runDispatcher() {
for m := range md.msgs {
if err := md.hmap.dispatch(m); err != nil {
md.logger().Warning("error dispatching message: %v", err)
}
}
}
func (handlers handlerMap) dispatch(m *Message) error { func (handlers handlerMap) dispatch(m *Message) error {
h,ok := handlers[m.Type] h,ok := handlers[m.Type]
if !ok { if !ok {

@ -11,37 +11,86 @@ import (
"crypto/rand" "crypto/rand"
"encoding/hex" "encoding/hex"
"fmt" "fmt"
"io"
) )
const maxFdCount = 3 const maxFdCount = 3
var log = logging.MustGetLogger("oz")
type MsgConn struct { type MsgConn struct {
msgs chan *Message log *logging.Logger
addr *net.UnixAddr
conn *net.UnixConn conn *net.UnixConn
buf [1024]byte buf [1024]byte
oob []byte oob []byte
handlers handlerMap disp *msgDispatcher
factory MsgFactory factory MsgFactory
isClosed bool isClosed bool
done chan bool
idGen <-chan int idGen <-chan int
respMan *responseManager respMan *responseManager
onClose func()
} }
func NewMsgConn(factory MsgFactory, address string) *MsgConn { func RunServer(address string, factory MsgFactory, log *logging.Logger, handlers ...interface{}) error {
mc := new(MsgConn) md,err := createDispatcher(log, handlers...)
mc.addr = &net.UnixAddr{address, "unixgram"} if err != nil {
mc.oob = createOobBuffer() return err
mc.msgs = make(chan *Message) }
mc.handlers = make(map[string]reflect.Value)
mc.factory = factory addr := &net.UnixAddr{address, "unix"}
mc.done = make(chan bool) listener,err := net.ListenUnix("unix", addr)
mc.idGen = newIdGen(mc.done) if err != nil {
mc.respMan = newResponseManager() md.close()
return mc return err
}
done := make(chan bool)
idGen := newIdGen(done)
for {
conn,err := listener.AcceptUnix()
if err != nil {
close(md.msgs)
listener.Close()
return err
}
if err := setPassCred(conn); err != nil {
return errors.New("Failed to set SO_PASSCRED on accepted socket connection:"+ err.Error())
}
mc := &MsgConn{
conn: conn,
disp: md,
oob: createOobBuffer(),
factory: factory,
idGen: idGen,
respMan: newResponseManager(),
}
go mc.readLoop()
}
return nil
}
func Connect(address string, factory MsgFactory, log *logging.Logger, handlers ...interface{}) (*MsgConn, error) {
md,err := createDispatcher(log, handlers...)
if err != nil {
return nil, err
}
conn,err := net.DialUnix("unix", nil, &net.UnixAddr{address, "unix"})
if err != nil {
return nil, err
}
done := make(chan bool)
idGen := newIdGen(done)
mc := &MsgConn{
conn: conn,
disp: md,
oob: createOobBuffer(),
factory: factory,
idGen: idGen,
respMan: newResponseManager(),
onClose: func() {
md.close()
close(done)
},
}
go mc.readLoop()
return mc, nil
} }
func newIdGen(done <-chan bool) <-chan int { func newIdGen(done <-chan bool) <-chan int {
@ -62,38 +111,6 @@ func idGenLoop(done <-chan bool, out chan <- int) {
} }
} }
func (mc *MsgConn) Listen() error {
if mc.conn != nil {
return errors.New("cannot Listen(), already connected")
}
conn, err := net.ListenUnixgram("unixgram", mc.addr)
if err != nil {
return err
}
if err := setPassCred(conn); err != nil {
return err
}
mc.conn = conn
return nil
}
func (mc *MsgConn) Connect() error {
if mc.conn != nil {
return errors.New("cannot Connect(), already connected")
}
clientAddr,err := CreateRandomAddress("@oz-")
if err != nil {
return err
}
conn, err := net.DialUnix("unixgram", &net.UnixAddr{clientAddr, "unixgram"}, nil)
if err != nil {
return err
}
mc.conn = conn
go mc.readLoop()
return nil
}
func CreateRandomAddress(prefix string) (string,error) { func CreateRandomAddress(prefix string) (string,error) {
var bs [16]byte var bs [16]byte
n,err := rand.Read(bs[:]) n,err := rand.Read(bs[:])
@ -106,16 +123,6 @@ func CreateRandomAddress(prefix string) (string,error) {
return prefix+ hex.EncodeToString(bs[:]),nil return prefix+ hex.EncodeToString(bs[:]),nil
} }
func (mc *MsgConn) Run() error {
go mc.readLoop()
for m := range mc.msgs {
if err := mc.handlers.dispatch(m); err != nil {
return fmt.Errorf("error dispatching message: %v", err)
}
}
return nil
}
func (mc *MsgConn) readLoop() { func (mc *MsgConn) readLoop() {
for { for {
if mc.processOneMessage() { if mc.processOneMessage() {
@ -124,24 +131,36 @@ func (mc *MsgConn) readLoop() {
} }
} }
func (mc *MsgConn) logger() *logging.Logger {
if mc.log != nil {
return mc.log
}
return defaultLog
}
func (mc *MsgConn) processOneMessage() bool { func (mc *MsgConn) processOneMessage() bool {
m,err := mc.readMessage() m,err := mc.readMessage()
if err != nil { if err != nil {
close(mc.msgs) if err == io.EOF {
mc.Close()
return true
}
if !mc.isClosed { if !mc.isClosed {
log.Warning("error on MsgConn.readMessage(): %v", err) mc.logger().Warning("error on MsgConn.readMessage(): %v", err)
} }
return true return true
} }
if !mc.respMan.handle(m) { if !mc.respMan.handle(m) {
mc.msgs <- m mc.disp.dispatch(m)
} }
return false return false
} }
func (mc *MsgConn) Close() error { func (mc *MsgConn) Close() error {
mc.isClosed = true mc.isClosed = true
close(mc.done) if mc.onClose != nil {
mc.onClose()
}
return mc.conn.Close() return mc.conn.Close()
} }
@ -151,7 +170,7 @@ func createOobBuffer() []byte {
} }
func (mc *MsgConn) readMessage() (*Message, error) { func (mc *MsgConn) readMessage() (*Message, error) {
n, oobn, _, a, err := mc.conn.ReadMsgUnix(mc.buf[:], mc.oob) n, oobn, _, _, err := mc.conn.ReadMsgUnix(mc.buf[:], mc.oob)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -160,7 +179,6 @@ func (mc *MsgConn) readMessage() (*Message, error) {
return nil, err return nil, err
} }
m.mconn = mc m.mconn = mc
m.Peer = a
if oobn > 0 { if oobn > 0 {
err := m.parseControlData(mc.oob[:oobn]) err := m.parseControlData(mc.oob[:oobn])
@ -198,7 +216,7 @@ func (mc *MsgConn) readMessage() (*Message, error) {
func (mc *MsgConn) AddHandlers(args ...interface{}) error { func (mc *MsgConn) AddHandlers(args ...interface{}) error {
for len(args) > 0 { for len(args) > 0 {
if err := mc.handlers.addHandler(args[0]); err != nil { if err := mc.disp.hmap.addHandler(args[0]); err != nil {
return err return err
} }
args = args[1:] args = args[1:]
@ -207,21 +225,21 @@ func (mc *MsgConn) AddHandlers(args ...interface{}) error {
} }
func (mc *MsgConn) SendMsg(msg interface{}, fds... int) error { func (mc *MsgConn) SendMsg(msg interface{}, fds... int) error {
return mc.sendMessage(msg, <-mc.idGen, mc.addr, fds...) return mc.sendMessage(msg, <-mc.idGen, fds...)
} }
func (mc *MsgConn) ExchangeMsg(msg interface{}, fds... int) (ResponseReader, error) { func (mc *MsgConn) ExchangeMsg(msg interface{}, fds... int) (ResponseReader, error) {
id := <-mc.idGen id := <-mc.idGen
rr := mc.respMan.register(id) rr := mc.respMan.register(id)
if err := mc.sendMessage(msg, id, mc.addr, fds...); err != nil { if err := mc.sendMessage(msg, id, fds...); err != nil {
rr.Done() rr.Done()
return nil, err return nil, err
} }
return rr,nil return rr,nil
} }
func (mc *MsgConn) sendMessage(msg interface{}, msgID int, dst *net.UnixAddr, fds... int) error { func (mc *MsgConn) sendMessage(msg interface{}, msgID int, fds... int) error {
msgType, err := getMessageType(msg) msgType, err := getMessageType(msg)
if err != nil { if err != nil {
return err return err
@ -234,7 +252,7 @@ func (mc *MsgConn) sendMessage(msg interface{}, msgID int, dst *net.UnixAddr, fd
if err != nil { if err != nil {
return err return err
} }
return mc.sendRaw(raw, dst, fds...) return mc.sendRaw(raw, fds...)
} }
func getMessageType(msg interface{}) (string, error) { func getMessageType(msg interface{}) (string, error) {
@ -264,25 +282,17 @@ func (mc *MsgConn) newBaseMessage(msgType string, msgID int, body interface{}) (
return base, nil return base, nil
} }
func (mc *MsgConn) sendRaw(data []byte, dst *net.UnixAddr, fds ...int) error { func (mc *MsgConn) sendRaw(data []byte, fds ...int) error {
if len(fds) > 0 { if len(fds) > 0 {
return mc.sendWithFds(data, dst, fds) return mc.sendWithFds(data, fds)
}
return mc.write(data, dst)
}
func (mc *MsgConn) write(data []byte, dst *net.UnixAddr) error {
if dst != nil {
_,err := mc.conn.WriteToUnix(data, dst)
return err
} }
_,err := mc.conn.Write(data) _,err := mc.conn.Write(data)
return err return err
} }
func (mc *MsgConn) sendWithFds(data []byte, dst *net.UnixAddr, fds []int) error { func (mc *MsgConn) sendWithFds(data []byte, fds []int) error {
oob := syscall.UnixRights(fds...) oob := syscall.UnixRights(fds...)
_,_,err := mc.conn.WriteMsgUnix(data, oob, dst) _,_,err := mc.conn.WriteMsgUnix(data, oob, nil)
return err return err
} }

@ -2,7 +2,6 @@ package ipc
import ( import (
"encoding/json" "encoding/json"
"net"
"syscall" "syscall"
"fmt" "fmt"
"reflect" "reflect"
@ -15,7 +14,7 @@ func NewMsgFactory(msgTypes ...interface{}) MsgFactory {
mf := (MsgFactory)(make(map[string]func() interface{})) mf := (MsgFactory)(make(map[string]func() interface{}))
for _, mt := range msgTypes { for _, mt := range msgTypes {
if err := mf.register(mt); err != nil { if err := mf.register(mt); err != nil {
log.Fatalf("failed adding (%T) in NewMsgFactory: %v", mt, err) defaultLog.Fatalf("failed adding (%T) in NewMsgFactory: %v", mt, err)
return nil return nil
} }
} }
@ -56,7 +55,6 @@ type Message struct {
Type string Type string
MsgID int MsgID int
Body interface{} Body interface{}
Peer *net.UnixAddr
Ucred *syscall.Ucred Ucred *syscall.Ucred
Fds []int Fds []int
mconn *MsgConn mconn *MsgConn
@ -120,5 +118,5 @@ func (m *Message) parseControlData(data []byte) error {
} }
func (m *Message) Respond(msg interface{}, fds... int) error { func (m *Message) Respond(msg interface{}, fds... int) error {
return m.mconn.sendMessage(msg, m.MsgID, m.Peer, fds...) return m.mconn.sendMessage(msg, m.MsgID, fds...)
} }

@ -7,11 +7,7 @@ import (
) )
func clientConnect() (*ipc.MsgConn, error) { func clientConnect() (*ipc.MsgConn, error) {
c := ipc.NewMsgConn(messageFactory, SocketName) return ipc.Connect(SocketName, messageFactory, nil)
if err := c.Connect(); err != nil {
return nil, err
}
return c, nil
} }
func clientSend(msg interface{}) (*ipc.Message, error) { func clientSend(msg interface{}) (*ipc.Message, error) {

@ -25,6 +25,7 @@ func Main() {
d := initialize() d := initialize()
err := runServer( err := runServer(
d.log,
d.handlePing, d.handlePing,
d.handleListProfiles, d.handleListProfiles,
d.handleLaunch, d.handleLaunch,
@ -65,15 +66,12 @@ func (d *daemonState) handleChildExit(pid int, wstatus syscall.WaitStatus) {
d.Notice("No sandbox found with oz-init pid = %d", pid) d.Notice("No sandbox found with oz-init pid = %d", pid)
} }
func runServer(args ...interface{}) error { func runServer(log *logging.Logger, args ...interface{}) error {
serv := ipc.NewMsgConn(messageFactory, SocketName) err := ipc.RunServer(SocketName, messageFactory, log, args...)
if err := serv.AddHandlers(args...); err != nil { if err != nil {
return err
}
if err := serv.Listen(); err != nil {
return err return err
} }
return serv.Run() return nil
} }
func (d * daemonState) handlePing(msg *PingMsg, m *ipc.Message) error { func (d * daemonState) handlePing(msg *PingMsg, m *ipc.Message) error {

@ -6,11 +6,7 @@ import (
) )
func clientConnect(addr string) (*ipc.MsgConn, error) { func clientConnect(addr string) (*ipc.MsgConn, error) {
c := ipc.NewMsgConn(messageFactory, addr) return ipc.Connect(addr, messageFactory, nil)
if err := c.Connect(); err != nil {
return nil, err
}
return c, nil
} }
func clientSend(addr string, msg interface{}) (*ipc.Message, error) { func clientSend(addr string, msg interface{}) (*ipc.Message, error) {

@ -130,13 +130,13 @@ func (st *initState) runInit() {
oz.ReapChildProcs(st.log, st.handleChildExit) oz.ReapChildProcs(st.log, st.handleChildExit)
serv := ipc.NewMsgConn(messageFactory, st.address) err := ipc.RunServer(st.address, messageFactory, st.log,
serv.AddHandlers(
handlePing, handlePing,
st.handleRunShell, st.handleRunShell,
) )
serv.Listen() if err != nil {
serv.Run() st.log.Warning("RunServer returned err: %v", err)
}
st.log.Info("oz-init exiting...") st.log.Info("oz-init exiting...")
} }

Loading…
Cancel
Save