Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 38 additions & 0 deletions tests/unit/vertexai/model_garden/test_model_garden.py
Original file line number Diff line number Diff line change
Expand Up @@ -1355,6 +1355,44 @@ def test_list_deployable_models(self, list_publisher_models_mock):
"google/gemma-2-2b",
]

def test_list_models(self, list_publisher_models_mock):
"""Tests listing models."""
aiplatform.init(
project=_TEST_PROJECT,
location=_TEST_LOCATION,
)

mg_models = model_garden.list_models()
list_publisher_models_mock.assert_called_with(
types.ListPublisherModelsRequest(
parent="publishers/*",
list_all_versions=True,
filter="is_hf_wildcard(false)",
)
)

assert mg_models == [
"google/paligemma@001",
"google/paligemma@002",
"google/paligemma@003",
"google/paligemma@004",
]

hf_models = model_garden.list_models(list_hf_models=True)
list_publisher_models_mock.assert_called_with(
types.ListPublisherModelsRequest(
parent="publishers/*",
list_all_versions=True,
filter="is_hf_wildcard(true)",
)
)
assert hf_models == [
"google/gemma-2-2b",
"google/gemma-2-2b",
"google/gemma-2-2b",
"google/gemma-2-2b",
]

def test_batch_prediction_success(self, batch_prediction_mock):
aiplatform.init(
project=_TEST_PROJECT,
Expand Down
3 changes: 2 additions & 1 deletion vertexai/model_garden/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,5 +22,6 @@
OpenModel = _model_garden.OpenModel
PartnerModel = _model_garden.PartnerModel
list_deployable_models = _model_garden.list_deployable_models
list_models = _model_garden.list_models

__all__ = ("OpenModel", "PartnerModel", "list_deployable_models")
__all__ = ("OpenModel", "PartnerModel", "list_deployable_models", "list_models")
45 changes: 45 additions & 0 deletions vertexai/model_garden/_model_garden.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ def list_deployable_models(
`{publisher}/{model}@{version}` or Hugging Face model ID in the format
of `{organization}/{model}`.
"""

filter_str = _NATIVE_MODEL_FILTER
if list_hf_models:
filter_str = " AND ".join([_HF_WILDCARD_FILTER, _VERIFIED_DEPLOYMENT_FILTER])
Expand Down Expand Up @@ -93,6 +94,50 @@ def list_deployable_models(
return output


def list_models(
*, list_hf_models: bool = False, model_filter: Optional[str] = None
) -> List[str]:
"""Lists the models in Model Garden.

Args:
list_hf_models: Whether to list the Hugging Face models.
model_filter: Optional. A string to filter the models by.

Returns:
The names of the models in Model Garden in the format of
`{publisher}/{model}@{version}` or Hugging Face model ID in the format
of `{organization}/{model}`.
"""
filter_str = _NATIVE_MODEL_FILTER
if list_hf_models:
filter_str = _HF_WILDCARD_FILTER
if model_filter:
filter_str = (
f'{filter_str} AND (model_user_id=~"(?i).*{model_filter}.*" OR'
f' display_name=~"(?i).*{model_filter}.*")'
)

request = types.ListPublisherModelsRequest(
parent="publishers/*",
list_all_versions=True,
filter=filter_str,
)
client = initializer.global_config.create_client(
client_class=_ModelGardenClientWithOverride,
credentials=initializer.global_config.credentials,
location_override="us-central1",
)
response = client.list_publisher_models(request)
output = []
for page in response.pages:
for model in page.publisher_models:
output.append(
re.sub(r"publishers/(hf-|)|models/", "", model.name)
+ ("" if list_hf_models else ("@" + model.version_id))
)
return output


def _is_hugging_face_model(model_name: str) -> bool:
"""Returns whether the model is a Hugging Face model."""
return re.match(r"^(?P<publisher>[^/]+)/(?P<model>[^/@]+)$", model_name)
Expand Down
Loading