Skip to content

Commit d1d0068

Browse files
committed
core: implements ZonesConstraints and uses it in the LMR pathfinding
close #13250 Signed-off-by: Angelina Kuntz <[email protected]>
1 parent 0b25cf0 commit d1d0068

File tree

9 files changed

+162
-3
lines changed

9 files changed

+162
-3
lines changed

core/src/main/kotlin/fr/sncf/osrd/api/RollingStockParser.kt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ import fr.sncf.osrd.train.RollingStock.*
1414
/** Parse the rolling stock model into something the backend can work with */
1515
fun parseRawRollingStock(
1616
rawPhysicsConsist: PhysicsConsistModel,
17-
loadingGaugeType: RJSLoadingGaugeType = RJSLoadingGaugeType.G1,
17+
loadingGaugeType: RJSLoadingGaugeType? = RJSLoadingGaugeType.G1,
1818
rollingStockSupportedSignalingSystems: List<String> = listOf(),
1919
): RollingStock {
2020
// Parse effort_curves

core/src/main/kotlin/fr/sncf/osrd/api/pathfinding/PathfindingBlocksEndpoint.kt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,7 @@ fun runPathfinding(infra: FullInfra, request: PathfindingBlockRequest): Pathfind
114114
request.rollingStockLoadingGauge,
115115
request.rollingStockSupportedElectrifications,
116116
request.rollingStockSupportedSignalingSystems,
117+
null,
117118
)
118119

119120
val heuristics =

