Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 33 additions & 0 deletions core/src/main/java/fr/sncf/osrd/api/StatusEndpoint.kt
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
package fr.sncf.osrd.api

import com.squareup.moshi.JsonAdapter
import com.squareup.moshi.Moshi
import com.squareup.moshi.Types
import org.takes.Request
import org.takes.Response
import org.takes.Take
import org.takes.rs.RsJson
import org.takes.rs.RsWithBody

class StatusEndpoint : Take {
override fun act(req: Request): Response {
val response = HashMap<String, String>()
response["ready"] = "true"
return RsJson(RsWithBody(adapter.toJson(response)))
}

companion object {
private val adapter: JsonAdapter<MutableMap<String, String>>

init {
val moshi = Moshi.Builder().build()
val type =
Types.newParameterizedType(
MutableMap::class.java,
String::class.java,
String::class.java,
)
adapter = moshi.adapter<MutableMap<String, String>>(type)
}
}
}
1 change: 1 addition & 0 deletions core/src/main/java/fr/sncf/osrd/cli/WorkerCommand.kt
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,7 @@ class WorkerCommand : CliCommand {
"/etcs_braking_curves" to
ETCSBrakingCurvesEndpoint(infraManager, electricalProfileSetManager),
"/version" to VersionEndpoint(),
"/status" to StatusEndpoint(),
"/stdcm" to STDCMEndpoint(infraManager, timetableCache),
"/worker_load" to WorkerLoadEndpoint(infraManager, timetableCache),
)
Expand Down
1 change: 1 addition & 0 deletions editoast/Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 2 additions & 1 deletion editoast/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ clap = { version = "4.5", features = ["derive", "env"] }
common = { path = "./common" }
core_client = { path = "./core_client" }
darling = "0.21"
dashmap = "6.1.0"
database = { path = "./database" }
deadpool = { version = "0.12.3", default-features = false, features = [
"managed",
Expand Down Expand Up @@ -158,7 +159,7 @@ clap.workspace = true
colored = "3.0.0"
common = { workspace = true }
core_client = { workspace = true, features = ["mocking_client"] }
dashmap = "6.1.0"
dashmap.workspace = true
database.workspace = true
deadpool.workspace = true
deadpool-redis.workspace = true
Expand Down
1 change: 1 addition & 0 deletions editoast/core_client/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ mocking_client = []
approx.workspace = true
chrono.workspace = true
common = { workspace = true }
dashmap.workspace = true
deadpool.workspace = true
editoast_derive.workspace = true
educe.workspace = true
Expand Down
33 changes: 30 additions & 3 deletions editoast/core_client/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ pub mod path_properties;
pub mod pathfinding;
pub mod signal_projection;
pub mod simulation;
pub mod status;
pub mod stdcm;
pub mod version;
pub mod worker_load;
Expand All @@ -19,6 +20,7 @@ use serde::de::DeserializeOwned;
use std::collections::HashMap;
use std::fmt::Display;
use std::marker::PhantomData;
use std::time::Duration;
use thiserror::Error;
use tracing::trace;

Expand Down Expand Up @@ -63,6 +65,8 @@ impl CoreClient {
path: &str,
body: Option<&B>,
worker_id: Option<String>,
override_timeout: Option<Duration>,
no_worker_load: bool,
) -> Result<R::Response, Error> {
trace!(
target: "editoast::coreclient",
Expand All @@ -76,7 +80,14 @@ impl CoreClient {
// TODO: tracing: use correlation id

let response = client
.call_with_response(worker_id, path, &body, true, None)
.call_with_response(
worker_id,
path,
&body,
true,
override_timeout,
no_worker_load,
)
.await
.map_err(Error::MqClientError)?;

Expand Down Expand Up @@ -150,6 +161,16 @@ where
/// Returns the worker id used for the request. Must be provided.
fn worker_id(&self) -> Option<String>;

/// Should this request avoid starting a new worker if none is running?
fn no_worker_load(&self) -> bool {
false
}

/// Returns the timeout override for this request, if any.
fn override_timeout(&self) -> Option<Duration> {
None
}

Comment on lines +164 to +173
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: associated constants maybe?

/// Sends this request using the given [CoreClient] and returns the response content on success
///
/// Raises a [enum@Error] if the request is not a success.
Expand All @@ -158,8 +179,14 @@ where
/// manage itself its expected errors. Maybe a bound error type defaulting
/// to CoreError and a trait function handle_errors would suffice?
async fn fetch(&self, core: &CoreClient) -> Result<R::Response, Error> {
core.fetch::<Self, R>(self.url(), Some(self), self.worker_id())
.await
core.fetch::<Self, R>(
self.url(),
Some(self),
self.worker_id(),
self.override_timeout(),
self.no_worker_load(),
)
.await
}
}

Expand Down
52 changes: 34 additions & 18 deletions editoast/core_client/src/mq_client.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
use dashmap::DashMap;
use deadpool::managed::Manager;
use deadpool::managed::Metrics;
use deadpool::managed::Pool;
Expand All @@ -18,15 +19,13 @@ use lapin::types::FieldTable;
use lapin::types::ShortString;
use serde::Serialize;
use serde_json::to_vec;
use std::collections::HashMap;
use std::fmt::Debug;
use std::sync::Arc;
use thiserror::Error;
use tokio::sync::RwLock;
use tokio::sync::oneshot;
use tokio::task;
use tokio::time::Duration;
use tokio::time::timeout;
use tracing::Instrument;
use url::Url;
use uuid::Uuid;
Expand All @@ -43,13 +42,15 @@ pub struct RabbitMQClient {
pub struct ChannelManager {
connection: Arc<RwLock<Option<Connection>>>,
hostname: String,
response_tracker: Arc<DashMap<String, oneshot::Sender<Delivery>>>,
}

impl ChannelManager {
pub fn new(connection: Arc<RwLock<Option<Connection>>>, hostname: String) -> Self {
ChannelManager {
connection,
hostname,
response_tracker: Arc::new(DashMap::new()),
}
}
}
Expand All @@ -72,7 +73,12 @@ impl Manager for ChannelManager {
.await
.map_err(|_| ChannelManagerError::Lapin)?;

Ok(ChannelWorker::new(Arc::new(channel), self.hostname.clone()).await)
Ok(ChannelWorker::new(
Arc::new(channel),
self.hostname.clone(),
self.response_tracker.clone(),
)
.await)
} else {
Err(ChannelManagerError::ConnectionNotFound)
}
Expand All @@ -94,15 +100,19 @@ impl Manager for ChannelManager {
#[derive(Debug)]
pub struct ChannelWorker {
channel: Arc<Channel>,
response_tracker: Arc<RwLock<HashMap<String, oneshot::Sender<Delivery>>>>,
response_tracker: Arc<DashMap<String, oneshot::Sender<Delivery>>>,
consumer_tag: String,
}

impl ChannelWorker {
pub async fn new(channel: Arc<Channel>, hostname: String) -> Self {
pub async fn new(
channel: Arc<Channel>,
hostname: String,
response_tracker: Arc<DashMap<String, oneshot::Sender<Delivery>>>,
) -> Self {
let worker = ChannelWorker {
channel,
response_tracker: Arc::new(RwLock::new(HashMap::new())),
response_tracker,
consumer_tag: format!("{}-{}", hostname, Uuid::new_v4()),
};
worker.dispatching_loop().await;
Expand All @@ -118,8 +128,7 @@ impl ChannelWorker {
correlation_id: String,
tx: oneshot::Sender<Delivery>,
) {
let mut response_tracker = self.response_tracker.write().await;
response_tracker.insert(correlation_id, tx);
self.response_tracker.insert(correlation_id, tx);
}

pub fn should_reuse(&self) -> bool {
Expand Down Expand Up @@ -148,8 +157,7 @@ impl ChannelWorker {
while let Some(delivery) = consumer.next().await {
let delivery = delivery.expect("Error in receiving message");
if let Some(correlation_id) = delivery.properties.correlation_id().as_ref() {
let mut tracker = response_tracker.write().await;
if let Some(sender) = tracker.remove(correlation_id.as_str()) {
if let Some((_, sender)) = response_tracker.remove(correlation_id.as_str()) {
let _ = sender.send(delivery);
}
} else {
Expand Down Expand Up @@ -181,6 +189,8 @@ pub enum MqClientError {
Serialization(#[educe(PartialEq(ignore))] serde_json::Error),
#[error("Cannot parse response status")]
StatusParsing,
#[error("Response channel was closed due to a delivery error")]
ResponseChannelClosed,
#[error("Response timeout")]
ResponseTimeout,
#[error("Connection does not exist")]
Expand Down Expand Up @@ -335,6 +345,7 @@ impl RabbitMQClient {
published_payload: &T,
mandatory: bool,
correlation_id: Option<String>,
no_worker_load: bool,
) -> Result<(), MqClientError>
where
T: Serialize,
Expand All @@ -360,6 +371,9 @@ impl RabbitMQClient {
let mut headers = FieldTable::default();
headers.insert("x-rpc-path".into(), path.into());
attach_tracing_info(&mut headers);
if no_worker_load {
headers.insert("x-osrdyne-no-start".into(), true.into());
}

let mut properties = BasicProperties::default().with_headers(headers);
if let Some(id) = correlation_id {
Expand Down Expand Up @@ -392,12 +406,14 @@ impl RabbitMQClient {
path: &str,
published_payload: &Option<T>,
mandatory: bool,
override_timeout: Option<u64>,
override_timeout: Option<Duration>,
no_worker_load: bool,
) -> Result<MQResponse, MqClientError>
where
T: Serialize,
{
let correlation_id = Uuid::new_v4().to_string();
let timeout = override_timeout.unwrap_or_else(|| Duration::from_secs(self.timeout));

// Get the next channel
let channel_worker = self
Expand All @@ -420,10 +436,14 @@ impl RabbitMQClient {
let mut headers = FieldTable::default();
headers.insert("x-rpc-path".into(), path.into());
attach_tracing_info(&mut headers);
if no_worker_load {
headers.insert("x-osrdyne-no-start".into(), true.into());
}

let properties = BasicProperties::default()
.with_reply_to(ShortString::from("amq.rabbitmq.reply-to"))
.with_correlation_id(ShortString::from(correlation_id.clone()))
.with_expiration(timeout.as_millis().to_string().into())
.with_headers(headers);

let (tx, rx) = oneshot::channel();
Expand All @@ -450,12 +470,7 @@ impl RabbitMQClient {
// Release from the pool
drop(channel_worker);

match timeout(
Duration::from_secs(override_timeout.unwrap_or(self.timeout)),
rx,
)
.await
{
match tokio::time::timeout(timeout, rx).await {
Ok(Ok(delivery)) => {
let status = delivery
.properties
Expand All @@ -471,7 +486,8 @@ impl RabbitMQClient {
status,
})
}
Ok(Err(_)) | Err(_) => Err(MqClientError::ResponseTimeout),
Ok(Err(_)) => Err(MqClientError::ResponseChannelClosed),
Err(_) => Err(MqClientError::ResponseTimeout),
}
}
}
Expand Down
29 changes: 29 additions & 0 deletions editoast/core_client/src/status.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
use serde::Serialize;

use super::AsCoreRequest;

/// A Core infra load request
#[derive(Debug, Serialize)]
pub struct StatusRequest {
pub infra: i64,
pub timetable: Option<i64>,
}

impl AsCoreRequest<()> for StatusRequest {
const URL_PATH: &'static str = "/status";

fn worker_id(&self) -> Option<String> {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For later: maybe enum core_client::WorkerKey instead of Option<String>?

match self.timetable {
Some(timetable) => Some(format!("{}-{}", self.infra, timetable)),
None => Some(self.infra.to_string()),
}
}

fn no_worker_load(&self) -> bool {
true
}

fn override_timeout(&self) -> Option<std::time::Duration> {
Some(std::time::Duration::from_secs(10))
}
}
Loading
Loading