Source code for aws_sso.aws_sso

import subprocess
from typing import Optional
from datetime import datetime, timedelta
import sqlite3
from pathlib import Path
import time
import logging
from .exceptions import CredentialError, AuthenticationError

# Configure module logger
logger = logging.getLogger(__name__)
logger.addHandler(logging.NullHandler())

[docs] class AWSsso: """ A class to handle AWS SSO authentication and credential management. This class provides functionality to refresh AWS SSO credentials and track their validity. Examples: ```python # Example 1: Basic usage sso = AWSsso( aws_exec_file_path=r'C:\Program Files\Amazon\AWSCLIV2\aws.exe', db_path=Path('./data/credentials.db'), refresh_window_hours=6 ) # Get credential expiration time expiration = sso.get_expiration_time() print(f"Credentials will expire at: {expiration}") # Example 2: Custom configuration sso = AWSsso( aws_exec_file_path=r'C:\Program Files\Amazon\AWSCLIV2\aws.exe', db_path=Path('./data/custom_credentials.db'), refresh_window_hours=12, # Refresh every 12 hours max_retries=5, # More retries retry_delay=10 # Longer delay between retries ) # Ensure credentials are valid before AWS operations try: sso.ensure_valid_credentials() # Proceed with AWS operations except AuthenticationError as e: print(f"Failed to authenticate: {e}") # Monitor credential status last_refresh = sso.get_last_refresh_time() if last_refresh: time_since_refresh = datetime.now() - last_refresh print(f"Time since last refresh: {time_since_refresh}") ``` """
[docs] def __init__( self, aws_exec_file_path: str = r'C:\Program Files\Amazon\AWSCLIV2\aws.exe', db_path: Path = Path("./data/aws_credentials.db"), refresh_window_hours: int = 6, max_retries: int = 3, retry_delay: int = 5 ): """ Initialize the AWS SSO handler. Args: aws_exec_file_path (str): Path to the AWS CLI executable db_path (Path): Path to the credentials database refresh_window_hours (int): Hours between credential refreshes max_retries (int): Maximum number of authentication retries retry_delay (int): Delay between retries in seconds Raises: ValueError: If configuration parameters are invalid CredentialError: If there's an error initializing the database """ self.aws_exec_file_path = aws_exec_file_path self.db_path = db_path self.refresh_window_hours = refresh_window_hours self.max_retries = max_retries self.retry_delay = retry_delay self._last_refresh_time: Optional[datetime] = None self._expiration_time: Optional[datetime] = None self._db_connection: Optional[sqlite3.Connection] = None self._init_db()
[docs] def _init_db(self) -> None: """ Initialize the SQLite database for storing credential timestamps. If the database exists and contains timestamps, initialize the cached timestamps. Raises: CredentialError: If there's an error creating the database or table """ try: # First check if the parent directory exists if not self.db_path.parent.exists(): try: self.db_path.parent.mkdir(parents=True, exist_ok=True) except Exception as e: raise CredentialError(f"Error creating database directory: {str(e)}") # Try to connect to the database try: conn = sqlite3.connect(self.db_path) except sqlite3.Error as e: raise CredentialError(f"Error connecting to database: {str(e)}") try: cursor = conn.cursor() # Create table if it doesn't exist cursor.execute(''' CREATE TABLE IF NOT EXISTS credential_timestamps ( id INTEGER PRIMARY KEY, last_refresh TIMESTAMP NOT NULL, created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP ) ''') conn.commit() # If table exists, get the latest timestamp cursor.execute('SELECT last_refresh FROM credential_timestamps ORDER BY id DESC LIMIT 1') result = cursor.fetchone() if result: self._last_refresh_time = datetime.fromisoformat(result[0]) self._expiration_time = self._last_refresh_time + timedelta(hours=self.refresh_window_hours) logger.info(f"Initialized timestamps from database. Last refresh: {self._last_refresh_time}") else: logger.info("No existing timestamps found in database") logger.info(f"Initialized credential database at {self.db_path}") except sqlite3.Error as e: raise CredentialError(f"Error initializing database tables: {str(e)}") finally: conn.close() except Exception as e: error_msg = f"Error initializing credential database: {str(e)}" logger.error(error_msg) raise CredentialError(error_msg)
[docs] def __enter__(self) -> 'AWSsso': """Context manager entry.""" return self
[docs] def __exit__(self, exc_type, exc_val, exc_tb) -> None: """Context manager exit - ensure database connection is closed.""" self._close_db_connection()
[docs] def _get_db_connection(self) -> sqlite3.Connection: """Get a database connection, creating one if it doesn't exist. Returns: sqlite3.Connection: A database connection instance Raises: CredentialError: If there's an error creating the connection """ if not self._db_connection: try: self._db_connection = sqlite3.connect(self.db_path) except sqlite3.Error as e: raise CredentialError(f"Error creating database connection: {str(e)}") return self._db_connection
[docs] def _close_db_connection(self) -> None: """Close the database connection if it exists.""" if self._db_connection: try: self._db_connection.close() except Exception as e: logger.error(f"Error closing database connection: {e}") finally: self._db_connection = None
@property def should_refresh_credentials(self) -> bool: """Check if credentials need to be refreshed. Returns: bool: True if credentials need to be refreshed, False otherwise """ if self._expiration_time is None: return True return datetime.now() > self._expiration_time
[docs] def get_expiration_time(self) -> Optional[datetime]: """Get the expiration time of the current credentials. Returns: Optional[datetime]: The expiration time of the credentials, or None if not set """ return self._expiration_time
[docs] def get_last_refresh_time(self) -> Optional[datetime]: """Get the last time the credentials were refreshed. Returns: Optional[datetime]: The last refresh time, or None if not set """ return self._last_refresh_time
[docs] def ensure_valid_credentials(self) -> bool: """Ensure that the AWS SSO credentials are valid. If credentials are expired or about to expire, they will be refreshed. Returns: bool: True if credentials are valid or were successfully refreshed Raises: AuthenticationError: If there's an error during SSO authentication CredentialError: If there's an error updating the timestamp """ if self.should_refresh_credentials: return self.refresh_credentials() return True
[docs] def refresh_credentials(self) -> bool: """ Refresh AWS SSO credentials using the AWS CLI if they need refreshing. Returns: bool: True if credentials are valid or were successfully refreshed, False otherwise Raises: AuthenticationError: If there's an error during SSO authentication CredentialError: If there's an error updating the timestamp """ # Attempt to refresh credentials for attempt in range(self.max_retries): try: result = subprocess.run( [self.aws_exec_file_path, 'sso', 'login'], shell=True, check=True, capture_output=True, text=True ) if result.returncode == 0: logger.info("Successfully authenticated with AWS SSO") # Update cached timestamps self._last_refresh_time = datetime.now() self._expiration_time = self._last_refresh_time + timedelta(hours=self.refresh_window_hours) # Update database with new timestamp try: conn = self._get_db_connection() try: cursor = conn.cursor() cursor.execute( 'INSERT INTO credential_timestamps (last_refresh) VALUES (?)', (self._last_refresh_time.isoformat(),) ) conn.commit() logger.info("Successfully updated SSO credential refresh timestamp") return True except sqlite3.Error as e: error_msg = f"Error updating SSO credential timestamp: {str(e)}" logger.error(error_msg) raise CredentialError(error_msg) finally: conn.close() self._db_connection = None except Exception as e: error_msg = f"Error updating SSO credential timestamp: {str(e)}" logger.error(error_msg) raise CredentialError(error_msg) if attempt < self.max_retries - 1: logger.warning(f"SSO authentication attempt {attempt + 1} failed, retrying...") time.sleep(self.retry_delay) except subprocess.CalledProcessError as e: if attempt == self.max_retries - 1: raise AuthenticationError(f"Failed to refresh SSO credentials after {self.max_retries} attempts: {str(e)}") logger.warning(f"SSO authentication attempt {attempt + 1} failed: {str(e)}") time.sleep(self.retry_delay) except FileNotFoundError as e: error_msg = f"AWS CLI executable not found at: {self.aws_exec_file_path}" logger.error(error_msg) raise AuthenticationError(error_msg) raise AuthenticationError(f"Failed to refresh SSO credentials after {self.max_retries} attempts")