Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions editoast/Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions editoast/core_client/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ mocking_client = ["dep:http"]
approx.workspace = true
chrono.workspace = true
common = { workspace = true }
dashmap.workspace = true
deadpool.workspace = true
editoast_derive.workspace = true
educe.workspace = true
Expand Down
21 changes: 13 additions & 8 deletions editoast/core_client/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ use serde::de::DeserializeOwned;
use std::collections::HashMap;
use std::fmt::Display;
use std::marker::PhantomData;
use std::time::Duration;
use thiserror::Error;
use tracing::trace;

Expand Down Expand Up @@ -63,6 +64,7 @@ impl CoreClient {
path: &str,
body: Option<&B>,
worker_id: Option<String>,
override_timeout: Option<Duration>,
) -> Result<R::Response, Error> {
trace!(
target: "editoast::coreclient",
Expand All @@ -76,7 +78,7 @@ impl CoreClient {
// TODO: tracing: use correlation id

let response = client
.call_with_response(worker_id, path, &body, true, None)
.call_with_response(worker_id, path, &body, true, override_timeout)
.await
.map_err(Error::MqClientError)?;

Expand Down Expand Up @@ -139,13 +141,11 @@ where
Self: Serialize + Sized + Sync,
R: CoreResponse,
{
/// A shorthand for [Self::url]
/// The URL path of this request
const URL_PATH: &'static str;

/// Returns the URL for this request, by default returns [Self::URL_PATH]
fn url(&self) -> &str {
Self::URL_PATH
}
/// An optional timeout override for this request
const OVERRIDE_TIMEOUT: Option<Duration> = None;

/// Returns the worker id used for the request. Must be provided.
fn worker_id(&self) -> Option<String>;
Expand All @@ -158,8 +158,13 @@ where
/// manage itself its expected errors. Maybe a bound error type defaulting
/// to CoreError and a trait function handle_errors would suffice?
async fn fetch(&self, core: &CoreClient) -> Result<R::Response, Error> {
core.fetch::<Self, R>(self.url(), Some(self), self.worker_id())
.await
core.fetch::<Self, R>(
Self::URL_PATH,
Some(self),
self.worker_id(),
Self::OVERRIDE_TIMEOUT,
)
.await
}
}

Expand Down
44 changes: 26 additions & 18 deletions editoast/core_client/src/mq_client.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
use dashmap::DashMap;
use deadpool::managed::Manager;
use deadpool::managed::Metrics;
use deadpool::managed::Pool;
Expand All @@ -18,15 +19,13 @@ use lapin::types::FieldTable;
use lapin::types::ShortString;
use serde::Serialize;
use serde_json::to_vec;
use std::collections::HashMap;
use std::fmt::Debug;
use std::sync::Arc;
use thiserror::Error;
use tokio::sync::RwLock;
use tokio::sync::oneshot;
use tokio::task;
use tokio::time::Duration;
use tokio::time::timeout;
use tracing::Instrument;
use url::Url;
use uuid::Uuid;
Expand All @@ -43,13 +42,15 @@ pub struct RabbitMQClient {
pub struct ChannelManager {
connection: Arc<RwLock<Option<Connection>>>,
hostname: String,
response_tracker: Arc<DashMap<String, oneshot::Sender<Delivery>>>,
}

impl ChannelManager {
pub fn new(connection: Arc<RwLock<Option<Connection>>>, hostname: String) -> Self {
ChannelManager {
connection,
hostname,
response_tracker: Arc::new(DashMap::new()),
}
}
}
Expand All @@ -72,7 +73,12 @@ impl Manager for ChannelManager {
.await
.map_err(|_| ChannelManagerError::Lapin)?;

Ok(ChannelWorker::new(Arc::new(channel), self.hostname.clone()).await)
Ok(ChannelWorker::new(
Arc::new(channel),
self.hostname.clone(),
self.response_tracker.clone(),
)
.await)
} else {
Err(ChannelManagerError::ConnectionNotFound)
}
Expand All @@ -94,15 +100,19 @@ impl Manager for ChannelManager {
#[derive(Debug)]
pub struct ChannelWorker {
channel: Arc<Channel>,
response_tracker: Arc<RwLock<HashMap<String, oneshot::Sender<Delivery>>>>,
response_tracker: Arc<DashMap<String, oneshot::Sender<Delivery>>>,
consumer_tag: String,
}

impl ChannelWorker {
pub async fn new(channel: Arc<Channel>, hostname: String) -> Self {
pub async fn new(
channel: Arc<Channel>,
hostname: String,
response_tracker: Arc<DashMap<String, oneshot::Sender<Delivery>>>,
) -> Self {
let worker = ChannelWorker {
channel,
response_tracker: Arc::new(RwLock::new(HashMap::new())),
response_tracker,
consumer_tag: format!("{}-{}", hostname, Uuid::new_v4()),
};
worker.dispatching_loop().await;
Expand All @@ -118,8 +128,7 @@ impl ChannelWorker {
correlation_id: String,
tx: oneshot::Sender<Delivery>,
) {
let mut response_tracker = self.response_tracker.write().await;
response_tracker.insert(correlation_id, tx);
self.response_tracker.insert(correlation_id, tx);
}

pub fn should_reuse(&self) -> bool {
Expand Down Expand Up @@ -148,8 +157,7 @@ impl ChannelWorker {
while let Some(delivery) = consumer.next().await {
let delivery = delivery.expect("Error in receiving message");
if let Some(correlation_id) = delivery.properties.correlation_id().as_ref() {
let mut tracker = response_tracker.write().await;
if let Some(sender) = tracker.remove(correlation_id.as_str()) {
if let Some((_, sender)) = response_tracker.remove(correlation_id.as_str()) {
let _ = sender.send(delivery);
}
} else {
Expand Down Expand Up @@ -181,6 +189,8 @@ pub enum MqClientError {
Serialization(#[educe(PartialEq(ignore))] serde_json::Error),
#[error("Cannot parse response status")]
StatusParsing,
#[error("Response channel was closed due to a delivery error")]
ResponseChannelClosed,
#[error("Response timeout")]
ResponseTimeout,
#[error("Connection does not exist")]
Expand Down Expand Up @@ -393,12 +403,13 @@ impl RabbitMQClient {
path: &str,
published_payload: &Option<T>,
mandatory: bool,
override_timeout: Option<u64>,
override_timeout: Option<Duration>,
) -> Result<MQResponse, MqClientError>
where
T: Serialize,
{
let correlation_id = Uuid::new_v4().to_string();
let timeout = override_timeout.unwrap_or_else(|| Duration::from_secs(self.timeout));

// Get the next channel
let channel_worker = self
Expand All @@ -425,6 +436,7 @@ impl RabbitMQClient {
let properties = BasicProperties::default()
.with_reply_to(ShortString::from("amq.rabbitmq.reply-to"))
.with_correlation_id(ShortString::from(correlation_id.clone()))
.with_expiration(timeout.as_millis().to_string().into())
.with_headers(headers);

let (tx, rx) = oneshot::channel();
Expand All @@ -451,12 +463,7 @@ impl RabbitMQClient {
// Release from the pool
drop(channel_worker);

match timeout(
Duration::from_secs(override_timeout.unwrap_or(self.timeout)),
rx,
)
.await
{
match tokio::time::timeout(timeout, rx).await {
Ok(Ok(delivery)) => {
let status = delivery
.properties
Expand All @@ -472,7 +479,8 @@ impl RabbitMQClient {
status,
})
}
Ok(Err(_)) | Err(_) => Err(MqClientError::ResponseTimeout),
Ok(Err(_)) => Err(MqClientError::ResponseChannelClosed),
Err(_) => Err(MqClientError::ResponseTimeout),
}
}
}
Expand Down
1 change: 1 addition & 0 deletions editoast/core_client/src/worker_load.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ pub struct WorkerLoadRequest {

impl AsCoreRequest<()> for WorkerLoadRequest {
const URL_PATH: &'static str = "/worker_load";
const OVERRIDE_TIMEOUT: Option<std::time::Duration> = Some(std::time::Duration::from_secs(10));

fn worker_id(&self) -> Option<String> {
match self.timetable {
Expand Down
56 changes: 24 additions & 32 deletions editoast/src/views/worker_load.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@ use axum::Extension;
use axum::extract::Json;
use axum::extract::State;
use core_client::AsCoreRequest;
use core_client::Error as CoreClientError;
use core_client::mq_client::MqClientError;
use core_client::worker_load::WorkerLoadRequest;
use editoast_derive::EditoastError;
use editoast_models::prelude::*;
Expand Down Expand Up @@ -85,7 +87,6 @@ pub(in crate::views) async fn worker_load(
State(AppState {
db_pool,
core_client,
osrdyne_client,
..
}): State<AppState>,
Extension(auth): AuthenticationExt,
Expand Down Expand Up @@ -123,29 +124,22 @@ pub(in crate::views) async fn worker_load(
.await?;
}

// Fetch status of the worker
let worker_key = match timetable_id {
Some(timetable_id) => format!("{infra_id}-{timetable_id}"),
None => infra_id.to_string(),
let infra_request = WorkerLoadRequest {
infra: infra.id,
expected_version: infra.version,
timetable: timetable_id,
};

let status = infra_request.fetch(core_client.as_ref()).await;

let status = match status {
Ok(()) => WorkerStatus::Ready,
Err(CoreClientError::MqClientError(MqClientError::ResponseTimeout)) => {
// Treat timeout as NotReady
WorkerStatus::NotReady
}
Err(e) => return Err(e.into()),
};
let status = osrdyne_client
.get_worker_status(&worker_key)
.await
.map_err(WorkerLoadError::FetchStatusError)?
.into();

if status == WorkerStatus::Error || status == WorkerStatus::NotReady {
let infra_request = WorkerLoadRequest {
infra: infra.id,
expected_version: infra.version,
timetable: timetable_id,
};

// Send message to load worker in background
tokio::spawn(async move {
let _ = infra_request.fetch(core_client.as_ref()).await;
});
}

Ok(Json(status))
}
Expand All @@ -156,7 +150,6 @@ mod tests {
use core_client::CoreClient;
use core_client::mocking::MockingClient;
use database::DbConnectionPoolV2;
use osrdyne_client::OsrdyneClient;
use reqwest::StatusCode;

use crate::models::fixtures::create_empty_infra;
Expand All @@ -167,16 +160,15 @@ mod tests {
let db_pool = DbConnectionPoolV2::for_tests();
let empty_infra = create_empty_infra(&mut db_pool.get_ok()).await;

let osrdyne_client = OsrdyneClient::mock()
.with_status(
&empty_infra.id.to_string(),
osrdyne_client::WorkerStatus::Ready,
)
.build();
let mut core = MockingClient::default();
core.stub("/status")
.response(StatusCode::OK)
.body("")
.finish();

let app = TestAppBuilder::new()
.db_pool(db_pool)
.core_client(CoreClient::Mocked(MockingClient::default()))
.osrdyne_client(osrdyne_client)
.core_client(CoreClient::Mocked(core))
.build();
let req = app.post("/worker_load").json(&WorkerLoadForm {
infra_id: empty_infra.id,
Expand Down
2 changes: 1 addition & 1 deletion osrdyne/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ reqwest = { version = "0.12.24", default-features = false, features = [
"rustls-tls",
] }
schemars = "1.0.5"
serde = "1.0.228"
serde = { version = "1.0.228", features = ["derive", "rc"] }
serde_json = "1.0.145"
sha2 = "0.10.9"
smol_str = { version = "0.3.4", features = ["serde"] }
Expand Down
Loading