Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ select = [
"SIM", # flake8-simplify
"ARG", # flake8-unused-arguments
]
ignore = ["E501", "N805", "RUF005", "S603", "SIM105"]
ignore = ["E501", "RUF005", "S603", "SIM105", "UP013"]

[tool.pytest.ini_options]
log_format = "%(message)s"
Expand Down
7 changes: 3 additions & 4 deletions ssh_zone_handler/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@ def __init__(self, config: ZoneHandlerConf) -> None:
self.config: ZoneHandlerConf = config
self.journal_user: Final[str] = config.system.journalctl_user
self.login_user: Final[str] = config.system.login_user
self.server: Final[str] = config.system.server_type
self.service_user: Final[str] = config.system.server_user
service_unit: Final[str] = config.system.systemd_unit

Expand Down Expand Up @@ -67,7 +66,6 @@ def generate(self) -> None:
all_rules += self.__log_rule()
all_rules += self._server_command_rules()

rule: str
for rule in all_rules:
print(rule)

Expand Down Expand Up @@ -95,7 +93,8 @@ def __zone_list(self, username: str) -> Sequence[str]:

@staticmethod
def __parse(
ssh_command: str, user_zones: Sequence[str]
ssh_command: str,
user_zones: Sequence[str],
) -> tuple[str | None, list[str]]:
args: list[str] = ssh_command.split()
command: str | None = None
Expand Down Expand Up @@ -145,7 +144,7 @@ def __logs(self, zones: list[str]) -> None:
logging.info("Outputting logs for the following zone(s): %s", zones_str)

result: CompletedProcess[str] = self._runner(command, failure)
log_lines: list[str] = result.stdout.split("\n")
log_lines = result.stdout.split("\n")

line: str
for line in self._filter_logs(log_lines, zones):
Expand Down
6 changes: 1 addition & 5 deletions ssh_zone_handler/bind.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,6 @@ def __lookup(self, zone: str, failure: str) -> str | None:

result: CompletedProcess[str] = self._runner(command, failure)

line: str
matched: re.Match[str] | None
pattern = re.compile(r"^([^:]+): (.+)$")
for line in result.stdout.split("\n"):
matched = pattern.match(line)
Expand Down Expand Up @@ -75,15 +73,13 @@ def _dump(self, zone: str) -> None:

run_failure = f'Failed to dump content of zone "{zone}"'
result: CompletedProcess[str] = self._runner(command, run_failure)
zone_content: str = result.stdout.rstrip()
zone_content = result.stdout.rstrip()

print(zone_content)

