# ML Implementation Strategy ## Overview This document outlines the machine learning components for the energy trading system, including gradient boosting price prediction models and Q-Learning reinforcement learning for battery optimization. **Package Location**: `backend/app/ml/` (integrated within the FastAPI backend) **Related Documents**: - `BACKEND_IMPLEMENTATION.md` - API endpoints, services, and ML integration layer - `FRONTEND_IMPLEMENTATION.md` - React frontend that consumes backend API **Training Split Strategy**: Time-based split (first 7 days train, next 1.5 days validation, last 1.5 days test) to prevent look-ahead bias and ensure realistic evaluation. **Data Source**: `~/energy-test-data/data/processed/*.parquet` --- ## Architecture ``` ┌──────────────────────────────────────────────────────────────┐ │ ML Pipeline │ ├──────────────────────────────────────────────────────────────┤ │ │ │ ┌──────────────┬──────────────────────┬──────────────────┐ │ │ │ Feature │ │ │ │ │ │ Engineering │ Model Training │ Model Serving │ │ │ │ │ │ │ │ │ │ - Lags │ - XGBoost │ - Load Model │ │ │ │ - Rolling │ - Q-Learning │ - Predict │ │ │ │ - Time │ - Validation │ - Return Action │ │ │ │ - Regions │ - Evaluation │ │ │ │ └──────────────┴──────────────────────┴──────────────────┘ │ │ │ │ │ ┌──────────────────────┴──────────────────────────────┐ │ │ │ Model Management │ │ │ │ - Versioning - Persistence - Registry - Backup │ │ │ └──────────────────────────────────────────────────────┘ │ └──────────────────────────────────────────────────────────────┘ │ ▼ ┌──────────────────────────────────────────────────────────────┐ │ FastAPI Backend (app.services.ml_service) │ │ /api/v1/models/train, /api/v1/models/predict │ └──────────────────────────────────────────────────────────────┘ ``` --- ## Project Structure ``` backend/app/ml/ ├── __init__.py │ ├── features/ │ ├── __init__.py │ ├── lag_features.py # Price lag feature extraction │ ├── rolling_features.py # Rolling statistics │ ├── time_features.py # Time-of-day encoding │ ├── regional_features.py # Cross-region differentials │ └── battery_features.py # Battery state features │ ├── price_prediction/ │ ├── __init__.py │ ├── model.py # XGBoost model wrapper │ ├── trainer.py # Training pipeline │ ├── evaluator.py # Evaluation metrics │ └── predictor.py # Prediction interface │ ├── rl_battery/ │ ├── __init__.py │ ├── environment.py # Battery MDP environment │ ├── agent.py # Q-Learning agent │ ├── trainer.py # RL training loop │ └── policy.py # Policy inference │ ├── model_management/ │ ├── __init__.py │ ├── registry.py # Model registry │ ├── persistence.py # Save/load models │ ├── versioning.py # Version handling │ └── comparison.py # Model comparison │ ├── evaluation/ │ ├── __init__.py │ ├── metrics.py # Common evaluation metrics │ ├── backtest_evaluator.py # Backtest performance evaluation │ └── reports.py # Generate evaluation reports │ ├── training/ │ ├── __init__.py │ ├── cli.py # CLI commands for retraining │ ├── pipeline.py # End-to-end training pipeline │ └── scheduler.py # Training job scheduler │ └── utils/ ├── __init__.py ├── data_split.py # Time-based data splitting ├── preprocessing.py # Data preprocessing ├── config.py # ML configuration └── evaluation.py # Evaluation metrics ``` --- ## Configuration ### backend/app/ml/utils/config.py ```python from dataclasses import dataclass from typing import List, Dict, Any from pathlib import Path @dataclass class PricePredictionConfig: """Configuration for price prediction models.""" # Data data_path: str = "~/energy-test-data/data/processed" target_column: str = "real_time_price" # Training split (time-based) train_end_pct: float = 0.70 # First 70% for training val_end_pct: float = 0.85 # Next 15% for validation # Last 15% for testing # Features price_lags: List[int] = None rolling_windows: List[int] = None include_time_features: bool = True include_regional_features: bool = True # Model n_estimators: int = 200 max_depth: int = 6 learning_rate: float = 0.1 subsample: float = 0.8 colsample_bytree: float = 0.8 random_state: int = 42 # Early stopping early_stopping_rounds: int = 20 early_stopping_threshold: float = 0.001 def __post_init__(self): if self.price_lags is None: self.price_lags = [1, 5, 10, 15, 30, 60] if self.rolling_windows is None: self.rolling_windows = [5, 10, 15, 30, 60] @dataclass class RLBatteryConfig: """Configuration for RL battery optimization.""" # State space charge_level_bins: int = 10 price_bins: int = 10 time_bins: int = 24 # Hours # Action space actions: List[str] = None # Q-Learning learning_rate: float = 0.1 discount_factor: float = 0.95 epsilon: float = 1.0 epsilon_decay: float = 0.995 epsilon_min: float = 0.05 # Training episodes: int = 1000 max_steps: int = 14400 # 10 days * 1440 minutes # Battery constraints min_reserve: float = 0.10 # 10% max_charge: float = 0.90 # 90% efficiency: float = 0.90 # Reward scaling reward_scale: float = 1.0 def __post_init__(self): if self.actions is None: self.actions = ["charge", "hold", "discharge"] @dataclass class MLConfig: """Overall ML configuration.""" # Paths models_path: str = "models" results_path: str = "results" # Price prediction price_prediction: PricePredictionConfig = None rl_battery: RLBatteryConfig = None # Training enable_gpu: bool = False n_jobs: int = 4 verbose: bool = True # Retraining keep_backup: bool = True max_backups: int = 5 def __post_init__(self): if self.price_prediction is None: self.price_prediction = PricePredictionConfig() if self.rl_battery is None: self.rl_battery = RLBatteryConfig() # Default configuration default_config = MLConfig() ``` --- ## Feature Engineering Interface ### Key Functions (backend/app/ml/features/__init__.py) ```python def build_price_features( df: pd.DataFrame, price_col: str = "real_time_price", lags: List[int] = None, windows: List[int] = None, regions: List[str] = None, include_time: bool = True, include_regional: bool = True, ) -> pd.DataFrame: """ Build complete feature set for price prediction. Args: df: Input DataFrame price_col: Name of price column lags: List of lag periods windows: List of rolling window sizes regions: List of regions for differential features include_time: Whether to include time features include_regional: Whether to include regional features Returns: DataFrame with all features """ def build_battery_features( df: pd.DataFrame, price_df: pd.DataFrame, battery_col: str = "charge_level_mwh", capacity_col: str = "capacity_mwh", timestamp_col: str = "timestamp", battery_id_col: str = "battery_id", ) -> pd.DataFrame: """ Build features for battery RL model. Args: df: Battery DataFrame price_df: Price DataFrame battery_col: Name of battery charge level column capacity_col: Name of battery capacity column timestamp_col: Name of timestamp column battery_id_col: Name of battery ID column Returns: DataFrame with battery features """ ``` --- ## Price Prediction Interface ### PricePredictor (backend/app/ml/price_prediction/predictor.py) ```python class PricePredictor: """Interface for making price predictions.""" def __init__(self, models_dir: str = "models/price_prediction"): """ Initialize predictor. Args: models_dir: Directory containing trained models """ def predict( self, current_data: pd.DataFrame, horizon: int = 15, region: Optional[str] = None, ) -> float: """ Predict price for a specific horizon. Args: current_data: Current/historical price data horizon: Prediction horizon in minutes region: Specific region to predict for (optional) Returns: Predicted price """ def predict_all_horizons( self, current_data: pd.DataFrame, region: Optional[str] = None, ) -> Dict[int, float]: """ Predict prices for all available horizons. Returns: Dictionary mapping horizons to predictions """ def predict_with_confidence( self, current_data: pd.DataFrame, horizon: int = 15, region: Optional[str] = None, ) -> Dict: """ Predict price with confidence interval. Returns: Dictionary with prediction and confidence interval """ def get_feature_importance(self, horizon: int) -> pd.DataFrame: """ Get feature importance for a specific horizon. Returns: DataFrame with feature importance """ ``` ### PricePredictionTrainer (backend/app/ml/price_prediction/trainer.py) ```python class PricePredictionTrainer: """Training pipeline for price prediction models.""" def __init__(self, config: PricePredictionConfig = None): """Initialize trainer.""" def load_data(self) -> pd.DataFrame: """Load price data.""" def prepare_data(self, df: pd.DataFrame) -> Tuple[pd.DataFrame, List[str]]: """ Prepare data with features. Returns: Tuple of (features DataFrame, feature names) """ def train_for_horizon( self, df_features: pd.DataFrame, feature_cols: List[str], horizon: int, ) -> Dict: """ Train model for a specific horizon. Returns: Training results dictionary with metrics """ def train_all(self, horizons: List[int] = None) -> Dict: """ Train models for all horizons. Returns: Dictionary with all training results """ def save_models(self, output_dir: str = "models/price_prediction") -> None: """Save all trained models.""" @classmethod def load_models( cls, models_dir: str = "models/price_prediction", horizons: List[int] = None, ) -> Dict[int, PricePredictionModel]: """ Load trained models. Returns: Dictionary mapping horizons to models """ ``` --- ## RL Battery Optimization Interface ### BatteryPolicy (backend/app/ml/rl_battery/policy.py) ```python class BatteryPolicy: """Interface for RL battery policy inference.""" def __init__(self, policy_path: str = "models/rl_battery"): """ Initialize policy. Args: policy_path: Path to trained policy """ def get_action( self, charge_level: float, current_price: float, price_forecast_1m: float = 0, price_forecast_5m: float = 0, price_forecast_15m: float = 0, hour: int = 0, ) -> Dict: """ Get action for current state. Returns: Dictionary with action, q_values """ ``` ### BatteryRLTrainer (backend/app/ml/rl_battery/trainer.py) ```python class BatteryRLTrainer: """Training pipeline for RL battery policy.""" def __init__(self, config: RLBatteryConfig = None): """Initialize trainer.""" def load_data(self) -> None: """Load price data for environment.""" def train(self, n_episodes: int = 1000, region: str = "FR") -> Dict: """ Train RL agent. Returns: Training results with metrics """ def save(self, output_dir: str = "models/rl_battery") -> None: """Save trained policy.""" ``` --- ## Model Management Interface ### ModelRegistry (backend/app/ml/model_management/registry.py) ```python class ModelRegistry: """Registry for tracking model versions.""" def __init__(self, registry_path: str = "models/registry.json"): """Initialize registry.""" def register_model( self, model_type: str, model_id: str, version: str, filepath: str, metadata: Dict = None, ) -> None: """Register a model version.""" def get_latest_version(self, model_id: str) -> Optional[Dict]: """Get latest version of a model.""" def list_models(self) -> List[Dict]: """List all registered models.""" ``` --- ## Training Task Interface ### Training Jobs (app/tasks/training_tasks.py) ```python async def train_model_task(training_id: str, request: TrainingRequest): """ Execute ML model training via Celery task. Dispatches to Celery for async processing of: - Price prediction training - RL battery policy training Emits WebSocket events for progress updates. """ @shared_task(name="tasks.train_price_prediction") def train_price_prediction(training_id: str, request_dict: dict): """ Celery task for price prediction model training. Process: 1. Load and prepare data 2. Train XGBoost models for specified horizon 3. Save models 4. Register in model registry 5. Update training job status """ @shared_task(name="tasks.train_rl_battery") def train_rl_battery(training_id: str, request_dict: dict): """ Celery task for RL battery policy training. Process: 1. Load environment and data 2. Train Q-Learning agent 3. Save policy 4. Register in model registry 5. Update training job status """ ``` --- ## Key Integration Points ### ML Service Integration (app/services/ml_service.py) The ML service provides the bridge between API routes and ML models: ```python class MLService: """Service for ML model management and inference.""" def list_models(self) -> List[ModelInfo]: """List all available trained models.""" def get_model_metrics(self, model_id: str) -> Dict[str, float]: """Get performance metrics for a model.""" def predict( self, model_id: str, timestamp: datetime, features: Optional[Dict[str, Any]] = None ) -> Dict[str, Any]: """ Run prediction with on-demand model loading. Supports: - Price prediction models - RL battery policy """ def get_feature_importance(self, model_id: str) -> Dict[str, float]: """Get feature importance for a model.""" ``` ### WebSocket Events Real-time training progress updates: ```python # Event types - "model_training_progress" # Training progress updates # Payload { "model_id": str, "progress": float, # 0.0 to 1.0 "epoch": int, # Current epoch (optional) "metrics": dict # Current metrics } ``` --- ## Data Flow ### Training Pipeline ``` API Request (POST /api/v1/models/train) ↓ training_tasks.train_model_task() ↓ Celery Task (train_price_prediction / train_rl_battery) ↓ PricePredictionTrainer / BatteryRLTrainer ↓ Feature Engineering → Model Training → Evaluation ↓ Save Models → Register in Registry ↓ WebSocket Events (progress updates) ↓ Update Training Status ``` ### Prediction Pipeline ``` API Request (POST /api/v1/models/predict) ↓ ml_service.predict() ↓ Load Model (on-demand) ↓ Feature Engineering ↓ Model.predict() / Policy.get_action() ↓ Return Prediction with Confidence ``` --- ## Model Artifacts ### Price Prediction Models Location: `models/price_prediction/` ``` model_1min.pkl # 1-minute horizon model model_5min.pkl # 5-minute horizon model model_15min.pkl # 15-minute horizon model model_60min.pkl # 60-minute horizon model training_results.json # Training metrics and metadata ``` ### RL Battery Policy Location: `models/rl_battery/` ``` battery_policy.pkl # Trained Q-Learning policy training_results.json # Training metrics ``` ### Model Registry Location: `models/registry.json` ```json { "models": { "price_prediction_15m": { "type": "price_prediction", "versions": ["v20260211_134500", "v20260210_100000"], "latest": "v20260211_134500" }, "battery_policy": { "type": "rl_battery", "versions": ["v20260211_140000"], "latest": "v20260211_140000" } } } ```