delimit each message with a length field

master
brl 10 years ago
parent ea2034fc45
commit dee514bd4e

@ -6,6 +6,7 @@ import (
"net" "net"
"syscall" "syscall"
"encoding/binary"
"fmt" "fmt"
"github.com/op/go-logging" "github.com/op/go-logging"
"io" "io"
@ -13,11 +14,13 @@ import (
) )
const maxFdCount = 3 const maxFdCount = 3
const maxMessageSz = 128 * 1024
const bufferSz = 1024
type MsgConn struct { type MsgConn struct {
log *logging.Logger log *logging.Logger
conn *net.UnixConn conn *net.UnixConn
buf [1024]byte buf []byte
oob []byte oob []byte
disp *msgDispatcher disp *msgDispatcher
factory MsgFactory factory MsgFactory
@ -28,6 +31,7 @@ type MsgConn struct {
} }
type MsgServer struct { type MsgServer struct {
log *logging.Logger
disp *msgDispatcher disp *msgDispatcher
factory MsgFactory factory MsgFactory
listener *net.UnixListener listener *net.UnixListener
@ -52,6 +56,7 @@ func NewServer(address string, factory MsgFactory, log *logging.Logger, handlers
done := make(chan bool) done := make(chan bool)
idGen := newIdGen(done) idGen := newIdGen(done)
return &MsgServer{ return &MsgServer{
log: log,
disp: md, disp: md,
factory: factory, factory: factory,
listener: listener, listener: listener,
@ -70,8 +75,10 @@ func (s *MsgServer) Run() error {
return errors.New("Failed to set SO_PASSCRED on accepted socket connection:" + err.Error()) return errors.New("Failed to set SO_PASSCRED on accepted socket connection:" + err.Error())
} }
mc := &MsgConn{ mc := &MsgConn{
log: s.log,
conn: conn, conn: conn,
disp: s.disp, disp: s.disp,
buf: make([]byte, bufferSz),
oob: createOobBuffer(), oob: createOobBuffer(),
factory: s.factory, factory: s.factory,
idGen: s.idGen, idGen: s.idGen,
@ -100,6 +107,7 @@ func Connect(address string, factory MsgFactory, log *logging.Logger, handlers .
done := make(chan bool) done := make(chan bool)
idGen := newIdGen(done) idGen := newIdGen(done)
mc := &MsgConn{ mc := &MsgConn{
log: log,
conn: conn, conn: conn,
disp: md, disp: md,
oob: createOobBuffer(), oob: createOobBuffer(),
@ -180,7 +188,19 @@ func createOobBuffer() []byte {
} }
func (mc *MsgConn) readMessage() (*Message, error) { func (mc *MsgConn) readMessage() (*Message, error) {
n, oobn, _, _, err := mc.conn.ReadMsgUnix(mc.buf[:], mc.oob) var szbuf [4]byte
n, oobn, _, _, err := mc.conn.ReadMsgUnix(szbuf[:], mc.oob)
if err != nil {
return nil, err
}
sz := binary.BigEndian.Uint32(szbuf[:])
if sz > maxMessageSz {
return nil, fmt.Errorf("message size of (%d) exceeds maximum message size (%d)", sz, maxMessageSz)
}
if sz > uint32(len(mc.buf)) {
mc.buf = make([]byte, sz)
}
n, _, _, _, err = mc.conn.ReadMsgUnix(mc.buf[:sz], nil)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -261,7 +281,10 @@ func (mc *MsgConn) sendMessage(msg interface{}, msgID int, fds ...int) error {
if err != nil { if err != nil {
return err return err
} }
return mc.sendRaw(raw, fds...) buf := make([]byte, len(raw)+4)
binary.BigEndian.PutUint32(buf, uint32(len(raw)))
copy(buf[4:], raw)
return mc.sendRaw(buf, fds...)
} }
func getMessageType(msg interface{}) (string, error) { func getMessageType(msg interface{}) (string, error) {

Loading…
Cancel
Save