diff --git a/sgfw/client.go b/sgfw/client.go
new file mode 100644
index 0000000..48be55c
--- /dev/null
+++ b/sgfw/client.go
@@ -0,0 +1,153 @@
+/*
+ * client.go - SOCSK5 client implementation.
+ *
+ * To the extent possible under law, Yawning Angel has waived all copyright and
+ * related or neighboring rights to or-ctl-filter, using the creative commons
+ * "cc0" public domain dedication. See LICENSE or
+ * for full details.
+ */
+
+package sgfw
+
+import (
+ "fmt"
+ "io"
+ "net"
+ "time"
+)
+
+// Redispatch dials the provided proxy and redispatches an existing request.
+func Redispatch(proxyNet, proxyAddr string, req *Request) (conn net.Conn, bndAddr *Address, err error) {
+ defer func() {
+ if err != nil && conn != nil {
+ conn.Close()
+ }
+ }()
+
+ conn, err = clientHandshake(proxyNet, proxyAddr, req)
+ if err != nil {
+ return nil, nil, err
+ }
+ bndAddr, err = clientCmd(conn, req)
+ return
+}
+
+func clientHandshake(proxyNet, proxyAddr string, req *Request) (net.Conn, error) {
+ conn, err := net.Dial(proxyNet, proxyAddr)
+ if err != nil {
+ return nil, err
+ }
+ if err := conn.SetDeadline(time.Now().Add(requestTimeout)); err != nil {
+ return conn, err
+ }
+ authMethod, err := clientNegotiateAuth(conn, req)
+ if err != nil {
+ return conn, err
+ }
+ if err := clientAuthenticate(conn, req, authMethod); err != nil {
+ return conn, err
+ }
+ if err := conn.SetDeadline(time.Time{}); err != nil {
+ return conn, err
+ }
+
+ return conn, nil
+}
+
+func clientNegotiateAuth(conn net.Conn, req *Request) (byte, error) {
+ useRFC1929 := req.Auth.Uname != nil && req.Auth.Passwd != nil
+ // XXX: Validate uname/passwd lengths, though should always be valid.
+
+ var buf [3]byte
+ buf[0] = version
+ buf[1] = 1
+ if useRFC1929 {
+ buf[2] = authUsernamePassword
+ } else {
+ buf[2] = authNoneRequired
+ }
+
+ if _, err := conn.Write(buf[:]); err != nil {
+ return authNoAcceptableMethods, err
+ }
+
+ var resp [2]byte
+ if _, err := io.ReadFull(conn, resp[:]); err != nil {
+ return authNoAcceptableMethods, err
+ }
+ if err := validateByte("version", resp[0], version); err != nil {
+ return authNoAcceptableMethods, err
+ }
+ if err := validateByte("method", resp[1], buf[2]); err != nil {
+ return authNoAcceptableMethods, err
+ }
+
+ return resp[1], nil
+}
+
+func clientAuthenticate(conn net.Conn, req *Request, authMethod byte) error {
+ switch authMethod {
+ case authNoneRequired:
+ case authUsernamePassword:
+ var buf []byte
+ buf = append(buf, authRFC1929Ver)
+ buf = append(buf, byte(len(req.Auth.Uname)))
+ buf = append(buf, req.Auth.Uname...)
+ buf = append(buf, byte(len(req.Auth.Passwd)))
+ buf = append(buf, req.Auth.Passwd...)
+ if _, err := conn.Write(buf); err != nil {
+ return err
+ }
+
+ var resp [2]byte
+ if _, err := io.ReadFull(conn, resp[:]); err != nil {
+ return err
+ }
+ if err := validateByte("version", resp[0], authRFC1929Ver); err != nil {
+ return err
+ }
+ if err := validateByte("status", resp[1], authRFC1929Success); err != nil {
+ return err
+ }
+ default:
+ panic(fmt.Sprintf("unknown authentication method: 0x%02x", authMethod))
+ }
+ return nil
+}
+
+func clientCmd(conn net.Conn, req *Request) (*Address, error) {
+ var buf []byte
+ buf = append(buf, version)
+ buf = append(buf, byte(req.Cmd))
+ buf = append(buf, rsv)
+ buf = append(buf, req.Addr.raw...)
+ if _, err := conn.Write(buf); err != nil {
+ return nil, err
+ }
+
+ var respHdr [3]byte
+ if _, err := io.ReadFull(conn, respHdr[:]); err != nil {
+ return nil, err
+ }
+
+ if err := validateByte("version", respHdr[0], version); err != nil {
+ return nil, err
+ }
+ if err := validateByte("rep", respHdr[1], byte(ReplySucceeded)); err != nil {
+ return nil, clientError(respHdr[1])
+ }
+ if err := validateByte("rsv", respHdr[2], rsv); err != nil {
+ return nil, err
+ }
+
+ var bndAddr Address
+ if err := bndAddr.read(conn); err != nil {
+ return nil, err
+ }
+
+ if err := conn.SetDeadline(time.Time{}); err != nil {
+ return nil, err
+ }
+
+ return &bndAddr, nil
+}
diff --git a/sgfw/common.go b/sgfw/common.go
new file mode 100644
index 0000000..1717722
--- /dev/null
+++ b/sgfw/common.go
@@ -0,0 +1,280 @@
+/*
+ * common.go - SOCSK5 common definitons/routines.
+ *
+ * To the extent possible under law, Yawning Angel has waived all copyright and
+ * related or neighboring rights to or-ctl-filter, using the creative commons
+ * "cc0" public domain dedication. See LICENSE or
+ * for full details.
+ */
+
+// Package socks5 implements a SOCKS5 client/server. For more information see
+// RFC 1928 and RFC 1929.
+//
+// Notes:
+// * GSSAPI authentication, is NOT supported.
+// * The authentication provided by the client is always accepted.
+// * A lot of the code is shamelessly stolen from obfs4proxy.
+package sgfw
+
+import (
+ "errors"
+ "fmt"
+ "io"
+ "net"
+ "strconv"
+ "syscall"
+ "time"
+)
+
+const (
+ version = 0x05
+ rsv = 0x00
+
+ atypIPv4 = 0x01
+ atypDomainName = 0x03
+ atypIPv6 = 0x04
+
+ authNoneRequired = 0x00
+ authUsernamePassword = 0x02
+ authNoAcceptableMethods = 0xff
+
+ inboundTimeout = 5 * time.Second
+ requestTimeout = 30 * time.Second
+)
+
+var errInvalidAtyp = errors.New("invalid address type")
+
+// ReplyCode is a SOCKS 5 reply code.
+type ReplyCode byte
+
+// The various SOCKS 5 reply codes from RFC 1928.
+const (
+ ReplySucceeded ReplyCode = iota
+ ReplyGeneralFailure
+ ReplyConnectionNotAllowed
+ ReplyNetworkUnreachable
+ ReplyHostUnreachable
+ ReplyConnectionRefused
+ ReplyTTLExpired
+ ReplyCommandNotSupported
+ ReplyAddressNotSupported
+)
+
+// Command is a SOCKS 5 command.
+type Command byte
+
+// The various SOCKS 5 commands.
+const (
+ CommandConnect Command = 0x01
+ CommandTorResolve Command = 0xf0
+ CommandTorResolvePTR Command = 0xf1
+)
+
+// Address is a SOCKS 5 address + port.
+type Address struct {
+ atyp uint8
+ raw []byte
+ addrStr string
+ portStr string
+}
+
+// FromString parses the provided "host:port" format address and populates the
+// Address fields.
+func (addr *Address) FromString(addrStr string) (err error) {
+ addr.addrStr, addr.portStr, err = net.SplitHostPort(addrStr)
+ if err != nil {
+ return
+ }
+
+ var raw []byte
+ if ip := net.ParseIP(addr.addrStr); ip != nil {
+ if v4Addr := ip.To4(); v4Addr != nil {
+ raw = append(raw, atypIPv4)
+ raw = append(raw, v4Addr...)
+ } else if v6Addr := ip.To16(); v6Addr != nil {
+ raw = append(raw, atypIPv6)
+ raw = append(raw, v6Addr...)
+ } else {
+ return errors.New("unsupported IP address type")
+ }
+ } else {
+ // Must be a FQDN.
+ if len(addr.addrStr) > 255 {
+ return fmt.Errorf("invalid FQDN, len > 255 bytes (%d bytes)", len(addr.addrStr))
+ }
+ raw = append(raw, atypDomainName)
+ raw = append(raw, addr.addrStr...)
+ }
+
+ var port uint64
+ if port, err = strconv.ParseUint(addr.portStr, 10, 16); err != nil {
+ return
+ }
+ raw = append(raw, byte(port>>8))
+ raw = append(raw, byte(port&0xff))
+
+ addr.raw = raw
+ return
+}
+
+// String returns the string representation of the address, in "host:port"
+// format.
+func (addr *Address) String() string {
+ return addr.addrStr + ":" + addr.portStr
+}
+
+// HostPort returns the string representation of the addess, split into the
+// host and port components.
+func (addr *Address) HostPort() (string, string) {
+ return addr.addrStr, addr.portStr
+}
+
+// Type returns the address type from the connect command this address was
+// parsed from
+func (addr *Address) Type() uint8 {
+ return addr.atyp
+}
+
+func (addr *Address) read(conn net.Conn) (err error) {
+ // The address looks like:
+ // uint8_t atyp
+ // uint8_t addr[] (Length depends on atyp)
+ // uint16_t port
+
+ // Read the atype.
+ var atyp byte
+ if atyp, err = readByte(conn); err != nil {
+ return
+ }
+ addr.raw = append(addr.raw, atyp)
+
+ // Read the address.
+ var rawAddr []byte
+ switch atyp {
+ case atypIPv4:
+ rawAddr = make([]byte, net.IPv4len)
+ if _, err = io.ReadFull(conn, rawAddr); err != nil {
+ return
+ }
+ v4Addr := net.IPv4(rawAddr[0], rawAddr[1], rawAddr[2], rawAddr[3])
+ addr.addrStr = v4Addr.String()
+ case atypDomainName:
+ var alen byte
+ if alen, err = readByte(conn); err != nil {
+ return
+ }
+ if alen == 0 {
+ return fmt.Errorf("domain name with 0 length")
+ }
+ rawAddr = make([]byte, alen)
+ addr.raw = append(addr.raw, alen)
+ if _, err = io.ReadFull(conn, rawAddr); err != nil {
+ return
+ }
+ addr.addrStr = string(rawAddr)
+ case atypIPv6:
+ rawAddr = make([]byte, net.IPv6len)
+ if _, err = io.ReadFull(conn, rawAddr); err != nil {
+ return
+ }
+ v6Addr := make(net.IP, net.IPv6len)
+ copy(v6Addr[:], rawAddr)
+ addr.addrStr = fmt.Sprintf("[%s]", v6Addr.String())
+ default:
+ return errInvalidAtyp
+ }
+ addr.atyp = atyp
+ addr.raw = append(addr.raw, rawAddr...)
+
+ // Read the port.
+ var rawPort [2]byte
+ if _, err = io.ReadFull(conn, rawPort[:]); err != nil {
+ return
+ }
+ port := int(rawPort[0])<<8 | int(rawPort[1])
+ addr.portStr = fmt.Sprintf("%d", port)
+ addr.raw = append(addr.raw, rawPort[:]...)
+
+ return
+}
+
+// ErrorToReplyCode converts an error to the "best" reply code.
+func ErrorToReplyCode(err error) ReplyCode {
+ if cErr, ok := err.(clientError); ok {
+ return ReplyCode(cErr)
+ }
+ opErr, ok := err.(*net.OpError)
+ if !ok {
+ return ReplyGeneralFailure
+ }
+
+ errno, ok := opErr.Err.(syscall.Errno)
+ if !ok {
+ return ReplyGeneralFailure
+ }
+ switch errno {
+ case syscall.EADDRNOTAVAIL:
+ return ReplyAddressNotSupported
+ case syscall.ETIMEDOUT:
+ return ReplyTTLExpired
+ case syscall.ENETUNREACH:
+ return ReplyNetworkUnreachable
+ case syscall.EHOSTUNREACH:
+ return ReplyHostUnreachable
+ case syscall.ECONNREFUSED, syscall.ECONNRESET:
+ return ReplyConnectionRefused
+ default:
+ return ReplyGeneralFailure
+ }
+}
+
+// Request describes a SOCKS 5 request.
+type Request struct {
+ Auth AuthInfo
+ Cmd Command
+ Addr Address
+
+ conn net.Conn
+}
+
+type clientError ReplyCode
+
+func (e clientError) Error() string {
+ switch ReplyCode(e) {
+ case ReplySucceeded:
+ return "socks5: succeeded"
+ case ReplyGeneralFailure:
+ return "socks5: general failure"
+ case ReplyConnectionNotAllowed:
+ return "socks5: connection not allowed"
+ case ReplyNetworkUnreachable:
+ return "socks5: network unreachable"
+ case ReplyHostUnreachable:
+ return "socks5: host unreachable"
+ case ReplyConnectionRefused:
+ return "socks5: connection refused"
+ case ReplyTTLExpired:
+ return "socks5: ttl expired"
+ case ReplyCommandNotSupported:
+ return "socks5: command not supported"
+ case ReplyAddressNotSupported:
+ return "socks5: address not supported"
+ default:
+ return fmt.Sprintf("socks5: reply code: 0x%02x", e)
+ }
+}
+
+func readByte(conn net.Conn) (byte, error) {
+ var tmp [1]byte
+ if _, err := conn.Read(tmp[:]); err != nil {
+ return 0, err
+ }
+ return tmp[0], nil
+}
+
+func validateByte(descr string, val, expected byte) error {
+ if val != expected {
+ return fmt.Errorf("message field '%s' was 0x%02x (expected 0x%02x)", descr, val, expected)
+ }
+ return nil
+}
diff --git a/sgfw/rules.go b/sgfw/rules.go
index 72c9db4..66b7e50 100644
--- a/sgfw/rules.go
+++ b/sgfw/rules.go
@@ -79,7 +79,7 @@ func (r *Rule) match(dst net.IP, dstPort uint16, hostname string) bool {
if r.hostname != "" {
return r.hostname == hostname
}
- return r.addr == binary.BigEndian.Uint32(dst)
+ return r.addr == binary.BigEndian.Uint32(dst.To4())
}
func (rl *RuleList) filterPacket(p *nfqueue.Packet, pinfo *procsnitch.Info, hostname string) FilterResult {
diff --git a/sgfw/server.go b/sgfw/server.go
new file mode 100644
index 0000000..498a548
--- /dev/null
+++ b/sgfw/server.go
@@ -0,0 +1,200 @@
+/*
+ * server.go - SOCSK5 server implementation.
+ *
+ * To the extent possible under law, Yawning Angel has waived all copyright and
+ * related or neighboring rights to or-ctl-filter, using the creative commons
+ * "cc0" public domain dedication. See LICENSE or
+ * for full details.
+ */
+
+package sgfw
+
+import (
+ "bytes"
+ "fmt"
+ "io"
+ "net"
+ "time"
+)
+
+// Handshake attempts to handle a incoming client handshake over the provided
+// connection and receive the SOCKS5 request. The routine handles sending
+// appropriate errors if applicable, but will not close the connection.
+func Handshake(conn net.Conn) (*Request, error) {
+ // Arm the handshake timeout.
+ var err error
+ if err = conn.SetDeadline(time.Now().Add(inboundTimeout)); err != nil {
+ return nil, err
+ }
+ defer func() {
+ // Disarm the handshake timeout, only propagate the error if
+ // the handshake was successful.
+ nerr := conn.SetDeadline(time.Time{})
+ if err == nil {
+ err = nerr
+ }
+ }()
+
+ req := new(Request)
+ req.conn = conn
+
+ // Negotiate the protocol version and authentication method.
+ var method byte
+ if method, err = req.negotiateAuth(); err != nil {
+ return nil, err
+ }
+
+ // Authenticate if neccecary.
+ if err = req.authenticate(method); err != nil {
+ return nil, err
+ }
+
+ // Read the client command.
+ if err = req.readCommand(); err != nil {
+ return nil, err
+ }
+
+ return req, err
+}
+
+// Reply sends a SOCKS5 reply to the corresponding request. The BND.ADDR and
+// BND.PORT fields are always set to an address/port corresponding to
+// "0.0.0.0:0".
+func (req *Request) Reply(code ReplyCode) error {
+ return req.ReplyAddr(code, nil)
+}
+
+// ReplyAddr sends a SOCKS5 reply to the corresponding request. The BND.ADDR
+// and BND.PORT fields are specified by addr, or "0.0.0.0:0" if not provided.
+func (req *Request) ReplyAddr(code ReplyCode, addr *Address) error {
+ // The server sends a reply message.
+ // uint8_t ver (0x05)
+ // uint8_t rep
+ // uint8_t rsv (0x00)
+ // uint8_t atyp
+ // uint8_t bnd_addr[]
+ // uint16_t bnd_port
+
+ resp := []byte{version, byte(code), rsv}
+ if addr == nil {
+ var nilAddr [net.IPv4len + 2]byte
+ resp = append(resp, atypIPv4)
+ resp = append(resp, nilAddr[:]...)
+ } else {
+ resp = append(resp, addr.raw...)
+ }
+
+ _, err := req.conn.Write(resp[:])
+ return err
+
+}
+
+func (req *Request) negotiateAuth() (byte, error) {
+ // The client sends a version identifier/selection message.
+ // uint8_t ver (0x05)
+ // uint8_t nmethods (>= 1).
+ // uint8_t methods[nmethods]
+
+ var err error
+ if err = req.readByteVerify("version", version); err != nil {
+ return 0, err
+ }
+
+ // Read the number of methods, and the methods.
+ var nmethods byte
+ method := byte(authNoAcceptableMethods)
+ if nmethods, err = req.readByte(); err != nil {
+ return method, err
+ }
+ methods := make([]byte, nmethods)
+ if _, err := io.ReadFull(req.conn, methods); err != nil {
+ return 0, err
+ }
+
+ // Pick the best authentication method, prioritizing authenticating
+ // over not if both options are present.
+ if bytes.IndexByte(methods, authUsernamePassword) != -1 {
+ method = authUsernamePassword
+ } else if bytes.IndexByte(methods, authNoneRequired) != -1 {
+ method = authNoneRequired
+ }
+
+ // The server sends a method selection message.
+ // uint8_t ver (0x05)
+ // uint8_t method
+ msg := []byte{version, method}
+ if _, err = req.conn.Write(msg); err != nil {
+ return 0, err
+ }
+
+ return method, nil
+}
+
+func (req *Request) authenticate(method byte) error {
+ switch method {
+ case authNoneRequired:
+ return nil
+ case authUsernamePassword:
+ return req.authRFC1929()
+ case authNoAcceptableMethods:
+ return fmt.Errorf("no acceptable authentication methods")
+ default:
+ // This should never happen as only supported auth methods should be
+ // negotiated.
+ return fmt.Errorf("negotiated unsupported method 0x%02x", method)
+ }
+}
+
+func (req *Request) readCommand() error {
+ // The client sends the request details.
+ // uint8_t ver (0x05)
+ // uint8_t cmd
+ // uint8_t rsv (0x00)
+ // uint8_t atyp
+ // uint8_t dst_addr[]
+ // uint16_t dst_port
+
+ var err error
+ var cmd byte
+ if err = req.readByteVerify("version", version); err != nil {
+ req.Reply(ReplyGeneralFailure)
+ return err
+ }
+ if cmd, err = req.readByte(); err != nil {
+ req.Reply(ReplyGeneralFailure)
+ return err
+ }
+ switch Command(cmd) {
+ case CommandConnect, CommandTorResolve, CommandTorResolvePTR:
+ req.Cmd = Command(cmd)
+ default:
+ req.Reply(ReplyCommandNotSupported)
+ return fmt.Errorf("unsupported SOCKS command: 0x%02x", cmd)
+ }
+ if err = req.readByteVerify("reserved", rsv); err != nil {
+ req.Reply(ReplyGeneralFailure)
+ return err
+ }
+
+ // Read the destination address/port.
+ err = req.Addr.read(req.conn)
+ if err == errInvalidAtyp {
+ req.Reply(ReplyAddressNotSupported)
+ } else if err != nil {
+ req.Reply(ReplyGeneralFailure)
+ }
+
+ return err
+}
+
+func (req *Request) readByte() (byte, error) {
+ return readByte(req.conn)
+}
+
+func (req *Request) readByteVerify(descr string, expected byte) error {
+ val, err := req.readByte()
+ if err != nil {
+ return err
+ }
+ return validateByte(descr, val, expected)
+}
diff --git a/sgfw/server_rfc1929.go b/sgfw/server_rfc1929.go
new file mode 100644
index 0000000..146559d
--- /dev/null
+++ b/sgfw/server_rfc1929.go
@@ -0,0 +1,84 @@
+/*
+ * server_rfc1929.go - SOCSK 5 server authentication.
+ *
+ * To the extent possible under law, Yawning Angel has waived all copyright and
+ * related or neighboring rights to or-ctl-filter, using the creative commons
+ * "cc0" public domain dedication. See LICENSE or
+ * for full details.
+ */
+
+package sgfw
+
+import (
+ "fmt"
+ "io"
+)
+
+const (
+ authRFC1929Ver = 0x01
+ authRFC1929Success = 0x00
+ authRFC1929Fail = 0x01
+)
+
+// AuthInfo is the RFC 1929 Username/Password authentication data.
+type AuthInfo struct {
+ Uname []byte
+ Passwd []byte
+}
+
+func (req *Request) authRFC1929() (err error) {
+ sendErrResp := func() {
+ // Swallow write/flush errors, the auth failure is the relevant error.
+ resp := []byte{authRFC1929Ver, authRFC1929Fail}
+ req.conn.Write(resp[:])
+ }
+
+ // The client sends a Username/Password request.
+ // uint8_t ver (0x01)
+ // uint8_t ulen (>= 1)
+ // uint8_t uname[ulen]
+ // uint8_t plen (>= 1)
+ // uint8_t passwd[plen]
+
+ if err = req.readByteVerify("auth version", authRFC1929Ver); err != nil {
+ sendErrResp()
+ return
+ }
+
+ // Read the username.
+ var ulen byte
+ if ulen, err = req.readByte(); err != nil {
+ sendErrResp()
+ return
+ } else if ulen < 1 {
+ sendErrResp()
+ return fmt.Errorf("username with 0 length")
+ }
+ uname := make([]byte, ulen)
+ if _, err = io.ReadFull(req.conn, uname); err != nil {
+ sendErrResp()
+ return
+ }
+
+ // Read the password.
+ var plen byte
+ if plen, err = req.readByte(); err != nil {
+ sendErrResp()
+ return
+ } else if plen < 1 {
+ sendErrResp()
+ return fmt.Errorf("password with 0 length")
+ }
+ passwd := make([]byte, plen)
+ if _, err = io.ReadFull(req.conn, passwd); err != nil {
+ sendErrResp()
+ return
+ }
+
+ req.Auth.Uname = uname
+ req.Auth.Passwd = passwd
+
+ resp := []byte{authRFC1929Ver, authRFC1929Success}
+ _, err = req.conn.Write(resp[:])
+ return
+}
diff --git a/sgfw/sgfw.go b/sgfw/sgfw.go
index c4a0b22..21679d5 100644
--- a/sgfw/sgfw.go
+++ b/sgfw/sgfw.go
@@ -7,6 +7,9 @@ import (
"sync"
"syscall"
"time"
+ "bufio"
+ "encoding/json"
+ "strings"
"github.com/op/go-logging"
@@ -107,8 +110,55 @@ func (fw *Firewall) runFilter() {
}
}
+type SocksJsonConfig struct {
+ SocksListener string
+ TorSocks string
+}
+
var commentRegexp = regexp.MustCompile("^[ \t]*#")
+const defaultSocksCfgPath = "/etc/fw-daemon-socks.json"
+
+func loadSocksConfiguration(configFilePath string) (*SocksJsonConfig, error) {
+ config := SocksJsonConfig{}
+ file, err := os.Open(configFilePath)
+ if err != nil {
+ return nil, err
+ }
+ scanner := bufio.NewScanner(file)
+ bs := ""
+ for scanner.Scan() {
+ line := scanner.Text()
+ if !commentRegexp.MatchString(line) {
+ bs += line + "\n"
+ }
+ }
+ if err := json.Unmarshal([]byte(bs), &config); err != nil {
+ return nil, err
+ }
+ return &config, nil
+}
+
+func getSocksChainConfig(config *SocksJsonConfig) *socksChainConfig {
+ // XXX
+ fields := strings.Split(config.TorSocks, "|")
+ torSocksNet := fields[0]
+ torSocksAddr := fields[1]
+ fields = strings.Split(config.SocksListener, "|")
+ socksListenNet := fields[0]
+ socksListenAddr := fields[1]
+ socksConfig := socksChainConfig{
+ TargetSocksNet: torSocksNet,
+ TargetSocksAddr: torSocksAddr,
+ ListenSocksNet: socksListenNet,
+ ListenSocksAddr: socksListenAddr,
+ }
+ log.Notice("Loaded Socks chain config:")
+ log.Notice(socksConfig)
+ return &socksConfig
+}
+
+
func Main() {
readConfig()
logBackend := setupLoggerBackend(FirewallConfig.LoggingLevel)
@@ -141,6 +191,26 @@ func Main() {
fw.loadRules()
+ /*
+ go func() {
+ http.ListenAndServe("localhost:6060", nil)
+ }()
+ */
+
+ wg := sync.WaitGroup{}
+
+ config, err := loadSocksConfiguration(defaultSocksCfgPath)
+ if err != nil && !os.IsNotExist(err) {
+ panic(err)
+ }
+ if config != nil {
+ socksConfig := getSocksChainConfig(config)
+ chain := NewSocksChain(socksConfig, &wg, fw)
+ chain.start()
+ } else {
+ log.Notice("Did not find SOCKS5 configuration file at", defaultSocksCfgPath, "; ignoring subsystem...")
+ }
+
fw.runFilter()
// observe process signals and either
diff --git a/sgfw/socks_server_chain.go b/sgfw/socks_server_chain.go
new file mode 100644
index 0000000..40ab7c3
--- /dev/null
+++ b/sgfw/socks_server_chain.go
@@ -0,0 +1,261 @@
+package sgfw
+
+import (
+ "io"
+ "net"
+ "os"
+ "sync"
+
+ "github.com/subgraph/go-procsnitch"
+ "strconv"
+)
+
+type socksChainConfig struct {
+ TargetSocksNet string
+ TargetSocksAddr string
+ ListenSocksNet string
+ ListenSocksAddr string
+}
+
+type socksChain struct {
+ cfg *socksChainConfig
+ fw *Firewall
+ listener net.Listener
+ wg *sync.WaitGroup
+ procInfo procsnitch.ProcInfo
+}
+
+type socksChainSession struct {
+ cfg *socksChainConfig
+ clientConn net.Conn
+ upstreamConn net.Conn
+ req *Request
+ bndAddr *Address
+ optData []byte
+ procInfo procsnitch.ProcInfo
+ server *socksChain
+}
+
+const (
+ socksVerdictDrop = 1
+ socksVerdictAccept = 2
+)
+
+type pendingSocksConnection struct {
+ pol *Policy
+ hname string
+ destIP net.IP
+ destPort uint16
+ pinfo *procsnitch.Info
+ verdict chan int
+}
+
+func (sc *pendingSocksConnection) policy() *Policy {
+ return sc.pol
+}
+
+func (sc *pendingSocksConnection) procInfo() *procsnitch.Info {
+ return sc.pinfo
+}
+
+func (sc *pendingSocksConnection) hostname() string {
+ return sc.hname
+}
+
+func (sc *pendingSocksConnection) dst() net.IP {
+ return sc.destIP
+}
+func (sc *pendingSocksConnection) dstPort() uint16 {
+ return sc.destPort
+}
+
+func (sc *pendingSocksConnection) deliverVerdict(v int) {
+ sc.verdict <- v
+ close(sc.verdict)
+}
+
+func (sc *pendingSocksConnection) accept() { sc.deliverVerdict(socksVerdictAccept) }
+
+func (sc *pendingSocksConnection) drop() { sc.deliverVerdict(socksVerdictDrop) }
+
+func (sc *pendingSocksConnection) print() string { return "socks connection" }
+
+func NewSocksChain(cfg *socksChainConfig, wg *sync.WaitGroup, fw *Firewall) *socksChain {
+ chain := socksChain{
+ cfg: cfg,
+ fw: fw,
+ wg: wg,
+ procInfo: procsnitch.SystemProcInfo{},
+ }
+ return &chain
+}
+
+// Start initializes the SOCKS 5 server and starts
+// accepting connections.
+func (s *socksChain) start() {
+ var err error
+ s.listener, err = net.Listen(s.cfg.ListenSocksNet, s.cfg.ListenSocksAddr)
+ if err != nil {
+ log.Errorf("ERR/socks: Failed to listen on the socks address: %v", err)
+ os.Exit(1)
+ }
+
+ s.wg.Add(1)
+ go s.socksAcceptLoop()
+}
+
+func (s *socksChain) socksAcceptLoop() error {
+ defer s.wg.Done()
+ defer s.listener.Close()
+
+ for {
+ conn, err := s.listener.Accept()
+ if err != nil {
+ if e, ok := err.(net.Error); ok && !e.Temporary() {
+ log.Infof("ERR/socks: Failed to Accept(): %v", err)
+ return err
+ }
+ continue
+ }
+ session := &socksChainSession{cfg: s.cfg, clientConn: conn, procInfo: s.procInfo, server: s}
+ go session.sessionWorker()
+ }
+}
+
+func (c *socksChainSession) sessionWorker() {
+ defer c.clientConn.Close()
+
+ clientAddr := c.clientConn.RemoteAddr()
+ log.Infof("INFO/socks: New connection from: %v", clientAddr)
+
+ // Do the SOCKS handshake with the client, and read the command.
+ var err error
+ if c.req, err = Handshake(c.clientConn); err != nil {
+ log.Infof("ERR/socks: Failed SOCKS5 handshake: %v", err)
+ return
+ }
+
+ switch c.req.Cmd {
+ case CommandTorResolve, CommandTorResolvePTR:
+ err = c.dispatchTorSOCKS()
+
+ // If we reach here, the request has been dispatched and completed.
+ if err == nil {
+ // Successfully even, send the response back with the addresc.
+ c.req.ReplyAddr(ReplySucceeded, c.bndAddr)
+ }
+ case CommandConnect:
+ if !c.filterConnect() {
+ c.req.Reply(ReplyConnectionRefused)
+ return
+ }
+ c.handleConnect()
+ default:
+ // Should *NEVER* happen, validated as part of handshake.
+ log.Infof("BUG/socks: Unsupported SOCKS command: 0x%02x", c.req.Cmd)
+ c.req.Reply(ReplyCommandNotSupported)
+ }
+}
+
+func (c *socksChainSession) addressDetails() (string, net.IP, uint16) {
+ addr := c.req.Addr
+ host, pstr := addr.HostPort()
+ port, err := strconv.ParseUint(pstr, 10, 16)
+ if err != nil || port == 0 || port > 0xFFFF {
+ log.Warningf("Illegal port value in socks address: %v", addr)
+ return "", nil, 0
+ }
+ if addr.Type() == 3 {
+ return host, nil, uint16(port)
+ }
+ ip := net.ParseIP(host)
+ if ip == nil {
+ log.Warningf("Failed to extract address information from socks address: %v", addr)
+ }
+ return "", ip, uint16(port)
+}
+
+func (c *socksChainSession) filterConnect() bool {
+ pinfo := procsnitch.FindProcessForConnection(c.clientConn, c.procInfo)
+ if pinfo == nil {
+ log.Warningf("No proc found for connection from: %s", c.clientConn.RemoteAddr())
+ return false
+ }
+
+ policy := c.server.fw.PolicyForPath(pinfo.ExePath)
+
+ hostname, ip, port := c.addressDetails()
+ if ip == nil && hostname == "" {
+ return false
+ }
+ result := policy.rules.filter(nil, ip, port, hostname, pinfo)
+ switch result {
+ case FILTER_DENY:
+ return false
+ case FILTER_ALLOW:
+ return true
+ case FILTER_PROMPT:
+ pending := &pendingSocksConnection{
+ pol: policy,
+ hname: hostname,
+ destIP: ip,
+ destPort: port,
+ pinfo: pinfo,
+ verdict: make(chan int),
+ }
+ policy.processPromptResult(pending)
+ v := <-pending.verdict
+ if v == socksVerdictAccept {
+ return true
+ }
+ }
+
+ return false
+
+}
+
+func (c *socksChainSession) handleConnect() {
+ err := c.dispatchTorSOCKS()
+ if err != nil {
+ return
+ }
+ c.req.Reply(ReplySucceeded)
+ defer c.upstreamConn.Close()
+
+ if c.optData != nil {
+ if _, err = c.upstreamConn.Write(c.optData); err != nil {
+ log.Infof("ERR/socks: Failed writing OptData: %v", err)
+ return
+ }
+ c.optData = nil
+ }
+
+ // A upstream connection has been established, push data back and forth
+ // till the session is done.
+ c.forwardTraffic()
+ log.Infof("INFO/socks: Closed SOCKS connection from: %v", c.clientConn.RemoteAddr())
+}
+
+func (c *socksChainSession) forwardTraffic() {
+ var wg sync.WaitGroup
+ wg.Add(2)
+
+ copyLoop := func(dst, src net.Conn) {
+ defer wg.Done()
+ defer dst.Close()
+
+ io.Copy(dst, src)
+ }
+ go copyLoop(c.upstreamConn, c.clientConn)
+ go copyLoop(c.clientConn, c.upstreamConn)
+
+ wg.Wait()
+}
+
+func (c *socksChainSession) dispatchTorSOCKS() (err error) {
+ c.upstreamConn, c.bndAddr, err = Redispatch(c.cfg.TargetSocksNet, c.cfg.TargetSocksAddr, c.req)
+ if err != nil {
+ c.req.Reply(ErrorToReplyCode(err))
+ }
+ return
+}
diff --git a/sgfw/socks_server_chain_test.go b/sgfw/socks_server_chain_test.go
new file mode 100644
index 0000000..c4cb149
--- /dev/null
+++ b/sgfw/socks_server_chain_test.go
@@ -0,0 +1,305 @@
+package sgfw
+
+import (
+ "bufio"
+ "bytes"
+ "fmt"
+ "io"
+ "net"
+ "strings"
+ "sync"
+ "testing"
+ "time"
+
+ "github.com/subgraph/fw-daemon/socks5"
+ "golang.org/x/net/proxy"
+)
+
+// MortalService can be killed at any time.
+type MortalService struct {
+ network string
+ address string
+ connectionCallback func(net.Conn) error
+
+ conns []net.Conn
+ quit chan bool
+ listener net.Listener
+ waitGroup *sync.WaitGroup
+}
+
+// NewMortalService creates a new MortalService
+func NewMortalService(network, address string, connectionCallback func(net.Conn) error) *MortalService {
+ l := MortalService{
+ network: network,
+ address: address,
+ connectionCallback: connectionCallback,
+
+ conns: make([]net.Conn, 0, 10),
+ quit: make(chan bool),
+ waitGroup: &sync.WaitGroup{},
+ }
+ return &l
+}
+
+// Stop will kill our listener and all it's connections
+func (l *MortalService) Stop() {
+ log.Infof("stopping listener service %s:%s", l.network, l.address)
+ close(l.quit)
+ if l.listener != nil {
+ l.listener.Close()
+ }
+ l.waitGroup.Wait()
+}
+
+func (l *MortalService) acceptLoop() {
+ defer l.waitGroup.Done()
+ defer func() {
+ log.Infof("stoping listener service %s:%s", l.network, l.address)
+ for i, conn := range l.conns {
+ if conn != nil {
+ log.Infof("Closing connection #%d", i)
+ conn.Close()
+ }
+ }
+ }()
+ defer l.listener.Close()
+
+ for {
+ conn, err := l.listener.Accept()
+ if nil != err {
+ if opErr, ok := err.(*net.OpError); ok && opErr.Timeout() {
+ continue
+ } else {
+ log.Infof("MortalService connection accept failure: %s\n", err)
+ select {
+ case <-l.quit:
+ return
+ default:
+ }
+ continue
+ }
+ }
+
+ l.conns = append(l.conns, conn)
+ go l.handleConnection(conn, len(l.conns)-1)
+ }
+}
+
+func (l *MortalService) createDeadlinedListener() error {
+ if l.network == "tcp" {
+ tcpAddr, err := net.ResolveTCPAddr("tcp", l.address)
+ if err != nil {
+ return fmt.Errorf("MortalService.createDeadlinedListener %s %s failure: %s", l.network, l.address, err)
+ }
+ tcpListener, err := net.ListenTCP("tcp", tcpAddr)
+ if err != nil {
+ return fmt.Errorf("MortalService.createDeadlinedListener %s %s failure: %s", l.network, l.address, err)
+ }
+ tcpListener.SetDeadline(time.Now().Add(1e9))
+ l.listener = tcpListener
+ return nil
+ } else if l.network == "unix" {
+ unixAddr, err := net.ResolveUnixAddr("unix", l.address)
+ if err != nil {
+ return fmt.Errorf("MortalService.createDeadlinedListener %s %s failure: %s", l.network, l.address, err)
+ }
+ unixListener, err := net.ListenUnix("unix", unixAddr)
+ if err != nil {
+ return fmt.Errorf("MortalService.createDeadlinedListener %s %s failure: %s", l.network, l.address, err)
+ }
+ unixListener.SetDeadline(time.Now().Add(1e9))
+ l.listener = unixListener
+ return nil
+ } else {
+ panic("")
+ }
+ return nil
+}
+
+// Start the MortalService
+func (l *MortalService) Start() error {
+ var err error
+ err = l.createDeadlinedListener()
+ if err != nil {
+ return err
+ }
+ l.waitGroup.Add(1)
+ go l.acceptLoop()
+ return nil
+}
+
+func (l *MortalService) handleConnection(conn net.Conn, id int) error {
+ defer func() {
+ log.Infof("Closing connection #%d", id)
+ conn.Close()
+ l.conns[id] = nil
+ }()
+
+ log.Infof("Starting connection #%d", id)
+
+ for {
+ if err := l.connectionCallback(conn); err != nil {
+ log.Error(err.Error())
+ return err
+ }
+ return nil
+ }
+}
+
+type AccumulatingService struct {
+ net, address string
+ banner string
+ buffer bytes.Buffer
+ mortalService *MortalService
+ hasProtocolInfo bool
+ hasAuthenticate bool
+ receivedChan chan bool
+}
+
+func NewAccumulatingService(net, address, banner string) *AccumulatingService {
+ l := AccumulatingService{
+ net: net,
+ address: address,
+ banner: banner,
+ hasProtocolInfo: true,
+ hasAuthenticate: true,
+ receivedChan: make(chan bool, 0),
+ }
+ return &l
+}
+
+func (a *AccumulatingService) Start() {
+ a.mortalService = NewMortalService(a.net, a.address, a.SessionWorker)
+ a.mortalService.Start()
+}
+
+func (a *AccumulatingService) Stop() {
+ fmt.Println("AccumulatingService STOP")
+ a.mortalService.Stop()
+}
+
+func (a *AccumulatingService) WaitUntilReceived() {
+ <-a.receivedChan
+}
+
+func (a *AccumulatingService) SessionWorker(conn net.Conn) error {
+ connReader := bufio.NewReader(conn)
+ conn.Write([]byte(a.banner))
+ for {
+ line, err := connReader.ReadBytes('\n')
+ if err != nil {
+ fmt.Printf("AccumulatingService read error: %s\n", err)
+ }
+ lineStr := strings.TrimSpace(string(line))
+ a.buffer.WriteString(lineStr + "\n")
+ a.receivedChan <- true
+ }
+ return nil
+}
+
+func fakeSocksSessionWorker(clientConn net.Conn, targetNet, targetAddr string) error {
+ defer clientConn.Close()
+
+ clientAddr := clientConn.RemoteAddr()
+ fmt.Printf("INFO/socks: New connection from: %v\n", clientAddr)
+
+ // Do the SOCKS handshake with the client, and read the command.
+ req, err := socks5.Handshake(clientConn)
+ if err != nil {
+ panic(fmt.Sprintf("ERR/socks: Failed SOCKS5 handshake: %v", err))
+ }
+
+ var upstreamConn net.Conn
+ upstreamConn, err = net.Dial(targetNet, targetAddr)
+ if err != nil {
+ panic(err)
+ }
+ defer upstreamConn.Close()
+ req.Reply(socks5.ReplySucceeded)
+
+ // A upstream connection has been established, push data back and forth
+ // till the session is done.
+ var wg sync.WaitGroup
+ wg.Add(2)
+ copyLoop := func(dst, src net.Conn) {
+ defer wg.Done()
+ defer dst.Close()
+
+ io.Copy(dst, src)
+ }
+ go copyLoop(upstreamConn, clientConn)
+ go copyLoop(clientConn, upstreamConn)
+
+ wg.Wait()
+ fmt.Printf("INFO/socks: Closed SOCKS connection from: %v\n", clientAddr)
+ return nil
+}
+
+func TestSocksServerProxyChain(t *testing.T) {
+ // socks client ---> socks chain ---> socks server ---> service
+ socksChainNet := "tcp"
+ socksChainAddr := "127.0.0.1:7750"
+ socksServerNet := "tcp"
+ socksServerAddr := "127.0.0.1:8850"
+ serviceNet := "tcp"
+ serviceAddr := "127.0.0.1:9950"
+
+ banner := "meow 123\r\n"
+ // setup the service listener
+ service := NewAccumulatingService(serviceNet, serviceAddr, banner)
+ service.Start()
+ defer service.Stop()
+
+ // setup the "socks server"
+ session := func(clientConn net.Conn) error {
+ return fakeSocksSessionWorker(clientConn, serviceNet, serviceAddr)
+ }
+ socksService := NewMortalService(socksServerNet, socksServerAddr, session)
+ socksService.Start()
+ defer socksService.Stop()
+
+ // setup the SOCKS proxy chain
+ socksConfig := socksChainConfig{
+ TargetSocksNet: socksServerNet,
+ TargetSocksAddr: socksServerAddr,
+ ListenSocksNet: socksChainNet,
+ ListenSocksAddr: socksChainAddr,
+ }
+ wg := sync.WaitGroup{}
+ ds := dbusServer{}
+ chain := NewSocksChain(&socksConfig, &wg, &ds)
+ chain.start()
+
+ // setup the SOCKS client
+ auth := proxy.Auth{
+ User: "",
+ Password: "",
+ }
+ forward := proxy.NewPerHost(proxy.Direct, proxy.Direct)
+ socksClient, err := proxy.SOCKS5(socksChainNet, socksChainAddr, &auth, forward)
+ conn, err := socksClient.Dial(serviceNet, serviceAddr)
+ if err != nil {
+ panic(err)
+ }
+
+ // read a banner from the service
+ rd := bufio.NewReader(conn)
+ line := []byte{}
+ line, err = rd.ReadBytes('\n')
+ if err != nil {
+ panic(err)
+ }
+ if string(line) != banner {
+ t.Errorf("Did not receive expected banner. Got %s, wanted %s\n", string(line), banner)
+ t.Fail()
+ }
+
+ // send the service some data and verify it was received
+ clientData := "hello world\r\n"
+ conn.Write([]byte(clientData))
+ service.WaitUntilReceived()
+ if service.buffer.String() != strings.TrimSpace(clientData)+"\n" {
+ t.Errorf("Client sent %s but service only received %s\n", "hello world\n", service.buffer.String())
+ t.Fail()
+ }
+}