Skip to content

Commit 98d2c58

Browse files
authored
Rework resuming (#32)
* Add tracking of Session state on update/connect * Add delay to moving orphaned Links * Implement post-resume synchronisation * Add ResumeSynchronizationEvent NB: This is potentially breaking as it replaces a `sealed` attribute * Get session via PATCH instead of GET * Add LavalinkNode.getCachedPlayers() * Abandon resuming after failed reconnection attempt * Handle reconnect even on abnormal closure * Change how resume reconnect failure is handled * Defensively call onResumeReconnectFailed(node) * Fix removing cached player from new node instead of old Also manages the LinkState a bit more consistently * Create new Link instances upon resume synchronization Only if one doesn't already exist for the guild
1 parent 6ccb3ec commit 98d2c58

File tree

5 files changed

+128
-13
lines changed

5 files changed

+128
-13
lines changed

src/main/kotlin/dev/arbjerg/lavalink/client/LavalinkClient.kt

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
package dev.arbjerg.lavalink.client
22

3+
import dev.arbjerg.lavalink.client.event.ClientEvent
34
import dev.arbjerg.lavalink.client.loadbalancing.ILoadBalancer
45
import dev.arbjerg.lavalink.client.loadbalancing.VoiceRegion
56
import dev.arbjerg.lavalink.client.loadbalancing.builtin.DefaultLoadBalancer
6-
import dev.arbjerg.lavalink.client.event.ClientEvent
77
import dev.arbjerg.lavalink.client.player.LavalinkPlayer
88
import dev.arbjerg.lavalink.internal.ReconnectTask
99
import dev.arbjerg.lavalink.protocol.v4.VoiceState
@@ -110,6 +110,19 @@ class LavalinkClient(val userId: Long) : Closeable, Disposable {
110110
Link(guildId, loadBalancer.selectNode(region, guildId))
111111
}
112112
}
113+
/**
114+
* Get or crate a [Link] between a guild and a node.
115+
*
116+
* The requested [LavalinkNode] is only assigned if a new [Link] is created
117+
*
118+
* @param guildId The id of the guild
119+
* @param node the node to initially assign the [Link] to if a new one is created
120+
*/
121+
internal fun getOrCreateLink(guildId: Long, node: LavalinkNode): Link {
122+
return linkMap.getOrPut(guildId) {
123+
Link(guildId, node)
124+
}
125+
}
113126

