"""Interactive session helper functions for CloudOS."""
import pandas as pd
import sys
import re
import json
import time
from datetime import datetime, timedelta, timezone
from rich.table import Table
from rich.console import Console
from rich.panel import Panel
import requests
from requests.adapters import HTTPAdapter
from urllib3.util.retry import Retry
from cloudos_cli.utils.requests import retry_requests_get
[docs]
def validate_instance_type(instance_type, execution_platform='aws'):
"""Validate instance type format for the given execution platform.
Parameters
----------
instance_type : str
Instance type to validate
execution_platform : str
'aws' or 'azure'
Returns
-------
tuple
(is_valid: bool, error_message: str or None)
"""
if not instance_type or not isinstance(instance_type, str):
return False, "Instance type must be a non-empty string"
if execution_platform == 'aws':
# AWS EC2 instance format: <family><generation>.<size>
# Examples: c5.xlarge, m5.2xlarge, r6i.large, t3.medium, g4dn.xlarge
# Family: c, m, r, t, g, p, i, d, x, z, h, etc. (1-4 chars)
# Generation: digit(s) optionally followed by letter(s) for variants
# Size: nano, micro, small, medium, large, xlarge, 2xlarge, 4xlarge, etc.
aws_pattern = r'^[a-z]{1,4}\d+[a-z]*\.(\d+)?(nano|micro|small|medium|large|xlarge|metal)$'
if not re.match(aws_pattern, instance_type, re.IGNORECASE):
return False, (f"Invalid AWS instance type format: '{instance_type}'. "
f"Expected format: <family><generation>.<size> (e.g., c5.xlarge, m5.2xlarge)")
elif execution_platform == 'azure':
# Azure VM format: Standard_<series><version>_<size> or Basic_<series><version>
# Examples: Standard_F1s, Standard_D4as_v4, Standard_B2ms, Basic_A1
azure_pattern = r'^(Standard|Basic)_[A-Z]\d+[a-z]*(_v\d+)?$'
if not re.match(azure_pattern, instance_type):
return False, (f"Invalid Azure instance type format: '{instance_type}'. "
f"Expected format: Standard_<series><size> (e.g., Standard_F1s, Standard_D4as_v4)")
else:
# Unknown platform - skip validation
return True, None
return True, None
def _map_session_type_to_friendly_name(session_type):
"""Map internal session type names to user-friendly display names.
Parameters
----------
session_type : str
Internal session type (e.g., 'awsJupyterNotebook')
Returns
-------
str
User-friendly type name (e.g., 'Jupyter')
"""
type_mapping = {
'awsJupyterNotebook': 'Jupyter',
'azureJupyterNotebook': 'Jupyter',
'awsVSCode': 'VS Code',
'azureVSCode': 'VS Code',
'awsRstudio': 'RStudio',
'azureRstudio': 'RStudio',
'awsSpark': 'Spark',
'awsJupyterSparkNotebook': 'Spark',
'azureJupyterSparkNotebook': 'Spark',
'azureSpark': 'Spark',
'awsRStudio': 'RStudio', # Handle both capitalizations
'azureRStudio': 'RStudio',
'awsWindowsSession': 'Windows'
}
return type_mapping.get(session_type, session_type)
[docs]
def create_interactive_session_list_table(sessions, pagination_metadata=None, selected_columns=None, page_size=10, fetch_page_callback=None):
"""Create a rich table displaying interactive sessions with interactive pagination.
Parameters
----------
sessions : list
List of session objects from the API
pagination_metadata : dict, optional
Pagination information from the API response
selected_columns : str or list, optional
Comma-separated string or list of column names to display.
If None, uses responsive column selection based on terminal width.
Available columns: id, name, status, type, instance, cost, owner
page_size : int, optional
Number of sessions per page for interactive pagination. Default=10.
fetch_page_callback : callable, optional
Callback function to fetch a specific page of results.
Should accept page number (1-indexed) and return dict with 'sessions' and 'pagination_metadata' keys.
"""
console = Console()
# Define all available columns with their configuration
all_columns = {
'id': {
'header': 'ID',
'style': 'cyan',
'no_wrap': True,
'max_width': 24,
'accessor': '_id'
},
'name': {
'header': 'Name',
'style': 'green',
'overflow': 'ellipsis',
'max_width': 25,
'accessor': 'name'
},
'status': {
'header': 'Status',
'style': 'yellow',
'no_wrap': True,
'max_width': 12,
'accessor': 'status'
},
'type': {
'header': 'Type',
'style': 'magenta',
'overflow': 'fold',
'max_width': 20,
'accessor': 'interactiveSessionType'
},
'instance': {
'header': 'Instance',
'style': 'cyan',
'overflow': 'ellipsis',
'max_width': 15,
'accessor': 'resources.instanceType'
},
'cost': {
'header': 'Cost',
'style': 'green',
'no_wrap': True,
'max_width': 12,
'accessor': 'totalCostInUsd'
},
'owner': {
'header': 'Owner',
'style': 'white',
'overflow': 'ellipsis',
'max_width': 20,
'accessor': 'user.name'
},
'project': {
'header': 'Project',
'style': 'cyan',
'overflow': 'ellipsis',
'max_width': 20,
'accessor': 'project.name'
},
'created_at': {
'header': 'Created At',
'style': 'white',
'overflow': 'ellipsis',
'max_width': 20,
'accessor': 'createdAt'
},
'runtime': {
'header': 'Total Running Time',
'style': 'white',
'no_wrap': True,
'max_width': 18,
'accessor': 'totalRunningTimeInSeconds'
},
'saved_at': {
'header': 'Last Time Saved',
'style': 'white',
'overflow': 'ellipsis',
'max_width': 20,
'accessor': 'lastSavedAt'
},
'resources': {
'header': 'Resources',
'style': 'cyan',
'overflow': 'ellipsis',
'max_width': 30,
'accessor': 'resources.instanceType'
},
'backend': {
'header': 'Backend',
'style': 'magenta',
'overflow': 'fold',
'max_width': 15,
'accessor': 'interactiveSessionType'
},
'version': {
'header': 'Version',
'style': 'white',
'no_wrap': True,
'max_width': 15,
'accessor': 'rVersion'
},
'spot': {
'header': 'Spot',
'style': 'cyan',
'no_wrap': True,
'max_width': 6,
'accessor': 'resources.isCostSaving'
},
'cost_limit': {
'header': 'Cost Limit Left',
'style': 'yellow',
'no_wrap': True,
'max_width': 15,
'accessor': 'execution'
},
'time_left': {
'header': 'Time Until Shutdown',
'style': 'magenta',
'no_wrap': True,
'max_width': 20,
'accessor': 'execution.autoShutdownAtDate'
}
}
# Determine columns to display
if selected_columns:
if isinstance(selected_columns, str):
selected_columns = [col.strip() for col in selected_columns.split(',')]
columns_to_show = selected_columns
else:
# Responsive column selection based on terminal width
terminal_width = console.width
if terminal_width < 60:
columns_to_show = ['status', 'name', 'id']
elif terminal_width < 90:
columns_to_show = ['status', 'name', 'type', 'id', 'owner']
elif terminal_width < 130:
columns_to_show = ['status', 'name', 'type', 'instance', 'cost', 'id', 'owner']
else:
columns_to_show = ['id', 'name', 'status', 'type', 'instance', 'cost', 'owner']
# Handle empty results
if len(sessions) == 0:
console.print('[yellow]No interactive sessions found.[/yellow]')
return
# Prepare rows data
rows = []
for session in sessions:
row_data = []
for col_name in columns_to_show:
if col_name not in all_columns:
continue
col_config = all_columns[col_name]
accessor = col_config['accessor']
# Extract value from session object
value = _get_nested_value(session, accessor)
# Format the value
formatted_value = _format_session_field(col_name, value)
row_data.append(formatted_value)
rows.append(row_data)
# Interactive pagination - use API pagination metadata if available
if pagination_metadata:
# Server-side pagination
current_api_page = pagination_metadata.get('page', 1)
total_sessions = pagination_metadata.get('count', len(sessions))
total_pages = pagination_metadata.get('totalPages', 1)
else:
# Client-side pagination (fallback)
current_api_page = 0
total_sessions = len(sessions)
total_pages = (len(sessions) + page_size - 1) // page_size if len(sessions) > 0 else 1
show_error = None # Track error messages to display
while True:
# For client-side pagination, start/end are indices into the local rows array
# For server-side pagination, we use the API page directly
if fetch_page_callback and pagination_metadata:
# Server-side pagination - sessions list contains current page data
page_rows = rows[:] # All rows are from current page
else:
# Client-side pagination
start = current_api_page * page_size
end = start + page_size
page_rows = rows[start:end]
# Clear console first
console.clear()
# Create table
table = Table(title='Interactive Sessions')
# Add columns to table
for col_name in columns_to_show:
if col_name not in all_columns:
continue
col_config = all_columns[col_name]
table.add_column(
col_config['header'],
style=col_config.get('style', 'white'),
no_wrap=col_config.get('no_wrap', False)
)
# Add rows to table
for row in page_rows:
table.add_row(*row)
# Print table
console.print(table)
# Display pagination info
console.print(f"\n[cyan]Total sessions:[/cyan] {total_sessions}")
if total_pages > 1:
console.print(f"[cyan]Page:[/cyan] {current_api_page} of {total_pages}")
console.print(f"[cyan]Sessions on this page:[/cyan] {len(page_rows)}")
# Show error message if any
if show_error:
console.print(show_error)
show_error = None # Reset error after displaying
# Show pagination controls
if total_pages > 1:
# Check if we're in an interactive environment
if not sys.stdin.isatty():
console.print("\n[yellow]Note: Pagination not available in non-interactive mode. Showing page 1 of {0}.[/yellow]".format(total_pages))
console.print("[yellow]Run in an interactive terminal to navigate through all pages.[/yellow]")
break
console.print(f"\n[bold cyan]n[/] = next, [bold cyan]p[/] = prev, [bold cyan]q[/] = quit")
# Get user input for navigation
try:
choice = input(">>> ").strip().lower()
except (EOFError, KeyboardInterrupt):
# Handle non-interactive environments or user interrupt
console.print("\n[yellow]Pagination interrupted.[/yellow]")
break
if choice in ("q", "quit"):
break
elif choice in ("n", "next"):
if current_api_page < total_pages:
# Try to fetch the next page
if fetch_page_callback:
try:
next_page_data = fetch_page_callback(current_api_page + 1)
sessions = next_page_data.get('sessions', [])
pagination_metadata = next_page_data.get('pagination_metadata', {})
current_api_page = pagination_metadata.get('page', current_api_page + 1)
total_pages = pagination_metadata.get('totalPages', total_pages)
# Rebuild rows for the new page
rows = []
for session in sessions:
row_data = []
for col_name in columns_to_show:
if col_name not in all_columns:
continue
col_config = all_columns[col_name]
accessor = col_config['accessor']
value = _get_nested_value(session, accessor)
formatted_value = _format_session_field(col_name, value)
row_data.append(formatted_value)
rows.append(row_data)
except Exception as e:
show_error = f"[red]Error fetching next page: {str(e)}[/red]"
else:
current_api_page += 1
else:
show_error = "[red]Invalid choice. Already on the last page.[/red]"
elif choice in ("p", "prev"):
if current_api_page > 1:
# Try to fetch the previous page
if fetch_page_callback:
try:
prev_page_data = fetch_page_callback(current_api_page - 1)
sessions = prev_page_data.get('sessions', [])
pagination_metadata = prev_page_data.get('pagination_metadata', {})
current_api_page = pagination_metadata.get('page', current_api_page - 1)
total_pages = pagination_metadata.get('totalPages', total_pages)
# Rebuild rows for the new page
rows = []
for session in sessions:
row_data = []
for col_name in columns_to_show:
if col_name not in all_columns:
continue
col_config = all_columns[col_name]
accessor = col_config['accessor']
value = _get_nested_value(session, accessor)
formatted_value = _format_session_field(col_name, value)
row_data.append(formatted_value)
rows.append(row_data)
except Exception as e:
show_error = f"[red]Error fetching previous page: {str(e)}[/red]"
else:
current_api_page -= 1
else:
show_error = "[red]Invalid choice. Already on the first page.[/red]"
else:
show_error = "[red]Invalid choice. Please enter 'n' (next), 'p' (prev), or 'q' (quit).[/red]"
else:
# Only one page, no need for input, just exit
break
[docs]
def process_interactive_session_list(sessions, all_fields=False):
"""Process interactive sessions data into a pandas DataFrame.
Parameters
----------
sessions : list
List of session objects from the API
all_fields : bool, default=False
If True, include all fields from the API response.
If False, include only the most relevant fields.
Returns
-------
df : pandas.DataFrame
DataFrame with session data
"""
if all_fields:
# Return all fields from the API response
df = pd.json_normalize(sessions)
else:
# Return only selected fields
rows = []
for session in sessions:
# Get user info (API uses 'name' and 'surname', not 'firstName' and 'lastName')
user_obj = session.get('user', {})
user_name = ''
if user_obj:
first_name = user_obj.get('name', '')
last_name = user_obj.get('surname', '')
user_name = f'{first_name} {last_name}'.strip()
row = {
'_id': session.get('_id', ''),
'name': session.get('name', ''),
'status': session.get('status', ''),
'interactiveSessionType': _map_session_type_to_friendly_name(session.get('interactiveSessionType', '')),
'user': user_name,
'instanceType': session.get('resources', {}).get('instanceType', ''),
'totalCostInUsd': session.get('totalCostInUsd', 0),
}
rows.append(row)
df = pd.DataFrame(rows)
return df
def _get_nested_value(obj, path):
"""Get a nested value from an object using dot notation.
Parameters
----------
obj : dict
The object to extract from
path : str
Dot-separated path (e.g., 'user.firstName')
Returns
-------
value
The value at the path, or empty string if not found
"""
parts = path.split('.')
value = obj
for part in parts:
if isinstance(value, dict):
value = value.get(part)
else:
return ''
return value if value is not None else ''
def _format_session_field(field_name, value):
"""Format a session field for display.
Parameters
----------
field_name : str
The name of the field
value
The value to format
Returns
-------
str
The formatted value
"""
if value == '' or value is None:
return '-'
if field_name == 'status':
# Color code status and map display values
status_lower = str(value).lower()
# Map API statuses to display values
# API 'ready' and 'aborted' are mapped to user-friendly names
display_status = 'running' if status_lower == 'ready' else ('paused' if status_lower == 'aborted' else value)
if status_lower in ['ready', 'running']:
return f'[bold green]{display_status}[/bold green]'
elif status_lower in ['paused', 'aborted']:
return f'[bold red]{display_status}[/bold red]'
elif status_lower in ['setup', 'initialising', 'initializing', 'scheduled']:
return f'[bold yellow]{display_status}[/bold yellow]'
else:
return str(display_status)
elif field_name == 'cost':
# Format cost with currency symbol
try:
cost = float(value)
return f'${cost:.2f}'
except (ValueError, TypeError):
return str(value)
elif field_name == 'id':
# Return full ID without truncation (MongoDB ObjectIds are always 24 chars)
# Full ID is needed for status command and other operations
return str(value)
elif field_name == 'name':
# Truncate long names
value_str = str(value)
if len(value_str) > 25:
return value_str[:22] + '…'
return value_str
elif field_name == 'runtime':
# Convert seconds to human-readable format (e.g., "1h 52m 52s")
try:
total_seconds = int(float(value))
hours = total_seconds // 3600
minutes = (total_seconds % 3600) // 60
seconds = total_seconds % 60
if hours > 0:
return f'{hours}h {minutes}m {seconds}s'
elif minutes > 0:
return f'{minutes}m {seconds}s'
else:
return f'{seconds}s'
except (ValueError, TypeError):
return str(value)
elif field_name == 'created_at' or field_name == 'saved_at':
# Format ISO8601 datetime to readable format
try:
dt = datetime.fromisoformat(str(value).replace('Z', '+00:00'))
return dt.strftime('%Y-%m-%d %H:%M')
except (ValueError, TypeError, ImportError):
return str(value)[:19] if value else '-'
elif field_name == 'version':
# Version is only available for RStudio sessions
if value and str(value).lower() != 'none':
return f'R {value}'
return '-'
elif field_name == 'type':
# Map internal type names to user-friendly names
return _map_session_type_to_friendly_name(str(value))
elif field_name == 'spot':
# Indicate if instance is cost-saving (spot)
if value is True:
return '[bold cyan]Yes[/bold cyan]'
elif value is False:
return 'No'
else:
return '-'
elif field_name == 'cost_limit':
# Calculate remaining cost limit (execution object contains computeCostLimit and computeCostSpent)
if isinstance(value, dict):
cost_limit = value.get('computeCostLimit', -1)
cost_spent = value.get('computeCostSpent', 0)
# -1 means unlimited
if cost_limit == -1:
return 'Unlimited'
try:
remaining = float(cost_limit) - float(cost_spent)
if remaining < 0:
remaining = 0
return f'${remaining:.2f}'
except (ValueError, TypeError):
return '-'
return '-'
elif field_name == 'time_left':
# Calculate time until auto-shutdown
if value and value != 'null' and str(value).strip():
try:
shutdown_time = datetime.fromisoformat(str(value).replace('Z', '+00:00'))
now = datetime.now(timezone.utc)
if shutdown_time > now:
time_diff = shutdown_time - now
total_seconds = int(time_diff.total_seconds())
hours = total_seconds // 3600
minutes = (total_seconds % 3600) // 60
if hours > 24:
days = hours // 24
remaining_hours = hours % 24
return f'{days}d {remaining_hours}h'
elif hours > 0:
return f'{hours}h {minutes}m'
else:
return f'{minutes}m'
else:
return '[red]Expired[/red]'
except (ValueError, TypeError, ImportError):
return '-'
return '-'
return str(value)
[docs]
def save_interactive_session_list_to_csv(df, outfile, count=None):
"""Save interactive session list to CSV file.
Parameters
----------
df : pandas.DataFrame
The session data to save
outfile : str
Path to the output CSV file
count : int, optional
Total number of sessions on this page for display message
"""
df.to_csv(outfile, index=False)
if count is not None:
print(f'\tInteractive session list collected with a total of {count} sessions on this page.')
print(f'\tInteractive session list saved to {outfile}')
[docs]
def parse_shutdown_duration(duration_str):
"""Parse shutdown duration string to ISO8601 datetime string.
Accepts formats: 30m, 2h, 8h, 1d, 2d
Parameters
----------
duration_str : str
Duration string (e.g., "2h", "30m", "1d")
Returns
-------
str
ISO8601 formatted datetime string (future time)
"""
match = re.match(r'^(\d+)([mhd])$', duration_str.lower())
if not match:
raise ValueError(f"Invalid duration format: {duration_str}. Use format like '2h', '30m', '1d'")
value = int(match.group(1))
unit = match.group(2)
if unit == 'm':
delta = timedelta(minutes=value)
elif unit == 'h':
delta = timedelta(hours=value)
elif unit == 'd':
delta = timedelta(days=value)
future_time = datetime.now(timezone.utc) + delta
return future_time.isoformat().replace('+00:00', 'Z')
[docs]
def parse_watch_timeout_duration(duration_str):
"""Parse watch timeout duration string to seconds.
Accepts formats: 30m, 2h, 1d, 30s
Parameters
----------
duration_str : str
Duration string (e.g., "30m", "2h", "1d", "30s")
Returns
-------
int
Duration in seconds
"""
match = re.match(r'^(\d+)([smhd])$', duration_str.lower())
if not match:
raise ValueError(f"Invalid duration format: {duration_str}. Use format like '30s', '30m', '2h', '1d'")
value = int(match.group(1))
unit = match.group(2)
if unit == 's':
return value
elif unit == 'm':
return value * 60
elif unit == 'h':
return value * 3600
elif unit == 'd':
return value * 86400
[docs]
def parse_data_file(data_file_str):
"""Parse data file format: either S3 or CloudOS dataset path.
Supports mounting both S3 files and CloudOS dataset files into the session.
Parameters
----------
data_file_str : str
Format:
- S3 file: s3://bucket_name/path/to/file.txt
- CloudOS dataset: project_name/dataset_path or project_name > dataset_path
Examples:
- s3://lifebit-featured-datasets/pipelines/phewas/data.csv
- leila-test/Data/3_vcf_list.txt
Returns
-------
dict
Parsed data item. For S3:
{"type": "s3", "s3_bucket": "...", "s3_prefix": "..."}
For CloudOS dataset:
{"type": "cloudos", "project_name": "...", "dataset_path": "..."}
Raises
------
ValueError
If format is invalid
"""
# Check if it's an S3 path
if data_file_str.startswith('s3://'):
# Parse S3 path: s3://bucket/prefix/file
s3_path = data_file_str[5:] # Remove 's3://'
parts = s3_path.split('/', 1)
bucket = parts[0]
if not bucket:
raise ValueError(f"Invalid S3 path: {data_file_str}. Expected: s3://bucket_name/path/to/file")
prefix = parts[1] if len(parts) > 1 else "/"
return {
"type": "s3",
"s3_bucket": bucket,
"s3_prefix": prefix
}
# Otherwise, parse as CloudOS dataset path
# Determine which separator to use: > takes precedence over /
separator = None
if '>' in data_file_str:
separator = '>'
elif '/' in data_file_str:
separator = '/'
else:
raise ValueError(
f"Invalid data file format: {data_file_str}. Expected one of:\n"
f" - S3 file: s3://bucket/path/file.txt\n"
f" - CloudOS dataset: project_name/dataset_path or project_name > dataset_path"
)
# Split only on the first separator to handle nested paths
parts = data_file_str.split(separator, 1)
if len(parts) != 2:
raise ValueError(f"Invalid data file format: {data_file_str}. Expected: project_name/dataset_path where dataset_path can be nested")
project_name, dataset_path = parts
return {
"type": "cloudos",
"project_name": project_name.strip(),
"dataset_path": dataset_path.strip()
}
[docs]
def resolve_data_file_id(datasets_api, dataset_path: str) -> dict:
"""Resolve nested dataset path to actual file ID.
Searches across all datasets in the project to find the target file.
This allows paths like 'Data/file.txt' to work even if 'Data' is a folder
within a dataset (not a dataset name itself).
Parameters
----------
datasets_api : Datasets
Initialized Datasets API instance (with correct project_name)
dataset_path : str
Nested path to file within the project (e.g., 'Data/file.txt' or 'Folder/subfolder/file.txt')
Can start with a dataset name or a folder name within any dataset.
Returns
-------
dict
Data item object with resolved file ID:
{"kind": "File", "item": "<fileId>", "name": "<fileName>"}
Raises
------
ValueError
If file not found in any dataset/folder
"""
try:
path_parts = dataset_path.strip('/').split('/')
file_name = path_parts[-1]
# First, try the path as-is (assuming first part is a dataset name)
try:
result = datasets_api.list_folder_content(dataset_path)
# Check if it's in the files list
for file_item in result.get('files', []):
if file_item.get('name') == file_name:
return {
"kind": "File",
"item": file_item.get('_id'),
"name": file_item.get('name')
}
# If we got here, quick path didn't work, continue to search
except (Exception):
# First path attempt failed, try searching across all datasets
pass
# If the quick path didn't work, search across all datasets
# This handles the case where the first part is a folder, not a dataset name
project_content = datasets_api.list_project_content()
datasets = project_content.get('folders', [])
if not datasets:
raise ValueError(f"No datasets found in project. Cannot locate path '{dataset_path}'")
# Try to find the file in each dataset
found_files = []
for dataset in datasets:
dataset_name = dataset.get('name')
try:
# Try with the dataset name prepended to the path
full_path = f"{dataset_name}/{dataset_path}"
result = datasets_api.list_folder_content(full_path)
# Check files list
for file_item in result.get('files', []):
if file_item.get('name') == file_name:
found_files.append({
"kind": "File",
"item": file_item.get('_id'),
"name": file_item.get('name')
})
# Return first match (most direct path)
return found_files[0]
except Exception:
# This dataset doesn't contain the path, continue
continue
# Also try searching without dataset prefix (path is from root of datasets)
for dataset in datasets:
try:
dataset_name = dataset.get('name')
# List what's in this dataset at the top level
dataset_content = datasets_api.list_datasets_content(dataset_name)
# Check if the target file is directly in this dataset's files
for file_item in dataset_content.get('files', []):
if file_item.get('name') == file_name:
found_files.append({
"kind": "File",
"item": file_item.get('_id'),
"name": file_item.get('name')
})
# Check folders and navigate if needed
for folder in dataset_content.get('folders', []):
if folder.get('name') == path_parts[0]:
# This dataset has the target folder
full_path = f"{dataset_name}/{dataset_path}"
try:
result = datasets_api.list_folder_content(full_path)
for file_item in result.get('files', []):
if file_item.get('name') == file_name:
return {
"kind": "File",
"item": file_item.get('_id'),
"name": file_item.get('name')
}
except Exception:
continue
except Exception:
continue
# If we found files, return the first one
if found_files:
return found_files[0]
# Nothing found - provide helpful error message
available_datasets = [d.get('name') for d in datasets]
raise ValueError(
f"File at path '{dataset_path}' not found in any dataset. "
f"Available datasets: {available_datasets}. "
f"Try using 'cloudos datasets ls' to explore your data structure."
)
except ValueError:
raise
except Exception as e:
raise ValueError(f"Error resolving dataset file at path '{dataset_path}': {str(e)}")
[docs]
def parse_link_path(link_path_str):
"""Parse link path format: supports S3, CloudOS, or legacy colon format.
Links an S3 folder or CloudOS folder to the session for read/write access.
Parameters
----------
link_path_str : str
Format (one of):
- S3 path: s3://bucketName/s3Prefix (e.g., s3://my-bucket/data/)
- CloudOS folder: project/folder_path (e.g., leila-test/Data)
- Legacy format (deprecated): mountName:bucketName:s3Prefix
Returns
-------
dict
Tuple of (type, data) where type is 's3' or 'cloudos' and data contains:
For S3: {"s3_bucket": "...", "s3_prefix": "..."}
For CloudOS: {"project_name": "...", "folder_path": "..."}
"""
# Check for Azure blob storage paths and provide helpful error
if link_path_str.startswith('az://') or link_path_str.startswith('https://') and '.blob.core.windows.net' in link_path_str:
raise ValueError(
f"Azure blob storage paths are not supported for linking. "
f"Folder linking is not supported on Azure execution platforms. "
f"Please use CloudOS file explorer to access your data directly."
)
# Check for S3 path
if link_path_str.startswith('s3://'):
# Parse S3 path: s3://bucket/prefix
s3_path = link_path_str[5:] # Remove 's3://'
parts = s3_path.split('/', 1)
if len(parts) < 1:
raise ValueError(f"Invalid S3 path: {link_path_str}. Expected: s3://bucket_name/prefix/")
bucket = parts[0]
prefix = parts[1] if len(parts) > 1 else ""
# Ensure prefix ends with / for S3 folders
if prefix and not prefix.endswith('/'):
prefix = prefix + '/'
return {
"type": "s3",
"s3_bucket": bucket,
"s3_prefix": prefix
}
# Check for legacy colon format
if ':' in link_path_str and '//' not in link_path_str:
# Legacy format: mountName:bucketName:s3Prefix
parts = link_path_str.split(':')
if len(parts) != 3:
raise ValueError(f"Invalid link format: {link_path_str}. Expected: mountName:bucketName:s3Prefix")
mount_name, bucket, prefix = parts
# Ensure prefix ends with /
if prefix and not prefix.endswith('/'):
prefix = prefix + '/'
return {
"type": "s3",
"mount_name": mount_name,
"s3_bucket": bucket,
"s3_prefix": prefix
}
# Otherwise, parse as CloudOS folder path
# Format: project_name/folder_path or project_name > folder_path
separator = None
if '>' in link_path_str:
separator = '>'
elif '/' in link_path_str:
separator = '/'
else:
raise ValueError(
f"Invalid link path format: {link_path_str}. Expected one of:\n"
f" - S3 path: s3://bucket/prefix/\n"
f" - CloudOS folder: project/folder/path\n"
f" - Legacy format (deprecated): mountName:bucketName:prefix"
)
parts = link_path_str.split(separator, 1)
if len(parts) != 2:
raise ValueError(f"Invalid link path: {link_path_str}")
project_name, folder_path = parts
return {
"type": "cloudos",
"project_name": project_name.strip(),
"folder_path": folder_path.strip()
}
[docs]
def build_session_payload(
name,
backend,
project_id,
execution_platform='aws',
instance_type='c5.xlarge',
storage_size=500,
is_spot=False,
is_shared=False,
cost_limit=-1,
shutdown_at=None,
data_files=None,
s3_mounts=None,
r_version=None,
spark_master_type=None,
spark_core_type=None,
spark_workers=1
):
"""Build the complex session creation payload for the API.
Parameters
----------
name : str
Session name (1-100 characters)
backend : str
Backend type: regular, vscode, spark, rstudio
project_id : str
Project MongoDB ObjectId
execution_platform : str, optional
Execution platform: 'aws' (default) or 'azure'
instance_type : str
Instance type (EC2 for AWS, e.g., c5.xlarge; Azure VM size, e.g., Standard_F1s)
storage_size : int
Storage in GB (default: 500, range: 100-5000)
is_spot : bool
Use spot instances (AWS only, default: False)
is_shared : bool
Make session shared (default: False)
cost_limit : float
Compute cost limit in USD (default: -1 for unlimited)
shutdown_at : str
ISO8601 datetime for auto-shutdown (optional)
data_files : list
List of data file dicts. For AWS: CloudOS or S3. For Azure: CloudOS only.
s3_mounts : list
List of S3 mount dicts (AWS only, ignored for Azure)
r_version : str
R version for RStudio (required for rstudio backend)
spark_master_type : str
Spark master instance type (required for spark backend, AWS only)
spark_core_type : str
Spark core instance type (required for spark backend, AWS only)
spark_workers : int
Initial number of Spark workers (default: 1, AWS only)
Returns
-------
dict
Complete payload for API request
"""
# Validate inputs
if not 1 <= len(name) <= 100:
raise ValueError("Session name must be 1-100 characters")
if not 100 <= storage_size <= 5000:
raise ValueError("Storage size must be between 100-5000 GB")
if backend not in ['regular', 'vscode', 'spark', 'rstudio']:
raise ValueError("Invalid backend type")
if execution_platform not in ['aws', 'azure']:
raise ValueError("Execution platform must be 'aws' or 'azure'")
# Spark is AWS only
if backend == 'spark' and execution_platform != 'aws':
raise ValueError("Spark backend is only available on AWS")
if backend == 'rstudio' and not r_version:
raise ValueError("R version (--r-version) is required for RStudio backend")
if backend == 'spark' and (not spark_master_type or not spark_core_type):
raise ValueError("Spark master and core instance types are required for Spark backend")
# Default shutdown to 24 hours if not provided
if not shutdown_at:
shutdown_at = (datetime.now(timezone.utc) + timedelta(hours=24)).isoformat().replace('+00:00', 'Z')
# Build interactiveSessionConfiguration
config = {
"name": name,
"backend": backend,
"executionPlatform": execution_platform,
"instanceType": instance_type,
"isCostSaving": is_spot,
"storageSizeInGb": storage_size,
"storageMode": "regular",
"visibility": "workspace" if is_shared else "private",
"execution": {
"computeCostLimit": cost_limit,
"autoShutdownAtDate": shutdown_at
}
}
# Add backend-specific fields
if backend == 'rstudio':
config['rVersion'] = r_version
if backend == 'spark':
master_type = spark_master_type
core_type = spark_core_type
config['cluster'] = {
"name": f"{name}-cluster",
"releaseLabel": "emr-7.3.0",
"ebsRootVolumeSizeInGb": 100,
"instances": {
"master": {
"type": master_type,
"costSaving": is_spot,
"storage": {
"type": "gp2",
"sizeInGbs": 50,
"volumesPerInstance": 1
}
},
"core": {
"type": core_type,
"costSaving": is_spot,
"storage": {
"type": "gp2",
"sizeInGbs": 50,
"volumesPerInstance": 1
},
"minNumberOfInstances": spark_workers,
"autoscaling": {
"minCapacity": spark_workers,
"maxCapacity": max(spark_workers * 2, 10)
}
},
"tasks": []
},
"autoscaling": {
"minCapacity": spark_workers,
"maxCapacity": max(spark_workers * 2, 10)
},
"id": None
}
# Build complete payload
# For Azure, S3 mounts are not supported (fuseFileSystems should be empty)
payload = {
"interactiveSessionConfiguration": config,
"dataItems": data_files or [],
"fileSystemIds": [], # Always empty (legacy compatibility)
"fuseFileSystems": s3_mounts or [] if execution_platform == 'aws' else [],
"projectId": project_id
}
return payload
[docs]
def build_resume_payload(
instance_type=None,
storage_size=None,
cost_limit=None,
shutdown_at=None,
data_files=None,
s3_mounts=None
):
"""Build the resume session payload for the API.
Only includes fields that have been specified (all are optional).
Parameters
----------
instance_type : str, optional
New instance type (if changing)
storage_size : int, optional
New storage size in GB (if changing)
cost_limit : float, optional
New compute cost limit (if changing)
shutdown_at : str, optional
New auto-shutdown datetime in ISO8601 format (if changing)
data_files : list, optional
Additional data files to mount
s3_mounts : list, optional
Additional S3 mounts (AWS only)
Returns
-------
dict
Resume payload for API request
"""
payload = {
"dataItems": data_files or [],
"fileSystemIds": [] # Always empty (deprecated)
}
# Only include newInteractiveSessionConfiguration if any config changes are specified
config_updates = {}
if instance_type is not None:
config_updates["instanceType"] = instance_type
if storage_size is not None:
config_updates["storageSizeInGb"] = storage_size
# Build execution updates if cost_limit or shutdown_at are specified
execution_updates = {}
if cost_limit is not None:
execution_updates["computeCostLimit"] = cost_limit
if shutdown_at is not None:
execution_updates["autoShutdownAtDate"] = shutdown_at
if execution_updates:
config_updates["execution"] = execution_updates
# Only add config updates if there are any
if config_updates:
payload["newInteractiveSessionConfiguration"] = config_updates
# Add S3 mounts if provided (for AWS)
if s3_mounts:
payload["fuseFileSystems"] = s3_mounts
return payload
# ============================================================================
# Interactive Session Status Helper Functions
# ============================================================================
# Status color mapping for Rich terminal
STATUS_COLORS = {
'running': 'green',
'paused': 'red',
'terminated': 'red',
'provisioning': 'yellow',
'scheduled': 'yellow',
}
# Terminal states where watch mode should exit
TERMINAL_STATES = {'running', 'paused', 'terminated'}
# Status mapping from API to user-friendly display
API_STATUS_MAPPING = {
'ready': 'running', # API returns 'ready' for running sessions
'aborted': 'paused', # API returns 'aborted' for paused sessions
'setup': 'setup',
'initialising': 'initialising',
'initializing': 'initialising',
'scheduled': 'scheduled',
'running': 'running', # Some endpoints may return 'running'
'stopped': 'paused', # Some endpoints may return 'stopped' - map to 'paused'
'terminated': 'terminated',
}
# Pre-running statuses (watch mode only valid for these)
PRE_RUNNING_STATUSES = {'setup', 'initialising', 'scheduled'}
[docs]
def map_status(api_status: str) -> str:
"""Map API status value to user-friendly display status.
Converts API status values (like 'ready', 'aborted') to display values
(like 'running', 'paused') matching the list command.
"""
return API_STATUS_MAPPING.get(api_status, api_status)
[docs]
def validate_session_id(session_id: str) -> bool:
"""Validate session ID format (24-character hex string)."""
if not session_id:
return False
return bool(re.match(r'^[a-f0-9]{24}$', session_id, re.IGNORECASE))
[docs]
class InteractiveSessionAPI:
"""API client for interactive session operations."""
REQUEST_TIMEOUT = 30 # seconds
def __init__(self, cloudos_url: str, apikey: str, verify_ssl: bool = True):
"""Initialize API client.
Parameters
----------
cloudos_url : str
Base CloudOS platform URL
apikey : str
API key for authentication
verify_ssl : bool
Whether to verify SSL certificates
"""
self.cloudos_url = cloudos_url.rstrip('/')
self.apikey = apikey
self.verify_ssl = verify_ssl
self.session = self._create_session()
def _create_session(self) -> requests.Session:
"""Create requests session with retry strategy."""
session = requests.Session()
# Configure retry strategy with exponential backoff
retry_strategy = Retry(
total=3,
backoff_factor=1,
status_forcelist=[429, 500, 502, 503, 504],
allowed_methods=['GET']
)
adapter = HTTPAdapter(max_retries=retry_strategy)
session.mount('http://', adapter)
session.mount('https://', adapter)
return session
[docs]
def get_session_status(self, session_id: str, team_id: str) -> dict:
"""Retrieve session status from API endpoint.
GET /api/v2/interactive-sessions/{sessionId}?teamId={teamId}
Parameters
----------
session_id : str
Session ID (24-character hex)
team_id : str
Team/workspace ID
Returns
-------
dict
Session status response
Raises
------
PermissionError
If authentication fails (401, 403)
ValueError
If session not found (404)
RuntimeError
For other API errors
"""
url = f"{self.cloudos_url}/api/v2/interactive-sessions/{session_id}"
params = {'teamId': team_id}
headers = {
'apikey': self.apikey,
'Content-Type': 'application/json'
}
try:
response = retry_requests_get(
url,
params=params,
headers=headers,
verify=self.verify_ssl
)
if response.status_code == 200:
return response.json()
elif response.status_code == 401:
raise PermissionError("Unauthorized: Invalid API key or credentials")
elif response.status_code == 403:
raise PermissionError("Forbidden: Insufficient permissions for this session")
elif response.status_code == 404:
raise ValueError(
f"Session not found. Verify session ID ({session_id}) "
f"and team ID ({team_id})"
)
elif response.status_code == 500:
raise RuntimeError("Server error: Unable to retrieve session status")
else:
raise RuntimeError(
f"API error (HTTP {response.status_code}): {response.text}"
)
except requests.exceptions.Timeout:
raise RuntimeError(f"API request timeout after {self.REQUEST_TIMEOUT} seconds")
except requests.exceptions.ConnectionError as e:
raise RuntimeError(f"Failed to connect to CloudOS: {str(e)}")
[docs]
class WatchModeManager:
"""Manages watch mode polling and display."""
def __init__(self, api_client: InteractiveSessionAPI,
session_id: str, team_id: str, interval: int = 10):
"""Initialize watch mode manager.
Parameters
----------
api_client : InteractiveSessionAPI
API client instance
session_id : str
Session ID to monitor
team_id : str
Team ID
interval : int
Polling interval in seconds (default: 10)
"""
self.api_client = api_client
self.session_id = session_id
self.team_id = team_id
self.interval = interval
self.start_time = time.time()
[docs]
def watch(self, verbose: bool = False) -> dict:
"""Continuously poll session status until reaching terminal state.
Terminal states: running, paused, terminated
Handles Ctrl+C gracefully.
"""
spinner_chars = ['◜', '◝', '◞', '◟']
spinner_index = 0
try:
while True:
# Fetch status
response = self.api_client.get_session_status(
self.session_id, self.team_id
)
status = response.get('status', '')
elapsed = int(time.time() - self.start_time)
# Display progress
spinner = spinner_chars[spinner_index % len(spinner_chars)]
if verbose:
print(
f"\r{spinner} Status: {status:<12} | "
f"Elapsed: {elapsed}s",
end='',
flush=True
)
# Check if reached terminal state
if status in TERMINAL_STATES:
print() # New line after spinner
if status == 'running':
print(
"✓ Session is now running and ready to use!"
)
else:
print(
f"⚠ Session reached terminal state: {status}"
)
return response
# Wait before next poll
spinner_index += 1
time.sleep(self.interval)
except KeyboardInterrupt:
print("\n⚠ Watch mode interrupted by user.")
raise
[docs]
def get_elapsed_time(self) -> str:
"""Get formatted elapsed time."""
elapsed = int(time.time() - self.start_time)
return format_duration(elapsed)
[docs]
def export_session_status_json(session_data: dict, output_file: str = None) -> str:
"""Export session status as JSON.
Parameters
----------
session_data : dict
Raw API response
output_file : str, optional
Path to save JSON file. If None, returns JSON string.
Returns
-------
str
JSON formatted string
"""
json_str = json.dumps(session_data, indent=2, default=str)
if output_file:
with open(output_file, 'w') as f:
f.write(json_str)
return json_str
[docs]
def export_session_status_csv(session_data: dict, output_file: str = None) -> str:
"""Export session status as CSV.
Parameters
----------
session_data : dict
Transformed session data (from transform_session_response)
output_file : str, optional
Path to save CSV file. If None, returns CSV string.
Returns
-------
str
CSV formatted string
"""
csv_str = OutputFormatter.format_csv(session_data)
if output_file:
with open(output_file, 'w') as f:
f.write(csv_str)
return csv_str
# ============================================================================
# Wrapper Functions for CLI Integration
# ============================================================================
[docs]
def get_interactive_session_status(cloudos_url: str, apikey: str, session_id: str,
team_id: str, verify_ssl: bool = True,
verbose: bool = False) -> dict:
"""Wrapper function to fetch session status from API.
Parameters
----------
cloudos_url : str
CloudOS platform URL
apikey : str
API key for authentication
session_id : str
Session ID (24-char hex)
team_id : str
Team/workspace ID
verify_ssl : bool
Whether to verify SSL certificates
verbose : bool
Whether to print verbose output
Returns
-------
dict
Raw API response
Raises
------
ValueError
If session not found
PermissionError
If authentication fails
RuntimeError
For other API errors
"""
api_client = InteractiveSessionAPI(
cloudos_url=cloudos_url,
apikey=apikey,
verify_ssl=verify_ssl
)
return api_client.get_session_status(session_id, team_id)
[docs]
def confirm_session_stop(session_data: dict, no_upload: bool = False, force: bool = False) -> None:
"""Display session termination confirmation details.
Parameters
----------
session_data : dict
Session data from API response
no_upload : bool
Whether data upload on close is disabled
force : bool
Whether force abort is enabled
"""
console = Console()
session_name = session_data.get('name', 'Unknown')
session_id = session_data.get('_id', 'Unknown')
status = map_status(session_data.get('status', 'unknown'))
cost_per_hour = session_data.get('costPerHour', 0)
# Create confirmation table
table = Table(title=f"About to stop session: {session_name}", title_style="bold yellow")
table.add_column("Property", style="cyan", no_wrap=True)
table.add_column("Value", style="green")
table.add_row("Session ID", session_id)
table.add_row("Current Status", status)
if not no_upload:
table.add_row("Data Action", "Will be saved before stopping")
else:
table.add_row("Data Action", "⚠ Will NOT be saved (--no-upload)")
if force:
table.add_row("Termination", "⚠ FORCED (skip graceful shutdown)")
else:
table.add_row("Termination", "Graceful shutdown")
if cost_per_hour:
table.add_row("Cost/Hour", f"${cost_per_hour:.2f}")
console.print(table)
[docs]
def poll_session_termination(cloudos_url: str, apikey: str, session_id: str, team_id: str,
max_wait: int = 300, poll_interval: int = 5, verify_ssl: bool = True) -> dict:
"""Poll session status until it reaches a terminal state.
Parameters
----------
cloudos_url : str
CloudOS API URL
apikey : str
API key for authentication
session_id : str
Session ID to monitor
team_id : str
Team/workspace ID
max_wait : int
Maximum time to wait in seconds (default: 300 = 5 minutes)
poll_interval : int
Polling interval in seconds (default: 5)
verify_ssl : bool
Whether to verify SSL certificates
Returns
-------
dict
Final session status response
Raises
------
TimeoutError
If session doesn't reach terminal state within max_wait
"""
console = Console()
start_time = time.time()
previous_status = None
with console.status("[bold yellow]Pausing session...", spinner='dots'):
while True:
elapsed = time.time() - start_time
# Fetch current status
session_response = get_interactive_session_status(
cloudos_url=cloudos_url,
apikey=apikey,
session_id=session_id,
team_id=team_id,
verify_ssl=verify_ssl,
verbose=False
)
current_status = map_status(session_response.get('status', ''))
# Print status changes
if current_status != previous_status:
console.log(f"Status: {current_status}")
previous_status = current_status
# Check if terminal state reached
if current_status in ['paused', 'terminated']:
console.print("[bold green]✓ Session paused successfully")
return session_response
# Check timeout
if elapsed > max_wait:
raise TimeoutError(
f"Session did not reach terminal state within {max_wait} seconds. "
f"Current status: {current_status}"
)
# Wait before next poll
time.sleep(poll_interval)
[docs]
def fetch_interactive_session_page(cl, workspace_id, page_num, limit, filter_status, filter_only_mine, archived, verify_ssl):
"""Helper function to fetch a specific page of interactive sessions.
Parameters
----------
cl : Cloudos
CloudOS API client instance
workspace_id : str
Workspace ID
page_num : int
Page number to fetch
limit : int
Number of results per page
filter_status : tuple or None
Status filters
filter_only_mine : bool
Whether to filter only user's sessions
archived : bool
Whether to include archived sessions
verify_ssl : bool or str
SSL verification setting
Returns
-------
dict
API response with sessions and pagination metadata
"""
return cl.get_interactive_session_list(
workspace_id,
page=page_num,
limit=limit,
status=list(filter_status) if filter_status else None,
owner_only=filter_only_mine,
include_archived=archived,
verify=verify_ssl
)