Further improvements to tlsguard by uT / remove dumb debugging stuff I left in there

master
dma 7 years ago
parent 96061fb18d
commit 2012b070c7

@ -1,3 +1,4 @@
package sgfw package sgfw
import ( import (
@ -9,10 +10,9 @@ import (
"io" "io"
"net" "net"
"time" "time"
"math/rand"
) )
const TLSGUARD_READ_TIMEOUT = 8 * time.Second const TLSGUARD_READ_TIMEOUT = 8 // seconds
const TLSGUARD_MIN_TLS_VER_MAJ = 3 const TLSGUARD_MIN_TLS_VER_MAJ = 3
const TLSGUARD_MIN_TLS_VER_MIN = 1 const TLSGUARD_MIN_TLS_VER_MIN = 1
@ -125,7 +125,6 @@ type connReader struct {
data []byte data []byte
rtype int rtype int
err error err error
numb int
} }
var cipherSuiteMap map[uint16]string = map[uint16]string{ var cipherSuiteMap map[uint16]string = map[uint16]string{
@ -266,48 +265,46 @@ func stripTLSData(record []byte, start_ind, end_ind int, len_ind int, len_size i
return result return result
} }
func connectionReader(conn net.Conn, is_client bool, c chan connReader, done chan bool, num int) { func connectionReader(conn net.Conn, is_client bool, c chan connReader, done chan bool) {
var ret_error error = nil var ret_error error = nil
buffered := []byte{} buffered := []byte{}
mlen := 0 mlen := 0
rtype := 0 rtype := 0
stage := 1 stage := 1
proceed := true ntimeouts := 0
for { for {
if ret_error != nil { if ret_error != nil {
cr := connReader{client: is_client, data: nil, rtype: 0, err: ret_error, numb: num} cr := connReader{client: is_client, data: nil, rtype: 0, err: ret_error}
c <- cr c <- cr
break break
} }
//fmt.Printf("why i am here %v %d\n", is_client, num)
//if is_client == true && proceed == false {
if proceed == false {
if len(buffered) > 0 {
c <- connReader{client: is_client, data: buffered, rtype:0, err: nil, numb: num}
}
c <- connReader{client: is_client, data: nil, rtype: 0, err: nil}
return
}
select { select {
case <-done: case <-done:
// fmt.Printf("++ DONE %d: %v\n", num, is_client) //fmt.Println("++ DONE: ", is_client)
if len(buffered) > 0 { if len(buffered) > 0 {
// fmt.Printf("++ DONE BUT DISPOSING OF BUFFERED DATA num: %d\n", num) //fmt.Println("++ DONE BUT DISPOSING OF BUFFERED DATA")
c <- connReader{client: is_client, data: buffered, rtype: 0, err: nil, numb: num} c <- connReader{client: is_client, data: buffered, rtype: 0, err: nil}
} }
c <- connReader{client: is_client, data: nil, rtype: 0, err: nil, numb: num}
c <- connReader{client: is_client, data: nil, rtype: 0, err: nil}
return return
default: default:
if stage == 1 && proceed == true { if stage == 1 {
header := make([]byte, TLS_RECORD_HDR_LEN) header := make([]byte, TLS_RECORD_HDR_LEN)
conn.SetReadDeadline(time.Now().Add(TLSGUARD_READ_TIMEOUT)) conn.SetReadDeadline(time.Now().Add(1 * time.Second))
// fmt.Printf("About to read here stage 1 %v %d\n", is_client, num)
_, err := io.ReadFull(conn, header) _, err := io.ReadFull(conn, header)
// fmt.Printf("Read here stage 1 %v num %d \n", is_client, num)
conn.SetReadDeadline(time.Time{}) conn.SetReadDeadline(time.Time{})
if err != nil { if err != nil {
ret_error = err if err, ok := err.(net.Error); ok && err.Timeout() {
ret_error = err
} else {
ntimeouts++
if ntimeouts == TLSGUARD_READ_TIMEOUT {
ret_error = err
}
}
continue continue
} }
@ -324,7 +321,7 @@ func connectionReader(conn net.Conn, is_client bool, c chan connReader, done cha
rtype = int(header[0]) rtype = int(header[0])
mlen = int(int(header[3])<<8 | int(header[4])) mlen = int(int(header[3])<<8 | int(header[4]))
//fmt.Printf("TLS data chunk header read: type = %#x, maj = %v, min = %v, len = %v\n", rtype, header[1], header[2], mlen) // fmt.Printf("TLS data chunk header read: type = %#x, maj = %v, min = %v, len = %v\n", rtype, header[1], header[2], mlen)
/* 16384+1024 if compression is not null */ /* 16384+1024 if compression is not null */
/* or 16384+2048 if ciphertext */ /* or 16384+2048 if ciphertext */
@ -335,30 +332,34 @@ func connectionReader(conn net.Conn, is_client bool, c chan connReader, done cha
buffered = header buffered = header
stage++ stage++
ntimeouts = 0
} else if stage == 2 { } else if stage == 2 {
remainder := make([]byte, mlen) remainder := make([]byte, mlen)
// fmt.Printf("About to read here stage 2 %v num %d\n", is_client, num) conn.SetReadDeadline(time.Now().Add(1 * time.Second))
conn.SetReadDeadline(time.Now().Add(TLSGUARD_READ_TIMEOUT))
_, err := io.ReadFull(conn, remainder) _, err := io.ReadFull(conn, remainder)
conn.SetReadDeadline(time.Time{}) conn.SetReadDeadline(time.Time{})
if err != nil { if err != nil {
ret_error = err if err, ok := err.(net.Error); ok && err.Timeout() {
ret_error = err
} else {
ntimeouts++
if ntimeouts == TLSGUARD_READ_TIMEOUT {
ret_error = err
}
}
continue continue
} }
buffered = append(buffered, remainder...) buffered = append(buffered, remainder...)
// fmt.Printf("------- CHUNK READ: client: %v, err = %v, bytes = %v\n", is_client, err, len(buffered)) // fmt.Printf("------- CHUNK READ: client: %v, err = %v, bytes = %v\n", is_client, err, len(buffered))
cr := connReader{client: is_client, data: buffered, rtype: rtype, err: err, numb: num} cr := connReader{client: is_client, data: buffered, rtype: rtype, err: err}
c <- cr c <- cr
buffered = []byte{} buffered = []byte{}
rtype = 0 rtype = 0
mlen = 0 mlen = 0
stage = 1 stage = 1
//proceed = false ntimeouts = 0
if is_client {
proceed = false
}
} }
} }
@ -390,11 +391,8 @@ func TLSGuard(conn, conn2 net.Conn, fqdn string) error {
crChan := make(chan connReader) crChan := make(chan connReader)
dChan := make(chan bool, 10) dChan := make(chan bool, 10)
dChan2 := make(chan bool, 10) dChan2 := make(chan bool, 10)
rand.Seed(time.Now().UTC().UnixNano()) go connectionReader(conn, true, crChan, dChan)
connectThread1 := rand.Intn(1000) go connectionReader(conn2, false, crChan, dChan2)
connectThread2 := rand.Intn(1000)
go connectionReader(conn, true, crChan, dChan, connectThread1)
go connectionReader(conn2, false, crChan, dChan2, connectThread2)
client_expected := []uint{SSL3_MT_CLIENT_HELLO} client_expected := []uint{SSL3_MT_CLIENT_HELLO}
server_expected := []uint{SSL3_MT_SERVER_HELLO} server_expected := []uint{SSL3_MT_SERVER_HELLO}
@ -407,7 +405,7 @@ func TLSGuard(conn, conn2 net.Conn, fqdn string) error {
select_loop: select_loop:
for { for {
if ndone == 2 { if ndone == 2 {
// fmt.Println("DONE channel got both notifications. Terminating loop.") //fmt.Println("DONE channel got both notifications. Terminating loop.")
close(dChan) close(dChan)
close(dChan2) close(dChan2)
close(crChan) close(crChan)
@ -422,9 +420,9 @@ select_loop:
other = conn2 other = conn2
} }
//fmt.Printf("++++ SELECT: %v, %v, %v\n", cr.client, cr.err, len(cr.data)) // fmt.Printf("++++ SELECT: %v, %v, %v\n", cr.client, cr.err, len(cr.data))
if cr.err == nil && cr.data == nil { if cr.err == nil && cr.data == nil {
// fmt.Println("DONE channel notification received") //fmt.Println("DONE channel notification received")
ndone++ ndone++
continue continue
} }
@ -435,7 +433,7 @@ select_loop:
/* We expect only a single byte of data */ /* We expect only a single byte of data */
if cr.rtype == SSL3_RT_CHANGE_CIPHER_SPEC { if cr.rtype == SSL3_RT_CHANGE_CIPHER_SPEC {
// fmt.Println("CHANGE CIPHER_SPEC: ", cr.data[TLS_RECORD_HDR_LEN]) // fmt.Println("CHANGE CIPHER_SPEC: ", cr.data[TLS_RECORD_HDR_LEN])
if len(cr.data) != 6 { if len(cr.data) != 6 {
return errors.New(fmt.Sprintf("TLSGuard dropped connection with strange change cipher spec data length (%v bytes)", len(cr.data))) return errors.New(fmt.Sprintf("TLSGuard dropped connection with strange change cipher spec data length (%v bytes)", len(cr.data)))
} }
@ -447,9 +445,6 @@ select_loop:
client_change_cipher = true client_change_cipher = true
} else { } else {
server_change_cipher = true server_change_cipher = true
x509Valid = true
dChan <- true
dChan2 <- true
} }
} else if cr.rtype == SSL3_RT_ALERT { } else if cr.rtype == SSL3_RT_ALERT {
if cr.data[TLS_RECORD_HDR_LEN] == SSL3_AL_WARNING { if cr.data[TLS_RECORD_HDR_LEN] == SSL3_AL_WARNING {
@ -461,7 +456,7 @@ select_loop:
} }
alert_desc := int(int(cr.data[5])<<8 | int(cr.data[6])) alert_desc := int(int(cr.data[5])<<8 | int(cr.data[6]))
fmt.Println("ALERT DESCRIPTION: ", alert_desc) // fmt.Println("ALERT DESCRIPTION: ", alert_desc)
if cr.data[TLS_RECORD_HDR_LEN] == SSL3_AL_FATAL { if cr.data[TLS_RECORD_HDR_LEN] == SSL3_AL_FATAL {
return errors.New(fmt.Sprintf("TLSGuard dropped connection after fatal error alert detected")) return errors.New(fmt.Sprintf("TLSGuard dropped connection after fatal error alert detected"))
@ -484,12 +479,12 @@ select_loop:
if (client_sess || server_sess) && (client_change_cipher || server_change_cipher) { if (client_sess || server_sess) && (client_change_cipher || server_change_cipher) {
if handshakeMessageLenInt > len(cr.data)+9 { if handshakeMessageLenInt > len(cr.data)+9 {
log.Notice("TLSGuard saw what looks like a resumed encrypted session... passing connection through") // log.Notice("TLSGuard saw what looks like a resumed encrypted session... passing connection through")
x509Valid = true
other.Write(cr.data) other.Write(cr.data)
dChan2 <- true
dChan <- true dChan <- true
dChan2 <- true
x509Valid = true
break select_loop break select_loop
} }
@ -504,27 +499,29 @@ select_loop:
if (cr.client && s == SSL3_MT_CLIENT_HELLO) || (!cr.client && s == SSL3_MT_SERVER_HELLO) { if (cr.client && s == SSL3_MT_CLIENT_HELLO) || (!cr.client && s == SSL3_MT_SERVER_HELLO) {
// rewrite := false // rewrite := false
// rewrite_buf := []byte{} // rewrite_buf := []byte{}
//SRC := "" /* SRC := ""
if s != SSL3_MT_CLIENT_HELLO {
//SRC = "CLIENT" if s == SSL3_MT_CLIENT_HELLO {
//} else { SRC = "CLIENT"
} else {
server_expected = []uint{SSL3_MT_CERTIFICATE, SSL3_MT_HELLO_REQUEST} server_expected = []uint{SSL3_MT_CERTIFICATE, SSL3_MT_HELLO_REQUEST}
// SRC = "SERVER" SRC = "SERVER"
} }
*/
hello_offset := 4 hello_offset := 4
// 2 byte protocol version // 2 byte protocol version
// fmt.Println(SRC, "HELLO VERSION = ", handshakeMsg[hello_offset:hello_offset+2]) // fmt.Println(SRC, "HELLO VERSION = ", handshakeMsg[hello_offset:hello_offset+2])
hello_offset += 2 hello_offset += 2
// 4 byte Random/GMT time // 4 byte Random/GMT time
//gmtbytes := binary.BigEndian.Uint32(handshakeMsg[hello_offset : hello_offset+4]) //gmtbytes := binary.BigEndian.Uint32(handshakeMsg[hello_offset : hello_offset+4])
//gmt := time.Unix(int64(gmtbytes), 0) //gmt := time.Unix(int64(gmtbytes), 0)
//fmt.Println(SRC, "HELLO GMT = ", gmt) // fmt.Println(SRC, "HELLO GMT = ", gmt)
hello_offset += 4 hello_offset += 4
// 28 bytes Random/random_bytes // 28 bytes Random/random_bytes
hello_offset += 28 hello_offset += 28
// 1 byte (32-bit session ID) // 1 byte (32-bit session ID)
sess_len := uint(handshakeMsg[hello_offset]) sess_len := uint(handshakeMsg[hello_offset])
// fmt.Println(SRC, "HELLO SESSION ID = ", sess_len) // fmt.Println(SRC, "HELLO SESSION ID = ", sess_len)
if cr.client && sess_len > 0 { if cr.client && sess_len > 0 {
client_sess = true client_sess = true
@ -532,71 +529,71 @@ select_loop:
server_sess = true server_sess = true
} }
/* /*
hello_offset += int(sess_len) + 1 hello_offset += int(sess_len) + 1
// 2 byte cipher suite array // 2 byte cipher suite array
cs := binary.BigEndian.Uint16(handshakeMsg[hello_offset : hello_offset+2]) cs := binary.BigEndian.Uint16(handshakeMsg[hello_offset : hello_offset+2])
noCS := cs noCS := cs
fmt.Printf("cs = %v / %#x\n", noCS, noCS) fmt.Printf("cs = %v / %#x\n", noCS, noCS)
saved_ciphersuite_size_off := hello_offset saved_ciphersuite_size_off := hello_offset
if !cr.client { if !cr.client {
fmt.Printf("SERVER selected ciphersuite: %#x (%s)\n", cs, getCipherSuiteName(uint(cs))) fmt.Printf("SERVER selected ciphersuite: %#x (%s)\n", cs, getCipherSuiteName(uint(cs)))
hello_offset += 2 hello_offset += 2
} else { } else {
for csind := 0; csind < int(noCS/2); csind++ { for csind := 0; csind < int(noCS/2); csind++ {
off := hello_offset + 2 + (csind * 2) off := hello_offset + 2 + (csind * 2)
cs = binary.BigEndian.Uint16(handshakeMsg[off : off+2]) cs = binary.BigEndian.Uint16(handshakeMsg[off : off+2])
cname := getCipherSuiteName(uint(cs)) cname := getCipherSuiteName(uint(cs))
fmt.Printf("%s HELLO CIPHERSUITE: %d/%d: %#x (%s)\n", SRC, csind+1, noCS/2, cs, cname) fmt.Printf("%s HELLO CIPHERSUITE: %d/%d: %#x (%s)\n", SRC, csind+1, noCS/2, cs, cname)
if isBadCipher(cname) { if isBadCipher(cname) {
fmt.Println("BAD CIPHER: ", cname) fmt.Println("BAD CIPHER: ", cname)
} }
} }
hello_offset += 2 + int(noCS) hello_offset += 2 + int(noCS)
} }
clen := uint(handshakeMsg[hello_offset]) clen := uint(handshakeMsg[hello_offset])
hello_offset++ hello_offset++
if !cr.client { if !cr.client {
fmt.Println("SERVER selected compression method: ", clen) fmt.Println("SERVER selected compression method: ", clen)
} else { } else {
fmt.Println(SRC, "HELLO COMPRESSION METHODS LEN = ", clen) fmt.Println(SRC, "HELLO COMPRESSION METHODS LEN = ", clen)
fmt.Println(SRC, "HELLO COMPRESSION METHODS: ", handshakeMsg[hello_offset:hello_offset+int(clen)]) fmt.Println(SRC, "HELLO COMPRESSION METHODS: ", handshakeMsg[hello_offset:hello_offset+int(clen)])
hello_offset += int(clen) hello_offset += int(clen)
} }
var extlen uint16 = 0 var extlen uint16 = 0
if hello_offset == len(handshakeMsg) { if hello_offset == len(handshakeMsg) {
fmt.Println("Message didn't have any extensions present") fmt.Println("Message didn't have any extensions present")
} else { } else {
extlen = binary.BigEndian.Uint16(handshakeMsg[hello_offset : hello_offset+2]) extlen = binary.BigEndian.Uint16(handshakeMsg[hello_offset : hello_offset+2])
fmt.Println(SRC, "HELLO EXTENSIONS LENGTH: ", extlen) fmt.Println(SRC, "HELLO EXTENSIONS LENGTH: ", extlen)
hello_offset += 2 hello_offset += 2
} }
if cr.client { if cr.client {
ext_ctr := 0 ext_ctr := 0
for ext_ctr < int(extlen)-2 { for ext_ctr < int(extlen)-2 {
exttype := binary.BigEndian.Uint16(handshakeMsg[hello_offset : hello_offset+2]) exttype := binary.BigEndian.Uint16(handshakeMsg[hello_offset : hello_offset+2])
hello_offset += 2 hello_offset += 2
ext_ctr += 2 ext_ctr += 2
// fmt.Printf("PROGRESS: %v of %v, %v of %v\n", ext_ctr, extlen, hello_offset, len(handshakeMsg)) // fmt.Printf("PROGRESS: %v of %v, %v of %v\n", ext_ctr, extlen, hello_offset, len(handshakeMsg))
fmt.Printf("EXTTYPE = %#x (%s)\n", exttype, gettlsExtensionName(uint(exttype))) fmt.Printf("EXTTYPE = %#x (%s)\n", exttype, gettlsExtensionName(uint(exttype)))
inner_len := binary.BigEndian.Uint16(handshakeMsg[hello_offset : hello_offset+2]) inner_len := binary.BigEndian.Uint16(handshakeMsg[hello_offset : hello_offset+2])
hello_offset += int(inner_len) + 2 hello_offset += int(inner_len) + 2
ext_ctr += int(inner_len) + 2 ext_ctr += int(inner_len) + 2
} }
}*/ }*/
other.Write(cr.data) other.Write(cr.data)
continue continue
@ -610,16 +607,9 @@ select_loop:
if !cr.client && isExpected(SSL3_MT_SERVER_HELLO, server_expected) { if !cr.client && isExpected(SSL3_MT_SERVER_HELLO, server_expected) {
server_expected = []uint{SSL3_MT_CERTIFICATE} server_expected = []uint{SSL3_MT_CERTIFICATE}
} }
if !cr.client && s == SSL3_MT_HELLO_REQUEST { if !cr.client && s == SSL3_MT_HELLO_REQUEST {
// fmt.Println("Server sent hello request") //fmt.Println("Server sent hello request")
/* if server_change_cipher {
x509Valid = true
other.Write(cr.data)
dChan <- true
dChan2 <- true
break select_loop
}
*/
other.Write(cr.data) other.Write(cr.data)
continue continue
} }
@ -670,13 +660,14 @@ select_loop:
} }
verifyOptions.Intermediates = pool verifyOptions.Intermediates = pool
// fmt.Println("ATTEMPTING TO VERIFY: ", fqdn) //fmt.Println("ATTEMPTING TO VERIFY: ", fqdn)
_, err := c.Verify(verifyOptions) _, err := c.Verify(verifyOptions)
// fmt.Println("ATTEMPTING TO VERIFY RESULT: ", err) //fmt.Println("ATTEMPTING TO VERIFY RESULT: ", err)
if err != nil { if err != nil {
return err return err
} else { } else {
x509Valid = true x509Valid = true
// Added in.
other.Write(cr.data) other.Write(cr.data)
dChan <- true dChan <- true
dChan2 <- true dChan2 <- true
@ -687,10 +678,10 @@ select_loop:
other.Write(cr.data) other.Write(cr.data)
if x509Valid || (s == SSL3_MT_SERVER_DONE) || (s == SSL3_MT_CERTIFICATE_REQUEST) { if x509Valid || (s == SSL3_MT_SERVER_DONE) || (s == SSL3_MT_CERTIFICATE_REQUEST) {
// fmt.Println("BREAKING OUT OF LOOP 1") //fmt.Println("BREAKING OUT OF LOOP 1")
dChan <- true dChan <- true
dChan2 <- true dChan2 <- true
// fmt.Println("BREAKING OUT OF LOOP 2") //fmt.Println("BREAKING OUT OF LOOP 2")
break select_loop break select_loop
} }
@ -715,7 +706,7 @@ select_loop:
// fmt.Println("WAITING; ndone = ", ndone) // fmt.Println("WAITING; ndone = ", ndone)
select { select {
case cr := <-crChan: case cr := <-crChan:
// fmt.Printf("CHAN DATA: %v, %v, %v\n", cr.client, cr.err, len(cr.data)) // fmt.Printf("CHAN DATA: %v, %v, %v\n", cr.client, cr.err, len(cr.data))
if cr.err != nil || cr.data == nil { if cr.err != nil || cr.data == nil {
ndone++ ndone++
} else if cr.client { } else if cr.client {
@ -727,11 +718,11 @@ select_loop:
} }
} }
// fmt.Println("______ ndone = 2\n") //fmt.Println("______ ndone = 2\n")
// dChan <- true // dChan <- true
//close(dChan) close(dChan)
//close(dChan2) close(dChan2)
if !x509Valid { if !x509Valid {
return errors.New("Unknown error: TLS connection could not be validated") return errors.New("Unknown error: TLS connection could not be validated")
@ -740,3 +731,4 @@ select_loop:
return nil return nil
} }

Loading…
Cancel
Save