agent-enviroments/builder/libs/seastar/tests/unit/loopback_socket.hh
2024-09-10 17:06:08 +03:00

350 lines
14 KiB
C++

/*
* This file is open source software, licensed to you under the terms
* of the Apache License, Version 2.0 (the "License"). See the NOTICE file
* distributed with this work for additional information regarding copyright
* ownership. You may not use this file except in compliance with the License.
*
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*
* Copyright (C) 2016 ScyllaDB
*/
#pragma once
#include <system_error>
#include <seastar/core/iostream.hh>
#include <seastar/core/circular_buffer.hh>
#include <seastar/core/shared_ptr.hh>
#include <seastar/core/queue.hh>
#include <seastar/core/loop.hh>
#include <seastar/core/sleep.hh>
#include <seastar/core/do_with.hh>
#include <seastar/net/stack.hh>
#include <seastar/core/sharded.hh>
namespace seastar {
struct loopback_error_injector {
enum class error { none, one_shot, abort };
virtual ~loopback_error_injector() {};
virtual error server_rcv_error() { return error::none; }
virtual error server_snd_error() { return error::none; }
virtual error client_rcv_error() { return error::none; }
virtual error client_snd_error() { return error::none; }
virtual error connect_error() { return error::none; }
virtual std::chrono::microseconds connect_delay() { return std::chrono::microseconds(0); }
};
class loopback_buffer {
public:
enum class type : uint8_t {
CLIENT_TX,
SERVER_TX
};
private:
bool _aborted = false;
queue<temporary_buffer<char>> _q{1};
loopback_error_injector* _error_injector;
type _type;
std::optional<promise<>> _shutdown;
public:
loopback_buffer(loopback_error_injector* error_injection, type t) : _error_injector(error_injection), _type(t) {}
future<> push(temporary_buffer<char>&& b) {
if (_aborted) {
return make_exception_future<>(std::system_error(EPIPE, std::system_category()));
}
if (_error_injector) {
auto error = _type == type::CLIENT_TX ? _error_injector->client_snd_error() : _error_injector->server_snd_error();
if (error == loopback_error_injector::error::one_shot) {
return make_exception_future<>(std::runtime_error("test injected glitch on send"));
}
if (error == loopback_error_injector::error::abort) {
abort();
return make_exception_future<>(std::runtime_error("test injected error on send"));
}
}
return _q.push_eventually(std::move(b));
}
future<temporary_buffer<char>> pop() {
if (_aborted) {
return make_exception_future<temporary_buffer<char>>(std::system_error(EPIPE, std::system_category()));
}
if (_error_injector) {
auto error = _type == type::CLIENT_TX ? _error_injector->client_rcv_error() : _error_injector->server_rcv_error();
if (error == loopback_error_injector::error::one_shot) {
return make_exception_future<temporary_buffer<char>>(std::runtime_error("test injected glitch on receive"));
}
if (error == loopback_error_injector::error::abort) {
abort();
return make_exception_future<temporary_buffer<char>>(std::runtime_error("test injected error on receive"));
}
}
return _q.pop_eventually();
}
void abort() noexcept {
shutdown();
_aborted = true;
_q.abort(std::make_exception_ptr(std::system_error(EPIPE, std::system_category())));
}
void shutdown() noexcept {
// it can be called by both -- reader and writer socket impls
if (_shutdown.has_value()) {
_shutdown->set_value();
_shutdown.reset();
}
}
future<> wait_input_shutdown() {
assert(!_shutdown.has_value());
_shutdown.emplace();
return _shutdown->get_future();
}
};
class loopback_data_sink_impl : public data_sink_impl {
lw_shared_ptr<foreign_ptr<lw_shared_ptr<loopback_buffer>>> _buffer;
noncopyable_function<void()> _batch_flush_error;
public:
explicit loopback_data_sink_impl(lw_shared_ptr<foreign_ptr<lw_shared_ptr<loopback_buffer>>> buffer, noncopyable_function<void()> flush_error)
: _buffer(buffer)
, _batch_flush_error(std::move(flush_error))
{
}
future<> put(net::packet data) override {
return do_with(data.release(), [this] (std::vector<temporary_buffer<char>>& bufs) {
return do_for_each(bufs, [this] (temporary_buffer<char>& buf) {
return smp::submit_to(_buffer->get_owner_shard(), [this, b = buf.get(), s = buf.size()] {
return (*_buffer)->push(temporary_buffer<char>(b, s));
});
});
});
}
future<> close() override {
return smp::submit_to(_buffer->get_owner_shard(), [this] {
return (*_buffer)->push({}).handle_exception_type([] (std::system_error& err) {
if (err.code().value() != EPIPE) {
throw err;
}
});
});
}
bool can_batch_flushes() const noexcept override { return true; }
void on_batch_flush_error() noexcept override { _batch_flush_error(); }
};
class loopback_data_source_impl : public data_source_impl {
bool _eof = false;
lw_shared_ptr<loopback_buffer> _buffer;
public:
explicit loopback_data_source_impl(lw_shared_ptr<loopback_buffer> buffer)
: _buffer(std::move(buffer)) {
}
future<temporary_buffer<char>> get() override {
return _buffer->pop().then_wrapped([this] (future<temporary_buffer<char>>&& b) {
_eof = b.failed();
if (!_eof) {
// future::get() is destructive, so we have to play these games
// FIXME: make future::get() non-destructive
auto&& tmp = b.get();
_eof = tmp.empty();
b = make_ready_future<temporary_buffer<char>>(std::move(tmp));
}
return std::move(b);
});
}
future<> close() override {
if (!_eof) {
_buffer->abort();
} else {
_buffer->shutdown();
}
return make_ready_future<>();
}
};
class loopback_connected_socket_impl : public net::connected_socket_impl {
lw_shared_ptr<foreign_ptr<lw_shared_ptr<loopback_buffer>>> _tx;
lw_shared_ptr<loopback_buffer> _rx;
public:
loopback_connected_socket_impl(foreign_ptr<lw_shared_ptr<loopback_buffer>> tx, lw_shared_ptr<loopback_buffer> rx)
: _tx(make_lw_shared(std::move(tx))), _rx(std::move(rx)) {
}
data_source source() override {
return data_source(std::make_unique<loopback_data_source_impl>(_rx));
}
data_sink sink() override {
return data_sink(std::make_unique<loopback_data_sink_impl>(_tx, [this] { shutdown_input(); }));
}
void shutdown_input() override {
_rx->abort();
}
void shutdown_output() override {
(void)smp::submit_to(_tx->get_owner_shard(), [tx = _tx] {
(*tx)->abort();
});
}
void set_nodelay(bool nodelay) override {
}
bool get_nodelay() const override {
return true;
}
void set_keepalive(bool keepalive) override {}
bool get_keepalive() const override {
return false;
}
void set_keepalive_parameters(const net::keepalive_params&) override {}
net::keepalive_params get_keepalive_parameters() const override {
return net::tcp_keepalive_params {std::chrono::seconds(0), std::chrono::seconds(0), 0};
}
void set_sockopt(int level, int optname, const void* data, size_t len) override {
throw std::runtime_error("Setting custom socket options is not supported for loopback");
}
int get_sockopt(int level, int optname, void* data, size_t len) const override {
throw std::runtime_error("Getting custom socket options is not supported for loopback");
}
socket_address local_address() const noexcept override {
// dummy
return {};
}
socket_address remote_address() const noexcept override {
// dummy
return {};
}
future<> wait_input_shutdown() override {
return _rx->wait_input_shutdown();
}
};
class loopback_server_socket_impl : public net::server_socket_impl {
lw_shared_ptr<queue<connected_socket>> _pending;
public:
explicit loopback_server_socket_impl(lw_shared_ptr<queue<connected_socket>> q)
: _pending(std::move(q)) {
}
future<accept_result> accept() override {
return _pending->pop_eventually().then([] (connected_socket&& cs) {
return make_ready_future<accept_result>(accept_result{std::move(cs), socket_address()});
});
}
void abort_accept() override {
_pending->abort(std::make_exception_ptr(std::system_error(ECONNABORTED, std::system_category())));
}
socket_address local_address() const override {
// CMH dummy
return {};
}
};
class loopback_connection_factory {
unsigned _shard = 0;
unsigned _shards_count;
std::vector<lw_shared_ptr<queue<connected_socket>>> _pending;
public:
explicit loopback_connection_factory(unsigned shards_count = smp::count)
: _shards_count(shards_count)
{
_pending.resize(shards_count);
}
server_socket get_server_socket() {
assert(this_shard_id() < _shards_count);
if (!_pending[this_shard_id()]) {
_pending[this_shard_id()] = make_lw_shared<queue<connected_socket>>(10);
}
return server_socket(std::make_unique<loopback_server_socket_impl>(_pending[this_shard_id()]));
}
future<> make_new_server_connection(foreign_ptr<lw_shared_ptr<loopback_buffer>> b1, lw_shared_ptr<loopback_buffer> b2) {
assert(this_shard_id() < _shards_count);
if (!_pending[this_shard_id()]) {
_pending[this_shard_id()] = make_lw_shared<queue<connected_socket>>(10);
}
return _pending[this_shard_id()]->push_eventually(connected_socket(std::make_unique<loopback_connected_socket_impl>(std::move(b1), b2)));
}
connected_socket make_new_client_connection(lw_shared_ptr<loopback_buffer> b1, foreign_ptr<lw_shared_ptr<loopback_buffer>> b2) {
return connected_socket(std::make_unique<loopback_connected_socket_impl>(std::move(b2), b1));
}
unsigned next_shard() {
return _shard++ % _shards_count;
}
void destroy_shard(unsigned shard) {
assert(shard < _shards_count);
_pending[shard] = nullptr;
}
future<> destroy_all_shards() {
return parallel_for_each(boost::irange(0u, _shards_count), [this](shard_id shard) {
return smp::submit_to(shard, [this] {
destroy_shard(this_shard_id());
});
});
}
};
class loopback_socket_impl : public net::socket_impl {
loopback_connection_factory& _factory;
loopback_error_injector* _error_injector;
lw_shared_ptr<loopback_buffer> _b1;
foreign_ptr<lw_shared_ptr<loopback_buffer>> _b2;
std::optional<promise<connected_socket>> _connect_abort;
public:
loopback_socket_impl(loopback_connection_factory& factory, loopback_error_injector* error_injector = nullptr)
: _factory(factory), _error_injector(error_injector)
{ }
future<connected_socket> connect(socket_address sa, socket_address local, seastar::transport proto = seastar::transport::TCP) override {
if (_error_injector) {
auto error = _error_injector->connect_error();
if (error != loopback_error_injector::error::none) {
_connect_abort.emplace();
return _connect_abort->get_future();
}
}
auto shard = _factory.next_shard();
_b1 = make_lw_shared<loopback_buffer>(_error_injector, loopback_buffer::type::SERVER_TX);
return smp::submit_to(shard, [this, b1 = make_foreign(_b1)] () mutable {
auto b2 = make_lw_shared<loopback_buffer>(_error_injector, loopback_buffer::type::CLIENT_TX);
_b2 = make_foreign(b2);
return _factory.make_new_server_connection(std::move(b1), b2).then([b2] {
return make_foreign(b2);
});
}).then([this] (foreign_ptr<lw_shared_ptr<loopback_buffer>> b2) {
if (_error_injector) {
auto delay = _error_injector->connect_delay();
if (delay != std::chrono::microseconds(0)) {
return seastar::sleep(delay).then([this, b2 = std::move(b2)] () mutable {
return _factory.make_new_client_connection(_b1, std::move(b2));
});
}
}
return make_ready_future<connected_socket>(_factory.make_new_client_connection(_b1, std::move(b2)));
});
}
virtual void set_reuseaddr(bool reuseaddr) override {}
virtual bool get_reuseaddr() const override { return false; };
void shutdown() override {
if (_connect_abort) {
_connect_abort->set_exception(std::make_exception_ptr(std::system_error(ECONNABORTED, std::system_category())));
_connect_abort = std::nullopt;
} else {
_b1->abort();
(void)smp::submit_to(_b2.get_owner_shard(), [b2 = std::move(_b2)] {
b2->abort();
});
}
}
};
}