Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
19 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion docs/source/extension/linting.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,8 @@
# Area
'Europe',
'WARP',
'Cloudflare'
'Cloudflare',
'CloudRift'
}

# Add multi-word terms that should be treated as a single entity
Expand Down
17 changes: 17 additions & 0 deletions docs/source/getting-started/installation.rst
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,8 @@ Install SkyPilot using pip:
# Seeweb is only supported for Python >= 3.10
pip install "skypilot[seeweb]"
pip install "skypilot[primeintellect]"
pip install "skypilot[cloudrift]"


pip install "skypilot[all]"

Expand Down Expand Up @@ -79,6 +81,8 @@ Install SkyPilot using pip:
# Seeweb is only supported for Python >= 3.10
pip install "skypilot[seeweb]"
pip install "skypilot-nightly[primeintellect]"
pip install "skypilot-nightly[cloudrift]"


pip install "skypilot-nightly[all]"

Expand Down Expand Up @@ -117,6 +121,7 @@ Install SkyPilot using pip:
# Seeweb is only supported for Python >= 3.10
pip install -e ".[seeweb]"
pip install -e ".[primeintellect]"
pip install -e ".[cloudrift]"

pip install -e ".[all]"

Expand Down Expand Up @@ -246,6 +251,7 @@ This will produce a summary like:
Seeweb: enabled
vSphere: enabled
Cloudflare (for R2 object store): enabled
CloudRift: enabled
Kubernetes: enabled

If any cloud's credentials or dependencies are missing, ``sky check`` will
Expand Down Expand Up @@ -495,6 +501,17 @@ Vast |community-badge|
echo "<your_api_key_here>" > ~/.config/vastai/vast_api_key


CloudRift |community-badge|
~~~~~~~~~~~~~~~~~~~~~~~~~~

`CloudRift <https://www.cloudrift.ai/>`__ is a cloud provider that provides access go elastic GPU accelerated VMs. To configure CloudRift access, go to the `Account <https://console.cloudrift.ai/keys/>`_ page on your CloudRift console to get your **API key**. Then, run:

.. code-block:: shell

mkdir -p ~/.config/cloudrift
echo "<your_api_key_here>" > ~/.config/cloudrift/cloudrift_api_key



Fluidstack |community-badge|
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
Expand Down
2 changes: 2 additions & 0 deletions sky/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,7 @@ def set_proxy_env_var(proxy_var: str, urllib_var: Optional[str]):
IBM = clouds.IBM
AWS = clouds.AWS
Azure = clouds.Azure
CloudRift = clouds.CloudRift
Cudo = clouds.Cudo
GCP = clouds.GCP
Lambda = clouds.Lambda
Expand All @@ -157,6 +158,7 @@ def set_proxy_env_var(proxy_var: str, urllib_var: Optional[str]):
'__version__',
'AWS',
'Azure',
'CloudRift',
'Cudo',
'GCP',
'IBM',
Expand Down
18 changes: 18 additions & 0 deletions sky/authentication.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
from sky.provision.kubernetes import utils as kubernetes_utils
from sky.provision.lambda_cloud import lambda_utils
from sky.provision.primeintellect import utils as primeintellect_utils
from sky.provision.cloudrift import utils as cloudrift_utils
from sky.utils import auth_utils
from sky.utils import common_utils
from sky.utils import subprocess_utils
Expand Down Expand Up @@ -480,3 +481,20 @@ def setup_seeweb_authentication(config: Dict[str, Any]) -> Dict[str, Any]:
config['auth']['remote_key_name'] = remote_name

return config


# ---------------------------------- RunPod ---------------------------------- #
def setup_cloudrift_authentication(config: Dict[str, Any]) -> Dict[str, Any]:
"""Sets up SSH authentication for RunPod.
- Generates a new SSH key pair if one does not exist.
- Adds the public SSH key to the user's RunPod account.
"""
_, public_key_path = get_or_generate_keys()
# with open(public_key_path, 'r', encoding='UTF-8') as pub_key_file:
# public_key = pub_key_file.read().strip()
# cloudrift_utils.get_cloudrift_client().add_ssh_key(public_key)

