Add initial implementation strategy documentation
Add comprehensive documentation for energy trading system: - Backend: FastAPI architecture, API routes, services, WebSocket - Frontend: React structure, components, state management - ML: Feature engineering, XGBoost price prediction, RL battery optimization
This commit is contained in:
750
BACKEND_IMPLEMENTATION.md
Normal file
750
BACKEND_IMPLEMENTATION.md
Normal file
@@ -0,0 +1,750 @@
|
|||||||
|
# Backend Implementation Strategy
|
||||||
|
|
||||||
|
## Overview
|
||||||
|
|
||||||
|
This document outlines the FastAPI backend for the energy trading system UI. The backend serves data, executes strategies, runs backtests, and provides real-time updates via WebSockets.
|
||||||
|
|
||||||
|
**Data Source**: `~/energy-test-data/data/processed/`
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Architecture
|
||||||
|
|
||||||
|
```
|
||||||
|
┌──────────────────────────────────────────────────────────────┐
|
||||||
|
│ FastAPI Application │
|
||||||
|
├──────────────────────────────────────────────────────────────┤
|
||||||
|
│ │
|
||||||
|
│ ┌─────────────┬─────────────┬─────────────┬──────────────┐ │
|
||||||
|
│ │ API │ Services │ Tasks │ WebSocket │ │
|
||||||
|
│ │ Routes │ Layer │ (Celery) │ Manager │ │
|
||||||
|
│ └─────────────┴─────────────┴─────────────┴──────────────┘ │
|
||||||
|
│ ┌──────────┐ │
|
||||||
|
│ │ Data │ │
|
||||||
|
│ │ Cache │ │
|
||||||
|
│ └──────────┘ │
|
||||||
|
└──────────────────────────────────────────────────────────────┘
|
||||||
|
│
|
||||||
|
▼
|
||||||
|
┌──────────────────────────────────────────────────────────────┐
|
||||||
|
│ Core Trading Engine (Imported) │
|
||||||
|
│ - Fundamental Strategy │
|
||||||
|
│ - Technical Analysis │
|
||||||
|
│ - ML Models (Price Prediction, RL Battery) │
|
||||||
|
│ - Backtesting Engine │
|
||||||
|
└──────────────────────────────────────────────────────────────┘
|
||||||
|
│
|
||||||
|
▼
|
||||||
|
┌──────────────────────────────────────────────────────────────┐
|
||||||
|
│ Data Source │
|
||||||
|
│ ~/energy-test-data/data/processed/*.parquet │
|
||||||
|
└──────────────────────────────────────────────────────────────┘
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Project Structure
|
||||||
|
|
||||||
|
```
|
||||||
|
backend/
|
||||||
|
├── app/
|
||||||
|
│ ├── __init__.py
|
||||||
|
│ ├── main.py # FastAPI app entry
|
||||||
|
│ ├── config.py # Configuration management
|
||||||
|
│ │
|
||||||
|
│ ├── api/
|
||||||
|
│ │ ├── __init__.py
|
||||||
|
│ │ ├── routes/
|
||||||
|
│ │ │ ├── __init__.py
|
||||||
|
│ │ │ ├── dashboard.py # Dashboard data endpoints
|
||||||
|
│ │ │ ├── backtest.py # Backtest execution
|
||||||
|
│ │ │ ├── models.py # ML model endpoints
|
||||||
|
│ │ │ ├── trading.py # Trading control
|
||||||
|
│ │ │ └── settings.py # Configuration management
|
||||||
|
│ │ └── websocket.py # WebSocket connection manager
|
||||||
|
│ │
|
||||||
|
│ ├── services/
|
||||||
|
│ │ ├── __init__.py
|
||||||
|
│ │ ├── data_service.py # Data loading and caching
|
||||||
|
│ │ ├── strategy_service.py # Strategy execution
|
||||||
|
│ │ ├── ml_service.py # ML model management
|
||||||
|
│ │ ├── trading_service.py # Trading operations
|
||||||
|
│ │ └── alert_service.py # Alert management
|
||||||
|
│ │
|
||||||
|
│ ├── tasks/
|
||||||
|
│ │ ├── __init__.py
|
||||||
|
│ │ ├── backtest_tasks.py # Async backtest execution
|
||||||
|
│ │ ├── training_tasks.py # ML model training
|
||||||
|
│ │ └── monitoring_tasks.py # Real-time data updates
|
||||||
|
│ │
|
||||||
|
│ ├── ml/ # ML models and training
|
||||||
|
│ │ ├── __init__.py
|
||||||
|
│ │ ├── features/
|
||||||
|
│ │ │ ├── __init__.py
|
||||||
|
│ │ │ ├── lag_features.py
|
||||||
|
│ │ │ ├── rolling_features.py
|
||||||
|
│ │ │ ├── time_features.py
|
||||||
|
│ │ │ ├── regional_features.py
|
||||||
|
│ │ │ └── battery_features.py
|
||||||
|
│ │ │
|
||||||
|
│ │ ├── price_prediction/
|
||||||
|
│ │ │ ├── __init__.py
|
||||||
|
│ │ │ ├── model.py
|
||||||
|
│ │ │ ├── trainer.py
|
||||||
|
│ │ │ └── predictor.py
|
||||||
|
│ │ │
|
||||||
|
│ │ ├── rl_battery/
|
||||||
|
│ │ │ ├── __init__.py
|
||||||
|
│ │ │ ├── environment.py
|
||||||
|
│ │ │ ├── agent.py
|
||||||
|
│ │ │ ├── trainer.py
|
||||||
|
│ │ │ └── policy.py
|
||||||
|
│ │ │
|
||||||
|
│ │ ├── model_management/
|
||||||
|
│ │ │ ├── __init__.py
|
||||||
|
│ │ │ ├── registry.py
|
||||||
|
│ │ │ ├── persistence.py
|
||||||
|
│ │ │ ├── versioning.py
|
||||||
|
│ │ │ └── comparison.py
|
||||||
|
│ │ │
|
||||||
|
│ │ ├── evaluation/
|
||||||
|
│ │ │ ├── __init__.py
|
||||||
|
│ │ │ ├── metrics.py
|
||||||
|
│ │ │ ├── backtest_evaluator.py
|
||||||
|
│ │ │ └── reports.py
|
||||||
|
│ │ │
|
||||||
|
│ │ ├── training/
|
||||||
|
│ │ │ ├── __init__.py
|
||||||
|
│ │ │ └── cli.py
|
||||||
|
│ │ │
|
||||||
|
│ │ └── utils/
|
||||||
|
│ │ ├── __init__.py
|
||||||
|
│ │ ├── data_split.py
|
||||||
|
│ │ ├── config.py
|
||||||
|
│ │ └── evaluation.py
|
||||||
|
│ │
|
||||||
|
│ ├── models/
|
||||||
|
│ │ ├── __init__.py
|
||||||
|
│ │ ├── schemas.py # Pydantic models
|
||||||
|
│ │ └── enums.py # Enumerations
|
||||||
|
│ │
|
||||||
|
│ ├── core/
|
||||||
|
│ │ ├── __init__.py
|
||||||
|
│ │ └── constants.py # Constants and defaults
|
||||||
|
│ │
|
||||||
|
│ └── utils/
|
||||||
|
│ ├── __init__.py
|
||||||
|
│ ├── logger.py
|
||||||
|
│ └── helpers.py
|
||||||
|
│
|
||||||
|
├── tests/
|
||||||
|
│ ├── __init__.py
|
||||||
|
│ ├── conftest.py
|
||||||
|
│ ├── test_api/
|
||||||
|
│ ├── test_services/
|
||||||
|
│ └── test_websocket.py
|
||||||
|
│
|
||||||
|
├── models/ # Trained ML models storage
|
||||||
|
│ ├── price_prediction/
|
||||||
|
│ │ ├── model_1min.pkl
|
||||||
|
│ │ ├── model_5min.pkl
|
||||||
|
│ │ ├── model_15min.pkl
|
||||||
|
│ │ └── model_60min.pkl
|
||||||
|
│ └── rl_battery/
|
||||||
|
│ └── battery_policy.pkl
|
||||||
|
│
|
||||||
|
├── results/ # Backtest results storage
|
||||||
|
│ └── backtests/
|
||||||
|
│
|
||||||
|
├── .env.example
|
||||||
|
├── requirements.txt
|
||||||
|
├── pyproject.toml
|
||||||
|
└── Dockerfile
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Configuration
|
||||||
|
|
||||||
|
### app/config.py (Settings)
|
||||||
|
|
||||||
|
```python
|
||||||
|
from pydantic_settings import BaseSettings
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
class Settings(BaseSettings):
|
||||||
|
# Application
|
||||||
|
APP_NAME: str = "Energy Trading API"
|
||||||
|
APP_VERSION: str = "1.0.0"
|
||||||
|
DEBUG: bool = True
|
||||||
|
|
||||||
|
# Server
|
||||||
|
HOST: str = "0.0.0.0"
|
||||||
|
PORT: int = 8000
|
||||||
|
|
||||||
|
# Data
|
||||||
|
DATA_PATH: str = "~/energy-test-data/data/processed"
|
||||||
|
DATA_PATH_RESOLVED: Path = Path(DATA_PATH).expanduser()
|
||||||
|
|
||||||
|
# CORS
|
||||||
|
CORS_ORIGINS: List[str] = [
|
||||||
|
"http://localhost:3000",
|
||||||
|
"http://localhost:5173",
|
||||||
|
]
|
||||||
|
|
||||||
|
# WebSocket
|
||||||
|
WS_HEARTBEAT_INTERVAL: int = 30
|
||||||
|
|
||||||
|
# Celery
|
||||||
|
CELERY_BROKER_URL: str = "redis://localhost:6379/0"
|
||||||
|
CELERY_RESULT_BACKEND: str = "redis://localhost:6379/0"
|
||||||
|
|
||||||
|
# Models
|
||||||
|
MODELS_PATH: str = "models"
|
||||||
|
RESULTS_PATH: str = "results"
|
||||||
|
|
||||||
|
# Battery
|
||||||
|
BATTERY_MIN_RESERVE: float = 0.10
|
||||||
|
BATTERY_MAX_CHARGE: float = 0.90
|
||||||
|
|
||||||
|
# Arbitrage
|
||||||
|
ARBITRAGE_MIN_SPREAD: float = 5.0 # EUR/MWh
|
||||||
|
|
||||||
|
# Mining
|
||||||
|
MINING_MARGIN_THRESHOLD: float = 5.0 # EUR/MWh
|
||||||
|
|
||||||
|
# ML
|
||||||
|
ML_PREDICTION_HORIZONS: List[int] = [1, 5, 15, 60]
|
||||||
|
ML_FEATURE_LAGS: List[int] = [1, 5, 10, 15, 30, 60]
|
||||||
|
|
||||||
|
class Config:
|
||||||
|
env_file = ".env"
|
||||||
|
case_sensitive = True
|
||||||
|
|
||||||
|
settings = Settings()
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Data Models (app/models/schemas.py)
|
||||||
|
|
||||||
|
### Enums
|
||||||
|
|
||||||
|
```python
|
||||||
|
class RegionEnum(str, Enum):
|
||||||
|
FR = "FR"
|
||||||
|
BE = "BE"
|
||||||
|
DE = "DE"
|
||||||
|
NL = "NL"
|
||||||
|
UK = "UK"
|
||||||
|
|
||||||
|
class FuelTypeEnum(str, Enum):
|
||||||
|
GAS = "gas"
|
||||||
|
NUCLEAR = "nuclear"
|
||||||
|
COAL = "coal"
|
||||||
|
SOLAR = "solar"
|
||||||
|
WIND = "wind"
|
||||||
|
HYDRO = "hydro"
|
||||||
|
|
||||||
|
class StrategyEnum(str, Enum):
|
||||||
|
FUNDAMENTAL = "fundamental"
|
||||||
|
TECHNICAL = "technical"
|
||||||
|
ML = "ml"
|
||||||
|
MINING = "mining"
|
||||||
|
|
||||||
|
class TradeTypeEnum(str, Enum):
|
||||||
|
BUY = "buy"
|
||||||
|
SELL = "sell"
|
||||||
|
CHARGE = "charge"
|
||||||
|
DISCHARGE = "discharge"
|
||||||
|
|
||||||
|
class BacktestStatusEnum(str, Enum):
|
||||||
|
PENDING = "pending"
|
||||||
|
RUNNING = "running"
|
||||||
|
COMPLETED = "completed"
|
||||||
|
FAILED = "failed"
|
||||||
|
CANCELLED = "cancelled"
|
||||||
|
|
||||||
|
class ModelType(str, Enum):
|
||||||
|
PRICE_PREDICTION = "price_prediction"
|
||||||
|
RL_BATTERY = "rl_battery"
|
||||||
|
|
||||||
|
class AlertTypeEnum(str, Enum):
|
||||||
|
PRICE_SPIKE = "price_spike"
|
||||||
|
ARBITRAGE_OPPORTUNITY = "arbitrage_opportunity"
|
||||||
|
BATTERY_LOW = "battery_low"
|
||||||
|
BATTERY_FULL = "battery_full"
|
||||||
|
STRATEGY_ERROR = "strategy_error"
|
||||||
|
```
|
||||||
|
|
||||||
|
### Key Schemas
|
||||||
|
|
||||||
|
```python
|
||||||
|
class PriceData(BaseModel):
|
||||||
|
timestamp: datetime
|
||||||
|
region: RegionEnum
|
||||||
|
day_ahead_price: float
|
||||||
|
real_time_price: float
|
||||||
|
volume_mw: float
|
||||||
|
|
||||||
|
class BatteryState(BaseModel):
|
||||||
|
timestamp: datetime
|
||||||
|
battery_id: str
|
||||||
|
capacity_mwh: float
|
||||||
|
charge_level_mwh: float
|
||||||
|
charge_rate_mw: float
|
||||||
|
discharge_rate_mw: float
|
||||||
|
efficiency: float
|
||||||
|
charge_level_pct: float = Field(default_factory=lambda: 0.0)
|
||||||
|
|
||||||
|
class BacktestConfig(BaseModel):
|
||||||
|
start_date: str
|
||||||
|
end_date: str
|
||||||
|
strategies: List[StrategyEnum] = Field(default_factory=list)
|
||||||
|
use_ml: bool = True
|
||||||
|
battery_min_reserve: Optional[float] = None
|
||||||
|
battery_max_charge: Optional[float] = None
|
||||||
|
arbitrage_min_spread: Optional[float] = None
|
||||||
|
|
||||||
|
class BacktestMetrics(BaseModel):
|
||||||
|
total_revenue: float
|
||||||
|
arbitrage_profit: float
|
||||||
|
battery_revenue: float
|
||||||
|
mining_profit: float
|
||||||
|
battery_utilization: float
|
||||||
|
price_capture_rate: float
|
||||||
|
win_rate: float
|
||||||
|
sharpe_ratio: float
|
||||||
|
max_drawdown: float
|
||||||
|
total_trades: int
|
||||||
|
|
||||||
|
class TrainingRequest(BaseModel):
|
||||||
|
model_type: ModelType
|
||||||
|
horizon: Optional[int] = None
|
||||||
|
start_date: str
|
||||||
|
end_date: str
|
||||||
|
hyperparameters: Dict[str, Any] = Field(default_factory=dict)
|
||||||
|
|
||||||
|
class PredictionResponse(BaseModel):
|
||||||
|
model_id: str
|
||||||
|
timestamp: datetime
|
||||||
|
prediction: float
|
||||||
|
confidence: Optional[float] = None
|
||||||
|
features_used: List[str] = Field(default_factory=list)
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## API Routes
|
||||||
|
|
||||||
|
### Dashboard API (`/api/v1/dashboard/*`)
|
||||||
|
|
||||||
|
```python
|
||||||
|
# GET /api/v1/dashboard/summary
|
||||||
|
Response: DashboardSummary
|
||||||
|
|
||||||
|
# GET /api/v1/dashboard/prices
|
||||||
|
Response: { regions: { [region]: { timestamp, day_ahead_price, real_time_price, volume_mw } } }
|
||||||
|
|
||||||
|
# GET /api/v1/dashboard/prices/history?region={region}&start={start}&end={end}&limit={limit}
|
||||||
|
Response: { region, data: PriceData[] }
|
||||||
|
|
||||||
|
# GET /api/v1/dashboard/battery
|
||||||
|
Response: { batteries: BatteryState[] }
|
||||||
|
|
||||||
|
# GET /api/v1/dashboard/arbitrage?min_spread={min_spread}
|
||||||
|
Response: { opportunities: ArbitrageOpportunity[], count: int }
|
||||||
|
```
|
||||||
|
|
||||||
|
### Backtest API (`/api/v1/backtest/*`)
|
||||||
|
|
||||||
|
```python
|
||||||
|
# POST /api/v1/backtest/start
|
||||||
|
Request: { config: BacktestConfig, name?: string }
|
||||||
|
Response: { backtest_id: string, status: BacktestStatus }
|
||||||
|
|
||||||
|
# GET /api/v1/backtest/{backtest_id}
|
||||||
|
Response: { status: BacktestStatus, results?: BacktestResult }
|
||||||
|
|
||||||
|
# GET /api/v1/backtest/{backtest_id}/results
|
||||||
|
Response: BacktestResult
|
||||||
|
|
||||||
|
# GET /api/v1/backtest/{backtest_id}/trades?limit={limit}
|
||||||
|
Response: { backtest_id, trades: Trade[], total: int }
|
||||||
|
|
||||||
|
# GET /api/v1/backtest/history
|
||||||
|
Response: { backtests: BacktestStatus[], total: int }
|
||||||
|
|
||||||
|
# DELETE /api/v1/backtest/{backtest_id}
|
||||||
|
Response: { message: string }
|
||||||
|
```
|
||||||
|
|
||||||
|
### Models API (`/api/v1/models/*`)
|
||||||
|
|
||||||
|
```python
|
||||||
|
# GET /api/v1/models
|
||||||
|
Response: { models: ModelInfo[], total: int }
|
||||||
|
|
||||||
|
# POST /api/v1/models/train
|
||||||
|
Request: TrainingRequest
|
||||||
|
Response: { training_id: string, status: TrainingStatus }
|
||||||
|
|
||||||
|
# GET /api/v1/models/{model_id}/status
|
||||||
|
Response: TrainingStatus
|
||||||
|
|
||||||
|
# GET /api/v1/models/{model_id}/metrics
|
||||||
|
Response: { model_id, metrics: dict }
|
||||||
|
|
||||||
|
# POST /api/v1/models/predict
|
||||||
|
Request: { model_id, timestamp, features?: dict }
|
||||||
|
Response: PredictionResponse
|
||||||
|
```
|
||||||
|
|
||||||
|
### Trading API (`/api/v1/trading/*`)
|
||||||
|
|
||||||
|
```python
|
||||||
|
# GET /api/v1/trading/strategies
|
||||||
|
Response: { strategies: StrategyStatus[] }
|
||||||
|
|
||||||
|
# POST /api/v1/trading/strategies
|
||||||
|
Request: { strategy: StrategyEnum, action: "start" | "stop" }
|
||||||
|
Response: { status: StrategyStatus }
|
||||||
|
|
||||||
|
# GET /api/v1/trading/positions
|
||||||
|
Response: { positions: TradingPosition[] }
|
||||||
|
```
|
||||||
|
|
||||||
|
### Settings API (`/api/v1/settings/*`)
|
||||||
|
|
||||||
|
```python
|
||||||
|
# GET /api/v1/settings
|
||||||
|
Response: AppSettings
|
||||||
|
|
||||||
|
# POST /api/v1/settings
|
||||||
|
Request: Partial<AppSettings>
|
||||||
|
Response: { message, updated_fields: string[] }
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Services Interface
|
||||||
|
|
||||||
|
### DataService (app/services/data_service.py)
|
||||||
|
|
||||||
|
```python
|
||||||
|
class DataService:
|
||||||
|
"""Data loading and caching service."""
|
||||||
|
|
||||||
|
async def initialize(self):
|
||||||
|
"""Load all datasets into memory."""
|
||||||
|
|
||||||
|
def get_latest_prices(self) -> Dict[str, Dict]:
|
||||||
|
"""Get latest prices for all regions."""
|
||||||
|
|
||||||
|
def get_price_history(self, region, start=None, end=None, limit=1000) -> List[Dict]:
|
||||||
|
"""Get price history for a region."""
|
||||||
|
|
||||||
|
def get_battery_states(self) -> List[Dict]:
|
||||||
|
"""Get current battery states."""
|
||||||
|
|
||||||
|
def get_arbitrage_opportunities(self, min_spread=None) -> List[Dict]:
|
||||||
|
"""Get current arbitrage opportunities."""
|
||||||
|
|
||||||
|
def get_dashboard_summary(self) -> Dict:
|
||||||
|
"""Get overall dashboard summary."""
|
||||||
|
```
|
||||||
|
|
||||||
|
### MLService (app/services/ml_service.py)
|
||||||
|
|
||||||
|
```python
|
||||||
|
class MLService:
|
||||||
|
"""Service for ML model management and inference."""
|
||||||
|
|
||||||
|
def list_models(self) -> List[ModelInfo]:
|
||||||
|
"""List all available trained models."""
|
||||||
|
|
||||||
|
def get_model_metrics(self, model_id: str) -> Dict[str, float]:
|
||||||
|
"""Get performance metrics for a model."""
|
||||||
|
|
||||||
|
def load_price_prediction_model(self, model_id: str):
|
||||||
|
"""Load price prediction model on-demand."""
|
||||||
|
|
||||||
|
def load_rl_battery_policy(self, model_id: str):
|
||||||
|
"""Load RL battery policy on-demand."""
|
||||||
|
|
||||||
|
def predict(
|
||||||
|
self,
|
||||||
|
model_id: str,
|
||||||
|
timestamp: datetime,
|
||||||
|
features: Optional[Dict[str, Any]] = None
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
"""Run prediction with on-demand model loading."""
|
||||||
|
|
||||||
|
def predict_with_confidence(
|
||||||
|
self,
|
||||||
|
model_id: str,
|
||||||
|
timestamp: datetime,
|
||||||
|
features: Optional[Dict[str, Any]] = None
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
"""Run prediction with confidence interval."""
|
||||||
|
|
||||||
|
def get_feature_importance(self, model_id: str) -> Dict[str, float]:
|
||||||
|
"""Get feature importance for a model."""
|
||||||
|
|
||||||
|
def get_model_info(self, model_id: str) -> Optional[ModelInfo]:
|
||||||
|
"""Get detailed info about a specific model."""
|
||||||
|
```
|
||||||
|
|
||||||
|
### StrategyService (app/services/strategy_service.py)
|
||||||
|
|
||||||
|
```python
|
||||||
|
class StrategyService:
|
||||||
|
"""Strategy execution service."""
|
||||||
|
|
||||||
|
async def execute_strategy(
|
||||||
|
self,
|
||||||
|
strategy: StrategyEnum,
|
||||||
|
config: Dict = None
|
||||||
|
) -> Dict:
|
||||||
|
"""Execute a trading strategy."""
|
||||||
|
|
||||||
|
async def get_strategy_status(self, strategy: StrategyEnum) -> StrategyStatus:
|
||||||
|
"""Get current status of a strategy."""
|
||||||
|
|
||||||
|
async def toggle_strategy(
|
||||||
|
self,
|
||||||
|
strategy: StrategyEnum,
|
||||||
|
action: str
|
||||||
|
) -> StrategyStatus:
|
||||||
|
"""Start or stop a strategy."""
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Tasks Interface
|
||||||
|
|
||||||
|
### Backtest Tasks (app/tasks/backtest_tasks.py)
|
||||||
|
|
||||||
|
```python
|
||||||
|
async def run_backtest_task(backtest_id: str, config: Dict, name: str = None):
|
||||||
|
"""
|
||||||
|
Execute backtest in background.
|
||||||
|
|
||||||
|
Process:
|
||||||
|
1. Load data
|
||||||
|
2. Execute strategies
|
||||||
|
3. Calculate metrics
|
||||||
|
4. Save results
|
||||||
|
5. Emit WebSocket progress events
|
||||||
|
"""
|
||||||
|
```
|
||||||
|
|
||||||
|
### Training Tasks (app/tasks/training_tasks.py)
|
||||||
|
|
||||||
|
```python
|
||||||
|
async def train_model_task(training_id: str, request: TrainingRequest):
|
||||||
|
"""
|
||||||
|
Execute ML model training via Celery task.
|
||||||
|
|
||||||
|
Dispatches to Celery for async processing.
|
||||||
|
Emits WebSocket events for progress updates.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@shared_task(name="tasks.train_price_prediction")
|
||||||
|
def train_price_prediction(training_id: str, request_dict: dict):
|
||||||
|
"""Celery task for price prediction model training."""
|
||||||
|
|
||||||
|
@shared_task(name="tasks.train_rl_battery")
|
||||||
|
def train_rl_battery(training_id: str, request_dict: dict):
|
||||||
|
"""Celery task for RL battery policy training."""
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## WebSocket Interface
|
||||||
|
|
||||||
|
### ConnectionManager (app/api/websocket.py)
|
||||||
|
|
||||||
|
```python
|
||||||
|
class ConnectionManager:
|
||||||
|
"""WebSocket connection manager."""
|
||||||
|
|
||||||
|
async def connect(self, websocket: WebSocket):
|
||||||
|
"""Accept and track new connection."""
|
||||||
|
|
||||||
|
def disconnect(self, websocket: WebSocket):
|
||||||
|
"""Remove connection."""
|
||||||
|
|
||||||
|
async def broadcast(self, event_type: str, data: Any):
|
||||||
|
"""Broadcast event to all connected clients."""
|
||||||
|
|
||||||
|
# Specific event broadcasters
|
||||||
|
async def broadcast_price_update(self, region: str, price_data: Dict):
|
||||||
|
"""Broadcast price update."""
|
||||||
|
|
||||||
|
async def broadcast_battery_update(self, battery_id: str, battery_state: Dict):
|
||||||
|
"""Broadcast battery state update."""
|
||||||
|
|
||||||
|
async def broadcast_trade(self, trade: Dict):
|
||||||
|
"""Broadcast new trade execution."""
|
||||||
|
|
||||||
|
async def broadcast_alert(self, alert: Dict):
|
||||||
|
"""Broadcast new alert."""
|
||||||
|
|
||||||
|
async def broadcast_backtest_progress(self, backtest_id: str, progress: float, status: str):
|
||||||
|
"""Broadcast backtest progress."""
|
||||||
|
|
||||||
|
async def broadcast_model_training_progress(
|
||||||
|
self,
|
||||||
|
model_id: str,
|
||||||
|
progress: float,
|
||||||
|
epoch: Optional[int] = None,
|
||||||
|
metrics: Optional[Dict] = None
|
||||||
|
):
|
||||||
|
"""Broadcast model training progress."""
|
||||||
|
```
|
||||||
|
|
||||||
|
### WebSocket Events
|
||||||
|
|
||||||
|
```python
|
||||||
|
# Event types
|
||||||
|
"price_update" # Real-time price changes
|
||||||
|
"battery_update" # Battery state changes
|
||||||
|
"arbitrage_opportunity" # New arbitrage opportunity
|
||||||
|
"trade_executed" # Trade execution
|
||||||
|
"alert_triggered" # Alert triggered
|
||||||
|
"backtest_progress" # Backtest progress
|
||||||
|
"model_training_progress" # Training progress
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Main Application
|
||||||
|
|
||||||
|
### app/main.py
|
||||||
|
|
||||||
|
```python
|
||||||
|
from fastapi import FastAPI, WebSocket
|
||||||
|
from fastapi.middleware.cors import CORSMiddleware
|
||||||
|
from contextlib import asynccontextmanager
|
||||||
|
|
||||||
|
app = FastAPI(
|
||||||
|
title=settings.APP_NAME,
|
||||||
|
version=settings.APP_VERSION,
|
||||||
|
docs_url="/docs",
|
||||||
|
redoc_url="/redoc",
|
||||||
|
)
|
||||||
|
|
||||||
|
# CORS middleware
|
||||||
|
app.add_middleware(CORSMiddleware, ...)
|
||||||
|
|
||||||
|
# Include routers
|
||||||
|
app.include_router(dashboard.router, prefix="/api/v1/dashboard", tags=["dashboard"])
|
||||||
|
app.include_router(backtest.router, prefix="/api/v1/backtest", tags=["backtest"])
|
||||||
|
app.include_router(models.router, prefix="/api/v1/models", tags=["models"])
|
||||||
|
app.include_router(trading.router, prefix="/api/v1/trading", tags=["trading"])
|
||||||
|
app.include_router(settings_routes.router, prefix="/api/v1/settings", tags=["settings"])
|
||||||
|
|
||||||
|
# Health check
|
||||||
|
@app.get("/health")
|
||||||
|
async def health_check():
|
||||||
|
return { "status": "healthy" }
|
||||||
|
|
||||||
|
# WebSocket endpoint
|
||||||
|
@app.websocket("/ws/real-time")
|
||||||
|
async def websocket_endpoint(websocket: WebSocket):
|
||||||
|
await manager.connect(websocket)
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Dependencies
|
||||||
|
|
||||||
|
### requirements.txt
|
||||||
|
|
||||||
|
```
|
||||||
|
# FastAPI & Server
|
||||||
|
fastapi>=0.104.0
|
||||||
|
uvicorn[standard]>=0.24.0
|
||||||
|
pydantic>=2.4.0
|
||||||
|
pydantic-settings>=2.0.0
|
||||||
|
|
||||||
|
# Data Processing
|
||||||
|
pandas>=2.1.0
|
||||||
|
numpy>=1.24.0
|
||||||
|
pyarrow>=14.0.0
|
||||||
|
|
||||||
|
# Machine Learning
|
||||||
|
xgboost>=2.0.0
|
||||||
|
scikit-learn>=1.3.0
|
||||||
|
|
||||||
|
# Reinforcement Learning
|
||||||
|
gymnasium>=0.29.0
|
||||||
|
stable-baselines3>=2.0.0
|
||||||
|
|
||||||
|
# Background Tasks
|
||||||
|
celery>=5.3.0
|
||||||
|
redis>=5.0.0
|
||||||
|
|
||||||
|
# WebSockets
|
||||||
|
websockets>=12.0.0
|
||||||
|
|
||||||
|
# Database
|
||||||
|
sqlalchemy>=2.0.0
|
||||||
|
alembic>=1.12.0
|
||||||
|
|
||||||
|
# Utilities
|
||||||
|
python-multipart>=0.0.6
|
||||||
|
python-jose[cryptography]>=3.3.0
|
||||||
|
python-dotenv>=1.0.0
|
||||||
|
|
||||||
|
# Testing
|
||||||
|
pytest>=7.4.0
|
||||||
|
pytest-asyncio>=0.21.0
|
||||||
|
httpx>=0.25.0
|
||||||
|
|
||||||
|
# Logging
|
||||||
|
loguru>=0.7.0
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Environment Variables
|
||||||
|
|
||||||
|
### .env.example
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Application
|
||||||
|
APP_NAME=Energy Trading API
|
||||||
|
APP_VERSION=1.0.0
|
||||||
|
DEBUG=true
|
||||||
|
|
||||||
|
# Server
|
||||||
|
HOST=0.0.0.0
|
||||||
|
PORT=8000
|
||||||
|
|
||||||
|
# Data
|
||||||
|
DATA_PATH=~/energy-test-data/data/processed
|
||||||
|
|
||||||
|
# CORS
|
||||||
|
CORS_ORIGINS=http://localhost:3000,http://localhost:5173
|
||||||
|
|
||||||
|
# Celery
|
||||||
|
CELERY_BROKER_URL=redis://localhost:6379/0
|
||||||
|
CELERY_RESULT_BACKEND=redis://localhost:6379/0
|
||||||
|
|
||||||
|
# Paths
|
||||||
|
MODELS_PATH=models
|
||||||
|
RESULTS_PATH=results
|
||||||
|
|
||||||
|
# Battery
|
||||||
|
BATTERY_MIN_RESERVE=0.10
|
||||||
|
BATTERY_MAX_CHARGE=0.90
|
||||||
|
|
||||||
|
# Arbitrage
|
||||||
|
ARBITRAGE_MIN_SPREAD=5.0
|
||||||
|
|
||||||
|
# Mining
|
||||||
|
MINING_MARGIN_THRESHOLD=5.0
|
||||||
|
```
|
||||||
780
FRONTEND_IMPLEMENTATION.md
Normal file
780
FRONTEND_IMPLEMENTATION.md
Normal file
@@ -0,0 +1,780 @@
|
|||||||
|
# Frontend Implementation Strategy
|
||||||
|
|
||||||
|
## Overview
|
||||||
|
|
||||||
|
This document outlines the React frontend for the energy trading system UI. The frontend provides real-time monitoring, backtesting tools, ML model insights, and trading controls.
|
||||||
|
|
||||||
|
**Backend API**: `http://localhost:8000`
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Architecture
|
||||||
|
|
||||||
|
```
|
||||||
|
┌──────────────────────────────────────────────────────────────┐
|
||||||
|
│ React Application │
|
||||||
|
├──────────────────────────────────────────────────────────────┤
|
||||||
|
│ │
|
||||||
|
│ ┌──────────────────────────────────────────────────────┐ │
|
||||||
|
│ │ Pages Layer │ │
|
||||||
|
│ │ Dashboard │ Backtest │ Models │ Trading │ Settings │ │
|
||||||
|
│ └──────────────────────────────────────────────────────┘ │
|
||||||
|
│ │ │
|
||||||
|
│ ┌──────────────────────────────────────────────────────┐ │
|
||||||
|
│ │ Components Layer │ │
|
||||||
|
│ │ Charts │ Forms │ Alerts │ Tables │ Controls │ │
|
||||||
|
│ └──────────────────────────────────────────────────────┘ │
|
||||||
|
│ │ │
|
||||||
|
│ ┌──────────────────────────────────────────────────────┐ │
|
||||||
|
│ │ Services Layer │ │
|
||||||
|
│ │ API Client │ WebSocket │ State Management │ │
|
||||||
|
│ └──────────────────────────────────────────────────────┘ │
|
||||||
|
└──────────────────────────────────────────────────────────────┘
|
||||||
|
│
|
||||||
|
▼
|
||||||
|
┌──────────────────────────────────────────────────────────────┐
|
||||||
|
│ FastAPI Backend │
|
||||||
|
│ REST API + WebSocket │
|
||||||
|
└──────────────────────────────────────────────────────────────┘
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Project Structure
|
||||||
|
|
||||||
|
```
|
||||||
|
frontend/
|
||||||
|
├── public/
|
||||||
|
│ ├── favicon.ico
|
||||||
|
│ └── index.html
|
||||||
|
│
|
||||||
|
├── src/
|
||||||
|
│ ├── App.tsx # Main app component
|
||||||
|
│ ├── main.tsx # Entry point
|
||||||
|
│ ├── index.css # Global styles
|
||||||
|
│ │
|
||||||
|
│ ├── components/
|
||||||
|
│ │ ├── common/
|
||||||
|
│ │ │ ├── Header.tsx
|
||||||
|
│ │ │ ├── Sidebar.tsx
|
||||||
|
│ │ │ ├── Loading.tsx
|
||||||
|
│ │ │ └── Error.tsx
|
||||||
|
│ │ │
|
||||||
|
│ │ ├── charts/
|
||||||
|
│ │ │ ├── PriceChart.tsx
|
||||||
|
│ │ │ ├── BatteryChart.tsx
|
||||||
|
│ │ │ ├── PnLChart.tsx
|
||||||
|
│ │ │ ├── GenerationChart.tsx
|
||||||
|
│ │ │ └── ModelMetricsChart.tsx
|
||||||
|
│ │ │
|
||||||
|
│ │ ├── alerts/
|
||||||
|
│ │ │ ├── AlertPanel.tsx
|
||||||
|
│ │ │ └── AlertItem.tsx
|
||||||
|
│ │ │
|
||||||
|
│ │ ├── tables/
|
||||||
|
│ │ │ ├── ArbitrageTable.tsx
|
||||||
|
│ │ │ ├── TradeLogTable.tsx
|
||||||
|
│ │ │ └── ModelListTable.tsx
|
||||||
|
│ │ │
|
||||||
|
│ │ └── forms/
|
||||||
|
│ │ ├── BacktestForm.tsx
|
||||||
|
│ │ ├── TrainingForm.tsx
|
||||||
|
│ │ └── SettingsForm.tsx
|
||||||
|
│ │
|
||||||
|
│ ├── pages/
|
||||||
|
│ │ ├── Dashboard.tsx
|
||||||
|
│ │ ├── Backtest.tsx
|
||||||
|
│ │ ├── Models.tsx
|
||||||
|
│ │ ├── Trading.tsx
|
||||||
|
│ │ └── Settings.tsx
|
||||||
|
│ │
|
||||||
|
│ ├── hooks/
|
||||||
|
│ │ ├── useWebSocket.ts
|
||||||
|
│ │ ├── useApi.ts
|
||||||
|
│ │ ├── useBacktest.ts
|
||||||
|
│ │ ├── useModels.ts
|
||||||
|
│ │ └── useTrading.ts
|
||||||
|
│ │
|
||||||
|
│ ├── services/
|
||||||
|
│ │ ├── api.ts # REST API client
|
||||||
|
│ │ ├── websocket.ts # WebSocket client
|
||||||
|
│ │ └── types.ts # TypeScript types
|
||||||
|
│ │
|
||||||
|
│ ├── store/
|
||||||
|
│ │ ├── index.ts # Zustand store setup
|
||||||
|
│ │ ├── dashboardSlice.ts
|
||||||
|
│ │ ├── backtestSlice.ts
|
||||||
|
│ │ ├── modelsSlice.ts
|
||||||
|
│ │ └── tradingSlice.ts
|
||||||
|
│ │
|
||||||
|
│ └── lib/
|
||||||
|
│ ├── utils.ts
|
||||||
|
│ ├── constants.ts
|
||||||
|
│ └── formatters.ts
|
||||||
|
│
|
||||||
|
├── tests/
|
||||||
|
├── .env.example
|
||||||
|
├── .env.local
|
||||||
|
├── package.json
|
||||||
|
├── tsconfig.json
|
||||||
|
├── vite.config.ts
|
||||||
|
└── tailwind.config.js
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Configuration
|
||||||
|
|
||||||
|
### vite.config.ts
|
||||||
|
|
||||||
|
```typescript
|
||||||
|
import { defineConfig } from 'vite';
|
||||||
|
import react from '@vitejs/plugin-react';
|
||||||
|
import path from 'path';
|
||||||
|
|
||||||
|
export default defineConfig({
|
||||||
|
plugins: [react()],
|
||||||
|
resolve: {
|
||||||
|
alias: {
|
||||||
|
'@': path.resolve(__dirname, './src'),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
server: {
|
||||||
|
port: 3000,
|
||||||
|
proxy: {
|
||||||
|
'/api': {
|
||||||
|
target: 'http://localhost:8000',
|
||||||
|
changeOrigin: true,
|
||||||
|
},
|
||||||
|
'/ws': {
|
||||||
|
target: 'ws://localhost:8000',
|
||||||
|
ws: true,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
});
|
||||||
|
```
|
||||||
|
|
||||||
|
### .env.example
|
||||||
|
|
||||||
|
```bash
|
||||||
|
VITE_API_URL=http://localhost:8000
|
||||||
|
VITE_WS_URL=ws://localhost:8000/ws/real-time
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## TypeScript Types
|
||||||
|
|
||||||
|
### services/types.ts
|
||||||
|
|
||||||
|
#### Enums
|
||||||
|
|
||||||
|
```typescript
|
||||||
|
export enum Region {
|
||||||
|
FR = "FR",
|
||||||
|
BE = "BE",
|
||||||
|
DE = "DE",
|
||||||
|
NL = "NL",
|
||||||
|
UK = "UK",
|
||||||
|
}
|
||||||
|
|
||||||
|
export enum Strategy {
|
||||||
|
FUNDAMENTAL = "fundamental",
|
||||||
|
TECHNICAL = "technical",
|
||||||
|
ML = "ml",
|
||||||
|
MINING = "mining",
|
||||||
|
}
|
||||||
|
|
||||||
|
export enum TradeType {
|
||||||
|
BUY = "buy",
|
||||||
|
SELL = "sell",
|
||||||
|
CHARGE = "charge",
|
||||||
|
DISCHARGE = "discharge",
|
||||||
|
}
|
||||||
|
|
||||||
|
export enum BacktestStatus {
|
||||||
|
PENDING = "pending",
|
||||||
|
RUNNING = "running",
|
||||||
|
COMPLETED = "completed",
|
||||||
|
FAILED = "failed",
|
||||||
|
CANCELLED = "cancelled",
|
||||||
|
}
|
||||||
|
|
||||||
|
export enum ModelType {
|
||||||
|
PRICE_PREDICTION = "price_prediction",
|
||||||
|
RL_BATTERY = "rl_battery",
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
#### Key Interfaces
|
||||||
|
|
||||||
|
```typescript
|
||||||
|
export interface PriceData {
|
||||||
|
timestamp: string;
|
||||||
|
region: Region;
|
||||||
|
day_ahead_price: number;
|
||||||
|
real_time_price: number;
|
||||||
|
volume_mw: number;
|
||||||
|
}
|
||||||
|
|
||||||
|
export interface BatteryState {
|
||||||
|
battery_id: string;
|
||||||
|
timestamp: string;
|
||||||
|
capacity_mwh: number;
|
||||||
|
charge_level_mwh: number;
|
||||||
|
charge_level_pct: number;
|
||||||
|
charge_rate_mw: number;
|
||||||
|
discharge_rate_mw: number;
|
||||||
|
efficiency: number;
|
||||||
|
}
|
||||||
|
|
||||||
|
export interface BacktestConfig {
|
||||||
|
start_date: string;
|
||||||
|
end_date: string;
|
||||||
|
strategies: Strategy[];
|
||||||
|
use_ml: boolean;
|
||||||
|
battery_min_reserve?: number;
|
||||||
|
battery_max_charge?: number;
|
||||||
|
arbitrage_min_spread?: number;
|
||||||
|
}
|
||||||
|
|
||||||
|
export interface BacktestMetrics {
|
||||||
|
total_revenue: number;
|
||||||
|
arbitrage_profit: number;
|
||||||
|
battery_revenue: number;
|
||||||
|
mining_profit: number;
|
||||||
|
battery_utilization: number;
|
||||||
|
price_capture_rate: number;
|
||||||
|
win_rate: number;
|
||||||
|
sharpe_ratio: number;
|
||||||
|
max_drawdown: number;
|
||||||
|
total_trades: number;
|
||||||
|
}
|
||||||
|
|
||||||
|
export interface BacktestStatus {
|
||||||
|
id: string;
|
||||||
|
status: BacktestStatus;
|
||||||
|
progress: number;
|
||||||
|
current_step: string;
|
||||||
|
started_at?: string;
|
||||||
|
completed_at?: string;
|
||||||
|
error?: string;
|
||||||
|
}
|
||||||
|
|
||||||
|
export interface ModelInfo {
|
||||||
|
id: string;
|
||||||
|
type: ModelType;
|
||||||
|
name: string;
|
||||||
|
horizon?: number;
|
||||||
|
created_at: string;
|
||||||
|
metrics: Record<string, number>;
|
||||||
|
status: string;
|
||||||
|
}
|
||||||
|
|
||||||
|
export interface TrainingRequest {
|
||||||
|
model_type: ModelType;
|
||||||
|
horizon?: number;
|
||||||
|
start_date: string;
|
||||||
|
end_date: string;
|
||||||
|
hyperparameters: Record<string, unknown>;
|
||||||
|
}
|
||||||
|
|
||||||
|
export interface TrainingStatus {
|
||||||
|
id: string;
|
||||||
|
status: BacktestStatus;
|
||||||
|
progress: number;
|
||||||
|
current_epoch?: number;
|
||||||
|
total_epochs?: number;
|
||||||
|
metrics: Record<string, number>;
|
||||||
|
started_at?: string;
|
||||||
|
completed_at?: string;
|
||||||
|
error?: string;
|
||||||
|
}
|
||||||
|
|
||||||
|
export interface PredictionRequest {
|
||||||
|
model_id: string;
|
||||||
|
timestamp: string;
|
||||||
|
features?: Record<string, unknown>;
|
||||||
|
}
|
||||||
|
|
||||||
|
export interface PredictionResponse {
|
||||||
|
model_id: string;
|
||||||
|
timestamp: string;
|
||||||
|
prediction: number;
|
||||||
|
confidence?: number;
|
||||||
|
features_used: string[];
|
||||||
|
}
|
||||||
|
|
||||||
|
export interface StrategyStatus {
|
||||||
|
strategy: Strategy;
|
||||||
|
running: boolean;
|
||||||
|
last_updated?: string;
|
||||||
|
}
|
||||||
|
|
||||||
|
export interface TradingPosition {
|
||||||
|
region: Region;
|
||||||
|
position_mw: number;
|
||||||
|
battery_charge_pct: number;
|
||||||
|
pnl: number;
|
||||||
|
}
|
||||||
|
|
||||||
|
export interface Alert {
|
||||||
|
id: string;
|
||||||
|
timestamp: string;
|
||||||
|
type: AlertType;
|
||||||
|
severity: "info" | "warning" | "error";
|
||||||
|
message: string;
|
||||||
|
data: Record<string, unknown>;
|
||||||
|
acknowledged: boolean;
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## API Client Interface
|
||||||
|
|
||||||
|
### services/api.ts
|
||||||
|
|
||||||
|
#### Dashboard API
|
||||||
|
|
||||||
|
```typescript
|
||||||
|
export const dashboardApi = {
|
||||||
|
getSummary: async (): Promise<DashboardSummary> => { },
|
||||||
|
|
||||||
|
getPrices: async (): Promise<Record<string, PriceData>> => { },
|
||||||
|
|
||||||
|
getPriceHistory: async (
|
||||||
|
region: string,
|
||||||
|
start?: string,
|
||||||
|
end?: string,
|
||||||
|
limit?: number
|
||||||
|
): Promise<{ region: string; data: PriceData[] }> => { },
|
||||||
|
|
||||||
|
getBatteryStates: async (): Promise<BatteryState[]> => { },
|
||||||
|
|
||||||
|
getArbitrage: async (minSpread?: number): Promise<{
|
||||||
|
opportunities: ArbitrageOpportunity[];
|
||||||
|
count: number;
|
||||||
|
}> => { },
|
||||||
|
};
|
||||||
|
```
|
||||||
|
|
||||||
|
#### Backtest API
|
||||||
|
|
||||||
|
```typescript
|
||||||
|
export const backtestApi = {
|
||||||
|
start: async (request: BacktestRequest): Promise<{
|
||||||
|
backtest_id: string;
|
||||||
|
status: BacktestStatus;
|
||||||
|
}> => { },
|
||||||
|
|
||||||
|
get: async (backtestId: string): Promise<{
|
||||||
|
status: BacktestStatus;
|
||||||
|
results?: BacktestResult;
|
||||||
|
}> => { },
|
||||||
|
|
||||||
|
getResults: async (backtestId: string): Promise<BacktestResult> => { },
|
||||||
|
|
||||||
|
getTrades: async (backtestId: string, limit?: number): Promise<{
|
||||||
|
backtest_id: string;
|
||||||
|
trades: Trade[];
|
||||||
|
total: number;
|
||||||
|
}> => { },
|
||||||
|
|
||||||
|
list: async (): Promise<{
|
||||||
|
backtests: BacktestStatus[];
|
||||||
|
total: number;
|
||||||
|
}> => { },
|
||||||
|
|
||||||
|
delete: async (backtestId: string): Promise<{ message: string }> => { },
|
||||||
|
};
|
||||||
|
```
|
||||||
|
|
||||||
|
#### Models API
|
||||||
|
|
||||||
|
```typescript
|
||||||
|
export const modelsApi = {
|
||||||
|
list: async (): Promise<{ models: ModelInfo[]; total: number }> => { },
|
||||||
|
|
||||||
|
train: async (request: TrainingRequest): Promise<{
|
||||||
|
training_id: string;
|
||||||
|
status: TrainingStatus;
|
||||||
|
}> => { },
|
||||||
|
|
||||||
|
getStatus: async (modelId: string): Promise<TrainingStatus> => { },
|
||||||
|
|
||||||
|
getMetrics: async (modelId: string): Promise<{
|
||||||
|
model_id: string;
|
||||||
|
metrics: Record<string, number>;
|
||||||
|
}> => { },
|
||||||
|
|
||||||
|
predict: async (request: PredictionRequest): Promise<PredictionResponse> => { },
|
||||||
|
};
|
||||||
|
```
|
||||||
|
|
||||||
|
#### Trading API
|
||||||
|
|
||||||
|
```typescript
|
||||||
|
export const tradingApi = {
|
||||||
|
getStrategies: async (): Promise<{ strategies: StrategyStatus[] }> => { },
|
||||||
|
|
||||||
|
toggleStrategy: async (control: {
|
||||||
|
strategy: Strategy;
|
||||||
|
action: "start" | "stop";
|
||||||
|
}): Promise<{ status: StrategyStatus }> => { },
|
||||||
|
|
||||||
|
getPositions: async (): Promise<{ positions: TradingPosition[] }> => { },
|
||||||
|
};
|
||||||
|
```
|
||||||
|
|
||||||
|
#### Settings API
|
||||||
|
|
||||||
|
```typescript
|
||||||
|
export const settingsApi = {
|
||||||
|
get: async (): Promise<AppSettings> => { },
|
||||||
|
|
||||||
|
update: async (settings: Partial<AppSettings>): Promise<{
|
||||||
|
message: string;
|
||||||
|
updated_fields: string[];
|
||||||
|
}> => { },
|
||||||
|
};
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## WebSocket Client Interface
|
||||||
|
|
||||||
|
### services/websocket.ts
|
||||||
|
|
||||||
|
```typescript
|
||||||
|
class WebSocketService {
|
||||||
|
private ws: WebSocket | null = null;
|
||||||
|
private url: string;
|
||||||
|
private eventHandlers: Map<WebSocketEventType, Set<EventHandler>> = new Map();
|
||||||
|
private isConnected = false;
|
||||||
|
|
||||||
|
constructor(url: string = import.meta.env.VITE_WS_URL);
|
||||||
|
|
||||||
|
connect(): void;
|
||||||
|
disconnect(): void;
|
||||||
|
|
||||||
|
subscribe<T = unknown>(
|
||||||
|
eventType: WebSocketEventType,
|
||||||
|
handler: (data: T) => void
|
||||||
|
): () => void; // Returns unsubscribe function
|
||||||
|
|
||||||
|
getConnectionStatus(): boolean;
|
||||||
|
|
||||||
|
private handleMessage(message: WebSocketMessage): void;
|
||||||
|
private attemptReconnect(): void;
|
||||||
|
}
|
||||||
|
|
||||||
|
export const webSocketService = new WebSocketService();
|
||||||
|
```
|
||||||
|
|
||||||
|
### WebSocket Event Types
|
||||||
|
|
||||||
|
```typescript
|
||||||
|
export type WebSocketEventType =
|
||||||
|
| "price_update"
|
||||||
|
| "battery_update"
|
||||||
|
| "arbitrage_opportunity"
|
||||||
|
| "trade_executed"
|
||||||
|
| "alert_triggered"
|
||||||
|
| "backtest_progress"
|
||||||
|
| "model_training_progress";
|
||||||
|
|
||||||
|
export interface WebSocketMessage<T = unknown> {
|
||||||
|
type: WebSocketEventType;
|
||||||
|
timestamp: string;
|
||||||
|
data: T;
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## State Management (Zustand)
|
||||||
|
|
||||||
|
### store/index.ts
|
||||||
|
|
||||||
|
```typescript
|
||||||
|
interface DashboardState {
|
||||||
|
summary: DashboardSummary | null;
|
||||||
|
prices: Record<string, PriceData>;
|
||||||
|
batteryStates: BatteryState[];
|
||||||
|
arbitrageOpportunities: ArbitrageOpportunity[];
|
||||||
|
alerts: Alert[];
|
||||||
|
|
||||||
|
updateSummary: (summary: DashboardSummary) => void;
|
||||||
|
updatePrices: (prices: Record<string, PriceData>) => void;
|
||||||
|
addAlert: (alert: Alert) => void;
|
||||||
|
// ...
|
||||||
|
}
|
||||||
|
|
||||||
|
interface BacktestState {
|
||||||
|
backtests: Record<string, BacktestStatus>;
|
||||||
|
currentBacktest: string | null;
|
||||||
|
isRunning: boolean;
|
||||||
|
|
||||||
|
updateBacktest: (status: BacktestStatus) => void;
|
||||||
|
setCurrentBacktest: (backtestId: string | null) => void;
|
||||||
|
// ...
|
||||||
|
}
|
||||||
|
|
||||||
|
interface ModelsState {
|
||||||
|
models: ModelInfo[];
|
||||||
|
trainingJobs: Record<string, TrainingStatus>;
|
||||||
|
selectedModel: string | null;
|
||||||
|
|
||||||
|
setModels: (models: ModelInfo[]) => void;
|
||||||
|
updateTrainingJob: (job: TrainingStatus) => void;
|
||||||
|
// ...
|
||||||
|
}
|
||||||
|
|
||||||
|
interface TradingState {
|
||||||
|
strategies: Record<string, StrategyStatus>;
|
||||||
|
positions: TradingPosition[];
|
||||||
|
pnl: number;
|
||||||
|
|
||||||
|
updateStrategy: (strategy: StrategyStatus) => void;
|
||||||
|
updatePositions: (positions: TradingPosition[]) => void;
|
||||||
|
// ...
|
||||||
|
}
|
||||||
|
|
||||||
|
export const useStore = create<AppStore>((set) => ({
|
||||||
|
// ... state and actions
|
||||||
|
}));
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Custom Hooks
|
||||||
|
|
||||||
|
### hooks/useWebSocket.ts
|
||||||
|
|
||||||
|
```typescript
|
||||||
|
export function useWebSocket() {
|
||||||
|
const subscribe = <T = unknown>(
|
||||||
|
eventType: WebSocketEventType,
|
||||||
|
handler: (data: T) => void
|
||||||
|
): (() => void) => {
|
||||||
|
return webSocketService.subscribe<T>(eventType, handler);
|
||||||
|
};
|
||||||
|
|
||||||
|
const isConnected = webSocketService.getConnectionStatus();
|
||||||
|
|
||||||
|
return { subscribe, isConnected };
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
### hooks/useApi.ts
|
||||||
|
|
||||||
|
```typescript
|
||||||
|
// Dashboard hooks
|
||||||
|
export function useDashboardSummary();
|
||||||
|
export function usePrices();
|
||||||
|
export function useBatteryStates();
|
||||||
|
export function useArbitrageOpportunities(minSpread?: number);
|
||||||
|
|
||||||
|
// Backtest hooks
|
||||||
|
export function useStartBacktest();
|
||||||
|
export function useBacktest(backtestId: string);
|
||||||
|
export function useBacktestList();
|
||||||
|
|
||||||
|
// Models hooks
|
||||||
|
export function useModels();
|
||||||
|
export function useTrainModel();
|
||||||
|
|
||||||
|
// Trading hooks
|
||||||
|
export function useStrategies();
|
||||||
|
export function useToggleStrategy();
|
||||||
|
export function usePositions();
|
||||||
|
|
||||||
|
// Settings hooks
|
||||||
|
export function useSettings();
|
||||||
|
export function useUpdateSettings();
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Pages Interface
|
||||||
|
|
||||||
|
### Dashboard.tsx
|
||||||
|
|
||||||
|
```typescript
|
||||||
|
export default function Dashboard() {
|
||||||
|
const { subscribe } = useWebSocket();
|
||||||
|
const { prices, batteryStates, arbitrageOpportunities, alerts } = useStore();
|
||||||
|
const { data: pricesData } = usePrices();
|
||||||
|
|
||||||
|
// Subscribe to real-time updates
|
||||||
|
useEffect(() => {
|
||||||
|
const unsubscribePrice = subscribe('price_update', (data) => { });
|
||||||
|
const unsubscribeBattery = subscribe('battery_update', (data) => { });
|
||||||
|
const unsubscribeAlert = subscribe('alert_triggered', (data) => { });
|
||||||
|
|
||||||
|
return () => {
|
||||||
|
unsubscribePrice();
|
||||||
|
unsubscribeBattery();
|
||||||
|
unsubscribeAlert();
|
||||||
|
};
|
||||||
|
}, [subscribe]);
|
||||||
|
|
||||||
|
// Render stats cards, charts, tables
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
### Backtest.tsx
|
||||||
|
|
||||||
|
```typescript
|
||||||
|
export default function Backtest() {
|
||||||
|
const [selectedBacktestId, setSelectedBacktestId] = useState<string | null>(null);
|
||||||
|
|
||||||
|
const { mutate: startBacktest, isPending } = useStartBacktest();
|
||||||
|
const { data: backtest } = useBacktest(selectedBacktestId || '');
|
||||||
|
const { data: backtestList } = useBacktestList();
|
||||||
|
|
||||||
|
const handleStartBacktest = (config: any) => {
|
||||||
|
startBacktest({ config, name: config.name }, {
|
||||||
|
onSuccess: (data) => {
|
||||||
|
setSelectedBacktestId(data.backtest_id);
|
||||||
|
},
|
||||||
|
});
|
||||||
|
};
|
||||||
|
|
||||||
|
// Render form, results, progress
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
### Models.tsx
|
||||||
|
|
||||||
|
```typescript
|
||||||
|
export default function Models() {
|
||||||
|
const [selectedModel, setSelectedModel] = useState<string | null>(null);
|
||||||
|
const [showTrainingForm, setShowTrainingForm] = useState(false);
|
||||||
|
|
||||||
|
const { data: modelsData } = useModels();
|
||||||
|
const { mutate: trainModel, isPending } = useTrainModel();
|
||||||
|
|
||||||
|
const handleTrainModel = (config: any) => {
|
||||||
|
trainModel(config, {
|
||||||
|
onSuccess: () => {
|
||||||
|
setShowTrainingForm(false);
|
||||||
|
},
|
||||||
|
});
|
||||||
|
};
|
||||||
|
|
||||||
|
// Render model list, training form, model details
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
### Trading.tsx
|
||||||
|
|
||||||
|
```typescript
|
||||||
|
export default function Trading() {
|
||||||
|
const { data: strategiesData } = useStrategies();
|
||||||
|
const { data: positionsData } = usePositions();
|
||||||
|
const { mutate: toggleStrategy } = useToggleStrategy();
|
||||||
|
|
||||||
|
const handleToggleStrategy = (strategyName: string, running: boolean) => {
|
||||||
|
toggleStrategy({
|
||||||
|
strategy: strategyName as any,
|
||||||
|
action: running ? 'stop' : 'start',
|
||||||
|
});
|
||||||
|
};
|
||||||
|
|
||||||
|
// Render strategy status, positions, controls
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Dependencies
|
||||||
|
|
||||||
|
### package.json
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"name": "energy-trading-ui",
|
||||||
|
"version": "1.0.0",
|
||||||
|
"type": "module",
|
||||||
|
"scripts": {
|
||||||
|
"dev": "vite",
|
||||||
|
"build": "tsc && vite build",
|
||||||
|
"preview": "vite preview",
|
||||||
|
"test": "vitest",
|
||||||
|
"type-check": "tsc --noEmit"
|
||||||
|
},
|
||||||
|
"dependencies": {
|
||||||
|
"react": "^18.2.0",
|
||||||
|
"react-dom": "^18.2.0",
|
||||||
|
"react-router-dom": "^6.20.0",
|
||||||
|
"recharts": "^2.10.0",
|
||||||
|
"zustand": "^4.4.0",
|
||||||
|
"@tanstack/react-query": "^5.0.0",
|
||||||
|
"axios": "^1.6.0",
|
||||||
|
"date-fns": "^2.30.0",
|
||||||
|
"lucide-react": "^0.292.0"
|
||||||
|
},
|
||||||
|
"devDependencies": {
|
||||||
|
"@types/react": "^18.2.37",
|
||||||
|
"@types/react-dom": "^18.2.15",
|
||||||
|
"@vitejs/plugin-react": "^4.2.0",
|
||||||
|
"typescript": "^5.2.2",
|
||||||
|
"vite": "^5.0.0",
|
||||||
|
"vitest": "^1.0.0",
|
||||||
|
"tailwindcss": "^3.3.5"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Data Flow
|
||||||
|
|
||||||
|
### API Calls
|
||||||
|
|
||||||
|
```
|
||||||
|
Component → useApi Hook → API Client → Axios → Backend
|
||||||
|
↓
|
||||||
|
React Query Cache
|
||||||
|
↓
|
||||||
|
Automatic Refetch
|
||||||
|
```
|
||||||
|
|
||||||
|
### WebSocket Updates
|
||||||
|
|
||||||
|
```
|
||||||
|
WebSocket Event → WebSocketService → Event Handler → Store Update → Component Re-render
|
||||||
|
```
|
||||||
|
|
||||||
|
### State Synchronization
|
||||||
|
|
||||||
|
```
|
||||||
|
API Response → React Query Cache → Component Props → Zustand Store (optional)
|
||||||
|
WebSocket Event → Zustand Store → Component Re-render
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Key Integration Points
|
||||||
|
|
||||||
|
### Backend API Integration
|
||||||
|
|
||||||
|
- REST API endpoints are fully typed in `services/types.ts`
|
||||||
|
- API client methods in `services/api.ts` match backend routes
|
||||||
|
- WebSocket events match backend event types
|
||||||
|
|
||||||
|
### Error Handling
|
||||||
|
|
||||||
|
- API errors handled by React Query error states
|
||||||
|
- WebSocket reconnection automatic with exponential backoff
|
||||||
|
- Component-level error boundaries for unhandled errors
|
||||||
|
|
||||||
|
### Performance
|
||||||
|
|
||||||
|
- React Query caching with automatic refetch intervals
|
||||||
|
- WebSocket connection pooling
|
||||||
|
- Component memoization where appropriate
|
||||||
|
- Virtual scrolling for large data tables
|
||||||
679
ML_IMPLEMENTATION.md
Normal file
679
ML_IMPLEMENTATION.md
Normal file
@@ -0,0 +1,679 @@
|
|||||||
|
# ML Implementation Strategy
|
||||||
|
|
||||||
|
## Overview
|
||||||
|
|
||||||
|
This document outlines the machine learning components for the energy trading system, including gradient boosting price prediction models and Q-Learning reinforcement learning for battery optimization.
|
||||||
|
|
||||||
|
**Package Location**: `backend/app/ml/` (integrated within the FastAPI backend)
|
||||||
|
|
||||||
|
**Related Documents**:
|
||||||
|
- `BACKEND_IMPLEMENTATION.md` - API endpoints, services, and ML integration layer
|
||||||
|
- `FRONTEND_IMPLEMENTATION.md` - React frontend that consumes backend API
|
||||||
|
|
||||||
|
**Training Split Strategy**: Time-based split (first 7 days train, next 1.5 days validation, last 1.5 days test) to prevent look-ahead bias and ensure realistic evaluation.
|
||||||
|
|
||||||
|
**Data Source**: `~/energy-test-data/data/processed/*.parquet`
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Architecture
|
||||||
|
|
||||||
|
```
|
||||||
|
┌──────────────────────────────────────────────────────────────┐
|
||||||
|
│ ML Pipeline │
|
||||||
|
├──────────────────────────────────────────────────────────────┤
|
||||||
|
│ │
|
||||||
|
│ ┌──────────────┬──────────────────────┬──────────────────┐ │
|
||||||
|
│ │ Feature │ │ │ │
|
||||||
|
│ │ Engineering │ Model Training │ Model Serving │ │
|
||||||
|
│ │ │ │ │ │
|
||||||
|
│ │ - Lags │ - XGBoost │ - Load Model │ │
|
||||||
|
│ │ - Rolling │ - Q-Learning │ - Predict │ │
|
||||||
|
│ │ - Time │ - Validation │ - Return Action │ │
|
||||||
|
│ │ - Regions │ - Evaluation │ │ │
|
||||||
|
│ └──────────────┴──────────────────────┴──────────────────┘ │
|
||||||
|
│ │ │
|
||||||
|
│ ┌──────────────────────┴──────────────────────────────┐ │
|
||||||
|
│ │ Model Management │ │
|
||||||
|
│ │ - Versioning - Persistence - Registry - Backup │ │
|
||||||
|
│ └──────────────────────────────────────────────────────┘ │
|
||||||
|
└──────────────────────────────────────────────────────────────┘
|
||||||
|
│
|
||||||
|
▼
|
||||||
|
┌──────────────────────────────────────────────────────────────┐
|
||||||
|
│ FastAPI Backend (app.services.ml_service) │
|
||||||
|
│ /api/v1/models/train, /api/v1/models/predict │
|
||||||
|
└──────────────────────────────────────────────────────────────┘
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Project Structure
|
||||||
|
|
||||||
|
```
|
||||||
|
backend/app/ml/
|
||||||
|
├── __init__.py
|
||||||
|
│
|
||||||
|
├── features/
|
||||||
|
│ ├── __init__.py
|
||||||
|
│ ├── lag_features.py # Price lag feature extraction
|
||||||
|
│ ├── rolling_features.py # Rolling statistics
|
||||||
|
│ ├── time_features.py # Time-of-day encoding
|
||||||
|
│ ├── regional_features.py # Cross-region differentials
|
||||||
|
│ └── battery_features.py # Battery state features
|
||||||
|
│
|
||||||
|
├── price_prediction/
|
||||||
|
│ ├── __init__.py
|
||||||
|
│ ├── model.py # XGBoost model wrapper
|
||||||
|
│ ├── trainer.py # Training pipeline
|
||||||
|
│ ├── evaluator.py # Evaluation metrics
|
||||||
|
│ └── predictor.py # Prediction interface
|
||||||
|
│
|
||||||
|
├── rl_battery/
|
||||||
|
│ ├── __init__.py
|
||||||
|
│ ├── environment.py # Battery MDP environment
|
||||||
|
│ ├── agent.py # Q-Learning agent
|
||||||
|
│ ├── trainer.py # RL training loop
|
||||||
|
│ └── policy.py # Policy inference
|
||||||
|
│
|
||||||
|
├── model_management/
|
||||||
|
│ ├── __init__.py
|
||||||
|
│ ├── registry.py # Model registry
|
||||||
|
│ ├── persistence.py # Save/load models
|
||||||
|
│ ├── versioning.py # Version handling
|
||||||
|
│ └── comparison.py # Model comparison
|
||||||
|
│
|
||||||
|
├── evaluation/
|
||||||
|
│ ├── __init__.py
|
||||||
|
│ ├── metrics.py # Common evaluation metrics
|
||||||
|
│ ├── backtest_evaluator.py # Backtest performance evaluation
|
||||||
|
│ └── reports.py # Generate evaluation reports
|
||||||
|
│
|
||||||
|
├── training/
|
||||||
|
│ ├── __init__.py
|
||||||
|
│ ├── cli.py # CLI commands for retraining
|
||||||
|
│ ├── pipeline.py # End-to-end training pipeline
|
||||||
|
│ └── scheduler.py # Training job scheduler
|
||||||
|
│
|
||||||
|
└── utils/
|
||||||
|
├── __init__.py
|
||||||
|
├── data_split.py # Time-based data splitting
|
||||||
|
├── preprocessing.py # Data preprocessing
|
||||||
|
├── config.py # ML configuration
|
||||||
|
└── evaluation.py # Evaluation metrics
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Configuration
|
||||||
|
|
||||||
|
### backend/app/ml/utils/config.py
|
||||||
|
|
||||||
|
```python
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import List, Dict, Any
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class PricePredictionConfig:
|
||||||
|
"""Configuration for price prediction models."""
|
||||||
|
|
||||||
|
# Data
|
||||||
|
data_path: str = "~/energy-test-data/data/processed"
|
||||||
|
target_column: str = "real_time_price"
|
||||||
|
|
||||||
|
# Training split (time-based)
|
||||||
|
train_end_pct: float = 0.70 # First 70% for training
|
||||||
|
val_end_pct: float = 0.85 # Next 15% for validation
|
||||||
|
# Last 15% for testing
|
||||||
|
|
||||||
|
# Features
|
||||||
|
price_lags: List[int] = None
|
||||||
|
rolling_windows: List[int] = None
|
||||||
|
include_time_features: bool = True
|
||||||
|
include_regional_features: bool = True
|
||||||
|
|
||||||
|
# Model
|
||||||
|
n_estimators: int = 200
|
||||||
|
max_depth: int = 6
|
||||||
|
learning_rate: float = 0.1
|
||||||
|
subsample: float = 0.8
|
||||||
|
colsample_bytree: float = 0.8
|
||||||
|
random_state: int = 42
|
||||||
|
|
||||||
|
# Early stopping
|
||||||
|
early_stopping_rounds: int = 20
|
||||||
|
early_stopping_threshold: float = 0.001
|
||||||
|
|
||||||
|
def __post_init__(self):
|
||||||
|
if self.price_lags is None:
|
||||||
|
self.price_lags = [1, 5, 10, 15, 30, 60]
|
||||||
|
if self.rolling_windows is None:
|
||||||
|
self.rolling_windows = [5, 10, 15, 30, 60]
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class RLBatteryConfig:
|
||||||
|
"""Configuration for RL battery optimization."""
|
||||||
|
|
||||||
|
# State space
|
||||||
|
charge_level_bins: int = 10
|
||||||
|
price_bins: int = 10
|
||||||
|
time_bins: int = 24 # Hours
|
||||||
|
|
||||||
|
# Action space
|
||||||
|
actions: List[str] = None
|
||||||
|
|
||||||
|
# Q-Learning
|
||||||
|
learning_rate: float = 0.1
|
||||||
|
discount_factor: float = 0.95
|
||||||
|
epsilon: float = 1.0
|
||||||
|
epsilon_decay: float = 0.995
|
||||||
|
epsilon_min: float = 0.05
|
||||||
|
|
||||||
|
# Training
|
||||||
|
episodes: int = 1000
|
||||||
|
max_steps: int = 14400 # 10 days * 1440 minutes
|
||||||
|
|
||||||
|
# Battery constraints
|
||||||
|
min_reserve: float = 0.10 # 10%
|
||||||
|
max_charge: float = 0.90 # 90%
|
||||||
|
efficiency: float = 0.90
|
||||||
|
|
||||||
|
# Reward scaling
|
||||||
|
reward_scale: float = 1.0
|
||||||
|
|
||||||
|
def __post_init__(self):
|
||||||
|
if self.actions is None:
|
||||||
|
self.actions = ["charge", "hold", "discharge"]
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class MLConfig:
|
||||||
|
"""Overall ML configuration."""
|
||||||
|
|
||||||
|
# Paths
|
||||||
|
models_path: str = "models"
|
||||||
|
results_path: str = "results"
|
||||||
|
|
||||||
|
# Price prediction
|
||||||
|
price_prediction: PricePredictionConfig = None
|
||||||
|
rl_battery: RLBatteryConfig = None
|
||||||
|
|
||||||
|
# Training
|
||||||
|
enable_gpu: bool = False
|
||||||
|
n_jobs: int = 4
|
||||||
|
verbose: bool = True
|
||||||
|
|
||||||
|
# Retraining
|
||||||
|
keep_backup: bool = True
|
||||||
|
max_backups: int = 5
|
||||||
|
|
||||||
|
def __post_init__(self):
|
||||||
|
if self.price_prediction is None:
|
||||||
|
self.price_prediction = PricePredictionConfig()
|
||||||
|
if self.rl_battery is None:
|
||||||
|
self.rl_battery = RLBatteryConfig()
|
||||||
|
|
||||||
|
|
||||||
|
# Default configuration
|
||||||
|
default_config = MLConfig()
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Feature Engineering Interface
|
||||||
|
|
||||||
|
### Key Functions (backend/app/ml/features/__init__.py)
|
||||||
|
|
||||||
|
```python
|
||||||
|
def build_price_features(
|
||||||
|
df: pd.DataFrame,
|
||||||
|
price_col: str = "real_time_price",
|
||||||
|
lags: List[int] = None,
|
||||||
|
windows: List[int] = None,
|
||||||
|
regions: List[str] = None,
|
||||||
|
include_time: bool = True,
|
||||||
|
include_regional: bool = True,
|
||||||
|
) -> pd.DataFrame:
|
||||||
|
"""
|
||||||
|
Build complete feature set for price prediction.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
df: Input DataFrame
|
||||||
|
price_col: Name of price column
|
||||||
|
lags: List of lag periods
|
||||||
|
windows: List of rolling window sizes
|
||||||
|
regions: List of regions for differential features
|
||||||
|
include_time: Whether to include time features
|
||||||
|
include_regional: Whether to include regional features
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
DataFrame with all features
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
def build_battery_features(
|
||||||
|
df: pd.DataFrame,
|
||||||
|
price_df: pd.DataFrame,
|
||||||
|
battery_col: str = "charge_level_mwh",
|
||||||
|
capacity_col: str = "capacity_mwh",
|
||||||
|
timestamp_col: str = "timestamp",
|
||||||
|
battery_id_col: str = "battery_id",
|
||||||
|
) -> pd.DataFrame:
|
||||||
|
"""
|
||||||
|
Build features for battery RL model.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
df: Battery DataFrame
|
||||||
|
price_df: Price DataFrame
|
||||||
|
battery_col: Name of battery charge level column
|
||||||
|
capacity_col: Name of battery capacity column
|
||||||
|
timestamp_col: Name of timestamp column
|
||||||
|
battery_id_col: Name of battery ID column
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
DataFrame with battery features
|
||||||
|
"""
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Price Prediction Interface
|
||||||
|
|
||||||
|
### PricePredictor (backend/app/ml/price_prediction/predictor.py)
|
||||||
|
|
||||||
|
```python
|
||||||
|
class PricePredictor:
|
||||||
|
"""Interface for making price predictions."""
|
||||||
|
|
||||||
|
def __init__(self, models_dir: str = "models/price_prediction"):
|
||||||
|
"""
|
||||||
|
Initialize predictor.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
models_dir: Directory containing trained models
|
||||||
|
"""
|
||||||
|
|
||||||
|
def predict(
|
||||||
|
self,
|
||||||
|
current_data: pd.DataFrame,
|
||||||
|
horizon: int = 15,
|
||||||
|
region: Optional[str] = None,
|
||||||
|
) -> float:
|
||||||
|
"""
|
||||||
|
Predict price for a specific horizon.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
current_data: Current/historical price data
|
||||||
|
horizon: Prediction horizon in minutes
|
||||||
|
region: Specific region to predict for (optional)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Predicted price
|
||||||
|
"""
|
||||||
|
|
||||||
|
def predict_all_horizons(
|
||||||
|
self,
|
||||||
|
current_data: pd.DataFrame,
|
||||||
|
region: Optional[str] = None,
|
||||||
|
) -> Dict[int, float]:
|
||||||
|
"""
|
||||||
|
Predict prices for all available horizons.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary mapping horizons to predictions
|
||||||
|
"""
|
||||||
|
|
||||||
|
def predict_with_confidence(
|
||||||
|
self,
|
||||||
|
current_data: pd.DataFrame,
|
||||||
|
horizon: int = 15,
|
||||||
|
region: Optional[str] = None,
|
||||||
|
) -> Dict:
|
||||||
|
"""
|
||||||
|
Predict price with confidence interval.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary with prediction and confidence interval
|
||||||
|
"""
|
||||||
|
|
||||||
|
def get_feature_importance(self, horizon: int) -> pd.DataFrame:
|
||||||
|
"""
|
||||||
|
Get feature importance for a specific horizon.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
DataFrame with feature importance
|
||||||
|
"""
|
||||||
|
```
|
||||||
|
|
||||||
|
### PricePredictionTrainer (backend/app/ml/price_prediction/trainer.py)
|
||||||
|
|
||||||
|
```python
|
||||||
|
class PricePredictionTrainer:
|
||||||
|
"""Training pipeline for price prediction models."""
|
||||||
|
|
||||||
|
def __init__(self, config: PricePredictionConfig = None):
|
||||||
|
"""Initialize trainer."""
|
||||||
|
|
||||||
|
def load_data(self) -> pd.DataFrame:
|
||||||
|
"""Load price data."""
|
||||||
|
|
||||||
|
def prepare_data(self, df: pd.DataFrame) -> Tuple[pd.DataFrame, List[str]]:
|
||||||
|
"""
|
||||||
|
Prepare data with features.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple of (features DataFrame, feature names)
|
||||||
|
"""
|
||||||
|
|
||||||
|
def train_for_horizon(
|
||||||
|
self,
|
||||||
|
df_features: pd.DataFrame,
|
||||||
|
feature_cols: List[str],
|
||||||
|
horizon: int,
|
||||||
|
) -> Dict:
|
||||||
|
"""
|
||||||
|
Train model for a specific horizon.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Training results dictionary with metrics
|
||||||
|
"""
|
||||||
|
|
||||||
|
def train_all(self, horizons: List[int] = None) -> Dict:
|
||||||
|
"""
|
||||||
|
Train models for all horizons.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary with all training results
|
||||||
|
"""
|
||||||
|
|
||||||
|
def save_models(self, output_dir: str = "models/price_prediction") -> None:
|
||||||
|
"""Save all trained models."""
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def load_models(
|
||||||
|
cls,
|
||||||
|
models_dir: str = "models/price_prediction",
|
||||||
|
horizons: List[int] = None,
|
||||||
|
) -> Dict[int, PricePredictionModel]:
|
||||||
|
"""
|
||||||
|
Load trained models.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary mapping horizons to models
|
||||||
|
"""
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## RL Battery Optimization Interface
|
||||||
|
|
||||||
|
### BatteryPolicy (backend/app/ml/rl_battery/policy.py)
|
||||||
|
|
||||||
|
```python
|
||||||
|
class BatteryPolicy:
|
||||||
|
"""Interface for RL battery policy inference."""
|
||||||
|
|
||||||
|
def __init__(self, policy_path: str = "models/rl_battery"):
|
||||||
|
"""
|
||||||
|
Initialize policy.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
policy_path: Path to trained policy
|
||||||
|
"""
|
||||||
|
|
||||||
|
def get_action(
|
||||||
|
self,
|
||||||
|
charge_level: float,
|
||||||
|
current_price: float,
|
||||||
|
price_forecast_1m: float = 0,
|
||||||
|
price_forecast_5m: float = 0,
|
||||||
|
price_forecast_15m: float = 0,
|
||||||
|
hour: int = 0,
|
||||||
|
) -> Dict:
|
||||||
|
"""
|
||||||
|
Get action for current state.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary with action, q_values
|
||||||
|
"""
|
||||||
|
```
|
||||||
|
|
||||||
|
### BatteryRLTrainer (backend/app/ml/rl_battery/trainer.py)
|
||||||
|
|
||||||
|
```python
|
||||||
|
class BatteryRLTrainer:
|
||||||
|
"""Training pipeline for RL battery policy."""
|
||||||
|
|
||||||
|
def __init__(self, config: RLBatteryConfig = None):
|
||||||
|
"""Initialize trainer."""
|
||||||
|
|
||||||
|
def load_data(self) -> None:
|
||||||
|
"""Load price data for environment."""
|
||||||
|
|
||||||
|
def train(self, n_episodes: int = 1000, region: str = "FR") -> Dict:
|
||||||
|
"""
|
||||||
|
Train RL agent.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Training results with metrics
|
||||||
|
"""
|
||||||
|
|
||||||
|
def save(self, output_dir: str = "models/rl_battery") -> None:
|
||||||
|
"""Save trained policy."""
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Model Management Interface
|
||||||
|
|
||||||
|
### ModelRegistry (backend/app/ml/model_management/registry.py)
|
||||||
|
|
||||||
|
```python
|
||||||
|
class ModelRegistry:
|
||||||
|
"""Registry for tracking model versions."""
|
||||||
|
|
||||||
|
def __init__(self, registry_path: str = "models/registry.json"):
|
||||||
|
"""Initialize registry."""
|
||||||
|
|
||||||
|
def register_model(
|
||||||
|
self,
|
||||||
|
model_type: str,
|
||||||
|
model_id: str,
|
||||||
|
version: str,
|
||||||
|
filepath: str,
|
||||||
|
metadata: Dict = None,
|
||||||
|
) -> None:
|
||||||
|
"""Register a model version."""
|
||||||
|
|
||||||
|
def get_latest_version(self, model_id: str) -> Optional[Dict]:
|
||||||
|
"""Get latest version of a model."""
|
||||||
|
|
||||||
|
def list_models(self) -> List[Dict]:
|
||||||
|
"""List all registered models."""
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Training Task Interface
|
||||||
|
|
||||||
|
### Training Jobs (app/tasks/training_tasks.py)
|
||||||
|
|
||||||
|
```python
|
||||||
|
async def train_model_task(training_id: str, request: TrainingRequest):
|
||||||
|
"""
|
||||||
|
Execute ML model training via Celery task.
|
||||||
|
|
||||||
|
Dispatches to Celery for async processing of:
|
||||||
|
- Price prediction training
|
||||||
|
- RL battery policy training
|
||||||
|
|
||||||
|
Emits WebSocket events for progress updates.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@shared_task(name="tasks.train_price_prediction")
|
||||||
|
def train_price_prediction(training_id: str, request_dict: dict):
|
||||||
|
"""
|
||||||
|
Celery task for price prediction model training.
|
||||||
|
|
||||||
|
Process:
|
||||||
|
1. Load and prepare data
|
||||||
|
2. Train XGBoost models for specified horizon
|
||||||
|
3. Save models
|
||||||
|
4. Register in model registry
|
||||||
|
5. Update training job status
|
||||||
|
"""
|
||||||
|
|
||||||
|
@shared_task(name="tasks.train_rl_battery")
|
||||||
|
def train_rl_battery(training_id: str, request_dict: dict):
|
||||||
|
"""
|
||||||
|
Celery task for RL battery policy training.
|
||||||
|
|
||||||
|
Process:
|
||||||
|
1. Load environment and data
|
||||||
|
2. Train Q-Learning agent
|
||||||
|
3. Save policy
|
||||||
|
4. Register in model registry
|
||||||
|
5. Update training job status
|
||||||
|
"""
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Key Integration Points
|
||||||
|
|
||||||
|
### ML Service Integration (app/services/ml_service.py)
|
||||||
|
|
||||||
|
The ML service provides the bridge between API routes and ML models:
|
||||||
|
|
||||||
|
```python
|
||||||
|
class MLService:
|
||||||
|
"""Service for ML model management and inference."""
|
||||||
|
|
||||||
|
def list_models(self) -> List[ModelInfo]:
|
||||||
|
"""List all available trained models."""
|
||||||
|
|
||||||
|
def get_model_metrics(self, model_id: str) -> Dict[str, float]:
|
||||||
|
"""Get performance metrics for a model."""
|
||||||
|
|
||||||
|
def predict(
|
||||||
|
self,
|
||||||
|
model_id: str,
|
||||||
|
timestamp: datetime,
|
||||||
|
features: Optional[Dict[str, Any]] = None
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Run prediction with on-demand model loading.
|
||||||
|
|
||||||
|
Supports:
|
||||||
|
- Price prediction models
|
||||||
|
- RL battery policy
|
||||||
|
"""
|
||||||
|
|
||||||
|
def get_feature_importance(self, model_id: str) -> Dict[str, float]:
|
||||||
|
"""Get feature importance for a model."""
|
||||||
|
```
|
||||||
|
|
||||||
|
### WebSocket Events
|
||||||
|
|
||||||
|
Real-time training progress updates:
|
||||||
|
|
||||||
|
```python
|
||||||
|
# Event types
|
||||||
|
- "model_training_progress" # Training progress updates
|
||||||
|
|
||||||
|
# Payload
|
||||||
|
{
|
||||||
|
"model_id": str,
|
||||||
|
"progress": float, # 0.0 to 1.0
|
||||||
|
"epoch": int, # Current epoch (optional)
|
||||||
|
"metrics": dict # Current metrics
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Data Flow
|
||||||
|
|
||||||
|
### Training Pipeline
|
||||||
|
|
||||||
|
```
|
||||||
|
API Request (POST /api/v1/models/train)
|
||||||
|
↓
|
||||||
|
training_tasks.train_model_task()
|
||||||
|
↓
|
||||||
|
Celery Task (train_price_prediction / train_rl_battery)
|
||||||
|
↓
|
||||||
|
PricePredictionTrainer / BatteryRLTrainer
|
||||||
|
↓
|
||||||
|
Feature Engineering → Model Training → Evaluation
|
||||||
|
↓
|
||||||
|
Save Models → Register in Registry
|
||||||
|
↓
|
||||||
|
WebSocket Events (progress updates)
|
||||||
|
↓
|
||||||
|
Update Training Status
|
||||||
|
```
|
||||||
|
|
||||||
|
### Prediction Pipeline
|
||||||
|
|
||||||
|
```
|
||||||
|
API Request (POST /api/v1/models/predict)
|
||||||
|
↓
|
||||||
|
ml_service.predict()
|
||||||
|
↓
|
||||||
|
Load Model (on-demand)
|
||||||
|
↓
|
||||||
|
Feature Engineering
|
||||||
|
↓
|
||||||
|
Model.predict() / Policy.get_action()
|
||||||
|
↓
|
||||||
|
Return Prediction with Confidence
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Model Artifacts
|
||||||
|
|
||||||
|
### Price Prediction Models
|
||||||
|
|
||||||
|
Location: `models/price_prediction/`
|
||||||
|
|
||||||
|
```
|
||||||
|
model_1min.pkl # 1-minute horizon model
|
||||||
|
model_5min.pkl # 5-minute horizon model
|
||||||
|
model_15min.pkl # 15-minute horizon model
|
||||||
|
model_60min.pkl # 60-minute horizon model
|
||||||
|
training_results.json # Training metrics and metadata
|
||||||
|
```
|
||||||
|
|
||||||
|
### RL Battery Policy
|
||||||
|
|
||||||
|
Location: `models/rl_battery/`
|
||||||
|
|
||||||
|
```
|
||||||
|
battery_policy.pkl # Trained Q-Learning policy
|
||||||
|
training_results.json # Training metrics
|
||||||
|
```
|
||||||
|
|
||||||
|
### Model Registry
|
||||||
|
|
||||||
|
Location: `models/registry.json`
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"models": {
|
||||||
|
"price_prediction_15m": {
|
||||||
|
"type": "price_prediction",
|
||||||
|
"versions": ["v20260211_134500", "v20260210_100000"],
|
||||||
|
"latest": "v20260211_134500"
|
||||||
|
},
|
||||||
|
"battery_policy": {
|
||||||
|
"type": "rl_battery",
|
||||||
|
"versions": ["v20260211_140000"],
|
||||||
|
"latest": "v20260211_140000"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
Reference in New Issue
Block a user