diff --git a/ipc/ipc.go b/ipc/ipc.go index 2c87707..1ec0a18 100644 --- a/ipc/ipc.go +++ b/ipc/ipc.go @@ -6,6 +6,7 @@ import ( "net" "syscall" + "encoding/binary" "fmt" "github.com/op/go-logging" "io" @@ -13,11 +14,13 @@ import ( ) const maxFdCount = 3 +const maxMessageSz = 128 * 1024 +const bufferSz = 1024 type MsgConn struct { log *logging.Logger conn *net.UnixConn - buf [1024]byte + buf []byte oob []byte disp *msgDispatcher factory MsgFactory @@ -28,6 +31,7 @@ type MsgConn struct { } type MsgServer struct { + log *logging.Logger disp *msgDispatcher factory MsgFactory listener *net.UnixListener @@ -52,6 +56,7 @@ func NewServer(address string, factory MsgFactory, log *logging.Logger, handlers done := make(chan bool) idGen := newIdGen(done) return &MsgServer{ + log: log, disp: md, factory: factory, listener: listener, @@ -70,8 +75,10 @@ func (s *MsgServer) Run() error { return errors.New("Failed to set SO_PASSCRED on accepted socket connection:" + err.Error()) } mc := &MsgConn{ + log: s.log, conn: conn, disp: s.disp, + buf: make([]byte, bufferSz), oob: createOobBuffer(), factory: s.factory, idGen: s.idGen, @@ -100,6 +107,7 @@ func Connect(address string, factory MsgFactory, log *logging.Logger, handlers . done := make(chan bool) idGen := newIdGen(done) mc := &MsgConn{ + log: log, conn: conn, disp: md, oob: createOobBuffer(), @@ -180,7 +188,19 @@ func createOobBuffer() []byte { } func (mc *MsgConn) readMessage() (*Message, error) { - n, oobn, _, _, err := mc.conn.ReadMsgUnix(mc.buf[:], mc.oob) + var szbuf [4]byte + n, oobn, _, _, err := mc.conn.ReadMsgUnix(szbuf[:], mc.oob) + if err != nil { + return nil, err + } + sz := binary.BigEndian.Uint32(szbuf[:]) + if sz > maxMessageSz { + return nil, fmt.Errorf("message size of (%d) exceeds maximum message size (%d)", sz, maxMessageSz) + } + if sz > uint32(len(mc.buf)) { + mc.buf = make([]byte, sz) + } + n, _, _, _, err = mc.conn.ReadMsgUnix(mc.buf[:sz], nil) if err != nil { return nil, err } @@ -261,7 +281,10 @@ func (mc *MsgConn) sendMessage(msg interface{}, msgID int, fds ...int) error { if err != nil { return err } - return mc.sendRaw(raw, fds...) + buf := make([]byte, len(raw)+4) + binary.BigEndian.PutUint32(buf, uint32(len(raw))) + copy(buf[4:], raw) + return mc.sendRaw(buf, fds...) } func getMessageType(msg interface{}) (string, error) {