Skip to content

Commit 2e53c49

Browse files
Make context switch for deserialized traces.
!!!WARNING!!! !!!WARNING!!! !!!WARNING!!! This change CHANGES API but this change IS NOT ENFORCED statically. After this change all accesses to deserialized trace points MUST be wrapped into LazyTraceReader.inContext() or withTraceContext() calls or preceded by switchGlobalContext() call. !!!WARNING!!! !!!WARNING!!! !!!WARNING!!!
1 parent 37f2938 commit 2e53c49

File tree

3 files changed

+156
-71
lines changed

3 files changed

+156
-71
lines changed

common/src/main/org/jetbrains/lincheck/trace/TraceContext.kt

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,32 @@ import org.jetbrains.lincheck.descriptors.MethodSignature
1919
import org.jetbrains.lincheck.descriptors.VariableDescriptor
2020
import org.jetbrains.lincheck.descriptors.Types
2121

22-
val TRACE_CONTEXT: TraceContext = TraceContext()
22+
var TRACE_CONTEXT: TraceContext = TraceContext()
23+
private set
24+
25+
val TRACE_CONTEXT_LOCK = Any()
26+
27+
fun setGlobalTraceContext(traceContext: TraceContext) {
28+
TRACE_CONTEXT = traceContext
29+
}
30+
31+
/**
32+
* Run [block] with global trace context set to [context], restore global after run.
33+
*
34+
* This function is effectively `synchronized`, it prevents concurrent access
35+
* to global trace context to avoid races.
36+
*/
37+
inline fun <R> withTraceContext(context: TraceContext, block: () -> R): R {
38+
synchronized(TRACE_CONTEXT_LOCK) {
39+
val saveContext = TRACE_CONTEXT
40+
setGlobalTraceContext(context)
41+
try {
42+
return block()
43+
} finally {
44+
setGlobalTraceContext(saveContext)
45+
}
46+
}
47+
}
2348

2449
const val UNKNOWN_CODE_LOCATION_ID = -1
2550

trace/src/main/org/jetbrains/lincheck/trace/Deserialization.kt

Lines changed: 123 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -193,8 +193,7 @@ class LazyTraceReader(
193193
tracePoint = readTracePointWithChildAddresses()
194194
)
195195

196-
// TODO: Create new
197-
val context: TraceContext = TRACE_CONTEXT
196+
val context: TraceContext = TraceContext()
198197

199198
private val dataStream: SeekableInputStream
200199
private val data: SeekableDataInput
@@ -220,81 +219,125 @@ class LazyTraceReader(
220219
index?.close()
221220
}
222221

