diff --git a/sgfw/tlsguard.go b/sgfw/tlsguard.go index dd13353..185d8f1 100644 --- a/sgfw/tlsguard.go +++ b/sgfw/tlsguard.go @@ -1,3 +1,4 @@ + package sgfw import ( @@ -9,10 +10,9 @@ import ( "io" "net" "time" - "math/rand" ) -const TLSGUARD_READ_TIMEOUT = 8 * time.Second +const TLSGUARD_READ_TIMEOUT = 8 // seconds const TLSGUARD_MIN_TLS_VER_MAJ = 3 const TLSGUARD_MIN_TLS_VER_MIN = 1 @@ -125,7 +125,6 @@ type connReader struct { data []byte rtype int err error - numb int } var cipherSuiteMap map[uint16]string = map[uint16]string{ @@ -266,48 +265,46 @@ func stripTLSData(record []byte, start_ind, end_ind int, len_ind int, len_size i return result } -func connectionReader(conn net.Conn, is_client bool, c chan connReader, done chan bool, num int) { +func connectionReader(conn net.Conn, is_client bool, c chan connReader, done chan bool) { var ret_error error = nil buffered := []byte{} mlen := 0 rtype := 0 stage := 1 - proceed := true + ntimeouts := 0 for { if ret_error != nil { - cr := connReader{client: is_client, data: nil, rtype: 0, err: ret_error, numb: num} + cr := connReader{client: is_client, data: nil, rtype: 0, err: ret_error} c <- cr break } - //fmt.Printf("why i am here %v %d\n", is_client, num) - //if is_client == true && proceed == false { - if proceed == false { - if len(buffered) > 0 { - c <- connReader{client: is_client, data: buffered, rtype:0, err: nil, numb: num} - } - c <- connReader{client: is_client, data: nil, rtype: 0, err: nil} - return - } + select { case <-done: - // fmt.Printf("++ DONE %d: %v\n", num, is_client) + //fmt.Println("++ DONE: ", is_client) if len(buffered) > 0 { - // fmt.Printf("++ DONE BUT DISPOSING OF BUFFERED DATA num: %d\n", num) - c <- connReader{client: is_client, data: buffered, rtype: 0, err: nil, numb: num} + //fmt.Println("++ DONE BUT DISPOSING OF BUFFERED DATA") + c <- connReader{client: is_client, data: buffered, rtype: 0, err: nil} } - c <- connReader{client: is_client, data: nil, rtype: 0, err: nil, numb: num} + + c <- connReader{client: is_client, data: nil, rtype: 0, err: nil} return default: - if stage == 1 && proceed == true { + if stage == 1 { header := make([]byte, TLS_RECORD_HDR_LEN) - conn.SetReadDeadline(time.Now().Add(TLSGUARD_READ_TIMEOUT)) - // fmt.Printf("About to read here stage 1 %v %d\n", is_client, num) + conn.SetReadDeadline(time.Now().Add(1 * time.Second)) _, err := io.ReadFull(conn, header) - // fmt.Printf("Read here stage 1 %v num %d \n", is_client, num) conn.SetReadDeadline(time.Time{}) if err != nil { - ret_error = err + if err, ok := err.(net.Error); ok && err.Timeout() { + ret_error = err + } else { + ntimeouts++ + if ntimeouts == TLSGUARD_READ_TIMEOUT { + ret_error = err + } + } continue } @@ -324,7 +321,7 @@ func connectionReader(conn net.Conn, is_client bool, c chan connReader, done cha rtype = int(header[0]) mlen = int(int(header[3])<<8 | int(header[4])) - //fmt.Printf("TLS data chunk header read: type = %#x, maj = %v, min = %v, len = %v\n", rtype, header[1], header[2], mlen) + // fmt.Printf("TLS data chunk header read: type = %#x, maj = %v, min = %v, len = %v\n", rtype, header[1], header[2], mlen) /* 16384+1024 if compression is not null */ /* or 16384+2048 if ciphertext */ @@ -335,30 +332,34 @@ func connectionReader(conn net.Conn, is_client bool, c chan connReader, done cha buffered = header stage++ + ntimeouts = 0 } else if stage == 2 { remainder := make([]byte, mlen) - // fmt.Printf("About to read here stage 2 %v num %d\n", is_client, num) - conn.SetReadDeadline(time.Now().Add(TLSGUARD_READ_TIMEOUT)) + conn.SetReadDeadline(time.Now().Add(1 * time.Second)) _, err := io.ReadFull(conn, remainder) conn.SetReadDeadline(time.Time{}) if err != nil { - ret_error = err + if err, ok := err.(net.Error); ok && err.Timeout() { + ret_error = err + } else { + ntimeouts++ + if ntimeouts == TLSGUARD_READ_TIMEOUT { + ret_error = err + } + } continue } buffered = append(buffered, remainder...) // fmt.Printf("------- CHUNK READ: client: %v, err = %v, bytes = %v\n", is_client, err, len(buffered)) - cr := connReader{client: is_client, data: buffered, rtype: rtype, err: err, numb: num} + cr := connReader{client: is_client, data: buffered, rtype: rtype, err: err} c <- cr buffered = []byte{} rtype = 0 mlen = 0 stage = 1 - //proceed = false - if is_client { - proceed = false - } + ntimeouts = 0 } } @@ -390,11 +391,8 @@ func TLSGuard(conn, conn2 net.Conn, fqdn string) error { crChan := make(chan connReader) dChan := make(chan bool, 10) dChan2 := make(chan bool, 10) - rand.Seed(time.Now().UTC().UnixNano()) - connectThread1 := rand.Intn(1000) - connectThread2 := rand.Intn(1000) - go connectionReader(conn, true, crChan, dChan, connectThread1) - go connectionReader(conn2, false, crChan, dChan2, connectThread2) + go connectionReader(conn, true, crChan, dChan) + go connectionReader(conn2, false, crChan, dChan2) client_expected := []uint{SSL3_MT_CLIENT_HELLO} server_expected := []uint{SSL3_MT_SERVER_HELLO} @@ -407,7 +405,7 @@ func TLSGuard(conn, conn2 net.Conn, fqdn string) error { select_loop: for { if ndone == 2 { - // fmt.Println("DONE channel got both notifications. Terminating loop.") + //fmt.Println("DONE channel got both notifications. Terminating loop.") close(dChan) close(dChan2) close(crChan) @@ -422,9 +420,9 @@ select_loop: other = conn2 } - //fmt.Printf("++++ SELECT: %v, %v, %v\n", cr.client, cr.err, len(cr.data)) + // fmt.Printf("++++ SELECT: %v, %v, %v\n", cr.client, cr.err, len(cr.data)) if cr.err == nil && cr.data == nil { - // fmt.Println("DONE channel notification received") + //fmt.Println("DONE channel notification received") ndone++ continue } @@ -435,7 +433,7 @@ select_loop: /* We expect only a single byte of data */ if cr.rtype == SSL3_RT_CHANGE_CIPHER_SPEC { - // fmt.Println("CHANGE CIPHER_SPEC: ", cr.data[TLS_RECORD_HDR_LEN]) + // fmt.Println("CHANGE CIPHER_SPEC: ", cr.data[TLS_RECORD_HDR_LEN]) if len(cr.data) != 6 { return errors.New(fmt.Sprintf("TLSGuard dropped connection with strange change cipher spec data length (%v bytes)", len(cr.data))) } @@ -447,9 +445,6 @@ select_loop: client_change_cipher = true } else { server_change_cipher = true - x509Valid = true - dChan <- true - dChan2 <- true } } else if cr.rtype == SSL3_RT_ALERT { if cr.data[TLS_RECORD_HDR_LEN] == SSL3_AL_WARNING { @@ -461,7 +456,7 @@ select_loop: } alert_desc := int(int(cr.data[5])<<8 | int(cr.data[6])) - fmt.Println("ALERT DESCRIPTION: ", alert_desc) + // fmt.Println("ALERT DESCRIPTION: ", alert_desc) if cr.data[TLS_RECORD_HDR_LEN] == SSL3_AL_FATAL { return errors.New(fmt.Sprintf("TLSGuard dropped connection after fatal error alert detected")) @@ -484,12 +479,12 @@ select_loop: if (client_sess || server_sess) && (client_change_cipher || server_change_cipher) { - if handshakeMessageLenInt > len(cr.data)+9 { - log.Notice("TLSGuard saw what looks like a resumed encrypted session... passing connection through") - x509Valid = true + if handshakeMessageLenInt > len(cr.data)+9 { + // log.Notice("TLSGuard saw what looks like a resumed encrypted session... passing connection through") other.Write(cr.data) - dChan2 <- true dChan <- true + dChan2 <- true + x509Valid = true break select_loop } @@ -504,27 +499,29 @@ select_loop: if (cr.client && s == SSL3_MT_CLIENT_HELLO) || (!cr.client && s == SSL3_MT_SERVER_HELLO) { // rewrite := false // rewrite_buf := []byte{} - //SRC := "" - if s != SSL3_MT_CLIENT_HELLO { - //SRC = "CLIENT" - //} else { + /* SRC := "" + + if s == SSL3_MT_CLIENT_HELLO { + SRC = "CLIENT" + } else { server_expected = []uint{SSL3_MT_CERTIFICATE, SSL3_MT_HELLO_REQUEST} - // SRC = "SERVER" + SRC = "SERVER" } +*/ hello_offset := 4 // 2 byte protocol version - // fmt.Println(SRC, "HELLO VERSION = ", handshakeMsg[hello_offset:hello_offset+2]) + // fmt.Println(SRC, "HELLO VERSION = ", handshakeMsg[hello_offset:hello_offset+2]) hello_offset += 2 // 4 byte Random/GMT time //gmtbytes := binary.BigEndian.Uint32(handshakeMsg[hello_offset : hello_offset+4]) //gmt := time.Unix(int64(gmtbytes), 0) - //fmt.Println(SRC, "HELLO GMT = ", gmt) + // fmt.Println(SRC, "HELLO GMT = ", gmt) hello_offset += 4 // 28 bytes Random/random_bytes hello_offset += 28 // 1 byte (32-bit session ID) sess_len := uint(handshakeMsg[hello_offset]) - // fmt.Println(SRC, "HELLO SESSION ID = ", sess_len) + // fmt.Println(SRC, "HELLO SESSION ID = ", sess_len) if cr.client && sess_len > 0 { client_sess = true @@ -532,71 +529,71 @@ select_loop: server_sess = true } - /* - hello_offset += int(sess_len) + 1 - // 2 byte cipher suite array - cs := binary.BigEndian.Uint16(handshakeMsg[hello_offset : hello_offset+2]) - noCS := cs - fmt.Printf("cs = %v / %#x\n", noCS, noCS) - - saved_ciphersuite_size_off := hello_offset - - if !cr.client { - fmt.Printf("SERVER selected ciphersuite: %#x (%s)\n", cs, getCipherSuiteName(uint(cs))) - hello_offset += 2 - } else { - - for csind := 0; csind < int(noCS/2); csind++ { - off := hello_offset + 2 + (csind * 2) - cs = binary.BigEndian.Uint16(handshakeMsg[off : off+2]) - cname := getCipherSuiteName(uint(cs)) - fmt.Printf("%s HELLO CIPHERSUITE: %d/%d: %#x (%s)\n", SRC, csind+1, noCS/2, cs, cname) - - if isBadCipher(cname) { - fmt.Println("BAD CIPHER: ", cname) - } - - } - - hello_offset += 2 + int(noCS) - } - - clen := uint(handshakeMsg[hello_offset]) - hello_offset++ - - if !cr.client { - fmt.Println("SERVER selected compression method: ", clen) - } else { - fmt.Println(SRC, "HELLO COMPRESSION METHODS LEN = ", clen) - fmt.Println(SRC, "HELLO COMPRESSION METHODS: ", handshakeMsg[hello_offset:hello_offset+int(clen)]) - hello_offset += int(clen) - } - - var extlen uint16 = 0 - - if hello_offset == len(handshakeMsg) { - fmt.Println("Message didn't have any extensions present") - } else { - extlen = binary.BigEndian.Uint16(handshakeMsg[hello_offset : hello_offset+2]) - fmt.Println(SRC, "HELLO EXTENSIONS LENGTH: ", extlen) - hello_offset += 2 - } - - if cr.client { - ext_ctr := 0 - - for ext_ctr < int(extlen)-2 { - exttype := binary.BigEndian.Uint16(handshakeMsg[hello_offset : hello_offset+2]) - hello_offset += 2 - ext_ctr += 2 - // fmt.Printf("PROGRESS: %v of %v, %v of %v\n", ext_ctr, extlen, hello_offset, len(handshakeMsg)) - fmt.Printf("EXTTYPE = %#x (%s)\n", exttype, gettlsExtensionName(uint(exttype))) - inner_len := binary.BigEndian.Uint16(handshakeMsg[hello_offset : hello_offset+2]) - hello_offset += int(inner_len) + 2 - ext_ctr += int(inner_len) + 2 - } - - }*/ + /* + hello_offset += int(sess_len) + 1 + // 2 byte cipher suite array + cs := binary.BigEndian.Uint16(handshakeMsg[hello_offset : hello_offset+2]) + noCS := cs + fmt.Printf("cs = %v / %#x\n", noCS, noCS) + + saved_ciphersuite_size_off := hello_offset + + if !cr.client { + fmt.Printf("SERVER selected ciphersuite: %#x (%s)\n", cs, getCipherSuiteName(uint(cs))) + hello_offset += 2 + } else { + + for csind := 0; csind < int(noCS/2); csind++ { + off := hello_offset + 2 + (csind * 2) + cs = binary.BigEndian.Uint16(handshakeMsg[off : off+2]) + cname := getCipherSuiteName(uint(cs)) + fmt.Printf("%s HELLO CIPHERSUITE: %d/%d: %#x (%s)\n", SRC, csind+1, noCS/2, cs, cname) + + if isBadCipher(cname) { + fmt.Println("BAD CIPHER: ", cname) + } + + } + + hello_offset += 2 + int(noCS) + } + + clen := uint(handshakeMsg[hello_offset]) + hello_offset++ + + if !cr.client { + fmt.Println("SERVER selected compression method: ", clen) + } else { + fmt.Println(SRC, "HELLO COMPRESSION METHODS LEN = ", clen) + fmt.Println(SRC, "HELLO COMPRESSION METHODS: ", handshakeMsg[hello_offset:hello_offset+int(clen)]) + hello_offset += int(clen) + } + + var extlen uint16 = 0 + + if hello_offset == len(handshakeMsg) { + fmt.Println("Message didn't have any extensions present") + } else { + extlen = binary.BigEndian.Uint16(handshakeMsg[hello_offset : hello_offset+2]) + fmt.Println(SRC, "HELLO EXTENSIONS LENGTH: ", extlen) + hello_offset += 2 + } + + if cr.client { + ext_ctr := 0 + + for ext_ctr < int(extlen)-2 { + exttype := binary.BigEndian.Uint16(handshakeMsg[hello_offset : hello_offset+2]) + hello_offset += 2 + ext_ctr += 2 + // fmt.Printf("PROGRESS: %v of %v, %v of %v\n", ext_ctr, extlen, hello_offset, len(handshakeMsg)) + fmt.Printf("EXTTYPE = %#x (%s)\n", exttype, gettlsExtensionName(uint(exttype))) + inner_len := binary.BigEndian.Uint16(handshakeMsg[hello_offset : hello_offset+2]) + hello_offset += int(inner_len) + 2 + ext_ctr += int(inner_len) + 2 + } + + }*/ other.Write(cr.data) continue @@ -610,16 +607,9 @@ select_loop: if !cr.client && isExpected(SSL3_MT_SERVER_HELLO, server_expected) { server_expected = []uint{SSL3_MT_CERTIFICATE} } + if !cr.client && s == SSL3_MT_HELLO_REQUEST { - // fmt.Println("Server sent hello request") -/* if server_change_cipher { - x509Valid = true - other.Write(cr.data) - dChan <- true - dChan2 <- true - break select_loop - } -*/ + //fmt.Println("Server sent hello request") other.Write(cr.data) continue } @@ -670,13 +660,14 @@ select_loop: } verifyOptions.Intermediates = pool - // fmt.Println("ATTEMPTING TO VERIFY: ", fqdn) + //fmt.Println("ATTEMPTING TO VERIFY: ", fqdn) _, err := c.Verify(verifyOptions) - // fmt.Println("ATTEMPTING TO VERIFY RESULT: ", err) + //fmt.Println("ATTEMPTING TO VERIFY RESULT: ", err) if err != nil { return err } else { x509Valid = true + // Added in. other.Write(cr.data) dChan <- true dChan2 <- true @@ -687,10 +678,10 @@ select_loop: other.Write(cr.data) if x509Valid || (s == SSL3_MT_SERVER_DONE) || (s == SSL3_MT_CERTIFICATE_REQUEST) { - // fmt.Println("BREAKING OUT OF LOOP 1") + //fmt.Println("BREAKING OUT OF LOOP 1") dChan <- true dChan2 <- true - // fmt.Println("BREAKING OUT OF LOOP 2") + //fmt.Println("BREAKING OUT OF LOOP 2") break select_loop } @@ -715,7 +706,7 @@ select_loop: // fmt.Println("WAITING; ndone = ", ndone) select { case cr := <-crChan: - // fmt.Printf("CHAN DATA: %v, %v, %v\n", cr.client, cr.err, len(cr.data)) + // fmt.Printf("CHAN DATA: %v, %v, %v\n", cr.client, cr.err, len(cr.data)) if cr.err != nil || cr.data == nil { ndone++ } else if cr.client { @@ -727,11 +718,11 @@ select_loop: } } - // fmt.Println("______ ndone = 2\n") + //fmt.Println("______ ndone = 2\n") // dChan <- true - //close(dChan) - //close(dChan2) + close(dChan) + close(dChan2) if !x509Valid { return errors.New("Unknown error: TLS connection could not be validated") @@ -740,3 +731,4 @@ select_loop: return nil } +