mirror of https://github.com/subgraph/fw-daemon
				
				
				
			
							parent
							
								
									3bb8d65ed1
								
							
						
					
					
						commit
						1e84a6e168
					
				| @ -0,0 +1,153 @@ | |||||||
|  | /* | ||||||
|  |  * client.go - SOCSK5 client implementation. | ||||||
|  |  * | ||||||
|  |  * To the extent possible under law, Yawning Angel has waived all copyright and | ||||||
|  |  * related or neighboring rights to or-ctl-filter, using the creative commons | ||||||
|  |  * "cc0" public domain dedication. See LICENSE or | ||||||
|  |  * <http://creativecommons.org/publicdomain/zero/1.0/> for full details.
 | ||||||
|  |  */ | ||||||
|  | 
 | ||||||
|  | package sgfw | ||||||
|  | 
 | ||||||
|  | import ( | ||||||
|  | 	"fmt" | ||||||
|  | 	"io" | ||||||
|  | 	"net" | ||||||
|  | 	"time" | ||||||
|  | ) | ||||||
|  | 
 | ||||||
|  | // Redispatch dials the provided proxy and redispatches an existing request.
 | ||||||
|  | func Redispatch(proxyNet, proxyAddr string, req *Request) (conn net.Conn, bndAddr *Address, err error) { | ||||||
|  | 	defer func() { | ||||||
|  | 		if err != nil && conn != nil { | ||||||
|  | 			conn.Close() | ||||||
|  | 		} | ||||||
|  | 	}() | ||||||
|  | 
 | ||||||
|  | 	conn, err = clientHandshake(proxyNet, proxyAddr, req) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return nil, nil, err | ||||||
|  | 	} | ||||||
|  | 	bndAddr, err = clientCmd(conn, req) | ||||||
|  | 	return | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func clientHandshake(proxyNet, proxyAddr string, req *Request) (net.Conn, error) { | ||||||
|  | 	conn, err := net.Dial(proxyNet, proxyAddr) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return nil, err | ||||||
|  | 	} | ||||||
|  | 	if err := conn.SetDeadline(time.Now().Add(requestTimeout)); err != nil { | ||||||
|  | 		return conn, err | ||||||
|  | 	} | ||||||
|  | 	authMethod, err := clientNegotiateAuth(conn, req) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return conn, err | ||||||
|  | 	} | ||||||
|  | 	if err := clientAuthenticate(conn, req, authMethod); err != nil { | ||||||
|  | 		return conn, err | ||||||
|  | 	} | ||||||
|  | 	if err := conn.SetDeadline(time.Time{}); err != nil { | ||||||
|  | 		return conn, err | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	return conn, nil | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func clientNegotiateAuth(conn net.Conn, req *Request) (byte, error) { | ||||||
|  | 	useRFC1929 := req.Auth.Uname != nil && req.Auth.Passwd != nil | ||||||
|  | 	// XXX: Validate uname/passwd lengths, though should always be valid.
 | ||||||
|  | 
 | ||||||
|  | 	var buf [3]byte | ||||||
|  | 	buf[0] = version | ||||||
|  | 	buf[1] = 1 | ||||||
|  | 	if useRFC1929 { | ||||||
|  | 		buf[2] = authUsernamePassword | ||||||
|  | 	} else { | ||||||
|  | 		buf[2] = authNoneRequired | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	if _, err := conn.Write(buf[:]); err != nil { | ||||||
|  | 		return authNoAcceptableMethods, err | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	var resp [2]byte | ||||||
|  | 	if _, err := io.ReadFull(conn, resp[:]); err != nil { | ||||||
|  | 		return authNoAcceptableMethods, err | ||||||
|  | 	} | ||||||
|  | 	if err := validateByte("version", resp[0], version); err != nil { | ||||||
|  | 		return authNoAcceptableMethods, err | ||||||
|  | 	} | ||||||
|  | 	if err := validateByte("method", resp[1], buf[2]); err != nil { | ||||||
|  | 		return authNoAcceptableMethods, err | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	return resp[1], nil | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func clientAuthenticate(conn net.Conn, req *Request, authMethod byte) error { | ||||||
|  | 	switch authMethod { | ||||||
|  | 	case authNoneRequired: | ||||||
|  | 	case authUsernamePassword: | ||||||
|  | 		var buf []byte | ||||||
|  | 		buf = append(buf, authRFC1929Ver) | ||||||
|  | 		buf = append(buf, byte(len(req.Auth.Uname))) | ||||||
|  | 		buf = append(buf, req.Auth.Uname...) | ||||||
|  | 		buf = append(buf, byte(len(req.Auth.Passwd))) | ||||||
|  | 		buf = append(buf, req.Auth.Passwd...) | ||||||
|  | 		if _, err := conn.Write(buf); err != nil { | ||||||
|  | 			return err | ||||||
|  | 		} | ||||||
|  | 
 | ||||||
|  | 		var resp [2]byte | ||||||
|  | 		if _, err := io.ReadFull(conn, resp[:]); err != nil { | ||||||
|  | 			return err | ||||||
|  | 		} | ||||||
|  | 		if err := validateByte("version", resp[0], authRFC1929Ver); err != nil { | ||||||
|  | 			return err | ||||||
|  | 		} | ||||||
|  | 		if err := validateByte("status", resp[1], authRFC1929Success); err != nil { | ||||||
|  | 			return err | ||||||
|  | 		} | ||||||
|  | 	default: | ||||||
|  | 		panic(fmt.Sprintf("unknown authentication method: 0x%02x", authMethod)) | ||||||
|  | 	} | ||||||
|  | 	return nil | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func clientCmd(conn net.Conn, req *Request) (*Address, error) { | ||||||
|  | 	var buf []byte | ||||||
|  | 	buf = append(buf, version) | ||||||
|  | 	buf = append(buf, byte(req.Cmd)) | ||||||
|  | 	buf = append(buf, rsv) | ||||||
|  | 	buf = append(buf, req.Addr.raw...) | ||||||
|  | 	if _, err := conn.Write(buf); err != nil { | ||||||
|  | 		return nil, err | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	var respHdr [3]byte | ||||||
|  | 	if _, err := io.ReadFull(conn, respHdr[:]); err != nil { | ||||||
|  | 		return nil, err | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	if err := validateByte("version", respHdr[0], version); err != nil { | ||||||
|  | 		return nil, err | ||||||
|  | 	} | ||||||
|  | 	if err := validateByte("rep", respHdr[1], byte(ReplySucceeded)); err != nil { | ||||||
|  | 		return nil, clientError(respHdr[1]) | ||||||
|  | 	} | ||||||
|  | 	if err := validateByte("rsv", respHdr[2], rsv); err != nil { | ||||||
|  | 		return nil, err | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	var bndAddr Address | ||||||
|  | 	if err := bndAddr.read(conn); err != nil { | ||||||
|  | 		return nil, err | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	if err := conn.SetDeadline(time.Time{}); err != nil { | ||||||
|  | 		return nil, err | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	return &bndAddr, nil | ||||||
|  | } | ||||||
| @ -0,0 +1,280 @@ | |||||||
|  | /* | ||||||
|  |  * common.go - SOCSK5 common definitons/routines. | ||||||
|  |  * | ||||||
|  |  * To the extent possible under law, Yawning Angel has waived all copyright and | ||||||
|  |  * related or neighboring rights to or-ctl-filter, using the creative commons | ||||||
|  |  * "cc0" public domain dedication. See LICENSE or | ||||||
|  |  * <http://creativecommons.org/publicdomain/zero/1.0/> for full details.
 | ||||||
|  |  */ | ||||||
|  | 
 | ||||||
|  | // Package socks5 implements a SOCKS5 client/server.  For more information see
 | ||||||
|  | // RFC 1928 and RFC 1929.
 | ||||||
|  | //
 | ||||||
|  | // Notes:
 | ||||||
|  | //  * GSSAPI authentication, is NOT supported.
 | ||||||
|  | //  * The authentication provided by the client is always accepted.
 | ||||||
|  | //  * A lot of the code is shamelessly stolen from obfs4proxy.
 | ||||||
|  | package sgfw | ||||||
|  | 
 | ||||||
|  | import ( | ||||||
|  | 	"errors" | ||||||
|  | 	"fmt" | ||||||
|  | 	"io" | ||||||
|  | 	"net" | ||||||
|  | 	"strconv" | ||||||
|  | 	"syscall" | ||||||
|  | 	"time" | ||||||
|  | ) | ||||||
|  | 
 | ||||||
|  | const ( | ||||||
|  | 	version = 0x05 | ||||||
|  | 	rsv     = 0x00 | ||||||
|  | 
 | ||||||
|  | 	atypIPv4       = 0x01 | ||||||
|  | 	atypDomainName = 0x03 | ||||||
|  | 	atypIPv6       = 0x04 | ||||||
|  | 
 | ||||||
|  | 	authNoneRequired        = 0x00 | ||||||
|  | 	authUsernamePassword    = 0x02 | ||||||
|  | 	authNoAcceptableMethods = 0xff | ||||||
|  | 
 | ||||||
|  | 	inboundTimeout = 5 * time.Second | ||||||
|  | 	requestTimeout = 30 * time.Second | ||||||
|  | ) | ||||||
|  | 
 | ||||||
|  | var errInvalidAtyp = errors.New("invalid address type") | ||||||
|  | 
 | ||||||
|  | // ReplyCode is a SOCKS 5 reply code.
 | ||||||
|  | type ReplyCode byte | ||||||
|  | 
 | ||||||
|  | // The various SOCKS 5 reply codes from RFC 1928.
 | ||||||
|  | const ( | ||||||
|  | 	ReplySucceeded ReplyCode = iota | ||||||
|  | 	ReplyGeneralFailure | ||||||
|  | 	ReplyConnectionNotAllowed | ||||||
|  | 	ReplyNetworkUnreachable | ||||||
|  | 	ReplyHostUnreachable | ||||||
|  | 	ReplyConnectionRefused | ||||||
|  | 	ReplyTTLExpired | ||||||
|  | 	ReplyCommandNotSupported | ||||||
|  | 	ReplyAddressNotSupported | ||||||
|  | ) | ||||||
|  | 
 | ||||||
|  | // Command is a SOCKS 5 command.
 | ||||||
|  | type Command byte | ||||||
|  | 
 | ||||||
|  | // The various SOCKS 5 commands.
 | ||||||
|  | const ( | ||||||
|  | 	CommandConnect       Command = 0x01 | ||||||
|  | 	CommandTorResolve    Command = 0xf0 | ||||||
|  | 	CommandTorResolvePTR Command = 0xf1 | ||||||
|  | ) | ||||||
|  | 
 | ||||||
|  | // Address is a SOCKS 5 address + port.
 | ||||||
|  | type Address struct { | ||||||
|  | 	atyp    uint8 | ||||||
|  | 	raw     []byte | ||||||
|  | 	addrStr string | ||||||
|  | 	portStr string | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // FromString parses the provided "host:port" format address and populates the
 | ||||||
|  | // Address fields.
 | ||||||
|  | func (addr *Address) FromString(addrStr string) (err error) { | ||||||
|  | 	addr.addrStr, addr.portStr, err = net.SplitHostPort(addrStr) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	var raw []byte | ||||||
|  | 	if ip := net.ParseIP(addr.addrStr); ip != nil { | ||||||
|  | 		if v4Addr := ip.To4(); v4Addr != nil { | ||||||
|  | 			raw = append(raw, atypIPv4) | ||||||
|  | 			raw = append(raw, v4Addr...) | ||||||
|  | 		} else if v6Addr := ip.To16(); v6Addr != nil { | ||||||
|  | 			raw = append(raw, atypIPv6) | ||||||
|  | 			raw = append(raw, v6Addr...) | ||||||
|  | 		} else { | ||||||
|  | 			return errors.New("unsupported IP address type") | ||||||
|  | 		} | ||||||
|  | 	} else { | ||||||
|  | 		// Must be a FQDN.
 | ||||||
|  | 		if len(addr.addrStr) > 255 { | ||||||
|  | 			return fmt.Errorf("invalid FQDN, len > 255 bytes (%d bytes)", len(addr.addrStr)) | ||||||
|  | 		} | ||||||
|  | 		raw = append(raw, atypDomainName) | ||||||
|  | 		raw = append(raw, addr.addrStr...) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	var port uint64 | ||||||
|  | 	if port, err = strconv.ParseUint(addr.portStr, 10, 16); err != nil { | ||||||
|  | 		return | ||||||
|  | 	} | ||||||
|  | 	raw = append(raw, byte(port>>8)) | ||||||
|  | 	raw = append(raw, byte(port&0xff)) | ||||||
|  | 
 | ||||||
|  | 	addr.raw = raw | ||||||
|  | 	return | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // String returns the string representation of the address, in "host:port"
 | ||||||
|  | // format.
 | ||||||
|  | func (addr *Address) String() string { | ||||||
|  | 	return addr.addrStr + ":" + addr.portStr | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // HostPort returns the string representation of the addess, split into the
 | ||||||
|  | // host and port components.
 | ||||||
|  | func (addr *Address) HostPort() (string, string) { | ||||||
|  | 	return addr.addrStr, addr.portStr | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // Type returns the address type from the connect command this address was
 | ||||||
|  | // parsed from
 | ||||||
|  | func (addr *Address) Type() uint8 { | ||||||
|  | 	return addr.atyp | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (addr *Address) read(conn net.Conn) (err error) { | ||||||
|  | 	// The address looks like:
 | ||||||
|  | 	//  uint8_t atyp
 | ||||||
|  | 	//  uint8_t addr[] (Length depends on atyp)
 | ||||||
|  | 	//  uint16_t port
 | ||||||
|  | 
 | ||||||
|  | 	// Read the atype.
 | ||||||
|  | 	var atyp byte | ||||||
|  | 	if atyp, err = readByte(conn); err != nil { | ||||||
|  | 		return | ||||||
|  | 	} | ||||||
|  | 	addr.raw = append(addr.raw, atyp) | ||||||
|  | 
 | ||||||
|  | 	// Read the address.
 | ||||||
|  | 	var rawAddr []byte | ||||||
|  | 	switch atyp { | ||||||
|  | 	case atypIPv4: | ||||||
|  | 		rawAddr = make([]byte, net.IPv4len) | ||||||
|  | 		if _, err = io.ReadFull(conn, rawAddr); err != nil { | ||||||
|  | 			return | ||||||
|  | 		} | ||||||
|  | 		v4Addr := net.IPv4(rawAddr[0], rawAddr[1], rawAddr[2], rawAddr[3]) | ||||||
|  | 		addr.addrStr = v4Addr.String() | ||||||
|  | 	case atypDomainName: | ||||||
|  | 		var alen byte | ||||||
|  | 		if alen, err = readByte(conn); err != nil { | ||||||
|  | 			return | ||||||
|  | 		} | ||||||
|  | 		if alen == 0 { | ||||||
|  | 			return fmt.Errorf("domain name with 0 length") | ||||||
|  | 		} | ||||||
|  | 		rawAddr = make([]byte, alen) | ||||||
|  | 		addr.raw = append(addr.raw, alen) | ||||||
|  | 		if _, err = io.ReadFull(conn, rawAddr); err != nil { | ||||||
|  | 			return | ||||||
|  | 		} | ||||||
|  | 		addr.addrStr = string(rawAddr) | ||||||
|  | 	case atypIPv6: | ||||||
|  | 		rawAddr = make([]byte, net.IPv6len) | ||||||
|  | 		if _, err = io.ReadFull(conn, rawAddr); err != nil { | ||||||
|  | 			return | ||||||
|  | 		} | ||||||
|  | 		v6Addr := make(net.IP, net.IPv6len) | ||||||
|  | 		copy(v6Addr[:], rawAddr) | ||||||
|  | 		addr.addrStr = fmt.Sprintf("[%s]", v6Addr.String()) | ||||||
|  | 	default: | ||||||
|  | 		return errInvalidAtyp | ||||||
|  | 	} | ||||||
|  | 	addr.atyp = atyp | ||||||
|  | 	addr.raw = append(addr.raw, rawAddr...) | ||||||
|  | 
 | ||||||
|  | 	// Read the port.
 | ||||||
|  | 	var rawPort [2]byte | ||||||
|  | 	if _, err = io.ReadFull(conn, rawPort[:]); err != nil { | ||||||
|  | 		return | ||||||
|  | 	} | ||||||
|  | 	port := int(rawPort[0])<<8 | int(rawPort[1]) | ||||||
|  | 	addr.portStr = fmt.Sprintf("%d", port) | ||||||
|  | 	addr.raw = append(addr.raw, rawPort[:]...) | ||||||
|  | 
 | ||||||
|  | 	return | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // ErrorToReplyCode converts an error to the "best" reply code.
 | ||||||
|  | func ErrorToReplyCode(err error) ReplyCode { | ||||||
|  | 	if cErr, ok := err.(clientError); ok { | ||||||
|  | 		return ReplyCode(cErr) | ||||||
|  | 	} | ||||||
|  | 	opErr, ok := err.(*net.OpError) | ||||||
|  | 	if !ok { | ||||||
|  | 		return ReplyGeneralFailure | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	errno, ok := opErr.Err.(syscall.Errno) | ||||||
|  | 	if !ok { | ||||||
|  | 		return ReplyGeneralFailure | ||||||
|  | 	} | ||||||
|  | 	switch errno { | ||||||
|  | 	case syscall.EADDRNOTAVAIL: | ||||||
|  | 		return ReplyAddressNotSupported | ||||||
|  | 	case syscall.ETIMEDOUT: | ||||||
|  | 		return ReplyTTLExpired | ||||||
|  | 	case syscall.ENETUNREACH: | ||||||
|  | 		return ReplyNetworkUnreachable | ||||||
|  | 	case syscall.EHOSTUNREACH: | ||||||
|  | 		return ReplyHostUnreachable | ||||||
|  | 	case syscall.ECONNREFUSED, syscall.ECONNRESET: | ||||||
|  | 		return ReplyConnectionRefused | ||||||
|  | 	default: | ||||||
|  | 		return ReplyGeneralFailure | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // Request describes a SOCKS 5 request.
 | ||||||
|  | type Request struct { | ||||||
|  | 	Auth AuthInfo | ||||||
|  | 	Cmd  Command | ||||||
|  | 	Addr Address | ||||||
|  | 
 | ||||||
|  | 	conn net.Conn | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | type clientError ReplyCode | ||||||
|  | 
 | ||||||
|  | func (e clientError) Error() string { | ||||||
|  | 	switch ReplyCode(e) { | ||||||
|  | 	case ReplySucceeded: | ||||||
|  | 		return "socks5: succeeded" | ||||||
|  | 	case ReplyGeneralFailure: | ||||||
|  | 		return "socks5: general failure" | ||||||
|  | 	case ReplyConnectionNotAllowed: | ||||||
|  | 		return "socks5: connection not allowed" | ||||||
|  | 	case ReplyNetworkUnreachable: | ||||||
|  | 		return "socks5: network unreachable" | ||||||
|  | 	case ReplyHostUnreachable: | ||||||
|  | 		return "socks5: host unreachable" | ||||||
|  | 	case ReplyConnectionRefused: | ||||||
|  | 		return "socks5: connection refused" | ||||||
|  | 	case ReplyTTLExpired: | ||||||
|  | 		return "socks5: ttl expired" | ||||||
|  | 	case ReplyCommandNotSupported: | ||||||
|  | 		return "socks5: command not supported" | ||||||
|  | 	case ReplyAddressNotSupported: | ||||||
|  | 		return "socks5: address not supported" | ||||||
|  | 	default: | ||||||
|  | 		return fmt.Sprintf("socks5: reply code: 0x%02x", e) | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func readByte(conn net.Conn) (byte, error) { | ||||||
|  | 	var tmp [1]byte | ||||||
|  | 	if _, err := conn.Read(tmp[:]); err != nil { | ||||||
|  | 		return 0, err | ||||||
|  | 	} | ||||||
|  | 	return tmp[0], nil | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func validateByte(descr string, val, expected byte) error { | ||||||
|  | 	if val != expected { | ||||||
|  | 		return fmt.Errorf("message field '%s' was 0x%02x (expected 0x%02x)", descr, val, expected) | ||||||
|  | 	} | ||||||
|  | 	return nil | ||||||
|  | } | ||||||
| @ -0,0 +1,200 @@ | |||||||
|  | /* | ||||||
|  |  * server.go - SOCSK5 server implementation. | ||||||
|  |  * | ||||||
|  |  * To the extent possible under law, Yawning Angel has waived all copyright and | ||||||
|  |  * related or neighboring rights to or-ctl-filter, using the creative commons | ||||||
|  |  * "cc0" public domain dedication. See LICENSE or | ||||||
|  |  * <http://creativecommons.org/publicdomain/zero/1.0/> for full details.
 | ||||||
|  |  */ | ||||||
|  | 
 | ||||||
|  | package sgfw | ||||||
|  | 
 | ||||||
|  | import ( | ||||||
|  | 	"bytes" | ||||||
|  | 	"fmt" | ||||||
|  | 	"io" | ||||||
|  | 	"net" | ||||||
|  | 	"time" | ||||||
|  | ) | ||||||
|  | 
 | ||||||
|  | // Handshake attempts to handle a incoming client handshake over the provided
 | ||||||
|  | // connection and receive the SOCKS5 request.  The routine handles sending
 | ||||||
|  | // appropriate errors if applicable, but will not close the connection.
 | ||||||
|  | func Handshake(conn net.Conn) (*Request, error) { | ||||||
|  | 	// Arm the handshake timeout.
 | ||||||
|  | 	var err error | ||||||
|  | 	if err = conn.SetDeadline(time.Now().Add(inboundTimeout)); err != nil { | ||||||
|  | 		return nil, err | ||||||
|  | 	} | ||||||
|  | 	defer func() { | ||||||
|  | 		// Disarm the handshake timeout, only propagate the error if
 | ||||||
|  | 		// the handshake was successful.
 | ||||||
|  | 		nerr := conn.SetDeadline(time.Time{}) | ||||||
|  | 		if err == nil { | ||||||
|  | 			err = nerr | ||||||
|  | 		} | ||||||
|  | 	}() | ||||||
|  | 
 | ||||||
|  | 	req := new(Request) | ||||||
|  | 	req.conn = conn | ||||||
|  | 
 | ||||||
|  | 	// Negotiate the protocol version and authentication method.
 | ||||||
|  | 	var method byte | ||||||
|  | 	if method, err = req.negotiateAuth(); err != nil { | ||||||
|  | 		return nil, err | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	// Authenticate if neccecary.
 | ||||||
|  | 	if err = req.authenticate(method); err != nil { | ||||||
|  | 		return nil, err | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	// Read the client command.
 | ||||||
|  | 	if err = req.readCommand(); err != nil { | ||||||
|  | 		return nil, err | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	return req, err | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // Reply sends a SOCKS5 reply to the corresponding request.  The BND.ADDR and
 | ||||||
|  | // BND.PORT fields are always set to an address/port corresponding to
 | ||||||
|  | // "0.0.0.0:0".
 | ||||||
|  | func (req *Request) Reply(code ReplyCode) error { | ||||||
|  | 	return req.ReplyAddr(code, nil) | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // ReplyAddr sends a SOCKS5 reply to the corresponding request.  The BND.ADDR
 | ||||||
|  | // and BND.PORT fields are specified by addr, or "0.0.0.0:0" if not provided.
 | ||||||
|  | func (req *Request) ReplyAddr(code ReplyCode, addr *Address) error { | ||||||
|  | 	// The server sends a reply message.
 | ||||||
|  | 	//  uint8_t ver (0x05)
 | ||||||
|  | 	//  uint8_t rep
 | ||||||
|  | 	//  uint8_t rsv (0x00)
 | ||||||
|  | 	//  uint8_t atyp
 | ||||||
|  | 	//  uint8_t bnd_addr[]
 | ||||||
|  | 	//  uint16_t bnd_port
 | ||||||
|  | 
 | ||||||
|  | 	resp := []byte{version, byte(code), rsv} | ||||||
|  | 	if addr == nil { | ||||||
|  | 		var nilAddr [net.IPv4len + 2]byte | ||||||
|  | 		resp = append(resp, atypIPv4) | ||||||
|  | 		resp = append(resp, nilAddr[:]...) | ||||||
|  | 	} else { | ||||||
|  | 		resp = append(resp, addr.raw...) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	_, err := req.conn.Write(resp[:]) | ||||||
|  | 	return err | ||||||
|  | 
 | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (req *Request) negotiateAuth() (byte, error) { | ||||||
|  | 	// The client sends a version identifier/selection message.
 | ||||||
|  | 	//	uint8_t ver (0x05)
 | ||||||
|  | 	//  uint8_t nmethods (>= 1).
 | ||||||
|  | 	//  uint8_t methods[nmethods]
 | ||||||
|  | 
 | ||||||
|  | 	var err error | ||||||
|  | 	if err = req.readByteVerify("version", version); err != nil { | ||||||
|  | 		return 0, err | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	// Read the number of methods, and the methods.
 | ||||||
|  | 	var nmethods byte | ||||||
|  | 	method := byte(authNoAcceptableMethods) | ||||||
|  | 	if nmethods, err = req.readByte(); err != nil { | ||||||
|  | 		return method, err | ||||||
|  | 	} | ||||||
|  | 	methods := make([]byte, nmethods) | ||||||
|  | 	if _, err := io.ReadFull(req.conn, methods); err != nil { | ||||||
|  | 		return 0, err | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	// Pick the best authentication method, prioritizing authenticating
 | ||||||
|  | 	// over not if both options are present.
 | ||||||
|  | 	if bytes.IndexByte(methods, authUsernamePassword) != -1 { | ||||||
|  | 		method = authUsernamePassword | ||||||
|  | 	} else if bytes.IndexByte(methods, authNoneRequired) != -1 { | ||||||
|  | 		method = authNoneRequired | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	// The server sends a method selection message.
 | ||||||
|  | 	//  uint8_t ver (0x05)
 | ||||||
|  | 	//  uint8_t method
 | ||||||
|  | 	msg := []byte{version, method} | ||||||
|  | 	if _, err = req.conn.Write(msg); err != nil { | ||||||
|  | 		return 0, err | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	return method, nil | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (req *Request) authenticate(method byte) error { | ||||||
|  | 	switch method { | ||||||
|  | 	case authNoneRequired: | ||||||
|  | 		return nil | ||||||
|  | 	case authUsernamePassword: | ||||||
|  | 		return req.authRFC1929() | ||||||
|  | 	case authNoAcceptableMethods: | ||||||
|  | 		return fmt.Errorf("no acceptable authentication methods") | ||||||
|  | 	default: | ||||||
|  | 		// This should never happen as only supported auth methods should be
 | ||||||
|  | 		// negotiated.
 | ||||||
|  | 		return fmt.Errorf("negotiated unsupported method 0x%02x", method) | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (req *Request) readCommand() error { | ||||||
|  | 	// The client sends the request details.
 | ||||||
|  | 	//  uint8_t ver (0x05)
 | ||||||
|  | 	//  uint8_t cmd
 | ||||||
|  | 	//  uint8_t rsv (0x00)
 | ||||||
|  | 	//  uint8_t atyp
 | ||||||
|  | 	//  uint8_t dst_addr[]
 | ||||||
|  | 	//  uint16_t dst_port
 | ||||||
|  | 
 | ||||||
|  | 	var err error | ||||||
|  | 	var cmd byte | ||||||
|  | 	if err = req.readByteVerify("version", version); err != nil { | ||||||
|  | 		req.Reply(ReplyGeneralFailure) | ||||||
|  | 		return err | ||||||
|  | 	} | ||||||
|  | 	if cmd, err = req.readByte(); err != nil { | ||||||
|  | 		req.Reply(ReplyGeneralFailure) | ||||||
|  | 		return err | ||||||
|  | 	} | ||||||
|  | 	switch Command(cmd) { | ||||||
|  | 	case CommandConnect, CommandTorResolve, CommandTorResolvePTR: | ||||||
|  | 		req.Cmd = Command(cmd) | ||||||
|  | 	default: | ||||||
|  | 		req.Reply(ReplyCommandNotSupported) | ||||||
|  | 		return fmt.Errorf("unsupported SOCKS command: 0x%02x", cmd) | ||||||
|  | 	} | ||||||
|  | 	if err = req.readByteVerify("reserved", rsv); err != nil { | ||||||
|  | 		req.Reply(ReplyGeneralFailure) | ||||||
|  | 		return err | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	// Read the destination address/port.
 | ||||||
|  | 	err = req.Addr.read(req.conn) | ||||||
|  | 	if err == errInvalidAtyp { | ||||||
|  | 		req.Reply(ReplyAddressNotSupported) | ||||||
|  | 	} else if err != nil { | ||||||
|  | 		req.Reply(ReplyGeneralFailure) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	return err | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (req *Request) readByte() (byte, error) { | ||||||
|  | 	return readByte(req.conn) | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (req *Request) readByteVerify(descr string, expected byte) error { | ||||||
|  | 	val, err := req.readByte() | ||||||
|  | 	if err != nil { | ||||||
|  | 		return err | ||||||
|  | 	} | ||||||
|  | 	return validateByte(descr, val, expected) | ||||||
|  | } | ||||||
| @ -0,0 +1,84 @@ | |||||||
|  | /* | ||||||
|  |  * server_rfc1929.go - SOCSK 5 server authentication. | ||||||
|  |  * | ||||||
|  |  * To the extent possible under law, Yawning Angel has waived all copyright and | ||||||
|  |  * related or neighboring rights to or-ctl-filter, using the creative commons | ||||||
|  |  * "cc0" public domain dedication. See LICENSE or | ||||||
|  |  * <http://creativecommons.org/publicdomain/zero/1.0/> for full details.
 | ||||||
|  |  */ | ||||||
|  | 
 | ||||||
|  | package sgfw | ||||||
|  | 
 | ||||||
|  | import ( | ||||||
|  | 	"fmt" | ||||||
|  | 	"io" | ||||||
|  | ) | ||||||
|  | 
 | ||||||
|  | const ( | ||||||
|  | 	authRFC1929Ver     = 0x01 | ||||||
|  | 	authRFC1929Success = 0x00 | ||||||
|  | 	authRFC1929Fail    = 0x01 | ||||||
|  | ) | ||||||
|  | 
 | ||||||
|  | // AuthInfo is the RFC 1929 Username/Password authentication data.
 | ||||||
|  | type AuthInfo struct { | ||||||
|  | 	Uname  []byte | ||||||
|  | 	Passwd []byte | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (req *Request) authRFC1929() (err error) { | ||||||
|  | 	sendErrResp := func() { | ||||||
|  | 		// Swallow write/flush errors, the auth failure is the relevant error.
 | ||||||
|  | 		resp := []byte{authRFC1929Ver, authRFC1929Fail} | ||||||
|  | 		req.conn.Write(resp[:]) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	// The client sends a Username/Password request.
 | ||||||
|  | 	//  uint8_t ver (0x01)
 | ||||||
|  | 	//  uint8_t ulen (>= 1)
 | ||||||
|  | 	//  uint8_t uname[ulen]
 | ||||||
|  | 	//  uint8_t plen (>= 1)
 | ||||||
|  | 	//  uint8_t passwd[plen]
 | ||||||
|  | 
 | ||||||
|  | 	if err = req.readByteVerify("auth version", authRFC1929Ver); err != nil { | ||||||
|  | 		sendErrResp() | ||||||
|  | 		return | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	// Read the username.
 | ||||||
|  | 	var ulen byte | ||||||
|  | 	if ulen, err = req.readByte(); err != nil { | ||||||
|  | 		sendErrResp() | ||||||
|  | 		return | ||||||
|  | 	} else if ulen < 1 { | ||||||
|  | 		sendErrResp() | ||||||
|  | 		return fmt.Errorf("username with 0 length") | ||||||
|  | 	} | ||||||
|  | 	uname := make([]byte, ulen) | ||||||
|  | 	if _, err = io.ReadFull(req.conn, uname); err != nil { | ||||||
|  | 		sendErrResp() | ||||||
|  | 		return | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	// Read the password.
 | ||||||
|  | 	var plen byte | ||||||
|  | 	if plen, err = req.readByte(); err != nil { | ||||||
|  | 		sendErrResp() | ||||||
|  | 		return | ||||||
|  | 	} else if plen < 1 { | ||||||
|  | 		sendErrResp() | ||||||
|  | 		return fmt.Errorf("password with 0 length") | ||||||
|  | 	} | ||||||
|  | 	passwd := make([]byte, plen) | ||||||
|  | 	if _, err = io.ReadFull(req.conn, passwd); err != nil { | ||||||
|  | 		sendErrResp() | ||||||
|  | 		return | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	req.Auth.Uname = uname | ||||||
|  | 	req.Auth.Passwd = passwd | ||||||
|  | 
 | ||||||
|  | 	resp := []byte{authRFC1929Ver, authRFC1929Success} | ||||||
|  | 	_, err = req.conn.Write(resp[:]) | ||||||
|  | 	return | ||||||
|  | } | ||||||
| @ -0,0 +1,261 @@ | |||||||
|  | package sgfw | ||||||
|  | 
 | ||||||
|  | import ( | ||||||
|  | 	"io" | ||||||
|  | 	"net" | ||||||
|  | 	"os" | ||||||
|  | 	"sync" | ||||||
|  | 
 | ||||||
|  | 	"github.com/subgraph/go-procsnitch" | ||||||
|  | 	"strconv" | ||||||
|  | ) | ||||||
|  | 
 | ||||||
|  | type socksChainConfig struct { | ||||||
|  | 	TargetSocksNet  string | ||||||
|  | 	TargetSocksAddr string | ||||||
|  | 	ListenSocksNet  string | ||||||
|  | 	ListenSocksAddr string | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | type socksChain struct { | ||||||
|  | 	cfg      *socksChainConfig | ||||||
|  | 	fw       *Firewall | ||||||
|  | 	listener net.Listener | ||||||
|  | 	wg       *sync.WaitGroup | ||||||
|  | 	procInfo procsnitch.ProcInfo | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | type socksChainSession struct { | ||||||
|  | 	cfg          *socksChainConfig | ||||||
|  | 	clientConn   net.Conn | ||||||
|  | 	upstreamConn net.Conn | ||||||
|  | 	req          *Request | ||||||
|  | 	bndAddr      *Address | ||||||
|  | 	optData      []byte | ||||||
|  | 	procInfo     procsnitch.ProcInfo | ||||||
|  | 	server       *socksChain | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | const ( | ||||||
|  | 	socksVerdictDrop   = 1 | ||||||
|  | 	socksVerdictAccept = 2 | ||||||
|  | ) | ||||||
|  | 
 | ||||||
|  | type pendingSocksConnection struct { | ||||||
|  | 	pol      *Policy | ||||||
|  | 	hname    string | ||||||
|  | 	destIP   net.IP | ||||||
|  | 	destPort uint16 | ||||||
|  | 	pinfo    *procsnitch.Info | ||||||
|  | 	verdict  chan int | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (sc *pendingSocksConnection) policy() *Policy { | ||||||
|  | 	return sc.pol | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (sc *pendingSocksConnection) procInfo() *procsnitch.Info { | ||||||
|  | 	return sc.pinfo | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (sc *pendingSocksConnection) hostname() string { | ||||||
|  | 	return sc.hname | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (sc *pendingSocksConnection) dst() net.IP { | ||||||
|  | 	return sc.destIP | ||||||
|  | } | ||||||
|  | func (sc *pendingSocksConnection) dstPort() uint16 { | ||||||
|  | 	return sc.destPort | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (sc *pendingSocksConnection) deliverVerdict(v int) { | ||||||
|  | 	sc.verdict <- v | ||||||
|  | 	close(sc.verdict) | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (sc *pendingSocksConnection) accept() { sc.deliverVerdict(socksVerdictAccept) } | ||||||
|  | 
 | ||||||
|  | func (sc *pendingSocksConnection) drop() { sc.deliverVerdict(socksVerdictDrop) } | ||||||
|  | 
 | ||||||
|  | func (sc *pendingSocksConnection) print() string { return "socks connection" } | ||||||
|  | 
 | ||||||
|  | func NewSocksChain(cfg *socksChainConfig, wg *sync.WaitGroup, fw *Firewall) *socksChain { | ||||||
|  | 	chain := socksChain{ | ||||||
|  | 		cfg:      cfg, | ||||||
|  | 		fw:       fw, | ||||||
|  | 		wg:       wg, | ||||||
|  | 		procInfo: procsnitch.SystemProcInfo{}, | ||||||
|  | 	} | ||||||
|  | 	return &chain | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // Start initializes the SOCKS 5 server and starts
 | ||||||
|  | // accepting connections.
 | ||||||
|  | func (s *socksChain) start() { | ||||||
|  | 	var err error | ||||||
|  | 	s.listener, err = net.Listen(s.cfg.ListenSocksNet, s.cfg.ListenSocksAddr) | ||||||
|  | 	if err != nil { | ||||||
|  | 		log.Errorf("ERR/socks: Failed to listen on the socks address: %v", err) | ||||||
|  | 		os.Exit(1) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	s.wg.Add(1) | ||||||
|  | 	go s.socksAcceptLoop() | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (s *socksChain) socksAcceptLoop() error { | ||||||
|  | 	defer s.wg.Done() | ||||||
|  | 	defer s.listener.Close() | ||||||
|  | 
 | ||||||
|  | 	for { | ||||||
|  | 		conn, err := s.listener.Accept() | ||||||
|  | 		if err != nil { | ||||||
|  | 			if e, ok := err.(net.Error); ok && !e.Temporary() { | ||||||
|  | 				log.Infof("ERR/socks: Failed to Accept(): %v", err) | ||||||
|  | 				return err | ||||||
|  | 			} | ||||||
|  | 			continue | ||||||
|  | 		} | ||||||
|  | 		session := &socksChainSession{cfg: s.cfg, clientConn: conn, procInfo: s.procInfo, server: s} | ||||||
|  | 		go session.sessionWorker() | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (c *socksChainSession) sessionWorker() { | ||||||
|  | 	defer c.clientConn.Close() | ||||||
|  | 
 | ||||||
|  | 	clientAddr := c.clientConn.RemoteAddr() | ||||||
|  | 	log.Infof("INFO/socks: New connection from: %v", clientAddr) | ||||||
|  | 
 | ||||||
|  | 	// Do the SOCKS handshake with the client, and read the command.
 | ||||||
|  | 	var err error | ||||||
|  | 	if c.req, err = Handshake(c.clientConn); err != nil { | ||||||
|  | 		log.Infof("ERR/socks: Failed SOCKS5 handshake: %v", err) | ||||||
|  | 		return | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	switch c.req.Cmd { | ||||||
|  | 	case CommandTorResolve, CommandTorResolvePTR: | ||||||
|  | 		err = c.dispatchTorSOCKS() | ||||||
|  | 
 | ||||||
|  | 		// If we reach here, the request has been dispatched and completed.
 | ||||||
|  | 		if err == nil { | ||||||
|  | 			// Successfully even, send the response back with the addresc.
 | ||||||
|  | 			c.req.ReplyAddr(ReplySucceeded, c.bndAddr) | ||||||
|  | 		} | ||||||
|  | 	case CommandConnect: | ||||||
|  | 		if !c.filterConnect() { | ||||||
|  | 			c.req.Reply(ReplyConnectionRefused) | ||||||
|  | 			return | ||||||
|  | 		} | ||||||
|  | 		c.handleConnect() | ||||||
|  | 	default: | ||||||
|  | 		// Should *NEVER* happen, validated as part of handshake.
 | ||||||
|  | 		log.Infof("BUG/socks: Unsupported SOCKS command: 0x%02x", c.req.Cmd) | ||||||
|  | 		c.req.Reply(ReplyCommandNotSupported) | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (c *socksChainSession) addressDetails() (string, net.IP, uint16) { | ||||||
|  | 	addr := c.req.Addr | ||||||
|  | 	host, pstr := addr.HostPort() | ||||||
|  | 	port, err := strconv.ParseUint(pstr, 10, 16) | ||||||
|  | 	if err != nil || port == 0 || port > 0xFFFF { | ||||||
|  | 		log.Warningf("Illegal port value in socks address: %v", addr) | ||||||
|  | 		return "", nil, 0 | ||||||
|  | 	} | ||||||
|  | 	if addr.Type() == 3 { | ||||||
|  | 		return host, nil, uint16(port) | ||||||
|  | 	} | ||||||
|  | 	ip := net.ParseIP(host) | ||||||
|  | 	if ip == nil { | ||||||
|  | 		log.Warningf("Failed to extract address information from socks address: %v", addr) | ||||||
|  | 	} | ||||||
|  | 	return "", ip, uint16(port) | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (c *socksChainSession) filterConnect() bool { | ||||||
|  | 	pinfo := procsnitch.FindProcessForConnection(c.clientConn, c.procInfo) | ||||||
|  | 	if pinfo == nil { | ||||||
|  | 		log.Warningf("No proc found for connection from: %s", c.clientConn.RemoteAddr()) | ||||||
|  | 		return false | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	policy := c.server.fw.PolicyForPath(pinfo.ExePath) | ||||||
|  | 
 | ||||||
|  | 	hostname, ip, port := c.addressDetails() | ||||||
|  | 	if ip == nil && hostname == "" { | ||||||
|  | 		return false | ||||||
|  | 	} | ||||||
|  | 	result := policy.rules.filter(nil, ip, port, hostname, pinfo) | ||||||
|  | 	switch result { | ||||||
|  | 	case FILTER_DENY: | ||||||
|  | 		return false | ||||||
|  | 	case FILTER_ALLOW: | ||||||
|  | 		return true | ||||||
|  | 	case FILTER_PROMPT: | ||||||
|  | 		pending := &pendingSocksConnection{ | ||||||
|  | 			pol:      policy, | ||||||
|  | 			hname:    hostname, | ||||||
|  | 			destIP:   ip, | ||||||
|  | 			destPort: port, | ||||||
|  | 			pinfo:    pinfo, | ||||||
|  | 			verdict:  make(chan int), | ||||||
|  | 		} | ||||||
|  | 		policy.processPromptResult(pending) | ||||||
|  | 		v := <-pending.verdict | ||||||
|  | 		if v == socksVerdictAccept { | ||||||
|  | 			return true | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	return false | ||||||
|  | 
 | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (c *socksChainSession) handleConnect() { | ||||||
|  | 	err := c.dispatchTorSOCKS() | ||||||
|  | 	if err != nil { | ||||||
|  | 		return | ||||||
|  | 	} | ||||||
|  | 	c.req.Reply(ReplySucceeded) | ||||||
|  | 	defer c.upstreamConn.Close() | ||||||
|  | 
 | ||||||
|  | 	if c.optData != nil { | ||||||
|  | 		if _, err = c.upstreamConn.Write(c.optData); err != nil { | ||||||
|  | 			log.Infof("ERR/socks: Failed writing OptData: %v", err) | ||||||
|  | 			return | ||||||
|  | 		} | ||||||
|  | 		c.optData = nil | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	// A upstream connection has been established, push data back and forth
 | ||||||
|  | 	// till the session is done.
 | ||||||
|  | 	c.forwardTraffic() | ||||||
|  | 	log.Infof("INFO/socks: Closed SOCKS connection from: %v", c.clientConn.RemoteAddr()) | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (c *socksChainSession) forwardTraffic() { | ||||||
|  | 	var wg sync.WaitGroup | ||||||
|  | 	wg.Add(2) | ||||||
|  | 
 | ||||||
|  | 	copyLoop := func(dst, src net.Conn) { | ||||||
|  | 		defer wg.Done() | ||||||
|  | 		defer dst.Close() | ||||||
|  | 
 | ||||||
|  | 		io.Copy(dst, src) | ||||||
|  | 	} | ||||||
|  | 	go copyLoop(c.upstreamConn, c.clientConn) | ||||||
|  | 	go copyLoop(c.clientConn, c.upstreamConn) | ||||||
|  | 
 | ||||||
|  | 	wg.Wait() | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (c *socksChainSession) dispatchTorSOCKS() (err error) { | ||||||
|  | 	c.upstreamConn, c.bndAddr, err = Redispatch(c.cfg.TargetSocksNet, c.cfg.TargetSocksAddr, c.req) | ||||||
|  | 	if err != nil { | ||||||
|  | 		c.req.Reply(ErrorToReplyCode(err)) | ||||||
|  | 	} | ||||||
|  | 	return | ||||||
|  | } | ||||||
| @ -0,0 +1,305 @@ | |||||||
|  | package sgfw | ||||||
|  | 
 | ||||||
|  | import ( | ||||||
|  | 	"bufio" | ||||||
|  | 	"bytes" | ||||||
|  | 	"fmt" | ||||||
|  | 	"io" | ||||||
|  | 	"net" | ||||||
|  | 	"strings" | ||||||
|  | 	"sync" | ||||||
|  | 	"testing" | ||||||
|  | 	"time" | ||||||
|  | 
 | ||||||
|  | 	"github.com/subgraph/fw-daemon/socks5" | ||||||
|  | 	"golang.org/x/net/proxy" | ||||||
|  | ) | ||||||
|  | 
 | ||||||
|  | // MortalService can be killed at any time.
 | ||||||
|  | type MortalService struct { | ||||||
|  | 	network            string | ||||||
|  | 	address            string | ||||||
|  | 	connectionCallback func(net.Conn) error | ||||||
|  | 
 | ||||||
|  | 	conns     []net.Conn | ||||||
|  | 	quit      chan bool | ||||||
|  | 	listener  net.Listener | ||||||
|  | 	waitGroup *sync.WaitGroup | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // NewMortalService creates a new MortalService
 | ||||||
|  | func NewMortalService(network, address string, connectionCallback func(net.Conn) error) *MortalService { | ||||||
|  | 	l := MortalService{ | ||||||
|  | 		network:            network, | ||||||
|  | 		address:            address, | ||||||
|  | 		connectionCallback: connectionCallback, | ||||||
|  | 
 | ||||||
|  | 		conns:     make([]net.Conn, 0, 10), | ||||||
|  | 		quit:      make(chan bool), | ||||||
|  | 		waitGroup: &sync.WaitGroup{}, | ||||||
|  | 	} | ||||||
|  | 	return &l | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // Stop will kill our listener and all it's connections
 | ||||||
|  | func (l *MortalService) Stop() { | ||||||
|  | 	log.Infof("stopping listener service %s:%s", l.network, l.address) | ||||||
|  | 	close(l.quit) | ||||||
|  | 	if l.listener != nil { | ||||||
|  | 		l.listener.Close() | ||||||
|  | 	} | ||||||
|  | 	l.waitGroup.Wait() | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (l *MortalService) acceptLoop() { | ||||||
|  | 	defer l.waitGroup.Done() | ||||||
|  | 	defer func() { | ||||||
|  | 		log.Infof("stoping listener service %s:%s", l.network, l.address) | ||||||
|  | 		for i, conn := range l.conns { | ||||||
|  | 			if conn != nil { | ||||||
|  | 				log.Infof("Closing connection #%d", i) | ||||||
|  | 				conn.Close() | ||||||
|  | 			} | ||||||
|  | 		} | ||||||
|  | 	}() | ||||||
|  | 	defer l.listener.Close() | ||||||
|  | 
 | ||||||
|  | 	for { | ||||||
|  | 		conn, err := l.listener.Accept() | ||||||
|  | 		if nil != err { | ||||||
|  | 			if opErr, ok := err.(*net.OpError); ok && opErr.Timeout() { | ||||||
|  | 				continue | ||||||
|  | 			} else { | ||||||
|  | 				log.Infof("MortalService connection accept failure: %s\n", err) | ||||||
|  | 				select { | ||||||
|  | 				case <-l.quit: | ||||||
|  | 					return | ||||||
|  | 				default: | ||||||
|  | 				} | ||||||
|  | 				continue | ||||||
|  | 			} | ||||||
|  | 		} | ||||||
|  | 
 | ||||||
|  | 		l.conns = append(l.conns, conn) | ||||||
|  | 		go l.handleConnection(conn, len(l.conns)-1) | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (l *MortalService) createDeadlinedListener() error { | ||||||
|  | 	if l.network == "tcp" { | ||||||
|  | 		tcpAddr, err := net.ResolveTCPAddr("tcp", l.address) | ||||||
|  | 		if err != nil { | ||||||
|  | 			return fmt.Errorf("MortalService.createDeadlinedListener %s %s failure: %s", l.network, l.address, err) | ||||||
|  | 		} | ||||||
|  | 		tcpListener, err := net.ListenTCP("tcp", tcpAddr) | ||||||
|  | 		if err != nil { | ||||||
|  | 			return fmt.Errorf("MortalService.createDeadlinedListener %s %s failure: %s", l.network, l.address, err) | ||||||
|  | 		} | ||||||
|  | 		tcpListener.SetDeadline(time.Now().Add(1e9)) | ||||||
|  | 		l.listener = tcpListener | ||||||
|  | 		return nil | ||||||
|  | 	} else if l.network == "unix" { | ||||||
|  | 		unixAddr, err := net.ResolveUnixAddr("unix", l.address) | ||||||
|  | 		if err != nil { | ||||||
|  | 			return fmt.Errorf("MortalService.createDeadlinedListener %s %s failure: %s", l.network, l.address, err) | ||||||
|  | 		} | ||||||
|  | 		unixListener, err := net.ListenUnix("unix", unixAddr) | ||||||
|  | 		if err != nil { | ||||||
|  | 			return fmt.Errorf("MortalService.createDeadlinedListener %s %s failure: %s", l.network, l.address, err) | ||||||
|  | 		} | ||||||
|  | 		unixListener.SetDeadline(time.Now().Add(1e9)) | ||||||
|  | 		l.listener = unixListener | ||||||
|  | 		return nil | ||||||
|  | 	} else { | ||||||
|  | 		panic("") | ||||||
|  | 	} | ||||||
|  | 	return nil | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // Start the MortalService
 | ||||||
|  | func (l *MortalService) Start() error { | ||||||
|  | 	var err error | ||||||
|  | 	err = l.createDeadlinedListener() | ||||||
|  | 	if err != nil { | ||||||
|  | 		return err | ||||||
|  | 	} | ||||||
|  | 	l.waitGroup.Add(1) | ||||||
|  | 	go l.acceptLoop() | ||||||
|  | 	return nil | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (l *MortalService) handleConnection(conn net.Conn, id int) error { | ||||||
|  | 	defer func() { | ||||||
|  | 		log.Infof("Closing connection #%d", id) | ||||||
|  | 		conn.Close() | ||||||
|  | 		l.conns[id] = nil | ||||||
|  | 	}() | ||||||
|  | 
 | ||||||
|  | 	log.Infof("Starting connection #%d", id) | ||||||
|  | 
 | ||||||
|  | 	for { | ||||||
|  | 		if err := l.connectionCallback(conn); err != nil { | ||||||
|  | 			log.Error(err.Error()) | ||||||
|  | 			return err | ||||||
|  | 		} | ||||||
|  | 		return nil | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | type AccumulatingService struct { | ||||||
|  | 	net, address    string | ||||||
|  | 	banner          string | ||||||
|  | 	buffer          bytes.Buffer | ||||||
|  | 	mortalService   *MortalService | ||||||
|  | 	hasProtocolInfo bool | ||||||
|  | 	hasAuthenticate bool | ||||||
|  | 	receivedChan    chan bool | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func NewAccumulatingService(net, address, banner string) *AccumulatingService { | ||||||
|  | 	l := AccumulatingService{ | ||||||
|  | 		net:             net, | ||||||
|  | 		address:         address, | ||||||
|  | 		banner:          banner, | ||||||
|  | 		hasProtocolInfo: true, | ||||||
|  | 		hasAuthenticate: true, | ||||||
|  | 		receivedChan:    make(chan bool, 0), | ||||||
|  | 	} | ||||||
|  | 	return &l | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (a *AccumulatingService) Start() { | ||||||
|  | 	a.mortalService = NewMortalService(a.net, a.address, a.SessionWorker) | ||||||
|  | 	a.mortalService.Start() | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (a *AccumulatingService) Stop() { | ||||||
|  | 	fmt.Println("AccumulatingService STOP") | ||||||
|  | 	a.mortalService.Stop() | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (a *AccumulatingService) WaitUntilReceived() { | ||||||
|  | 	<-a.receivedChan | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (a *AccumulatingService) SessionWorker(conn net.Conn) error { | ||||||
|  | 	connReader := bufio.NewReader(conn) | ||||||
|  | 	conn.Write([]byte(a.banner)) | ||||||
|  | 	for { | ||||||
|  | 		line, err := connReader.ReadBytes('\n') | ||||||
|  | 		if err != nil { | ||||||
|  | 			fmt.Printf("AccumulatingService read error: %s\n", err) | ||||||
|  | 		} | ||||||
|  | 		lineStr := strings.TrimSpace(string(line)) | ||||||
|  | 		a.buffer.WriteString(lineStr + "\n") | ||||||
|  | 		a.receivedChan <- true | ||||||
|  | 	} | ||||||
|  | 	return nil | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func fakeSocksSessionWorker(clientConn net.Conn, targetNet, targetAddr string) error { | ||||||
|  | 	defer clientConn.Close() | ||||||
|  | 
 | ||||||
|  | 	clientAddr := clientConn.RemoteAddr() | ||||||
|  | 	fmt.Printf("INFO/socks: New connection from: %v\n", clientAddr) | ||||||
|  | 
 | ||||||
|  | 	// Do the SOCKS handshake with the client, and read the command.
 | ||||||
|  | 	req, err := socks5.Handshake(clientConn) | ||||||
|  | 	if err != nil { | ||||||
|  | 		panic(fmt.Sprintf("ERR/socks: Failed SOCKS5 handshake: %v", err)) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	var upstreamConn net.Conn | ||||||
|  | 	upstreamConn, err = net.Dial(targetNet, targetAddr) | ||||||
|  | 	if err != nil { | ||||||
|  | 		panic(err) | ||||||
|  | 	} | ||||||
|  | 	defer upstreamConn.Close() | ||||||
|  | 	req.Reply(socks5.ReplySucceeded) | ||||||
|  | 
 | ||||||
|  | 	// A upstream connection has been established, push data back and forth
 | ||||||
|  | 	// till the session is done.
 | ||||||
|  | 	var wg sync.WaitGroup | ||||||
|  | 	wg.Add(2) | ||||||
|  | 	copyLoop := func(dst, src net.Conn) { | ||||||
|  | 		defer wg.Done() | ||||||
|  | 		defer dst.Close() | ||||||
|  | 
 | ||||||
|  | 		io.Copy(dst, src) | ||||||
|  | 	} | ||||||
|  | 	go copyLoop(upstreamConn, clientConn) | ||||||
|  | 	go copyLoop(clientConn, upstreamConn) | ||||||
|  | 
 | ||||||
|  | 	wg.Wait() | ||||||
|  | 	fmt.Printf("INFO/socks: Closed SOCKS connection from: %v\n", clientAddr) | ||||||
|  | 	return nil | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func TestSocksServerProxyChain(t *testing.T) { | ||||||
|  | 	// socks client ---> socks chain ---> socks server ---> service
 | ||||||
|  | 	socksChainNet := "tcp" | ||||||
|  | 	socksChainAddr := "127.0.0.1:7750" | ||||||
|  | 	socksServerNet := "tcp" | ||||||
|  | 	socksServerAddr := "127.0.0.1:8850" | ||||||
|  | 	serviceNet := "tcp" | ||||||
|  | 	serviceAddr := "127.0.0.1:9950" | ||||||
|  | 
 | ||||||
|  | 	banner := "meow 123\r\n" | ||||||
|  | 	// setup the service listener
 | ||||||
|  | 	service := NewAccumulatingService(serviceNet, serviceAddr, banner) | ||||||
|  | 	service.Start() | ||||||
|  | 	defer service.Stop() | ||||||
|  | 
 | ||||||
|  | 	// setup the "socks server"
 | ||||||
|  | 	session := func(clientConn net.Conn) error { | ||||||
|  | 		return fakeSocksSessionWorker(clientConn, serviceNet, serviceAddr) | ||||||
|  | 	} | ||||||
|  | 	socksService := NewMortalService(socksServerNet, socksServerAddr, session) | ||||||
|  | 	socksService.Start() | ||||||
|  | 	defer socksService.Stop() | ||||||
|  | 
 | ||||||
|  | 	// setup the SOCKS proxy chain
 | ||||||
|  | 	socksConfig := socksChainConfig{ | ||||||
|  | 		TargetSocksNet:  socksServerNet, | ||||||
|  | 		TargetSocksAddr: socksServerAddr, | ||||||
|  | 		ListenSocksNet:  socksChainNet, | ||||||
|  | 		ListenSocksAddr: socksChainAddr, | ||||||
|  | 	} | ||||||
|  | 	wg := sync.WaitGroup{} | ||||||
|  | 	ds := dbusServer{} | ||||||
|  | 	chain := NewSocksChain(&socksConfig, &wg, &ds) | ||||||
|  | 	chain.start() | ||||||
|  | 
 | ||||||
|  | 	// setup the SOCKS client
 | ||||||
|  | 	auth := proxy.Auth{ | ||||||
|  | 		User:     "", | ||||||
|  | 		Password: "", | ||||||
|  | 	} | ||||||
|  | 	forward := proxy.NewPerHost(proxy.Direct, proxy.Direct) | ||||||
|  | 	socksClient, err := proxy.SOCKS5(socksChainNet, socksChainAddr, &auth, forward) | ||||||
|  | 	conn, err := socksClient.Dial(serviceNet, serviceAddr) | ||||||
|  | 	if err != nil { | ||||||
|  | 		panic(err) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	// read a banner from the service
 | ||||||
|  | 	rd := bufio.NewReader(conn) | ||||||
|  | 	line := []byte{} | ||||||
|  | 	line, err = rd.ReadBytes('\n') | ||||||
|  | 	if err != nil { | ||||||
|  | 		panic(err) | ||||||
|  | 	} | ||||||
|  | 	if string(line) != banner { | ||||||
|  | 		t.Errorf("Did not receive expected banner. Got %s, wanted %s\n", string(line), banner) | ||||||
|  | 		t.Fail() | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	// send the service some data and verify it was received
 | ||||||
|  | 	clientData := "hello world\r\n" | ||||||
|  | 	conn.Write([]byte(clientData)) | ||||||
|  | 	service.WaitUntilReceived() | ||||||
|  | 	if service.buffer.String() != strings.TrimSpace(clientData)+"\n" { | ||||||
|  | 		t.Errorf("Client sent %s but service only received %s\n", "hello world\n", service.buffer.String()) | ||||||
|  | 		t.Fail() | ||||||
|  | 	} | ||||||
|  | } | ||||||
					Loading…
					
					
				
		Reference in new issue
	
	 shw
						shw