Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 57 additions & 11 deletions torchx/specs/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import warnings
from dataclasses import asdict, dataclass, field
from datetime import datetime
from enum import Enum
from enum import Enum, IntEnum
from json import JSONDecodeError
from string import Template
from typing import (
Expand Down Expand Up @@ -892,6 +892,30 @@ class runopt:
Represents the metadata about the specific run option
"""

class AutoAlias(IntEnum):
snake_case = 0x1
SNAKE_CASE = 0x2
camelCase = 0x4

@staticmethod
def convert_to_camel_case(alias: str) -> str:
words = re.split(r"[_\-\s]+|(?<=[a-z])(?=[A-Z])", alias)
words = [w for w in words if w] # Remove empty strings
if not words:
return ""
return words[0].lower() + "".join(w.capitalize() for w in words[1:])

@staticmethod
def convert_to_snake_case(alias: str) -> str:
alias = re.sub(r"[-\s]+", "_", alias)
alias = re.sub(r"([a-z0-9])([A-Z])", r"\1_\2", alias)
alias = re.sub(r"([A-Z]+)([A-Z][a-z])", r"\1_\2", alias)
return alias.lower()

@staticmethod
def convert_to_const_case(alias: str) -> str:
return runopt.AutoAlias.convert_to_snake_case(alias).upper()

class alias(str):
pass

Expand All @@ -902,8 +926,8 @@ class deprecated(str):
opt_type: Type[CfgVal]
is_required: bool
help: str
aliases: list[alias] | None = None
deprecated_aliases: list[deprecated] | None = None
aliases: set[alias] | None = None
deprecated_aliases: set[deprecated] | None = None

@property
def is_type_list_of_str(self) -> bool:
Expand Down Expand Up @@ -1189,15 +1213,28 @@ def cfg_from_json_repr(self, json_repr: str) -> Dict[str, CfgVal]:
cfg[key] = val
return cfg

def _generate_aliases(
self, auto_alias: int, aliases: set[str]
) -> set[runopt.alias]:
generated_aliases = set()
for alias in aliases:
if auto_alias & runopt.AutoAlias.camelCase:
generated_aliases.add(runopt.AutoAlias.convert_to_camel_case(alias))
if auto_alias & runopt.AutoAlias.snake_case:
generated_aliases.add(runopt.AutoAlias.convert_to_snake_case(alias))
if auto_alias & runopt.AutoAlias.SNAKE_CASE:
generated_aliases.add(runopt.AutoAlias.convert_to_const_case(alias))
return generated_aliases

def _get_primary_key_and_aliases(
self,
cfg_key: list[str] | str,
) -> tuple[str, list[runopt.alias], list[runopt.deprecated]]:
cfg_key: list[str | int] | str,
) -> tuple[str, set[runopt.alias], set[runopt.deprecated]]:
"""
Returns the primary key and aliases for the given cfg_key.
"""
if isinstance(cfg_key, str):
return cfg_key, [], []
return cfg_key, set(), set()

if len(cfg_key) == 0:
raise ValueError("cfg_key must be a non-empty list")
Expand All @@ -1211,13 +1248,16 @@ def _get_primary_key_and_aliases(
stacklevel=2,
)
primary_key = None
aliases = list[runopt.alias]()
deprecated_aliases = list[runopt.deprecated]()
auto_alias = 0x0
aliases = set[runopt.alias]()
deprecated_aliases = set[runopt.deprecated]()
for name in cfg_key:
if isinstance(name, runopt.alias):
aliases.append(name)
aliases.add(name)
elif isinstance(name, runopt.deprecated):
deprecated_aliases.append(name)
deprecated_aliases.add(name)
elif isinstance(name, int):
auto_alias = auto_alias | name
else:
if primary_key is not None:
raise ValueError(
Expand All @@ -1228,11 +1268,17 @@ def _get_primary_key_and_aliases(
raise ValueError(
"Missing cfg_key. Please provide one other than the aliases."
)
if auto_alias != 0x0:
aliases_to_generate_for = aliases | {primary_key}
additional_aliases = self._generate_aliases(
auto_alias, aliases_to_generate_for
)
aliases.update(additional_aliases)
return primary_key, aliases, deprecated_aliases

def add(
self,
cfg_key: str | list[str],
cfg_key: str | list[str | int],
type_: Type[CfgVal],
help: str,
default: CfgVal = None,
Expand Down
22 changes: 22 additions & 0 deletions torchx/specs/test/api_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -648,6 +648,28 @@ def test_runopts_add_with_deprecated_aliases(self) -> None:
"Run option `jobPriority` is deprecated, use `job_priority` instead",
)

def test_runopt_auto_aliases(self) -> None:
opts = runopts()
opts.add(
["job_priority", runopt.AutoAlias.camelCase],
type_=str,
help="run as user",
)
opts.add(
[
"model_type_name",
runopt.AutoAlias.camelCase | runopt.AutoAlias.SNAKE_CASE,
],
type_=str,
help="run as user",
)
self.assertEqual(2, len(opts._opts))
self.assertIsNotNone(opts.get("job_priority"))
self.assertIsNotNone(opts.get("jobPriority"))
self.assertIsNotNone(opts.get("model_type_name"))
self.assertIsNotNone(opts.get("modelTypeName"))
self.assertIsNotNone(opts.get("MODEL_TYPE_NAME"))

def get_runopts(self) -> runopts:
opts = runopts()
opts.add("run_as", type_=str, help="run as user", required=True)
Expand Down
Loading