mirror of https://github.com/xSmurf/oz.git
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
306 lines
6.6 KiB
306 lines
6.6 KiB
package ipc
|
|
|
|
import (
|
|
"encoding/json"
|
|
"errors"
|
|
"net"
|
|
"syscall"
|
|
|
|
"fmt"
|
|
"github.com/op/go-logging"
|
|
"io"
|
|
"reflect"
|
|
)
|
|
|
|
const maxFdCount = 3
|
|
|
|
type MsgConn struct {
|
|
log *logging.Logger
|
|
conn *net.UnixConn
|
|
buf [1024]byte
|
|
oob []byte
|
|
disp *msgDispatcher
|
|
factory MsgFactory
|
|
isClosed bool
|
|
idGen <-chan int
|
|
respMan *responseManager
|
|
onClose func()
|
|
}
|
|
|
|
type MsgServer struct {
|
|
disp *msgDispatcher
|
|
factory MsgFactory
|
|
listener *net.UnixListener
|
|
done chan bool
|
|
idGen <-chan int
|
|
}
|
|
|
|
func NewServer(address string, factory MsgFactory, log *logging.Logger, handlers ...interface{}) (*MsgServer, error) {
|
|
md, err := createDispatcher(log, handlers...)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
listener, err := net.ListenUnix("unix", &net.UnixAddr{address, "unix"})
|
|
if err := setPassCred(listener); err != nil {
|
|
return nil, errors.New("Failed to set SO_PASSCRED on listening socket: " + err.Error())
|
|
}
|
|
if err != nil {
|
|
md.close()
|
|
return nil, err
|
|
}
|
|
done := make(chan bool)
|
|
idGen := newIdGen(done)
|
|
return &MsgServer{
|
|
disp: md,
|
|
factory: factory,
|
|
listener: listener,
|
|
done: done,
|
|
idGen: idGen,
|
|
}, nil
|
|
}
|
|
|
|
func (s *MsgServer) Run() error {
|
|
for {
|
|
conn, err := s.listener.AcceptUnix()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
if err := setPassCred(conn); err != nil {
|
|
return errors.New("Failed to set SO_PASSCRED on accepted socket connection:" + err.Error())
|
|
}
|
|
mc := &MsgConn{
|
|
conn: conn,
|
|
disp: s.disp,
|
|
oob: createOobBuffer(),
|
|
factory: s.factory,
|
|
idGen: s.idGen,
|
|
respMan: newResponseManager(),
|
|
}
|
|
go mc.readLoop()
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (s *MsgServer) Close() error {
|
|
s.disp.close()
|
|
close(s.done)
|
|
return s.listener.Close()
|
|
}
|
|
|
|
func Connect(address string, factory MsgFactory, log *logging.Logger, handlers ...interface{}) (*MsgConn, error) {
|
|
md, err := createDispatcher(log, handlers...)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
conn, err := net.DialUnix("unix", nil, &net.UnixAddr{address, "unix"})
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
done := make(chan bool)
|
|
idGen := newIdGen(done)
|
|
mc := &MsgConn{
|
|
conn: conn,
|
|
disp: md,
|
|
oob: createOobBuffer(),
|
|
factory: factory,
|
|
idGen: idGen,
|
|
respMan: newResponseManager(),
|
|
onClose: func() {
|
|
md.close()
|
|
close(done)
|
|
},
|
|
}
|
|
go mc.readLoop()
|
|
return mc, nil
|
|
}
|
|
|
|
func newIdGen(done <-chan bool) <-chan int {
|
|
ch := make(chan int)
|
|
go idGenLoop(done, ch)
|
|
return ch
|
|
}
|
|
|
|
func idGenLoop(done <-chan bool, out chan<- int) {
|
|
current := int(1)
|
|
for {
|
|
select {
|
|
case out <- current:
|
|
current += 1
|
|
case <-done:
|
|
return
|
|
}
|
|
}
|
|
}
|
|
|
|
func (mc *MsgConn) readLoop() {
|
|
for {
|
|
if mc.processOneMessage() {
|
|
return
|
|
}
|
|
}
|
|
}
|
|
|
|
func (mc *MsgConn) logger() *logging.Logger {
|
|
if mc.log != nil {
|
|
return mc.log
|
|
}
|
|
return defaultLog
|
|
}
|
|
|
|
func (mc *MsgConn) processOneMessage() bool {
|
|
m, err := mc.readMessage()
|
|
if err != nil {
|
|
if err == io.EOF {
|
|
mc.Close()
|
|
return true
|
|
}
|
|
if !mc.isClosed {
|
|
mc.logger().Warning("error on MsgConn.readMessage(): %v", err)
|
|
}
|
|
return true
|
|
}
|
|
if !mc.respMan.handle(m) {
|
|
mc.disp.dispatch(m)
|
|
}
|
|
return false
|
|
}
|
|
|
|
func (mc *MsgConn) Close() error {
|
|
mc.isClosed = true
|
|
if mc.onClose != nil {
|
|
mc.onClose()
|
|
}
|
|
return mc.conn.Close()
|
|
}
|
|
|
|
func createOobBuffer() []byte {
|
|
oobSize := syscall.CmsgSpace(syscall.SizeofUcred) + syscall.CmsgSpace(4*maxFdCount)
|
|
return make([]byte, oobSize)
|
|
}
|
|
|
|
func (mc *MsgConn) readMessage() (*Message, error) {
|
|
n, oobn, _, _, err := mc.conn.ReadMsgUnix(mc.buf[:], mc.oob)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
m, err := mc.parseMessage(mc.buf[:n])
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
m.mconn = mc
|
|
|
|
if oobn > 0 {
|
|
err := m.parseControlData(mc.oob[:oobn])
|
|
if err != nil {
|
|
}
|
|
}
|
|
return m, nil
|
|
}
|
|
|
|
// AddHandlers registers a list of message handling functions with a MsgConn instance.
|
|
// Each handler function must have two arguments and return a single error value. The
|
|
// first argument must be pointer to a message structure type. A message structure type
|
|
// is a structure that must have a struct tag on the first field:
|
|
//
|
|
// type FooMsg struct {
|
|
// Stuff string "Foo" // <------ struct tag
|
|
// // etc...
|
|
// }
|
|
//
|
|
// type SimpleMsg struct {
|
|
// dummy int "Simple" // struct has no fields, so add an unexported dummy field just for the tag
|
|
// }
|
|
//
|
|
// The second argument to a handler function must have type *ipc.Message. After a handler function
|
|
// has been registered, received messages matching the first argument will be dispatched to the corresponding
|
|
// handler function.
|
|
//
|
|
// func fooHandler(foo *FooMsg, msg *ipc.Message) error { /* ... */ }
|
|
// func simpleHandler(simple *SimpleMsg, msg *ipc.Message) error { /* ... */ }
|
|
//
|
|
// /* register fooHandler() to handle incoming FooMsg and SimpleHandler to handle SimpleMsg */
|
|
// conn.AddHandlers(fooHandler, simpleHandler)
|
|
//
|
|
|
|
func (mc *MsgConn) AddHandlers(args ...interface{}) error {
|
|
for len(args) > 0 {
|
|
if err := mc.disp.hmap.addHandler(args[0]); err != nil {
|
|
return err
|
|
}
|
|
args = args[1:]
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (mc *MsgConn) SendMsg(msg interface{}, fds ...int) error {
|
|
return mc.sendMessage(msg, <-mc.idGen, fds...)
|
|
}
|
|
|
|
func (mc *MsgConn) ExchangeMsg(msg interface{}, fds ...int) (ResponseReader, error) {
|
|
id := <-mc.idGen
|
|
rr := mc.respMan.register(id)
|
|
|
|
if err := mc.sendMessage(msg, id, fds...); err != nil {
|
|
rr.Done()
|
|
return nil, err
|
|
}
|
|
return rr, nil
|
|
}
|
|
|
|
func (mc *MsgConn) sendMessage(msg interface{}, msgID int, fds ...int) error {
|
|
msgType, err := getMessageType(msg)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
base, err := mc.newBaseMessage(msgType, msgID, msg)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
raw, err := json.Marshal(base)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
return mc.sendRaw(raw, fds...)
|
|
}
|
|
|
|
func getMessageType(msg interface{}) (string, error) {
|
|
t := reflect.TypeOf(msg)
|
|
if t.Kind() == reflect.Ptr {
|
|
t = t.Elem()
|
|
}
|
|
if t.Kind() != reflect.Struct {
|
|
return "", fmt.Errorf("sendMessage() msg (%T) is not a struct", msg)
|
|
}
|
|
if t.NumField() == 0 || len(t.Field(0).Tag) == 0 {
|
|
return "", fmt.Errorf("sendMessage() msg struct (%T) does not have tag on first field")
|
|
}
|
|
return string(t.Field(0).Tag), nil
|
|
}
|
|
|
|
func (mc *MsgConn) newBaseMessage(msgType string, msgID int, body interface{}) (*BaseMsg, error) {
|
|
bodyBytes, err := json.Marshal(body)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
base := new(BaseMsg)
|
|
base.Type = msgType
|
|
base.MsgID = msgID
|
|
base.Body = bodyBytes
|
|
return base, nil
|
|
}
|
|
|
|
func (mc *MsgConn) sendRaw(data []byte, fds ...int) error {
|
|
if len(fds) > 0 {
|
|
return mc.sendWithFds(data, fds)
|
|
}
|
|
_, err := mc.conn.Write(data)
|
|
return err
|
|
}
|
|
|
|
func (mc *MsgConn) sendWithFds(data []byte, fds []int) error {
|
|
oob := syscall.UnixRights(fds...)
|
|
_, _, err := mc.conn.WriteMsgUnix(data, oob, nil)
|
|
return err
|
|
}
|