diff --git a/torchx/runner/config.py b/torchx/runner/config.py index 788acb276..0366d691e 100644 --- a/torchx/runner/config.py +++ b/torchx/runner/config.py @@ -278,14 +278,14 @@ def dump( continue # serialize list elements with `;` delimiter (consistent with torchx cli) - if opt.opt_type == List[str]: + if opt.is_type_list_of_str: # deal with empty or None default lists if opt.default: # pyre-ignore[6] opt.default type checked already as List[str] val = ";".join(opt.default) else: val = _NONE - elif opt.opt_type == Dict[str, str]: + elif opt.is_type_dict_of_str: # deal with empty or None default lists if opt.default: # pyre-ignore[16] opt.default type checked already as Dict[str, str] @@ -536,26 +536,26 @@ def load(scheduler: str, f: TextIO, cfg: Dict[str, CfgVal]) -> None: # this also handles empty or None lists cfg[name] = None else: - runopt = runopts.get(name) + opt = runopts.get(name) - if runopt is None: + if opt is None: log.warning( f"`{name} = {value}` was declared in the [{section}] section " f" of the config file but is not a runopt of `{scheduler}` scheduler." f" Remove the entry from the config file to no longer see this warning" ) else: - if runopt.opt_type is bool: + if opt.opt_type is bool: # need to handle bool specially since str -> bool is based on # str emptiness not value (e.g. bool("False") == True) cfg[name] = config.getboolean(section, name) - elif runopt.opt_type is List[str]: + elif opt.is_type_list_of_str: cfg[name] = value.split(";") - elif runopt.opt_type is Dict[str, str]: + elif opt.is_type_dict_of_str: cfg[name] = { s.split(":", 1)[0]: s.split(":", 1)[1] for s in value.replace(",", ";").split(";") } else: # pyre-ignore[29] - cfg[name] = runopt.opt_type(value) + cfg[name] = opt.opt_type(value) diff --git a/torchx/runner/test/config_test.py b/torchx/runner/test/config_test.py index c573bd3bb..901018c9a 100644 --- a/torchx/runner/test/config_test.py +++ b/torchx/runner/test/config_test.py @@ -95,22 +95,34 @@ def _run_opts(self) -> runopts: ) opts.add( "l", - type_=List[str], + type_=list[str], default=["a", "b", "c"], help="a list option", ) opts.add( - "l_none", + "l_typing", type_=List[str], + default=["a", "b", "c"], + help="a typing.List option", + ) + opts.add( + "l_none", + type_=list[str], default=None, help="a None list option", ) opts.add( "d", - type_=Dict[str, str], + type_=dict[str, str], default={"foo": "bar"}, help="a dict option", ) + opts.add( + "d_typing", + type_=Dict[str, str], + default={"foo": "bar"}, + help="a typing.Dict option", + ) opts.add( "d_none", type_=Dict[str, str], @@ -151,6 +163,10 @@ def _run_opts(self) -> runopts: [test] s = my_default i = 100 +l = abc;def +l_typing = ghi;jkl +d = a:b,c:d +d_typing = e:f,g:h """ _MY_CONFIG2 = """# @@ -387,6 +403,10 @@ def test_apply_dirs(self, _) -> None: self.assertEqual("runtime_value", cfg.get("s")) self.assertEqual(100, cfg.get("i")) self.assertEqual(1.2, cfg.get("f")) + self.assertEqual({"a": "b", "c": "d"}, cfg.get("d")) + self.assertEqual({"e": "f", "g": "h"}, cfg.get("d_typing")) + self.assertEqual(["abc", "def"], cfg.get("l")) + self.assertEqual(["ghi", "jkl"], cfg.get("l_typing")) def test_dump_invalid_scheduler(self) -> None: with self.assertRaises(ValueError): @@ -460,7 +480,7 @@ def test_dump_and_load_all_runopt_types(self, _) -> None: # all runopts in the TestScheduler have defaults, just check against those for opt_name, opt in TestScheduler("test").run_opts(): - self.assertEqual(cfg.get(opt_name), opt.default) + self.assertEqual(opt.default, cfg.get(opt_name)) def test_dump_and_load_all_registered_schedulers(self) -> None: # dump all the runopts for all registered schedulers diff --git a/torchx/specs/api.py b/torchx/specs/api.py index c28212bc5..e3e954a5b 100644 --- a/torchx/specs/api.py +++ b/torchx/specs/api.py @@ -789,6 +789,60 @@ class runopt: is_required: bool help: str + @property + def is_type_list_of_str(self) -> bool: + """ + Checks if the option type is a list of strings. + + Returns: + bool: True if the option type is either List[str] or list[str], False otherwise. + """ + return self.opt_type in (List[str], list[str]) + + @property + def is_type_dict_of_str(self) -> bool: + """ + Checks if the option type is a dict of string keys to string values. + + Returns: + bool: True if the option type is either Dict[str, str] or dict[str, str], False otherwise. + """ + return self.opt_type in (Dict[str, str], dict[str, str]) + + def cast_to_type(self, value: str) -> CfgVal: + """Casts the given `value` (in its string representation) to the type of this run option. + Below are the cast rules for each option type and value literal: + + 1. opt_type=str, value="foo" -> "foo" + 1. opt_type=bool, value="True"/"False" -> True/False + 1. opt_type=int, value="1" -> 1 + 1. opt_type=float, value="1.1" -> 1.1 + 1. opt_type=list[str]/List[str], value="a,b,c" or value="a;b;c" -> ["a", "b", "c"] + 1. opt_type=dict[str,str]/Dict[str,str], + value="key1:val1,key2:val2" or value="key1:val1;key2:val2" -> {"key1": "val1", "key2": "val2"} + + NOTE: dict parsing uses ":" as the kv separator (rather than the standard "=") because "=" is used + at the top-level cfg to parse runopts (notice the plural) from the CLI. Originally torchx only supported + primitives and list[str] as CfgVal but dict[str,str] was added in https://github.com/pytorch/torchx/pull/855 + """ + + if self.opt_type is None: + raise ValueError("runopt's opt_type cannot be `None`") + elif self.opt_type == bool: + return value.lower() == "true" + elif self.opt_type in (List[str], list[str]): + # lists may be ; or , delimited + # also deal with trailing "," by removing empty strings + return [v for v in value.replace(";", ",").split(",") if v] + elif self.opt_type in (Dict[str, str], dict[str, str]): + return { + s.split(":", 1)[0]: s.split(":", 1)[1] + for s in value.replace(";", ",").split(",") + } + else: + assert self.opt_type in (str, int, float) + return self.opt_type(value) + class runopts: """ @@ -948,27 +1002,11 @@ def cfg_from_str(self, cfg_str: str) -> Dict[str, CfgVal]: """ - def _cast_to_type(value: str, opt_type: Type[CfgVal]) -> CfgVal: - if opt_type == bool: - return value.lower() == "true" - elif opt_type in (List[str], list[str]): - # lists may be ; or , delimited - # also deal with trailing "," by removing empty strings - return [v for v in value.replace(";", ",").split(",") if v] - elif opt_type in (Dict[str, str], dict[str, str]): - return { - s.split(":", 1)[0]: s.split(":", 1)[1] - for s in value.replace(";", ",").split(",") - } - else: - # pyre-ignore[19, 6] type won't be dict here as we handled it above - return opt_type(value) - cfg: Dict[str, CfgVal] = {} for key, val in to_dict(cfg_str).items(): - runopt_ = self.get(key) - if runopt_: - cfg[key] = _cast_to_type(val, runopt_.opt_type) + opt = self.get(key) + if opt: + cfg[key] = opt.cast_to_type(val) else: logger.warning( f"{YELLOW_BOLD}Unknown run option passed to scheduler: {key}={val}{RESET}" @@ -982,16 +1020,16 @@ def cfg_from_json_repr(self, json_repr: str) -> Dict[str, CfgVal]: cfg: Dict[str, CfgVal] = {} cfg_dict = json.loads(json_repr) for key, val in cfg_dict.items(): - runopt_ = self.get(key) - if runopt_: + opt = self.get(key) + if opt: # Optional runopt cfg values default their value to None, # but use `_type` to specify their type when provided. # Make sure not to treat None's as lists/dictionaries if val is None: cfg[key] = val - elif runopt_.opt_type == List[str]: + elif opt.is_type_list_of_str: cfg[key] = [str(v) for v in val] - elif runopt_.opt_type == Dict[str, str]: + elif opt.is_type_dict_of_str: cfg[key] = {str(k): str(v) for k, v in val.items()} else: cfg[key] = val diff --git a/torchx/specs/test/api_test.py b/torchx/specs/test/api_test.py index e02898943..2490e89b2 100644 --- a/torchx/specs/test/api_test.py +++ b/torchx/specs/test/api_test.py @@ -38,6 +38,7 @@ RetryPolicy, Role, RoleStatus, + runopt, runopts, ) @@ -437,6 +438,16 @@ def test_valid_values(self) -> None: self.assertTrue(cfg.get("preemptible")) self.assertIsNone(cfg.get("unknown")) + def test_runopt_cast_to_type_typing_list(self) -> None: + opt = runopt(default="", opt_type=List[str], is_required=False, help="help") + self.assertEqual(["a", "b", "c"], opt.cast_to_type("a,b,c")) + self.assertEqual(["abc", "def", "ghi"], opt.cast_to_type("abc;def;ghi")) + + def test_runopt_cast_to_type_builtin_list(self) -> None: + opt = runopt(default="", opt_type=list[str], is_required=False, help="help") + self.assertEqual(["a", "b", "c"], opt.cast_to_type("a,b,c")) + self.assertEqual(["abc", "def", "ghi"], opt.cast_to_type("abc;def;ghi")) + def test_runopts_add(self) -> None: """ tests for various add option variations