refactor IPC to use SOCK_STREAM instead of SOCK_DGRAM

networking
brl 9 years ago
parent deb2751d85
commit 120dbd97a2

@ -3,10 +3,57 @@ import (
"reflect"
"errors"
"fmt"
"github.com/op/go-logging"
)
type handlerMap map[string]reflect.Value
var defaultLog = logging.MustGetLogger("ipc")
type msgDispatcher struct {
log *logging.Logger
msgs chan *Message
hmap handlerMap
}
func createDispatcher(log *logging.Logger, handlers...interface{}) (*msgDispatcher, error) {
md := &msgDispatcher{
log: log,
msgs: make(chan *Message),
hmap: make(map[string]reflect.Value),
}
for _,h := range handlers {
if err := md.hmap.addHandler(h); err != nil {
return nil, err
}
}
go md.runDispatcher()
return md, nil
}
func (md *msgDispatcher) close() {
close(md.msgs)
}
func (md *msgDispatcher) dispatch(m *Message) {
md.msgs <- m
}
func (md *msgDispatcher) logger() *logging.Logger {
if md.log != nil {
return md.log
}
return defaultLog
}
func (md *msgDispatcher) runDispatcher() {
for m := range md.msgs {
if err := md.hmap.dispatch(m); err != nil {
md.logger().Warning("error dispatching message: %v", err)
}
}
}
func (handlers handlerMap) dispatch(m *Message) error {
h,ok := handlers[m.Type]
if !ok {

@ -11,37 +11,86 @@ import (
"crypto/rand"
"encoding/hex"
"fmt"
"io"
)
const maxFdCount = 3
var log = logging.MustGetLogger("oz")
type MsgConn struct {
msgs chan *Message
addr *net.UnixAddr
log *logging.Logger
conn *net.UnixConn
buf [1024]byte
oob []byte
handlers handlerMap
disp *msgDispatcher
factory MsgFactory
isClosed bool
done chan bool
idGen <-chan int
respMan *responseManager
onClose func()
}
func NewMsgConn(factory MsgFactory, address string) *MsgConn {
mc := new(MsgConn)
mc.addr = &net.UnixAddr{address, "unixgram"}
mc.oob = createOobBuffer()
mc.msgs = make(chan *Message)
mc.handlers = make(map[string]reflect.Value)
mc.factory = factory
mc.done = make(chan bool)
mc.idGen = newIdGen(mc.done)
mc.respMan = newResponseManager()
return mc
func RunServer(address string, factory MsgFactory, log *logging.Logger, handlers ...interface{}) error {
md,err := createDispatcher(log, handlers...)
if err != nil {
return err
}
addr := &net.UnixAddr{address, "unix"}
listener,err := net.ListenUnix("unix", addr)
if err != nil {
md.close()
return err
}
done := make(chan bool)
idGen := newIdGen(done)
for {
conn,err := listener.AcceptUnix()
if err != nil {
close(md.msgs)
listener.Close()
return err
}
if err := setPassCred(conn); err != nil {
return errors.New("Failed to set SO_PASSCRED on accepted socket connection:"+ err.Error())
}
mc := &MsgConn{
conn: conn,
disp: md,
oob: createOobBuffer(),
factory: factory,
idGen: idGen,
respMan: newResponseManager(),
}
go mc.readLoop()
}
return nil
}
func Connect(address string, factory MsgFactory, log *logging.Logger, handlers ...interface{}) (*MsgConn, error) {
md,err := createDispatcher(log, handlers...)
if err != nil {
return nil, err
}
conn,err := net.DialUnix("unix", nil, &net.UnixAddr{address, "unix"})
if err != nil {
return nil, err
}
done := make(chan bool)
idGen := newIdGen(done)
mc := &MsgConn{
conn: conn,
disp: md,
oob: createOobBuffer(),
factory: factory,
idGen: idGen,
respMan: newResponseManager(),
onClose: func() {
md.close()
close(done)
},
}
go mc.readLoop()
return mc, nil
}
func newIdGen(done <-chan bool) <-chan int {
@ -62,38 +111,6 @@ func idGenLoop(done <-chan bool, out chan <- int) {
}
}
func (mc *MsgConn) Listen() error {
if mc.conn != nil {
return errors.New("cannot Listen(), already connected")
}
conn, err := net.ListenUnixgram("unixgram", mc.addr)
if err != nil {
return err
}
if err := setPassCred(conn); err != nil {
return err
}
mc.conn = conn
return nil
}
func (mc *MsgConn) Connect() error {
if mc.conn != nil {
return errors.New("cannot Connect(), already connected")
}
clientAddr,err := CreateRandomAddress("@oz-")
if err != nil {
return err
}
conn, err := net.DialUnix("unixgram", &net.UnixAddr{clientAddr, "unixgram"}, nil)
if err != nil {
return err
}
mc.conn = conn
go mc.readLoop()
return nil
}
func CreateRandomAddress(prefix string) (string,error) {
var bs [16]byte
n,err := rand.Read(bs[:])
@ -106,16 +123,6 @@ func CreateRandomAddress(prefix string) (string,error) {
return prefix+ hex.EncodeToString(bs[:]),nil
}
func (mc *MsgConn) Run() error {
go mc.readLoop()
for m := range mc.msgs {
if err := mc.handlers.dispatch(m); err != nil {
return fmt.Errorf("error dispatching message: %v", err)
}
}
return nil
}
func (mc *MsgConn) readLoop() {
for {
if mc.processOneMessage() {
@ -124,24 +131,36 @@ func (mc *MsgConn) readLoop() {
}
}
func (mc *MsgConn) logger() *logging.Logger {
if mc.log != nil {
return mc.log
}
return defaultLog
}
func (mc *MsgConn) processOneMessage() bool {
m,err := mc.readMessage()
if err != nil {
close(mc.msgs)
if err == io.EOF {
mc.Close()
return true
}
if !mc.isClosed {
log.Warning("error on MsgConn.readMessage(): %v", err)
mc.logger().Warning("error on MsgConn.readMessage(): %v", err)
}
return true
}
if !mc.respMan.handle(m) {
mc.msgs <- m
mc.disp.dispatch(m)
}
return false
}
func (mc *MsgConn) Close() error {
mc.isClosed = true
close(mc.done)
if mc.onClose != nil {
mc.onClose()
}
return mc.conn.Close()
}
@ -151,7 +170,7 @@ func createOobBuffer() []byte {
}
func (mc *MsgConn) readMessage() (*Message, error) {
n, oobn, _, a, err := mc.conn.ReadMsgUnix(mc.buf[:], mc.oob)
n, oobn, _, _, err := mc.conn.ReadMsgUnix(mc.buf[:], mc.oob)
if err != nil {
return nil, err
}
@ -160,7 +179,6 @@ func (mc *MsgConn) readMessage() (*Message, error) {
return nil, err
}
m.mconn = mc
m.Peer = a
if oobn > 0 {
err := m.parseControlData(mc.oob[:oobn])
@ -198,7 +216,7 @@ func (mc *MsgConn) readMessage() (*Message, error) {
func (mc *MsgConn) AddHandlers(args ...interface{}) error {
for len(args) > 0 {
if err := mc.handlers.addHandler(args[0]); err != nil {
if err := mc.disp.hmap.addHandler(args[0]); err != nil {
return err
}
args = args[1:]
@ -207,21 +225,21 @@ func (mc *MsgConn) AddHandlers(args ...interface{}) error {
}
func (mc *MsgConn) SendMsg(msg interface{}, fds... int) error {
return mc.sendMessage(msg, <-mc.idGen, mc.addr, fds...)
return mc.sendMessage(msg, <-mc.idGen, fds...)
}
func (mc *MsgConn) ExchangeMsg(msg interface{}, fds... int) (ResponseReader, error) {
id := <-mc.idGen
rr := mc.respMan.register(id)
if err := mc.sendMessage(msg, id, mc.addr, fds...); err != nil {
if err := mc.sendMessage(msg, id, fds...); err != nil {
rr.Done()
return nil, err
}
return rr,nil
}
func (mc *MsgConn) sendMessage(msg interface{}, msgID int, dst *net.UnixAddr, fds... int) error {
func (mc *MsgConn) sendMessage(msg interface{}, msgID int, fds... int) error {
msgType, err := getMessageType(msg)
if err != nil {
return err
@ -234,7 +252,7 @@ func (mc *MsgConn) sendMessage(msg interface{}, msgID int, dst *net.UnixAddr, fd
if err != nil {
return err
}
return mc.sendRaw(raw, dst, fds...)
return mc.sendRaw(raw, fds...)
}
func getMessageType(msg interface{}) (string, error) {
@ -264,25 +282,17 @@ func (mc *MsgConn) newBaseMessage(msgType string, msgID int, body interface{}) (
return base, nil
}
func (mc *MsgConn) sendRaw(data []byte, dst *net.UnixAddr, fds ...int) error {
func (mc *MsgConn) sendRaw(data []byte, fds ...int) error {
if len(fds) > 0 {
return mc.sendWithFds(data, dst, fds)
}
return mc.write(data, dst)
}
func (mc *MsgConn) write(data []byte, dst *net.UnixAddr) error {
if dst != nil {
_,err := mc.conn.WriteToUnix(data, dst)
return err
return mc.sendWithFds(data, fds)
}
_,err := mc.conn.Write(data)
return err
}
func (mc *MsgConn) sendWithFds(data []byte, dst *net.UnixAddr, fds []int) error {
func (mc *MsgConn) sendWithFds(data []byte, fds []int) error {
oob := syscall.UnixRights(fds...)
_,_,err := mc.conn.WriteMsgUnix(data, oob, dst)
_,_,err := mc.conn.WriteMsgUnix(data, oob, nil)
return err
}

@ -2,7 +2,6 @@ package ipc
import (
"encoding/json"
"net"
"syscall"
"fmt"
"reflect"
@ -15,7 +14,7 @@ func NewMsgFactory(msgTypes ...interface{}) MsgFactory {
mf := (MsgFactory)(make(map[string]func() interface{}))
for _, mt := range msgTypes {
if err := mf.register(mt); err != nil {
log.Fatalf("failed adding (%T) in NewMsgFactory: %v", mt, err)
defaultLog.Fatalf("failed adding (%T) in NewMsgFactory: %v", mt, err)
return nil
}
}
@ -56,7 +55,6 @@ type Message struct {
Type string
MsgID int
Body interface{}
Peer *net.UnixAddr
Ucred *syscall.Ucred
Fds []int
mconn *MsgConn
@ -120,5 +118,5 @@ func (m *Message) parseControlData(data []byte) error {
}
func (m *Message) Respond(msg interface{}, fds... int) error {
return m.mconn.sendMessage(msg, m.MsgID, m.Peer, fds...)
return m.mconn.sendMessage(msg, m.MsgID, fds...)
}

@ -7,11 +7,7 @@ import (
)
func clientConnect() (*ipc.MsgConn, error) {
c := ipc.NewMsgConn(messageFactory, SocketName)
if err := c.Connect(); err != nil {
return nil, err
}
return c, nil
return ipc.Connect(SocketName, messageFactory, nil)
}
func clientSend(msg interface{}) (*ipc.Message, error) {

@ -25,6 +25,7 @@ func Main() {
d := initialize()
err := runServer(
d.log,
d.handlePing,
d.handleListProfiles,
d.handleLaunch,
@ -65,15 +66,12 @@ func (d *daemonState) handleChildExit(pid int, wstatus syscall.WaitStatus) {
d.Notice("No sandbox found with oz-init pid = %d", pid)
}
func runServer(args ...interface{}) error {
serv := ipc.NewMsgConn(messageFactory, SocketName)
if err := serv.AddHandlers(args...); err != nil {
return err
}
if err := serv.Listen(); err != nil {
func runServer(log *logging.Logger, args ...interface{}) error {
err := ipc.RunServer(SocketName, messageFactory, log, args...)
if err != nil {
return err
}
return serv.Run()
return nil
}
func (d * daemonState) handlePing(msg *PingMsg, m *ipc.Message) error {

@ -6,11 +6,7 @@ import (
)
func clientConnect(addr string) (*ipc.MsgConn, error) {
c := ipc.NewMsgConn(messageFactory, addr)
if err := c.Connect(); err != nil {
return nil, err
}
return c, nil
return ipc.Connect(addr, messageFactory, nil)
}
func clientSend(addr string, msg interface{}) (*ipc.Message, error) {

@ -130,13 +130,13 @@ func (st *initState) runInit() {
oz.ReapChildProcs(st.log, st.handleChildExit)
serv := ipc.NewMsgConn(messageFactory, st.address)
serv.AddHandlers(
err := ipc.RunServer(st.address, messageFactory, st.log,
handlePing,
st.handleRunShell,
)
serv.Listen()
serv.Run()
if err != nil {
st.log.Warning("RunServer returned err: %v", err)
}
st.log.Info("oz-init exiting...")
}

Loading…
Cancel
Save