0
0
Fork 0

fix(ws) concurrent access to websocket

This commit is contained in:
kerms 2024-06-22 14:28:48 +08:00
parent 8fba1208f3
commit 209f62866b
10 changed files with 219 additions and 53 deletions

View File

@ -0,0 +1,9 @@
file(GLOB SOURCES
*.c
)
idf_component_register(
SRCS ${SOURCES}
INCLUDE_DIRS "."
REQUIRES utils wt_common
)

View File

@ -39,12 +39,15 @@ void *memory_pool_get(uint32_t tick_wait)
{
void *ptr = NULL;
xQueueReceive(buf_queue, &ptr, tick_wait);
assert(ptr);
return ptr;
}
void memory_pool_put(void *ptr)
{
//printf("put buf %d\n", uxQueueMessagesWaiting(buf_queue));
#ifdef WT_DEBUG_MODE
printf("put buf %d\n", uxQueueMessagesWaiting(buf_queue));
#endif
if (unlikely(xQueueSend(buf_queue, &ptr, 0) != pdTRUE)) {
assert(0);
}

View File

@ -13,6 +13,7 @@
#include "memory_pool.h"
#include "request_runner.h"
#include "uart_tcp_bridge.h"
#include "global_module.h"
#include <assert.h>
@ -27,6 +28,9 @@ void app_main()
wifi_manager_init();
DAP_Setup();
global_module_init();
start_webserver();
xTaskCreate(tcp_server_task, "tcp_server", 4096, NULL, 14, NULL);

View File

@ -5,5 +5,6 @@ file(GLOB SOURCES *.c
idf_component_register(
SRCS ${SOURCES}
INCLUDE_DIRS "."
REQUIRES json request_runner
REQUIRES json request_runner wt_common
PRIV_REQUIRES memory_pool
)

View File

@ -8,6 +8,13 @@
typedef struct api_json_req_t {
cJSON *in;
cJSON *out;
union {
struct {
uint8_t big_buffer: 1;
uint8_t reserved: 7;
};
uint8_t out_flag;
};
} api_json_req_t;
typedef struct api_json_module_req_t {
@ -26,6 +33,9 @@ typedef enum api_json_req_status_e {
API_JSON_ASYNC = 1,
API_JSON_BAD_REQUEST = 2,
API_JSON_INTERNAL_ERR = 3,
API_JSON_UNSUPPORTED_CMD = 4,
API_JSON_PROPERTY_ERR = 5,
API_JSON_BUSY = 6,
} api_json_req_status_e;
typedef int (*api_json_on_req)(uint16_t cmd, api_json_req_t *req, api_json_module_async_t *rsp);
@ -41,8 +51,8 @@ void api_json_module_dump();
int api_json_module_add(api_json_init_func);
#define API_JSON_MODULE_REGISTER(PRI, INIT) \
__attribute__((used, constructor(PRI))) void cons_ ## INIT(); \
#define API_JSON_MODULE_REGISTER(INIT) \
__attribute__((used, constructor)) void cons_ ## INIT(); \
void cons_ ## INIT() { api_json_module_add(INIT); }
int api_json_module_call(uint8_t id, uint16_t cmd, api_json_req_t *in, api_json_module_async_t *out);

View File

@ -20,9 +20,6 @@ httpd_resp_send(req, (const char *)filename##_start, file_size);
static esp_err_t html_base_get_handler(httpd_req_t *req)
{
char *buf;
size_t buf_len;
/* this "hash" actually use the first 4 chars as an int32_t "hash" */
const int *URI_HASH = (const int *)req->uri;

View File

@ -16,13 +16,25 @@
#include <freertos/FreeRTOS.h>
#include <freertos/queue.h>
#include <lwipopts.h>
#include <lwip/netdb.h>
#define TAG __FILE_NAME__
#define MSG_BUSY_ERROR "{\"error\":\"Resource busy\"}"
#define MSG_JSON_ERROR "{\"error\":\"JSON parse error\"}"
#define MSG_BAD_REQUEST_ERROR "{\"error\":\"Bad json request\"}"
#define MSG_SEND_JSON_ERROR "{\"error\":\"JSON generation error\"}"
#define MSG_INTERNAL_ERROR "{\"error\":\"Internal error\"}"
// Error message templates
#define ERROR_MSG_TEMPLATE(err_msg, err_code) "{\"error\":\"" err_msg "\", \"code\":" #err_code "}"
// Error messages
#define MSG_BAD_REQUEST_ERROR ERROR_MSG_TEMPLATE("Bad json request", 2)
#define MSG_JSON_ERROR ERROR_MSG_TEMPLATE("JSON parse error", 3)
#define MSG_SEND_JSON_ERROR ERROR_MSG_TEMPLATE("JSON generation error", 3)
#define MSG_INTERNAL_ERROR ERROR_MSG_TEMPLATE("Internal error", 3)
#define MSG_UNSUPPORTED_CMD ERROR_MSG_TEMPLATE("Unsupported cmd", 4)
#define MSG_PROPERTY_ERROR ERROR_MSG_TEMPLATE("Property error", 5)
#define MSG_BUSY_ERROR ERROR_MSG_TEMPLATE("Resource busy", 6)
#define GET_FD_IDX(fd) ((fd) - LWIP_SOCKET_OFFSET)
#define WS_MODULE_ID 3
@ -42,9 +54,17 @@ struct ws_ctx_t {
struct ws_client_info_t {
httpd_handle_t hd;
int fd; /* range 64 - max_socket ~ 64 */
} clients[CONFIG_LWIP_MAX_SOCKETS + 1];
} clients[CONFIG_LWIP_MAX_SOCKETS];
TaskHandle_t task_heartbeat;
/* valid <client_count> fd are stored at beginning of the array
* use GET_FD_IDX to get the index in clients */
uint8_t valid_fd[CONFIG_LWIP_MAX_SOCKETS];
int8_t client_count;
struct {
SemaphoreHandle_t mutex;
StaticSemaphore_t xMutexBuffer;
} lock[CONFIG_LWIP_MAX_SOCKETS];
} ws_ctx;
static int ws_on_text_data(httpd_req_t *req, ws_msg_t *ws_msg);
@ -52,6 +72,9 @@ static int ws_on_binary_data(httpd_req_t *req, ws_msg_t *ws_msg);
static int ws_on_socket_open(httpd_req_t *req);
static int ws_on_close(httpd_req_t *req, httpd_ws_frame_t *ws_pkt, void *msg);
/* send with lock per fd */
static inline int ws_send_frame_safe(httpd_handle_t hd, int fd, httpd_ws_frame_t *frame);
static void ws_async_resp(void *arg);
static void async_send_out_cb(void *arg, int module_status);
static void json_to_text(ws_msg_t *msg);
@ -69,10 +92,11 @@ static esp_err_t ws_req_handler(httpd_req_t *req)
if (unlikely(req->method == HTTP_GET)) {
return ws_on_socket_open(req);
}
#ifdef WT_DEBUG_MODE
ESP_LOGI(TAG, "ws_handler: httpd_handle_t=%p, sockfd=%d, client_info:%d, client_count: %d", req->handle,
httpd_req_to_sockfd(req), httpd_ws_get_fd_info(req->handle, httpd_req_to_sockfd(req)),
ws_ctx.client_count);
#endif
int err = ESP_OK;
httpd_ws_frame_t *ws_pkt;
@ -85,7 +109,7 @@ static esp_err_t ws_req_handler(httpd_req_t *req)
resp_pkt.len = strlen(MSG_BUSY_ERROR);
resp_pkt.payload = (uint8_t *)MSG_BUSY_ERROR;
resp_pkt.final = 1;
httpd_ws_send_frame_async(req->handle, httpd_req_to_sockfd(req), &resp_pkt);
ws_send_frame_safe(req->handle, httpd_req_to_sockfd(req), &resp_pkt);
goto end;
}
ws_pkt = &ws_msg->ws_pkt;
@ -97,17 +121,19 @@ static esp_err_t ws_req_handler(httpd_req_t *req)
ESP_LOGE(TAG, "ws recv len error");
return ws_on_close(req, ws_pkt, ws_msg);
}
#ifdef WT_DEBUG_MODE
ESP_LOGI(TAG, "frame len: %d, type: %d", ws_pkt->len, ws_pkt->type);
#endif
if (unlikely(ws_pkt->len > PAYLOAD_LEN)) {
ESP_LOGE(TAG, "frame len is too big");
return ws_on_close(req, ws_pkt, ws_msg);
}
ws_pkt->payload = ws_msg->payload;
switch (ws_pkt->type) {
case HTTPD_WS_TYPE_CONTINUE:
ESP_LOGE(TAG, "WS Continue not handled");
goto end;
case HTTPD_WS_TYPE_TEXT:
ws_pkt->payload = ws_msg->payload;
/* read incoming data */
err = httpd_ws_recv_frame(req, ws_pkt, ws_pkt->len);
if (unlikely(err != ESP_OK)) {
@ -118,11 +144,23 @@ static esp_err_t ws_req_handler(httpd_req_t *req)
case HTTPD_WS_TYPE_BINARY:
return ws_on_binary_data(req, ws_msg);
case HTTPD_WS_TYPE_CLOSE:
if ((err = httpd_ws_recv_frame(req, ws_pkt, 126)) != ESP_OK) {
ESP_LOGE(TAG, "Cannot receive the full CLOSE frame");
goto end;
}
return ws_on_close(req, ws_pkt, ws_msg);
case HTTPD_WS_TYPE_PING:
/* Now turn the frame to PONG */
ESP_LOGI(TAG, "PING received");
if ((err = httpd_ws_recv_frame(req, ws_pkt, 126)) != ESP_OK) {
ESP_LOGE(TAG, "Cannot receive the full PONG frame");
goto end;
}
ws_pkt->type = HTTPD_WS_TYPE_PONG;
err = httpd_ws_send_frame(req, ws_pkt);
err = ws_send_frame_safe(req->handle, httpd_req_to_sockfd(req), ws_pkt);
if (err) {
ESP_LOGE(TAG, "Cannot send PONG frame %s", esp_err_to_name(err));
}
goto end;
case HTTPD_WS_TYPE_PONG:
err = ESP_OK;
@ -159,6 +197,7 @@ int ws_on_text_data(httpd_req_t *req, ws_msg_t *ws_msg)
}
ws_msg->json.out = NULL;
ws_msg->json.out_flag = 0;
ret = api_json_route(&ws_msg->json, &ws_msg->async);
if (ret == API_JSON_ASYNC) {
ws_msg->hd = req->handle;
@ -175,6 +214,7 @@ int ws_on_text_data(httpd_req_t *req, ws_msg_t *ws_msg)
ws_set_err_msg(ws_pkt, ret);
goto end;
} else if (ws_msg->json.out == NULL) {
/* API exec ok */
goto end;
}
@ -184,25 +224,34 @@ int ws_on_text_data(httpd_req_t *req, ws_msg_t *ws_msg)
end:
cJSON_Delete(ws_msg->json.in);
put_buf:
httpd_ws_send_frame_async(req->handle, httpd_req_to_sockfd(req), ws_pkt);
ws_send_frame_safe(req->handle, httpd_req_to_sockfd(req), ws_pkt);
memory_pool_put(ws_msg);
return err;
}
void ws_set_err_msg(httpd_ws_frame_t *p_frame, api_json_req_status_e ret)
{
p_frame->final = 1;
switch (ret) {
case API_JSON_BAD_REQUEST:
p_frame->len = strlen(MSG_BAD_REQUEST_ERROR);
p_frame->payload = (uint8_t *)MSG_BAD_REQUEST_ERROR;
p_frame->final = 1;
break;
case API_JSON_INTERNAL_ERR:
p_frame->len = strlen(MSG_INTERNAL_ERROR);
p_frame->payload = (uint8_t *)MSG_INTERNAL_ERROR;
p_frame->final = 1;
case API_JSON_BUSY:
p_frame->len = strlen(MSG_BUSY_ERROR);
p_frame->payload = (uint8_t *)MSG_BUSY_ERROR;
break;
case API_JSON_UNSUPPORTED_CMD:
p_frame->len = strlen(MSG_UNSUPPORTED_CMD);
p_frame->payload = (uint8_t *)MSG_UNSUPPORTED_CMD;
break;
case API_JSON_PROPERTY_ERR:
p_frame->len = strlen(MSG_PROPERTY_ERROR);
p_frame->payload = (uint8_t *)MSG_PROPERTY_ERROR;
break;
default:
p_frame->len = strlen(MSG_INTERNAL_ERROR);
p_frame->payload = (uint8_t *)MSG_INTERNAL_ERROR;
return;
}
}
@ -235,7 +284,7 @@ static int ws_on_close(httpd_req_t *req, httpd_ws_frame_t *ws_pkt, void *msg)
ws_pkt->type = HTTPD_WS_TYPE_CLOSE;
ESP_LOGI(TAG, "ws %d closed", httpd_req_to_sockfd(req));
ws_rm_fd(httpd_req_to_sockfd(req));
int err = httpd_ws_send_frame(req, ws_pkt);
int err = httpd_ws_send_frame_async(req, httpd_req_to_sockfd(req), ws_pkt);
if (err) {
ESP_LOGE(TAG, "on close %s", esp_err_to_name(err));
}
@ -252,7 +301,7 @@ static void ws_async_resp(void *arg)
int err;
ESP_LOGI(TAG, "ws async fd : %d", fd);
err = httpd_ws_send_frame_async(hd, fd, &req->ws_pkt);
err = ws_send_frame_safe(hd, fd, &req->ws_pkt);
if (unlikely(err)) {
ESP_LOGE(TAG, "%s", esp_err_to_name(err));
}
@ -286,7 +335,7 @@ void json_to_text(ws_msg_t *ws_msg)
int err;
httpd_ws_frame_t *ws_pkt = &ws_msg->ws_pkt;
/* api function returns something, send it to http client */
err = !cJSON_PrintPreallocated(ws_msg->json.out, (char *)ws_msg->payload, PAYLOAD_LEN - 5, 0);
err = !cJSON_PrintPreallocated(ws_msg->json.out, (char *)ws_pkt->payload, PAYLOAD_LEN - 5, 0);
cJSON_Delete(ws_msg->json.out);
if (unlikely(err)) {
ws_pkt->len = strlen(MSG_SEND_JSON_ERROR);
@ -301,12 +350,14 @@ void json_to_text(ws_msg_t *ws_msg)
* */
static inline int8_t ws_add_fd(httpd_handle_t hd, int fd)
{
if (ws_ctx.client_count > CONFIG_LWIP_MAX_SOCKETS) {
if (ws_ctx.client_count >= CONFIG_LWIP_MAX_SOCKETS) {
return 1;
}
ws_ctx.clients[ws_ctx.client_count].fd = fd;
ws_ctx.clients[ws_ctx.client_count].hd = hd;
uint8_t idx = GET_FD_IDX(fd);
ws_ctx.clients[idx].fd = fd;
ws_ctx.clients[idx].hd = hd;
ws_ctx.valid_fd[ws_ctx.client_count] = fd;
ws_ctx.client_count++;
return 0;
}
@ -318,16 +369,25 @@ static inline int8_t ws_add_fd(httpd_handle_t hd, int fd)
static inline void ws_rm_fd(int fd)
{
for (int i = 0; i < ws_ctx.client_count; ++i) {
if (ws_ctx.clients[i].fd == fd) {
ws_ctx.clients[i].fd = ws_ctx.clients[ws_ctx.client_count - 1].fd;
ws_ctx.clients[i].hd = ws_ctx.clients[ws_ctx.client_count - 1].hd;
if (ws_ctx.valid_fd[i] != fd) {
continue;
}
ws_ctx.valid_fd[i] = ws_ctx.valid_fd[ws_ctx.client_count - 1];
ws_ctx.client_count--;
return;
}
}
static inline int ws_send_frame_safe(httpd_handle_t hd, int fd, httpd_ws_frame_t *frame)
{
int err;
xSemaphoreTake(ws_ctx.lock[GET_FD_IDX(fd)].mutex, portMAX_DELAY);
err = httpd_ws_send_frame_async(hd, fd, frame);
xSemaphoreGive(ws_ctx.lock[GET_FD_IDX(fd)].mutex);
return err;
}
static void send_heartbeat(void *arg)
static inline void ws_broadcast_heartbeat()
{
static httpd_ws_frame_t ws_pkt = {
.len = 0,
@ -335,23 +395,14 @@ static void send_heartbeat(void *arg)
.type = HTTPD_WS_TYPE_TEXT,
};
struct ws_client_info_t *client_info = arg;
int err;
err = httpd_ws_send_frame_async(client_info->hd, client_info->fd, &ws_pkt);
if (err) {
ws_rm_fd(client_info->fd);
httpd_sess_trigger_close(client_info->hd, client_info->fd);
ESP_LOGE(TAG, "hb send err: %s", esp_err_to_name(err));
}
}
static inline void ws_broadcast_heartbeat()
{
int err;
for (int i = 0; i < ws_ctx.client_count; ++i) {
err = httpd_queue_work(ws_ctx.clients[i].hd, send_heartbeat, &ws_ctx.clients[i]);
uint8_t idx = GET_FD_IDX(ws_ctx.valid_fd[i]);
err = ws_send_frame_safe(ws_ctx.clients[idx].hd, ws_ctx.clients[idx].fd, &ws_pkt);
if (err) {
ESP_LOGE(TAG, "hb queue work err: %s", esp_err_to_name(err));
ws_rm_fd(ws_ctx.clients[idx].fd);
httpd_sess_trigger_close(ws_ctx.clients[idx].hd, ws_ctx.clients[idx].fd);
ESP_LOGE(TAG, "hb send err: %s", esp_err_to_name(err));
}
}
}
@ -385,6 +436,11 @@ static int WS_REQ_INIT(const httpd_uri_t **uri_conf)
{
*uri_conf = &uri_api;
xTaskCreate(heartbeat_task, "hb task", 2048, NULL, 3, &ws_ctx.task_heartbeat);
ws_ctx.client_count = 0;
for (int i = 0; i < CONFIG_LWIP_MAX_SOCKETS; ++i) {
ws_ctx.lock[i].mutex = xSemaphoreCreateMutexStatic(&ws_ctx.lock[i].xMutexBuffer);
}
return 0;
}

View File

@ -130,6 +130,4 @@ static int wifi_api_json_init(api_json_module_cfg_t *cfg)
return 0;
}
API_JSON_MODULE_REGISTER(0x90, wifi_api_json_init)
API_JSON_MODULE_REGISTER(wifi_api_json_init)

View File

@ -0,0 +1,3 @@
idf_component_register(
INCLUDE_DIRS "."
)

View File

@ -0,0 +1,85 @@
/*
* SPDX-FileCopyrightText: 2024 kerms <kerms@niazo.org>
*
* SPDX-License-Identifier: Apache-2.0
*/
#ifndef WT_DATA_DEF_H_GUARD
#define WT_DATA_DEF_H_GUARD
#include <stdint.h>
typedef enum wt_data_type_t {
WT_DATA_RESERVED = 0x00,
WT_DATA_EVENT = 0x02,
/* broadcast data */
WT_DATA_RAW_BROADCAST = 0x10,
WT_DATA_CMD_BROADCAST = 0x11,
/* targeted data */
WT_DATA_ROUTE_HDR = 0x20,
WT_DATA_RAW = 0x21,
WT_DATA_CMD = 0x22,
WT_DATA_RESPONSE = 0x23,
/* standard protocols */
WT_DATA_PROTOBUF = 0x40,
WT_DATA_JSON = 0x41,
WT_DATA_MQTT = 0x42,
WT_USER_DATA_TYPE_BEGIN = 0xA0,
WT_USER_DATA_TYPE_END = 0xFE,
WT_DATA_TYPE_MAX = 0xFF,
} __attribute__((packed)) wt_data_type_t;
_Static_assert(sizeof(wt_data_type_t) == 1, "wt_data_type_t must be 1 byte");
typedef struct wt_bin_data_hdr_t {
wt_data_type_t data_type; /* type of the hdr+payload */
union {
/* when targeted message -> bin data handle */
struct {
uint8_t module_id; /* src when broadcast, else target module */
uint8_t sub_id; /* src when broadcast, else target sub_id */
};
/* not used, only for make the union == 3B */
struct {
uint8_t dummy1;
uint8_t dummy2;
uint8_t dummy3;
} dummy;
};
} wt_bin_data_hdr_t;
_Static_assert(sizeof(wt_bin_data_hdr_t) == 4, "wt_data_4B_hdr_t must be 4 byte");
typedef struct wt_bin_data_t {
wt_bin_data_hdr_t hdr;
uint8_t payload[0];
} wt_bin_data_t;
typedef struct wt_bin_data_internal_t {
struct {
uint64_t Dummy1;
uint64_t Dummy2;
} Dummy; /* 16 byte padding for httpd_ws_frame */
struct {
uint16_t data_len;
uint8_t src_module;
uint8_t src_sub_module;
};
wt_bin_data_t data;
} wt_bin_data_internal_t;
typedef union wt_port_info {
uint8_t name[32]; /* for ease identification */
struct { /* for socket */
uint32_t foreign_ip;
uint16_t local_port;
uint16_t foreign_port;
};
struct { /* for peripheral port number */
uint8_t periph_num;
};
} wt_port_info_t;
#endif //WT_DATA_DEF_H_GUARD