From f616f54b2c82fc33f2b20bf99815046cce70295e Mon Sep 17 00:00:00 2001 From: Stephen Watt Date: Tue, 10 Oct 2017 23:09:59 -0400 Subject: [PATCH] Fix TLSGuard handshake timeout issue by breaking total timeout period into one second polling intervals. --- sgfw/tlsguard.go | 34 +++++++++++++++++++++++++++++----- 1 file changed, 29 insertions(+), 5 deletions(-) diff --git a/sgfw/tlsguard.go b/sgfw/tlsguard.go index 9370a06..6e2f93c 100644 --- a/sgfw/tlsguard.go +++ b/sgfw/tlsguard.go @@ -11,7 +11,7 @@ import ( "time" ) -const TLSGUARD_READ_TIMEOUT = 8 * time.Second +const TLSGUARD_READ_TIMEOUT = 8 // seconds const TLSGUARD_MIN_TLS_VER_MAJ = 3 const TLSGUARD_MIN_TLS_VER_MIN = 1 @@ -270,6 +270,7 @@ func connectionReader(conn net.Conn, is_client bool, c chan connReader, done cha mlen := 0 rtype := 0 stage := 1 + ntimeouts := 0 for { if ret_error != nil { @@ -291,11 +292,18 @@ func connectionReader(conn net.Conn, is_client bool, c chan connReader, done cha default: if stage == 1 { header := make([]byte, TLS_RECORD_HDR_LEN) - conn.SetReadDeadline(time.Now().Add(TLSGUARD_READ_TIMEOUT)) + conn.SetReadDeadline(time.Now().Add(1 * time.Second)) _, err := io.ReadFull(conn, header) conn.SetReadDeadline(time.Time{}) if err != nil { - ret_error = err + if err, ok := err.(net.Error); ok && err.Timeout() { + ret_error = err + } else { + ntimeouts++ + if ntimeouts == TLSGUARD_READ_TIMEOUT { + ret_error = err + } + } continue } @@ -323,13 +331,21 @@ func connectionReader(conn net.Conn, is_client bool, c chan connReader, done cha buffered = header stage++ + ntimeouts = 0 } else if stage == 2 { remainder := make([]byte, mlen) - conn.SetReadDeadline(time.Now().Add(TLSGUARD_READ_TIMEOUT)) + conn.SetReadDeadline(time.Now().Add(1 * time.Second)) _, err := io.ReadFull(conn, remainder) conn.SetReadDeadline(time.Time{}) if err != nil { - ret_error = err + if err, ok := err.(net.Error); ok && err.Timeout() { + ret_error = err + } else { + ntimeouts++ + if ntimeouts == TLSGUARD_READ_TIMEOUT { + ret_error = err + } + } continue } @@ -342,6 +358,7 @@ func connectionReader(conn net.Conn, is_client bool, c chan connReader, done cha rtype = 0 mlen = 0 stage = 1 + ntimeouts = 0 } } @@ -592,6 +609,8 @@ select_loop: if !cr.client && s == SSL3_MT_HELLO_REQUEST { fmt.Println("Server sent hello request") + other.Write(cr.data) + continue } if s > SSL3_MT_CERTIFICATE_STATUS { @@ -647,6 +666,11 @@ select_loop: return err } else { x509Valid = true + // Added in. + other.Write(cr.data) + dChan <- true + dChan2 <- true + break select_loop } }