diff --git a/src/AsyncTCP.cpp b/src/AsyncTCP.cpp index 169c098a..bbf9403b 100644 --- a/src/AsyncTCP.cpp +++ b/src/AsyncTCP.cpp @@ -580,8 +580,11 @@ extern "C" { /* Async TCP Client */ - +#if ASYNC_TCP_SSL_ENABLED +AsyncClient::AsyncClient(tcp_pcb* pcb, tcp_pcb* server_pcb) +#else AsyncClient::AsyncClient(tcp_pcb* pcb) +#endif : _connect_cb(0) , _connect_cb_arg(0) , _discard_cb(0) @@ -628,6 +631,20 @@ AsyncClient::AsyncClient(tcp_pcb* pcb) tcp_sent(_pcb, &_tcp_sent); tcp_err(_pcb, &_tcp_error); tcp_poll(_pcb, &_tcp_poll, 1); +#if ASYNC_TCP_SSL_ENABLED + if(server_pcb){ + if(tcp_ssl_new_server_client(_pcb, this, server_pcb) < 0){ + _close(); + return; + } + tcp_ssl_data(_pcb, &_s_data); + tcp_ssl_handshake(_pcb, &_s_handshake); + tcp_ssl_err(_pcb, &_s_ssl_error); + + _pcb_secure = true; + _handshake_done = false; + } +#endif } } @@ -662,7 +679,7 @@ AsyncClient& AsyncClient::operator=(const AsyncClient& other){ _handshake_done = false; tcp_ssl_arg(_pcb, this); tcp_ssl_data(_pcb, &_s_data); - tcp_ssl_handshake(_pcb, &_s_handshake); + tcp_ssl_handshake(_pcb, &_s_handshake); tcp_ssl_err(_pcb, &_s_ssl_error); } else { _pcb_secure = false; @@ -775,6 +792,7 @@ bool AsyncClient::connect(IPAddress ip, uint16_t port){ tcp_sent(pcb, &_tcp_sent); tcp_poll(pcb, &_tcp_poll, 1); //_tcp_connect(pcb, &addr, port,(tcp_connected_fn)&_s_connected); + log_d("_tcp_connect"); _tcp_connect(pcb, _closed_slot, &addr, port,(tcp_connected_fn)&_tcp_connected); return true; } @@ -923,7 +941,7 @@ void AsyncClient::ackPacket(struct pbuf * pb){ * */ int8_t AsyncClient::_close(){ - //ets_printf("X: 0x%08x\n", (uint32_t)this); + log_d("X: 0x%08x", (uint32_t)this); int8_t err = ERR_OK; if(_pcb) { //log_i(""); @@ -1091,6 +1109,7 @@ int8_t AsyncClient::_sent(tcp_pcb* pcb, uint16_t len) { } int8_t AsyncClient::_recv(tcp_pcb* pcb, pbuf* pb, int8_t err) { + log_d("_recv_recv_recv_recv_recv_recv_recv"); while(pb != NULL) { _rx_last_packet = millis(); pbuf *nxt = pb->next; @@ -1105,7 +1124,9 @@ int8_t AsyncClient::_recv(tcp_pcb* pcb, pbuf* pb, int8_t err) { pbuf_free(pb); // handle errors if(err < 0){ - if (err != MBEDTLS_ERR_SSL_PEER_CLOSE_NOTIFY) { + if (err == MBEDTLS_ERR_SSL_PEER_CLOSE_NOTIFY) { + tcp_ssl_free(pcb); + } else { log_e("_recv err: %d\n", err); _close(); } @@ -1462,6 +1483,9 @@ AsyncServer::AsyncServer(IPAddress addr, uint16_t port) , _pcb(0) , _connect_cb(0) , _connect_cb_arg(0) +#if ASYNC_TCP_SSL_ENABLED +, _secure(false) +#endif // ASYNC_TCP_SSL_ENABLED {} AsyncServer::AsyncServer(uint16_t port) @@ -1471,6 +1495,9 @@ AsyncServer::AsyncServer(uint16_t port) , _pcb(0) , _connect_cb(0) , _connect_cb_arg(0) +#if ASYNC_TCP_SSL_ENABLED +, _secure(false) +#endif // ASYNC_TCP_SSL_ENABLED {} AsyncServer::~AsyncServer(){ @@ -1486,6 +1513,9 @@ void AsyncServer::begin(){ if(_pcb) { return; } +#if ASYNC_TCP_SSL_ENABLED + _secure = false; +#endif // ASYNC_TCP_SSL_ENABLED if(!_start_async_task()){ log_e("failed to start task"); @@ -1515,12 +1545,18 @@ void AsyncServer::begin(){ log_e("listen_pcb == NULL"); return; } + log_d("pcb 0x%08x", _pcb); tcp_arg(_pcb, (void*) this); tcp_accept(_pcb, &_s_accept); } void AsyncServer::end(){ if(_pcb){ +#if ASYNC_TCP_SSL_ENABLED + if(_secure){ + tcp_ssl_free(_pcb); + } +#endif // ASYNC_TCP_SSL_ENABLED tcp_arg(_pcb, NULL); tcp_accept(_pcb, NULL); if(tcp_close(_pcb) != ERR_OK){ @@ -1532,10 +1568,15 @@ void AsyncServer::end(){ //runs on LwIP thread int8_t AsyncServer::_accept(tcp_pcb* pcb, int8_t err){ + log_d("pcb 0x%08x", pcb); //ets_printf("+A: 0x%08x\n", pcb); - if(_connect_cb){ + if(_connect_cb){ +#if ASYNC_TCP_SSL_ENABLED + AsyncClient *c = new AsyncClient(pcb, _pcb); +#else AsyncClient *c = new AsyncClient(pcb); - if(c){ +#endif + if (c) { c->setNoDelay(_noDelay); return _tcp_accept(this, c); } @@ -1548,6 +1589,13 @@ int8_t AsyncServer::_accept(tcp_pcb* pcb, int8_t err){ } int8_t AsyncServer::_accepted(AsyncClient* client){ +#if ASYNC_TCP_SSL_ENABLED + if (_secure) { + client->onConnect([this](void * arg, AsyncClient *c) { + _connect_cb(_connect_cb_arg, c); + }, this); + } else +#endif if(_connect_cb){ _connect_cb(_connect_cb_arg, client); } @@ -1576,3 +1624,49 @@ int8_t AsyncServer::_s_accept(void * arg, tcp_pcb * pcb, int8_t err){ int8_t AsyncServer::_s_accepted(void *arg, AsyncClient* client){ return reinterpret_cast(arg)->_accepted(client); } + +#if ASYNC_TCP_SSL_ENABLED +void AsyncServer::beginSecure(const char *cert, const char *private_key_file, const char *password) { + if(_pcb) { + return; + } + _secure = true; + + if(!_start_async_task()){ + log_e("failed to start task"); + return; + } + int8_t err; + _pcb = tcp_new_ip_type(IPADDR_TYPE_V4); + if (!_pcb){ + log_e("_pcb == NULL"); + return; + } + + ip_addr_t local_addr; + local_addr.type = IPADDR_TYPE_V4; + local_addr.u_addr.ip4.addr = (uint32_t) _addr; + err = _tcp_bind(_pcb, &local_addr, _port); + + if (err != ERR_OK) { + _tcp_close(_pcb, -1); + log_e("bind error: %d", err); + return; + } + + static uint8_t backlog = 5; + _pcb = _tcp_listen_with_backlog(_pcb, backlog); + if (!_pcb) { + log_e("listen_pcb == NULL"); + return; + } + log_d("pcb 0x%08x", _pcb); + if (tcp_ssl_new_server(_pcb, this, cert, strlen(cert) + 1, private_key_file, strlen(private_key_file) + 1, password) == 0) { + log_d("start accepting clients"); + tcp_arg(_pcb, (void*) this); + tcp_accept(_pcb, &_s_accept); + } else { + end(); + } +} +#endif // ASYNC_TCP_SSL_ENABLED diff --git a/src/AsyncTCP.h b/src/AsyncTCP.h index 48d9b68a..cf5cfddd 100644 --- a/src/AsyncTCP.h +++ b/src/AsyncTCP.h @@ -58,9 +58,17 @@ typedef std::function AcTimeoutHandler struct tcp_pcb; struct ip_addr; +class AsyncServer; + class AsyncClient { public: + friend class AsyncServer; + +#if ASYNC_TCP_SSL_ENABLED + AsyncClient(tcp_pcb* pcb = 0, tcp_pcb* server_pcb = 0); +#else AsyncClient(tcp_pcb* pcb = 0); +#endif ~AsyncClient(); AsyncClient & operator=(const AsyncClient &other); @@ -236,9 +244,8 @@ class AsyncServer { ~AsyncServer(); void onClient(AcConnectHandler cb, void* arg); #if ASYNC_TCP_SSL_ENABLED - // Dummy, so it compiles with ESP Async WebServer library enabled. void onSslFileRequest(AcSSlFileHandler cb, void* arg) {}; - void beginSecure(const char *cert, const char *private_key_file, const char *password) {}; + void beginSecure(const char *cert, const char *private_key_file, const char *password); #endif void begin(); void end(); @@ -257,6 +264,9 @@ class AsyncServer { tcp_pcb* _pcb; AcConnectHandler _connect_cb; void* _connect_cb_arg; +#if ASYNC_TCP_SSL_ENABLED + bool _secure; +#endif int8_t _accept(tcp_pcb* newpcb, int8_t err); int8_t _accepted(AsyncClient* client); diff --git a/src/tcp_mbedtls.c b/src/tcp_mbedtls.c index c6d69bde..6b7afc5b 100644 --- a/src/tcp_mbedtls.c +++ b/src/tcp_mbedtls.c @@ -10,7 +10,7 @@ extern esp_err_t _tcp_output4ssl(struct tcp_pcb * pcb, void* client); extern esp_err_t _tcp_write4ssl(struct tcp_pcb * pcb, const char* data, size_t size, uint8_t apiflags, void* client); -#if 0 +#if 1 #define TCP_SSL_DEBUG(...) do { ets_printf("T %s- ", pcTaskGetTaskName(xTaskGetCurrentTaskHandle())); ets_printf(__VA_ARGS__); } while(0) #else #define TCP_SSL_DEBUG(...) @@ -61,13 +61,17 @@ struct tcp_ssl_pcb { struct tcp_pcb *tcp; int fd; mbedtls_ssl_context ssl_ctx; + bool has_ssl_conf; mbedtls_ssl_config ssl_conf; - mbedtls_x509_crt ca_cert; bool has_ca_cert; - mbedtls_x509_crt client_cert; + mbedtls_x509_crt ca_cert; bool has_client_cert; + mbedtls_x509_crt client_cert; + bool has_client_key; mbedtls_pk_context client_key; + bool has_drbg_ctx; mbedtls_ctr_drbg_context drbg_ctx; + bool has_entropy_ctx; mbedtls_entropy_context entropy_ctx; uint8_t type; // int handshake; @@ -191,6 +195,10 @@ tcp_ssl_t * tcp_ssl_new(struct tcp_pcb *tcp, void* arg) { new_item->next = NULL; new_item->has_ca_cert = false; new_item->has_client_cert = false; + new_item->has_client_key = false; + new_item->has_entropy_ctx = false; + new_item->has_ssl_conf = false; + new_item->has_drbg_ctx = false; if(tcp_ssl_array == NULL){ tcp_ssl_array = new_item; @@ -236,6 +244,9 @@ int tcp_ssl_new_client(struct tcp_pcb *tcp, void *arg, const char* hostname, con mbedtls_ctr_drbg_init(&tcp_ssl->drbg_ctx); mbedtls_ssl_init(&tcp_ssl->ssl_ctx); mbedtls_ssl_config_init(&tcp_ssl->ssl_conf); + tcp_ssl->has_entropy_ctx = true; + tcp_ssl->has_drbg_ctx = true; + tcp_ssl->has_ssl_conf = true; if(root_ca != NULL) { mbedtls_x509_crt_init(&tcp_ssl->ca_cert); tcp_ssl->has_ca_cert = true; @@ -244,6 +255,7 @@ int tcp_ssl_new_client(struct tcp_pcb *tcp, void *arg, const char* hostname, con mbedtls_x509_crt_init(&tcp_ssl->client_cert); mbedtls_pk_init(&tcp_ssl->client_key); tcp_ssl->has_client_cert = true; + tcp_ssl->has_client_key = true; } mbedtls_ctr_drbg_seed(&tcp_ssl->drbg_ctx, mbedtls_entropy_func, @@ -326,6 +338,141 @@ int tcp_ssl_new_client(struct tcp_pcb *tcp, void *arg, const char* hostname, con return ERR_OK; } +int tcp_ssl_new_server(struct tcp_pcb *tcp, void *arg, const char *cert, const size_t cert_len, const char *private_key, const size_t private_key_len, const char *password) { + tcp_ssl_t* tcp_ssl; + + if(tcp == NULL) { + return -1; + } + + if(tcp_ssl_get(tcp) != NULL){ + return -1; + } + + tcp_ssl = tcp_ssl_new(tcp, arg); + if(tcp_ssl == NULL){ + return -1; + } + + int ret; + mbedtls_ssl_init( &tcp_ssl->ssl_ctx ); + mbedtls_ssl_config_init( &tcp_ssl->ssl_conf ); + mbedtls_x509_crt_init( &tcp_ssl->ca_cert ); + mbedtls_pk_init( &tcp_ssl->client_key ); + mbedtls_entropy_init( &tcp_ssl->entropy_ctx ); + mbedtls_ctr_drbg_init( &tcp_ssl->drbg_ctx ); + + tcp_ssl->has_entropy_ctx = true; + tcp_ssl->has_ssl_conf = true; + tcp_ssl->has_ca_cert = true; + tcp_ssl->has_client_key = true; + tcp_ssl->has_drbg_ctx = true; + + /* + * 1. Load the certificates and private RSA key + */ + TCP_SSL_DEBUG("Loading the server cert\n"); + ret = mbedtls_x509_crt_parse(&tcp_ssl->ca_cert, (const unsigned char *) cert, cert_len); + if (ret != 0) { + TCP_SSL_DEBUG("failed loading server cert, returned %d\n", ret); + tcp_ssl_free(tcp); + return handle_error(ret); + } + + TCP_SSL_DEBUG("Loading the server key\n"); + ret = mbedtls_pk_parse_key(&tcp_ssl->client_key, (const unsigned char *) private_key, private_key_len, NULL, 0); + if (ret != 0) { + TCP_SSL_DEBUG("failed loading server private key, returned %d\n", ret); + tcp_ssl_free(tcp); + return handle_error(ret); + } + + /* + * 3. Seed the RNG + */ + TCP_SSL_DEBUG("Seeding the random number generator...\n" ); + ret = mbedtls_ctr_drbg_seed(&tcp_ssl->drbg_ctx, mbedtls_entropy_func, &tcp_ssl->entropy_ctx, + (const unsigned char *) pers, + sizeof(pers)); + if (ret != 0) { + TCP_SSL_DEBUG("failed seeding the random number generator, returned %d\n", ret); + tcp_ssl_free(tcp); + return handle_error(ret); + } + + /* + * 4. Setup stuff + */ + TCP_SSL_DEBUG("Setting up the SSL data...\n" ); + ret = mbedtls_ssl_config_defaults( &tcp_ssl->ssl_conf, + MBEDTLS_SSL_IS_SERVER, + MBEDTLS_SSL_TRANSPORT_STREAM, + MBEDTLS_SSL_PRESET_DEFAULT ); + if (ret != 0) { + TCP_SSL_DEBUG("failed mbedtls_ssl_config_defaults returned %d\n", ret); + tcp_ssl_free(tcp); + return handle_error(ret); + } + + mbedtls_ssl_conf_rng(&tcp_ssl->ssl_conf, mbedtls_ctr_drbg_random, &tcp_ssl->drbg_ctx); + + + mbedtls_ssl_conf_ca_chain(&tcp_ssl->ssl_conf, tcp_ssl->ca_cert.next, NULL); + ret = mbedtls_ssl_conf_own_cert(&tcp_ssl->ssl_conf, &tcp_ssl->ca_cert, &tcp_ssl->client_key); + if (ret != 0) { + TCP_SSL_DEBUG("failed mbedtls_ssl_conf_own_cert returned %d\n", ret); + tcp_ssl_free(tcp); + return handle_error(ret); + } + + ret = mbedtls_ssl_setup(&tcp_ssl->ssl_ctx, &tcp_ssl->ssl_conf); + if (ret != 0) { + TCP_SSL_DEBUG("failed mbedtls_ssl_setup returned %d\n", ret); + tcp_ssl_free(tcp); + return handle_error(ret); + } + + TCP_SSL_DEBUG("tcp_ssl_new_server completed succesfully\n"); + + return ERR_OK; +} + +int tcp_ssl_new_server_client(struct tcp_pcb *tcp, void *arg, struct tcp_pcb *server_tcp) { + tcp_ssl_t* tcp_ssl; + tcp_ssl_t* server_tcp_ssl; + + if(tcp == NULL || server_tcp == NULL) { + return -1; + } + + if(tcp_ssl_get(tcp) != NULL){ + return -1; + } + + server_tcp_ssl = tcp_ssl_get(server_tcp); + if (server_tcp_ssl == NULL) { + return -1; + } + + tcp_ssl = tcp_ssl_new(tcp, arg); + if(tcp_ssl == NULL){ + return -1; + } + + int ret; + + mbedtls_ssl_init(&tcp_ssl->ssl_ctx); + ret = mbedtls_ssl_setup(&tcp_ssl->ssl_ctx, &server_tcp_ssl->ssl_conf); + if (ret != 0) { + TCP_SSL_DEBUG("failed: mbedtls_ssl_setup returned -0x%04x\n", -ret ); + return handle_error(ret); + } + + mbedtls_ssl_set_bio(&tcp_ssl->ssl_ctx, (void*)tcp_ssl, tcp_ssl_send, tcp_ssl_recv, NULL); + + return ERR_OK; +} + // Open an SSL connection using a PSK (pre-shared-key) cipher suite. int tcp_ssl_new_psk_client(struct tcp_pcb *tcp, void *arg, const char* psk_ident, const char* pskey) { tcp_ssl_t* tcp_ssl; @@ -352,6 +499,10 @@ int tcp_ssl_new_psk_client(struct tcp_pcb *tcp, void *arg, const char* psk_ident mbedtls_ssl_init(&tcp_ssl->ssl_ctx); mbedtls_ssl_config_init(&tcp_ssl->ssl_conf); + tcp_ssl->has_entropy_ctx = true; + tcp_ssl->has_ssl_conf = true; + tcp_ssl->has_drbg_ctx = true; + mbedtls_ctr_drbg_seed(&tcp_ssl->drbg_ctx, mbedtls_entropy_func, &tcp_ssl->entropy_ctx, (const uint8_t*)pers, sizeof(pers)); @@ -526,51 +677,72 @@ int tcp_ssl_read(struct tcp_pcb *tcp, struct pbuf *p) { tcp_ssl->tcp_pbuf = NULL; + TCP_SSL_DEBUG("tcp_ssl_read: return total_bytes: %d\r\n", total_bytes >= 0 ? 0 : total_bytes); return total_bytes >= 0 ? 0 : total_bytes; // return error code } +int tcp_ssl_handshake_step(struct tcp_pcb *tcp) { + TCP_SSL_DEBUG("tcp_ssl_handshake_step(%x)\n", tcp); + if(tcp == NULL) { + return -1; + } + + tcp_ssl_t * tcp_ssl = tcp_ssl_get(tcp); + if(tcp_ssl == NULL){ + return 0; + } + + return ERR_OK; +} + int tcp_ssl_free(struct tcp_pcb *tcp) { TCP_SSL_DEBUG("tcp_ssl_free(%x)\n", tcp); if(tcp == NULL) { return -1; } tcp_ssl_t * item = tcp_ssl_array; - if(item->tcp == tcp){ + if (item == NULL) { + return ERR_TCP_SSL_INVALID_CLIENTFD_DATA;//item not found + } + + if (item->tcp == tcp) { tcp_ssl_array = tcp_ssl_array->next; - if(item->tcp_pbuf != NULL) { - pbuf_free(item->tcp_pbuf); + } else { + while(item->next && item->next->tcp != tcp) + item = item->next; + + if(item->next == NULL || item->next->tcp != tcp){ + return ERR_TCP_SSL_INVALID_CLIENTFD_DATA;//item not found } - mbedtls_ssl_free(&item->ssl_ctx); + + tcp_ssl_t * thisItem = item->next; + item->next = item->next->next; + item = thisItem; + } + + if(item->tcp_pbuf != NULL) { + pbuf_free(item->tcp_pbuf); + } + mbedtls_ssl_free(&item->ssl_ctx); + if(item->has_ssl_conf) { mbedtls_ssl_config_free(&item->ssl_conf); + } + if(item->has_drbg_ctx) { mbedtls_ctr_drbg_free(&item->drbg_ctx); + } + if(item->has_entropy_ctx) { mbedtls_entropy_free(&item->entropy_ctx); - if(item->has_ca_cert) { - mbedtls_x509_crt_free(&item->ca_cert); - } - if (item->has_client_cert) { - mbedtls_x509_crt_free(&item->client_cert); - mbedtls_pk_free(&item->client_key); - } - free(item); - return 0; } - - while(item->next && item->next->tcp != tcp) - item = item->next; - - if(item->next == NULL){ - return ERR_TCP_SSL_INVALID_CLIENTFD_DATA;//item not found + if(item->has_ca_cert) { + mbedtls_x509_crt_free(&item->ca_cert); + } + if (item->has_client_cert) { + mbedtls_x509_crt_free(&item->client_cert); } - tcp_ssl_t * i = item->next; - item->next = i->next; - if(i->tcp_pbuf != NULL){ - pbuf_free(i->tcp_pbuf); + if (item->has_client_key) { + mbedtls_pk_free(&item->client_key); } - mbedtls_ssl_free(&i->ssl_ctx); - mbedtls_ssl_config_free(&i->ssl_conf); - mbedtls_ctr_drbg_free(&i->drbg_ctx); - mbedtls_entropy_free(&i->entropy_ctx); - free(i); + free(item); return 0; } @@ -593,10 +765,10 @@ void tcp_ssl_data(struct tcp_pcb *tcp, tcp_ssl_data_cb_t arg){ } } -void tcp_ssl_handshake(struct tcp_pcb *tcp, tcp_ssl_handshake_cb_t arg){ +void tcp_ssl_handshake(struct tcp_pcb *tcp, tcp_ssl_handshake_cb_t ssl_handshake_cb){ tcp_ssl_t * item = tcp_ssl_get(tcp); if(item) { - item->on_handshake = arg; + item->on_handshake = ssl_handshake_cb; } } diff --git a/src/tcp_mbedtls.h b/src/tcp_mbedtls.h index 7a741fdc..746de97e 100644 --- a/src/tcp_mbedtls.h +++ b/src/tcp_mbedtls.h @@ -32,6 +32,8 @@ typedef void (* tcp_ssl_error_cb_t)(void *arg, struct tcp_pcb *tcp, int8_t error uint8_t tcp_ssl_has_client(); int tcp_ssl_new_client(struct tcp_pcb *tcp, void *arg, const char* hostname, const char* root_ca, const size_t root_ca_len, const char* cli_cert, const size_t cli_cert_len, const char* cli_key, const size_t cli_key_len); +int tcp_ssl_new_server(struct tcp_pcb *tcp, void *arg, const char *cert, const size_t cert_len, const char *private_key, const size_t private_key_len, const char *password); +int tcp_ssl_new_server_client(struct tcp_pcb *tcp, void *arg, struct tcp_pcb *server_tcp); int tcp_ssl_new_psk_client(struct tcp_pcb *tcp, void *arg, const char* psk_ident, const char* psk); int tcp_ssl_write(struct tcp_pcb *tcp, uint8_t *data, size_t len); int tcp_ssl_read(struct tcp_pcb *tcp, struct pbuf *p); @@ -40,7 +42,7 @@ int tcp_ssl_free(struct tcp_pcb *tcp); bool tcp_ssl_has(struct tcp_pcb *tcp); void tcp_ssl_arg(struct tcp_pcb *tcp, void * arg); void tcp_ssl_data(struct tcp_pcb *tcp, tcp_ssl_data_cb_t arg); -void tcp_ssl_handshake(struct tcp_pcb *tcp, tcp_ssl_handshake_cb_t arg); +void tcp_ssl_handshake(struct tcp_pcb *tcp, tcp_ssl_handshake_cb_t ssl_handshake_cb); void tcp_ssl_err(struct tcp_pcb *tcp, tcp_ssl_error_cb_t arg); #ifdef __cplusplus