mirror of https://github.com/subgraph/fw-daemon
parent
28e89eb149
commit
f87ac5639e
@ -0,0 +1,139 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
// MortalService can be killed at any time.
|
||||
type MortalService struct {
|
||||
network string
|
||||
address string
|
||||
connectionCallback func(net.Conn) error
|
||||
|
||||
conns []net.Conn
|
||||
quit chan bool
|
||||
listener net.Listener
|
||||
waitGroup *sync.WaitGroup
|
||||
}
|
||||
|
||||
// NewMortalService creates a new MortalService
|
||||
func NewMortalService(network, address string, connectionCallback func(net.Conn) error) *MortalService {
|
||||
l := MortalService{
|
||||
network: network,
|
||||
address: address,
|
||||
connectionCallback: connectionCallback,
|
||||
|
||||
conns: make([]net.Conn, 0, 10),
|
||||
quit: make(chan bool),
|
||||
waitGroup: &sync.WaitGroup{},
|
||||
}
|
||||
return &l
|
||||
}
|
||||
|
||||
// Stop will kill our listener and all it's connections
|
||||
func (l *MortalService) Stop() {
|
||||
log.Info("stopping listener service %s:%s", l.network, l.address)
|
||||
close(l.quit)
|
||||
if l.listener != nil {
|
||||
l.listener.Close()
|
||||
}
|
||||
l.waitGroup.Wait()
|
||||
}
|
||||
|
||||
func (l *MortalService) acceptLoop() {
|
||||
defer l.waitGroup.Done()
|
||||
defer func() {
|
||||
log.Info("stoping listener service %s:%s", l.network, l.address)
|
||||
for i, conn := range l.conns {
|
||||
if conn != nil {
|
||||
log.Info("Closing connection #%d", i)
|
||||
conn.Close()
|
||||
}
|
||||
}
|
||||
}()
|
||||
defer l.listener.Close()
|
||||
|
||||
for {
|
||||
conn, err := l.listener.Accept()
|
||||
if nil != err {
|
||||
if opErr, ok := err.(*net.OpError); ok && opErr.Timeout() {
|
||||
continue
|
||||
} else {
|
||||
log.Info("MortalService connection accept failure: %s\n", err)
|
||||
select {
|
||||
case <-l.quit:
|
||||
return
|
||||
default:
|
||||
}
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
l.conns = append(l.conns, conn)
|
||||
go l.handleConnection(conn, len(l.conns)-1)
|
||||
}
|
||||
}
|
||||
|
||||
func (l *MortalService) createDeadlinedListener() error {
|
||||
if l.network == "tcp" {
|
||||
tcpAddr, err := net.ResolveTCPAddr("tcp", l.address)
|
||||
if err != nil {
|
||||
return fmt.Errorf("MortalService.createDeadlinedListener %s %s failure: %s", l.network, l.address, err)
|
||||
}
|
||||
tcpListener, err := net.ListenTCP("tcp", tcpAddr)
|
||||
if err != nil {
|
||||
return fmt.Errorf("MortalService.createDeadlinedListener %s %s failure: %s", l.network, l.address, err)
|
||||
}
|
||||
tcpListener.SetDeadline(time.Now().Add(1e9))
|
||||
l.listener = tcpListener
|
||||
return nil
|
||||
} else if l.network == "unix" {
|
||||
unixAddr, err := net.ResolveUnixAddr("unix", l.address)
|
||||
if err != nil {
|
||||
return fmt.Errorf("MortalService.createDeadlinedListener %s %s failure: %s", l.network, l.address, err)
|
||||
}
|
||||
unixListener, err := net.ListenUnix("unix", unixAddr)
|
||||
if err != nil {
|
||||
return fmt.Errorf("MortalService.createDeadlinedListener %s %s failure: %s", l.network, l.address, err)
|
||||
}
|
||||
unixListener.SetDeadline(time.Now().Add(1e9))
|
||||
l.listener = unixListener
|
||||
return nil
|
||||
} else {
|
||||
panic("")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Start the MortalService
|
||||
func (l *MortalService) Start() error {
|
||||
var err error
|
||||
err = l.createDeadlinedListener()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
l.waitGroup.Add(1)
|
||||
go l.acceptLoop()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (l *MortalService) handleConnection(conn net.Conn, id int) error {
|
||||
defer func() {
|
||||
log.Info("Closing connection #%d", id)
|
||||
conn.Close()
|
||||
l.conns[id] = nil
|
||||
}()
|
||||
|
||||
log.Info("Starting connection #%d", id)
|
||||
|
||||
for {
|
||||
if err := l.connectionCallback(conn); err != nil {
|
||||
log.Error(err.Error())
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
}
|
@ -0,0 +1,91 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"fmt"
|
||||
"golang.org/x/net/proxy"
|
||||
"net"
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
)
|
||||
|
||||
type AccumulatingService struct {
|
||||
net, address string
|
||||
buffer bytes.Buffer
|
||||
mortalService *MortalService
|
||||
hasProtocolInfo bool
|
||||
hasAuthenticate bool
|
||||
}
|
||||
|
||||
func NewAccumulatingService(net, address string) *AccumulatingService {
|
||||
l := AccumulatingService{
|
||||
net: net,
|
||||
address: address,
|
||||
hasProtocolInfo: true,
|
||||
hasAuthenticate: true,
|
||||
}
|
||||
return &l
|
||||
}
|
||||
|
||||
func (a *AccumulatingService) Start() {
|
||||
a.mortalService = NewMortalService(a.net, a.address, a.SessionWorker)
|
||||
a.mortalService.Start()
|
||||
}
|
||||
|
||||
func (a *AccumulatingService) Stop() {
|
||||
fmt.Println("AccumulatingService STOP")
|
||||
a.mortalService.Stop()
|
||||
}
|
||||
|
||||
func (a *AccumulatingService) SessionWorker(conn net.Conn) error {
|
||||
connReader := bufio.NewReader(conn)
|
||||
for {
|
||||
|
||||
line, err := connReader.ReadBytes('\n')
|
||||
if err != nil {
|
||||
fmt.Println("AccumulatingService read error:", err)
|
||||
}
|
||||
lineStr := strings.TrimSpace(string(line))
|
||||
a.buffer.WriteString(lineStr + "\n")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func TestSocksServerProxyChain(t *testing.T) {
|
||||
socksConfig := SocksChainConfig{
|
||||
TargetSocksNet: "tcp",
|
||||
TargetSocksAddr: "127.0.0.1:9050",
|
||||
ListenSocksNet: "tcp",
|
||||
ListenSocksAddr: "127.0.0.1:8850",
|
||||
}
|
||||
wg := sync.WaitGroup{}
|
||||
InitSocksListener(&socksConfig, &wg)
|
||||
|
||||
auth := proxy.Auth{
|
||||
User: "",
|
||||
Password: "",
|
||||
}
|
||||
forward := proxy.NewPerHost(proxy.Direct, proxy.Direct)
|
||||
|
||||
terminatingService := NewAccumulatingService("tcp", "127.0.0.1:1234")
|
||||
terminatingService.Start()
|
||||
|
||||
socksClient, err := proxy.SOCKS5("tcp", "127.0.0.1:8850", &auth, forward)
|
||||
conn, err := socksClient.Dial("tcp", "127.0.0.1:1234")
|
||||
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
rd := bufio.NewReader(conn)
|
||||
line := []byte{}
|
||||
line, err = rd.ReadBytes('\n')
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
fmt.Println("socks client received", string(line))
|
||||
|
||||
wg.Wait()
|
||||
}
|
Loading…
Reference in new issue