114127
/**
115128
* Returns a [Link] if it exists in the cache.
@@ -169,6 +182,21 @@ class LavalinkClient(val userId: Long) : Closeable, Disposable {
169182
return
170183
}
171184

185+
val session = node.cachedSession
186+
val canResume = session != null && session.resuming && session.timeoutSeconds > 0
187+
if (canResume) {
188+
// This causes onResumeReconnectFailed(node) to be called if the next reconnect fails
189+
node.ws.onResumableConnectionDisconnected()
190+
} else {
191+
transferNodes(node)
192+
}
193+
}
194+
195+
internal fun onResumeReconnectFailed(node: LavalinkNode) {
196+
transferNodes(node)
197+
}
198+
199+
private fun transferNodes(node: LavalinkNode) {
172200
linkMap.forEach { (_, link) ->
173201
if (link.node == node) {
174202
val voiceRegion = link.cachedPlayer?.voiceRegion

src/main/kotlin/dev/arbjerg/lavalink/client/LavalinkNode.kt

Lines changed: 47 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package dev.arbjerg.lavalink.client
22

33
import dev.arbjerg.lavalink.client.event.ClientEvent
4+
import dev.arbjerg.lavalink.client.event.ResumeSynchronizationEvent
45
import dev.arbjerg.lavalink.client.http.HttpBuilder
56
import dev.arbjerg.lavalink.client.player.*
67
import dev.arbjerg.lavalink.client.player.Track
@@ -17,12 +18,14 @@ import kotlinx.serialization.serializer
1718
import okhttp3.Call
1819
import okhttp3.OkHttpClient
1920
import okhttp3.Response
21+
import org.slf4j.LoggerFactory
2022
import reactor.core.Disposable
2123
import reactor.core.publisher.Flux
2224
import reactor.core.publisher.Mono
2325
import reactor.core.publisher.Sinks
2426
import reactor.core.publisher.Sinks.Many
2527
import reactor.kotlin.core.publisher.toMono
28+
import reactor.util.retry.Retry
2629
import java.io.Closeable
2730
import java.io.IOException
2831
import java.time.Duration
@@ -39,12 +42,14 @@ class LavalinkNode(
3942
private val nodeOptions: NodeOptions,
4043
val lavalink: LavalinkClient
4144
) : Disposable, Closeable {
45+
private val logger = LoggerFactory.getLogger(LavalinkNode::class.java)
4246
// "safe" uri with all paths removed
4347
val baseUri = "${nodeOptions.serverUri.scheme}://${nodeOptions.serverUri.host}:${nodeOptions.serverUri.port}"
4448

4549
val name = nodeOptions.name
4650
val regionFilter = nodeOptions.regionFilter
4751
val password = nodeOptions.password
52+
internal var cachedSession: Session? = null
4853

4954
var sessionId: String? = nodeOptions.sessionId
5055
internal set
@@ -237,10 +242,12 @@ class LavalinkNode(
237242
}
238243

239244
/**
240-
* Enables resuming. This causes Lavalink to continue playing for [duration], during which
245+
* Enables resuming. This causes Lavalink to continue playing for [timeout] amount of time, during which
241246
* we can reconnect without losing our session data. */
242247
fun enableResuming(timeout: Duration): Mono<Session> {
243-
return rest.patchSession(Session(resuming = true, timeout.seconds))
248+
return rest.patchSession(Session(resuming = true, timeout.seconds)).doOnSuccess {
249+
cachedSession = it
250+
}
244251
}
245252

246253
/**
@@ -249,7 +256,9 @@ class LavalinkNode(
249256
* This is the default behavior, reversing calls to [enableResuming].
250257
*/
251258
fun disableResuming(): Mono<Session> {
252-
return rest.patchSession(Session(resuming = false, timeoutSeconds = 0))
259+
return rest.patchSession(Session(resuming = false, timeoutSeconds = 0)).doOnSuccess {
260+
cachedSession = it
261+
}
253262
}
254263