@staticmethod
def _filter_logs(log_lines: list[str], zones: list[str]) -> Iterator[str]:
line: str
for line in log_lines:
zone: str
for zone in zones:
if (
f"zone {zone}/IN" in line
Expand Down
20 changes: 10 additions & 10 deletions ssh_zone_handler/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,8 +91,8 @@ def ssh_keys(config_file: Path = CONFIG_FILE) -> None:
logging.debug(str(cfe))
sys.exit(1)

szh = SshZoneAuthorizedKeys(config)
szh.output()
szh_authorized_keys = SshZoneAuthorizedKeys(config)
szh_authorized_keys.output()


def sudoers(config_file: Path = CONFIG_FILE) -> None:
Expand All @@ -109,14 +109,14 @@ def sudoers(config_file: Path = CONFIG_FILE) -> None:
except ConfigFileError as cfe:
_error_out(str(cfe))

szh: BindSudoers | KnotSudoers
szh_sudoers: BindSudoers | KnotSudoers
if config.system.server_type == "bind":
szh = BindSudoers(config)
szh_sudoers = BindSudoers(config)
elif config.system.server_type == "knot":
szh = KnotSudoers(config)
szh_sudoers = KnotSudoers(config)
else:
_error_out("Unsupported server configured")
szh.generate()
szh_sudoers.generate()


def wrapper(config_file: Path = CONFIG_FILE) -> None:
Expand Down Expand Up @@ -145,15 +145,15 @@ def wrapper(config_file: Path = CONFIG_FILE) -> None:
except KeyError:
pass

szh: BindCommand | KnotCommand
szh_command: BindCommand | KnotCommand
if config.system.server_type == "bind":
szh = BindCommand(config)
szh_command = BindCommand(config)
elif config.system.server_type == "knot":
szh = KnotCommand(config)
szh_command = KnotCommand(config)
else:
_error_out("Unsupported server configured")

try:
szh.invoke(ssh_command, username)
szh_command.invoke(ssh_command, username)
except InvokeError as error:
_error_out(str(error))
4 changes: 1 addition & 3 deletions ssh_zone_handler/knot.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,16 +60,14 @@ def _dump(self, zone: str) -> None:
run_failure = f'Failed to dump content of zone "{zone}"'

result: CompletedProcess[str] = self._runner(command, run_failure)
zone_content: str = result.stdout.rstrip()
zone_content = result.stdout.rstrip()
zone_content = self.__filter_dump(zone_content, zone)

print(zone_content)

@staticmethod
def _filter_logs(log_lines: list[str], zones: list[str]) -> Iterator[str]:
line: str
for line in log_lines:
zone: str
for zone in zones:
if f"[{zone}.]" in line:
yield line
Expand Down
8 changes: 6 additions & 2 deletions ssh_zone_handler/types.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""Custom types"""

from typing import Annotated, Final, Literal
from typing import Annotated, Final, Literal, TypedDict

from pydantic import BaseModel, Field, ValidationInfo, field_validator, model_validator
from typing_extensions import Self
Expand All @@ -12,10 +12,11 @@
Ptr4Zone = Annotated[str, Field(pattern=r"^[0-9/]+\.([0-9]+\.)+in-addr\.arpa$")]
Ptr6Zone = Annotated[str, Field(pattern=r"^([a-f0-9]\.)+ip6\.arpa$")]
Zone = FwdZone | Ptr4Zone | Ptr6Zone
ServiceDefault = TypedDict("ServiceDefault", {"unit": ServiceUnit, "user": SystemUser})

SSHKey = Annotated[str, Field(pattern=r"^(ecdsa-sha2-nistp256 AAAAE2VjZHNhLXNoYTItbmlzdHAyNT|ecdsa-sha2-nistp384 AAAAE2VjZHNhLXNoYTItbmlzdHAzOD|ecdsa-sha2-nistp521 AAAAE2VjZHNhLXNoYTItbmlzdHA1Mj|[email protected] AAAAInNrLWVjZHNhLXNoYTItbmlzdHAyNTZAb3BlbnNzaC5jb2|ssh-ed25519 AAAAC3NzaC1lZDI1NTE5|[email protected] AAAAGnNrLXNzaC1lZDI1NTE5QG9wZW5zc2guY29t|ssh-rsa AAAAB3NzaC1yc2)[0-9A-Za-z+/]+[=]{0,3}(\s.*)?$")] # fmt: skip

SERVICE_DEFAULTS: Final[dict[str, dict[str, str]]] = {
SERVICE_DEFAULTS: Final[dict[str, ServiceDefault]] = {
"bind": {
"unit": "named.service",
"user": "bind",
Expand All @@ -39,6 +40,7 @@ class SystemConf(BaseModel, extra="forbid", frozen=True):
systemd_unit: ServiceUnit = Field(default="", validate_default=True)

@field_validator("server_user", mode="before")
@classmethod
def _default_user(cls, user: str, values: ValidationInfo) -> str:
if not user:
try:
Expand All @@ -48,6 +50,7 @@ def _default_user(cls, user: str, values: ValidationInfo) -> str:
return user

@field_validator("systemd_unit", mode="before")
@classmethod
def _default_unit(cls, systemd_unit: str, values: ValidationInfo) -> str:
if not systemd_unit:
try:
Expand All @@ -66,6 +69,7 @@ class UserConf(BaseModel, extra="forbid", frozen=True):
zones: list[Zone]

@field_validator("ssh_keys", mode="after")
@classmethod
def _clean_ssh_keys(cls, ssh_keys: list[SSHKey]) -> list[SSHKey]:
cleaned_keys: list[SSHKey] = []
for ssh_key in ssh_keys:
Expand Down