Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
106 changes: 100 additions & 6 deletions src/AsyncTCP.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -580,8 +580,11 @@ extern "C" {
/*
Async TCP Client
*/

#if ASYNC_TCP_SSL_ENABLED
AsyncClient::AsyncClient(tcp_pcb* pcb, tcp_pcb* server_pcb)
#else
AsyncClient::AsyncClient(tcp_pcb* pcb)
#endif
: _connect_cb(0)
, _connect_cb_arg(0)
, _discard_cb(0)
Expand Down Expand Up @@ -628,6 +631,20 @@ AsyncClient::AsyncClient(tcp_pcb* pcb)
tcp_sent(_pcb, &_tcp_sent);
tcp_err(_pcb, &_tcp_error);
tcp_poll(_pcb, &_tcp_poll, 1);
#if ASYNC_TCP_SSL_ENABLED
if(server_pcb){
if(tcp_ssl_new_server_client(_pcb, this, server_pcb) < 0){
_close();
return;
}
tcp_ssl_data(_pcb, &_s_data);
tcp_ssl_handshake(_pcb, &_s_handshake);
tcp_ssl_err(_pcb, &_s_ssl_error);

_pcb_secure = true;
_handshake_done = false;
}
#endif
}
}

Expand Down Expand Up @@ -662,7 +679,7 @@ AsyncClient& AsyncClient::operator=(const AsyncClient& other){
_handshake_done = false;
tcp_ssl_arg(_pcb, this);
tcp_ssl_data(_pcb, &_s_data);
tcp_ssl_handshake(_pcb, &_s_handshake);
tcp_ssl_handshake(_pcb, &_s_handshake);
tcp_ssl_err(_pcb, &_s_ssl_error);
} else {
_pcb_secure = false;
Expand Down Expand Up @@ -775,6 +792,7 @@ bool AsyncClient::connect(IPAddress ip, uint16_t port){
tcp_sent(pcb, &_tcp_sent);
tcp_poll(pcb, &_tcp_poll, 1);
//_tcp_connect(pcb, &addr, port,(tcp_connected_fn)&_s_connected);
log_d("_tcp_connect");
_tcp_connect(pcb, _closed_slot, &addr, port,(tcp_connected_fn)&_tcp_connected);
return true;
}
Expand Down Expand Up @@ -923,7 +941,7 @@ void AsyncClient::ackPacket(struct pbuf * pb){
* */

int8_t AsyncClient::_close(){
//ets_printf("X: 0x%08x\n", (uint32_t)this);
log_d("X: 0x%08x", (uint32_t)this);
int8_t err = ERR_OK;
if(_pcb) {
//log_i("");
Expand Down Expand Up @@ -1091,6 +1109,7 @@ int8_t AsyncClient::_sent(tcp_pcb* pcb, uint16_t len) {
}

int8_t AsyncClient::_recv(tcp_pcb* pcb, pbuf* pb, int8_t err) {
log_d("_recv_recv_recv_recv_recv_recv_recv");
while(pb != NULL) {
_rx_last_packet = millis();
pbuf *nxt = pb->next;
Expand All @@ -1105,7 +1124,9 @@ int8_t AsyncClient::_recv(tcp_pcb* pcb, pbuf* pb, int8_t err) {
pbuf_free(pb);
// handle errors
if(err < 0){
if (err != MBEDTLS_ERR_SSL_PEER_CLOSE_NOTIFY) {
if (err == MBEDTLS_ERR_SSL_PEER_CLOSE_NOTIFY) {
tcp_ssl_free(pcb);
} else {
log_e("_recv err: %d\n", err);
_close();
}
Expand Down Expand Up @@ -1462,6 +1483,9 @@ AsyncServer::AsyncServer(IPAddress addr, uint16_t port)
, _pcb(0)
, _connect_cb(0)
, _connect_cb_arg(0)
#if ASYNC_TCP_SSL_ENABLED
, _secure(false)
#endif // ASYNC_TCP_SSL_ENABLED
{}

AsyncServer::AsyncServer(uint16_t port)
Expand All @@ -1471,6 +1495,9 @@ AsyncServer::AsyncServer(uint16_t port)
, _pcb(0)
, _connect_cb(0)
, _connect_cb_arg(0)
#if ASYNC_TCP_SSL_ENABLED
, _secure(false)
#endif // ASYNC_TCP_SSL_ENABLED
{}

AsyncServer::~AsyncServer(){
Expand All @@ -1486,6 +1513,9 @@ void AsyncServer::begin(){
if(_pcb) {
return;
}
#if ASYNC_TCP_SSL_ENABLED
_secure = false;
#endif // ASYNC_TCP_SSL_ENABLED

if(!_start_async_task()){
log_e("failed to start task");
Expand Down Expand Up @@ -1515,12 +1545,18 @@ void AsyncServer::begin(){
log_e("listen_pcb == NULL");
return;
}
log_d("pcb 0x%08x", _pcb);
tcp_arg(_pcb, (void*) this);
tcp_accept(_pcb, &_s_accept);
}

void AsyncServer::end(){
if(_pcb){
#if ASYNC_TCP_SSL_ENABLED
if(_secure){
tcp_ssl_free(_pcb);
}
#endif // ASYNC_TCP_SSL_ENABLED
tcp_arg(_pcb, NULL);
tcp_accept(_pcb, NULL);
if(tcp_close(_pcb) != ERR_OK){
Expand All @@ -1532,10 +1568,15 @@ void AsyncServer::end(){

//runs on LwIP thread
int8_t AsyncServer::_accept(tcp_pcb* pcb, int8_t err){
log_d("pcb 0x%08x", pcb);
//ets_printf("+A: 0x%08x\n", pcb);
if(_connect_cb){
if(_connect_cb){
#if ASYNC_TCP_SSL_ENABLED
AsyncClient *c = new AsyncClient(pcb, _pcb);
#else
AsyncClient *c = new AsyncClient(pcb);
if(c){
#endif
if (c) {
c->setNoDelay(_noDelay);
return _tcp_accept(this, c);
}
Expand All @@ -1548,6 +1589,13 @@ int8_t AsyncServer::_accept(tcp_pcb* pcb, int8_t err){
}

int8_t AsyncServer::_accepted(AsyncClient* client){
#if ASYNC_TCP_SSL_ENABLED
if (_secure) {
client->onConnect([this](void * arg, AsyncClient *c) {
_connect_cb(_connect_cb_arg, c);
}, this);
} else
#endif
if(_connect_cb){
_connect_cb(_connect_cb_arg, client);
}
Expand Down Expand Up @@ -1576,3 +1624,49 @@ int8_t AsyncServer::_s_accept(void * arg, tcp_pcb * pcb, int8_t err){
int8_t AsyncServer::_s_accepted(void *arg, AsyncClient* client){
return reinterpret_cast<AsyncServer*>(arg)->_accepted(client);
}

#if ASYNC_TCP_SSL_ENABLED
void AsyncServer::beginSecure(const char *cert, const char *private_key_file, const char *password) {
if(_pcb) {
return;
}
_secure = true;

if(!_start_async_task()){
log_e("failed to start task");
return;
}
int8_t err;
_pcb = tcp_new_ip_type(IPADDR_TYPE_V4);
if (!_pcb){
log_e("_pcb == NULL");
return;
}

ip_addr_t local_addr;
local_addr.type = IPADDR_TYPE_V4;
local_addr.u_addr.ip4.addr = (uint32_t) _addr;
err = _tcp_bind(_pcb, &local_addr, _port);

if (err != ERR_OK) {
_tcp_close(_pcb, -1);
log_e("bind error: %d", err);
return;
}

static uint8_t backlog = 5;
_pcb = _tcp_listen_with_backlog(_pcb, backlog);
if (!_pcb) {
log_e("listen_pcb == NULL");
return;
}
log_d("pcb 0x%08x", _pcb);
if (tcp_ssl_new_server(_pcb, this, cert, strlen(cert) + 1, private_key_file, strlen(private_key_file) + 1, password) == 0) {
log_d("start accepting clients");
tcp_arg(_pcb, (void*) this);
tcp_accept(_pcb, &_s_accept);
} else {
end();
}
}
#endif // ASYNC_TCP_SSL_ENABLED
14 changes: 12 additions & 2 deletions src/AsyncTCP.h
Original file line number Diff line number Diff line change
Expand Up @@ -58,9 +58,17 @@ typedef std::function<void(void*, AsyncClient*, uint32_t time)> AcTimeoutHandler
struct tcp_pcb;
struct ip_addr;

class AsyncServer;

class AsyncClient {
public:
friend class AsyncServer;

#if ASYNC_TCP_SSL_ENABLED
AsyncClient(tcp_pcb* pcb = 0, tcp_pcb* server_pcb = 0);
#else
AsyncClient(tcp_pcb* pcb = 0);
#endif
~AsyncClient();

AsyncClient & operator=(const AsyncClient &other);
Expand Down Expand Up @@ -236,9 +244,8 @@ class AsyncServer {
~AsyncServer();
void onClient(AcConnectHandler cb, void* arg);
#if ASYNC_TCP_SSL_ENABLED
// Dummy, so it compiles with ESP Async WebServer library enabled.
void onSslFileRequest(AcSSlFileHandler cb, void* arg) {};
void beginSecure(const char *cert, const char *private_key_file, const char *password) {};
void beginSecure(const char *cert, const char *private_key_file, const char *password);
#endif
void begin();
void end();
Expand All @@ -257,6 +264,9 @@ class AsyncServer {
tcp_pcb* _pcb;
AcConnectHandler _connect_cb;
void* _connect_cb_arg;
#if ASYNC_TCP_SSL_ENABLED
bool _secure;
#endif

int8_t _accept(tcp_pcb* newpcb, int8_t err);
int8_t _accepted(AsyncClient* client);
Expand Down
Loading