Skip to content

Commit dd433ce

Browse files
committed
pb: Add copy method generation
1 parent e91237c commit dd433ce

File tree

3 files changed

+99
-1
lines changed

3 files changed

+99
-1
lines changed

protoc-gen/common/src/main/kotlin/kotlinx/rpc/protoc/gen/core/CodeGenerator.kt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -126,7 +126,7 @@ open class CodeGenerator(
126126
condition: String,
127127
block: (CodeGenerator.() -> Unit),
128128
) {
129-
scope("$condition ->", block = block)
129+
scope("$condition ->", block = block, nlAfterClosed = false)
130130
}
131131

132132
private fun scopeWithSuffix(

protoc-gen/common/src/main/kotlin/kotlinx/rpc/protoc/gen/core/comments.kt

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,8 +82,13 @@ class Comment(
8282
leading.isEmpty() &&
8383
trailing.isEmpty()
8484
}
85+
86+
companion object {
87+
fun leading(comment: String): Comment = Comment(emptyList(), listOf(comment), emptyList())
88+
}
8589
}
8690

91+
8792
fun Descriptors.FileDescriptor.extractComments(): Map<String, Comment> {
8893
return toProto().sourceCodeInfo.locationList.associate {
8994
val leading = it.leadingComments ?: ""

protoc-gen/protobuf/src/main/kotlin/kotlinx/rpc/protoc/gen/ModelToProtobufKotlinCommonGenerator.kt

Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ import com.google.protobuf.ByteString
1010
import com.google.protobuf.Descriptors
1111
import kotlinx.rpc.protoc.gen.core.AModelToKotlinCommonGenerator
1212
import kotlinx.rpc.protoc.gen.core.CodeGenerator
13+
import kotlinx.rpc.protoc.gen.core.Comment
1314
import kotlinx.rpc.protoc.gen.core.Config
1415
import kotlinx.rpc.protoc.gen.core.INTERNAL_RPC_API_ANNO
1516
import kotlinx.rpc.protoc.gen.core.PB_PKG
@@ -93,6 +94,12 @@ class ModelToProtobufKotlinCommonGenerator(
9394
)
9495
}
9596

97+
function("copy",
98+
args = "body: ${declaration.internalClassFullName()}.() -> Unit = {}",
99+
returnType = declaration.name.safeFullName(),
100+
comment = Comment.leading("Copies the original message, including unknown fields.")
101+
)
102+
96103
if (declaration.actualFields.isNotEmpty()) {
97104
newLine()
98105
}
@@ -179,6 +186,8 @@ class ModelToProtobufKotlinCommonGenerator(
179186
generateOneOfHashCode(declaration)
180187
generateEquals(declaration)
181188
generateToString(declaration)
189+
generateCopy(declaration)
190+
generateOneOfCopy(declaration)
182191

183192
declaration.nestedDeclarations.forEach { nested ->
184193
generateInternalMessage(nested)
@@ -378,6 +387,88 @@ class ModelToProtobufKotlinCommonGenerator(
378387
}
379388
}
380389

390+
private fun CodeGenerator.generateCopy(declaration: MessageDeclaration) {
391+
if (!declaration.isUserFacing) {
392+
// e.g., internal map entries don't need a copy() method
393+
return
394+
}
395+
function(
396+
name = "copy",
397+
modifiers = "override",
398+
args = "body: ${declaration.internalClassName()}.() -> Unit",
399+
returnType = declaration.internalClassName(),
400+
) {
401+
code("val copy = ${declaration.internalClassName()}()")
402+
for (field in declaration.actualFields) {
403+
// write each field to the new copy object
404+
if (field.presenceIdx != null) {
405+
// if the field has presence, we need to check if it was set in the original object.
406+
// if it was set, we copy it to the new object, otherwise we leave it unset.
407+
ifBranch(condition = "presenceMask[${field.presenceIdx}]", ifBlock = {
408+
code("copy.${field.name} = ${field.type.copyCall(field.name)}")
409+
})
410+
} else {
411+
// by default, we copy the field value
412+
code("copy.${field.name} = ${field.type.copyCall(field.name)}")
413+
}
414+
}
415+
code("copy.apply(body)")
416+
code("return copy")
417+
}
418+
}
419+
420+
private fun FieldType.copyCall(varName: String): String {
421+
return when (this) {
422+
is FieldType.IntegralType -> varName
423+
is FieldType.Enum -> varName
424+
is FieldType.List -> "$varName.map { ${value.copyCall("it")} }"
425+
is FieldType.Map -> "$varName.mapValues { ${entry.value.copyCall("it.value")} }"
426+
is FieldType.Message -> "$varName.copy()"
427+
is FieldType.OneOf -> "$varName?.oneOfCopy()"
428+
}
429+
}
430+
431+
private fun CodeGenerator.generateOneOfCopy(declaration: MessageDeclaration) {
432+
declaration.oneOfDeclarations.forEach { oneOf ->
433+
val oneOfFullName = oneOf.name.safeFullName()
434+
function(
435+
name = "oneOfCopy",
436+
returnType = oneOfFullName,
437+
contextReceiver = oneOfFullName,
438+
) {
439+
// check if the type is copy by value (no need for deep copy)
440+
val copyByValue = { type: FieldType -> type is FieldType.IntegralType || type is FieldType.Enum }
441+
442+
// if all variants are integral or enum types, we can just return this directly.
443+
val fastPath = oneOf.variants.all { copyByValue(it.type) }
444+
if (fastPath) {
445+
code("return this")
446+
} else {
447+
// dispatch on all possible variants and copy its value
448+
whenBlock(
449+
prefix = "return",
450+
condition = "this"
451+
) {
452+
oneOf.variants.forEach { variant ->
453+
val variantName = "$oneOfFullName.${variant.name}"
454+
whenCase("is $variantName") {
455+
if (copyByValue(variant.type)) {
456+
// no need to reconstruct a new object, we can just return this
457+
code("this")
458+
} else {
459+
code("$variantName(${variant.type.copyCall("this.value")})")
460+
}
461+
}
462+
}
463+
}
464+
}
465+
466+
}
467+
}
468+
}
469+
470+
471+
381472
private fun CodeGenerator.generatePresenceIndicesObject(declaration: MessageDeclaration) {
382473
if (declaration.presenceMaskSize == 0) {
383474
return
@@ -1191,6 +1282,8 @@ class ModelToProtobufKotlinCommonGenerator(
11911282

11921283
additionalPublicImports.add("kotlin.jvm.JvmInline")
11931284
}
1285+
1286+
11941287
}
11951288
}
11961289

0 commit comments

Comments
 (0)