Skip to content
Open
172 changes: 168 additions & 4 deletions src/oumi/core/configs/base_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,11 @@
# limitations under the License.

import dataclasses
import inspect
import logging
import re
from collections.abc import Iterator
from enum import Enum
from io import StringIO
from pathlib import Path
from typing import Any, Optional, TypeVar, Union, cast
Expand All @@ -28,6 +30,92 @@

_CLI_IGNORED_PREFIXES = ["--local-rank"]

# Set of primitive types that OmegaConf can handle directly
_PRIMITIVE_TYPES = {str, int, float, bool, type(None), bytes, Path, Enum}


def _is_primitive_type(value: Any) -> bool:
"""Check if a value is of a primitive type that OmegaConf can handle."""
return (
type(value) in _PRIMITIVE_TYPES
or isinstance(value, Path)
or isinstance(value, Enum)
)


def _handle_non_primitives(config: Any, removed_paths: set, path: str = "") -> Any:
"""Recursively process config object to handle non-primitive values.

Args:
config: The config object to process
removed_paths: Set to track paths of removed non-primitive values
path: The current path in the config (for logging)

Returns:
The processed config with non-primitive values removed
"""
if isinstance(config, list):
return [
_handle_non_primitives(item, removed_paths, f"{path}[{i}]")
for i, item in enumerate(config)
]

if isinstance(config, dict):
result = {}
for key, value in config.items():
current_path = f"{path}.{key}" if path else key
if _is_primitive_type(value):
result[key] = value
else:
# Recursively process nested dictionaries and other non-primitive values
processed_value = _handle_non_primitives(
value, removed_paths, current_path
)
if processed_value is not None:
result[key] = processed_value
else:
removed_paths.add(current_path)
result[key] = None
return result

if _is_primitive_type(config):
return config

if hasattr(config, "__dataclass_fields__"):
result = {}
for field_name in config.__dataclass_fields__:
field_value = getattr(config, field_name)
current_path = f"{path}.{field_name}" if path else field_name
processed_value = _handle_non_primitives(
field_value, removed_paths, current_path
)
if processed_value is not None:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IMO if we redact a value, we should keep the key and have the value be None. You did something similar up in line 79.

result[field_name] = processed_value
else:
removed_paths.add(current_path)
result[field_name] = None
return result

# Try to convert functions to their source code
if callable(config):
try:
if hasattr(config, "__name__") and config.__name__ == "<lambda>":
removed_paths.add(path)
return None

# Lambda functions and built-in functions can't have source extracted
source = inspect.getsource(config)
# Only return source if we successfully got it
return source
except (TypeError, OSError):
# Can't get source for lambdas, built-ins, or C extensions
removed_paths.add(path)
return None

# For any other type, remove it and track the path
removed_paths.add(path)
return None


def _filter_ignored_args(arg_list: list[str]) -> list[str]:
"""Filters out ignored CLI arguments."""
Expand All @@ -54,11 +142,37 @@ def _read_config_without_interpolation(config_path: str) -> str:
return stringified_config


@dataclasses.dataclass
@dataclasses.dataclass(eq=False)
class BaseConfig:
def to_yaml(self, config_path: Union[str, Path, StringIO]) -> None:
"""Saves the configuration to a YAML file."""
OmegaConf.save(config=self, f=config_path)
"""Saves the configuration to a YAML file.

Non-primitive values are removed and warnings are logged.

Args:
config_path: Path to save the config to
"""
# Convert dataclass fields to a dictionary first
config_dict = {}
for field_name, field_value in self:
config_dict[field_name] = field_value

# Process non-primitive values before creating OmegaConf structure
removed_paths = set()
processed_config = _handle_non_primitives(
config_dict, removed_paths=removed_paths
)

# Log warnings for removed values
if removed_paths:
logger = logging.getLogger(__name__)
logger.warning(
"The following non-primitive values were removed from the config "
"as they cannot be saved to YAML:\n"
+ "\n".join(f"- {path}" for path in sorted(removed_paths))
)

OmegaConf.save(config=processed_config, f=config_path)

@classmethod
def from_yaml(
Expand Down Expand Up @@ -182,7 +296,18 @@ def print_config(self, logger: Optional[logging.Logger] = None) -> None:
if logger is None:
logger = logging.getLogger(__name__)

config_yaml = OmegaConf.to_yaml(self, resolve=True)
# Convert dataclass fields to a dictionary first
config_dict = {}
for field_name, field_value in self:
config_dict[field_name] = field_value

# Process non-primitive values before creating OmegaConf structure
removed_paths = set()
processed_config = _handle_non_primitives(
config_dict, removed_paths=removed_paths
)

config_yaml = OmegaConf.to_yaml(processed_config, resolve=True)
logger.info(f"Configuration:\n{config_yaml}")

def finalize_and_validate(self) -> None:
Expand Down Expand Up @@ -211,3 +336,42 @@ def __iter__(self) -> Iterator[tuple[str, Any]]:
"""
for param in dataclasses.fields(self):
yield param.name, getattr(self, param.name)

def __eq__(self, other: object) -> bool:
"""Custom equality comparison that handles callable objects specially."""
if not isinstance(other, self.__class__):
return False

for field_name, field_value in self:
other_value = getattr(other, field_name)

# Special handling for callable objects
if callable(field_value) and callable(other_value):
if (
hasattr(field_value, "__name__")
and hasattr(other_value, "__name__")
and field_value.__name__ == "<lambda>"
and other_value.__name__ == "<lambda>"
):
# Consider all lambda functions equal for config comparison purposes
continue

# For regular functions, try to compare by source code
try:
field_source = inspect.getsource(field_value).strip()
other_source = inspect.getsource(other_value).strip()
if field_source != other_source:
return False
except (TypeError, OSError):
# If we can't get source, fall back to identity comparison
if field_value != other_value:
return False
elif callable(field_value) or callable(other_value):
# One is callable, the other is not
return False
else:
# Normal comparison for non-callable values
if field_value != other_value:
return False

return True
Loading
Loading