diff --git a/sgfw/client.go b/sgfw/client.go new file mode 100644 index 0000000..48be55c --- /dev/null +++ b/sgfw/client.go @@ -0,0 +1,153 @@ +/* + * client.go - SOCSK5 client implementation. + * + * To the extent possible under law, Yawning Angel has waived all copyright and + * related or neighboring rights to or-ctl-filter, using the creative commons + * "cc0" public domain dedication. See LICENSE or + * for full details. + */ + +package sgfw + +import ( + "fmt" + "io" + "net" + "time" +) + +// Redispatch dials the provided proxy and redispatches an existing request. +func Redispatch(proxyNet, proxyAddr string, req *Request) (conn net.Conn, bndAddr *Address, err error) { + defer func() { + if err != nil && conn != nil { + conn.Close() + } + }() + + conn, err = clientHandshake(proxyNet, proxyAddr, req) + if err != nil { + return nil, nil, err + } + bndAddr, err = clientCmd(conn, req) + return +} + +func clientHandshake(proxyNet, proxyAddr string, req *Request) (net.Conn, error) { + conn, err := net.Dial(proxyNet, proxyAddr) + if err != nil { + return nil, err + } + if err := conn.SetDeadline(time.Now().Add(requestTimeout)); err != nil { + return conn, err + } + authMethod, err := clientNegotiateAuth(conn, req) + if err != nil { + return conn, err + } + if err := clientAuthenticate(conn, req, authMethod); err != nil { + return conn, err + } + if err := conn.SetDeadline(time.Time{}); err != nil { + return conn, err + } + + return conn, nil +} + +func clientNegotiateAuth(conn net.Conn, req *Request) (byte, error) { + useRFC1929 := req.Auth.Uname != nil && req.Auth.Passwd != nil + // XXX: Validate uname/passwd lengths, though should always be valid. + + var buf [3]byte + buf[0] = version + buf[1] = 1 + if useRFC1929 { + buf[2] = authUsernamePassword + } else { + buf[2] = authNoneRequired + } + + if _, err := conn.Write(buf[:]); err != nil { + return authNoAcceptableMethods, err + } + + var resp [2]byte + if _, err := io.ReadFull(conn, resp[:]); err != nil { + return authNoAcceptableMethods, err + } + if err := validateByte("version", resp[0], version); err != nil { + return authNoAcceptableMethods, err + } + if err := validateByte("method", resp[1], buf[2]); err != nil { + return authNoAcceptableMethods, err + } + + return resp[1], nil +} + +func clientAuthenticate(conn net.Conn, req *Request, authMethod byte) error { + switch authMethod { + case authNoneRequired: + case authUsernamePassword: + var buf []byte + buf = append(buf, authRFC1929Ver) + buf = append(buf, byte(len(req.Auth.Uname))) + buf = append(buf, req.Auth.Uname...) + buf = append(buf, byte(len(req.Auth.Passwd))) + buf = append(buf, req.Auth.Passwd...) + if _, err := conn.Write(buf); err != nil { + return err + } + + var resp [2]byte + if _, err := io.ReadFull(conn, resp[:]); err != nil { + return err + } + if err := validateByte("version", resp[0], authRFC1929Ver); err != nil { + return err + } + if err := validateByte("status", resp[1], authRFC1929Success); err != nil { + return err + } + default: + panic(fmt.Sprintf("unknown authentication method: 0x%02x", authMethod)) + } + return nil +} + +func clientCmd(conn net.Conn, req *Request) (*Address, error) { + var buf []byte + buf = append(buf, version) + buf = append(buf, byte(req.Cmd)) + buf = append(buf, rsv) + buf = append(buf, req.Addr.raw...) + if _, err := conn.Write(buf); err != nil { + return nil, err + } + + var respHdr [3]byte + if _, err := io.ReadFull(conn, respHdr[:]); err != nil { + return nil, err + } + + if err := validateByte("version", respHdr[0], version); err != nil { + return nil, err + } + if err := validateByte("rep", respHdr[1], byte(ReplySucceeded)); err != nil { + return nil, clientError(respHdr[1]) + } + if err := validateByte("rsv", respHdr[2], rsv); err != nil { + return nil, err + } + + var bndAddr Address + if err := bndAddr.read(conn); err != nil { + return nil, err + } + + if err := conn.SetDeadline(time.Time{}); err != nil { + return nil, err + } + + return &bndAddr, nil +} diff --git a/sgfw/common.go b/sgfw/common.go new file mode 100644 index 0000000..1717722 --- /dev/null +++ b/sgfw/common.go @@ -0,0 +1,280 @@ +/* + * common.go - SOCSK5 common definitons/routines. + * + * To the extent possible under law, Yawning Angel has waived all copyright and + * related or neighboring rights to or-ctl-filter, using the creative commons + * "cc0" public domain dedication. See LICENSE or + * for full details. + */ + +// Package socks5 implements a SOCKS5 client/server. For more information see +// RFC 1928 and RFC 1929. +// +// Notes: +// * GSSAPI authentication, is NOT supported. +// * The authentication provided by the client is always accepted. +// * A lot of the code is shamelessly stolen from obfs4proxy. +package sgfw + +import ( + "errors" + "fmt" + "io" + "net" + "strconv" + "syscall" + "time" +) + +const ( + version = 0x05 + rsv = 0x00 + + atypIPv4 = 0x01 + atypDomainName = 0x03 + atypIPv6 = 0x04 + + authNoneRequired = 0x00 + authUsernamePassword = 0x02 + authNoAcceptableMethods = 0xff + + inboundTimeout = 5 * time.Second + requestTimeout = 30 * time.Second +) + +var errInvalidAtyp = errors.New("invalid address type") + +// ReplyCode is a SOCKS 5 reply code. +type ReplyCode byte + +// The various SOCKS 5 reply codes from RFC 1928. +const ( + ReplySucceeded ReplyCode = iota + ReplyGeneralFailure + ReplyConnectionNotAllowed + ReplyNetworkUnreachable + ReplyHostUnreachable + ReplyConnectionRefused + ReplyTTLExpired + ReplyCommandNotSupported + ReplyAddressNotSupported +) + +// Command is a SOCKS 5 command. +type Command byte + +// The various SOCKS 5 commands. +const ( + CommandConnect Command = 0x01 + CommandTorResolve Command = 0xf0 + CommandTorResolvePTR Command = 0xf1 +) + +// Address is a SOCKS 5 address + port. +type Address struct { + atyp uint8 + raw []byte + addrStr string + portStr string +} + +// FromString parses the provided "host:port" format address and populates the +// Address fields. +func (addr *Address) FromString(addrStr string) (err error) { + addr.addrStr, addr.portStr, err = net.SplitHostPort(addrStr) + if err != nil { + return + } + + var raw []byte + if ip := net.ParseIP(addr.addrStr); ip != nil { + if v4Addr := ip.To4(); v4Addr != nil { + raw = append(raw, atypIPv4) + raw = append(raw, v4Addr...) + } else if v6Addr := ip.To16(); v6Addr != nil { + raw = append(raw, atypIPv6) + raw = append(raw, v6Addr...) + } else { + return errors.New("unsupported IP address type") + } + } else { + // Must be a FQDN. + if len(addr.addrStr) > 255 { + return fmt.Errorf("invalid FQDN, len > 255 bytes (%d bytes)", len(addr.addrStr)) + } + raw = append(raw, atypDomainName) + raw = append(raw, addr.addrStr...) + } + + var port uint64 + if port, err = strconv.ParseUint(addr.portStr, 10, 16); err != nil { + return + } + raw = append(raw, byte(port>>8)) + raw = append(raw, byte(port&0xff)) + + addr.raw = raw + return +} + +// String returns the string representation of the address, in "host:port" +// format. +func (addr *Address) String() string { + return addr.addrStr + ":" + addr.portStr +} + +// HostPort returns the string representation of the addess, split into the +// host and port components. +func (addr *Address) HostPort() (string, string) { + return addr.addrStr, addr.portStr +} + +// Type returns the address type from the connect command this address was +// parsed from +func (addr *Address) Type() uint8 { + return addr.atyp +} + +func (addr *Address) read(conn net.Conn) (err error) { + // The address looks like: + // uint8_t atyp + // uint8_t addr[] (Length depends on atyp) + // uint16_t port + + // Read the atype. + var atyp byte + if atyp, err = readByte(conn); err != nil { + return + } + addr.raw = append(addr.raw, atyp) + + // Read the address. + var rawAddr []byte + switch atyp { + case atypIPv4: + rawAddr = make([]byte, net.IPv4len) + if _, err = io.ReadFull(conn, rawAddr); err != nil { + return + } + v4Addr := net.IPv4(rawAddr[0], rawAddr[1], rawAddr[2], rawAddr[3]) + addr.addrStr = v4Addr.String() + case atypDomainName: + var alen byte + if alen, err = readByte(conn); err != nil { + return + } + if alen == 0 { + return fmt.Errorf("domain name with 0 length") + } + rawAddr = make([]byte, alen) + addr.raw = append(addr.raw, alen) + if _, err = io.ReadFull(conn, rawAddr); err != nil { + return + } + addr.addrStr = string(rawAddr) + case atypIPv6: + rawAddr = make([]byte, net.IPv6len) + if _, err = io.ReadFull(conn, rawAddr); err != nil { + return + } + v6Addr := make(net.IP, net.IPv6len) + copy(v6Addr[:], rawAddr) + addr.addrStr = fmt.Sprintf("[%s]", v6Addr.String()) + default: + return errInvalidAtyp + } + addr.atyp = atyp + addr.raw = append(addr.raw, rawAddr...) + + // Read the port. + var rawPort [2]byte + if _, err = io.ReadFull(conn, rawPort[:]); err != nil { + return + } + port := int(rawPort[0])<<8 | int(rawPort[1]) + addr.portStr = fmt.Sprintf("%d", port) + addr.raw = append(addr.raw, rawPort[:]...) + + return +} + +// ErrorToReplyCode converts an error to the "best" reply code. +func ErrorToReplyCode(err error) ReplyCode { + if cErr, ok := err.(clientError); ok { + return ReplyCode(cErr) + } + opErr, ok := err.(*net.OpError) + if !ok { + return ReplyGeneralFailure + } + + errno, ok := opErr.Err.(syscall.Errno) + if !ok { + return ReplyGeneralFailure + } + switch errno { + case syscall.EADDRNOTAVAIL: + return ReplyAddressNotSupported + case syscall.ETIMEDOUT: + return ReplyTTLExpired + case syscall.ENETUNREACH: + return ReplyNetworkUnreachable + case syscall.EHOSTUNREACH: + return ReplyHostUnreachable + case syscall.ECONNREFUSED, syscall.ECONNRESET: + return ReplyConnectionRefused + default: + return ReplyGeneralFailure + } +} + +// Request describes a SOCKS 5 request. +type Request struct { + Auth AuthInfo + Cmd Command + Addr Address + + conn net.Conn +} + +type clientError ReplyCode + +func (e clientError) Error() string { + switch ReplyCode(e) { + case ReplySucceeded: + return "socks5: succeeded" + case ReplyGeneralFailure: + return "socks5: general failure" + case ReplyConnectionNotAllowed: + return "socks5: connection not allowed" + case ReplyNetworkUnreachable: + return "socks5: network unreachable" + case ReplyHostUnreachable: + return "socks5: host unreachable" + case ReplyConnectionRefused: + return "socks5: connection refused" + case ReplyTTLExpired: + return "socks5: ttl expired" + case ReplyCommandNotSupported: + return "socks5: command not supported" + case ReplyAddressNotSupported: + return "socks5: address not supported" + default: + return fmt.Sprintf("socks5: reply code: 0x%02x", e) + } +} + +func readByte(conn net.Conn) (byte, error) { + var tmp [1]byte + if _, err := conn.Read(tmp[:]); err != nil { + return 0, err + } + return tmp[0], nil +} + +func validateByte(descr string, val, expected byte) error { + if val != expected { + return fmt.Errorf("message field '%s' was 0x%02x (expected 0x%02x)", descr, val, expected) + } + return nil +} diff --git a/sgfw/rules.go b/sgfw/rules.go index 72c9db4..66b7e50 100644 --- a/sgfw/rules.go +++ b/sgfw/rules.go @@ -79,7 +79,7 @@ func (r *Rule) match(dst net.IP, dstPort uint16, hostname string) bool { if r.hostname != "" { return r.hostname == hostname } - return r.addr == binary.BigEndian.Uint32(dst) + return r.addr == binary.BigEndian.Uint32(dst.To4()) } func (rl *RuleList) filterPacket(p *nfqueue.Packet, pinfo *procsnitch.Info, hostname string) FilterResult { diff --git a/sgfw/server.go b/sgfw/server.go new file mode 100644 index 0000000..498a548 --- /dev/null +++ b/sgfw/server.go @@ -0,0 +1,200 @@ +/* + * server.go - SOCSK5 server implementation. + * + * To the extent possible under law, Yawning Angel has waived all copyright and + * related or neighboring rights to or-ctl-filter, using the creative commons + * "cc0" public domain dedication. See LICENSE or + * for full details. + */ + +package sgfw + +import ( + "bytes" + "fmt" + "io" + "net" + "time" +) + +// Handshake attempts to handle a incoming client handshake over the provided +// connection and receive the SOCKS5 request. The routine handles sending +// appropriate errors if applicable, but will not close the connection. +func Handshake(conn net.Conn) (*Request, error) { + // Arm the handshake timeout. + var err error + if err = conn.SetDeadline(time.Now().Add(inboundTimeout)); err != nil { + return nil, err + } + defer func() { + // Disarm the handshake timeout, only propagate the error if + // the handshake was successful. + nerr := conn.SetDeadline(time.Time{}) + if err == nil { + err = nerr + } + }() + + req := new(Request) + req.conn = conn + + // Negotiate the protocol version and authentication method. + var method byte + if method, err = req.negotiateAuth(); err != nil { + return nil, err + } + + // Authenticate if neccecary. + if err = req.authenticate(method); err != nil { + return nil, err + } + + // Read the client command. + if err = req.readCommand(); err != nil { + return nil, err + } + + return req, err +} + +// Reply sends a SOCKS5 reply to the corresponding request. The BND.ADDR and +// BND.PORT fields are always set to an address/port corresponding to +// "0.0.0.0:0". +func (req *Request) Reply(code ReplyCode) error { + return req.ReplyAddr(code, nil) +} + +// ReplyAddr sends a SOCKS5 reply to the corresponding request. The BND.ADDR +// and BND.PORT fields are specified by addr, or "0.0.0.0:0" if not provided. +func (req *Request) ReplyAddr(code ReplyCode, addr *Address) error { + // The server sends a reply message. + // uint8_t ver (0x05) + // uint8_t rep + // uint8_t rsv (0x00) + // uint8_t atyp + // uint8_t bnd_addr[] + // uint16_t bnd_port + + resp := []byte{version, byte(code), rsv} + if addr == nil { + var nilAddr [net.IPv4len + 2]byte + resp = append(resp, atypIPv4) + resp = append(resp, nilAddr[:]...) + } else { + resp = append(resp, addr.raw...) + } + + _, err := req.conn.Write(resp[:]) + return err + +} + +func (req *Request) negotiateAuth() (byte, error) { + // The client sends a version identifier/selection message. + // uint8_t ver (0x05) + // uint8_t nmethods (>= 1). + // uint8_t methods[nmethods] + + var err error + if err = req.readByteVerify("version", version); err != nil { + return 0, err + } + + // Read the number of methods, and the methods. + var nmethods byte + method := byte(authNoAcceptableMethods) + if nmethods, err = req.readByte(); err != nil { + return method, err + } + methods := make([]byte, nmethods) + if _, err := io.ReadFull(req.conn, methods); err != nil { + return 0, err + } + + // Pick the best authentication method, prioritizing authenticating + // over not if both options are present. + if bytes.IndexByte(methods, authUsernamePassword) != -1 { + method = authUsernamePassword + } else if bytes.IndexByte(methods, authNoneRequired) != -1 { + method = authNoneRequired + } + + // The server sends a method selection message. + // uint8_t ver (0x05) + // uint8_t method + msg := []byte{version, method} + if _, err = req.conn.Write(msg); err != nil { + return 0, err + } + + return method, nil +} + +func (req *Request) authenticate(method byte) error { + switch method { + case authNoneRequired: + return nil + case authUsernamePassword: + return req.authRFC1929() + case authNoAcceptableMethods: + return fmt.Errorf("no acceptable authentication methods") + default: + // This should never happen as only supported auth methods should be + // negotiated. + return fmt.Errorf("negotiated unsupported method 0x%02x", method) + } +} + +func (req *Request) readCommand() error { + // The client sends the request details. + // uint8_t ver (0x05) + // uint8_t cmd + // uint8_t rsv (0x00) + // uint8_t atyp + // uint8_t dst_addr[] + // uint16_t dst_port + + var err error + var cmd byte + if err = req.readByteVerify("version", version); err != nil { + req.Reply(ReplyGeneralFailure) + return err + } + if cmd, err = req.readByte(); err != nil { + req.Reply(ReplyGeneralFailure) + return err + } + switch Command(cmd) { + case CommandConnect, CommandTorResolve, CommandTorResolvePTR: + req.Cmd = Command(cmd) + default: + req.Reply(ReplyCommandNotSupported) + return fmt.Errorf("unsupported SOCKS command: 0x%02x", cmd) + } + if err = req.readByteVerify("reserved", rsv); err != nil { + req.Reply(ReplyGeneralFailure) + return err + } + + // Read the destination address/port. + err = req.Addr.read(req.conn) + if err == errInvalidAtyp { + req.Reply(ReplyAddressNotSupported) + } else if err != nil { + req.Reply(ReplyGeneralFailure) + } + + return err +} + +func (req *Request) readByte() (byte, error) { + return readByte(req.conn) +} + +func (req *Request) readByteVerify(descr string, expected byte) error { + val, err := req.readByte() + if err != nil { + return err + } + return validateByte(descr, val, expected) +} diff --git a/sgfw/server_rfc1929.go b/sgfw/server_rfc1929.go new file mode 100644 index 0000000..146559d --- /dev/null +++ b/sgfw/server_rfc1929.go @@ -0,0 +1,84 @@ +/* + * server_rfc1929.go - SOCSK 5 server authentication. + * + * To the extent possible under law, Yawning Angel has waived all copyright and + * related or neighboring rights to or-ctl-filter, using the creative commons + * "cc0" public domain dedication. See LICENSE or + * for full details. + */ + +package sgfw + +import ( + "fmt" + "io" +) + +const ( + authRFC1929Ver = 0x01 + authRFC1929Success = 0x00 + authRFC1929Fail = 0x01 +) + +// AuthInfo is the RFC 1929 Username/Password authentication data. +type AuthInfo struct { + Uname []byte + Passwd []byte +} + +func (req *Request) authRFC1929() (err error) { + sendErrResp := func() { + // Swallow write/flush errors, the auth failure is the relevant error. + resp := []byte{authRFC1929Ver, authRFC1929Fail} + req.conn.Write(resp[:]) + } + + // The client sends a Username/Password request. + // uint8_t ver (0x01) + // uint8_t ulen (>= 1) + // uint8_t uname[ulen] + // uint8_t plen (>= 1) + // uint8_t passwd[plen] + + if err = req.readByteVerify("auth version", authRFC1929Ver); err != nil { + sendErrResp() + return + } + + // Read the username. + var ulen byte + if ulen, err = req.readByte(); err != nil { + sendErrResp() + return + } else if ulen < 1 { + sendErrResp() + return fmt.Errorf("username with 0 length") + } + uname := make([]byte, ulen) + if _, err = io.ReadFull(req.conn, uname); err != nil { + sendErrResp() + return + } + + // Read the password. + var plen byte + if plen, err = req.readByte(); err != nil { + sendErrResp() + return + } else if plen < 1 { + sendErrResp() + return fmt.Errorf("password with 0 length") + } + passwd := make([]byte, plen) + if _, err = io.ReadFull(req.conn, passwd); err != nil { + sendErrResp() + return + } + + req.Auth.Uname = uname + req.Auth.Passwd = passwd + + resp := []byte{authRFC1929Ver, authRFC1929Success} + _, err = req.conn.Write(resp[:]) + return +} diff --git a/sgfw/sgfw.go b/sgfw/sgfw.go index c4a0b22..21679d5 100644 --- a/sgfw/sgfw.go +++ b/sgfw/sgfw.go @@ -7,6 +7,9 @@ import ( "sync" "syscall" "time" + "bufio" + "encoding/json" + "strings" "github.com/op/go-logging" @@ -107,8 +110,55 @@ func (fw *Firewall) runFilter() { } } +type SocksJsonConfig struct { + SocksListener string + TorSocks string +} + var commentRegexp = regexp.MustCompile("^[ \t]*#") +const defaultSocksCfgPath = "/etc/fw-daemon-socks.json" + +func loadSocksConfiguration(configFilePath string) (*SocksJsonConfig, error) { + config := SocksJsonConfig{} + file, err := os.Open(configFilePath) + if err != nil { + return nil, err + } + scanner := bufio.NewScanner(file) + bs := "" + for scanner.Scan() { + line := scanner.Text() + if !commentRegexp.MatchString(line) { + bs += line + "\n" + } + } + if err := json.Unmarshal([]byte(bs), &config); err != nil { + return nil, err + } + return &config, nil +} + +func getSocksChainConfig(config *SocksJsonConfig) *socksChainConfig { + // XXX + fields := strings.Split(config.TorSocks, "|") + torSocksNet := fields[0] + torSocksAddr := fields[1] + fields = strings.Split(config.SocksListener, "|") + socksListenNet := fields[0] + socksListenAddr := fields[1] + socksConfig := socksChainConfig{ + TargetSocksNet: torSocksNet, + TargetSocksAddr: torSocksAddr, + ListenSocksNet: socksListenNet, + ListenSocksAddr: socksListenAddr, + } + log.Notice("Loaded Socks chain config:") + log.Notice(socksConfig) + return &socksConfig +} + + func Main() { readConfig() logBackend := setupLoggerBackend(FirewallConfig.LoggingLevel) @@ -141,6 +191,26 @@ func Main() { fw.loadRules() + /* + go func() { + http.ListenAndServe("localhost:6060", nil) + }() + */ + + wg := sync.WaitGroup{} + + config, err := loadSocksConfiguration(defaultSocksCfgPath) + if err != nil && !os.IsNotExist(err) { + panic(err) + } + if config != nil { + socksConfig := getSocksChainConfig(config) + chain := NewSocksChain(socksConfig, &wg, fw) + chain.start() + } else { + log.Notice("Did not find SOCKS5 configuration file at", defaultSocksCfgPath, "; ignoring subsystem...") + } + fw.runFilter() // observe process signals and either diff --git a/sgfw/socks_server_chain.go b/sgfw/socks_server_chain.go new file mode 100644 index 0000000..40ab7c3 --- /dev/null +++ b/sgfw/socks_server_chain.go @@ -0,0 +1,261 @@ +package sgfw + +import ( + "io" + "net" + "os" + "sync" + + "github.com/subgraph/go-procsnitch" + "strconv" +) + +type socksChainConfig struct { + TargetSocksNet string + TargetSocksAddr string + ListenSocksNet string + ListenSocksAddr string +} + +type socksChain struct { + cfg *socksChainConfig + fw *Firewall + listener net.Listener + wg *sync.WaitGroup + procInfo procsnitch.ProcInfo +} + +type socksChainSession struct { + cfg *socksChainConfig + clientConn net.Conn + upstreamConn net.Conn + req *Request + bndAddr *Address + optData []byte + procInfo procsnitch.ProcInfo + server *socksChain +} + +const ( + socksVerdictDrop = 1 + socksVerdictAccept = 2 +) + +type pendingSocksConnection struct { + pol *Policy + hname string + destIP net.IP + destPort uint16 + pinfo *procsnitch.Info + verdict chan int +} + +func (sc *pendingSocksConnection) policy() *Policy { + return sc.pol +} + +func (sc *pendingSocksConnection) procInfo() *procsnitch.Info { + return sc.pinfo +} + +func (sc *pendingSocksConnection) hostname() string { + return sc.hname +} + +func (sc *pendingSocksConnection) dst() net.IP { + return sc.destIP +} +func (sc *pendingSocksConnection) dstPort() uint16 { + return sc.destPort +} + +func (sc *pendingSocksConnection) deliverVerdict(v int) { + sc.verdict <- v + close(sc.verdict) +} + +func (sc *pendingSocksConnection) accept() { sc.deliverVerdict(socksVerdictAccept) } + +func (sc *pendingSocksConnection) drop() { sc.deliverVerdict(socksVerdictDrop) } + +func (sc *pendingSocksConnection) print() string { return "socks connection" } + +func NewSocksChain(cfg *socksChainConfig, wg *sync.WaitGroup, fw *Firewall) *socksChain { + chain := socksChain{ + cfg: cfg, + fw: fw, + wg: wg, + procInfo: procsnitch.SystemProcInfo{}, + } + return &chain +} + +// Start initializes the SOCKS 5 server and starts +// accepting connections. +func (s *socksChain) start() { + var err error + s.listener, err = net.Listen(s.cfg.ListenSocksNet, s.cfg.ListenSocksAddr) + if err != nil { + log.Errorf("ERR/socks: Failed to listen on the socks address: %v", err) + os.Exit(1) + } + + s.wg.Add(1) + go s.socksAcceptLoop() +} + +func (s *socksChain) socksAcceptLoop() error { + defer s.wg.Done() + defer s.listener.Close() + + for { + conn, err := s.listener.Accept() + if err != nil { + if e, ok := err.(net.Error); ok && !e.Temporary() { + log.Infof("ERR/socks: Failed to Accept(): %v", err) + return err + } + continue + } + session := &socksChainSession{cfg: s.cfg, clientConn: conn, procInfo: s.procInfo, server: s} + go session.sessionWorker() + } +} + +func (c *socksChainSession) sessionWorker() { + defer c.clientConn.Close() + + clientAddr := c.clientConn.RemoteAddr() + log.Infof("INFO/socks: New connection from: %v", clientAddr) + + // Do the SOCKS handshake with the client, and read the command. + var err error + if c.req, err = Handshake(c.clientConn); err != nil { + log.Infof("ERR/socks: Failed SOCKS5 handshake: %v", err) + return + } + + switch c.req.Cmd { + case CommandTorResolve, CommandTorResolvePTR: + err = c.dispatchTorSOCKS() + + // If we reach here, the request has been dispatched and completed. + if err == nil { + // Successfully even, send the response back with the addresc. + c.req.ReplyAddr(ReplySucceeded, c.bndAddr) + } + case CommandConnect: + if !c.filterConnect() { + c.req.Reply(ReplyConnectionRefused) + return + } + c.handleConnect() + default: + // Should *NEVER* happen, validated as part of handshake. + log.Infof("BUG/socks: Unsupported SOCKS command: 0x%02x", c.req.Cmd) + c.req.Reply(ReplyCommandNotSupported) + } +} + +func (c *socksChainSession) addressDetails() (string, net.IP, uint16) { + addr := c.req.Addr + host, pstr := addr.HostPort() + port, err := strconv.ParseUint(pstr, 10, 16) + if err != nil || port == 0 || port > 0xFFFF { + log.Warningf("Illegal port value in socks address: %v", addr) + return "", nil, 0 + } + if addr.Type() == 3 { + return host, nil, uint16(port) + } + ip := net.ParseIP(host) + if ip == nil { + log.Warningf("Failed to extract address information from socks address: %v", addr) + } + return "", ip, uint16(port) +} + +func (c *socksChainSession) filterConnect() bool { + pinfo := procsnitch.FindProcessForConnection(c.clientConn, c.procInfo) + if pinfo == nil { + log.Warningf("No proc found for connection from: %s", c.clientConn.RemoteAddr()) + return false + } + + policy := c.server.fw.PolicyForPath(pinfo.ExePath) + + hostname, ip, port := c.addressDetails() + if ip == nil && hostname == "" { + return false + } + result := policy.rules.filter(nil, ip, port, hostname, pinfo) + switch result { + case FILTER_DENY: + return false + case FILTER_ALLOW: + return true + case FILTER_PROMPT: + pending := &pendingSocksConnection{ + pol: policy, + hname: hostname, + destIP: ip, + destPort: port, + pinfo: pinfo, + verdict: make(chan int), + } + policy.processPromptResult(pending) + v := <-pending.verdict + if v == socksVerdictAccept { + return true + } + } + + return false + +} + +func (c *socksChainSession) handleConnect() { + err := c.dispatchTorSOCKS() + if err != nil { + return + } + c.req.Reply(ReplySucceeded) + defer c.upstreamConn.Close() + + if c.optData != nil { + if _, err = c.upstreamConn.Write(c.optData); err != nil { + log.Infof("ERR/socks: Failed writing OptData: %v", err) + return + } + c.optData = nil + } + + // A upstream connection has been established, push data back and forth + // till the session is done. + c.forwardTraffic() + log.Infof("INFO/socks: Closed SOCKS connection from: %v", c.clientConn.RemoteAddr()) +} + +func (c *socksChainSession) forwardTraffic() { + var wg sync.WaitGroup + wg.Add(2) + + copyLoop := func(dst, src net.Conn) { + defer wg.Done() + defer dst.Close() + + io.Copy(dst, src) + } + go copyLoop(c.upstreamConn, c.clientConn) + go copyLoop(c.clientConn, c.upstreamConn) + + wg.Wait() +} + +func (c *socksChainSession) dispatchTorSOCKS() (err error) { + c.upstreamConn, c.bndAddr, err = Redispatch(c.cfg.TargetSocksNet, c.cfg.TargetSocksAddr, c.req) + if err != nil { + c.req.Reply(ErrorToReplyCode(err)) + } + return +} diff --git a/sgfw/socks_server_chain_test.go b/sgfw/socks_server_chain_test.go new file mode 100644 index 0000000..c4cb149 --- /dev/null +++ b/sgfw/socks_server_chain_test.go @@ -0,0 +1,305 @@ +package sgfw + +import ( + "bufio" + "bytes" + "fmt" + "io" + "net" + "strings" + "sync" + "testing" + "time" + + "github.com/subgraph/fw-daemon/socks5" + "golang.org/x/net/proxy" +) + +// MortalService can be killed at any time. +type MortalService struct { + network string + address string + connectionCallback func(net.Conn) error + + conns []net.Conn + quit chan bool + listener net.Listener + waitGroup *sync.WaitGroup +} + +// NewMortalService creates a new MortalService +func NewMortalService(network, address string, connectionCallback func(net.Conn) error) *MortalService { + l := MortalService{ + network: network, + address: address, + connectionCallback: connectionCallback, + + conns: make([]net.Conn, 0, 10), + quit: make(chan bool), + waitGroup: &sync.WaitGroup{}, + } + return &l +} + +// Stop will kill our listener and all it's connections +func (l *MortalService) Stop() { + log.Infof("stopping listener service %s:%s", l.network, l.address) + close(l.quit) + if l.listener != nil { + l.listener.Close() + } + l.waitGroup.Wait() +} + +func (l *MortalService) acceptLoop() { + defer l.waitGroup.Done() + defer func() { + log.Infof("stoping listener service %s:%s", l.network, l.address) + for i, conn := range l.conns { + if conn != nil { + log.Infof("Closing connection #%d", i) + conn.Close() + } + } + }() + defer l.listener.Close() + + for { + conn, err := l.listener.Accept() + if nil != err { + if opErr, ok := err.(*net.OpError); ok && opErr.Timeout() { + continue + } else { + log.Infof("MortalService connection accept failure: %s\n", err) + select { + case <-l.quit: + return + default: + } + continue + } + } + + l.conns = append(l.conns, conn) + go l.handleConnection(conn, len(l.conns)-1) + } +} + +func (l *MortalService) createDeadlinedListener() error { + if l.network == "tcp" { + tcpAddr, err := net.ResolveTCPAddr("tcp", l.address) + if err != nil { + return fmt.Errorf("MortalService.createDeadlinedListener %s %s failure: %s", l.network, l.address, err) + } + tcpListener, err := net.ListenTCP("tcp", tcpAddr) + if err != nil { + return fmt.Errorf("MortalService.createDeadlinedListener %s %s failure: %s", l.network, l.address, err) + } + tcpListener.SetDeadline(time.Now().Add(1e9)) + l.listener = tcpListener + return nil + } else if l.network == "unix" { + unixAddr, err := net.ResolveUnixAddr("unix", l.address) + if err != nil { + return fmt.Errorf("MortalService.createDeadlinedListener %s %s failure: %s", l.network, l.address, err) + } + unixListener, err := net.ListenUnix("unix", unixAddr) + if err != nil { + return fmt.Errorf("MortalService.createDeadlinedListener %s %s failure: %s", l.network, l.address, err) + } + unixListener.SetDeadline(time.Now().Add(1e9)) + l.listener = unixListener + return nil + } else { + panic("") + } + return nil +} + +// Start the MortalService +func (l *MortalService) Start() error { + var err error + err = l.createDeadlinedListener() + if err != nil { + return err + } + l.waitGroup.Add(1) + go l.acceptLoop() + return nil +} + +func (l *MortalService) handleConnection(conn net.Conn, id int) error { + defer func() { + log.Infof("Closing connection #%d", id) + conn.Close() + l.conns[id] = nil + }() + + log.Infof("Starting connection #%d", id) + + for { + if err := l.connectionCallback(conn); err != nil { + log.Error(err.Error()) + return err + } + return nil + } +} + +type AccumulatingService struct { + net, address string + banner string + buffer bytes.Buffer + mortalService *MortalService + hasProtocolInfo bool + hasAuthenticate bool + receivedChan chan bool +} + +func NewAccumulatingService(net, address, banner string) *AccumulatingService { + l := AccumulatingService{ + net: net, + address: address, + banner: banner, + hasProtocolInfo: true, + hasAuthenticate: true, + receivedChan: make(chan bool, 0), + } + return &l +} + +func (a *AccumulatingService) Start() { + a.mortalService = NewMortalService(a.net, a.address, a.SessionWorker) + a.mortalService.Start() +} + +func (a *AccumulatingService) Stop() { + fmt.Println("AccumulatingService STOP") + a.mortalService.Stop() +} + +func (a *AccumulatingService) WaitUntilReceived() { + <-a.receivedChan +} + +func (a *AccumulatingService) SessionWorker(conn net.Conn) error { + connReader := bufio.NewReader(conn) + conn.Write([]byte(a.banner)) + for { + line, err := connReader.ReadBytes('\n') + if err != nil { + fmt.Printf("AccumulatingService read error: %s\n", err) + } + lineStr := strings.TrimSpace(string(line)) + a.buffer.WriteString(lineStr + "\n") + a.receivedChan <- true + } + return nil +} + +func fakeSocksSessionWorker(clientConn net.Conn, targetNet, targetAddr string) error { + defer clientConn.Close() + + clientAddr := clientConn.RemoteAddr() + fmt.Printf("INFO/socks: New connection from: %v\n", clientAddr) + + // Do the SOCKS handshake with the client, and read the command. + req, err := socks5.Handshake(clientConn) + if err != nil { + panic(fmt.Sprintf("ERR/socks: Failed SOCKS5 handshake: %v", err)) + } + + var upstreamConn net.Conn + upstreamConn, err = net.Dial(targetNet, targetAddr) + if err != nil { + panic(err) + } + defer upstreamConn.Close() + req.Reply(socks5.ReplySucceeded) + + // A upstream connection has been established, push data back and forth + // till the session is done. + var wg sync.WaitGroup + wg.Add(2) + copyLoop := func(dst, src net.Conn) { + defer wg.Done() + defer dst.Close() + + io.Copy(dst, src) + } + go copyLoop(upstreamConn, clientConn) + go copyLoop(clientConn, upstreamConn) + + wg.Wait() + fmt.Printf("INFO/socks: Closed SOCKS connection from: %v\n", clientAddr) + return nil +} + +func TestSocksServerProxyChain(t *testing.T) { + // socks client ---> socks chain ---> socks server ---> service + socksChainNet := "tcp" + socksChainAddr := "127.0.0.1:7750" + socksServerNet := "tcp" + socksServerAddr := "127.0.0.1:8850" + serviceNet := "tcp" + serviceAddr := "127.0.0.1:9950" + + banner := "meow 123\r\n" + // setup the service listener + service := NewAccumulatingService(serviceNet, serviceAddr, banner) + service.Start() + defer service.Stop() + + // setup the "socks server" + session := func(clientConn net.Conn) error { + return fakeSocksSessionWorker(clientConn, serviceNet, serviceAddr) + } + socksService := NewMortalService(socksServerNet, socksServerAddr, session) + socksService.Start() + defer socksService.Stop() + + // setup the SOCKS proxy chain + socksConfig := socksChainConfig{ + TargetSocksNet: socksServerNet, + TargetSocksAddr: socksServerAddr, + ListenSocksNet: socksChainNet, + ListenSocksAddr: socksChainAddr, + } + wg := sync.WaitGroup{} + ds := dbusServer{} + chain := NewSocksChain(&socksConfig, &wg, &ds) + chain.start() + + // setup the SOCKS client + auth := proxy.Auth{ + User: "", + Password: "", + } + forward := proxy.NewPerHost(proxy.Direct, proxy.Direct) + socksClient, err := proxy.SOCKS5(socksChainNet, socksChainAddr, &auth, forward) + conn, err := socksClient.Dial(serviceNet, serviceAddr) + if err != nil { + panic(err) + } + + // read a banner from the service + rd := bufio.NewReader(conn) + line := []byte{} + line, err = rd.ReadBytes('\n') + if err != nil { + panic(err) + } + if string(line) != banner { + t.Errorf("Did not receive expected banner. Got %s, wanted %s\n", string(line), banner) + t.Fail() + } + + // send the service some data and verify it was received + clientData := "hello world\r\n" + conn.Write([]byte(clientData)) + service.WaitUntilReceived() + if service.buffer.String() != strings.TrimSpace(clientData)+"\n" { + t.Errorf("Client sent %s but service only received %s\n", "hello world\n", service.buffer.String()) + t.Fail() + } +}