Reincorporated socks5 code.

Fixed small but critical bug in rules matching/IP comparison.
shw_dev
shw 7 years ago
parent 3bb8d65ed1
commit 1e84a6e168

@ -0,0 +1,153 @@
/*
* client.go - SOCSK5 client implementation.
*
* To the extent possible under law, Yawning Angel has waived all copyright and
* related or neighboring rights to or-ctl-filter, using the creative commons
* "cc0" public domain dedication. See LICENSE or
* <http://creativecommons.org/publicdomain/zero/1.0/> for full details.
*/
package sgfw
import (
"fmt"
"io"
"net"
"time"
)
// Redispatch dials the provided proxy and redispatches an existing request.
func Redispatch(proxyNet, proxyAddr string, req *Request) (conn net.Conn, bndAddr *Address, err error) {
defer func() {
if err != nil && conn != nil {
conn.Close()
}
}()
conn, err = clientHandshake(proxyNet, proxyAddr, req)
if err != nil {
return nil, nil, err
}
bndAddr, err = clientCmd(conn, req)
return
}
func clientHandshake(proxyNet, proxyAddr string, req *Request) (net.Conn, error) {
conn, err := net.Dial(proxyNet, proxyAddr)
if err != nil {
return nil, err
}
if err := conn.SetDeadline(time.Now().Add(requestTimeout)); err != nil {
return conn, err
}
authMethod, err := clientNegotiateAuth(conn, req)
if err != nil {
return conn, err
}
if err := clientAuthenticate(conn, req, authMethod); err != nil {
return conn, err
}
if err := conn.SetDeadline(time.Time{}); err != nil {
return conn, err
}
return conn, nil
}
func clientNegotiateAuth(conn net.Conn, req *Request) (byte, error) {
useRFC1929 := req.Auth.Uname != nil && req.Auth.Passwd != nil
// XXX: Validate uname/passwd lengths, though should always be valid.
var buf [3]byte
buf[0] = version
buf[1] = 1
if useRFC1929 {
buf[2] = authUsernamePassword
} else {
buf[2] = authNoneRequired
}
if _, err := conn.Write(buf[:]); err != nil {
return authNoAcceptableMethods, err
}
var resp [2]byte
if _, err := io.ReadFull(conn, resp[:]); err != nil {
return authNoAcceptableMethods, err
}
if err := validateByte("version", resp[0], version); err != nil {
return authNoAcceptableMethods, err
}
if err := validateByte("method", resp[1], buf[2]); err != nil {
return authNoAcceptableMethods, err
}
return resp[1], nil
}
func clientAuthenticate(conn net.Conn, req *Request, authMethod byte) error {
switch authMethod {
case authNoneRequired:
case authUsernamePassword:
var buf []byte
buf = append(buf, authRFC1929Ver)
buf = append(buf, byte(len(req.Auth.Uname)))
buf = append(buf, req.Auth.Uname...)
buf = append(buf, byte(len(req.Auth.Passwd)))
buf = append(buf, req.Auth.Passwd...)
if _, err := conn.Write(buf); err != nil {
return err
}
var resp [2]byte
if _, err := io.ReadFull(conn, resp[:]); err != nil {
return err
}
if err := validateByte("version", resp[0], authRFC1929Ver); err != nil {
return err
}
if err := validateByte("status", resp[1], authRFC1929Success); err != nil {
return err
}
default:
panic(fmt.Sprintf("unknown authentication method: 0x%02x", authMethod))
}
return nil
}
func clientCmd(conn net.Conn, req *Request) (*Address, error) {
var buf []byte
buf = append(buf, version)
buf = append(buf, byte(req.Cmd))
buf = append(buf, rsv)
buf = append(buf, req.Addr.raw...)
if _, err := conn.Write(buf); err != nil {
return nil, err
}
var respHdr [3]byte
if _, err := io.ReadFull(conn, respHdr[:]); err != nil {
return nil, err
}
if err := validateByte("version", respHdr[0], version); err != nil {
return nil, err
}
if err := validateByte("rep", respHdr[1], byte(ReplySucceeded)); err != nil {
return nil, clientError(respHdr[1])
}
if err := validateByte("rsv", respHdr[2], rsv); err != nil {
return nil, err
}
var bndAddr Address
if err := bndAddr.read(conn); err != nil {
return nil, err
}
if err := conn.SetDeadline(time.Time{}); err != nil {
return nil, err
}
return &bndAddr, nil
}

