diff --git a/sgfw/tlsguard.go b/sgfw/tlsguard.go index 3aa464d..9ba290a 100644 --- a/sgfw/tlsguard.go +++ b/sgfw/tlsguard.go @@ -1,4 +1,3 @@ - package sgfw import ( @@ -282,7 +281,7 @@ func connectionReader(conn net.Conn, is_client bool, c chan connReader, done cha select { case <-done: - //fmt.Println("++ DONE: ", is_client) + // fmt.Println("++ DONE: ", is_client) if len(buffered) > 0 { //fmt.Println("++ DONE BUT DISPOSING OF BUFFERED DATA") c <- connReader{client: is_client, data: buffered, rtype: 0, err: nil} @@ -321,7 +320,7 @@ func connectionReader(conn net.Conn, is_client bool, c chan connReader, done cha rtype = int(header[0]) mlen = int(int(header[3])<<8 | int(header[4])) - // fmt.Printf("TLS data chunk header read: type = %#x, maj = %v, min = %v, len = %v\n", rtype, header[1], header[2], mlen) + //fmt.Printf("TLS data chunk header read: type = %#x, maj = %v, min = %v, len = %v\n", rtype, header[1], header[2], mlen) /* 16384+1024 if compression is not null */ /* or 16384+2048 if ciphertext */ @@ -335,12 +334,15 @@ func connectionReader(conn net.Conn, is_client bool, c chan connReader, done cha ntimeouts = 0 } else if stage == 2 { remainder := make([]byte, mlen) - conn.SetReadDeadline(time.Now().Add(1 * time.Second)) - _, err := io.ReadFull(conn, remainder) + conn.SetReadDeadline(time.Now().Add(2 * time.Second)) + numRead, err := io.ReadFull(conn, remainder) conn.SetReadDeadline(time.Time{}) if err != nil { if err, ok := err.(net.Error); ok && err.Timeout() { ret_error = err + if numRead > 0 { + buffered = append(buffered, remainder[:numRead]...) + } } else { ntimeouts++ if ntimeouts == TLSGUARD_READ_TIMEOUT { @@ -348,9 +350,9 @@ func connectionReader(conn net.Conn, is_client bool, c chan connReader, done cha } } continue + } else { + buffered = append(buffered, remainder...) } - - buffered = append(buffered, remainder...) //fmt.Printf("------- CHUNK READ: client: %v, err = %v, bytes = %v\n", is_client, err, len(buffered)) cr := connReader{client: is_client, data: buffered, rtype: rtype, err: err} c <- cr @@ -387,7 +389,7 @@ func TLSGuard(conn, conn2 net.Conn, fqdn string) error { //conn client //conn2 server - //fmt.Println("-------- STARTING HANDSHAKE LOOP") + // fmt.Println("-------- STARTING HANDSHAKE LOOP") crChan := make(chan connReader) dChan := make(chan bool, 10) dChan2 := make(chan bool, 10) @@ -420,7 +422,7 @@ select_loop: other = conn2 } - // fmt.Printf("++++ SELECT: %v, %v, %v\n", cr.client, cr.err, len(cr.data)) + //fmt.Printf("++++ SELECT: %v, %v, %v\n", cr.client, cr.err, len(cr.data)) if cr.err == nil && cr.data == nil { //fmt.Println("DONE channel notification received") ndone++ @@ -433,7 +435,7 @@ select_loop: /* We expect only a single byte of data */ if cr.rtype == SSL3_RT_CHANGE_CIPHER_SPEC { - //fmt.Println("CHANGE CIPHER_SPEC: ", cr.data[TLS_RECORD_HDR_LEN]) + // fmt.Println("CHANGE CIPHER_SPEC: ", cr.data[TLS_RECORD_HDR_LEN]) if len(cr.data) != 6 { return errors.New(fmt.Sprintf("TLSGuard dropped connection with strange change cipher spec data length (%v bytes)", len(cr.data))) } @@ -456,7 +458,7 @@ select_loop: } alert_desc := int(int(cr.data[5])<<8 | int(cr.data[6])) - //fmt.Println("ALERT DESCRIPTION: ", alert_desc) + fmt.Println("ALERT DESCRIPTION: ", alert_desc) if cr.data[TLS_RECORD_HDR_LEN] == SSL3_AL_FATAL { return errors.New(fmt.Sprintf("TLSGuard dropped connection after fatal error alert detected")) @@ -479,14 +481,14 @@ select_loop: if (client_sess || server_sess) && (client_change_cipher || server_change_cipher) { - if handshakeMessageLenInt > len(cr.data)+9 { - // log.Notice("TLSGuard saw what looks like a resumed encrypted session... passing connection through") + //if handshakeMessageLenInt > len(cr.data)+9 { + log.Notice("TLSGuard saw what looks like a resumed encrypted session... passing connection through") other.Write(cr.data) dChan <- true dChan2 <- true x509Valid = true break select_loop - } + //} } @@ -503,6 +505,9 @@ select_loop: if s != SSL3_MT_CLIENT_HELLO { server_expected = []uint{SSL3_MT_CERTIFICATE, SSL3_MT_HELLO_REQUEST} + //SRC = "CLIENT" + } else { + //SRC = "SERVER" } hello_offset := 4 @@ -657,7 +662,7 @@ select_loop: } verifyOptions.Intermediates = pool - // fmt.Println("ATTEMPTING TO VERIFY: ", fqdn) + //fmt.Println("ATTEMPTING TO VERIFY: ", fqdn) _, err := c.Verify(verifyOptions) //fmt.Println("ATTEMPTING TO VERIFY RESULT: ", err) if err != nil { @@ -685,13 +690,11 @@ select_loop: // fmt.Printf("Sending chunk of type %d to client.\n", s) } else if cr.err != nil { ndone++ - if cr.client { fmt.Println("Client read error: ", cr.err) } else { fmt.Println("Server read error: ", cr.err) } - return cr.err } @@ -700,10 +703,10 @@ select_loop: //fmt.Println("WAITING; ndone = ", ndone) for ndone < 2 { - //fmt.Println("WAITING; ndone = ", ndone) + // fmt.Println("WAITING; ndone = ", ndone) select { case cr := <-crChan: - //fmt.Printf("CHAN DATA: %v, %v, %v\n", cr.client, cr.err, len(cr.data)) + // fmt.Printf("CHAN DATA: %v, %v, %v\n", cr.client, cr.err, len(cr.data)) if cr.err != nil || cr.data == nil { ndone++ } else if cr.client {