diff --git a/sgfw/tlsguard.go b/sgfw/tlsguard.go index 0fe2781..9af434a 100644 --- a/sgfw/tlsguard.go +++ b/sgfw/tlsguard.go @@ -3,104 +3,132 @@ package sgfw import ( "crypto/x509" "errors" + "fmt" "io" "net" + "time" ) -func TLSGuard(conn, conn2 net.Conn, fqdn string) error { - // Should this be a requirement? - // if strings.HasSuffix(request.DestAddr.FQDN, "onion") { +const TLSGUARD_READ_TIMEOUT = 2 * time.Second +const TLSGUARD_MIN_TLS_VER_MAJ = 3 +const TLSGUARD_MIN_TLS_VER_MIN = 1 + +const SSL3_RT_CHANGE_CIPHER_SPEC = 20 +const SSL3_RT_ALERT = 21 +const SSL3_RT_HANDSHAKE = 22 +const SSL3_RT_APPLICATION_DATA = 23 + +const SSL3_MT_SERVER_HELLO = 2 +const SSL3_MT_CERTIFICATE = 11 +const SSL3_MT_CERTIFICATE_REQUEST = 13 +const SSL3_MT_SERVER_DONE = 14 + +func readTLSChunk(conn net.Conn) ([]byte, int, error) { + conn.SetReadDeadline(time.Now().Add(TLSGUARD_READ_TIMEOUT)) + cbytes, err := readNBytes(conn, 5) + conn.SetReadDeadline(time.Time{}) - handshakeByte, err := readNBytes(conn, 1) if err != nil { - return err + fmt.Println("TLS data chunk read failure: ", err) + return nil, 0, err } - if handshakeByte[0] != 0x16 { - return errors.New("Blocked client from attempting non-TLS connection") + if int(cbytes[1]) < TLSGUARD_MIN_TLS_VER_MAJ { + return nil, 0, errors.New("TLS protocol major version less than expected minimum") + } else if int(cbytes[2]) < TLSGUARD_MIN_TLS_VER_MIN { + return nil, 0, errors.New("TLS protocol minor version less than expected minimum") } - vers, err := readNBytes(conn, 2) - if err != nil { - return err - } + cbyte := cbytes[0] + mlen := int(int(cbytes[3])<<8 | int(cbytes[4])) +// fmt.Printf("TLS data chunk header read: type = %#x, maj = %v, min = %v, len = %v\n", cbyte, cbytes[1], cbytes[2], mlen) + + conn.SetReadDeadline(time.Now().Add(TLSGUARD_READ_TIMEOUT)) + cbytes2, err := readNBytes(conn, mlen) + conn.SetReadDeadline(time.Time{}) - length, err := readNBytes(conn, 2) if err != nil { - return err + return nil, 0, err } - ffslen := int(int(length[0])<<8 | int(length[1])) + cbytes = append(cbytes, cbytes2...) + return cbytes, int(cbyte), nil +} + +func TLSGuard(conn, conn2 net.Conn, fqdn string) error { + // Should this be a requirement? + // if strings.HasSuffix(request.DestAddr.FQDN, "onion") { + + //conn client + //conn2 server - ffs, err := readNBytes(conn, ffslen) + // Read the opening message from the client + chunk, rtype, err := readTLSChunk(conn) if err != nil { return err } - // Transmit client hello - conn2.Write(handshakeByte) - conn2.Write(vers) - conn2.Write(length) - conn2.Write(ffs) + if rtype != SSL3_RT_HANDSHAKE { + return errors.New("Blocked client from attempting non-TLS connection") + } + + // Pass it on through to the server + conn2.Write(chunk) // Read ServerHello - bytesRead := 0 - var s byte // 0x0e is done - var responseBuf []byte = []byte{} valid := false - sendToClient := false + loop := 1 - for sendToClient == false { - // Handshake byte - serverhandshakeByte, err := readNBytes(conn2, 1) - if err != nil { - return nil - } + passthru := false - responseBuf = append(responseBuf, serverhandshakeByte[0]) - bytesRead += 1 + for 1 == 1 { + loop++ - if serverhandshakeByte[0] != 0x16 { - return errors.New("Expected TLS server handshake byte was not received") - } +// fmt.Printf("SSL LOOP %v; trying to read: conn2\n", loop) + chunk, rtype, err = readTLSChunk(conn2) - // Protocol version, 2 bytes - serverProtocolVer, err := readNBytes(conn2, 2) if err != nil { - return err - } + fmt.Printf("OTHER loop %v: trying to read: conn\n", loop) + chunk, rtype, err2 := readTLSChunk(conn) + fmt.Printf("read: %v, %v, %v\n", err2, rtype, len(chunk)) - bytesRead += 2 - responseBuf = append(responseBuf, serverProtocolVer...) + if err2 == nil { + conn2.Write(chunk) + continue + } - // Record length, 2 bytes - serverRecordLen, err := readNBytes(conn2, 2) - if err != nil { return err } - bytesRead += 2 - responseBuf = append(responseBuf, serverRecordLen...) - serverRecordLenInt := int(int(serverRecordLen[0])<<8 | int(serverRecordLen[1])) - - // Record type byte - serverMsg, err := readNBytes(conn2, serverRecordLenInt) - if err != nil { - return err + if rtype == SSL3_RT_CHANGE_CIPHER_SPEC || rtype == SSL3_RT_APPLICATION_DATA || + rtype == SSL3_RT_ALERT { +// fmt.Println("OTHER DATA; PASSING THRU") + passthru = true + } else if rtype == SSL3_RT_HANDSHAKE { + passthru = false + } else { + return errors.New(fmt.Sprintf("Expected TLS server handshake byte was not received [%#x vs 0x16]", rtype)) } - bytesRead += len(serverMsg) - responseBuf = append(responseBuf, serverMsg...) - s = serverMsg[0] - - // Message len, 3 bytes - serverMessageLen := serverMsg[1:4] - serverMessageLenInt := int(int(serverMessageLen[0])<<16 | int(serverMessageLen[1])<<8 | int(serverMessageLen[2])) - - // serverHelloBody, err := readNBytes(conn2, serverMessageLenInt) - serverHelloBody := serverMsg[4 : 4+serverMessageLenInt] + if passthru { +// fmt.Println("passthru writing buf again and continuing:") + conn.Write(chunk) + continue + } - if s == 0x0b { + serverMsg := chunk[5:] + s := serverMsg[0] + fmt.Printf("s = %#x\n", s) + + if s == SSL3_MT_CERTIFICATE { + // Message len, 3 bytes + serverMessageLen := serverMsg[1:4] + serverMessageLenInt := int(int(serverMessageLen[0])<<16 | int(serverMessageLen[1])<<8 | int(serverMessageLen[2])) +// fmt.Printf("chunk len = %v, serverMsgLen = %v, slint = %v\n", len(chunk), len(serverMsg), serverMessageLenInt) + if len(serverMsg) < serverMessageLenInt { + return errors.New(fmt.Sprintf("len(serverMsg) %v < serverMessageLenInt %v!\n", len(serverMsg), serverMessageLenInt)) + } + serverHelloBody := serverMsg[4 : 4+serverMessageLenInt] certChainLen := int(int(serverHelloBody[0])<<16 | int(serverHelloBody[1])<<8 | int(serverHelloBody[2])) remaining := certChainLen pos := serverHelloBody[3:certChainLen] @@ -117,7 +145,7 @@ func TLSGuard(conn, conn2 net.Conn, fqdn string) error { for remaining > 0 { certLen := int(int(pos[0])<<16 | int(pos[1])<<8 | int(pos[2])) - // fmt.Printf("Certs chain len %d, cert 1 len %d:\n", certChainLen, certLen) + // fmt.Printf("Certs chain len %d, cert 1 len %d:\n", certChainLen, certLen) cert := pos[3 : 3+certLen] certs, err := x509.ParseCertificates(cert) if remaining == certChainLen { @@ -136,26 +164,24 @@ func TLSGuard(conn, conn2 net.Conn, fqdn string) error { } verifyOptions.Intermediates = pool +// fmt.Println("ATTEMPTING TO VERIFY: ", fqdn) _, err = c.Verify(verifyOptions) +// fmt.Println("ATTEMPTING TO VERIFY RESULT: ", err) if err != nil { return err } else { valid = true } // else if s == 0x0d { fmt.Printf("found a client cert request, sending buf to client\n") } - } else if s == 0x0e { - sendToClient = true - } else if s == 0x0d { - sendToClient = true + } else if s == SSL3_MT_SERVER_DONE { + conn.Write(chunk) + break + } else if s == SSL3_MT_CERTIFICATE_REQUEST { + break } +// fmt.Printf("Sending chunk of type %d to client.\n", s) - // fmt.Printf("Version bytes: %x %x\n", responseBuf[1], responseBuf[2]) - // fmt.Printf("Len bytes: %x %x\n", responseBuf[3], responseBuf[4]) - // fmt.Printf("Message type: %x\n", responseBuf[5]) - // fmt.Printf("Message len: %x %x %x\n", responseBuf[6], responseBuf[7], responseBuf[8]) - // fmt.Printf("Message body: %v\n", responseBuf[9:]) - conn.Write(responseBuf) - responseBuf = []byte{} + conn.Write(chunk) } if !valid {