Incorporated TLSGuard and turned it on by default for all outbound SOCKS5 connections.

Fixed display of nil IPs (when only hostname is passed via SOCKS5 connect).
shw_dev
shw 7 years ago
parent f945481c2e
commit 2e7b7debeb

@ -331,7 +331,13 @@ func addRequest(listStore *gtk.ListStore, path, proto string, pid int, ipaddr, h
colVals[1] = path
colVals[2] = proto
colVals[3] = pid
colVals[4] = ipaddr
if ipaddr == "" {
colVals[4] = "---"
} else {
colVals[4] = ipaddr
}
colVals[5] = hostname
colVals[6] = port
colVals[7] = uid

@ -136,13 +136,19 @@ func (p *prompter) processConnection(pc pendingConnection) {
}
policy := pc.policy()
dststr := ""
if pc.dst() != nil {
dststr = pc.dst().String()
}
call := p.dbusObj.Call("com.subgraph.FirewallPrompt.RequestPrompt", 0,
policy.application,
policy.icon,
policy.path,
addr,
int32(pc.dstPort()),
pc.dst().String(),
dststr,
pc.src().String(),
pc.proto(),
int32(pc.procInfo().UID),

@ -172,7 +172,7 @@ func (c *socksChainSession) sessionWorker() {
// If we reach here, the request has been dispatched and completed.
if err == nil {
// Successfully even, send the response back with the addresc.
// Successfully even, send the response back with the address.
c.req.ReplyAddr(ReplySucceeded, c.bndAddr)
}
case CommandConnect:
@ -364,6 +364,15 @@ func (c *socksChainSession) handleConnect() {
}
func (c *socksChainSession) forwardTraffic() {
err := TLSGuard(c.clientConn, c.upstreamConn, c.req.Addr.addrStr)
if err != nil {
log.Error("Dropping traffic due to TLSGuard violation: ", err)
return
} else {
log.Notice("TLSGuard approved certificate presented for connection to: ", c.req.Addr.addrStr)
}
var wg sync.WaitGroup
wg.Add(2)

@ -0,0 +1,180 @@
package sgfw
import (
"crypto/x509"
"io"
"net"
"errors"
)
func TLSGuard(conn, conn2 net.Conn, fqdn string) error {
// Should this be a requirement?
// if strings.HasSuffix(request.DestAddr.FQDN, "onion") {
handshakeByte, err := readNBytes(conn, 1)
if err != nil {
return err
}
if handshakeByte[0] != 0x16 {
return errors.New("Blocked client from attempting non-TLS connection")
}
vers, err := readNBytes(conn, 2)
if err != nil {
return err
}
length, err := readNBytes(conn, 2)
if err != nil {
return err
}
ffslen := int(int(length[0])<<8 | int(length[1]))
ffs, err := readNBytes(conn, ffslen)
if err != nil {
return err
}
// Transmit client hello
conn2.Write(handshakeByte)
conn2.Write(vers)
conn2.Write(length)
conn2.Write(ffs)
// Read ServerHello
bytesRead := 0
var s byte // 0x0e is done
var responseBuf []byte = []byte{}
valid := false
sendToClient := false
for sendToClient == false {
// Handshake byte
serverhandshakeByte, err := readNBytes(conn2, 1)
if err != nil {
return nil
}
responseBuf = append(responseBuf, serverhandshakeByte[0])
bytesRead += 1
if serverhandshakeByte[0] != 0x16 {
return errors.New("Expected TLS server handshake byte was not received")
}
// Protocol version, 2 bytes
serverProtocolVer, err := readNBytes(conn2, 2)
if err != nil {
return err
}
bytesRead += 2
responseBuf = append(responseBuf, serverProtocolVer...)
// Record length, 2 bytes
serverRecordLen, err := readNBytes(conn2, 2)
if err != nil {
return err
}
bytesRead += 2
responseBuf = append(responseBuf, serverRecordLen...)
serverRecordLenInt := int(int(serverRecordLen[0])<<8 | int(serverRecordLen[1]))
// Record type byte
serverMsg, err := readNBytes(conn2, serverRecordLenInt)
if err != nil {
return err
}
bytesRead += len(serverMsg)
responseBuf = append(responseBuf, serverMsg...)
s = serverMsg[0]
// Message len, 3 bytes
serverMessageLen := serverMsg[1:4]
serverMessageLenInt := int(int(serverMessageLen[0])<<16 | int(serverMessageLen[1])<<8 | int(serverMessageLen[2]))
// serverHelloBody, err := readNBytes(conn2, serverMessageLenInt)
serverHelloBody := serverMsg[4 : 4+serverMessageLenInt]
if s == 0x0b {
certChainLen := int(int(serverHelloBody[0])<<16 | int(serverHelloBody[1])<<8 | int(serverHelloBody[2]))
remaining := certChainLen
pos := serverHelloBody[3:certChainLen]
// var certChain []*x509.Certificate
var verifyOptions x509.VerifyOptions
if fqdn != "" {
verifyOptions.DNSName = fqdn
}
pool := x509.NewCertPool()
var c *x509.Certificate
for remaining > 0 {
certLen := int(int(pos[0])<<16 | int(pos[1])<<8 | int(pos[2]))
// fmt.Printf("Certs chain len %d, cert 1 len %d:\n", certChainLen, certLen)
cert := pos[3 : 3+certLen]
certs, err := x509.ParseCertificates(cert)
if remaining == certChainLen {
c = certs[0]
} else {
pool.AddCert(certs[0])
}
// certChain = append(certChain, certs[0])
if err != nil {
return err
}
remaining = remaining - certLen - 3
if remaining > 0 {
pos = pos[3+certLen:]
}
}
verifyOptions.Intermediates = pool
_, err = c.Verify(verifyOptions)
if err != nil {
return err
} else {
valid = true
}
// else if s == 0x0d { fmt.Printf("found a client cert request, sending buf to client\n") }
} else if s == 0x0e {
sendToClient = true
} else if s == 0x0d {
sendToClient = true
}
// fmt.Printf("Version bytes: %x %x\n", responseBuf[1], responseBuf[2])
// fmt.Printf("Len bytes: %x %x\n", responseBuf[3], responseBuf[4])
// fmt.Printf("Message type: %x\n", responseBuf[5])
// fmt.Printf("Message len: %x %x %x\n", responseBuf[6], responseBuf[7], responseBuf[8])
// fmt.Printf("Message body: %v\n", responseBuf[9:])
conn.Write(responseBuf)
responseBuf = []byte{}
}
if !valid {
return errors.New("Unknown error: TLS connection could not be validated")
}
return nil
}
func readNBytes(conn net.Conn, numBytes int) ([]byte, error) {
res := make([]byte, 0)
temp := make([]byte, 1)
for i := 0; i < numBytes; i++ {
_, err := io.ReadAtLeast(conn, temp, 1)
if err != nil {
return res, err
}
res = append(res, temp[0])
}
return res, nil
}
Loading…
Cancel
Save