@ -0,0 +1,280 @@
/*
* common.go - SOCSK5 common definitons/routines.
*
* To the extent possible under law, Yawning Angel has waived all copyright and
* related or neighboring rights to or-ctl-filter, using the creative commons
* "cc0" public domain dedication. See LICENSE or
* <http://creativecommons.org/publicdomain/zero/1.0/> for full details.
*/
// Package socks5 implements a SOCKS5 client/server. For more information see
// RFC 1928 and RFC 1929.
//
// Notes:
// * GSSAPI authentication, is NOT supported.
// * The authentication provided by the client is always accepted.
// * A lot of the code is shamelessly stolen from obfs4proxy.
package sgfw
import (
"errors"
"fmt"
"io"
"net"
"strconv"
"syscall"
"time"
)
const (
version = 0x05
rsv = 0x00
atypIPv4 = 0x01
atypDomainName = 0x03
atypIPv6 = 0x04
authNoneRequired = 0x00
authUsernamePassword = 0x02
authNoAcceptableMethods = 0xff
inboundTimeout = 5 * time.Second
requestTimeout = 30 * time.Second
)
var errInvalidAtyp = errors.New("invalid address type")
// ReplyCode is a SOCKS 5 reply code.
type ReplyCode byte
// The various SOCKS 5 reply codes from RFC 1928.
const (
ReplySucceeded ReplyCode = iota
ReplyGeneralFailure
ReplyConnectionNotAllowed
ReplyNetworkUnreachable
ReplyHostUnreachable
ReplyConnectionRefused
ReplyTTLExpired
ReplyCommandNotSupported
ReplyAddressNotSupported
)
// Command is a SOCKS 5 command.
type Command byte
// The various SOCKS 5 commands.
const (
CommandConnect Command = 0x01
CommandTorResolve Command = 0xf0
CommandTorResolvePTR Command = 0xf1
)
// Address is a SOCKS 5 address + port.
type Address struct {
atyp uint8
raw []byte
addrStr string
portStr string
}
// FromString parses the provided "host:port" format address and populates the
// Address fields.
func (addr *Address) FromString(addrStr string) (err error) {
addr.addrStr, addr.portStr, err = net.SplitHostPort(addrStr)
if err != nil {
return
}
var raw []byte
if ip := net.ParseIP(addr.addrStr); ip != nil {
if v4Addr := ip.To4(); v4Addr != nil {
raw = append(raw, atypIPv4)
raw = append(raw, v4Addr...)
} else if v6Addr := ip.To16(); v6Addr != nil {
raw = append(raw, atypIPv6)
raw = append(raw, v6Addr...)
} else {
return errors.New("unsupported IP address type")
}
} else {
// Must be a FQDN.
if len(addr.addrStr) > 255 {
return fmt.Errorf("invalid FQDN, len > 255 bytes (%d bytes)", len(addr.addrStr))
}
raw = append(raw, atypDomainName)
raw = append(raw, addr.addrStr...)
}
var port uint64
if port, err = strconv.ParseUint(addr.portStr, 10, 16); err != nil {
return
}
raw = append(raw, byte(port>>8))
raw = append(raw, byte(port&0xff))
addr.raw = raw
return
}
// String returns the string representation of the address, in "host:port"
// format.
func (addr *Address) String() string {
return addr.addrStr + ":" + addr.portStr
}
// HostPort returns the string representation of the addess, split into the
// host and port components.
func (addr *Address) HostPort() (string, string) {
return addr.addrStr, addr.portStr
}
// Type returns the address type from the connect command this address was
// parsed from
func (addr *Address) Type() uint8 {
return addr.atyp
}
func (addr *Address) read(conn net.Conn) (err error) {
// The address looks like:
// uint8_t atyp
// uint8_t addr[] (Length depends on atyp)
// uint16_t port
// Read the atype.
var atyp byte
if atyp, err = readByte(conn); err != nil {
return
}
addr.raw = append(addr.raw, atyp)
// Read the address.
var rawAddr []byte
switch atyp {
case atypIPv4:
rawAddr = make([]byte, net.IPv4len)
if _, err = io.ReadFull(conn, rawAddr); err != nil {
return
}
v4Addr := net.IPv4(rawAddr[0], rawAddr[1], rawAddr[2], rawAddr[3])
addr.addrStr = v4Addr.String()
case atypDomainName:
var alen byte
if alen, err = readByte(conn); err != nil {
return
}
if alen == 0 {
return fmt.Errorf("domain name with 0 length")
}
rawAddr = make([]byte, alen)
addr.raw = append(addr.raw, alen)
if _, err = io.ReadFull(conn, rawAddr); err != nil {
return
}
addr.addrStr = string(rawAddr)
case atypIPv6:
rawAddr = make([]byte, net.IPv6len)
if _, err = io.ReadFull(conn, rawAddr); err != nil {
return
}
v6Addr := make(net.IP, net.IPv6len)
copy(v6Addr[:], rawAddr)
addr.addrStr = fmt.Sprintf("[%s]", v6Addr.String())
default:
return errInvalidAtyp
}
addr.atyp = atyp
addr.raw = append(addr.raw, rawAddr...)
// Read the port.
var rawPort [2]byte
if _, err = io.ReadFull(conn, rawPort[:]); err != nil {
return
}
port := int(rawPort[0])<<8 | int(rawPort[1])
addr.portStr = fmt.Sprintf("%d", port)
addr.raw = append(addr.raw, rawPort[:]...)
return
}
// ErrorToReplyCode converts an error to the "best" reply code.
func ErrorToReplyCode(err error) ReplyCode {
if cErr, ok := err.(clientError); ok {
return ReplyCode(cErr)
}
opErr, ok := err.(*net.OpError)
if !ok {
return ReplyGeneralFailure
}
errno, ok := opErr.Err.(syscall.Errno)
if !ok {
return ReplyGeneralFailure
}
switch errno {
case syscall.EADDRNOTAVAIL:
return ReplyAddressNotSupported
case syscall.ETIMEDOUT:
return ReplyTTLExpired
case syscall.ENETUNREACH:
return ReplyNetworkUnreachable
case syscall.EHOSTUNREACH:
return ReplyHostUnreachable
case syscall.ECONNREFUSED, syscall.ECONNRESET:
return ReplyConnectionRefused
default:
return ReplyGeneralFailure
}
}
// Request describes a SOCKS 5 request.
type Request struct {
Auth AuthInfo
Cmd Command
Addr Address
conn net.Conn
}
type clientError ReplyCode
func (e clientError) Error() string {
switch ReplyCode(e) {
case ReplySucceeded:
return "socks5: succeeded"
case ReplyGeneralFailure:
return "socks5: general failure"
case ReplyConnectionNotAllowed:
return "socks5: connection not allowed"
case ReplyNetworkUnreachable:
return "socks5: network unreachable"
case ReplyHostUnreachable:
return "socks5: host unreachable"
case ReplyConnectionRefused:
return "socks5: connection refused"
case ReplyTTLExpired:
return "socks5: ttl expired"
case ReplyCommandNotSupported:
return "socks5: command not supported"
case ReplyAddressNotSupported:
return "socks5: address not supported"
default:
return fmt.Sprintf("socks5: reply code: 0x%02x", e)
}
}
func readByte(conn net.Conn) (byte, error) {
var tmp [1]byte
if _, err := conn.Read(tmp[:]); err != nil {
return 0, err
}
return tmp[0], nil
}
func validateByte(descr string, val, expected byte) error {
if val != expected {
return fmt.Errorf("message field '%s' was 0x%02x (expected 0x%02x)", descr, val, expected)
}
return nil
}