255264
/**
@@ -429,6 +438,41 @@ class LavalinkNode(
429438
lavalink.transferOrphansTo(this)
430439
}
431440

441+
internal fun synchronizeAfterResume() {
442+
getPlayers()
443+
.retryWhen(Retry.fixedDelay(3, Duration.ofSeconds(1)))
444+
.map { players ->
445+
val remoteGuildIds = players.map { it.guildId }
446+
447+
players.forEach { player ->
448+
playerCache[player.guildId] = player
449+
450+
val link = lavalink.getOrCreateLink(player.guildId, node = this)
451+
if (link.node != this) return@forEach
452+
453+
link.state = if (player.state.connected) {
454+
LinkState.CONNECTED
455+
} else {
456+
LinkState.DISCONNECTED
457+
}
458+
}
459+
460+
val missingIds = playerCache.keys().toList() - remoteGuildIds.toSet()
461+
missingIds.forEach { guildId ->
462+
playerCache.remove(guildId)
463+
val link = lavalink.getLinkIfCached(guildId) ?: return@forEach
464+
if (link.node == this) link.state = LinkState.DISCONNECTED
465+
}
466+
467+
ResumeSynchronizationEvent(this, failureReason = null)
468+
}.doOnError {
469+
logger.error("Failure while attempting synchronization with $this", it)
470+
sink.tryEmitNext(ResumeSynchronizationEvent(this, failureReason = it))
471+
}.subscribe {
472+
sink.tryEmitNext(it)
473+
}
474+
}
475+
432476
override fun equals(other: Any?): Boolean {
433477
if (this === other) return true
434478
if (javaClass != other?.javaClass) return false

src/main/kotlin/dev/arbjerg/lavalink/client/event/events.kt

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ import dev.arbjerg.lavalink.client.player.toCustom
77
import dev.arbjerg.lavalink.protocol.v4.*
88
import dev.arbjerg.lavalink.protocol.v4.Message.EmittedEvent.TrackEndEvent.AudioTrackEndReason
99

10-
internal fun Message.toClientEvent(node: LavalinkNode) = when (this) {
10+
internal fun Message.toClientEvent(node: LavalinkNode): ClientEvent = when (this) {
1111
is Message.ReadyEvent -> ReadyEvent(node, resumed, sessionId)
1212
is Message.EmittedEvent.TrackEndEvent -> TrackEndEvent(node, guildId.toLong(), track.toCustom(), reason)
1313
is Message.EmittedEvent.TrackExceptionEvent -> TrackExceptionEvent(node, guildId.toLong(), track.toCustom(), exception.toCustom())
@@ -18,12 +18,19 @@ internal fun Message.toClientEvent(node: LavalinkNode) = when (this) {
1818
is Message.StatsEvent -> StatsEvent(node, frameStats, players, playingPlayers, uptime, memory, cpu)
1919
}
2020

21-
sealed class ClientEvent(open val node: LavalinkNode)
21+
abstract class ClientEvent(open val node: LavalinkNode)
2222

2323
// Normal events
2424
data class ReadyEvent(override val node: LavalinkNode, val resumed: Boolean, val sessionId: String)
2525
: ClientEvent(node)
2626

27+
/**
28+
* Represents a successful or failed synchronization after a [ReadyEvent] with [ReadyEvent.resumed] set to true.
29+
*
30+
* Whether it is successful depends on whether [failureReason] is null.
31+
*/
32+
data class ResumeSynchronizationEvent(override val node: LavalinkNode, val failureReason: Throwable?) : ClientEvent(node)
33+
2734
data class PlayerUpdateEvent(override val node: LavalinkNode, val guildId: Long, val state: PlayerState)
2835
: ClientEvent(node)
2936

src/main/kotlin/dev/arbjerg/lavalink/internal/LavalinkRestClient.kt

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,15 @@ class LavalinkRestClient(val node: LavalinkNode) {
7676
}.toMono()
7777
}
7878

79+
fun getSession(): Mono<Session> {
80+
return newRequest {
81+
path("/v4/sessions/${node.sessionId}")
82+
// Using patch with an empty object is a dirty hack because GET is not supported for this resource
83+
// 7 years younger me should have known better ~Freya
84+
patch("{}".toRequestBody("application/json".toMediaType()))
85+
}.toMono()
86+
}
87+
7988
/**
8089
* Make a request to the lavalink node. This is internal to keep it looking nice in kotlin. Java compatibility is in the node class.
8190
*/

src/main/kotlin/dev/arbjerg/lavalink/internal/LavalinkSocket.kt

Lines changed: 34 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ import java.io.EOFException
1818
import java.net.ConnectException
1919
import java.net.SocketException
2020
import java.net.SocketTimeoutException
21+
import java.util.concurrent.atomic.AtomicBoolean
2122

