@@ -10,6 +10,7 @@ import com.google.protobuf.ByteString
1010import com.google.protobuf.Descriptors
1111import kotlinx.rpc.protoc.gen.core.AModelToKotlinCommonGenerator
1212import kotlinx.rpc.protoc.gen.core.CodeGenerator
13+ import kotlinx.rpc.protoc.gen.core.Comment
1314import kotlinx.rpc.protoc.gen.core.Config
1415import kotlinx.rpc.protoc.gen.core.INTERNAL_RPC_API_ANNO
1516import kotlinx.rpc.protoc.gen.core.PB_PKG
@@ -93,6 +94,12 @@ class ModelToProtobufKotlinCommonGenerator(
9394 )
9495 }
9596
97+ function(" copy" ,
98+ args = " body: ${declaration.internalClassFullName()} .() -> Unit = {}" ,
99+ returnType = declaration.name.safeFullName(),
100+ comment = Comment .leading(" Copies the original message, including unknown fields." )
101+ )
102+
96103 if (declaration.actualFields.isNotEmpty()) {
97104 newLine()
98105 }
@@ -179,6 +186,8 @@ class ModelToProtobufKotlinCommonGenerator(
179186 generateOneOfHashCode(declaration)
180187 generateEquals(declaration)
181188 generateToString(declaration)
189+ generateCopy(declaration)
190+ generateOneOfCopy(declaration)
182191
183192 declaration.nestedDeclarations.forEach { nested ->
184193 generateInternalMessage(nested)
@@ -378,6 +387,88 @@ class ModelToProtobufKotlinCommonGenerator(
378387 }
379388 }
380389
390+ private fun CodeGenerator.generateCopy (declaration : MessageDeclaration ) {
391+ if (! declaration.isUserFacing) {
392+ // e.g., internal map entries don't need a copy() method
393+ return
394+ }
395+ function(
396+ name = " copy" ,
397+ modifiers = " override" ,
398+ args = " body: ${declaration.internalClassName()} .() -> Unit" ,
399+ returnType = declaration.internalClassName(),
400+ ) {
401+ code(" val copy = ${declaration.internalClassName()} ()" )
402+ for (field in declaration.actualFields) {
403+ // write each field to the new copy object
404+ if (field.presenceIdx != null ) {
405+ // if the field has presence, we need to check if it was set in the original object.
406+ // if it was set, we copy it to the new object, otherwise we leave it unset.
407+ ifBranch(condition = " presenceMask[${field.presenceIdx} ]" , ifBlock = {
408+ code(" copy.${field.name} = ${field.type.copyCall(field.name)} " )
409+ })
410+ } else {
411+ // by default, we copy the field value
412+ code(" copy.${field.name} = ${field.type.copyCall(field.name)} " )
413+ }
414+ }
415+ code(" copy.apply(body)" )
416+ code(" return copy" )
417+ }
418+ }
419+
420+ private fun FieldType.copyCall (varName : String ): String {
421+ return when (this ) {
422+ is FieldType .IntegralType -> varName
423+ is FieldType .Enum -> varName
424+ is FieldType .List -> " $varName .map { ${value.copyCall(" it" )} }"
425+ is FieldType .Map -> " $varName .mapValues { ${entry.value.copyCall(" it.value" )} }"
426+ is FieldType .Message -> " $varName .copy()"
427+ is FieldType .OneOf -> " $varName ?.oneOfCopy()"
428+ }
429+ }
430+
431+ private fun CodeGenerator.generateOneOfCopy (declaration : MessageDeclaration ) {
432+ declaration.oneOfDeclarations.forEach { oneOf ->
433+ val oneOfFullName = oneOf.name.safeFullName()
434+ function(
435+ name = " oneOfCopy" ,
436+ returnType = oneOfFullName,
437+ contextReceiver = oneOfFullName,
438+ ) {
439+ // check if the type is copy by value (no need for deep copy)
440+ val copyByValue = { type: FieldType -> type is FieldType .IntegralType || type is FieldType .Enum }
441+
442+ // if all variants are integral or enum types, we can just return this directly.
443+ val fastPath = oneOf.variants.all { copyByValue(it.type) }
444+ if (fastPath) {
445+ code(" return this" )
446+ } else {
447+ // dispatch on all possible variants and copy its value
448+ whenBlock(
449+ prefix = " return" ,
450+ condition = " this"
451+ ) {
452+ oneOf.variants.forEach { variant ->
453+ val variantName = " $oneOfFullName .${variant.name} "
454+ whenCase(" is $variantName " ) {
455+ if (copyByValue(variant.type)) {
456+ // no need to reconstruct a new object, we can just return this
457+ code(" this" )
458+ } else {
459+ code(" $variantName (${variant.type.copyCall(" this.value" )} )" )
460+ }
461+ }
462+ }
463+ }
464+ }
465+
466+ }
467+ }
468+ }
469+
470+
471+
381472 private fun CodeGenerator.generatePresenceIndicesObject (declaration : MessageDeclaration ) {
382473 if (declaration.presenceMaskSize == 0 ) {
383474 return
@@ -1191,6 +1282,8 @@ class ModelToProtobufKotlinCommonGenerator(
11911282
11921283 additionalPublicImports.add(" kotlin.jvm.JvmInline" )
11931284 }
1285+
1286+
11941287 }
11951288 }
11961289
0 commit comments