Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Issue with DTLS: Unable to Achieve Handshake Between Client and Serve #1323

Open
hamma96 opened this issue Jul 24, 2024 · 0 comments
Open

Issue with DTLS: Unable to Achieve Handshake Between Client and Serve #1323

hamma96 opened this issue Jul 24, 2024 · 0 comments

Comments

@hamma96
Copy link

hamma96 commented Jul 24, 2024

I have been working for a week on creating a DTLS (Datagram Transport Layer Security) client-server setup, but I am consistently failing to achieve a successful handshake. Despite multiple attempts and configurations, the handshake process does not complete as expected.

`import socket
import logging
from OpenSSL import SSL
from openssl_psk import patch_context
import time
import threading
import hashlib

patch_context()

logging.basicConfig(level=logging.INFO)

def psk_client_callback(connection, hint):
logging.info(f"[TLSClient] PSK client callback called with hint: {hint}")
identity = b'client-identity'
key = b'1a2b3c4d5e6f'
logging.info(f"[TLSClient] Returning identity: {identity}, key: {key}")
return (identity, key)

class TLSClient:
def init(self, config):
self.context = SSL.Context(SSL.DTLS_METHOD)
self.context.set_cipher_list(b'PSK-AES256-CBC-SHA')
self.context.set_psk_client_callback(psk_client_callback)
self.context.set_options(SSL.OP_NO_RENEGOTIATION)
self.context.set_info_callback(lambda conn, where, ret: print(f"[TLSClient] Info: where={where}, ret={ret}, state={conn.get_state_string()}"))
self.client_socket = None
self.config = config
self.ssl_conn = None
self.callback_running = False
self._running = False

def log_handshake_progress(self, conn):
    state = conn.get_state_string()
    pending = conn.pending()
    cipher_name = conn.get_cipher_name()
    version = conn.get_protocol_version_name()
    logging.info(f"[TLSClient] Handshake state: {state}, Pending: {pending}, Cipher: {cipher_name}, Version: {version}")


def start_client(self):
    try:
        self._running = True
        self.client_socket = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
        #self.client_socket.setblocking(False)
        self.client_socket.connect(self.config['address'])
        
        self.context.set_timeout(30)
        self.ssl_conn = SSL.Connection(self.context, self.client_socket)            
        self.ssl_conn.set_connect_state()

        logging.info("[TLSClient] Starting DTLS handshake...")
        while self._running:
            try:   

                self.log_handshake_progress(self.ssl_conn)
                self.ssl_conn.do_handshake()
            
            except SSL.WantReadError:
                self.log_handshake_progress(self.ssl_conn)
                pass
            else:
                logging.info("[TLSClient] else handshake.")
                self._running = False
            
        self.log_handshake_progress(self.ssl_conn)
        logging.info("[TLSClient] DTLS handshake completed.")

        # Send a message to the server
        message = b"Hello from Client!"
        self.ssl_conn.send(message)
        logging.info(f"[TLSClient] Sent to server: {message}")

        # Receive a response from the server
        data = self.ssl_conn.recv(self.config['buffer_size'])
        logging.info(f"[TLSClient] Received from server: {data.decode()}")

        self.ssl_conn.shutdown()
        self.ssl_conn.close()

    except SSL.Error as e:
        logging.error(f"[TLSClient] SSL error: {e}")
    except Exception as e:
        logging.error(f"[TLSClient] Error: {e}")
    finally:
        self.callback_running = False  # Stop callback thread
        if self.client_socket:
            self.client_socket.close()
        logging.info("[TLSClient] Client stopped")

def psk_server_callback(connection, identity):
logging.info(f"[TLSServer] PSK server callback called with identity: {identity}")
if identity == b'client-identity':
key = b'1a2b3c4d5e6f'
logging.info(f"[TLSServer] Returning key: {key}")
return key
return None

class TLSServer:
def init(self, config):
self.context = SSL.Context(SSL.DTLS_METHOD)
self.context.set_cipher_list(b'PSK-AES256-CBC-SHA')
self.context.set_psk_server_callback(psk_server_callback)
self.context.set_options(SSL.OP_NO_QUERY_MTU)
self.context.set_info_callback(lambda conn, where, ret: print(f"[TLSServer] Info: where={where}, ret={ret}, state={conn.get_state_string()}"))
# Setup cookie generation and verification
self.context.set_cookie_generate_callback(self.generate_cookie)
self.context.set_cookie_verify_callback(self.verify_cookie)

    self.server_socket = None
    self._running = False
    self.config = config
    self.ssl_conn = None

def generate_cookie(self, ssl):
        logging.info("[TLSServer] generate_cookie")
        return b"xyzzy"

def verify_cookie(self, ssl, cookie):
        logging.info("[TLSServer] verify_cookie")
        return cookie == b"xyzzy"

def log_handshake_progress(self, conn: SSL.Connection):
    state = conn.get_state_string()
    pending = conn.pending()
    cipher_name = conn.get_cipher_name()
    version = conn.get_protocol_version_name()
    logging.info(f"[TLSServer] Handshake state: {state}, Pending: {pending}, Cipher: {cipher_name}, Version: {version}")

def start_server(self):
    try:
        self._running = True
        self.server_socket = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
        #self.server_socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
        #self.server_socket.setblocking(False)
        self.server_socket.bind(self.config['address'])

        logging.info("[TLSServer] Server is running and waiting for connections...")
        s_handshaking = False
        self.context.set_timeout(30)
        s_listening = True

        import select
        while self._running:
            try:
                #ready_sockets, _, _ = select.select([self.server_socket], [], [])
                #sock = self.server_socket
                #for sock in ready_sockets:
                data, addr = self.server_socket.recvfrom(self.config['buffer_size'])
                ssl_conn = SSL.Connection(self.context, self.server_socket)          
                ssl_conn.set_accept_state()
                ssl_conn.set_tlsext_host_name(self.config['address'][0].encode())
                ssl_conn.set_ciphertext_mtu(1500)
                #self.invoke_client_callback(data, addr)
                self.log_handshake_progress(ssl_conn)
                if len(data) > 0 and data[0] == 22 and data[13] == 1:
                    logging.info("[TLSServer] Received ClientHello from client")
                    logging.info(f"[TLSServer] Received initial data from {addr}: {data}")
                    if s_listening:
                        try:
                            ssl_conn.DTLSv1_listen()
                            logging.info("[TLSServer] After DTLSv1_listen")
                        except SSL.WantReadError:
                            logging.info("[TLSServer] WantReadError during DTLSv1_listen")
                            continue
                        else:
                            s_listening = False
                            s_handshaking = True
                            logging.info("[TLSServer] s_listening=False")
                            ssl_conn.bio_write(data)
                        

                    logging.info(f"[TLSServer] Starting DTLS handshake with {addr}...")
                    while s_handshaking:
                        try:
                            self.log_handshake_progress(ssl_conn)
                            ssl_conn.do_handshake()
                            break
                        except SSL.WantReadError:
                            self.log_handshake_progress(ssl_conn)
                            self._running = False
                            s_handshaking = False
                            pass
                        except SSL.Error as e:
                            logging.error(f"[TLSServer] SSL error occurred during handshake: {e}")
                            self.log_handshake_progress(ssl_conn)
                            self._running = False
                            s_handshaking = False
                            break
                    self.log_handshake_progress(ssl_conn)
                    logging.info(f"[TLSServer] DTLS handshake with {addr} completed.")


            except SSL.Error as e:
                logging.error(f"[TLSServer] SSL error occurred: {e}")
            except Exception as e:
                logging.error(f"[TLSServer] An error occurred: {e}")

    except socket.error as e:
        logging.error(f"[TLSServer] Socket error: {e}")
    finally:
        self.cleanup()



def cleanup(self):
    self._running = False
    if self.server_socket:
        self.server_socket.close()
    logging.info("[TLSServer] Server cleaned up and stopped.")

if name == "main":
server_config = {
'address': ('localhost', 4433),
'buffer_size': 4096
}

client_config = {
    'address': ('localhost', 4433),
    'buffer_size': 4096
}

server = TLSServer(server_config)
server_thread = threading.Thread(target=server.start_server)
server_thread.start()

time.sleep(1)

""" client = TLSClient(client_config)
client_thread = threading.Thread(target=client.start_client)
client_thread.start() """

time.sleep(120)

server._running = False
#client_thread.join()
server_thread.join()`
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Development

No branches or pull requests

1 participant