222+
/**
223+
* Run block with trace context belonging to this reader and restore global context afterward.
224+
*
225+
* @See withTraceContext
226+
*/
227+
inline fun <R> inContext(block: () -> R): R = withTraceContext(context, block)
228+
229+
/**
230+
* Switch global context to one belonging to this reader.
231+
*
232+
* This function is inherently unsafe: it doesn't prevent future switches of global context
233+
* by other code or threads.
234+
*
235+
* Please, use [inContext] instead if possible.
236+
*/
237+
@Deprecated("Use inContext if possible.")
238+
fun switchGlobalContext() = setGlobalTraceContext(context)
239+
240+
/**
241+
* Read top-level call tracepoints for each thread.
242+
*
243+
* This method runs in context of this reader's context and doesn't require [inContext].
244+
*/
223245
fun readRoots(): List<TRTracePoint> {
224-
var start = System.currentTimeMillis()
225-
loadContext()
226-
Logger.debug { "Context loaded in ${System.currentTimeMillis() - start} ms" }
227-
start = System.currentTimeMillis()
228-
229-
val roots = mutableMapOf<Int, TRTracePoint>()
230-
231-
dataBlocks.forEach {
232-
val (threadId, blocks) = it
233-
data.seek(blocks.first().physicalStart)
234-
val kind = data.readKind()
235-
check(kind == ObjectKind.BLOCK_START) { "Thread $threadId block 0 has wrong start: $kind" }
236-
val blockId = data.readInt()
237-
check(blockId == threadId) { "Thread $threadId block 0 has wrong idt: $blockId" }
238-
239-
val tracepoints = mutableListOf<TRTracePoint>()
240-
loadTracePoints(
241-
threadId = threadId,
242-
maxRead = Integer.MAX_VALUE,
243-
reader = this::readTracePointWithPostprocessor,
244-
registrator = { _, tracePoint, _ ->
245-
tracepoints.add(tracePoint)
246-
}
247-
)
248-
if (tracepoints.isEmpty()) {
249-
System.err.println("Thread $threadId doesn't write any tracepoints")
250-
} else {
251-
if (tracepoints.size > 1) {
252-
System.err.println("Thread $threadId wrote too many root tracepoints: ${tracepoints.size}")
246+
return withTraceContext(context) {
247+
var start = System.currentTimeMillis()
248+
loadContext()
249+
Logger.debug { "Context loaded in ${System.currentTimeMillis() - start} ms" }
250+
start = System.currentTimeMillis()
251+
252+
val roots = mutableMapOf<Int, TRTracePoint>()
253+
254+
dataBlocks.forEach {
255+
val (threadId, blocks) = it
256+
data.seek(blocks.first().physicalStart)
257+
val kind = data.readKind()
258+
check(kind == ObjectKind.BLOCK_START) { "Thread $threadId block 0 has wrong start: $kind" }
259+
val blockId = data.readInt()
260+
check(blockId == threadId) { "Thread $threadId block 0 has wrong idt: $blockId" }
261+
262+
val tracepoints = mutableListOf<TRTracePoint>()
263+
loadTracePoints(
264+
threadId = threadId,
265+
maxRead = Integer.MAX_VALUE,
266+
reader = this::readTracePointWithPostprocessor,
267+
registrator = { _, tracePoint, _ ->
268+
tracepoints.add(tracePoint)
269+
}
270+
)
271+
if (tracepoints.isEmpty()) {
272+
System.err.println("Thread $threadId doesn't write any tracepoints")
273+
} else {
274+
if (tracepoints.size > 1) {
275+
System.err.println("Thread $threadId wrote too many root tracepoints: ${tracepoints.size}")
276+
}
277+
roots[threadId] = tracepoints.first()
253278
}
254-
roots[threadId] = tracepoints.first()
255279
}
256-
}
257-
Logger.debug { "Roots loaded in ${System.currentTimeMillis() - start} ms" }
280+
Logger.debug { "Roots loaded in ${System.currentTimeMillis() - start} ms" }
258281

259-
return roots.entries.sortedBy { it.key }.map { (_, tracePoint) -> tracePoint }
282+
roots.entries.sortedBy { it.key }.map { (_, tracePoint) -> tracePoint }
283+
}
260284
}
261285

286+
/**
287+
* Load one level of children of giver call tracepoint.
288+
*
289+
* This method runs in context of this reader's context and doesn't require [inContext].
290+
*/
262291
fun loadAllChildren(parent: TRMethodCallTracePoint) {
263-
val (start, end) = callTracepointChildren[parent.eventId]
264-
?: error("TRMethodCallTracePoint ${parent.eventId} is not found in index")
292+
withTraceContext(context) {
293+
val (start, end) = callTracepointChildren[parent.eventId]
294+
?: error("TRMethodCallTracePoint ${parent.eventId} is not found in index")
265295

266-
data.seek(calculatePhysicalOffset(parent.threadId, start))
296+
data.seek(calculatePhysicalOffset(parent.threadId, start))
267297

268-
loadTracePoints(
269-
threadId = parent.threadId,
270-
maxRead = Integer.MAX_VALUE,
271-
reader = this::readTracePointWithPostprocessor,
272-
registrator = { idx, tracePoint, _ ->
273-
parent.loadChild(idx, tracePoint)
274-
}
275-
)
298+
loadTracePoints(
299+
threadId = parent.threadId,
300+
maxRead = Integer.MAX_VALUE,
301+
reader = this::readTracePointWithPostprocessor,
302+
registrator = { idx, tracePoint, _ ->
303+
parent.loadChild(idx, tracePoint)
304+
}
305+
)
276306

277-
val actualFooterPos = data.position() - 1 // 1 is size of object kind
278-
check(actualFooterPos == calculatePhysicalOffset(parent.threadId, end)) {
279-
"Input contains broken data: expected Tracepoint Footer for event ${parent.eventId} at position $end, got $actualFooterPos"
307+
val actualFooterPos = data.position() - 1 // 1 is size of object kind
308+
check(actualFooterPos == calculatePhysicalOffset(parent.threadId, end)) {
309+
"Input contains broken data: expected Tracepoint Footer for event ${parent.eventId} at position $end, got $actualFooterPos"
310+
}
280311
}
281312
}
282313

314+
/**
315+
* Load child of giver call tracepoint at given index.
316+
*
317+
* This method runs in context of this reader's context and doesn't require [inContext].
318+
*/
283319
fun loadChild(parent: TRMethodCallTracePoint, childIdx: Int): Unit = loadChildrenRange(parent, childIdx, 1)
284320

321+
/**
322+
* Load range of children of giver call tracepoint at given span.
323+
*
324+
* This method runs in context of this reader's context and doesn't require [inContext].
325+
*/
285326
fun loadChildrenRange(parent: TRMethodCallTracePoint, from: Int, count: Int) {
286-
require(from in 0 ..< parent.events.size) { "From index $from must be in range 0..<${parent.events.size}" }
287-
require(count in 1 .. parent.events.size - from) { "Count $count must be in range 1..${parent.events.size - from}" }
327+
withTraceContext(context) {
328+
require(from in 0..<parent.events.size) { "From index $from must be in range 0..<${parent.events.size}" }
329+
require(count in 1..parent.events.size - from) { "Count $count must be in range 1..${parent.events.size - from}" }
288330

289-
data.seek(calculatePhysicalOffset(parent.threadId, parent.getChildAddress(from)))
290-
loadTracePoints(
291-
threadId = parent.threadId,
292-
maxRead = count,
293-
reader = this::readTracePointWithPostprocessor,
294-
registrator = { idx, tracePoint, _ ->
295-
parent.loadChild(idx + from, tracePoint)
296-
}
297-
)
331+
data.seek(calculatePhysicalOffset(parent.threadId, parent.getChildAddress(from)))
332+
loadTracePoints(
333+
threadId = parent.threadId,
334+
maxRead = count,
335+
reader = this::readTracePointWithPostprocessor,
336+
registrator = { idx, tracePoint, _ ->
337+
parent.loadChild(idx + from, tracePoint)
338+
}
339+
)
340+
}
298341
}
299342

300343
fun getChildAndRestorePosition(parent: TRMethodCallTracePoint, childIdx: Int): TRTracePoint? {
@@ -436,7 +479,6 @@ class LazyTraceReader(
436479
}
437480

438481
private fun loadContextWithoutIndex() {
439-
val context = TRACE_CONTEXT
440482
// Two Longs is header
441483
data.seek((Long.SIZE_BYTES * 2).toLong())
442484
loadAllObjectsDeep(
@@ -570,14 +612,31 @@ class LazyTraceReader(
570612
data class TraceWithContext(
571613
val context: TraceContext,
572614
val roots: List<TRTracePoint>
573-
)
615+
) {
616+
/**
617+
* Run block with trace context belonging to this trace and restore global context afterward.
618+
*
619+
* @See withTraceContext
620+
*/
621+
inline fun <R> inContext(block: () -> R): R = withTraceContext(context, block)
622+
623+
/**
624+
* Switch global context to one belonging to this trace.
625+
*
626+
* This function is inherently unsafe: it doesn't prevent future switches of global context
627+
* by other code or threads.
628+
*
629+
* Please, use [inContext] instead if possible.
630+
*/
631+
@Deprecated("Use inContext if possible.")
632+
fun switchGlobalContext() = setGlobalTraceContext(context)
633+
}
574634

575635
fun loadRecordedTrace(inp: InputStream): TraceWithContext {
576636
DataInputStream(inp.buffered(INPUT_BUFFER_SIZE)).use { input ->
577637
checkDataHeader(input)
578638

579-
// TODO: Create empty fresh context
580-
val context = TRACE_CONTEXT
639+
val context = TraceContext()
581640
val roots = mutableMapOf<Int, MutableList<TRTracePoint>>()
582641

583642
loadAllObjectsDeep(

trace/src/main/org/jetbrains/lincheck/trace/Printing.kt

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -24,12 +24,13 @@ fun printRecorderTrace(fileName: String?, context: TraceContext, rootCallsPerThr
2424
)
2525

2626
fun printRecorderTrace(output: OutputStream, context: TraceContext, rootCallsPerThread: List<TRTracePoint>, verbose: Boolean) {
27-
check(context == TRACE_CONTEXT) { "Now only global TRACE_CONTEXT is supported" }
28-
PrintStream(output.buffered(OUTPUT_BUFFER_SIZE)).use { output ->
29-
val appendable = DefaultTRTextAppendable(output, verbose)
30-
rootCallsPerThread.forEachIndexed { i, root ->
31-
output.println("# Thread ${i+1}")
32-
printTRPoint(appendable, root, 0)
27+
withTraceContext(context) {
28+
PrintStream(output.buffered(OUTPUT_BUFFER_SIZE)).use { output ->
29+
val appendable = DefaultTRTextAppendable(output, verbose)
30+
rootCallsPerThread.forEachIndexed { i, root ->
31+
output.println("# Thread ${i + 1}")
32+
printTRPoint(appendable, root, 0)
33+
}
3334
}
3435
}
3536
}

0 commit comments

Comments
 (0)