diff --git a/pyproject.toml b/pyproject.toml index de17765..8dce1a0 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -52,7 +52,7 @@ select = [ "SIM", # flake8-simplify "ARG", # flake8-unused-arguments ] -ignore = ["E501", "N805", "RUF005", "S603", "SIM105"] +ignore = ["E501", "RUF005", "S603", "SIM105", "UP013"] [tool.pytest.ini_options] log_format = "%(message)s" diff --git a/ssh_zone_handler/base.py b/ssh_zone_handler/base.py index f7fe3b0..079a6a1 100644 --- a/ssh_zone_handler/base.py +++ b/ssh_zone_handler/base.py @@ -21,7 +21,6 @@ def __init__(self, config: ZoneHandlerConf) -> None: self.config: ZoneHandlerConf = config self.journal_user: Final[str] = config.system.journalctl_user self.login_user: Final[str] = config.system.login_user - self.server: Final[str] = config.system.server_type self.service_user: Final[str] = config.system.server_user service_unit: Final[str] = config.system.systemd_unit @@ -67,7 +66,6 @@ def generate(self) -> None: all_rules += self.__log_rule() all_rules += self._server_command_rules() - rule: str for rule in all_rules: print(rule) @@ -95,7 +93,8 @@ def __zone_list(self, username: str) -> Sequence[str]: @staticmethod def __parse( - ssh_command: str, user_zones: Sequence[str] + ssh_command: str, + user_zones: Sequence[str], ) -> tuple[str | None, list[str]]: args: list[str] = ssh_command.split() command: str | None = None @@ -145,7 +144,7 @@ def __logs(self, zones: list[str]) -> None: logging.info("Outputting logs for the following zone(s): %s", zones_str) result: CompletedProcess[str] = self._runner(command, failure) - log_lines: list[str] = result.stdout.split("\n") + log_lines = result.stdout.split("\n") line: str for line in self._filter_logs(log_lines, zones): diff --git a/ssh_zone_handler/bind.py b/ssh_zone_handler/bind.py index 6ffc43f..c482808 100644 --- a/ssh_zone_handler/bind.py +++ b/ssh_zone_handler/bind.py @@ -44,8 +44,6 @@ def __lookup(self, zone: str, failure: str) -> str | None: result: CompletedProcess[str] = self._runner(command, failure) - line: str - matched: re.Match[str] | None pattern = re.compile(r"^([^:]+): (.+)$") for line in result.stdout.split("\n"): matched = pattern.match(line) @@ -75,15 +73,13 @@ def _dump(self, zone: str) -> None: run_failure = f'Failed to dump content of zone "{zone}"' result: CompletedProcess[str] = self._runner(command, run_failure) - zone_content: str = result.stdout.rstrip() + zone_content = result.stdout.rstrip() print(zone_content) @staticmethod def _filter_logs(log_lines: list[str], zones: list[str]) -> Iterator[str]: - line: str for line in log_lines: - zone: str for zone in zones: if ( f"zone {zone}/IN" in line diff --git a/ssh_zone_handler/cli.py b/ssh_zone_handler/cli.py index b642fa9..d4e71dd 100644 --- a/ssh_zone_handler/cli.py +++ b/ssh_zone_handler/cli.py @@ -91,8 +91,8 @@ def ssh_keys(config_file: Path = CONFIG_FILE) -> None: logging.debug(str(cfe)) sys.exit(1) - szh = SshZoneAuthorizedKeys(config) - szh.output() + szh_authorized_keys = SshZoneAuthorizedKeys(config) + szh_authorized_keys.output() def sudoers(config_file: Path = CONFIG_FILE) -> None: @@ -109,14 +109,14 @@ def sudoers(config_file: Path = CONFIG_FILE) -> None: except ConfigFileError as cfe: _error_out(str(cfe)) - szh: BindSudoers | KnotSudoers + szh_sudoers: BindSudoers | KnotSudoers if config.system.server_type == "bind": - szh = BindSudoers(config) + szh_sudoers = BindSudoers(config) elif config.system.server_type == "knot": - szh = KnotSudoers(config) + szh_sudoers = KnotSudoers(config) else: _error_out("Unsupported server configured") - szh.generate() + szh_sudoers.generate() def wrapper(config_file: Path = CONFIG_FILE) -> None: @@ -145,15 +145,15 @@ def wrapper(config_file: Path = CONFIG_FILE) -> None: except KeyError: pass - szh: BindCommand | KnotCommand + szh_command: BindCommand | KnotCommand if config.system.server_type == "bind": - szh = BindCommand(config) + szh_command = BindCommand(config) elif config.system.server_type == "knot": - szh = KnotCommand(config) + szh_command = KnotCommand(config) else: _error_out("Unsupported server configured") try: - szh.invoke(ssh_command, username) + szh_command.invoke(ssh_command, username) except InvokeError as error: _error_out(str(error)) diff --git a/ssh_zone_handler/knot.py b/ssh_zone_handler/knot.py index 80ffeb9..6e4184d 100644 --- a/ssh_zone_handler/knot.py +++ b/ssh_zone_handler/knot.py @@ -60,16 +60,14 @@ def _dump(self, zone: str) -> None: run_failure = f'Failed to dump content of zone "{zone}"' result: CompletedProcess[str] = self._runner(command, run_failure) - zone_content: str = result.stdout.rstrip() + zone_content = result.stdout.rstrip() zone_content = self.__filter_dump(zone_content, zone) print(zone_content) @staticmethod def _filter_logs(log_lines: list[str], zones: list[str]) -> Iterator[str]: - line: str for line in log_lines: - zone: str for zone in zones: if f"[{zone}.]" in line: yield line diff --git a/ssh_zone_handler/types.py b/ssh_zone_handler/types.py index 17b4342..de7e3b5 100644 --- a/ssh_zone_handler/types.py +++ b/ssh_zone_handler/types.py @@ -1,6 +1,6 @@ """Custom types""" -from typing import Annotated, Final, Literal +from typing import Annotated, Final, Literal, TypedDict from pydantic import BaseModel, Field, ValidationInfo, field_validator, model_validator from typing_extensions import Self @@ -12,10 +12,11 @@ Ptr4Zone = Annotated[str, Field(pattern=r"^[0-9/]+\.([0-9]+\.)+in-addr\.arpa$")] Ptr6Zone = Annotated[str, Field(pattern=r"^([a-f0-9]\.)+ip6\.arpa$")] Zone = FwdZone | Ptr4Zone | Ptr6Zone +ServiceDefault = TypedDict("ServiceDefault", {"unit": ServiceUnit, "user": SystemUser}) SSHKey = Annotated[str, Field(pattern=r"^(ecdsa-sha2-nistp256 AAAAE2VjZHNhLXNoYTItbmlzdHAyNT|ecdsa-sha2-nistp384 AAAAE2VjZHNhLXNoYTItbmlzdHAzOD|ecdsa-sha2-nistp521 AAAAE2VjZHNhLXNoYTItbmlzdHA1Mj|sk-ecdsa-sha2-nistp256@openssh.com AAAAInNrLWVjZHNhLXNoYTItbmlzdHAyNTZAb3BlbnNzaC5jb2|ssh-ed25519 AAAAC3NzaC1lZDI1NTE5|sk-ssh-ed25519@openssh.com AAAAGnNrLXNzaC1lZDI1NTE5QG9wZW5zc2guY29t|ssh-rsa AAAAB3NzaC1yc2)[0-9A-Za-z+/]+[=]{0,3}(\s.*)?$")] # fmt: skip -SERVICE_DEFAULTS: Final[dict[str, dict[str, str]]] = { +SERVICE_DEFAULTS: Final[dict[str, ServiceDefault]] = { "bind": { "unit": "named.service", "user": "bind", @@ -39,6 +40,7 @@ class SystemConf(BaseModel, extra="forbid", frozen=True): systemd_unit: ServiceUnit = Field(default="", validate_default=True) @field_validator("server_user", mode="before") + @classmethod def _default_user(cls, user: str, values: ValidationInfo) -> str: if not user: try: @@ -48,6 +50,7 @@ def _default_user(cls, user: str, values: ValidationInfo) -> str: return user @field_validator("systemd_unit", mode="before") + @classmethod def _default_unit(cls, systemd_unit: str, values: ValidationInfo) -> str: if not systemd_unit: try: @@ -66,6 +69,7 @@ class UserConf(BaseModel, extra="forbid", frozen=True): zones: list[Zone] @field_validator("ssh_keys", mode="after") + @classmethod def _clean_ssh_keys(cls, ssh_keys: list[SSHKey]) -> list[SSHKey]: cleaned_keys: list[SSHKey] = [] for ssh_key in ssh_keys: