diff --git a/mortal_service.go b/mortal_service.go new file mode 100644 index 0000000..b18fd14 --- /dev/null +++ b/mortal_service.go @@ -0,0 +1,139 @@ +package main + +import ( + "fmt" + "net" + "sync" + "time" +) + +// MortalService can be killed at any time. +type MortalService struct { + network string + address string + connectionCallback func(net.Conn) error + + conns []net.Conn + quit chan bool + listener net.Listener + waitGroup *sync.WaitGroup +} + +// NewMortalService creates a new MortalService +func NewMortalService(network, address string, connectionCallback func(net.Conn) error) *MortalService { + l := MortalService{ + network: network, + address: address, + connectionCallback: connectionCallback, + + conns: make([]net.Conn, 0, 10), + quit: make(chan bool), + waitGroup: &sync.WaitGroup{}, + } + return &l +} + +// Stop will kill our listener and all it's connections +func (l *MortalService) Stop() { + log.Info("stopping listener service %s:%s", l.network, l.address) + close(l.quit) + if l.listener != nil { + l.listener.Close() + } + l.waitGroup.Wait() +} + +func (l *MortalService) acceptLoop() { + defer l.waitGroup.Done() + defer func() { + log.Info("stoping listener service %s:%s", l.network, l.address) + for i, conn := range l.conns { + if conn != nil { + log.Info("Closing connection #%d", i) + conn.Close() + } + } + }() + defer l.listener.Close() + + for { + conn, err := l.listener.Accept() + if nil != err { + if opErr, ok := err.(*net.OpError); ok && opErr.Timeout() { + continue + } else { + log.Info("MortalService connection accept failure: %s\n", err) + select { + case <-l.quit: + return + default: + } + continue + } + } + + l.conns = append(l.conns, conn) + go l.handleConnection(conn, len(l.conns)-1) + } +} + +func (l *MortalService) createDeadlinedListener() error { + if l.network == "tcp" { + tcpAddr, err := net.ResolveTCPAddr("tcp", l.address) + if err != nil { + return fmt.Errorf("MortalService.createDeadlinedListener %s %s failure: %s", l.network, l.address, err) + } + tcpListener, err := net.ListenTCP("tcp", tcpAddr) + if err != nil { + return fmt.Errorf("MortalService.createDeadlinedListener %s %s failure: %s", l.network, l.address, err) + } + tcpListener.SetDeadline(time.Now().Add(1e9)) + l.listener = tcpListener + return nil + } else if l.network == "unix" { + unixAddr, err := net.ResolveUnixAddr("unix", l.address) + if err != nil { + return fmt.Errorf("MortalService.createDeadlinedListener %s %s failure: %s", l.network, l.address, err) + } + unixListener, err := net.ListenUnix("unix", unixAddr) + if err != nil { + return fmt.Errorf("MortalService.createDeadlinedListener %s %s failure: %s", l.network, l.address, err) + } + unixListener.SetDeadline(time.Now().Add(1e9)) + l.listener = unixListener + return nil + } else { + panic("") + } + return nil +} + +// Start the MortalService +func (l *MortalService) Start() error { + var err error + err = l.createDeadlinedListener() + if err != nil { + return err + } + l.waitGroup.Add(1) + go l.acceptLoop() + return nil +} + +func (l *MortalService) handleConnection(conn net.Conn, id int) error { + defer func() { + log.Info("Closing connection #%d", id) + conn.Close() + l.conns[id] = nil + }() + + log.Info("Starting connection #%d", id) + + for { + if err := l.connectionCallback(conn); err != nil { + log.Error(err.Error()) + return err + } + return nil + } +} diff --git a/socks_server_chain_test.go b/socks_server_chain_test.go new file mode 100644 index 0000000..bc782f3 --- /dev/null +++ b/socks_server_chain_test.go @@ -0,0 +1,91 @@ +package main + +import ( + "bufio" + "bytes" + "fmt" + "golang.org/x/net/proxy" + "net" + "strings" + "sync" + "testing" +) + +type AccumulatingService struct { + net, address string + buffer bytes.Buffer + mortalService *MortalService + hasProtocolInfo bool + hasAuthenticate bool +} + +func NewAccumulatingService(net, address string) *AccumulatingService { + l := AccumulatingService{ + net: net, + address: address, + hasProtocolInfo: true, + hasAuthenticate: true, + } + return &l +} + +func (a *AccumulatingService) Start() { + a.mortalService = NewMortalService(a.net, a.address, a.SessionWorker) + a.mortalService.Start() +} + +func (a *AccumulatingService) Stop() { + fmt.Println("AccumulatingService STOP") + a.mortalService.Stop() +} + +func (a *AccumulatingService) SessionWorker(conn net.Conn) error { + connReader := bufio.NewReader(conn) + for { + + line, err := connReader.ReadBytes('\n') + if err != nil { + fmt.Println("AccumulatingService read error:", err) + } + lineStr := strings.TrimSpace(string(line)) + a.buffer.WriteString(lineStr + "\n") + } + return nil +} + +func TestSocksServerProxyChain(t *testing.T) { + socksConfig := SocksChainConfig{ + TargetSocksNet: "tcp", + TargetSocksAddr: "127.0.0.1:9050", + ListenSocksNet: "tcp", + ListenSocksAddr: "127.0.0.1:8850", + } + wg := sync.WaitGroup{} + InitSocksListener(&socksConfig, &wg) + + auth := proxy.Auth{ + User: "", + Password: "", + } + forward := proxy.NewPerHost(proxy.Direct, proxy.Direct) + + terminatingService := NewAccumulatingService("tcp", "127.0.0.1:1234") + terminatingService.Start() + + socksClient, err := proxy.SOCKS5("tcp", "127.0.0.1:8850", &auth, forward) + conn, err := socksClient.Dial("tcp", "127.0.0.1:1234") + + if err != nil { + panic(err) + } + + rd := bufio.NewReader(conn) + line := []byte{} + line, err = rd.ReadBytes('\n') + if err != nil { + panic(err) + } + fmt.Println("socks client received", string(line)) + + wg.Wait() +}