LCOV - code coverage report
Current view: top level - extra/webserver/webserver/websocket - SPWebWebsocketManager.cc (source / functions) Hit Total Coverage
Test: coverage.info Lines: 93 165 56.4 %
Date: 2024-05-12 00:16:13 Functions: 20 39 51.3 %

          Line data    Source code
       1             : /**
       2             :  Copyright (c) 2024 Stappler LLC <admin@stappler.dev>
       3             : 
       4             :  Permission is hereby granted, free of charge, to any person obtaining a copy
       5             :  of this software and associated documentation files (the "Software"), to deal
       6             :  in the Software without restriction, including without limitation the rights
       7             :  to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
       8             :  copies of the Software, and to permit persons to whom the Software is
       9             :  furnished to do so, subject to the following conditions:
      10             : 
      11             :  The above copyright notice and this permission notice shall be included in
      12             :  all copies or substantial portions of the Software.
      13             : 
      14             :  THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
      15             :  IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
      16             :  FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
      17             :  AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
      18             :  LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
      19             :  OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
      20             :  THE SOFTWARE.
      21             :  **/
      22             : 
      23             : #include "SPWebWebsocketManager.h"
      24             : #include "SPWebWebsocketConnection.h"
      25             : #include "SPWebRequestController.h"
      26             : #include "SPWebRoot.h"
      27             : 
      28             : namespace STAPPLER_VERSIONIZED stappler::web {
      29             : 
      30             : constexpr auto WEBSOCKET_GUID = StringView("258EAFA5-E914-47DA-95CA-C5AB0DC85B11");
      31             : 
      32          25 : String WebsocketManager::makeAcceptKey(StringView key) {
      33          25 :         auto digest = crypto::Sha1().update(key).update(WEBSOCKET_GUID).final();
      34             : 
      35          50 :         return base64::encode<Interface>(CoderSource(digest));
      36             : }
      37             : 
      38          25 : WebsocketManager::WebsocketManager(const Host &host) : _pool(getCurrentPool()), _host(host) { }
      39           0 : WebsocketManager::~WebsocketManager() { }
      40             : 
      41           0 : WebsocketHandler * WebsocketManager::onAccept(const Request &req, pool_t *) {
      42           0 :         return nullptr;
      43             : }
      44             : 
      45           0 : bool WebsocketManager::onBroadcast(const Value & val) {
      46           0 :         return false;
      47             : }
      48             : 
      49          25 : size_t WebsocketManager::size() const {
      50          50 :         return _count.load();
      51             : }
      52             : 
      53           0 : void WebsocketManager::receiveBroadcast(const Value &val) {
      54           0 :         if (onBroadcast(val)) {
      55           0 :                 _mutex.lock();
      56           0 :                 for (auto &it : _handlers) {
      57           0 :                         it->receiveBroadcast(val);
      58             :                 }
      59           0 :                 _mutex.unlock();
      60             :         }
      61           0 : }
      62             : 
      63          25 : static void *WebsocketManager_thread(WebsocketHandler *h) {
      64             : #if LINUX
      65          25 :         pthread_setname_np(pthread_self(), "WebSocketThread");
      66             : #endif
      67             : 
      68          25 :         auto m = h->manager();
      69          25 :         perform([&] {
      70          25 :                 m->run(h);
      71          25 :         }, h->connection()->getPool(), config::TAG_WEBSOCKET, h->connection());
      72             : 
      73          25 :         WebsocketConnection::destroy(h->connection());
      74             : 
      75          25 :         return NULL;
      76             : }
      77             : 
      78          25 : Status WebsocketManager::accept(Request &req) {
      79          25 :         auto version = req.getRequestHeader("sec-websocket-version");
      80          25 :         auto key = req.getRequestHeader("sec-websocket-key");
      81          25 :         auto decKey = base64::decode<Interface>(key);
      82          25 :         if (decKey.size() != 16 || version != "13") {
      83           0 :                 req.setErrorHeader("Sec-WebSocket-Version", "13");
      84           0 :                 return HTTP_BAD_REQUEST;
      85             :         }
      86             : 
      87          25 :         allocator_t *alloc = nullptr;
      88          25 :         pool_t *pool = nullptr;
      89             : 
      90             :         auto FailCleanup = [&] (Status code) SP_COVERAGE_TRIVIAL -> Status {
      91             :                 if (pool) {
      92             :                         pool::destroy(pool);
      93             :                 }
      94             : 
      95             :                 if (alloc) {
      96             :                         allocator::destroy(alloc);
      97             :                 }
      98             : 
      99             :                 return code;
     100          25 :         };
     101             : 
     102          25 :         alloc = allocator::create();
     103          25 :         pool = pool::create(alloc, memory::PoolFlags::None);
     104             : 
     105          25 :         allocator::max_free_set(alloc, 20_MiB);
     106             : 
     107          25 :         auto handler = onAccept(req, pool);
     108          25 :         if (handler) {
     109          25 :                 req.clearResponseHeaders();
     110          25 :                 req.setResponseHeader("Upgrade", "websocket");
     111          25 :                 req.setResponseHeader("Connection", "Upgrade");
     112          25 :                 req.setResponseHeader("Sec-WebSocket-Accept", makeAcceptKey(key));
     113             : 
     114          25 :                 if (auto conn = req.config()->convertToWebsocket(handler, alloc, pool)) {
     115          25 :                         auto accessRole = req.getAccessRole();
     116             : 
     117          25 :                         conn->setAccessRole(accessRole);
     118             : 
     119          25 :                         handler->setConnection(conn);
     120          25 :                         std::thread thread(WebsocketManager_thread, handler);
     121          25 :                         thread.detach();
     122          25 :                         return HTTP_OK;
     123          25 :                 }
     124             :         }
     125           0 :         if (req.getInfo().status == HTTP_OK) {
     126           0 :                 return FailCleanup(HTTP_BAD_REQUEST);
     127             :         }
     128           0 :         return FailCleanup(req.getInfo().status);
     129          25 : }
     130             : 
     131          25 : void WebsocketManager::run(WebsocketHandler *h) {
     132          25 :         auto c = h->connection();
     133          25 :         c->run(h, [&, this] {
     134          25 :                 addHandler(h);
     135          50 :         }, [&, this] {
     136          25 :                 removeHandler(h);
     137          25 :         });
     138          25 : }
     139             : 
     140          25 : void WebsocketManager::addHandler(WebsocketHandler * h) {
     141          25 :         _mutex.lock();
     142          25 :         _handlers.emplace_back(h);
     143          25 :         ++ _count;
     144          25 :         _mutex.unlock();
     145          25 : }
     146             : 
     147          25 : void WebsocketManager::removeHandler(WebsocketHandler * h) {
     148          25 :         _mutex.lock();
     149          25 :         auto it = _handlers.begin();
     150          25 :         while (it != _handlers.end() && *it != h) {
     151           0 :                 ++ it;
     152             :         }
     153          25 :         if (it != _handlers.end()) {
     154          25 :                 _handlers.erase(it);
     155             :         }
     156          25 :         -- _count;
     157          25 :         _mutex.unlock();
     158          25 : }
     159             : 
     160          25 : WebsocketHandler::WebsocketHandler(WebsocketManager *m, pool_t *p, StringView url, TimeInterval ttl, size_t max)
     161          25 : : _pool(p), _manager(m), _url(url.pdup(_pool)), _ttl(ttl), _maxInputFrameSize(max), _broadcastMutex() { }
     162             : 
     163           0 : WebsocketHandler::~WebsocketHandler() { }
     164             : 
     165             : // Data frame was received from network
     166           0 : bool WebsocketHandler::handleFrame(WebsocketFrameType, const Bytes &) { return true; }
     167             : 
     168             : // Message was received from broadcast
     169           0 : bool WebsocketHandler::handleMessage(const Value &) { return true; }
     170             : 
     171          50 : void WebsocketHandler::sendBroadcast(Value &&val) const {
     172             :         Value bcast {
     173         100 :                 std::make_pair("server", Value(_manager->host().getHostInfo().hostname)),
     174         100 :                 std::make_pair("url", Value(_url)),
     175         100 :                 std::make_pair("data", Value(std::move(val))),
     176         350 :         };
     177             : 
     178          50 :         performWithStorage([&] (const db::Transaction &t) {
     179          50 :                 t.getAdapter().broadcast(bcast);
     180          50 :         });
     181          50 : }
     182             : 
     183           0 : void WebsocketHandler::setEncodeFormat(const data::EncodeFormat &fmt) {
     184           0 :         _format = fmt;
     185           0 : }
     186             : 
     187        2850 : bool WebsocketHandler::send(StringView str) {
     188        2850 :         return _conn->write(WebsocketFrameType::Text, (const uint8_t *)str.data(), str.size());
     189             : }
     190           0 : bool WebsocketHandler::send(BytesView bytes) {
     191           0 :         return _conn->write(WebsocketFrameType::Binary, bytes.data(), bytes.size());
     192             : }
     193          25 : bool WebsocketHandler::send(const Value &data) {
     194          25 :         if (_format.isTextual()) {
     195          25 :                 StringStream stream;
     196          25 :                 stream << _format << data;
     197          25 :                 return send(StringView(stream.weak()));
     198          25 :         } else {
     199           0 :                 return send(data::write(data, _format));
     200             :         }
     201             : }
     202             : 
     203         500 : void WebsocketHandler::performWithStorage(const Callback<void(const db::Transaction &)> &cb) const {
     204         500 :         _manager->host().performWithStorage(cb);
     205         500 : }
     206             : 
     207           0 : bool WebsocketHandler::performAsync(const Callback<void(AsyncTask &)> &cb) const {
     208           0 :         return _conn->performAsync(cb);
     209             : }
     210             : 
     211           0 : pool_t *WebsocketHandler::pool() const {
     212           0 :         return _conn->getHandlePool();
     213             : }
     214             : 
     215          25 : void WebsocketHandler::setConnection(WebsocketConnection *c) {
     216          25 :         _conn = c;
     217          25 : }
     218             : 
     219           0 : void WebsocketHandler::receiveBroadcast(const Value &data) {
     220           0 :         if (_conn->isEnabled()) {
     221           0 :                 _broadcastMutex.lock();
     222           0 :                 if (!_broadcastsPool) {
     223           0 :                         _broadcastsPool = memory::pool::create(_pool);
     224             :                 }
     225           0 :                 if (_broadcastsPool) {
     226           0 :                         perform([&, this] {
     227           0 :                                 if (!_broadcastsMessages) {
     228           0 :                                         _broadcastsMessages = new (_broadcastsPool) Vector<Value>(_broadcastsPool);
     229             :                                 }
     230             : 
     231           0 :                                 _broadcastsMessages->emplace_back(data);
     232           0 :                         }, _broadcastsPool, config::TAG_WEBSOCKET, _conn);
     233             :                 }
     234           0 :                 _broadcastMutex.unlock();
     235           0 :                 _conn->wakeup();
     236             :         }
     237           0 : }
     238             : 
     239           0 : bool WebsocketHandler::processBroadcasts() {
     240             :         pool_t *pool;
     241             :         Vector<Value> * vec;
     242             : 
     243           0 :         _broadcastMutex.lock();
     244             : 
     245           0 :         pool = _broadcastsPool;
     246           0 :         vec = _broadcastsMessages;
     247             : 
     248           0 :         _broadcastsPool = nullptr;
     249           0 :         _broadcastsMessages = nullptr;
     250             : 
     251           0 :         _broadcastMutex.unlock();
     252             : 
     253           0 :         bool ret = true;
     254           0 :         if (pool) {
     255           0 :                 perform([&, this] {
     256           0 :                         sendPendingNotifications(pool);
     257           0 :                         if (vec) {
     258           0 :                                 for (auto & it : (*vec)) {
     259           0 :                                         if (!handleMessage(it)) {
     260           0 :                                                 ret = false;
     261           0 :                                                 break;
     262             :                                         }
     263             :                                 }
     264             :                         }
     265           0 :                 }, pool, config::TAG_WEBSOCKET, _conn);
     266           0 :                 pool::destroy(pool);
     267             :         }
     268             : 
     269           0 :         return ret;
     270             : }
     271             : 
     272         249 : void WebsocketHandler::sendPendingNotifications(pool_t *pool) {
     273        1245 :         perform([&, this] {
     274         249 :                 _manager->host().getRoot()->setErrorNotification(pool, [this] (Value && data) {
     275           0 :                         send(Value({
     276           0 :                                 std::make_pair("error", Value(std::move(data)))
     277           0 :                         }));
     278           0 :                 }, [this] (Value && data) {
     279           0 :                         send(Value({
     280           0 :                                 std::make_pair("debug", Value(std::move(data)))
     281           0 :                         }));
     282           0 :                 });
     283         498 :         }, pool, config::TAG_WEBSOCKET, _conn);
     284         249 : }
     285             : 
     286             : }

Generated by: LCOV version 1.14