Skip to content

Commit 15952b8

Browse files
committed
feat(protobuf): add source tracking for query types and flattened fields
Signed-off-by: Ahmed Mohamed <[email protected]>
1 parent b6c4ff4 commit 15952b8

File tree

6 files changed

+212
-166
lines changed

6 files changed

+212
-166
lines changed

src/s2dm/cli.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -587,7 +587,7 @@ def jsonschema(
587587
def protobuf(
588588
ctx: click.Context,
589589
schemas: list[Path],
590-
selection_query: Path | None,
590+
selection_query: Path,
591591
output: Path,
592592
root_type: str | None,
593593
flatten_naming: bool,
@@ -598,13 +598,11 @@ def protobuf(
598598
naming_config = ctx.obj.get("naming_config")
599599
graphql_schema = load_schema_with_naming(schemas, naming_config)
600600

601-
query_document = None
602-
if selection_query:
603-
query_document = parse(selection_query.read_text())
604-
graphql_schema = prune_schema_using_query_selection(graphql_schema, query_document)
601+
query_document = parse(selection_query.read_text())
602+
graphql_schema = prune_schema_using_query_selection(graphql_schema, query_document)
605603

606604
result = translate_to_protobuf(
607-
graphql_schema, root_type, flatten_naming, package_name, naming_config, expanded_instances, query_document
605+
graphql_schema, query_document, root_type, flatten_naming, package_name, naming_config, expanded_instances
608606
)
609607
_ = output.write_text(result)
610608

src/s2dm/exporters/protobuf/models.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -74,4 +74,3 @@ class ProtoSchema(BaseModel):
7474
messages: list[ProtoMessage] = Field(default_factory=list)
7575
unions: list[ProtoUnion] = Field(default_factory=list)
7676
flatten_mode: bool = False
77-
flattened_fields: list[ProtoField] = Field(default_factory=list)

src/s2dm/exporters/protobuf/protobuf.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -9,12 +9,12 @@
99

1010
def transform(
1111
graphql_schema: GraphQLSchema,
12+
selection_query: DocumentNode,
1213
root_type: str | None = None,
1314
flatten_naming: bool = False,
1415
package_name: str | None = None,
1516
naming_config: dict[str, Any] | None = None,
1617
expanded_instances: bool = False,
17-
selection_query: DocumentNode | None = None,
1818
) -> str:
1919
"""
2020
Transform a GraphQL schema object to Protocol Buffers format.
@@ -26,10 +26,13 @@ def transform(
2626
package_name: Optional package name for the .proto file
2727
naming_config: Optional naming configuration
2828
expanded_instances: If True, expand instance tags into nested structures
29-
selection_query: Optional selection query document to determine root-level types
29+
selection_query: Required selection query document to determine root-level types
3030
3131
Returns:
3232
str: Protocol Buffers representation as a string
33+
34+
Raises:
35+
ValueError: If selection_query is not provided
3336
"""
3437
log.info(f"Transforming GraphQL schema to Protobuf with {len(graphql_schema.type_map)} types")
3538

@@ -39,7 +42,7 @@ def transform(
3942
log.info(f"Using root type: {root_type}")
4043

4144
transformer = ProtobufTransformer(
42-
graphql_schema, root_type, flatten_naming, package_name, naming_config, expanded_instances, selection_query
45+
graphql_schema, selection_query, root_type, flatten_naming, package_name, naming_config, expanded_instances
4346
)
4447
proto_content = transformer.transform()
4548

@@ -50,12 +53,12 @@ def transform(
5053

5154
def translate_to_protobuf(
5255
schema: GraphQLSchema,
56+
selection_query: DocumentNode,
5357
root_type: str | None = None,
5458
flatten_naming: bool = False,
5559
package_name: str | None = None,
5660
naming_config: dict[str, Any] | None = None,
5761
expanded_instances: bool = False,
58-
selection_query: DocumentNode | None = None,
5962
) -> str:
6063
"""
6164
Translate a GraphQL schema to Protocol Buffers format.
@@ -67,11 +70,14 @@ def translate_to_protobuf(
6770
package_name: Optional package name for the .proto file
6871
naming_config: Optional naming configuration
6972
expanded_instances: If True, expand instance tags into nested structures
70-
selection_query: Optional selection query document to determine root-level types
73+
selection_query: Required selection query document to determine root-level types
7174
7275
Returns:
7376
str: Protocol Buffers (.proto) representation as a string
77+
78+
Raises:
79+
ValueError: If selection_query is not provided
7480
"""
7581
return transform(
76-
schema, root_type, flatten_naming, package_name, naming_config, expanded_instances, selection_query
82+
schema, selection_query, root_type, flatten_naming, package_name, naming_config, expanded_instances
7783
)

src/s2dm/exporters/protobuf/templates/proto_flattened.j2

Lines changed: 0 additions & 82 deletions
This file was deleted.

src/s2dm/exporters/protobuf/transformer.py

Lines changed: 21 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -82,13 +82,16 @@ class ProtobufTransformer:
8282
def __init__(
8383
self,
8484
graphql_schema: GraphQLSchema,
85+
selection_query: DocumentNode,
8586
root_type: str | None = None,
8687
flatten_naming: bool = False,
8788
package_name: str | None = None,
8889
naming_config: dict[str, Any] | None = None,
8990
expanded_instances: bool = False,
90-
selection_query: DocumentNode | None = None,
9191
):
92+
if selection_query is None:
93+
raise ValueError("selection_query is required")
94+
9295
self.graphql_schema = graphql_schema
9396
self.root_type = root_type
9497
self.flatten_naming = flatten_naming
@@ -151,7 +154,7 @@ def transform(self) -> str:
151154
# If no fields reference that object type directly (non-flattened), the type definition is no longer needed.
152155
# However, unions and enums cannot be flattened and must remain as separate type definitions.
153156
(
154-
proto_schema.flattened_fields,
157+
flattened_fields,
155158
referenced_type_names,
156159
flattened_root_types,
157160
) = self._build_flattened_fields(message_types)
@@ -167,7 +170,17 @@ def transform(self) -> str:
167170
proto_schema.unions = self._build_unions(union_types)
168171
proto_schema.messages = self._build_messages(message_types)
169172

170-
template_name = "proto_flattened.j2" if self.flatten_naming else "proto_standard.j2"
173+
if self.flatten_naming:
174+
root_message_name = self._get_query_operation_name()
175+
root_message_source = f"query: {root_message_name}"
176+
root_message = ProtoMessage(
177+
name=root_message_name,
178+
fields=flattened_fields,
179+
source=root_message_source,
180+
)
181+
proto_schema.messages.append(root_message)
182+
183+
template_name = "proto_standard.j2"
171184
template = self.env.get_template(template_name)
172185

173186
template_vars = self._build_template_vars(proto_schema)
@@ -185,9 +198,7 @@ def check_message(message: ProtoMessage) -> bool:
185198
return True
186199
return any(check_message(nested) for nested in message.nested_messages)
187200

188-
return any(field.validation_rules for field in proto_schema.flattened_fields) or any(
189-
check_message(message) for message in proto_schema.messages
190-
)
201+
return any(check_message(message) for message in proto_schema.messages)
191202

192203
def _has_source_option(self, proto_schema: ProtoSchema) -> bool:
193204
"""Check if any type in the schema has a source option."""
@@ -199,9 +210,6 @@ def _get_query_operation_name(self) -> str:
199210
"""Extract the operation name from the selection query, defaulting to appropriate fallback."""
200211
default_name = "Message" if self.flatten_naming else "Query"
201212

202-
if not self.selection_query:
203-
return default_name
204-
205213
for definition in self.selection_query.definitions:
206214
if not isinstance(definition, OperationDefinitionNode) or definition.operation != OperationType.QUERY:
207215
continue
@@ -226,7 +234,6 @@ def _build_template_vars(self, proto_schema: ProtoSchema) -> dict[str, Any]:
226234
template_vars = proto_schema.model_dump()
227235
template_vars["imports"] = imports
228236
template_vars["has_source_option"] = has_source_option
229-
template_vars["message_name"] = self._get_query_operation_name()
230237

231238
return template_vars
232239

@@ -259,15 +266,18 @@ def _build_messages(self, message_types: list[GraphQLObjectType | GraphQLInterfa
259266
fields, nested_messages = self._build_message_fields(message_type)
260267

261268
message_name = message_type.name
269+
source = message_type.name
270+
262271
if message_type.name == "Query":
263272
message_name = self._get_query_operation_name()
273+
source = f"query: {message_name}"
264274

265275
messages.append(
266276
ProtoMessage(
267277
name=message_name,
268278
fields=fields,
269279
description=message_type.description,
270-
source=message_type.name,
280+
source=source,
271281
nested_messages=nested_messages,
272282
)
273283
)

0 commit comments

Comments
 (0)