# Default username for Prime Intellect images
config['auth']['ssh_user'] = 'riftuser'
config['auth']['ssh_public_key'] = public_key_path
return configure_ssh_info(config)
2 changes: 2 additions & 0 deletions sky/backends/backend_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1131,6 +1131,8 @@ def _add_auth_to_cluster_config(cloud: clouds.Cloud, tmp_yaml_path: str):
config = auth.setup_primeintellect_authentication(config)
elif isinstance(cloud, clouds.Seeweb):
config = auth.setup_seeweb_authentication(config)
elif isinstance(cloud, clouds.CloudRift):
config = auth.setup_cloudrift_authentication(config)
else:
assert False, cloud
yaml_utils.dump_yaml(tmp_yaml_path, config)
Expand Down
1 change: 1 addition & 0 deletions sky/backends/cloud_vm_ray_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -293,6 +293,7 @@ def _get_cluster_config_template(cloud):
cloud_to_template = {
clouds.AWS: 'aws-ray.yml.j2',
clouds.Azure: 'azure-ray.yml.j2',
clouds.CloudRift: 'cloudrift-ray.yml.j2',
clouds.Cudo: 'cudo-ray.yml.j2',
clouds.GCP: 'gcp-ray.yml.j2',
clouds.Lambda: 'lambda-ray.yml.j2',
Expand Down
118 changes: 118 additions & 0 deletions sky/catalog/cloudrift_catalog.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
"""CloudRift Catalog

This module loads the service catalog file and can be used to
query instance types and pricing information for CloudRift.
"""

import typing
from typing import Dict, List, Optional, Tuple, Union

from sky.catalog import common
from sky.utils import ux_utils

if typing.TYPE_CHECKING:
from sky.clouds import cloud

# Initialize catalog dataframe
# This will be populated once we create the CloudRift CSV files
_df = None
try:
_df = common.read_catalog('cloudrift/vms.csv')
except (FileNotFoundError, ValueError):
# Create empty dataframe with expected columns if file doesn't exist yet
import pandas as pd
_df = pd.DataFrame(columns=[
'InstanceType', 'AcceleratorName', 'AcceleratorCount', 'vCPUs',
'MemoryGiB', 'GpuMemoryGiB', 'Price', 'SpotPrice', 'Region', 'Zone'
])


def instance_type_exists(instance_type: str) -> bool:
return common.instance_type_exists_impl(_df, instance_type)


def validate_region_zone(
region: Optional[str],
zone: Optional[str]) -> Tuple[Optional[str], Optional[str]]:
if zone is not None:
with ux_utils.print_exception_no_traceback():
raise ValueError('CloudRift does not support zones.')
return common.validate_region_zone_impl('cloudrift', _df, region, zone)


def get_hourly_cost(instance_type: str,
use_spot: bool = False,
region: Optional[str] = None,
zone: Optional[str] = None) -> float:
"""Returns the cost, or the cheapest cost among all regions for spot."""
if zone is not None:
with ux_utils.print_exception_no_traceback():
raise ValueError('CloudRift does not support zones.')
return common.get_hourly_cost_impl(_df, instance_type, use_spot, region,
zone)


def get_vcpus_mem_from_instance_type(
instance_type: str) -> Tuple[Optional[float], Optional[float]]:
return common.get_vcpus_mem_from_instance_type_impl(_df, instance_type)


def get_default_instance_type(cpus: Optional[str] = None,
memory: Optional[str] = None,
disk_tier: Optional[str] = None,
region: Optional[str] = None,
zone: Optional[str] = None) -> Optional[str]:
del disk_tier
# Default to A100 GPU instance if available
if instance_type_exists('gpu-a100x1-80gb'):
return 'gpu-a100x1-80gb'
return common.get_instance_type_for_cpus_mem_impl(_df, cpus, memory, region,
zone)


def get_accelerators_from_instance_type(
instance_type: str) -> Optional[Dict[str, Union[int, float]]]:
return common.get_accelerators_from_instance_type_impl(_df, instance_type)


def get_instance_type_for_accelerator(
acc_name: str,
acc_count: int,
cpus: Optional[str] = None,
memory: Optional[str] = None,
use_spot: bool = False,
region: Optional[str] = None,
zone: Optional[str] = None) -> Tuple[Optional[List[str]], List[str]]:
"""Returns a list of instance types that have the given accelerator."""
if zone is not None:
with ux_utils.print_exception_no_traceback():
raise ValueError('CloudRift does not support zones.')
return common.get_instance_type_for_accelerator_impl(df=_df,
acc_name=acc_name,
acc_count=acc_count,
cpus=cpus,
memory=memory,
use_spot=use_spot,
region=region,
zone=zone)


def get_region_zones_for_instance_type(instance_type: str,
use_spot: bool) -> List['cloud.Region']:
df = _df[_df['InstanceType'] == instance_type]
return common.get_region_zones(df, use_spot)


def list_accelerators(
gpus_only: bool,
name_filter: Optional[str],
region_filter: Optional[str],
quantity_filter: Optional[int],
case_sensitive: bool = True,
all_regions: bool = False,
require_price: bool = True) -> Dict[str, List[common.InstanceTypeInfo]]:
"""Returns all instance types in CloudRift offering GPUs."""
del require_price # Unused.
return common.list_accelerators_impl('CloudRift', _df, gpus_only, name_filter,
region_filter, quantity_filter,
case_sensitive, all_regions)
136 changes: 136 additions & 0 deletions sky/catalog/data_fetchers/fetch_cloudrift.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
"""A script that fetches CloudRift instance types and generates a CSV catalog.

Usage:
python fetch_cloudrift.py
"""

import csv
import json
import os
import sys
from typing import Dict, List

# Add the parent directory to the path so we can import sky modules
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '../../..')))

