diff --git a/README.md b/README.md index 5d28fda..e03a32f 100644 --- a/README.md +++ b/README.md @@ -24,3 +24,38 @@ uvx env-example # Exclude specific directories relative to the project root uvx env-example --exclude-dir other/scripts ``` + +# Example +```python +from pydantic import BaseSettings + + +class AppSettings(BaseSettings): + model_config = { + "env_prefix": "APP__" + } + debug: bool + log_level: str + +class DatabaseSettings(BaseSettings): + model_config = { + "env_prefix": "DB__" + } + host: str + port: int + username: str + password: str +``` + +env-example will generate the following `.env.example` file: +```shell +# AppSettings +APP__DEBUG= +APP__LOG_LEVEL= + +# DatabaseSettings +DB__HOST= +DB__PORT= +DB__USERNAME= +DB__PASSWORD= +``` diff --git a/src/env_example/main.py b/src/env_example/main.py index 4faf5ab..d95e706 100644 --- a/src/env_example/main.py +++ b/src/env_example/main.py @@ -7,10 +7,11 @@ Call, ClassDef, Constant, + Dict, Name, ) from collections import defaultdict -from dataclasses import dataclass +from dataclasses import dataclass, field from pathlib import Path from typing import Iterator, Self @@ -49,13 +50,6 @@ def __lt__(self, other: Self) -> bool: BASE_SETTINGS_FQN = QualifiedName(("pydantic_settings", "BaseSettings")) -@dataclass(frozen=True) -class SettingField: - name: str - settings_class: str - prefix: str | None = None - - @dataclass(frozen=True) class ModuleImport: module: str @@ -64,10 +58,33 @@ class ModuleImport: @dataclass(frozen=True) class NameImport: - module: str + module: str | None name: str + level: int alias: str | None = None + def __post_init__(self): + if not self.module and not self.level: + raise ValueError("Absolute imports must have a module component") + + def get_qualified_parent_module( + self, + current: QualifiedName, + ) -> QualifiedName: + """Resolve the qualified parent module for a relative import""" + if not self.level: + assert self.module + return QualifiedName.from_str(self.module) + + parent_qn = current + for _ in range(self.level): + parent_qn = parent_qn.parent + + if not self.module: + return parent_qn + + return parent_qn.child(self.module) + type ImportItem = ModuleImport | NameImport @@ -78,23 +95,10 @@ class ParsedModule: classes: dict[str, ast.ClassDef] -class InheritanceHierarchy: - def __init__(self) -> None: - self._children: defaultdict[QualifiedName, set[QualifiedName]] = ( - defaultdict(set) - ) - - def add_relation(self, parent: QualifiedName, child: QualifiedName): - self._children[parent].add(child) - - def transitive_subclasses( - self, class_name: QualifiedName - ) -> set[QualifiedName]: - reachable = set() - for child in self._children[class_name]: - reachable.add(child) - reachable.update(self.transitive_subclasses(child)) - return reachable +@dataclass +class ParsedSettings: + prefix: str | None = None + fields: set[str] = field(default_factory=set) def main() -> None: @@ -119,60 +123,55 @@ def generate_env_example( exclude_relative: list[Path] | None, ) -> None: """ - Orchestrator function + Orchestrator function. 1. Parse the modules and map the package structure of the project - 2. Build the inheritance hierarchy - 3. Calculate all transitive subclasses of BaseSettings - 4. Extract fields for all settings classes - 5. Write them to an .env.example file + 2. Build a class inheritance lookup + 3. Parse settings for the subclasses of BaseSettings + 4. Write them to an .env.example file """ exclude_absolute: set[Path] = ( {p.resolve() for p in exclude_relative} if exclude_relative else set() ) - module_hierarchy: dict[QualifiedName, ParsedModule] = {} + module_lookup: dict[QualifiedName, ParsedModule] = {} for fqn, ast_module in walk_project( root=project_root, exclude_paths=exclude_absolute, ): classes = filter_module_by_type(ast_module, ast.ClassDef) - module_hierarchy[fqn] = ParsedModule( + module_lookup[fqn] = ParsedModule( ast_module=ast_module, classes={cd.name: cd for cd in classes}, ) - inheritance = InheritanceHierarchy() - for fqn, parsed_module in module_hierarchy.items(): + child_lookup: defaultdict[QualifiedName, list[QualifiedName]] = ( + defaultdict(list) + ) + for fqn, parsed_module in module_lookup.items(): for class_def in parsed_module.classes.values(): class_fqn = fqn.child(class_def.name) for base in get_bases_from_class(class_def): parent = find_source_or_external_import( searched_symbol=base, search_module=fqn, - module_lookup=module_hierarchy, + module_lookup=module_lookup, ) if parent: - inheritance.add_relation( - parent=parent, - child=class_fqn, - ) - - settings_subclasses = inheritance.transitive_subclasses(BASE_SETTINGS_FQN) - - fields_per_class: dict[str, list[SettingField]] = {} - for fqn in sorted(settings_subclasses): - class_def = module_hierarchy[fqn.parent].classes[fqn.leaf] - fields_per_class[class_def.name] = extract_fields_from_settings( - class_def + child_lookup[parent].append(class_fqn) + + parsed_settings = defaultdict(ParsedSettings) + children = child_lookup[BASE_SETTINGS_FQN] + for child in children: + gather_settings_for_subtree( + node=child, + child_lookup=child_lookup, + module_lookup=module_lookup, + parsed_settings=parsed_settings, ) - env_example_txt = build_env_example(fields_per_class) + env_example_txt = build_env_example(parsed_settings) if env_example_txt: - write_to_file(env_example_txt, project_root / OUTPUT_FILE) - - -def write_to_file(text: str, file: Path) -> None: - file.write_text(text) + (project_root / OUTPUT_FILE).write_text(env_example_txt) def walk_project( @@ -199,7 +198,7 @@ def walk_dir( ) for item in sorted(dir.iterdir()): - if is_package and item.is_file() and item.suffix == ".py": + if item.is_file() and item.suffix == ".py": module = ast.parse(item.read_text()) module_fqn = ( new_parent @@ -213,11 +212,40 @@ def walk_dir( and item.name not in ALWAYS_EXCLUDE_DIRS and item not in exclude_paths ): - yield from walk_dir(item, parent_package=parent_package) + yield from walk_dir(item, parent_package=new_parent) yield from walk_dir(root, parent_package=QualifiedName(())) +def gather_settings_for_subtree( + node: QualifiedName, + child_lookup: defaultdict[QualifiedName, list[QualifiedName]], + module_lookup: dict[QualifiedName, ParsedModule], + parsed_settings: defaultdict[QualifiedName, ParsedSettings], +) -> None: + """ + Recursively parses fieldsfrom settings classes and adds them to + an aggregator for both the currently considered class and its children. + """ + class_def = module_lookup[node.parent].classes[node.leaf] + fields = parse_fields_from_settings(class_def) + prefix = parse_settings_prefix(class_def) + + parsed_settings[node].prefix = prefix + parsed_settings[node].fields.update(fields) + + for child in child_lookup[node]: + # add parent fields for the child settings class + parsed_settings[child].fields.update(parsed_settings[node].fields) + + gather_settings_for_subtree( + node=child, + child_lookup=child_lookup, + module_lookup=module_lookup, + parsed_settings=parsed_settings, + ) + + def find_source_or_external_import( searched_symbol: QualifiedName, search_module: QualifiedName, @@ -229,11 +257,8 @@ def find_source_or_external_import( - the fqn to the implementation of the searched symbol - None if none of the above, and no imports can be followed """ - match searched_symbol.parts: - case (symbol_object_name,): - symbol_module_ref = None - case (*_, symbol_module_ref, symbol_object_name): - pass + *module_parts, symbol_object_name = searched_symbol.parts + symbol_module_ref = ".".join(module_parts) or None parsed_module = module_lookup.get(search_module) @@ -249,10 +274,13 @@ def find_source_or_external_import( imports = resolve_import_statements(parsed_module.ast_module) for imp in imports: match imp: - case NameImport(module, name, _) if name == symbol_object_name: + case NameImport(module, name) if name == symbol_object_name: + resolved = imp.get_qualified_parent_module( + current=search_module + ) return find_source_or_external_import( searched_symbol=QualifiedName((name,)), - search_module=QualifiedName.from_str(module), + search_module=resolved, module_lookup=module_lookup, ) case ModuleImport(module, alias) if ( @@ -267,51 +295,82 @@ def find_source_or_external_import( return None -def build_env_example(fields_per_class: dict[str, list[SettingField]]) -> str: - if not fields_per_class: +def build_env_example( + parsed_settings: dict[QualifiedName, ParsedSettings], +) -> str: + if not parsed_settings: return "" sections = [ - f"# {class_name}\n" + f"# {qn.leaf}\n" + "\n".join( - f"{field.prefix or ''}{field.name}=".upper() for field in fields + f"{parsed.prefix or ''}{field}=".upper() + for field in sorted(parsed.fields) + ) + for qn, parsed in sorted( + parsed_settings.items(), key=lambda x: x[0].leaf ) - for class_name, fields in fields_per_class.items() ] return "\n\n".join(sections) + "\n" -def extract_fields_from_settings(cd: ClassDef) -> list[SettingField]: +def parse_settings_prefix(cd: ClassDef) -> str | None: + """ + Parses the model_config configuration to find the configured + prefix. model_config can be given as a SettingConfigDict and + as a plain dict. we cover both cases. + """ prefixes: list[str] = [] for item in cd.body: - if not isinstance(item, (Assign, AnnAssign)): + if isinstance(item, AnnAssign): + target = item.target + value = item.value + elif isinstance(item, Assign) and len(item.targets) == 1: + target = item.targets[0] + value = item.value + else: continue - value = item.value - if not isinstance(value, Call): + if not (isinstance(target, Name) and target.id == "model_config"): continue - if not ( - isinstance(value.func, Name) - and value.func.id == SETTINGS_CONFIG_CLASS - ): - continue - - for kw in value.keywords: - if ( - kw.arg == ENV_PREFIX_ARG - and isinstance(kw.value, Constant) - and isinstance(kw.value.value, str) + if isinstance(value, Call): + # SettingsConfigDict case + if not ( + isinstance(value.func, Name) + and value.func.id == SETTINGS_CONFIG_CLASS ): - prefixes.append(kw.value.value) + continue + for kw in value.keywords: + if ( + kw.arg == ENV_PREFIX_ARG + and isinstance(kw.value, Constant) + and isinstance(kw.value.value, str) + ): + prefixes.append(kw.value.value) + + elif isinstance(value, Dict): + # plain dict case + for key, val in zip(value.keys, value.values): + if ( + isinstance(key, Constant) + and key.value == ENV_PREFIX_ARG + and isinstance(val, Constant) + and isinstance(val.value, str) + ): + prefixes.append(val.value) if len(prefixes) > 1: raise ValueError( - f"Multiple prefixes found for class {cd.name}: {(prefixes,)}" + f"Multiple prefixes found for class {cd.name}: {prefixes}" ) prefix = prefixes[0] if prefixes else None - fields: list[SettingField] = [] + return prefix + + +def parse_fields_from_settings(cd: ClassDef) -> list[str]: + fields: list[str] = [] for elem in cd.body: if not isinstance(elem, AnnAssign): @@ -319,13 +378,7 @@ def extract_fields_from_settings(cd: ClassDef) -> list[SettingField]: if not isinstance(elem.target, Name): continue name: str = elem.target.id - fields.append( - SettingField( - name=name, - settings_class=cd.name, - prefix=prefix, - ) - ) + fields.append(name) return fields @@ -335,8 +388,15 @@ def get_bases_from_class(cd: ClassDef) -> list[QualifiedName]: for base in cd.bases: if isinstance(base, Name): bases.append(QualifiedName((base.id,))) - elif isinstance(base, Attribute) and isinstance(base.value, Name): - bases.append(QualifiedName((base.value.id, base.attr))) + elif isinstance(base, Attribute): + parts: list[str] = [base.attr] + node = base.value + while isinstance(node, Attribute): + parts.append(node.attr) + node = node.value + if isinstance(node, Name): + parts.append(node.id) + bases.append(QualifiedName(tuple(reversed(parts)))) return bases @@ -352,10 +412,13 @@ def resolve_import_statements(module: ast.Module) -> list[ImportItem]: ModuleImport(module=name.name, alias=name.asname) for name in item.names ) - elif isinstance(item, ast.ImportFrom) and item.module: + elif isinstance(item, ast.ImportFrom): imports.extend( NameImport( - module=item.module, name=name.name, alias=name.asname + module=item.module, + name=name.name, + alias=name.asname, + level=item.level, ) for name in item.names ) diff --git a/tests/cases/default_exclude/.env.example.expected b/tests/cases/default_exclude/.env.example.expected index f199de9..6e3a486 100644 --- a/tests/cases/default_exclude/.env.example.expected +++ b/tests/cases/default_exclude/.env.example.expected @@ -1,2 +1,2 @@ -# IncludedSettings +# Settings FIELD= diff --git a/tests/cases/default_exclude/project/package/included/module.py b/tests/cases/default_exclude/project/package/included/module.py index 23fb179..8af5378 100644 --- a/tests/cases/default_exclude/project/package/included/module.py +++ b/tests/cases/default_exclude/project/package/included/module.py @@ -1,5 +1,5 @@ from pydantic_settings import BaseSettings -class IncludedSettings(BaseSettings): +class Settings(BaseSettings): field: int diff --git a/tests/cases/default_exclude/project/package/module.py b/tests/cases/default_exclude/project/package/site-packages/__init__.py similarity index 100% rename from tests/cases/default_exclude/project/package/module.py rename to tests/cases/default_exclude/project/package/site-packages/__init__.py diff --git a/tests/cases/default_exclude/project/package/site-packages/module.py b/tests/cases/default_exclude/project/package/site-packages/module.py new file mode 100644 index 0000000..291afa9 --- /dev/null +++ b/tests/cases/default_exclude/project/package/site-packages/module.py @@ -0,0 +1,5 @@ +from pydantic_settings import BaseSettings + + +class ExcludedSettings(BaseSettings): + field: int diff --git a/tests/cases/main_file/.env.example.expected b/tests/cases/main_file/.env.example.expected new file mode 100644 index 0000000..264e77c --- /dev/null +++ b/tests/cases/main_file/.env.example.expected @@ -0,0 +1,9 @@ +# InheritedSettings +CHILD_FIELD= +PACKAGE_FIELD= + +# MainSettings +MAIN_FIELD= + +# Settings +PACKAGE_FIELD= diff --git a/tests/cases/main_file/project/main.py b/tests/cases/main_file/project/main.py new file mode 100644 index 0000000..7a394af --- /dev/null +++ b/tests/cases/main_file/project/main.py @@ -0,0 +1,10 @@ +from package.module import Settings +from pydantic_settings import BaseSettings + + +class MainSettings(BaseSettings): + main_field: int + + +class InheritedSettings(Settings): + child_field: int diff --git a/tests/cases/main_file/project/package/__init__.py b/tests/cases/main_file/project/package/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/cases/main_file/project/package/module.py b/tests/cases/main_file/project/package/module.py new file mode 100644 index 0000000..08c9e5a --- /dev/null +++ b/tests/cases/main_file/project/package/module.py @@ -0,0 +1,5 @@ +from pydantic_settings import BaseSettings + + +class Settings(BaseSettings): + package_field: int diff --git a/tests/cases/multiple_prefixes/project/package/module.py b/tests/cases/multiple_prefixes/project/package/module.py index ad2933e..c0fb58d 100644 --- a/tests/cases/multiple_prefixes/project/package/module.py +++ b/tests/cases/multiple_prefixes/project/package/module.py @@ -3,5 +3,5 @@ class Settings(BaseSettings): model_config = SettingsConfigDict(env_prefix="first_prefix_") - other_config = SettingsConfigDict(env_prefix="second_prefix_") + model_config = SettingsConfigDict(env_prefix="second_prefix_") field: int diff --git a/tests/cases/prefix/.env.example.expected b/tests/cases/prefix/.env.example.expected index acaf632..1b20255 100644 --- a/tests/cases/prefix/.env.example.expected +++ b/tests/cases/prefix/.env.example.expected @@ -1,2 +1,5 @@ +# DictSettings +DICT_PREFIX__OTHER_FIELD= + # Settings MY_PREFIX__FIELD= diff --git a/tests/cases/prefix/project/package/module.py b/tests/cases/prefix/project/package/module.py index f7b7f5d..b957a4f 100644 --- a/tests/cases/prefix/project/package/module.py +++ b/tests/cases/prefix/project/package/module.py @@ -4,3 +4,8 @@ class Settings(BaseSettings): model_config = SettingsConfigDict(env_prefix="my_prefix__") field: int + + +class DictSettings(BaseSettings): + model_config = {"env_prefix": "dict_prefix__"} + other_field: str diff --git a/tests/cases/qualified_module_import/.env.example.expected b/tests/cases/qualified_module_import/.env.example.expected new file mode 100644 index 0000000..f22f355 --- /dev/null +++ b/tests/cases/qualified_module_import/.env.example.expected @@ -0,0 +1,6 @@ +# ChildSettings +CHILD_FIELD= +PARENT_FIELD= + +# ParentSettings +PARENT_FIELD= diff --git a/tests/cases/qualified_module_import/project/package/__init__.py b/tests/cases/qualified_module_import/project/package/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/cases/qualified_module_import/project/package/consumer.py b/tests/cases/qualified_module_import/project/package/consumer.py new file mode 100644 index 0000000..bd951ba --- /dev/null +++ b/tests/cases/qualified_module_import/project/package/consumer.py @@ -0,0 +1,5 @@ +import package.subpackage.base + + +class ChildSettings(package.subpackage.base.ParentSettings): + child_field: str diff --git a/tests/cases/qualified_module_import/project/package/subpackage/__init__.py b/tests/cases/qualified_module_import/project/package/subpackage/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/cases/qualified_module_import/project/package/subpackage/base.py b/tests/cases/qualified_module_import/project/package/subpackage/base.py new file mode 100644 index 0000000..a5ba1bc --- /dev/null +++ b/tests/cases/qualified_module_import/project/package/subpackage/base.py @@ -0,0 +1,5 @@ +from pydantic_settings import BaseSettings + + +class ParentSettings(BaseSettings): + parent_field: str diff --git a/tests/cases/reexport_inheritance/.env.example.expected b/tests/cases/reexport_inheritance/.env.example.expected index 1bd040f..f22f355 100644 --- a/tests/cases/reexport_inheritance/.env.example.expected +++ b/tests/cases/reexport_inheritance/.env.example.expected @@ -1,5 +1,6 @@ -# ParentSettings -PARENT_FIELD= - # ChildSettings CHILD_FIELD= +PARENT_FIELD= + +# ParentSettings +PARENT_FIELD= diff --git a/tests/cases/relative_import/.env.example.expected b/tests/cases/relative_import/.env.example.expected new file mode 100644 index 0000000..0fb8839 --- /dev/null +++ b/tests/cases/relative_import/.env.example.expected @@ -0,0 +1,11 @@ +# ChildSettings +CHILD_FIELD= +MIDDLE_FIELD= +PARENT_FIELD= + +# MiddleSettings +MIDDLE_FIELD= +PARENT_FIELD= + +# ParentSettings +PARENT_FIELD= diff --git a/tests/cases/relative_import/project/package/__init__.py b/tests/cases/relative_import/project/package/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/cases/relative_import/project/package/base.py b/tests/cases/relative_import/project/package/base.py new file mode 100644 index 0000000..a5ba1bc --- /dev/null +++ b/tests/cases/relative_import/project/package/base.py @@ -0,0 +1,5 @@ +from pydantic_settings import BaseSettings + + +class ParentSettings(BaseSettings): + parent_field: str diff --git a/tests/cases/relative_import/project/package/subpackage/__init__.py b/tests/cases/relative_import/project/package/subpackage/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/cases/relative_import/project/package/subpackage/consumer.py b/tests/cases/relative_import/project/package/subpackage/consumer.py new file mode 100644 index 0000000..8e2473c --- /dev/null +++ b/tests/cases/relative_import/project/package/subpackage/consumer.py @@ -0,0 +1,5 @@ +from .middle import MiddleSettings + + +class ChildSettings(MiddleSettings): + child_field: str diff --git a/tests/cases/relative_import/project/package/subpackage/middle.py b/tests/cases/relative_import/project/package/subpackage/middle.py new file mode 100644 index 0000000..7277930 --- /dev/null +++ b/tests/cases/relative_import/project/package/subpackage/middle.py @@ -0,0 +1,5 @@ +from ..base import ParentSettings + + +class MiddleSettings(ParentSettings): + middle_field: str diff --git a/tests/cases/transitive_inheritance/.env.example.expected b/tests/cases/transitive_inheritance/.env.example.expected index 4e0bc99..7b4e0db 100644 --- a/tests/cases/transitive_inheritance/.env.example.expected +++ b/tests/cases/transitive_inheritance/.env.example.expected @@ -1,4 +1,5 @@ # ChildSettings +FIELD= OTHER_FIELD= # ParentSettings diff --git a/tests/cases/two_level_transitive_inheritance/.env.example.expected b/tests/cases/two_level_transitive_inheritance/.env.example.expected index 1127475..7f0c731 100644 --- a/tests/cases/two_level_transitive_inheritance/.env.example.expected +++ b/tests/cases/two_level_transitive_inheritance/.env.example.expected @@ -1,8 +1,11 @@ # ChildSettings CHILD_FIELD= +GRANDPARENT_FIELD= +PARENT_FIELD= # GrandparentSettings GRANDPARENT_FIELD= # ParentSettings +GRANDPARENT_FIELD= PARENT_FIELD= diff --git a/tests/test_run.py b/tests/test_run.py index fbec5b2..05e43dd 100644 --- a/tests/test_run.py +++ b/tests/test_run.py @@ -26,6 +26,9 @@ Case(name="transitive_inheritance", exclude_dirs=None), Case(name="two_level_transitive_inheritance", exclude_dirs=None), Case(name="reexport_inheritance", exclude_dirs=None), + Case(name="qualified_module_import", exclude_dirs=None), + Case(name="relative_import", exclude_dirs=None), + Case(name="main_file", exclude_dirs=None), ]