Read more TLS messages during handshake

shw-merge
dma 7 years ago
parent 38fabc3327
commit 7d3e31a005

@ -3,104 +3,132 @@ package sgfw
import ( import (
"crypto/x509" "crypto/x509"
"errors" "errors"
"fmt"
"io" "io"
"net" "net"
"time"
) )
func TLSGuard(conn, conn2 net.Conn, fqdn string) error { const TLSGUARD_READ_TIMEOUT = 2 * time.Second
// Should this be a requirement? const TLSGUARD_MIN_TLS_VER_MAJ = 3
// if strings.HasSuffix(request.DestAddr.FQDN, "onion") { const TLSGUARD_MIN_TLS_VER_MIN = 1
const SSL3_RT_CHANGE_CIPHER_SPEC = 20
const SSL3_RT_ALERT = 21
const SSL3_RT_HANDSHAKE = 22
const SSL3_RT_APPLICATION_DATA = 23
const SSL3_MT_SERVER_HELLO = 2
const SSL3_MT_CERTIFICATE = 11
const SSL3_MT_CERTIFICATE_REQUEST = 13
const SSL3_MT_SERVER_DONE = 14
func readTLSChunk(conn net.Conn) ([]byte, int, error) {
conn.SetReadDeadline(time.Now().Add(TLSGUARD_READ_TIMEOUT))
cbytes, err := readNBytes(conn, 5)
conn.SetReadDeadline(time.Time{})
handshakeByte, err := readNBytes(conn, 1)
if err != nil { if err != nil {
return err fmt.Println("TLS data chunk read failure: ", err)
return nil, 0, err
} }
if handshakeByte[0] != 0x16 { if int(cbytes[1]) < TLSGUARD_MIN_TLS_VER_MAJ {
return errors.New("Blocked client from attempting non-TLS connection") return nil, 0, errors.New("TLS protocol major version less than expected minimum")
} else if int(cbytes[2]) < TLSGUARD_MIN_TLS_VER_MIN {
return nil, 0, errors.New("TLS protocol minor version less than expected minimum")
} }
vers, err := readNBytes(conn, 2) cbyte := cbytes[0]
if err != nil { mlen := int(int(cbytes[3])<<8 | int(cbytes[4]))
return err // fmt.Printf("TLS data chunk header read: type = %#x, maj = %v, min = %v, len = %v\n", cbyte, cbytes[1], cbytes[2], mlen)
}
conn.SetReadDeadline(time.Now().Add(TLSGUARD_READ_TIMEOUT))
cbytes2, err := readNBytes(conn, mlen)
conn.SetReadDeadline(time.Time{})
length, err := readNBytes(conn, 2)
if err != nil { if err != nil {
return err return nil, 0, err
} }
ffslen := int(int(length[0])<<8 | int(length[1])) cbytes = append(cbytes, cbytes2...)
return cbytes, int(cbyte), nil
}
func TLSGuard(conn, conn2 net.Conn, fqdn string) error {
// Should this be a requirement?
// if strings.HasSuffix(request.DestAddr.FQDN, "onion") {
//conn client
//conn2 server
ffs, err := readNBytes(conn, ffslen) // Read the opening message from the client
chunk, rtype, err := readTLSChunk(conn)
if err != nil { if err != nil {
return err return err
} }
// Transmit client hello if rtype != SSL3_RT_HANDSHAKE {
conn2.Write(handshakeByte) return errors.New("Blocked client from attempting non-TLS connection")
conn2.Write(vers) }
conn2.Write(length)
conn2.Write(ffs) // Pass it on through to the server
conn2.Write(chunk)
// Read ServerHello // Read ServerHello
bytesRead := 0
var s byte // 0x0e is done
var responseBuf []byte = []byte{}
valid := false valid := false
sendToClient := false loop := 1
for sendToClient == false { passthru := false
// Handshake byte
serverhandshakeByte, err := readNBytes(conn2, 1)
if err != nil {
return nil
}
responseBuf = append(responseBuf, serverhandshakeByte[0]) for 1 == 1 {
bytesRead += 1 loop++
if serverhandshakeByte[0] != 0x16 { // fmt.Printf("SSL LOOP %v; trying to read: conn2\n", loop)
return errors.New("Expected TLS server handshake byte was not received") chunk, rtype, err = readTLSChunk(conn2)
}
// Protocol version, 2 bytes
serverProtocolVer, err := readNBytes(conn2, 2)
if err != nil { if err != nil {
return err fmt.Printf("OTHER loop %v: trying to read: conn\n", loop)
} chunk, rtype, err2 := readTLSChunk(conn)
fmt.Printf("read: %v, %v, %v\n", err2, rtype, len(chunk))
bytesRead += 2 if err2 == nil {
responseBuf = append(responseBuf, serverProtocolVer...) conn2.Write(chunk)
continue
}
// Record length, 2 bytes
serverRecordLen, err := readNBytes(conn2, 2)
if err != nil {
return err return err
} }
bytesRead += 2 if rtype == SSL3_RT_CHANGE_CIPHER_SPEC || rtype == SSL3_RT_APPLICATION_DATA ||
responseBuf = append(responseBuf, serverRecordLen...) rtype == SSL3_RT_ALERT {
serverRecordLenInt := int(int(serverRecordLen[0])<<8 | int(serverRecordLen[1])) // fmt.Println("OTHER DATA; PASSING THRU")
passthru = true
// Record type byte } else if rtype == SSL3_RT_HANDSHAKE {
serverMsg, err := readNBytes(conn2, serverRecordLenInt) passthru = false
if err != nil { } else {
return err return errors.New(fmt.Sprintf("Expected TLS server handshake byte was not received [%#x vs 0x16]", rtype))
} }
bytesRead += len(serverMsg) if passthru {
responseBuf = append(responseBuf, serverMsg...) // fmt.Println("passthru writing buf again and continuing:")
s = serverMsg[0] conn.Write(chunk)
continue
// Message len, 3 bytes }
serverMessageLen := serverMsg[1:4]
serverMessageLenInt := int(int(serverMessageLen[0])<<16 | int(serverMessageLen[1])<<8 | int(serverMessageLen[2]))
// serverHelloBody, err := readNBytes(conn2, serverMessageLenInt)
serverHelloBody := serverMsg[4 : 4+serverMessageLenInt]
if s == 0x0b { serverMsg := chunk[5:]
s := serverMsg[0]
fmt.Printf("s = %#x\n", s)
if s == SSL3_MT_CERTIFICATE {
// Message len, 3 bytes
serverMessageLen := serverMsg[1:4]
serverMessageLenInt := int(int(serverMessageLen[0])<<16 | int(serverMessageLen[1])<<8 | int(serverMessageLen[2]))
// fmt.Printf("chunk len = %v, serverMsgLen = %v, slint = %v\n", len(chunk), len(serverMsg), serverMessageLenInt)
if len(serverMsg) < serverMessageLenInt {
return errors.New(fmt.Sprintf("len(serverMsg) %v < serverMessageLenInt %v!\n", len(serverMsg), serverMessageLenInt))
}
serverHelloBody := serverMsg[4 : 4+serverMessageLenInt]
certChainLen := int(int(serverHelloBody[0])<<16 | int(serverHelloBody[1])<<8 | int(serverHelloBody[2])) certChainLen := int(int(serverHelloBody[0])<<16 | int(serverHelloBody[1])<<8 | int(serverHelloBody[2]))
remaining := certChainLen remaining := certChainLen
pos := serverHelloBody[3:certChainLen] pos := serverHelloBody[3:certChainLen]
@ -117,7 +145,7 @@ func TLSGuard(conn, conn2 net.Conn, fqdn string) error {
for remaining > 0 { for remaining > 0 {
certLen := int(int(pos[0])<<16 | int(pos[1])<<8 | int(pos[2])) certLen := int(int(pos[0])<<16 | int(pos[1])<<8 | int(pos[2]))
// fmt.Printf("Certs chain len %d, cert 1 len %d:\n", certChainLen, certLen) // fmt.Printf("Certs chain len %d, cert 1 len %d:\n", certChainLen, certLen)
cert := pos[3 : 3+certLen] cert := pos[3 : 3+certLen]
certs, err := x509.ParseCertificates(cert) certs, err := x509.ParseCertificates(cert)
if remaining == certChainLen { if remaining == certChainLen {
@ -136,26 +164,24 @@ func TLSGuard(conn, conn2 net.Conn, fqdn string) error {
} }
verifyOptions.Intermediates = pool verifyOptions.Intermediates = pool
// fmt.Println("ATTEMPTING TO VERIFY: ", fqdn)
_, err = c.Verify(verifyOptions) _, err = c.Verify(verifyOptions)
// fmt.Println("ATTEMPTING TO VERIFY RESULT: ", err)
if err != nil { if err != nil {
return err return err
} else { } else {
valid = true valid = true
} }
// else if s == 0x0d { fmt.Printf("found a client cert request, sending buf to client\n") } // else if s == 0x0d { fmt.Printf("found a client cert request, sending buf to client\n") }
} else if s == 0x0e { } else if s == SSL3_MT_SERVER_DONE {
sendToClient = true conn.Write(chunk)
} else if s == 0x0d { break
sendToClient = true } else if s == SSL3_MT_CERTIFICATE_REQUEST {
break
} }
// fmt.Printf("Sending chunk of type %d to client.\n", s)
// fmt.Printf("Version bytes: %x %x\n", responseBuf[1], responseBuf[2]) conn.Write(chunk)
// fmt.Printf("Len bytes: %x %x\n", responseBuf[3], responseBuf[4])
// fmt.Printf("Message type: %x\n", responseBuf[5])
// fmt.Printf("Message len: %x %x %x\n", responseBuf[6], responseBuf[7], responseBuf[8])
// fmt.Printf("Message body: %v\n", responseBuf[9:])
conn.Write(responseBuf)
responseBuf = []byte{}
} }
if !valid { if !valid {

Loading…
Cancel
Save