Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ class WikipediaFirebaseMessagingService : FirebaseMessagingService() {
L.e(t)
}) {
for (lang in WikipediaApp.instance.languageState.appLanguageCodes) {
val csrfToken = withContext(Dispatchers.IO) { CsrfTokenClient.getToken(WikiSite.forLanguageCode(lang)).blockingSingle() }
val csrfToken = withContext(Dispatchers.IO) { CsrfTokenClient.getTokenBlocking(WikiSite.forLanguageCode(lang)) }
if (lang == WikipediaApp.instance.appOrSystemLanguageCode) {
subscribeWithCsrf(csrfToken)
}
Expand Down
52 changes: 52 additions & 0 deletions app/src/main/java/org/wikipedia/csrf/CsrfTokenClient.kt
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,9 @@ package org.wikipedia.csrf
import io.reactivex.rxjava3.core.Completable
import io.reactivex.rxjava3.core.Observable
import io.reactivex.rxjava3.schedulers.Schedulers
import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.runBlocking
import kotlinx.coroutines.withContext
import org.wikipedia.WikipediaApp
import org.wikipedia.auth.AccountUtil
import org.wikipedia.concurrency.FlowEventBus
Expand All @@ -22,6 +24,56 @@ object CsrfTokenClient {
private const val ANON_TOKEN = "+\\"
private const val MAX_RETRIES = 3

suspend fun getTokenBlocking(site: WikiSite, type: String = "csrf", svc: Service? = null): String {
var token = ""
withContext(Dispatchers.IO) {
try {
MUTEX.acquire()
val service = svc ?: ServiceFactory.get(site)
var lastError: Throwable? = null
for (retry in 0 until MAX_RETRIES) {
if (retry > 0) {
// Log in explicitly
try {
LoginClient().loginBlocking(site, AccountUtil.userName, AccountUtil.password!!, "")
} catch (e: Exception) {
L.e(e)
lastError = e
}
}
try {
val tokenResponse = service.getToken(type)
token = if (type == "rollback") {
tokenResponse.query?.rollbackToken().orEmpty()
} else {
tokenResponse.query?.csrfToken().orEmpty()
}
if (AccountUtil.isLoggedIn && token == ANON_TOKEN) {
throw RuntimeException("App believes we're logged in, but got anonymous token.")
}
} catch (e: Exception) {
L.e(e)
lastError = e
}
if (token.isEmpty() || (AccountUtil.isLoggedIn && token == ANON_TOKEN)) {
continue
}
break
}
if (token.isEmpty() || (AccountUtil.isLoggedIn && token == ANON_TOKEN)) {
if (token == ANON_TOKEN) {
bailWithLogout()
}
throw lastError ?: IOException("Invalid token, or login failure.")
}
} finally {
MUTEX.release()
}
}
return token
}

// TODO: remove this after all usages are converted to coroutines
fun getToken(site: WikiSite, type: String = "csrf", svc: Service? = null): Observable<String> {
return Observable.create { emitter ->
var token = ""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ class NotificationPollBroadcastReceiver : BroadcastReceiver() {

suspend fun markRead(wiki: WikiSite, notifications: List<Notification>, unread: Boolean) {
withContext(Dispatchers.IO) {
val token = CsrfTokenClient.getToken(wiki).blockingSingle()
val token = CsrfTokenClient.getTokenBlocking(wiki)
notifications.windowed(50, partialWindows = true).forEach { window ->
val idListStr = window.joinToString("|")
ServiceFactory.get(wiki).markRead(token, if (unread) null else idListStr, if (unread) idListStr else null)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,6 @@ import androidx.work.NetworkType
import androidx.work.OneTimeWorkRequestBuilder
import androidx.work.WorkManager
import androidx.work.WorkerParameters
import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.withContext
import org.wikipedia.WikipediaApp
import org.wikipedia.csrf.CsrfTokenClient
import org.wikipedia.dataclient.ServiceFactory
Expand All @@ -32,7 +30,9 @@ class PollNotificationWorker(
Result.success()
} catch (t: Throwable) {
if (t is MwException && t.error.title == "login-required") {
assertLoggedIn()
// Attempt to get a dummy CSRF token, which should automatically re-log us in explicitly,
// and should automatically log us out if the credentials are no longer valid.
CsrfTokenClient.getTokenBlocking(WikipediaApp.instance.wikiSite)
}
L.e(t)
Result.failure()
Expand All @@ -59,19 +59,6 @@ class PollNotificationWorker(
}
}

private suspend fun assertLoggedIn() {
// Attempt to get a dummy CSRF token, which should automatically re-log us in explicitly,
// and should automatically log us out if the credentials are no longer valid.
try {
withContext(Dispatchers.IO) {
CsrfTokenClient.getToken(WikipediaApp.instance.wikiSite).blockingSingle()
}
} catch (e: Throwable) {
// Ignore the exception.
L.e(e)
}
}

companion object {
fun schedulePollNotificationJob(context: Context) {
val constraints = Constraints.Builder()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ class SuggestedEditsImageRecsFragmentViewModel(bundle: Bundle) : ViewModel() {
private suspend fun invalidateRecommendation(token: String?, accepted: Boolean, revId: Long, reasonCodes: List<Int>?) {

withContext(Dispatchers.IO) {
val csrfToken = token ?: CsrfTokenClient.getToken(pageTitle.wikiSite).blockingSingle()
val csrfToken = token ?: CsrfTokenClient.getTokenBlocking(pageTitle.wikiSite)

// Attempt to call the AddImageFeedback API first, and if it fails, try the
// growthinvalidateimagerecommendation API instead.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ class SuggestedEditsImageTagsViewModel : ViewModel() {
viewModelScope.launch(CoroutineExceptionHandler { _, throwable ->
_actionState.value = Resource.Error(throwable)
}) {
val csrfToken = CsrfTokenClient.getToken(Constants.commonsWikiSite).blockingSingle()
val csrfToken = CsrfTokenClient.getTokenBlocking(Constants.commonsWikiSite)
val mId = "M" + page.pageId
var claimStr = "{\"claims\":["
var commentStr = "/* add-depicts: "
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ object NotificationDirectReplyHelper {
L.e(throwable)
fallBackToTalkPage(context, title)
}) {
val token = async { CsrfTokenClient.getToken(wiki).blockingFirst() }
val token = async { CsrfTokenClient.getTokenBlocking(wiki) }
val talkPageResponse = async { ServiceFactory.getRest(wiki).getTalkPage(title.prefixedText) }
val topic = talkPageResponse.await().topics?.find {
it.id > 0 && it.html?.trim().orEmpty() == StringUtil.removeUnderscores(title.fragment)
Expand Down
10 changes: 2 additions & 8 deletions app/src/main/java/org/wikipedia/talk/TalkTopicsViewModel.kt
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,8 @@ import androidx.lifecycle.ViewModel
import androidx.lifecycle.ViewModelProvider
import androidx.lifecycle.viewModelScope
import kotlinx.coroutines.CoroutineExceptionHandler
import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.flow.MutableStateFlow
import kotlinx.coroutines.launch
import kotlinx.coroutines.withContext
import org.wikipedia.WikipediaApp
import org.wikipedia.analytics.eventplatform.WatchlistAnalyticsHelper
import org.wikipedia.csrf.CsrfTokenClient
Expand Down Expand Up @@ -126,9 +124,7 @@ class TalkTopicsViewModel(var pageTitle: PageTitle, private val sidePanel: Boole

fun undoSave(newRevisionId: Long, undoneSubject: CharSequence, undoneBody: CharSequence) {
viewModelScope.launch(actionHandler) {
val token = withContext(Dispatchers.IO) {
CsrfTokenClient.getToken(pageTitle.wikiSite).blockingFirst()
}
val token = CsrfTokenClient.getTokenBlocking(pageTitle.wikiSite)
val undoResponse = ServiceFactory.get(pageTitle.wikiSite).postUndoEdit(title = pageTitle.prefixedText, undoRevId = newRevisionId, token = token)
actionState.value = ActionState.UndoEdit(undoResponse, undoneSubject, undoneBody)
}
Expand Down Expand Up @@ -156,9 +152,7 @@ class TalkTopicsViewModel(var pageTitle: PageTitle, private val sidePanel: Boole

fun subscribeTopic(commentName: String, subscribed: Boolean) {
viewModelScope.launch(actionHandler) {
val token = withContext(Dispatchers.IO) {
CsrfTokenClient.getToken(pageTitle.wikiSite).blockingFirst()
}
val token = CsrfTokenClient.getTokenBlocking(pageTitle.wikiSite)
ServiceFactory.get(pageTitle.wikiSite).subscribeTalkPageTopic(pageTitle.prefixedText, commentName, token, if (!subscribed) true else null)
}
}
Expand Down
30 changes: 0 additions & 30 deletions app/src/test/java/org/wikipedia/csrf/CsrfTokenClientTest.java

This file was deleted.

71 changes: 71 additions & 0 deletions app/src/test/java/org/wikipedia/csrf/CsrfTokenClientTest.kt
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
package org.wikipedia.csrf

import kotlinx.coroutines.runBlocking
import org.hamcrest.MatcherAssert
import org.hamcrest.Matchers
import org.junit.Test
import org.wikipedia.test.MockRetrofitTest

class CsrfTokenClientTest : MockRetrofitTest() {
@Test
@Throws(Throwable::class)
fun testRequestSuccess() {
val expected = "b6f7bd58c013ab30735cb19ecc0aa08258122cba+\\"
enqueueFromFile("csrf_token.json")
CsrfTokenClient.getToken(wikiSite, "csrf", apiService).test().await()
.assertComplete().assertNoErrors()
.assertValue { result -> result == expected }
}

@Test
fun testRequestSuccessCoroutine() {
val expected = "b6f7bd58c013ab30735cb19ecc0aa08258122cba+\\"
enqueueFromFile("csrf_token.json")
runBlocking {
val result = CsrfTokenClient.getTokenBlocking(wikiSite, "csrf", apiService)
assert(result == expected)
}
}

@Test
@Throws(Throwable::class)
fun testRequestResponseApiError() {
enqueueFromFile("api_error.json")
CsrfTokenClient.getToken(wikiSite, "csrf", apiService).test().await()
.assertError(Exception::class.java)
}

@Test
@Throws(Throwable::class)
fun testRequestResponseApiErrorCoroutine() {
enqueueFromFile("api_error.json")
runBlocking {
try {
CsrfTokenClient.getTokenBlocking(wikiSite, "csrf", apiService)
} catch (e: Exception) {
MatcherAssert.assertThat(e, Matchers.notNullValue())
}
}
}

@Test
@Throws(Throwable::class)
fun testRequestResponseFailure() {
enqueue404()
CsrfTokenClient.getToken(wikiSite, "csrf", apiService).test().await()
.assertError(Exception::class.java)
}

@Test
@Throws(Throwable::class)
fun testRequestResponseFailureCoroutine() {
enqueue404()
runBlocking {
try {
CsrfTokenClient.getTokenBlocking(wikiSite, "csrf", apiService)
} catch (e: Exception) {
MatcherAssert.assertThat(e, Matchers.notNullValue())
}
}
}
}
Loading