|
|
@ -3,104 +3,132 @@ package sgfw
|
|
|
|
import (
|
|
|
|
import (
|
|
|
|
"crypto/x509"
|
|
|
|
"crypto/x509"
|
|
|
|
"errors"
|
|
|
|
"errors"
|
|
|
|
|
|
|
|
"fmt"
|
|
|
|
"io"
|
|
|
|
"io"
|
|
|
|
"net"
|
|
|
|
"net"
|
|
|
|
|
|
|
|
"time"
|
|
|
|
)
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
func TLSGuard(conn, conn2 net.Conn, fqdn string) error {
|
|
|
|
const TLSGUARD_READ_TIMEOUT = 2 * time.Second
|
|
|
|
// Should this be a requirement?
|
|
|
|
const TLSGUARD_MIN_TLS_VER_MAJ = 3
|
|
|
|
// if strings.HasSuffix(request.DestAddr.FQDN, "onion") {
|
|
|
|
const TLSGUARD_MIN_TLS_VER_MIN = 1
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
const SSL3_RT_CHANGE_CIPHER_SPEC = 20
|
|
|
|
|
|
|
|
const SSL3_RT_ALERT = 21
|
|
|
|
|
|
|
|
const SSL3_RT_HANDSHAKE = 22
|
|
|
|
|
|
|
|
const SSL3_RT_APPLICATION_DATA = 23
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
const SSL3_MT_SERVER_HELLO = 2
|
|
|
|
|
|
|
|
const SSL3_MT_CERTIFICATE = 11
|
|
|
|
|
|
|
|
const SSL3_MT_CERTIFICATE_REQUEST = 13
|
|
|
|
|
|
|
|
const SSL3_MT_SERVER_DONE = 14
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
func readTLSChunk(conn net.Conn) ([]byte, int, error) {
|
|
|
|
|
|
|
|
conn.SetReadDeadline(time.Now().Add(TLSGUARD_READ_TIMEOUT))
|
|
|
|
|
|
|
|
cbytes, err := readNBytes(conn, 5)
|
|
|
|
|
|
|
|
conn.SetReadDeadline(time.Time{})
|
|
|
|
|
|
|
|
|
|
|
|
handshakeByte, err := readNBytes(conn, 1)
|
|
|
|
|
|
|
|
if err != nil {
|
|
|
|
if err != nil {
|
|
|
|
return err
|
|
|
|
fmt.Println("TLS data chunk read failure: ", err)
|
|
|
|
|
|
|
|
return nil, 0, err
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
if handshakeByte[0] != 0x16 {
|
|
|
|
if int(cbytes[1]) < TLSGUARD_MIN_TLS_VER_MAJ {
|
|
|
|
return errors.New("Blocked client from attempting non-TLS connection")
|
|
|
|
return nil, 0, errors.New("TLS protocol major version less than expected minimum")
|
|
|
|
|
|
|
|
} else if int(cbytes[2]) < TLSGUARD_MIN_TLS_VER_MIN {
|
|
|
|
|
|
|
|
return nil, 0, errors.New("TLS protocol minor version less than expected minimum")
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
vers, err := readNBytes(conn, 2)
|
|
|
|
cbyte := cbytes[0]
|
|
|
|
if err != nil {
|
|
|
|
mlen := int(int(cbytes[3])<<8 | int(cbytes[4]))
|
|
|
|
return err
|
|
|
|
// fmt.Printf("TLS data chunk header read: type = %#x, maj = %v, min = %v, len = %v\n", cbyte, cbytes[1], cbytes[2], mlen)
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
conn.SetReadDeadline(time.Now().Add(TLSGUARD_READ_TIMEOUT))
|
|
|
|
|
|
|
|
cbytes2, err := readNBytes(conn, mlen)
|
|
|
|
|
|
|
|
conn.SetReadDeadline(time.Time{})
|
|
|
|
|
|
|
|
|
|
|
|
length, err := readNBytes(conn, 2)
|
|
|
|
|
|
|
|
if err != nil {
|
|
|
|
if err != nil {
|
|
|
|
return err
|
|
|
|
return nil, 0, err
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
ffslen := int(int(length[0])<<8 | int(length[1]))
|
|
|
|
cbytes = append(cbytes, cbytes2...)
|
|
|
|
|
|
|
|
return cbytes, int(cbyte), nil
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
func TLSGuard(conn, conn2 net.Conn, fqdn string) error {
|
|
|
|
|
|
|
|
// Should this be a requirement?
|
|
|
|
|
|
|
|
// if strings.HasSuffix(request.DestAddr.FQDN, "onion") {
|
|
|
|
|
|
|
|
|
|
|
|
ffs, err := readNBytes(conn, ffslen)
|
|
|
|
//conn client
|
|
|
|
|
|
|
|
//conn2 server
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
// Read the opening message from the client
|
|
|
|
|
|
|
|
chunk, rtype, err := readTLSChunk(conn)
|
|
|
|
if err != nil {
|
|
|
|
if err != nil {
|
|
|
|
return err
|
|
|
|
return err
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
// Transmit client hello
|
|
|
|
if rtype != SSL3_RT_HANDSHAKE {
|
|
|
|
conn2.Write(handshakeByte)
|
|
|
|
return errors.New("Blocked client from attempting non-TLS connection")
|
|
|
|
conn2.Write(vers)
|
|
|
|
}
|
|
|
|
conn2.Write(length)
|
|
|
|
|
|
|
|
conn2.Write(ffs)
|
|
|
|
// Pass it on through to the server
|
|
|
|
|
|
|
|
conn2.Write(chunk)
|
|
|
|
|
|
|
|
|
|
|
|
// Read ServerHello
|
|
|
|
// Read ServerHello
|
|
|
|
bytesRead := 0
|
|
|
|
|
|
|
|
var s byte // 0x0e is done
|
|
|
|
|
|
|
|
var responseBuf []byte = []byte{}
|
|
|
|
|
|
|
|
valid := false
|
|
|
|
valid := false
|
|
|
|
sendToClient := false
|
|
|
|
loop := 1
|
|
|
|
|
|
|
|
|
|
|
|
for sendToClient == false {
|
|
|
|
passthru := false
|
|
|
|
// Handshake byte
|
|
|
|
|
|
|
|
serverhandshakeByte, err := readNBytes(conn2, 1)
|
|
|
|
|
|
|
|
if err != nil {
|
|
|
|
|
|
|
|
return nil
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
responseBuf = append(responseBuf, serverhandshakeByte[0])
|
|
|
|
for 1 == 1 {
|
|
|
|
bytesRead += 1
|
|
|
|
loop++
|
|
|
|
|
|
|
|
|
|
|
|
if serverhandshakeByte[0] != 0x16 {
|
|
|
|
// fmt.Printf("SSL LOOP %v; trying to read: conn2\n", loop)
|
|
|
|
return errors.New("Expected TLS server handshake byte was not received")
|
|
|
|
chunk, rtype, err = readTLSChunk(conn2)
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
// Protocol version, 2 bytes
|
|
|
|
|
|
|
|
serverProtocolVer, err := readNBytes(conn2, 2)
|
|
|
|
|
|
|
|
if err != nil {
|
|
|
|
if err != nil {
|
|
|
|
return err
|
|
|
|
fmt.Printf("OTHER loop %v: trying to read: conn\n", loop)
|
|
|
|
}
|
|
|
|
chunk, rtype, err2 := readTLSChunk(conn)
|
|
|
|
|
|
|
|
fmt.Printf("read: %v, %v, %v\n", err2, rtype, len(chunk))
|
|
|
|
|
|
|
|
|
|
|
|
bytesRead += 2
|
|
|
|
if err2 == nil {
|
|
|
|
responseBuf = append(responseBuf, serverProtocolVer...)
|
|
|
|
conn2.Write(chunk)
|
|
|
|
|
|
|
|
continue
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
// Record length, 2 bytes
|
|
|
|
|
|
|
|
serverRecordLen, err := readNBytes(conn2, 2)
|
|
|
|
|
|
|
|
if err != nil {
|
|
|
|
|
|
|
|
return err
|
|
|
|
return err
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
bytesRead += 2
|
|
|
|
if rtype == SSL3_RT_CHANGE_CIPHER_SPEC || rtype == SSL3_RT_APPLICATION_DATA ||
|
|
|
|
responseBuf = append(responseBuf, serverRecordLen...)
|
|
|
|
rtype == SSL3_RT_ALERT {
|
|
|
|
serverRecordLenInt := int(int(serverRecordLen[0])<<8 | int(serverRecordLen[1]))
|
|
|
|
// fmt.Println("OTHER DATA; PASSING THRU")
|
|
|
|
|
|
|
|
passthru = true
|
|
|
|
|
|
|
|
} else if rtype == SSL3_RT_HANDSHAKE {
|
|
|
|
|
|
|
|
passthru = false
|
|
|
|
|
|
|
|
} else {
|
|
|
|
|
|
|
|
return errors.New(fmt.Sprintf("Expected TLS server handshake byte was not received [%#x vs 0x16]", rtype))
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
// Record type byte
|
|
|
|
if passthru {
|
|
|
|
serverMsg, err := readNBytes(conn2, serverRecordLenInt)
|
|
|
|
// fmt.Println("passthru writing buf again and continuing:")
|
|
|
|
if err != nil {
|
|
|
|
conn.Write(chunk)
|
|
|
|
return err
|
|
|
|
continue
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
bytesRead += len(serverMsg)
|
|
|
|
serverMsg := chunk[5:]
|
|
|
|
responseBuf = append(responseBuf, serverMsg...)
|
|
|
|
s := serverMsg[0]
|
|
|
|
s = serverMsg[0]
|
|
|
|
fmt.Printf("s = %#x\n", s)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if s == SSL3_MT_CERTIFICATE {
|
|
|
|
// Message len, 3 bytes
|
|
|
|
// Message len, 3 bytes
|
|
|
|
serverMessageLen := serverMsg[1:4]
|
|
|
|
serverMessageLen := serverMsg[1:4]
|
|
|
|
serverMessageLenInt := int(int(serverMessageLen[0])<<16 | int(serverMessageLen[1])<<8 | int(serverMessageLen[2]))
|
|
|
|
serverMessageLenInt := int(int(serverMessageLen[0])<<16 | int(serverMessageLen[1])<<8 | int(serverMessageLen[2]))
|
|
|
|
|
|
|
|
// fmt.Printf("chunk len = %v, serverMsgLen = %v, slint = %v\n", len(chunk), len(serverMsg), serverMessageLenInt)
|
|
|
|
// serverHelloBody, err := readNBytes(conn2, serverMessageLenInt)
|
|
|
|
if len(serverMsg) < serverMessageLenInt {
|
|
|
|
|
|
|
|
return errors.New(fmt.Sprintf("len(serverMsg) %v < serverMessageLenInt %v!\n", len(serverMsg), serverMessageLenInt))
|
|
|
|
|
|
|
|
}
|
|
|
|
serverHelloBody := serverMsg[4 : 4+serverMessageLenInt]
|
|
|
|
serverHelloBody := serverMsg[4 : 4+serverMessageLenInt]
|
|
|
|
|
|
|
|
|
|
|
|
if s == 0x0b {
|
|
|
|
|
|
|
|
certChainLen := int(int(serverHelloBody[0])<<16 | int(serverHelloBody[1])<<8 | int(serverHelloBody[2]))
|
|
|
|
certChainLen := int(int(serverHelloBody[0])<<16 | int(serverHelloBody[1])<<8 | int(serverHelloBody[2]))
|
|
|
|
remaining := certChainLen
|
|
|
|
remaining := certChainLen
|
|
|
|
pos := serverHelloBody[3:certChainLen]
|
|
|
|
pos := serverHelloBody[3:certChainLen]
|
|
|
@ -136,26 +164,24 @@ func TLSGuard(conn, conn2 net.Conn, fqdn string) error {
|
|
|
|
}
|
|
|
|
}
|
|
|
|
verifyOptions.Intermediates = pool
|
|
|
|
verifyOptions.Intermediates = pool
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
// fmt.Println("ATTEMPTING TO VERIFY: ", fqdn)
|
|
|
|
_, err = c.Verify(verifyOptions)
|
|
|
|
_, err = c.Verify(verifyOptions)
|
|
|
|
|
|
|
|
// fmt.Println("ATTEMPTING TO VERIFY RESULT: ", err)
|
|
|
|
if err != nil {
|
|
|
|
if err != nil {
|
|
|
|
return err
|
|
|
|
return err
|
|
|
|
} else {
|
|
|
|
} else {
|
|
|
|
valid = true
|
|
|
|
valid = true
|
|
|
|
}
|
|
|
|
}
|
|
|
|
// else if s == 0x0d { fmt.Printf("found a client cert request, sending buf to client\n") }
|
|
|
|
// else if s == 0x0d { fmt.Printf("found a client cert request, sending buf to client\n") }
|
|
|
|
} else if s == 0x0e {
|
|
|
|
} else if s == SSL3_MT_SERVER_DONE {
|
|
|
|
sendToClient = true
|
|
|
|
conn.Write(chunk)
|
|
|
|
} else if s == 0x0d {
|
|
|
|
break
|
|
|
|
sendToClient = true
|
|
|
|
} else if s == SSL3_MT_CERTIFICATE_REQUEST {
|
|
|
|
}
|
|
|
|
break
|
|
|
|
|
|
|
|
}
|
|
|
|
// fmt.Printf("Version bytes: %x %x\n", responseBuf[1], responseBuf[2])
|
|
|
|
// fmt.Printf("Sending chunk of type %d to client.\n", s)
|
|
|
|
// fmt.Printf("Len bytes: %x %x\n", responseBuf[3], responseBuf[4])
|
|
|
|
|
|
|
|
// fmt.Printf("Message type: %x\n", responseBuf[5])
|
|
|
|
conn.Write(chunk)
|
|
|
|
// fmt.Printf("Message len: %x %x %x\n", responseBuf[6], responseBuf[7], responseBuf[8])
|
|
|
|
|
|
|
|
// fmt.Printf("Message body: %v\n", responseBuf[9:])
|
|
|
|
|
|
|
|
conn.Write(responseBuf)
|
|
|
|
|
|
|
|
responseBuf = []byte{}
|
|
|
|
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
if !valid {
|
|
|
|
if !valid {
|
|
|
|