from sky.provision.cloudrift.utils import get_cloudrift_client

# Constants
BYTES_TO_GIB = 1024 * 1024 * 1024 # 1 GiB = 1024^3 bytes


def extract_region_from_dc(dc_name: str, providers_data: List[Dict]) -> str:
"""Extract region information from datacenter name.

Uses provider data to extract country code as region.
Falls back to original extraction method if datacenter not found in providers data.

Args:
dc_name: The datacenter name, e.g. 'us-east-nc-nr-1'
providers_data: List of provider dictionaries from CloudRift API

Returns:
Region string (country code) if found in providers data, otherwise
extracts region from datacenter name (e.g. 'us-east-nc-nr-1' -> 'us-east-nc-nr')
"""
# First try to find the datacenter in providers data
for provider in providers_data:
for datacenter in provider.get('datacenters', []):
if datacenter.get('name') == dc_name:
# Use country code as region
return datacenter.get('country_code', '')

# Fall back to original extraction method
parts = dc_name.split('-')
if parts[-1].isdigit():
return '-'.join(parts[:-1])
return dc_name


def create_catalog(output_dir: str) -> None:
"""Create the catalog by querying CloudRift API and generating a CSV file."""
client = get_cloudrift_client()

# Get instance types
instance_types = client.get_instance_types()

# Get providers data to extract region information
providers = client.get_providers()

with open(os.path.join(output_dir, 'vms.csv'), mode='w', encoding='utf-8') as f:
writer = csv.writer(f, delimiter=',', quotechar='"')
writer.writerow([
'InstanceType',
'AcceleratorName',
'AcceleratorCount',
'vCPUs',
'MemoryGiB',
'GpuInfo',
'Region',
'SpotPrice',
'Price',
'AvailabilityZone'
])

for instance_type in instance_types:
# Process each variant
for variant in instance_type.get('variants', []):
instance_name = variant.get('name')
gpu_count = variant.get('gpu_count', 0)

# Skip instances without GPUs
if gpu_count == 0:
continue

# Extract instance properties
vcpus = variant.get('logical_cpu_count', 0)
memory_bytes = variant.get('dram', 0)
memory_gib = memory_bytes / BYTES_TO_GIB

# Get price (convert from cents to dollars)
price = variant.get('cost_per_hour', 0) / 100.0

# Extract accelerator name from brand_short
accelerator_name = instance_type.get('brand_short', '')

# Get available datacenters
dcs = variant.get('nodes_per_dc', {})

# If there are no datacenters, use empty values but still include the instance
if not dcs:
writer.writerow([
instance_name,
accelerator_name,
gpu_count,
vcpus,
round(memory_gib, 1),
accelerator_name,
'', # Region
0.0, # SpotPrice (CloudRift doesn't have spot instances yet)
price,
'' # AvailabilityZone
])
continue

# Write a row for each datacenter
for dc_name in dcs.keys():
region = extract_region_from_dc(dc_name, providers)

writer.writerow([
instance_name,
accelerator_name,
gpu_count,
vcpus,
round(memory_gib, 1),
accelerator_name,
region,
0.0, # SpotPrice (CloudRift doesn't have spot instances yet)
price,
dc_name # Using the datacenter name as the availability zone
])


if __name__ == '__main__':
os.makedirs('cloudrift', exist_ok=True)
create_catalog('cloudrift')
print('CloudRift catalog saved to cloudrift/vms.csv')
2 changes: 2 additions & 0 deletions sky/clouds/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
# isort: split
from sky.clouds.aws import AWS
from sky.clouds.azure import Azure
from sky.clouds.cloudrift import CloudRift
from sky.clouds.cudo import Cudo
from sky.clouds.do import DO
from sky.clouds.fluidstack import Fluidstack
Expand All @@ -40,6 +41,7 @@
'AWS',
'Azure',
'Cloud',
'CloudRift',
'Cudo',
'DummyCloud',
'GCP',
Expand Down
Loading
Loading