2223
class LavalinkSocket(private val node: LavalinkNode) : WebSocketListener(), Closeable {
2324
private val logger = LoggerFactory.getLogger(LavalinkSocket::class.java)
@@ -26,11 +27,15 @@ class LavalinkSocket(private val node: LavalinkNode) : WebSocketListener(), Clos
2627

2728
var mayReconnect = true
2829
var lastReconnectAttempt = 0L
30+
@Volatile
2931
private var reconnectsAttempted = 0
3032
val reconnectInterval: Int
3133
get() = reconnectsAttempted * 2000 - 2000
3234
var open: Boolean = false
3335
private set
36+
@Volatile
37+
private var hasEverConnected = false
38+
private val isAttemptingResume = AtomicBoolean(node.sessionId != null)
3439

3540
init {
3641
connect()
@@ -40,6 +45,8 @@ class LavalinkSocket(private val node: LavalinkNode) : WebSocketListener(), Clos
4045
logger.info("${node.name} has been connected!")
4146
open = true
4247
reconnectsAttempted = 0
48+
hasEverConnected = true
49+
isAttemptingResume.set(false)
4350
}
4451

4552
override fun onMessage(webSocket: WebSocket, text: String) {
@@ -75,6 +82,17 @@ class LavalinkSocket(private val node: LavalinkNode) : WebSocketListener(), Clos
7582
.subscribe()
7683
}
7784

85+
if (!resumed) {
86+
node.cachedSession = null
87+
}
88+
if (node.cachedSession == null) {
89+
node.rest.getSession().subscribe { node.cachedSession = it }
90+
}
91+
92+
if (resumed) {
93+
node.synchronizeAfterResume()
94+
}
95+
7896
// Move players from older, unavailable nodes to ourselves.
7997
node.transferOrphansToSelf()
8098
}
@@ -138,21 +156,18 @@ class LavalinkSocket(private val node: LavalinkNode) : WebSocketListener(), Clos
138156
if (mayReconnect) {
139157
logger.info("${node.name} disconnected, reconnecting in ${reconnectInterval / 1000} seconds")
140158
}
141-
142-
node.available = false
143-
open = false
144159
}
145160

146161
override fun onFailure(webSocket: WebSocket, t: Throwable, response: Response?) {
147-
handleFailureTrhowable(t)
162+
handleFailureThrowable(t)
148163

149164
node.available = false
150165
open = false
151166

152167
node.lavalink.onNodeDisconnected(node)
153168
}
154169

155-
private fun handleFailureTrhowable(t: Throwable) {
170+
private fun handleFailureThrowable(t: Throwable) {
156171
when(t) {
157172
is EOFException -> {
158173
logger.debug("Got disconnected from ${node.name}, trying to reconnect", t)
@@ -180,10 +195,19 @@ class LavalinkSocket(private val node: LavalinkNode) : WebSocketListener(), Clos
180195
logger.error("Unknown error on ${node.name}", t)
181196
}
182197
}
198+
199+
if (hasEverConnected && isAttemptingResume.getAndSet(false)) {
200+
try {
201+
node.lavalink.onResumeReconnectFailed(node)
202+
} catch (e: Exception) {
203+
logger.error("Exception after giving up on resuming", e)
204+
}
205+
}
183206
}
184207

185208
override fun onClosing(webSocket: WebSocket, code: Int, reason: String) {
186209
node.available = false
210+
open = false
187211
node.lavalink.onNodeDisconnected(node)
188212

189213
if (code == 1000) {
@@ -203,7 +227,6 @@ class LavalinkSocket(private val node: LavalinkNode) : WebSocketListener(), Clos
203227
reason
204228
)
205229
}
206-
207230
}
208231

209232
fun attemptReconnect() {
@@ -224,7 +247,7 @@ class LavalinkSocket(private val node: LavalinkNode) : WebSocketListener(), Clos
224247
.addHeader("Client-Name", "Lavalink-Client/${CLIENT_VERSION}")
225248
.addHeader("User-Id", node.lavalink.userId.toString())
226249
.apply {
227-
if (node.sessionId != null) {
250+
if (node.sessionId != null && isAttemptingResume.get()) {
228251
addHeader("Session-Id", node.sessionId!!)
229252
}
230253
}
@@ -241,4 +264,8 @@ class LavalinkSocket(private val node: LavalinkNode) : WebSocketListener(), Clos
241264
socket?.close(1000, "Client shutdown")
242265
socket?.cancel()
243266
}
267+
268+
internal fun onResumableConnectionDisconnected() {
269+
isAttemptingResume.set(true)
270+
}
244271
}

0 commit comments

Comments
 (0)