@ -79,7 +79,7 @@ func (r *Rule) match(dst net.IP, dstPort uint16, hostname string) bool {
if r.hostname != "" {
return r.hostname == hostname
}
return r.addr == binary.BigEndian.Uint32(dst)
return r.addr == binary.BigEndian.Uint32(dst.To4())
}
func (rl *RuleList) filterPacket(p *nfqueue.Packet, pinfo *procsnitch.Info, hostname string) FilterResult {

@ -0,0 +1,200 @@
/*
* server.go - SOCSK5 server implementation.
*
* To the extent possible under law, Yawning Angel has waived all copyright and
* related or neighboring rights to or-ctl-filter, using the creative commons
* "cc0" public domain dedication. See LICENSE or
* <http://creativecommons.org/publicdomain/zero/1.0/> for full details.
*/
package sgfw
import (
"bytes"
"fmt"
"io"
"net"
"time"
)
// Handshake attempts to handle a incoming client handshake over the provided
// connection and receive the SOCKS5 request. The routine handles sending
// appropriate errors if applicable, but will not close the connection.
func Handshake(conn net.Conn) (*Request, error) {
// Arm the handshake timeout.
var err error
if err = conn.SetDeadline(time.Now().Add(inboundTimeout)); err != nil {
return nil, err
}
defer func() {
// Disarm the handshake timeout, only propagate the error if
// the handshake was successful.
nerr := conn.SetDeadline(time.Time{})
if err == nil {
err = nerr
}
}()
req := new(Request)
req.conn = conn
// Negotiate the protocol version and authentication method.
var method byte
if method, err = req.negotiateAuth(); err != nil {
return nil, err
}
// Authenticate if neccecary.
if err = req.authenticate(method); err != nil {
return nil, err
}
// Read the client command.
if err = req.readCommand(); err != nil {
return nil, err
}
return req, err
}
// Reply sends a SOCKS5 reply to the corresponding request. The BND.ADDR and
// BND.PORT fields are always set to an address/port corresponding to
// "0.0.0.0:0".
func (req *Request) Reply(code ReplyCode) error {
return req.ReplyAddr(code, nil)
}
// ReplyAddr sends a SOCKS5 reply to the corresponding request. The BND.ADDR
// and BND.PORT fields are specified by addr, or "0.0.0.0:0" if not provided.
func (req *Request) ReplyAddr(code ReplyCode, addr *Address) error {
// The server sends a reply message.
// uint8_t ver (0x05)
// uint8_t rep
// uint8_t rsv (0x00)
// uint8_t atyp
// uint8_t bnd_addr[]
// uint16_t bnd_port
resp := []byte{version, byte(code), rsv}
if addr == nil {
var nilAddr [net.IPv4len + 2]byte
resp = append(resp, atypIPv4)
resp = append(resp, nilAddr[:]...)
} else {
resp = append(resp, addr.raw...)
}
_, err := req.conn.Write(resp[:])
return err
}
func (req *Request) negotiateAuth() (byte, error) {
// The client sends a version identifier/selection message.
// uint8_t ver (0x05)
// uint8_t nmethods (>= 1).
// uint8_t methods[nmethods]
var err error
if err = req.readByteVerify("version", version); err != nil {
return 0, err
}
// Read the number of methods, and the methods.
var nmethods byte
method := byte(authNoAcceptableMethods)
if nmethods, err = req.readByte(); err != nil {
return method, err
}
methods := make([]byte, nmethods)
if _, err := io.ReadFull(req.conn, methods); err != nil {
return 0, err
}
// Pick the best authentication method, prioritizing authenticating
// over not if both options are present.
if bytes.IndexByte(methods, authUsernamePassword) != -1 {
method = authUsernamePassword
} else if bytes.IndexByte(methods, authNoneRequired) != -1 {
method = authNoneRequired
}
// The server sends a method selection message.
// uint8_t ver (0x05)
// uint8_t method
msg := []byte{version, method}
if _, err = req.conn.Write(msg); err != nil {
return 0, err
}
return method, nil
}
func (req *Request) authenticate(method byte) error {
switch method {
case authNoneRequired:
return nil
case authUsernamePassword:
return req.authRFC1929()
case authNoAcceptableMethods:
return fmt.Errorf("no acceptable authentication methods")
default:
// This should never happen as only supported auth methods should be
// negotiated.
return fmt.Errorf("negotiated unsupported method 0x%02x", method)
}
}
func (req *Request) readCommand() error {
// The client sends the request details.
// uint8_t ver (0x05)
// uint8_t cmd
// uint8_t rsv (0x00)
// uint8_t atyp
// uint8_t dst_addr[]
// uint16_t dst_port
var err error
var cmd byte
if err = req.readByteVerify("version", version); err != nil {
req.Reply(ReplyGeneralFailure)
return err
}
if cmd, err = req.readByte(); err != nil {
req.Reply(ReplyGeneralFailure)
return err
}
switch Command(cmd) {
case CommandConnect, CommandTorResolve, CommandTorResolvePTR:
req.Cmd = Command(cmd)
default:
req.Reply(ReplyCommandNotSupported)
return fmt.Errorf("unsupported SOCKS command: 0x%02x", cmd)
}
if err = req.readByteVerify("reserved", rsv); err != nil {
req.Reply(ReplyGeneralFailure)
return err
}
// Read the destination address/port.
err = req.Addr.read(req.conn)
if err == errInvalidAtyp {
req.Reply(ReplyAddressNotSupported)
} else if err != nil {
req.Reply(ReplyGeneralFailure)
}
return err
}
func (req *Request) readByte() (byte, error) {
return readByte(req.conn)
}
func (req *Request) readByteVerify(descr string, expected byte) error {
val, err := req.readByte()
if err != nil {
return err
}
return validateByte(descr, val, expected)
}

@ -0,0 +1,84 @@
/*
* server_rfc1929.go - SOCSK 5 server authentication.
*
* To the extent possible under law, Yawning Angel has waived all copyright and
* related or neighboring rights to or-ctl-filter, using the creative commons
* "cc0" public domain dedication. See LICENSE or
* <http://creativecommons.org/publicdomain/zero/1.0/> for full details.
*/
package sgfw
import (
"fmt"
"io"
)
const (
authRFC1929Ver = 0x01
authRFC1929Success = 0x00
authRFC1929Fail = 0x01
)
// AuthInfo is the RFC 1929 Username/Password authentication data.
type AuthInfo struct {
Uname []byte
Passwd []byte
}
func (req *Request) authRFC1929() (err error) {
sendErrResp := func() {
// Swallow write/flush errors, the auth failure is the relevant error.
resp := []byte{authRFC1929Ver, authRFC1929Fail}
req.conn.Write(resp[:])
}
// The client sends a Username/Password request.
// uint8_t ver (0x01)
// uint8_t ulen (>= 1)
// uint8_t uname[ulen]
// uint8_t plen (>= 1)
// uint8_t passwd[plen]
if err = req.readByteVerify("auth version", authRFC1929Ver); err != nil {
sendErrResp()
return
}
// Read the username.
var ulen byte
if ulen, err = req.readByte(); err != nil {
sendErrResp()
return
} else if ulen < 1 {
sendErrResp()
return fmt.Errorf("username with 0 length")
}
uname := make([]byte, ulen)
if _, err = io.ReadFull(req.conn, uname); err != nil {
sendErrResp()
return
}
// Read the password.
var plen byte
if plen, err = req.readByte(); err != nil {
sendErrResp()
return
} else if plen < 1 {
sendErrResp()
return fmt.Errorf("password with 0 length")
}
passwd := make([]byte, plen)
if _, err = io.ReadFull(req.conn, passwd); err != nil {
sendErrResp()
return
}
req.Auth.Uname = uname
req.Auth.Passwd = passwd
resp := []byte{authRFC1929Ver, authRFC1929Success}
_, err = req.conn.Write(resp[:])
return
}

@ -7,6 +7,9 @@ import (
"sync"
"syscall"
"time"
"bufio"
"encoding/json"
"strings"
"github.com/op/go-logging"
@ -107,8 +110,55 @@ func (fw *Firewall) runFilter() {
}
}
type SocksJsonConfig struct {
SocksListener string
TorSocks string
}
var commentRegexp = regexp.MustCompile("^[ \t]*#")
const defaultSocksCfgPath = "/etc/fw-daemon-socks.json"
func loadSocksConfiguration(configFilePath string) (*SocksJsonConfig, error) {
config := SocksJsonConfig{}
file, err := os.Open(configFilePath)
if err != nil {
return nil, err
}
scanner := bufio.NewScanner(file)
bs := ""
for scanner.Scan() {
line := scanner.Text()
if !commentRegexp.MatchString(line) {
bs += line + "\n"
}
}
if err := json.Unmarshal([]byte(bs), &config); err != nil {
return nil, err
}
return &config, nil
}
func getSocksChainConfig(config *SocksJsonConfig) *socksChainConfig {
// XXX
fields := strings.Split(config.TorSocks, "|")
torSocksNet := fields[0]
torSocksAddr := fields[1]
fields = strings.Split(config.SocksListener, "|")
socksListenNet := fields[0]
socksListenAddr := fields[1]
socksConfig := socksChainConfig{
TargetSocksNet: torSocksNet,
TargetSocksAddr: torSocksAddr,
ListenSocksNet: socksListenNet,
ListenSocksAddr: socksListenAddr,
}
log.Notice("Loaded Socks chain config:")
log.Notice(socksConfig)
return &socksConfig
}
func Main() {
readConfig()
logBackend := setupLoggerBackend(FirewallConfig.LoggingLevel)
@ -141,6 +191,26 @@ func Main() {
fw.loadRules()
/*
go func() {
http.ListenAndServe("localhost:6060", nil)
}()
*/
wg := sync.WaitGroup{}
config, err := loadSocksConfiguration(defaultSocksCfgPath)
if err != nil && !os.IsNotExist(err) {
panic(err)
}
if config != nil {
socksConfig := getSocksChainConfig(config)
chain := NewSocksChain(socksConfig, &wg, fw)
chain.start()
} else {
log.Notice("Did not find SOCKS5 configuration file at", defaultSocksCfgPath, "; ignoring subsystem...")
}
fw.runFilter()
// observe process signals and either

@ -0,0 +1,261 @@
package sgfw
import (
"io"
"net"
"os"
"sync"
"github.com/subgraph/go-procsnitch"
"strconv"
)
type socksChainConfig struct {
TargetSocksNet string
TargetSocksAddr string
ListenSocksNet string
ListenSocksAddr string
}
type socksChain struct {
cfg *socksChainConfig
fw *Firewall
listener net.Listener
wg *sync.WaitGroup
procInfo procsnitch.ProcInfo
}
type socksChainSession struct {
cfg *socksChainConfig
clientConn net.Conn
upstreamConn net.Conn
req *Request
bndAddr *Address
optData []byte
procInfo procsnitch.ProcInfo
server *socksChain
}
const (
socksVerdictDrop = 1
socksVerdictAccept = 2
)
type pendingSocksConnection struct {
pol *Policy
hname string
destIP net.IP
destPort uint16
pinfo *procsnitch.Info
verdict chan int
}
func (sc *pendingSocksConnection) policy() *Policy {
return sc.pol
}
func (sc *pendingSocksConnection) procInfo() *procsnitch.Info {
return sc.pinfo
}
func (sc *pendingSocksConnection) hostname() string {
return sc.hname
}
func (sc *pendingSocksConnection) dst() net.IP {
return sc.destIP
}
func (sc *pendingSocksConnection) dstPort() uint16 {
return sc.destPort
}
func (sc *pendingSocksConnection) deliverVerdict(v int) {
sc.verdict <- v
close(sc.verdict)
}
func (sc *pendingSocksConnection) accept() { sc.deliverVerdict(socksVerdictAccept) }
func (sc *pendingSocksConnection) drop() { sc.deliverVerdict(socksVerdictDrop) }
func (sc *pendingSocksConnection) print() string { return "socks connection" }
func NewSocksChain(cfg *socksChainConfig, wg *sync.WaitGroup, fw *Firewall) *socksChain {
chain := socksChain{
cfg: cfg,
fw: fw,
wg: wg,
procInfo: procsnitch.SystemProcInfo{},
}
return &chain
}
// Start initializes the SOCKS 5 server and starts
// accepting connections.
func (s *socksChain) start() {
var err error
s.listener, err = net.Listen(s.cfg.ListenSocksNet, s.cfg.ListenSocksAddr)
if err != nil {
log.Errorf("ERR/socks: Failed to listen on the socks address: %v", err)
os.Exit(1)
}
s.wg.Add(1)
go s.socksAcceptLoop()
}
func (s *socksChain) socksAcceptLoop() error {
defer s.wg.Done()
defer s.listener.Close()
for {
conn, err := s.listener.Accept()
if err != nil {
if e, ok := err.(net.Error); ok && !e.Temporary() {
log.Infof("ERR/socks: Failed to Accept(): %v", err)
return err
}
continue
}
session := &socksChainSession{cfg: s.cfg, clientConn: conn, procInfo: s.procInfo, server: s}
go session.sessionWorker()
}
}
func (c *socksChainSession) sessionWorker() {
defer c.clientConn.Close()
clientAddr := c.clientConn.RemoteAddr()
log.Infof("INFO/socks: New connection from: %v", clientAddr)
// Do the SOCKS handshake with the client, and read the command.
var err error
if c.req, err = Handshake(c.clientConn); err != nil {
log.Infof("ERR/socks: Failed SOCKS5 handshake: %v", err)
return
}
switch c.req.Cmd {
case CommandTorResolve, CommandTorResolvePTR:
err = c.dispatchTorSOCKS()
// If we reach here, the request has been dispatched and completed.
if err == nil {
// Successfully even, send the response back with the addresc.
c.req.ReplyAddr(ReplySucceeded, c.bndAddr)
}
case CommandConnect:
if !c.filterConnect() {
c.req.Reply(ReplyConnectionRefused)
return
}
c.handleConnect()
default:
// Should *NEVER* happen, validated as part of handshake.
log.Infof("BUG/socks: Unsupported SOCKS command: 0x%02x", c.req.Cmd)
c.req.Reply(ReplyCommandNotSupported)
}
}
func (c *socksChainSession) addressDetails() (string, net.IP, uint16) {
addr := c.req.Addr
host, pstr := addr.HostPort()
port, err := strconv.ParseUint(pstr, 10, 16)
if err != nil || port == 0 || port > 0xFFFF {
log.Warningf("Illegal port value in socks address: %v", addr)
return "", nil, 0
}
if addr.Type() == 3 {
return host, nil, uint16(port)
}
ip := net.ParseIP(host)
if ip == nil {
log.Warningf("Failed to extract address information from socks address: %v", addr)
}
return "", ip, uint16(port)
}
func (c *socksChainSession) filterConnect() bool {
pinfo := procsnitch.FindProcessForConnection(c.clientConn, c.procInfo)
if pinfo == nil {
log.Warningf("No proc found for connection from: %s", c.clientConn.RemoteAddr())
return false
}
policy := c.server.fw.PolicyForPath(pinfo.ExePath)
hostname, ip, port := c.addressDetails()
if ip == nil && hostname == "" {
return false
}
result := policy.rules.filter(nil, ip, port, hostname, pinfo)
switch result {
case FILTER_DENY:
return false
case FILTER_ALLOW:
return true
case FILTER_PROMPT:
pending := &pendingSocksConnection{
pol: policy,
hname: hostname,
destIP: ip,
destPort: port,
pinfo: pinfo,
verdict: make(chan int),
}
policy.processPromptResult(pending)
v := <-pending.verdict
if v == socksVerdictAccept {
return true
}
}
return false
}
func (c *socksChainSession) handleConnect() {
err := c.dispatchTorSOCKS()
if err != nil {
return
}
c.req.Reply(ReplySucceeded)
defer c.upstreamConn.Close()
if c.optData != nil {
if _, err = c.upstreamConn.Write(c.optData); err != nil {
log.Infof("ERR/socks: Failed writing OptData: %v", err)
return
}
c.optData = nil
}
// A upstream connection has been established, push data back and forth
// till the session is done.
c.forwardTraffic()
log.Infof("INFO/socks: Closed SOCKS connection from: %v", c.clientConn.RemoteAddr())
}
func (c *socksChainSession) forwardTraffic() {
var wg sync.WaitGroup
wg.Add(2)
copyLoop := func(dst, src net.Conn) {
defer wg.Done()
defer dst.Close()
io.Copy(dst, src)
}
go copyLoop(c.upstreamConn, c.clientConn)
go copyLoop(c.clientConn, c.upstreamConn)
wg.Wait()
}
func (c *socksChainSession) dispatchTorSOCKS() (err error) {
c.upstreamConn, c.bndAddr, err = Redispatch(c.cfg.TargetSocksNet, c.cfg.TargetSocksAddr, c.req)
if err != nil {
c.req.Reply(ErrorToReplyCode(err))
}
return
}

@ -0,0 +1,305 @@
package sgfw
import (
"bufio"
"bytes"
"fmt"
"io"
"net"
"strings"
"sync"
"testing"
"time"
"github.com/subgraph/fw-daemon/socks5"
"golang.org/x/net/proxy"
)
// MortalService can be killed at any time.
type MortalService struct {
network string
address string
connectionCallback func(net.Conn) error
conns []net.Conn
quit chan bool
listener net.Listener
waitGroup *sync.WaitGroup
}
// NewMortalService creates a new MortalService
func NewMortalService(network, address string, connectionCallback func(net.Conn) error) *MortalService {
l := MortalService{
network: network,
address: address,
connectionCallback: connectionCallback,
conns: make([]net.Conn, 0, 10),
quit: make(chan bool),
waitGroup: &sync.WaitGroup{},
}
return &l
}
// Stop will kill our listener and all it's connections
func (l *MortalService) Stop() {
log.Infof("stopping listener service %s:%s", l.network, l.address)
close(l.quit)
if l.listener != nil {
l.listener.Close()
}
l.waitGroup.Wait()
}
func (l *MortalService) acceptLoop() {
defer l.waitGroup.Done()
defer func() {
log.Infof("stoping listener service %s:%s", l.network, l.address)
for i, conn := range l.conns {
if conn != nil {
log.Infof("Closing connection #%d", i)
conn.Close()
}
}
}()
defer l.listener.Close()
for {
conn, err := l.listener.Accept()
if nil != err {
if opErr, ok := err.(*net.OpError); ok && opErr.Timeout() {
continue
} else {
log.Infof("MortalService connection accept failure: %s\n", err)
select {
case <-l.quit:
return
default:
}
continue
}
}
l.conns = append(l.conns, conn)
go l.handleConnection(conn, len(l.conns)-1)
}
}
func (l *MortalService) createDeadlinedListener() error {
if l.network == "tcp" {
tcpAddr, err := net.ResolveTCPAddr("tcp", l.address)
if err != nil {
return fmt.Errorf("MortalService.createDeadlinedListener %s %s failure: %s", l.network, l.address, err)
}
tcpListener, err := net.ListenTCP("tcp", tcpAddr)
if err != nil {
return fmt.Errorf("MortalService.createDeadlinedListener %s %s failure: %s", l.network, l.address, err)
}
tcpListener.SetDeadline(time.Now().Add(1e9))
l.listener = tcpListener
return nil
} else if l.network == "unix" {
unixAddr, err := net.ResolveUnixAddr("unix", l.address)
if err != nil {
return fmt.Errorf("MortalService.createDeadlinedListener %s %s failure: %s", l.network, l.address, err)
}
unixListener, err := net.ListenUnix("unix", unixAddr)
if err != nil {
return fmt.Errorf("MortalService.createDeadlinedListener %s %s failure: %s", l.network, l.address, err)
}
unixListener.SetDeadline(time.Now().Add(1e9))
l.listener = unixListener
return nil
} else {
panic("")
}
return nil
}
// Start the MortalService
func (l *MortalService) Start() error {
var err error
err = l.createDeadlinedListener()
if err != nil {
return err
}
l.waitGroup.Add(1)
go l.acceptLoop()
return nil
}
func (l *MortalService) handleConnection(conn net.Conn, id int) error {
defer func() {
log.Infof("Closing connection #%d", id)
conn.Close()
l.conns[id] = nil
}()
log.Infof("Starting connection #%d", id)
for {
if err := l.connectionCallback(conn); err != nil {
log.Error(err.Error())
return err
}
return nil
}
}
type AccumulatingService struct {
net, address string
banner string
buffer bytes.Buffer
mortalService *MortalService
hasProtocolInfo bool
hasAuthenticate bool
receivedChan chan bool
}
func NewAccumulatingService(net, address, banner string) *AccumulatingService {
l := AccumulatingService{
net: net,
address: address,
banner: banner,
hasProtocolInfo: true,
hasAuthenticate: true,
receivedChan: make(chan bool, 0),
}
return &l
}
func (a *AccumulatingService) Start() {
a.mortalService = NewMortalService(a.net, a.address, a.SessionWorker)
a.mortalService.Start()
}
func (a *AccumulatingService) Stop() {
fmt.Println("AccumulatingService STOP")
a.mortalService.Stop()
}
func (a *AccumulatingService) WaitUntilReceived() {
<-a.receivedChan
}
func (a *AccumulatingService) SessionWorker(conn net.Conn) error {
connReader := bufio.NewReader(conn)
conn.Write([]byte(a.banner))
for {
line, err := connReader.ReadBytes('\n')
if err != nil {
fmt.Printf("AccumulatingService read error: %s\n", err)
}
lineStr := strings.TrimSpace(string(line))
a.buffer.WriteString(lineStr + "\n")
a.receivedChan <- true
}
return nil
}
func fakeSocksSessionWorker(clientConn net.Conn, targetNet, targetAddr string) error {
defer clientConn.Close()
clientAddr := clientConn.RemoteAddr()
fmt.Printf("INFO/socks: New connection from: %v\n", clientAddr)
// Do the SOCKS handshake with the client, and read the command.
req, err := socks5.Handshake(clientConn)
if err != nil {
panic(fmt.Sprintf("ERR/socks: Failed SOCKS5 handshake: %v", err))
}
var upstreamConn net.Conn
upstreamConn, err = net.Dial(targetNet, targetAddr)
if err != nil {
panic(err)
}
defer upstreamConn.Close()
req.Reply(socks5.ReplySucceeded)
// A upstream connection has been established, push data back and forth
// till the session is done.
var wg sync.WaitGroup
wg.Add(2)
copyLoop := func(dst, src net.Conn) {
defer wg.Done()
defer dst.Close()
io.Copy(dst, src)
}
go copyLoop(upstreamConn, clientConn)
go copyLoop(clientConn, upstreamConn)
wg.Wait()
fmt.Printf("INFO/socks: Closed SOCKS connection from: %v\n", clientAddr)
return nil
}
func TestSocksServerProxyChain(t *testing.T) {
// socks client ---> socks chain ---> socks server ---> service
socksChainNet := "tcp"
socksChainAddr := "127.0.0.1:7750"
socksServerNet := "tcp"
socksServerAddr := "127.0.0.1:8850"
serviceNet := "tcp"
serviceAddr := "127.0.0.1:9950"
banner := "meow 123\r\n"
// setup the service listener
service := NewAccumulatingService(serviceNet, serviceAddr, banner)
service.Start()
defer service.Stop()
// setup the "socks server"
session := func(clientConn net.Conn) error {
return fakeSocksSessionWorker(clientConn, serviceNet, serviceAddr)
}
socksService := NewMortalService(socksServerNet, socksServerAddr, session)
socksService.Start()
defer socksService.Stop()
// setup the SOCKS proxy chain
socksConfig := socksChainConfig{
TargetSocksNet: socksServerNet,
TargetSocksAddr: socksServerAddr,
ListenSocksNet: socksChainNet,
ListenSocksAddr: socksChainAddr,
}
wg := sync.WaitGroup{}
ds := dbusServer{}
chain := NewSocksChain(&socksConfig, &wg, &ds)
chain.start()
// setup the SOCKS client
auth := proxy.Auth{
User: "",
Password: "",
}
forward := proxy.NewPerHost(proxy.Direct, proxy.Direct)
socksClient, err := proxy.SOCKS5(socksChainNet, socksChainAddr, &auth, forward)
conn, err := socksClient.Dial(serviceNet, serviceAddr)
if err != nil {
panic(err)
}
// read a banner from the service
rd := bufio.NewReader(conn)
line := []byte{}
line, err = rd.ReadBytes('\n')
if err != nil {
panic(err)
}
if string(line) != banner {
t.Errorf("Did not receive expected banner. Got %s, wanted %s\n", string(line), banner)
t.Fail()
}
// send the service some data and verify it was received
clientData := "hello world\r\n"
conn.Write([]byte(clientData))
service.WaitUntilReceived()
if service.buffer.String() != strings.TrimSpace(clientData)+"\n" {
t.Errorf("Client sent %s but service only received %s\n", "hello world\n", service.buffer.String())
t.Fail()
}
}
Loading…
Cancel
Save