import subprocess
from typing import Optional
from datetime import datetime, timedelta
import sqlite3
from pathlib import Path
import time
import logging
from .exceptions import CredentialError, AuthenticationError
# Configure module logger
logger = logging.getLogger(__name__)
logger.addHandler(logging.NullHandler())
[docs]
class AWSsso:
"""
A class to handle AWS SSO authentication and credential management.
This class provides functionality to refresh AWS SSO credentials and track their validity.
Examples:
```python
# Example 1: Basic usage
sso = AWSsso(
aws_exec_file_path=r'C:\Program Files\Amazon\AWSCLIV2\aws.exe',
db_path=Path('./data/credentials.db'),
refresh_window_hours=6
)
# Get credential expiration time
expiration = sso.get_expiration_time()
print(f"Credentials will expire at: {expiration}")
# Example 2: Custom configuration
sso = AWSsso(
aws_exec_file_path=r'C:\Program Files\Amazon\AWSCLIV2\aws.exe',
db_path=Path('./data/custom_credentials.db'),
refresh_window_hours=12, # Refresh every 12 hours
max_retries=5, # More retries
retry_delay=10 # Longer delay between retries
)
# Ensure credentials are valid before AWS operations
try:
sso.ensure_valid_credentials()
# Proceed with AWS operations
except AuthenticationError as e:
print(f"Failed to authenticate: {e}")
# Monitor credential status
last_refresh = sso.get_last_refresh_time()
if last_refresh:
time_since_refresh = datetime.now() - last_refresh
print(f"Time since last refresh: {time_since_refresh}")
```
"""
[docs]
def __init__(
self,
aws_exec_file_path: str = r'C:\Program Files\Amazon\AWSCLIV2\aws.exe',
db_path: Path = Path("./data/aws_credentials.db"),
refresh_window_hours: int = 6,
max_retries: int = 3,
retry_delay: int = 5
):
"""
Initialize the AWS SSO handler.
Args:
aws_exec_file_path (str): Path to the AWS CLI executable
db_path (Path): Path to the credentials database
refresh_window_hours (int): Hours between credential refreshes
max_retries (int): Maximum number of authentication retries
retry_delay (int): Delay between retries in seconds
Raises:
ValueError: If configuration parameters are invalid
CredentialError: If there's an error initializing the database
"""
self.aws_exec_file_path = aws_exec_file_path
self.db_path = db_path
self.refresh_window_hours = refresh_window_hours
self.max_retries = max_retries
self.retry_delay = retry_delay
self._last_refresh_time: Optional[datetime] = None
self._expiration_time: Optional[datetime] = None
self._db_connection: Optional[sqlite3.Connection] = None
self._init_db()
[docs]
def _init_db(self) -> None:
"""
Initialize the SQLite database for storing credential timestamps.
If the database exists and contains timestamps, initialize the cached timestamps.
Raises:
CredentialError: If there's an error creating the database or table
"""
try:
# First check if the parent directory exists
if not self.db_path.parent.exists():
try:
self.db_path.parent.mkdir(parents=True, exist_ok=True)
except Exception as e:
raise CredentialError(f"Error creating database directory: {str(e)}")
# Try to connect to the database
try:
conn = sqlite3.connect(self.db_path)
except sqlite3.Error as e:
raise CredentialError(f"Error connecting to database: {str(e)}")
try:
cursor = conn.cursor()
# Create table if it doesn't exist
cursor.execute('''
CREATE TABLE IF NOT EXISTS credential_timestamps (
id INTEGER PRIMARY KEY,
last_refresh TIMESTAMP NOT NULL,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
)
''')
conn.commit()
# If table exists, get the latest timestamp
cursor.execute('SELECT last_refresh FROM credential_timestamps ORDER BY id DESC LIMIT 1')
result = cursor.fetchone()
if result:
self._last_refresh_time = datetime.fromisoformat(result[0])
self._expiration_time = self._last_refresh_time + timedelta(hours=self.refresh_window_hours)
logger.info(f"Initialized timestamps from database. Last refresh: {self._last_refresh_time}")
else:
logger.info("No existing timestamps found in database")
logger.info(f"Initialized credential database at {self.db_path}")
except sqlite3.Error as e:
raise CredentialError(f"Error initializing database tables: {str(e)}")
finally:
conn.close()
except Exception as e:
error_msg = f"Error initializing credential database: {str(e)}"
logger.error(error_msg)
raise CredentialError(error_msg)
[docs]
def __enter__(self) -> 'AWSsso':
"""Context manager entry."""
return self
[docs]
def __exit__(self, exc_type, exc_val, exc_tb) -> None:
"""Context manager exit - ensure database connection is closed."""
self._close_db_connection()
[docs]
def _get_db_connection(self) -> sqlite3.Connection:
"""Get a database connection, creating one if it doesn't exist.
Returns:
sqlite3.Connection: A database connection instance
Raises:
CredentialError: If there's an error creating the connection
"""
if not self._db_connection:
try:
self._db_connection = sqlite3.connect(self.db_path)
except sqlite3.Error as e:
raise CredentialError(f"Error creating database connection: {str(e)}")
return self._db_connection
[docs]
def _close_db_connection(self) -> None:
"""Close the database connection if it exists."""
if self._db_connection:
try:
self._db_connection.close()
except Exception as e:
logger.error(f"Error closing database connection: {e}")
finally:
self._db_connection = None
@property
def should_refresh_credentials(self) -> bool:
"""Check if credentials need to be refreshed.
Returns:
bool: True if credentials need to be refreshed, False otherwise
"""
if self._expiration_time is None:
return True
return datetime.now() > self._expiration_time
[docs]
def get_expiration_time(self) -> Optional[datetime]:
"""Get the expiration time of the current credentials.
Returns:
Optional[datetime]: The expiration time of the credentials, or None if not set
"""
return self._expiration_time
[docs]
def get_last_refresh_time(self) -> Optional[datetime]:
"""Get the last time the credentials were refreshed.
Returns:
Optional[datetime]: The last refresh time, or None if not set
"""
return self._last_refresh_time
[docs]
def ensure_valid_credentials(self) -> bool:
"""Ensure that the AWS SSO credentials are valid.
If credentials are expired or about to expire, they will be refreshed.
Returns:
bool: True if credentials are valid or were successfully refreshed
Raises:
AuthenticationError: If there's an error during SSO authentication
CredentialError: If there's an error updating the timestamp
"""
if self.should_refresh_credentials:
return self.refresh_credentials()
return True
[docs]
def refresh_credentials(self) -> bool:
"""
Refresh AWS SSO credentials using the AWS CLI if they need refreshing.
Returns:
bool: True if credentials are valid or were successfully refreshed, False otherwise
Raises:
AuthenticationError: If there's an error during SSO authentication
CredentialError: If there's an error updating the timestamp
"""
# Attempt to refresh credentials
for attempt in range(self.max_retries):
try:
result = subprocess.run(
[self.aws_exec_file_path, 'sso', 'login'],
shell=True,
check=True,
capture_output=True,
text=True
)
if result.returncode == 0:
logger.info("Successfully authenticated with AWS SSO")
# Update cached timestamps
self._last_refresh_time = datetime.now()
self._expiration_time = self._last_refresh_time + timedelta(hours=self.refresh_window_hours)
# Update database with new timestamp
try:
conn = self._get_db_connection()
try:
cursor = conn.cursor()
cursor.execute(
'INSERT INTO credential_timestamps (last_refresh) VALUES (?)',
(self._last_refresh_time.isoformat(),)
)
conn.commit()
logger.info("Successfully updated SSO credential refresh timestamp")
return True
except sqlite3.Error as e:
error_msg = f"Error updating SSO credential timestamp: {str(e)}"
logger.error(error_msg)
raise CredentialError(error_msg)
finally:
conn.close()
self._db_connection = None
except Exception as e:
error_msg = f"Error updating SSO credential timestamp: {str(e)}"
logger.error(error_msg)
raise CredentialError(error_msg)
if attempt < self.max_retries - 1:
logger.warning(f"SSO authentication attempt {attempt + 1} failed, retrying...")
time.sleep(self.retry_delay)
except subprocess.CalledProcessError as e:
if attempt == self.max_retries - 1:
raise AuthenticationError(f"Failed to refresh SSO credentials after {self.max_retries} attempts: {str(e)}")
logger.warning(f"SSO authentication attempt {attempt + 1} failed: {str(e)}")
time.sleep(self.retry_delay)
except FileNotFoundError as e:
error_msg = f"AWS CLI executable not found at: {self.aws_exec_file_path}"
logger.error(error_msg)
raise AuthenticationError(error_msg)
raise AuthenticationError(f"Failed to refresh SSO credentials after {self.max_retries} attempts")