Line data Source code
1 : /**
2 : Copyright (c) 2024 Stappler LLC <admin@stappler.dev>
3 :
4 : Permission is hereby granted, free of charge, to any person obtaining a copy
5 : of this software and associated documentation files (the "Software"), to deal
6 : in the Software without restriction, including without limitation the rights
7 : to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
8 : copies of the Software, and to permit persons to whom the Software is
9 : furnished to do so, subject to the following conditions:
10 :
11 : The above copyright notice and this permission notice shall be included in
12 : all copies or substantial portions of the Software.
13 :
14 : THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
15 : IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
16 : FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
17 : AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
18 : LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
19 : OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
20 : THE SOFTWARE.
21 : **/
22 :
23 : #include "SPWebWebsocketManager.h"
24 : #include "SPWebWebsocketConnection.h"
25 : #include "SPWebRequestController.h"
26 : #include "SPWebRoot.h"
27 :
28 : namespace STAPPLER_VERSIONIZED stappler::web {
29 :
30 : constexpr auto WEBSOCKET_GUID = StringView("258EAFA5-E914-47DA-95CA-C5AB0DC85B11");
31 :
32 25 : String WebsocketManager::makeAcceptKey(StringView key) {
33 25 : auto digest = crypto::Sha1().update(key).update(WEBSOCKET_GUID).final();
34 :
35 50 : return base64::encode<Interface>(CoderSource(digest));
36 : }
37 :
38 25 : WebsocketManager::WebsocketManager(const Host &host) : _pool(getCurrentPool()), _host(host) { }
39 0 : WebsocketManager::~WebsocketManager() { }
40 :
41 0 : WebsocketHandler * WebsocketManager::onAccept(const Request &req, pool_t *) {
42 0 : return nullptr;
43 : }
44 :
45 0 : bool WebsocketManager::onBroadcast(const Value & val) {
46 0 : return false;
47 : }
48 :
49 25 : size_t WebsocketManager::size() const {
50 50 : return _count.load();
51 : }
52 :
53 0 : void WebsocketManager::receiveBroadcast(const Value &val) {
54 0 : if (onBroadcast(val)) {
55 0 : _mutex.lock();
56 0 : for (auto &it : _handlers) {
57 0 : it->receiveBroadcast(val);
58 : }
59 0 : _mutex.unlock();
60 : }
61 0 : }
62 :
63 25 : static void *WebsocketManager_thread(WebsocketHandler *h) {
64 : #if LINUX
65 25 : pthread_setname_np(pthread_self(), "WebSocketThread");
66 : #endif
67 :
68 25 : auto m = h->manager();
69 25 : perform([&] {
70 25 : m->run(h);
71 25 : }, h->connection()->getPool(), config::TAG_WEBSOCKET, h->connection());
72 :
73 25 : WebsocketConnection::destroy(h->connection());
74 :
75 25 : return NULL;
76 : }
77 :
78 25 : Status WebsocketManager::accept(Request &req) {
79 25 : auto version = req.getRequestHeader("sec-websocket-version");
80 25 : auto key = req.getRequestHeader("sec-websocket-key");
81 25 : auto decKey = base64::decode<Interface>(key);
82 25 : if (decKey.size() != 16 || version != "13") {
83 0 : req.setErrorHeader("Sec-WebSocket-Version", "13");
84 0 : return HTTP_BAD_REQUEST;
85 : }
86 :
87 25 : allocator_t *alloc = nullptr;
88 25 : pool_t *pool = nullptr;
89 :
90 : auto FailCleanup = [&] (Status code) SP_COVERAGE_TRIVIAL -> Status {
91 : if (pool) {
92 : pool::destroy(pool);
93 : }
94 :
95 : if (alloc) {
96 : allocator::destroy(alloc);
97 : }
98 :
99 : return code;
100 25 : };
101 :
102 25 : alloc = allocator::create();
103 25 : pool = pool::create(alloc, memory::PoolFlags::None);
104 :
105 25 : allocator::max_free_set(alloc, 20_MiB);
106 :
107 25 : auto handler = onAccept(req, pool);
108 25 : if (handler) {
109 25 : req.clearResponseHeaders();
110 25 : req.setResponseHeader("Upgrade", "websocket");
111 25 : req.setResponseHeader("Connection", "Upgrade");
112 25 : req.setResponseHeader("Sec-WebSocket-Accept", makeAcceptKey(key));
113 :
114 25 : if (auto conn = req.config()->convertToWebsocket(handler, alloc, pool)) {
115 25 : auto accessRole = req.getAccessRole();
116 :
117 25 : conn->setAccessRole(accessRole);
118 :
119 25 : handler->setConnection(conn);
120 25 : std::thread thread(WebsocketManager_thread, handler);
121 25 : thread.detach();
122 25 : return HTTP_OK;
123 25 : }
124 : }
125 0 : if (req.getInfo().status == HTTP_OK) {
126 0 : return FailCleanup(HTTP_BAD_REQUEST);
127 : }
128 0 : return FailCleanup(req.getInfo().status);
129 25 : }
130 :
131 25 : void WebsocketManager::run(WebsocketHandler *h) {
132 25 : auto c = h->connection();
133 25 : c->run(h, [&, this] {
134 25 : addHandler(h);
135 50 : }, [&, this] {
136 25 : removeHandler(h);
137 25 : });
138 25 : }
139 :
140 25 : void WebsocketManager::addHandler(WebsocketHandler * h) {
141 25 : _mutex.lock();
142 25 : _handlers.emplace_back(h);
143 25 : ++ _count;
144 25 : _mutex.unlock();
145 25 : }
146 :
147 25 : void WebsocketManager::removeHandler(WebsocketHandler * h) {
148 25 : _mutex.lock();
149 25 : auto it = _handlers.begin();
150 25 : while (it != _handlers.end() && *it != h) {
151 0 : ++ it;
152 : }
153 25 : if (it != _handlers.end()) {
154 25 : _handlers.erase(it);
155 : }
156 25 : -- _count;
157 25 : _mutex.unlock();
158 25 : }
159 :
160 25 : WebsocketHandler::WebsocketHandler(WebsocketManager *m, pool_t *p, StringView url, TimeInterval ttl, size_t max)
161 25 : : _pool(p), _manager(m), _url(url.pdup(_pool)), _ttl(ttl), _maxInputFrameSize(max), _broadcastMutex() { }
162 :
163 0 : WebsocketHandler::~WebsocketHandler() { }
164 :
165 : // Data frame was received from network
166 0 : bool WebsocketHandler::handleFrame(WebsocketFrameType, const Bytes &) { return true; }
167 :
168 : // Message was received from broadcast
169 0 : bool WebsocketHandler::handleMessage(const Value &) { return true; }
170 :
171 50 : void WebsocketHandler::sendBroadcast(Value &&val) const {
172 : Value bcast {
173 100 : std::make_pair("server", Value(_manager->host().getHostInfo().hostname)),
174 100 : std::make_pair("url", Value(_url)),
175 100 : std::make_pair("data", Value(std::move(val))),
176 350 : };
177 :
178 50 : performWithStorage([&] (const db::Transaction &t) {
179 50 : t.getAdapter().broadcast(bcast);
180 50 : });
181 50 : }
182 :
183 0 : void WebsocketHandler::setEncodeFormat(const data::EncodeFormat &fmt) {
184 0 : _format = fmt;
185 0 : }
186 :
187 2850 : bool WebsocketHandler::send(StringView str) {
188 2850 : return _conn->write(WebsocketFrameType::Text, (const uint8_t *)str.data(), str.size());
189 : }
190 0 : bool WebsocketHandler::send(BytesView bytes) {
191 0 : return _conn->write(WebsocketFrameType::Binary, bytes.data(), bytes.size());
192 : }
193 25 : bool WebsocketHandler::send(const Value &data) {
194 25 : if (_format.isTextual()) {
195 25 : StringStream stream;
196 25 : stream << _format << data;
197 25 : return send(StringView(stream.weak()));
198 25 : } else {
199 0 : return send(data::write(data, _format));
200 : }
201 : }
202 :
203 500 : void WebsocketHandler::performWithStorage(const Callback<void(const db::Transaction &)> &cb) const {
204 500 : _manager->host().performWithStorage(cb);
205 500 : }
206 :
207 0 : bool WebsocketHandler::performAsync(const Callback<void(AsyncTask &)> &cb) const {
208 0 : return _conn->performAsync(cb);
209 : }
210 :
211 0 : pool_t *WebsocketHandler::pool() const {
212 0 : return _conn->getHandlePool();
213 : }
214 :
215 25 : void WebsocketHandler::setConnection(WebsocketConnection *c) {
216 25 : _conn = c;
217 25 : }
218 :
219 0 : void WebsocketHandler::receiveBroadcast(const Value &data) {
220 0 : if (_conn->isEnabled()) {
221 0 : _broadcastMutex.lock();
222 0 : if (!_broadcastsPool) {
223 0 : _broadcastsPool = memory::pool::create(_pool);
224 : }
225 0 : if (_broadcastsPool) {
226 0 : perform([&, this] {
227 0 : if (!_broadcastsMessages) {
228 0 : _broadcastsMessages = new (_broadcastsPool) Vector<Value>(_broadcastsPool);
229 : }
230 :
231 0 : _broadcastsMessages->emplace_back(data);
232 0 : }, _broadcastsPool, config::TAG_WEBSOCKET, _conn);
233 : }
234 0 : _broadcastMutex.unlock();
235 0 : _conn->wakeup();
236 : }
237 0 : }
238 :
239 0 : bool WebsocketHandler::processBroadcasts() {
240 : pool_t *pool;
241 : Vector<Value> * vec;
242 :
243 0 : _broadcastMutex.lock();
244 :
245 0 : pool = _broadcastsPool;
246 0 : vec = _broadcastsMessages;
247 :
248 0 : _broadcastsPool = nullptr;
249 0 : _broadcastsMessages = nullptr;
250 :
251 0 : _broadcastMutex.unlock();
252 :
253 0 : bool ret = true;
254 0 : if (pool) {
255 0 : perform([&, this] {
256 0 : sendPendingNotifications(pool);
257 0 : if (vec) {
258 0 : for (auto & it : (*vec)) {
259 0 : if (!handleMessage(it)) {
260 0 : ret = false;
261 0 : break;
262 : }
263 : }
264 : }
265 0 : }, pool, config::TAG_WEBSOCKET, _conn);
266 0 : pool::destroy(pool);
267 : }
268 :
269 0 : return ret;
270 : }
271 :
272 249 : void WebsocketHandler::sendPendingNotifications(pool_t *pool) {
273 1245 : perform([&, this] {
274 249 : _manager->host().getRoot()->setErrorNotification(pool, [this] (Value && data) {
275 0 : send(Value({
276 0 : std::make_pair("error", Value(std::move(data)))
277 0 : }));
278 0 : }, [this] (Value && data) {
279 0 : send(Value({
280 0 : std::make_pair("debug", Value(std::move(data)))
281 0 : }));
282 0 : });
283 498 : }, pool, config::TAG_WEBSOCKET, _conn);
284 249 : }
285 :
286 : }
|