mirror of https://github.com/subgraph/fw-daemon
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
306 lines
7.4 KiB
306 lines
7.4 KiB
package main
|
|
|
|
import (
|
|
"bufio"
|
|
"bytes"
|
|
"fmt"
|
|
"io"
|
|
"net"
|
|
"strings"
|
|
"sync"
|
|
"testing"
|
|
"time"
|
|
|
|
"github.com/subgraph/fw-daemon/socks5"
|
|
"golang.org/x/net/proxy"
|
|
)
|
|
|
|
// MortalService can be killed at any time.
|
|
type MortalService struct {
|
|
network string
|
|
address string
|
|
connectionCallback func(net.Conn) error
|
|
|
|
conns []net.Conn
|
|
quit chan bool
|
|
listener net.Listener
|
|
waitGroup *sync.WaitGroup
|
|
}
|
|
|
|
// NewMortalService creates a new MortalService
|
|
func NewMortalService(network, address string, connectionCallback func(net.Conn) error) *MortalService {
|
|
l := MortalService{
|
|
network: network,
|
|
address: address,
|
|
connectionCallback: connectionCallback,
|
|
|
|
conns: make([]net.Conn, 0, 10),
|
|
quit: make(chan bool),
|
|
waitGroup: &sync.WaitGroup{},
|
|
}
|
|
return &l
|
|
}
|
|
|
|
// Stop will kill our listener and all it's connections
|
|
func (l *MortalService) Stop() {
|
|
log.Info("stopping listener service %s:%s", l.network, l.address)
|
|
close(l.quit)
|
|
if l.listener != nil {
|
|
l.listener.Close()
|
|
}
|
|
l.waitGroup.Wait()
|
|
}
|
|
|
|
func (l *MortalService) acceptLoop() {
|
|
defer l.waitGroup.Done()
|
|
defer func() {
|
|
log.Info("stoping listener service %s:%s", l.network, l.address)
|
|
for i, conn := range l.conns {
|
|
if conn != nil {
|
|
log.Info("Closing connection #%d", i)
|
|
conn.Close()
|
|
}
|
|
}
|
|
}()
|
|
defer l.listener.Close()
|
|
|
|
for {
|
|
conn, err := l.listener.Accept()
|
|
if nil != err {
|
|
if opErr, ok := err.(*net.OpError); ok && opErr.Timeout() {
|
|
continue
|
|
} else {
|
|
log.Info("MortalService connection accept failure: %s\n", err)
|
|
select {
|
|
case <-l.quit:
|
|
return
|
|
default:
|
|
}
|
|
continue
|
|
}
|
|
}
|
|
|
|
l.conns = append(l.conns, conn)
|
|
go l.handleConnection(conn, len(l.conns)-1)
|
|
}
|
|
}
|
|
|
|
func (l *MortalService) createDeadlinedListener() error {
|
|
if l.network == "tcp" {
|
|
tcpAddr, err := net.ResolveTCPAddr("tcp", l.address)
|
|
if err != nil {
|
|
return fmt.Errorf("MortalService.createDeadlinedListener %s %s failure: %s", l.network, l.address, err)
|
|
}
|
|
tcpListener, err := net.ListenTCP("tcp", tcpAddr)
|
|
if err != nil {
|
|
return fmt.Errorf("MortalService.createDeadlinedListener %s %s failure: %s", l.network, l.address, err)
|
|
}
|
|
tcpListener.SetDeadline(time.Now().Add(1e9))
|
|
l.listener = tcpListener
|
|
return nil
|
|
} else if l.network == "unix" {
|
|
unixAddr, err := net.ResolveUnixAddr("unix", l.address)
|
|
if err != nil {
|
|
return fmt.Errorf("MortalService.createDeadlinedListener %s %s failure: %s", l.network, l.address, err)
|
|
}
|
|
unixListener, err := net.ListenUnix("unix", unixAddr)
|
|
if err != nil {
|
|
return fmt.Errorf("MortalService.createDeadlinedListener %s %s failure: %s", l.network, l.address, err)
|
|
}
|
|
unixListener.SetDeadline(time.Now().Add(1e9))
|
|
l.listener = unixListener
|
|
return nil
|
|
} else {
|
|
panic("")
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// Start the MortalService
|
|
func (l *MortalService) Start() error {
|
|
var err error
|
|
err = l.createDeadlinedListener()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
l.waitGroup.Add(1)
|
|
go l.acceptLoop()
|
|
return nil
|
|
}
|
|
|
|
func (l *MortalService) handleConnection(conn net.Conn, id int) error {
|
|
defer func() {
|
|
log.Info("Closing connection #%d", id)
|
|
conn.Close()
|
|
l.conns[id] = nil
|
|
}()
|
|
|
|
log.Info("Starting connection #%d", id)
|
|
|
|
for {
|
|
if err := l.connectionCallback(conn); err != nil {
|
|
log.Error(err.Error())
|
|
return err
|
|
}
|
|
return nil
|
|
}
|
|
}
|
|
|
|
type AccumulatingService struct {
|
|
net, address string
|
|
banner string
|
|
buffer bytes.Buffer
|
|
mortalService *MortalService
|
|
hasProtocolInfo bool
|
|
hasAuthenticate bool
|
|
receivedChan chan bool
|
|
}
|
|
|
|
func NewAccumulatingService(net, address, banner string) *AccumulatingService {
|
|
l := AccumulatingService{
|
|
net: net,
|
|
address: address,
|
|
banner: banner,
|
|
hasProtocolInfo: true,
|
|
hasAuthenticate: true,
|
|
receivedChan: make(chan bool, 0),
|
|
}
|
|
return &l
|
|
}
|
|
|
|
func (a *AccumulatingService) Start() {
|
|
a.mortalService = NewMortalService(a.net, a.address, a.SessionWorker)
|
|
a.mortalService.Start()
|
|
}
|
|
|
|
func (a *AccumulatingService) Stop() {
|
|
fmt.Println("AccumulatingService STOP")
|
|
a.mortalService.Stop()
|
|
}
|
|
|
|
func (a *AccumulatingService) WaitUntilReceived() {
|
|
<-a.receivedChan
|
|
}
|
|
|
|
func (a *AccumulatingService) SessionWorker(conn net.Conn) error {
|
|
connReader := bufio.NewReader(conn)
|
|
conn.Write([]byte(a.banner))
|
|
for {
|
|
line, err := connReader.ReadBytes('\n')
|
|
if err != nil {
|
|
fmt.Printf("AccumulatingService read error: %s\n", err)
|
|
}
|
|
lineStr := strings.TrimSpace(string(line))
|
|
a.buffer.WriteString(lineStr + "\n")
|
|
a.receivedChan <- true
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func fakeSocksSessionWorker(clientConn net.Conn, targetNet, targetAddr string) error {
|
|
defer clientConn.Close()
|
|
|
|
clientAddr := clientConn.RemoteAddr()
|
|
fmt.Printf("INFO/socks: New connection from: %v\n", clientAddr)
|
|
|
|
// Do the SOCKS handshake with the client, and read the command.
|
|
req, err := socks5.Handshake(clientConn)
|
|
if err != nil {
|
|
panic(fmt.Sprintf("ERR/socks: Failed SOCKS5 handshake: %v", err))
|
|
}
|
|
|
|
var upstreamConn net.Conn
|
|
upstreamConn, err = net.Dial(targetNet, targetAddr)
|
|
if err != nil {
|
|
panic(err)
|
|
}
|
|
defer upstreamConn.Close()
|
|
req.Reply(socks5.ReplySucceeded)
|
|
|
|
// A upstream connection has been established, push data back and forth
|
|
// till the session is done.
|
|
var wg sync.WaitGroup
|
|
wg.Add(2)
|
|
//upstreamConn.Write([]byte("meow 123\r\n"))
|
|
//clientConn.Write([]byte("meow 123\r\n"))
|
|
copyLoop := func(dst, src net.Conn) {
|
|
defer wg.Done()
|
|
defer dst.Close()
|
|
|
|
io.Copy(dst, src)
|
|
}
|
|
go copyLoop(upstreamConn, clientConn)
|
|
go copyLoop(clientConn, upstreamConn)
|
|
|
|
wg.Wait()
|
|
fmt.Printf("INFO/socks: Closed SOCKS connection from: %v\n", clientAddr)
|
|
return nil
|
|
}
|
|
|
|
func TestSocksServerProxyChain(t *testing.T) {
|
|
// socks client ---> socks chain ---> socks server ---> service
|
|
socksChainNet := "tcp"
|
|
socksChainAddr := "127.0.0.1:7750"
|
|
socksServerNet := "tcp"
|
|
socksServerAddr := "127.0.0.1:8850"
|
|
serviceNet := "tcp"
|
|
serviceAddr := "127.0.0.1:9950"
|
|
|
|
banner := "meow 123\r\n"
|
|
// setup the service listener
|
|
service := NewAccumulatingService(serviceNet, serviceAddr, banner)
|
|
service.Start()
|
|
defer service.Stop()
|
|
|
|
// setup the "socks server"
|
|
session := func(clientConn net.Conn) error {
|
|
return fakeSocksSessionWorker(clientConn, serviceNet, serviceAddr)
|
|
}
|
|
socksService := NewMortalService(socksServerNet, socksServerAddr, session)
|
|
socksService.Start()
|
|
defer socksService.Stop()
|
|
|
|
// setup the SOCKS proxy chain
|
|
socksConfig := SocksChainConfig{
|
|
TargetSocksNet: socksServerNet,
|
|
TargetSocksAddr: socksServerAddr,
|
|
ListenSocksNet: socksChainNet,
|
|
ListenSocksAddr: socksChainAddr,
|
|
}
|
|
wg := sync.WaitGroup{}
|
|
InitSocksListener(&socksConfig, &wg)
|
|
|
|
// setup the SOCKS client
|
|
auth := proxy.Auth{
|
|
User: "",
|
|
Password: "",
|
|
}
|
|
forward := proxy.NewPerHost(proxy.Direct, proxy.Direct)
|
|
socksClient, err := proxy.SOCKS5(socksChainNet, socksChainAddr, &auth, forward)
|
|
conn, err := socksClient.Dial(serviceNet, serviceAddr)
|
|
if err != nil {
|
|
panic(err)
|
|
}
|
|
|
|
// read a banner from the service
|
|
rd := bufio.NewReader(conn)
|
|
line := []byte{}
|
|
line, err = rd.ReadBytes('\n')
|
|
if err != nil {
|
|
panic(err)
|
|
}
|
|
if string(line) != banner {
|
|
t.Errorf("Did not receive expected banner. Got %s, wanted %s\n", string(line), banner)
|
|
t.Fail()
|
|
}
|
|
|
|
// send the service some data and verify it was received
|
|
clientData := "hello world\r\n"
|
|
conn.Write([]byte(clientData))
|
|
service.WaitUntilReceived()
|
|
if service.buffer.String() != strings.TrimSpace(clientData)+"\n" {
|
|
t.Errorf("Client sent %s but service only received %s\n", "hello world\n", service.buffer.String())
|
|
t.Fail()
|
|
}
|
|
}
|