@@ -82,13 +82,16 @@ class ProtobufTransformer:
8282 def __init__ (
8383 self ,
8484 graphql_schema : GraphQLSchema ,
85+ selection_query : DocumentNode ,
8586 root_type : str | None = None ,
8687 flatten_naming : bool = False ,
8788 package_name : str | None = None ,
8889 naming_config : dict [str , Any ] | None = None ,
8990 expanded_instances : bool = False ,
90- selection_query : DocumentNode | None = None ,
9191 ):
92+ if selection_query is None :
93+ raise ValueError ("selection_query is required" )
94+
9295 self .graphql_schema = graphql_schema
9396 self .root_type = root_type
9497 self .flatten_naming = flatten_naming
@@ -151,7 +154,7 @@ def transform(self) -> str:
151154 # If no fields reference that object type directly (non-flattened), the type definition is no longer needed.
152155 # However, unions and enums cannot be flattened and must remain as separate type definitions.
153156 (
154- proto_schema . flattened_fields ,
157+ flattened_fields ,
155158 referenced_type_names ,
156159 flattened_root_types ,
157160 ) = self ._build_flattened_fields (message_types )
@@ -167,7 +170,17 @@ def transform(self) -> str:
167170 proto_schema .unions = self ._build_unions (union_types )
168171 proto_schema .messages = self ._build_messages (message_types )
169172
170- template_name = "proto_flattened.j2" if self .flatten_naming else "proto_standard.j2"
173+ if self .flatten_naming :
174+ root_message_name = self ._get_query_operation_name ()
175+ root_message_source = f"query: { root_message_name } "
176+ root_message = ProtoMessage (
177+ name = root_message_name ,
178+ fields = flattened_fields ,
179+ source = root_message_source ,
180+ )
181+ proto_schema .messages .append (root_message )
182+
183+ template_name = "proto_standard.j2"
171184 template = self .env .get_template (template_name )
172185
173186 template_vars = self ._build_template_vars (proto_schema )
@@ -185,9 +198,7 @@ def check_message(message: ProtoMessage) -> bool:
185198 return True
186199 return any (check_message (nested ) for nested in message .nested_messages )
187200
188- return any (field .validation_rules for field in proto_schema .flattened_fields ) or any (
189- check_message (message ) for message in proto_schema .messages
190- )
201+ return any (check_message (message ) for message in proto_schema .messages )
191202
192203 def _has_source_option (self , proto_schema : ProtoSchema ) -> bool :
193204 """Check if any type in the schema has a source option."""
@@ -199,9 +210,6 @@ def _get_query_operation_name(self) -> str:
199210 """Extract the operation name from the selection query, defaulting to appropriate fallback."""
200211 default_name = "Message" if self .flatten_naming else "Query"
201212
202- if not self .selection_query :
203- return default_name
204-
205213 for definition in self .selection_query .definitions :
206214 if not isinstance (definition , OperationDefinitionNode ) or definition .operation != OperationType .QUERY :
207215 continue
@@ -226,7 +234,6 @@ def _build_template_vars(self, proto_schema: ProtoSchema) -> dict[str, Any]:
226234 template_vars = proto_schema .model_dump ()
227235 template_vars ["imports" ] = imports
228236 template_vars ["has_source_option" ] = has_source_option
229- template_vars ["message_name" ] = self ._get_query_operation_name ()
230237
231238 return template_vars
232239
@@ -259,15 +266,18 @@ def _build_messages(self, message_types: list[GraphQLObjectType | GraphQLInterfa
259266 fields , nested_messages = self ._build_message_fields (message_type )
260267
261268 message_name = message_type .name
269+ source = message_type .name
270+
262271 if message_type .name == "Query" :
263272 message_name = self ._get_query_operation_name ()
273+ source = f"query: { message_name } "
264274
265275 messages .append (
266276 ProtoMessage (
267277 name = message_name ,
268278 fields = fields ,
269279 description = message_type .description ,
270- source = message_type . name ,
280+ source = source ,
271281 nested_messages = nested_messages ,
272282 )
273283 )
0 commit comments