Add initial implementation strategy documentation
Add comprehensive documentation for energy trading system: - Backend: FastAPI architecture, API routes, services, WebSocket - Frontend: React structure, components, state management - ML: Feature engineering, XGBoost price prediction, RL battery optimization
This commit is contained in:
679
ML_IMPLEMENTATION.md
Normal file
679
ML_IMPLEMENTATION.md
Normal file
@@ -0,0 +1,679 @@
|
||||
# ML Implementation Strategy
|
||||
|
||||
## Overview
|
||||
|
||||
This document outlines the machine learning components for the energy trading system, including gradient boosting price prediction models and Q-Learning reinforcement learning for battery optimization.
|
||||
|
||||
**Package Location**: `backend/app/ml/` (integrated within the FastAPI backend)
|
||||
|
||||
**Related Documents**:
|
||||
- `BACKEND_IMPLEMENTATION.md` - API endpoints, services, and ML integration layer
|
||||
- `FRONTEND_IMPLEMENTATION.md` - React frontend that consumes backend API
|
||||
|
||||
**Training Split Strategy**: Time-based split (first 7 days train, next 1.5 days validation, last 1.5 days test) to prevent look-ahead bias and ensure realistic evaluation.
|
||||
|
||||
**Data Source**: `~/energy-test-data/data/processed/*.parquet`
|
||||
|
||||
---
|
||||
|
||||
## Architecture
|
||||
|
||||
```
|
||||
┌──────────────────────────────────────────────────────────────┐
|
||||
│ ML Pipeline │
|
||||
├──────────────────────────────────────────────────────────────┤
|
||||
│ │
|
||||
│ ┌──────────────┬──────────────────────┬──────────────────┐ │
|
||||
│ │ Feature │ │ │ │
|
||||
│ │ Engineering │ Model Training │ Model Serving │ │
|
||||
│ │ │ │ │ │
|
||||
│ │ - Lags │ - XGBoost │ - Load Model │ │
|
||||
│ │ - Rolling │ - Q-Learning │ - Predict │ │
|
||||
│ │ - Time │ - Validation │ - Return Action │ │
|
||||
│ │ - Regions │ - Evaluation │ │ │
|
||||
│ └──────────────┴──────────────────────┴──────────────────┘ │
|
||||
│ │ │
|
||||
│ ┌──────────────────────┴──────────────────────────────┐ │
|
||||
│ │ Model Management │ │
|
||||
│ │ - Versioning - Persistence - Registry - Backup │ │
|
||||
│ └──────────────────────────────────────────────────────┘ │
|
||||
└──────────────────────────────────────────────────────────────┘
|
||||
│
|
||||
▼
|
||||
┌──────────────────────────────────────────────────────────────┐
|
||||
│ FastAPI Backend (app.services.ml_service) │
|
||||
│ /api/v1/models/train, /api/v1/models/predict │
|
||||
└──────────────────────────────────────────────────────────────┘
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Project Structure
|
||||
|
||||
```
|
||||
backend/app/ml/
|
||||
├── __init__.py
|
||||
│
|
||||
├── features/
|
||||
│ ├── __init__.py
|
||||
│ ├── lag_features.py # Price lag feature extraction
|
||||
│ ├── rolling_features.py # Rolling statistics
|
||||
│ ├── time_features.py # Time-of-day encoding
|
||||
│ ├── regional_features.py # Cross-region differentials
|
||||
│ └── battery_features.py # Battery state features
|
||||
│
|
||||
├── price_prediction/
|
||||
│ ├── __init__.py
|
||||
│ ├── model.py # XGBoost model wrapper
|
||||
│ ├── trainer.py # Training pipeline
|
||||
│ ├── evaluator.py # Evaluation metrics
|
||||
│ └── predictor.py # Prediction interface
|
||||
│
|
||||
├── rl_battery/
|
||||
│ ├── __init__.py
|
||||
│ ├── environment.py # Battery MDP environment
|
||||
│ ├── agent.py # Q-Learning agent
|
||||
│ ├── trainer.py # RL training loop
|
||||
│ └── policy.py # Policy inference
|
||||
│
|
||||
├── model_management/
|
||||
│ ├── __init__.py
|
||||
│ ├── registry.py # Model registry
|
||||
│ ├── persistence.py # Save/load models
|
||||
│ ├── versioning.py # Version handling
|
||||
│ └── comparison.py # Model comparison
|
||||
│
|
||||
├── evaluation/
|
||||
│ ├── __init__.py
|
||||
│ ├── metrics.py # Common evaluation metrics
|
||||
│ ├── backtest_evaluator.py # Backtest performance evaluation
|
||||
│ └── reports.py # Generate evaluation reports
|
||||
│
|
||||
├── training/
|
||||
│ ├── __init__.py
|
||||
│ ├── cli.py # CLI commands for retraining
|
||||
│ ├── pipeline.py # End-to-end training pipeline
|
||||
│ └── scheduler.py # Training job scheduler
|
||||
│
|
||||
└── utils/
|
||||
├── __init__.py
|
||||
├── data_split.py # Time-based data splitting
|
||||
├── preprocessing.py # Data preprocessing
|
||||
├── config.py # ML configuration
|
||||
└── evaluation.py # Evaluation metrics
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Configuration
|
||||
|
||||
### backend/app/ml/utils/config.py
|
||||
|
||||
```python
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Dict, Any
|
||||
from pathlib import Path
|
||||
|
||||
@dataclass
|
||||
class PricePredictionConfig:
|
||||
"""Configuration for price prediction models."""
|
||||
|
||||
# Data
|
||||
data_path: str = "~/energy-test-data/data/processed"
|
||||
target_column: str = "real_time_price"
|
||||
|
||||
# Training split (time-based)
|
||||
train_end_pct: float = 0.70 # First 70% for training
|
||||
val_end_pct: float = 0.85 # Next 15% for validation
|
||||
# Last 15% for testing
|
||||
|
||||
# Features
|
||||
price_lags: List[int] = None
|
||||
rolling_windows: List[int] = None
|
||||
include_time_features: bool = True
|
||||
include_regional_features: bool = True
|
||||
|
||||
# Model
|
||||
n_estimators: int = 200
|
||||
max_depth: int = 6
|
||||
learning_rate: float = 0.1
|
||||
subsample: float = 0.8
|
||||
colsample_bytree: float = 0.8
|
||||
random_state: int = 42
|
||||
|
||||
# Early stopping
|
||||
early_stopping_rounds: int = 20
|
||||
early_stopping_threshold: float = 0.001
|
||||
|
||||
def __post_init__(self):
|
||||
if self.price_lags is None:
|
||||
self.price_lags = [1, 5, 10, 15, 30, 60]
|
||||
if self.rolling_windows is None:
|
||||
self.rolling_windows = [5, 10, 15, 30, 60]
|
||||
|
||||
|
||||
@dataclass
|
||||
class RLBatteryConfig:
|
||||
"""Configuration for RL battery optimization."""
|
||||
|
||||
# State space
|
||||
charge_level_bins: int = 10
|
||||
price_bins: int = 10
|
||||
time_bins: int = 24 # Hours
|
||||
|
||||
# Action space
|
||||
actions: List[str] = None
|
||||
|
||||
# Q-Learning
|
||||
learning_rate: float = 0.1
|
||||
discount_factor: float = 0.95
|
||||
epsilon: float = 1.0
|
||||
epsilon_decay: float = 0.995
|
||||
epsilon_min: float = 0.05
|
||||
|
||||
# Training
|
||||
episodes: int = 1000
|
||||
max_steps: int = 14400 # 10 days * 1440 minutes
|
||||
|
||||
# Battery constraints
|
||||
min_reserve: float = 0.10 # 10%
|
||||
max_charge: float = 0.90 # 90%
|
||||
efficiency: float = 0.90
|
||||
|
||||
# Reward scaling
|
||||
reward_scale: float = 1.0
|
||||
|
||||
def __post_init__(self):
|
||||
if self.actions is None:
|
||||
self.actions = ["charge", "hold", "discharge"]
|
||||
|
||||
|
||||
@dataclass
|
||||
class MLConfig:
|
||||
"""Overall ML configuration."""
|
||||
|
||||
# Paths
|
||||
models_path: str = "models"
|
||||
results_path: str = "results"
|
||||
|
||||
# Price prediction
|
||||
price_prediction: PricePredictionConfig = None
|
||||
rl_battery: RLBatteryConfig = None
|
||||
|
||||
# Training
|
||||
enable_gpu: bool = False
|
||||
n_jobs: int = 4
|
||||
verbose: bool = True
|
||||
|
||||
# Retraining
|
||||
keep_backup: bool = True
|
||||
max_backups: int = 5
|
||||
|
||||
def __post_init__(self):
|
||||
if self.price_prediction is None:
|
||||
self.price_prediction = PricePredictionConfig()
|
||||
if self.rl_battery is None:
|
||||
self.rl_battery = RLBatteryConfig()
|
||||
|
||||
|
||||
# Default configuration
|
||||
default_config = MLConfig()
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Feature Engineering Interface
|
||||
|
||||
### Key Functions (backend/app/ml/features/__init__.py)
|
||||
|
||||
```python
|
||||
def build_price_features(
|
||||
df: pd.DataFrame,
|
||||
price_col: str = "real_time_price",
|
||||
lags: List[int] = None,
|
||||
windows: List[int] = None,
|
||||
regions: List[str] = None,
|
||||
include_time: bool = True,
|
||||
include_regional: bool = True,
|
||||
) -> pd.DataFrame:
|
||||
"""
|
||||
Build complete feature set for price prediction.
|
||||
|
||||
Args:
|
||||
df: Input DataFrame
|
||||
price_col: Name of price column
|
||||
lags: List of lag periods
|
||||
windows: List of rolling window sizes
|
||||
regions: List of regions for differential features
|
||||
include_time: Whether to include time features
|
||||
include_regional: Whether to include regional features
|
||||
|
||||
Returns:
|
||||
DataFrame with all features
|
||||
"""
|
||||
|
||||
|
||||
def build_battery_features(
|
||||
df: pd.DataFrame,
|
||||
price_df: pd.DataFrame,
|
||||
battery_col: str = "charge_level_mwh",
|
||||
capacity_col: str = "capacity_mwh",
|
||||
timestamp_col: str = "timestamp",
|
||||
battery_id_col: str = "battery_id",
|
||||
) -> pd.DataFrame:
|
||||
"""
|
||||
Build features for battery RL model.
|
||||
|
||||
Args:
|
||||
df: Battery DataFrame
|
||||
price_df: Price DataFrame
|
||||
battery_col: Name of battery charge level column
|
||||
capacity_col: Name of battery capacity column
|
||||
timestamp_col: Name of timestamp column
|
||||
battery_id_col: Name of battery ID column
|
||||
|
||||
Returns:
|
||||
DataFrame with battery features
|
||||
"""
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Price Prediction Interface
|
||||
|
||||
### PricePredictor (backend/app/ml/price_prediction/predictor.py)
|
||||
|
||||
```python
|
||||
class PricePredictor:
|
||||
"""Interface for making price predictions."""
|
||||
|
||||
def __init__(self, models_dir: str = "models/price_prediction"):
|
||||
"""
|
||||
Initialize predictor.
|
||||
|
||||
Args:
|
||||
models_dir: Directory containing trained models
|
||||
"""
|
||||
|
||||
def predict(
|
||||
self,
|
||||
current_data: pd.DataFrame,
|
||||
horizon: int = 15,
|
||||
region: Optional[str] = None,
|
||||
) -> float:
|
||||
"""
|
||||
Predict price for a specific horizon.
|
||||
|
||||
Args:
|
||||
current_data: Current/historical price data
|
||||
horizon: Prediction horizon in minutes
|
||||
region: Specific region to predict for (optional)
|
||||
|
||||
Returns:
|
||||
Predicted price
|
||||
"""
|
||||
|
||||
def predict_all_horizons(
|
||||
self,
|
||||
current_data: pd.DataFrame,
|
||||
region: Optional[str] = None,
|
||||
) -> Dict[int, float]:
|
||||
"""
|
||||
Predict prices for all available horizons.
|
||||
|
||||
Returns:
|
||||
Dictionary mapping horizons to predictions
|
||||
"""
|
||||
|
||||
def predict_with_confidence(
|
||||
self,
|
||||
current_data: pd.DataFrame,
|
||||
horizon: int = 15,
|
||||
region: Optional[str] = None,
|
||||
) -> Dict:
|
||||
"""
|
||||
Predict price with confidence interval.
|
||||
|
||||
Returns:
|
||||
Dictionary with prediction and confidence interval
|
||||
"""
|
||||
|
||||
def get_feature_importance(self, horizon: int) -> pd.DataFrame:
|
||||
"""
|
||||
Get feature importance for a specific horizon.
|
||||
|
||||
Returns:
|
||||
DataFrame with feature importance
|
||||
"""
|
||||
```
|
||||
|
||||
### PricePredictionTrainer (backend/app/ml/price_prediction/trainer.py)
|
||||
|
||||
```python
|
||||
class PricePredictionTrainer:
|
||||
"""Training pipeline for price prediction models."""
|
||||
|
||||
def __init__(self, config: PricePredictionConfig = None):
|
||||
"""Initialize trainer."""
|
||||
|
||||
def load_data(self) -> pd.DataFrame:
|
||||
"""Load price data."""
|
||||
|
||||
def prepare_data(self, df: pd.DataFrame) -> Tuple[pd.DataFrame, List[str]]:
|
||||
"""
|
||||
Prepare data with features.
|
||||
|
||||
Returns:
|
||||
Tuple of (features DataFrame, feature names)
|
||||
"""
|
||||
|
||||
def train_for_horizon(
|
||||
self,
|
||||
df_features: pd.DataFrame,
|
||||
feature_cols: List[str],
|
||||
horizon: int,
|
||||
) -> Dict:
|
||||
"""
|
||||
Train model for a specific horizon.
|
||||
|
||||
Returns:
|
||||
Training results dictionary with metrics
|
||||
"""
|
||||
|
||||
def train_all(self, horizons: List[int] = None) -> Dict:
|
||||
"""
|
||||
Train models for all horizons.
|
||||
|
||||
Returns:
|
||||
Dictionary with all training results
|
||||
"""
|
||||
|
||||
def save_models(self, output_dir: str = "models/price_prediction") -> None:
|
||||
"""Save all trained models."""
|
||||
|
||||
@classmethod
|
||||
def load_models(
|
||||
cls,
|
||||
models_dir: str = "models/price_prediction",
|
||||
horizons: List[int] = None,
|
||||
) -> Dict[int, PricePredictionModel]:
|
||||
"""
|
||||
Load trained models.
|
||||
|
||||
Returns:
|
||||
Dictionary mapping horizons to models
|
||||
"""
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## RL Battery Optimization Interface
|
||||
|
||||
### BatteryPolicy (backend/app/ml/rl_battery/policy.py)
|
||||
|
||||
```python
|
||||
class BatteryPolicy:
|
||||
"""Interface for RL battery policy inference."""
|
||||
|
||||
def __init__(self, policy_path: str = "models/rl_battery"):
|
||||
"""
|
||||
Initialize policy.
|
||||
|
||||
Args:
|
||||
policy_path: Path to trained policy
|
||||
"""
|
||||
|
||||
def get_action(
|
||||
self,
|
||||
charge_level: float,
|
||||
current_price: float,
|
||||
price_forecast_1m: float = 0,
|
||||
price_forecast_5m: float = 0,
|
||||
price_forecast_15m: float = 0,
|
||||
hour: int = 0,
|
||||
) -> Dict:
|
||||
"""
|
||||
Get action for current state.
|
||||
|
||||
Returns:
|
||||
Dictionary with action, q_values
|
||||
"""
|
||||
```
|
||||
|
||||
### BatteryRLTrainer (backend/app/ml/rl_battery/trainer.py)
|
||||
|
||||
```python
|
||||
class BatteryRLTrainer:
|
||||
"""Training pipeline for RL battery policy."""
|
||||
|
||||
def __init__(self, config: RLBatteryConfig = None):
|
||||
"""Initialize trainer."""
|
||||
|
||||
def load_data(self) -> None:
|
||||
"""Load price data for environment."""
|
||||
|
||||
def train(self, n_episodes: int = 1000, region: str = "FR") -> Dict:
|
||||
"""
|
||||
Train RL agent.
|
||||
|
||||
Returns:
|
||||
Training results with metrics
|
||||
"""
|
||||
|
||||
def save(self, output_dir: str = "models/rl_battery") -> None:
|
||||
"""Save trained policy."""
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Model Management Interface
|
||||
|
||||
### ModelRegistry (backend/app/ml/model_management/registry.py)
|
||||
|
||||
```python
|
||||
class ModelRegistry:
|
||||
"""Registry for tracking model versions."""
|
||||
|
||||
def __init__(self, registry_path: str = "models/registry.json"):
|
||||
"""Initialize registry."""
|
||||
|
||||
def register_model(
|
||||
self,
|
||||
model_type: str,
|
||||
model_id: str,
|
||||
version: str,
|
||||
filepath: str,
|
||||
metadata: Dict = None,
|
||||
) -> None:
|
||||
"""Register a model version."""
|
||||
|
||||
def get_latest_version(self, model_id: str) -> Optional[Dict]:
|
||||
"""Get latest version of a model."""
|
||||
|
||||
def list_models(self) -> List[Dict]:
|
||||
"""List all registered models."""
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Training Task Interface
|
||||
|
||||
### Training Jobs (app/tasks/training_tasks.py)
|
||||
|
||||
```python
|
||||
async def train_model_task(training_id: str, request: TrainingRequest):
|
||||
"""
|
||||
Execute ML model training via Celery task.
|
||||
|
||||
Dispatches to Celery for async processing of:
|
||||
- Price prediction training
|
||||
- RL battery policy training
|
||||
|
||||
Emits WebSocket events for progress updates.
|
||||
"""
|
||||
|
||||
@shared_task(name="tasks.train_price_prediction")
|
||||
def train_price_prediction(training_id: str, request_dict: dict):
|
||||
"""
|
||||
Celery task for price prediction model training.
|
||||
|
||||
Process:
|
||||
1. Load and prepare data
|
||||
2. Train XGBoost models for specified horizon
|
||||
3. Save models
|
||||
4. Register in model registry
|
||||
5. Update training job status
|
||||
"""
|
||||
|
||||
@shared_task(name="tasks.train_rl_battery")
|
||||
def train_rl_battery(training_id: str, request_dict: dict):
|
||||
"""
|
||||
Celery task for RL battery policy training.
|
||||
|
||||
Process:
|
||||
1. Load environment and data
|
||||
2. Train Q-Learning agent
|
||||
3. Save policy
|
||||
4. Register in model registry
|
||||
5. Update training job status
|
||||
"""
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Key Integration Points
|
||||
|
||||
### ML Service Integration (app/services/ml_service.py)
|
||||
|
||||
The ML service provides the bridge between API routes and ML models:
|
||||
|
||||
```python
|
||||
class MLService:
|
||||
"""Service for ML model management and inference."""
|
||||
|
||||
def list_models(self) -> List[ModelInfo]:
|
||||
"""List all available trained models."""
|
||||
|
||||
def get_model_metrics(self, model_id: str) -> Dict[str, float]:
|
||||
"""Get performance metrics for a model."""
|
||||
|
||||
def predict(
|
||||
self,
|
||||
model_id: str,
|
||||
timestamp: datetime,
|
||||
features: Optional[Dict[str, Any]] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Run prediction with on-demand model loading.
|
||||
|
||||
Supports:
|
||||
- Price prediction models
|
||||
- RL battery policy
|
||||
"""
|
||||
|
||||
def get_feature_importance(self, model_id: str) -> Dict[str, float]:
|
||||
"""Get feature importance for a model."""
|
||||
```
|
||||
|
||||
### WebSocket Events
|
||||
|
||||
Real-time training progress updates:
|
||||
|
||||
```python
|
||||
# Event types
|
||||
- "model_training_progress" # Training progress updates
|
||||
|
||||
# Payload
|
||||
{
|
||||
"model_id": str,
|
||||
"progress": float, # 0.0 to 1.0
|
||||
"epoch": int, # Current epoch (optional)
|
||||
"metrics": dict # Current metrics
|
||||
}
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Data Flow
|
||||
|
||||
### Training Pipeline
|
||||
|
||||
```
|
||||
API Request (POST /api/v1/models/train)
|
||||
↓
|
||||
training_tasks.train_model_task()
|
||||
↓
|
||||
Celery Task (train_price_prediction / train_rl_battery)
|
||||
↓
|
||||
PricePredictionTrainer / BatteryRLTrainer
|
||||
↓
|
||||
Feature Engineering → Model Training → Evaluation
|
||||
↓
|
||||
Save Models → Register in Registry
|
||||
↓
|
||||
WebSocket Events (progress updates)
|
||||
↓
|
||||
Update Training Status
|
||||
```
|
||||
|
||||
### Prediction Pipeline
|
||||
|
||||
```
|
||||
API Request (POST /api/v1/models/predict)
|
||||
↓
|
||||
ml_service.predict()
|
||||
↓
|
||||
Load Model (on-demand)
|
||||
↓
|
||||
Feature Engineering
|
||||
↓
|
||||
Model.predict() / Policy.get_action()
|
||||
↓
|
||||
Return Prediction with Confidence
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Model Artifacts
|
||||
|
||||
### Price Prediction Models
|
||||
|
||||
Location: `models/price_prediction/`
|
||||
|
||||
```
|
||||
model_1min.pkl # 1-minute horizon model
|
||||
model_5min.pkl # 5-minute horizon model
|
||||
model_15min.pkl # 15-minute horizon model
|
||||
model_60min.pkl # 60-minute horizon model
|
||||
training_results.json # Training metrics and metadata
|
||||
```
|
||||
|
||||
### RL Battery Policy
|
||||
|
||||
Location: `models/rl_battery/`
|
||||
|
||||
```
|
||||
battery_policy.pkl # Trained Q-Learning policy
|
||||
training_results.json # Training metrics
|
||||
```
|
||||
|
||||
### Model Registry
|
||||
|
||||
Location: `models/registry.json`
|
||||
|
||||
```json
|
||||
{
|
||||
"models": {
|
||||
"price_prediction_15m": {
|
||||
"type": "price_prediction",
|
||||
"versions": ["v20260211_134500", "v20260210_100000"],
|
||||
"latest": "v20260211_134500"
|
||||
},
|
||||
"battery_policy": {
|
||||
"type": "rl_battery",
|
||||
"versions": ["v20260211_140000"],
|
||||
"latest": "v20260211_140000"
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
Reference in New Issue
Block a user