core/src/main/kotlin/fr/sncf/osrd/api/stdcm/STDCMEndpoint.kt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -137,6 +137,7 @@ class STDCMEndpoint(
137137
parseMarginValue(request.margin),
138138
Pathfinding.TIMEOUT,
139139
temporarySpeedLimitManager,
140+
request.zones,
140141
)
141142
if (path == null || hasDuplicateTracks(infra, path.blocks)) {
142143
val response = PathNotFound()

core/src/main/kotlin/fr/sncf/osrd/api/stdcm/STDCMRequest.kt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ class STDCMRequest(
3131
// Pathfinding inputs
3232
/// List of waypoints. Each waypoint is a list of track offsets
3333
@Json(name = "path_items") val pathItems: List<STDCMPathItem>,
34-
@Json(name = "rolling_stock_loading_gauge") val rollingStockLoadingGauge: RJSLoadingGaugeType,
34+
@Json(name = "rolling_stock_loading_gauge") val rollingStockLoadingGauge: RJSLoadingGaugeType?,
3535
@Json(name = "rolling_stock_supported_signaling_systems")
3636
val rollingStockSupportedSignalingSystems: List<String>,
3737

@@ -57,6 +57,7 @@ class STDCMRequest(
5757
@Json(name = "temporary_speed_limits")
5858
val temporarySpeedLimits: Collection<STDCMTemporarySpeedLimit>,
5959
@Json(name = "work_schedules") val workSchedules: Collection<WorkSchedule> = listOf(),
60+
val zones: HashSet<String>?,
6061
)
6162

6263
data class STDCMTemporarySpeedLimit(

core/src/main/kotlin/fr/sncf/osrd/pathfinding/constraints/ConstraintCombiner.kt

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,13 +27,15 @@ class ConstraintCombiner<EdgeT, OffsetType>(
2727
fun initConstraints(
2828
fullInfra: FullInfra,
2929
rollingStock: RollingStock,
30+
zones: HashSet<String>?,
3031
): List<PathfindingConstraint<Block>> {
3132
return initConstraintsFromRSProps(
3233
fullInfra,
3334
rollingStock.isThermal,
3435
rollingStock.loadingGaugeType,
3536
rollingStock.modeNames.toList(),
3637
rollingStock.supportedSignalingSystems.toList(),
38+
zones,
3739
)
3840
}
3941

@@ -43,6 +45,7 @@ fun initConstraintsFromRSProps(
4345
rollingStockLoadingGauge: RJSLoadingGaugeType,
4446
rollingStockSupportedElectrification: List<String>,
4547
rollingStockSupportedSignalingSystems: List<String>,
48+
zones: HashSet<String>?,
4649
): List<PathfindingConstraint<Block>> {
4750
val res = mutableListOf<PathfindingConstraint<Block>>()
4851
if (!rollingStockIsThermal) {
@@ -60,5 +63,6 @@ fun initConstraintsFromRSProps(
6063
infra.signalingSimulator.sigModuleManager.findSignalingSystem(it)
6164
}
6265
res.add(SignalingSystemConstraints(infra.blockInfra, listOf(sigSystemIds)))
66+
res.add(ZonesConstraints(infra.blockInfra, infra.rawInfra, zones))
6367
return res
6468
}
Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
package fr.sncf.osrd.pathfinding.constraints
2+
3+
import fr.sncf.osrd.graph.PathfindingConstraint
4+
import fr.sncf.osrd.path.implementations.buildTrainPathFromBlock
5+
import fr.sncf.osrd.path.interfaces.TrainPath
6+
import fr.sncf.osrd.pathfinding.Pathfinding
7+
import fr.sncf.osrd.sim_infra.api.Block
8+
import fr.sncf.osrd.sim_infra.api.BlockId
9+
import fr.sncf.osrd.sim_infra.api.BlockInfra
10+
import fr.sncf.osrd.sim_infra.api.RawSignalingInfra
11+
import fr.sncf.osrd.utils.units.Offset
12+
13+
data class ZonesConstraints(
14+
val blockInfra: BlockInfra,
15+
val infra: RawSignalingInfra,
16+
val zones: HashSet<String>?,
17+
) : PathfindingConstraint<Block> {
18+
override fun apply(edge: BlockId): Collection<Pathfinding.Range<Block>> {
19+
val res = HashSet<Pathfinding.Range<Block>>()
20+
val path = buildTrainPathFromBlock(infra, blockInfra, edge)
21+
res.addAll(getBlockedRanges(zones, path))
22+
return res
23+
}
24+
25+
private fun getBlockedRanges(
26+
zones: HashSet<String>?,
27+
path: TrainPath,
28+
): Collection<Pathfinding.Range<Block>> {
29+
30+
val res = HashSet<Pathfinding.Range<Block>>()
31+
val invalidTrackSections =
32+
infra.trackSections.filter { zones?.contains(infra.getTrackSectionName(it)) == false }
33+
34+
val invalidTrackSectionsInsideBlock =
35+
path.getChunks().filter {
36+
val trackSection = infra.getTrackFromChunk(it.value.value)
37+
invalidTrackSections.contains(trackSection)
38+
}
39+
40+
for ((_, from, to) in invalidTrackSectionsInsideBlock) {
41+
res.add(Pathfinding.Range(Offset(from.distance), Offset(to.distance)))
42+
}
43+
return res
44+
}
45+
}

core/src/main/kotlin/fr/sncf/osrd/stdcm/graph/STDCMPathfinding.kt

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,7 @@ fun findPath(
5959
standardAllowance: AllowanceValue?,
6060
pathfindingTimeout: Double,
6161
temporarySpeedLimitManager: TemporarySpeedLimitManager,
62+
zones: HashSet<String>?,
6263
): STDCMResult? {
6364
return STDCMPathfinding(
6465
fullInfra,
@@ -74,6 +75,7 @@ fun findPath(
7475
standardAllowance,
7576
pathfindingTimeout,
7677
temporarySpeedLimitManager,
78+
zones,
7779
)
7880
.findPath()
7981
}
@@ -92,6 +94,7 @@ class STDCMPathfinding(
9294
standardAllowance: AllowanceValue?,
9395
private val pathfindingTimeout: Double = Pathfinding.TIMEOUT,
9496
private val temporarySpeedLimitManager: TemporarySpeedLimitManager,
97+
private val zones: HashSet<String>?,
9598
) {
9699

97100
private var starts: Set<STDCMNode> = HashSet()
@@ -116,7 +119,7 @@ class STDCMPathfinding(
116119
runInputSanityChecks()
117120

118121
val constraints =
119-
ConstraintCombiner(initConstraints(fullInfra, rollingStock).toMutableList())
122+
ConstraintCombiner(initConstraints(fullInfra, rollingStock, zones).toMutableList())
120123

121124
assert(steps.last().stop) { "The last stop is supposed to be an actual stop" }
122125
starts = getStartNodes(graph, listOf(constraints))
Lines changed: 103 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,103 @@
1+
package fr.sncf.osrd.pathfinding.constraints
2+
3+
import fr.sncf.osrd.pathfinding.Pathfinding
4+
import fr.sncf.osrd.sim_infra.api.Block
5+
import fr.sncf.osrd.sim_infra.api.BlockId
6+
import fr.sncf.osrd.sim_infra.api.TrackChunk
7+
import fr.sncf.osrd.utils.Direction
8+
import fr.sncf.osrd.utils.Helpers
9+
import fr.sncf.osrd.utils.units.Length
10+
import fr.sncf.osrd.utils.units.Offset
11+
import fr.sncf.osrd.utils.units.meters
12+
import java.util.stream.Stream
13+
import org.assertj.core.api.Assertions
14+
import org.junit.jupiter.api.BeforeAll
15+
import org.junit.jupiter.api.TestInstance
16+
import org.junit.jupiter.params.ParameterizedTest
17+
import org.junit.jupiter.params.provider.Arguments
18+
import org.junit.jupiter.params.provider.MethodSource
19+
20+
@TestInstance(TestInstance.Lifecycle.PER_CLASS)
21+
class ZonesConstraintsTest {
22+
23+
private var zonesConstraints: ZonesConstraints? = null
24+
25+
private var ta0Chunk0Length: Length<TrackChunk> = Length(0.meters)
26+
private var ta0Chunk0block: BlockId? = null
27+
private var ta0Chunk1block: BlockId? = null
28+
private var ta1Chunk0block: BlockId? = null
29+
30+
@BeforeAll
31+
fun setup() {
32+
val infra = Helpers.smallInfra
33+
val zones =
34+
hashSetOf(
35+
"TA1",
36+
"TA2",
37+
"TA3",
38+
"TA4",
39+
"TA5",
40+
"TA6",
41+
"TA7",
42+
"TG0",
43+
"TG1",
44+
"TG2",
45+
"TG3",
46+
"TG4",
47+
"TG5",
48+
"TH0",
49+
"TH1",
50+
"TE1",
51+
"TE2",
52+
"TE3",
53+
"TD0",
54+
"TD1",
55+
"TD2",
56+
"TD3",
57+
)
58+
zonesConstraints = ZonesConstraints(infra.blockInfra, infra.rawInfra, zones)
59+
val ta0 = infra.rawInfra.getTrackSectionFromName("TA0")!!
60+
val ta0Chunks = infra.rawInfra.getTrackSectionChunks(ta0)
61+
assert(ta0Chunks.size == 2)
62+
val ta0Chunk0 =
63+
if (infra.rawInfra.getTrackChunkOffset(ta0Chunks[0]) <= Offset(0.meters)) ta0Chunks[0]
64+
else ta0Chunks[1]
65+
val ta0Chunk1 =
66+
if (infra.rawInfra.getTrackChunkOffset(ta0Chunks[0]) <= Offset(0.meters)) ta0Chunks[1]
67+
else ta0Chunks[0]
68+
ta0Chunk0Length = infra.rawInfra.getTrackChunkLength(ta0Chunk0)
69+
ta0Chunk0block =
70+
infra.blockInfra.getBlocksFromTrackChunk(ta0Chunk0, Direction.INCREASING).getAtIndex(0)
71+
ta0Chunk1block =
72+
infra.blockInfra.getBlocksFromTrackChunk(ta0Chunk1, Direction.INCREASING).getAtIndex(0)
73+
val ta1 = infra.rawInfra.getTrackSectionFromName("TA1")!!
74+
val ta1Chunks = infra.rawInfra.getTrackSectionChunks(ta1)
75+
val ta1Chunk0 = ta1Chunks[0]
76+
ta1Chunk0block =
77+
infra.blockInfra.getBlocksFromTrackChunk(ta1Chunk0, Direction.INCREASING).getAtIndex(0)
78+
}
79+
80+
@ParameterizedTest
81+
@MethodSource("testZonesArgs")
82+
fun testDeadSectionAndElectrificationBlockedRanges(
83+
blockId: BlockId,
84+
expectedBlockedRanges: Collection<Pathfinding.Range<Block>>,
85+
) {
86+
val blockedRanges = zonesConstraints!!.apply(blockId)
87+
Assertions.assertThat(blockedRanges).isEqualTo(expectedBlockedRanges)
88+
}
89+
90+
private fun testZonesArgs(): Stream<Arguments> {
91+
return Stream.of(
92+
Arguments.of(
93+
ta0Chunk0block!!.index.toInt(),
94+
setOf(Pathfinding.Range(Length(0.meters), ta0Chunk0Length)),
95+
),
96+
Arguments.of(
97+
ta0Chunk1block!!.index.toInt(),
98+
setOf(Pathfinding.Range(Length<TrackChunk>(0.meters), Length(180.meters))),
99+
),
100+
Arguments.of(ta1Chunk0block!!.index.toInt(), HashSet<Any>()),
101+
)
102+
}
103+
}

core/src/test/kotlin/fr/sncf/osrd/stdcm/STDCMPathfindingBuilder.kt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -177,6 +177,7 @@ data class STDCMPathfindingBuilder(
177177
standardAllowance,
178178
pathfindingTimeout,
179179
temporarySpeedLimitManager,
180+
null,
180181
)
181182
}
182183
}

0 commit comments

Comments
 (0)