From a22a13f6f41e90a1ef3f0396adb1ee5ce377e620 Mon Sep 17 00:00:00 2001 From: kbt-devops Date: Wed, 11 Feb 2026 02:16:25 +0700 Subject: [PATCH] Add initial implementation strategy documentation Add comprehensive documentation for energy trading system: - Backend: FastAPI architecture, API routes, services, WebSocket - Frontend: React structure, components, state management - ML: Feature engineering, XGBoost price prediction, RL battery optimization --- BACKEND_IMPLEMENTATION.md | 750 +++++++++++++++++++++++++++++++++++ FRONTEND_IMPLEMENTATION.md | 780 +++++++++++++++++++++++++++++++++++++ ML_IMPLEMENTATION.md | 679 ++++++++++++++++++++++++++++++++ 3 files changed, 2209 insertions(+) create mode 100644 BACKEND_IMPLEMENTATION.md create mode 100644 FRONTEND_IMPLEMENTATION.md create mode 100644 ML_IMPLEMENTATION.md diff --git a/BACKEND_IMPLEMENTATION.md b/BACKEND_IMPLEMENTATION.md new file mode 100644 index 0000000..84aa587 --- /dev/null +++ b/BACKEND_IMPLEMENTATION.md @@ -0,0 +1,750 @@ +# Backend Implementation Strategy + +## Overview + +This document outlines the FastAPI backend for the energy trading system UI. The backend serves data, executes strategies, runs backtests, and provides real-time updates via WebSockets. + +**Data Source**: `~/energy-test-data/data/processed/` + +--- + +## Architecture + +``` +┌──────────────────────────────────────────────────────────────┐ +│ FastAPI Application │ +├──────────────────────────────────────────────────────────────┤ +│ │ +│ ┌─────────────┬─────────────┬─────────────┬──────────────┐ │ +│ │ API │ Services │ Tasks │ WebSocket │ │ +│ │ Routes │ Layer │ (Celery) │ Manager │ │ +│ └─────────────┴─────────────┴─────────────┴──────────────┘ │ +│ ┌──────────┐ │ +│ │ Data │ │ +│ │ Cache │ │ +│ └──────────┘ │ +└──────────────────────────────────────────────────────────────┘ + │ + ▼ +┌──────────────────────────────────────────────────────────────┐ +│ Core Trading Engine (Imported) │ +│ - Fundamental Strategy │ +│ - Technical Analysis │ +│ - ML Models (Price Prediction, RL Battery) │ +│ - Backtesting Engine │ +└──────────────────────────────────────────────────────────────┘ + │ + ▼ +┌──────────────────────────────────────────────────────────────┐ +│ Data Source │ +│ ~/energy-test-data/data/processed/*.parquet │ +└──────────────────────────────────────────────────────────────┘ +``` + +--- + +## Project Structure + +``` +backend/ +├── app/ +│ ├── __init__.py +│ ├── main.py # FastAPI app entry +│ ├── config.py # Configuration management +│ │ +│ ├── api/ +│ │ ├── __init__.py +│ │ ├── routes/ +│ │ │ ├── __init__.py +│ │ │ ├── dashboard.py # Dashboard data endpoints +│ │ │ ├── backtest.py # Backtest execution +│ │ │ ├── models.py # ML model endpoints +│ │ │ ├── trading.py # Trading control +│ │ │ └── settings.py # Configuration management +│ │ └── websocket.py # WebSocket connection manager +│ │ +│ ├── services/ +│ │ ├── __init__.py +│ │ ├── data_service.py # Data loading and caching +│ │ ├── strategy_service.py # Strategy execution +│ │ ├── ml_service.py # ML model management +│ │ ├── trading_service.py # Trading operations +│ │ └── alert_service.py # Alert management +│ │ +│ ├── tasks/ +│ │ ├── __init__.py +│ │ ├── backtest_tasks.py # Async backtest execution +│ │ ├── training_tasks.py # ML model training +│ │ └── monitoring_tasks.py # Real-time data updates +│ │ +│ ├── ml/ # ML models and training +│ │ ├── __init__.py +│ │ ├── features/ +│ │ │ ├── __init__.py +│ │ │ ├── lag_features.py +│ │ │ ├── rolling_features.py +│ │ │ ├── time_features.py +│ │ │ ├── regional_features.py +│ │ │ └── battery_features.py +│ │ │ +│ │ ├── price_prediction/ +│ │ │ ├── __init__.py +│ │ │ ├── model.py +│ │ │ ├── trainer.py +│ │ │ └── predictor.py +│ │ │ +│ │ ├── rl_battery/ +│ │ │ ├── __init__.py +│ │ │ ├── environment.py +│ │ │ ├── agent.py +│ │ │ ├── trainer.py +│ │ │ └── policy.py +│ │ │ +│ │ ├── model_management/ +│ │ │ ├── __init__.py +│ │ │ ├── registry.py +│ │ │ ├── persistence.py +│ │ │ ├── versioning.py +│ │ │ └── comparison.py +│ │ │ +│ │ ├── evaluation/ +│ │ │ ├── __init__.py +│ │ │ ├── metrics.py +│ │ │ ├── backtest_evaluator.py +│ │ │ └── reports.py +│ │ │ +│ │ ├── training/ +│ │ │ ├── __init__.py +│ │ │ └── cli.py +│ │ │ +│ │ └── utils/ +│ │ ├── __init__.py +│ │ ├── data_split.py +│ │ ├── config.py +│ │ └── evaluation.py +│ │ +│ ├── models/ +│ │ ├── __init__.py +│ │ ├── schemas.py # Pydantic models +│ │ └── enums.py # Enumerations +│ │ +│ ├── core/ +│ │ ├── __init__.py +│ │ └── constants.py # Constants and defaults +│ │ +│ └── utils/ +│ ├── __init__.py +│ ├── logger.py +│ └── helpers.py +│ +├── tests/ +│ ├── __init__.py +│ ├── conftest.py +│ ├── test_api/ +│ ├── test_services/ +│ └── test_websocket.py +│ +├── models/ # Trained ML models storage +│ ├── price_prediction/ +│ │ ├── model_1min.pkl +│ │ ├── model_5min.pkl +│ │ ├── model_15min.pkl +│ │ └── model_60min.pkl +│ └── rl_battery/ +│ └── battery_policy.pkl +│ +├── results/ # Backtest results storage +│ └── backtests/ +│ +├── .env.example +├── requirements.txt +├── pyproject.toml +└── Dockerfile +``` + +--- + +## Configuration + +### app/config.py (Settings) + +```python +from pydantic_settings import BaseSettings +from pathlib import Path +from typing import List + +class Settings(BaseSettings): + # Application + APP_NAME: str = "Energy Trading API" + APP_VERSION: str = "1.0.0" + DEBUG: bool = True + + # Server + HOST: str = "0.0.0.0" + PORT: int = 8000 + + # Data + DATA_PATH: str = "~/energy-test-data/data/processed" + DATA_PATH_RESOLVED: Path = Path(DATA_PATH).expanduser() + + # CORS + CORS_ORIGINS: List[str] = [ + "http://localhost:3000", + "http://localhost:5173", + ] + + # WebSocket + WS_HEARTBEAT_INTERVAL: int = 30 + + # Celery + CELERY_BROKER_URL: str = "redis://localhost:6379/0" + CELERY_RESULT_BACKEND: str = "redis://localhost:6379/0" + + # Models + MODELS_PATH: str = "models" + RESULTS_PATH: str = "results" + + # Battery + BATTERY_MIN_RESERVE: float = 0.10 + BATTERY_MAX_CHARGE: float = 0.90 + + # Arbitrage + ARBITRAGE_MIN_SPREAD: float = 5.0 # EUR/MWh + + # Mining + MINING_MARGIN_THRESHOLD: float = 5.0 # EUR/MWh + + # ML + ML_PREDICTION_HORIZONS: List[int] = [1, 5, 15, 60] + ML_FEATURE_LAGS: List[int] = [1, 5, 10, 15, 30, 60] + + class Config: + env_file = ".env" + case_sensitive = True + +settings = Settings() +``` + +--- + +## Data Models (app/models/schemas.py) + +### Enums + +```python +class RegionEnum(str, Enum): + FR = "FR" + BE = "BE" + DE = "DE" + NL = "NL" + UK = "UK" + +class FuelTypeEnum(str, Enum): + GAS = "gas" + NUCLEAR = "nuclear" + COAL = "coal" + SOLAR = "solar" + WIND = "wind" + HYDRO = "hydro" + +class StrategyEnum(str, Enum): + FUNDAMENTAL = "fundamental" + TECHNICAL = "technical" + ML = "ml" + MINING = "mining" + +class TradeTypeEnum(str, Enum): + BUY = "buy" + SELL = "sell" + CHARGE = "charge" + DISCHARGE = "discharge" + +class BacktestStatusEnum(str, Enum): + PENDING = "pending" + RUNNING = "running" + COMPLETED = "completed" + FAILED = "failed" + CANCELLED = "cancelled" + +class ModelType(str, Enum): + PRICE_PREDICTION = "price_prediction" + RL_BATTERY = "rl_battery" + +class AlertTypeEnum(str, Enum): + PRICE_SPIKE = "price_spike" + ARBITRAGE_OPPORTUNITY = "arbitrage_opportunity" + BATTERY_LOW = "battery_low" + BATTERY_FULL = "battery_full" + STRATEGY_ERROR = "strategy_error" +``` + +### Key Schemas + +```python +class PriceData(BaseModel): + timestamp: datetime + region: RegionEnum + day_ahead_price: float + real_time_price: float + volume_mw: float + +class BatteryState(BaseModel): + timestamp: datetime + battery_id: str + capacity_mwh: float + charge_level_mwh: float + charge_rate_mw: float + discharge_rate_mw: float + efficiency: float + charge_level_pct: float = Field(default_factory=lambda: 0.0) + +class BacktestConfig(BaseModel): + start_date: str + end_date: str + strategies: List[StrategyEnum] = Field(default_factory=list) + use_ml: bool = True + battery_min_reserve: Optional[float] = None + battery_max_charge: Optional[float] = None + arbitrage_min_spread: Optional[float] = None + +class BacktestMetrics(BaseModel): + total_revenue: float + arbitrage_profit: float + battery_revenue: float + mining_profit: float + battery_utilization: float + price_capture_rate: float + win_rate: float + sharpe_ratio: float + max_drawdown: float + total_trades: int + +class TrainingRequest(BaseModel): + model_type: ModelType + horizon: Optional[int] = None + start_date: str + end_date: str + hyperparameters: Dict[str, Any] = Field(default_factory=dict) + +class PredictionResponse(BaseModel): + model_id: str + timestamp: datetime + prediction: float + confidence: Optional[float] = None + features_used: List[str] = Field(default_factory=list) +``` + +--- + +## API Routes + +### Dashboard API (`/api/v1/dashboard/*`) + +```python +# GET /api/v1/dashboard/summary +Response: DashboardSummary + +# GET /api/v1/dashboard/prices +Response: { regions: { [region]: { timestamp, day_ahead_price, real_time_price, volume_mw } } } + +# GET /api/v1/dashboard/prices/history?region={region}&start={start}&end={end}&limit={limit} +Response: { region, data: PriceData[] } + +# GET /api/v1/dashboard/battery +Response: { batteries: BatteryState[] } + +# GET /api/v1/dashboard/arbitrage?min_spread={min_spread} +Response: { opportunities: ArbitrageOpportunity[], count: int } +``` + +### Backtest API (`/api/v1/backtest/*`) + +```python +# POST /api/v1/backtest/start +Request: { config: BacktestConfig, name?: string } +Response: { backtest_id: string, status: BacktestStatus } + +# GET /api/v1/backtest/{backtest_id} +Response: { status: BacktestStatus, results?: BacktestResult } + +# GET /api/v1/backtest/{backtest_id}/results +Response: BacktestResult + +# GET /api/v1/backtest/{backtest_id}/trades?limit={limit} +Response: { backtest_id, trades: Trade[], total: int } + +# GET /api/v1/backtest/history +Response: { backtests: BacktestStatus[], total: int } + +# DELETE /api/v1/backtest/{backtest_id} +Response: { message: string } +``` + +### Models API (`/api/v1/models/*`) + +```python +# GET /api/v1/models +Response: { models: ModelInfo[], total: int } + +# POST /api/v1/models/train +Request: TrainingRequest +Response: { training_id: string, status: TrainingStatus } + +# GET /api/v1/models/{model_id}/status +Response: TrainingStatus + +# GET /api/v1/models/{model_id}/metrics +Response: { model_id, metrics: dict } + +# POST /api/v1/models/predict +Request: { model_id, timestamp, features?: dict } +Response: PredictionResponse +``` + +### Trading API (`/api/v1/trading/*`) + +```python +# GET /api/v1/trading/strategies +Response: { strategies: StrategyStatus[] } + +# POST /api/v1/trading/strategies +Request: { strategy: StrategyEnum, action: "start" | "stop" } +Response: { status: StrategyStatus } + +# GET /api/v1/trading/positions +Response: { positions: TradingPosition[] } +``` + +### Settings API (`/api/v1/settings/*`) + +```python +# GET /api/v1/settings +Response: AppSettings + +# POST /api/v1/settings +Request: Partial +Response: { message, updated_fields: string[] } +``` + +--- + +## Services Interface + +### DataService (app/services/data_service.py) + +```python +class DataService: + """Data loading and caching service.""" + + async def initialize(self): + """Load all datasets into memory.""" + + def get_latest_prices(self) -> Dict[str, Dict]: + """Get latest prices for all regions.""" + + def get_price_history(self, region, start=None, end=None, limit=1000) -> List[Dict]: + """Get price history for a region.""" + + def get_battery_states(self) -> List[Dict]: + """Get current battery states.""" + + def get_arbitrage_opportunities(self, min_spread=None) -> List[Dict]: + """Get current arbitrage opportunities.""" + + def get_dashboard_summary(self) -> Dict: + """Get overall dashboard summary.""" +``` + +### MLService (app/services/ml_service.py) + +```python +class MLService: + """Service for ML model management and inference.""" + + def list_models(self) -> List[ModelInfo]: + """List all available trained models.""" + + def get_model_metrics(self, model_id: str) -> Dict[str, float]: + """Get performance metrics for a model.""" + + def load_price_prediction_model(self, model_id: str): + """Load price prediction model on-demand.""" + + def load_rl_battery_policy(self, model_id: str): + """Load RL battery policy on-demand.""" + + def predict( + self, + model_id: str, + timestamp: datetime, + features: Optional[Dict[str, Any]] = None + ) -> Dict[str, Any]: + """Run prediction with on-demand model loading.""" + + def predict_with_confidence( + self, + model_id: str, + timestamp: datetime, + features: Optional[Dict[str, Any]] = None + ) -> Dict[str, Any]: + """Run prediction with confidence interval.""" + + def get_feature_importance(self, model_id: str) -> Dict[str, float]: + """Get feature importance for a model.""" + + def get_model_info(self, model_id: str) -> Optional[ModelInfo]: + """Get detailed info about a specific model.""" +``` + +### StrategyService (app/services/strategy_service.py) + +```python +class StrategyService: + """Strategy execution service.""" + + async def execute_strategy( + self, + strategy: StrategyEnum, + config: Dict = None + ) -> Dict: + """Execute a trading strategy.""" + + async def get_strategy_status(self, strategy: StrategyEnum) -> StrategyStatus: + """Get current status of a strategy.""" + + async def toggle_strategy( + self, + strategy: StrategyEnum, + action: str + ) -> StrategyStatus: + """Start or stop a strategy.""" +``` + +--- + +## Tasks Interface + +### Backtest Tasks (app/tasks/backtest_tasks.py) + +```python +async def run_backtest_task(backtest_id: str, config: Dict, name: str = None): + """ + Execute backtest in background. + + Process: + 1. Load data + 2. Execute strategies + 3. Calculate metrics + 4. Save results + 5. Emit WebSocket progress events + """ +``` + +### Training Tasks (app/tasks/training_tasks.py) + +```python +async def train_model_task(training_id: str, request: TrainingRequest): + """ + Execute ML model training via Celery task. + + Dispatches to Celery for async processing. + Emits WebSocket events for progress updates. + """ + +@shared_task(name="tasks.train_price_prediction") +def train_price_prediction(training_id: str, request_dict: dict): + """Celery task for price prediction model training.""" + +@shared_task(name="tasks.train_rl_battery") +def train_rl_battery(training_id: str, request_dict: dict): + """Celery task for RL battery policy training.""" +``` + +--- + +## WebSocket Interface + +### ConnectionManager (app/api/websocket.py) + +```python +class ConnectionManager: + """WebSocket connection manager.""" + + async def connect(self, websocket: WebSocket): + """Accept and track new connection.""" + + def disconnect(self, websocket: WebSocket): + """Remove connection.""" + + async def broadcast(self, event_type: str, data: Any): + """Broadcast event to all connected clients.""" + + # Specific event broadcasters + async def broadcast_price_update(self, region: str, price_data: Dict): + """Broadcast price update.""" + + async def broadcast_battery_update(self, battery_id: str, battery_state: Dict): + """Broadcast battery state update.""" + + async def broadcast_trade(self, trade: Dict): + """Broadcast new trade execution.""" + + async def broadcast_alert(self, alert: Dict): + """Broadcast new alert.""" + + async def broadcast_backtest_progress(self, backtest_id: str, progress: float, status: str): + """Broadcast backtest progress.""" + + async def broadcast_model_training_progress( + self, + model_id: str, + progress: float, + epoch: Optional[int] = None, + metrics: Optional[Dict] = None + ): + """Broadcast model training progress.""" +``` + +### WebSocket Events + +```python +# Event types +"price_update" # Real-time price changes +"battery_update" # Battery state changes +"arbitrage_opportunity" # New arbitrage opportunity +"trade_executed" # Trade execution +"alert_triggered" # Alert triggered +"backtest_progress" # Backtest progress +"model_training_progress" # Training progress +``` + +--- + +## Main Application + +### app/main.py + +```python +from fastapi import FastAPI, WebSocket +from fastapi.middleware.cors import CORSMiddleware +from contextlib import asynccontextmanager + +app = FastAPI( + title=settings.APP_NAME, + version=settings.APP_VERSION, + docs_url="/docs", + redoc_url="/redoc", +) + +# CORS middleware +app.add_middleware(CORSMiddleware, ...) + +# Include routers +app.include_router(dashboard.router, prefix="/api/v1/dashboard", tags=["dashboard"]) +app.include_router(backtest.router, prefix="/api/v1/backtest", tags=["backtest"]) +app.include_router(models.router, prefix="/api/v1/models", tags=["models"]) +app.include_router(trading.router, prefix="/api/v1/trading", tags=["trading"]) +app.include_router(settings_routes.router, prefix="/api/v1/settings", tags=["settings"]) + +# Health check +@app.get("/health") +async def health_check(): + return { "status": "healthy" } + +# WebSocket endpoint +@app.websocket("/ws/real-time") +async def websocket_endpoint(websocket: WebSocket): + await manager.connect(websocket) +``` + +--- + +## Dependencies + +### requirements.txt + +``` +# FastAPI & Server +fastapi>=0.104.0 +uvicorn[standard]>=0.24.0 +pydantic>=2.4.0 +pydantic-settings>=2.0.0 + +# Data Processing +pandas>=2.1.0 +numpy>=1.24.0 +pyarrow>=14.0.0 + +# Machine Learning +xgboost>=2.0.0 +scikit-learn>=1.3.0 + +# Reinforcement Learning +gymnasium>=0.29.0 +stable-baselines3>=2.0.0 + +# Background Tasks +celery>=5.3.0 +redis>=5.0.0 + +# WebSockets +websockets>=12.0.0 + +# Database +sqlalchemy>=2.0.0 +alembic>=1.12.0 + +# Utilities +python-multipart>=0.0.6 +python-jose[cryptography]>=3.3.0 +python-dotenv>=1.0.0 + +# Testing +pytest>=7.4.0 +pytest-asyncio>=0.21.0 +httpx>=0.25.0 + +# Logging +loguru>=0.7.0 +``` + +--- + +## Environment Variables + +### .env.example + +```bash +# Application +APP_NAME=Energy Trading API +APP_VERSION=1.0.0 +DEBUG=true + +# Server +HOST=0.0.0.0 +PORT=8000 + +# Data +DATA_PATH=~/energy-test-data/data/processed + +# CORS +CORS_ORIGINS=http://localhost:3000,http://localhost:5173 + +# Celery +CELERY_BROKER_URL=redis://localhost:6379/0 +CELERY_RESULT_BACKEND=redis://localhost:6379/0 + +# Paths +MODELS_PATH=models +RESULTS_PATH=results + +# Battery +BATTERY_MIN_RESERVE=0.10 +BATTERY_MAX_CHARGE=0.90 + +# Arbitrage +ARBITRAGE_MIN_SPREAD=5.0 + +# Mining +MINING_MARGIN_THRESHOLD=5.0 +``` diff --git a/FRONTEND_IMPLEMENTATION.md b/FRONTEND_IMPLEMENTATION.md new file mode 100644 index 0000000..7d0cee9 --- /dev/null +++ b/FRONTEND_IMPLEMENTATION.md @@ -0,0 +1,780 @@ +# Frontend Implementation Strategy + +## Overview + +This document outlines the React frontend for the energy trading system UI. The frontend provides real-time monitoring, backtesting tools, ML model insights, and trading controls. + +**Backend API**: `http://localhost:8000` + +--- + +## Architecture + +``` +┌──────────────────────────────────────────────────────────────┐ +│ React Application │ +├──────────────────────────────────────────────────────────────┤ +│ │ +│ ┌──────────────────────────────────────────────────────┐ │ +│ │ Pages Layer │ │ +│ │ Dashboard │ Backtest │ Models │ Trading │ Settings │ │ +│ └──────────────────────────────────────────────────────┘ │ +│ │ │ +│ ┌──────────────────────────────────────────────────────┐ │ +│ │ Components Layer │ │ +│ │ Charts │ Forms │ Alerts │ Tables │ Controls │ │ +│ └──────────────────────────────────────────────────────┘ │ +│ │ │ +│ ┌──────────────────────────────────────────────────────┐ │ +│ │ Services Layer │ │ +│ │ API Client │ WebSocket │ State Management │ │ +│ └──────────────────────────────────────────────────────┘ │ +└──────────────────────────────────────────────────────────────┘ + │ + ▼ +┌──────────────────────────────────────────────────────────────┐ +│ FastAPI Backend │ +│ REST API + WebSocket │ +└──────────────────────────────────────────────────────────────┘ +``` + +--- + +## Project Structure + +``` +frontend/ +├── public/ +│ ├── favicon.ico +│ └── index.html +│ +├── src/ +│ ├── App.tsx # Main app component +│ ├── main.tsx # Entry point +│ ├── index.css # Global styles +│ │ +│ ├── components/ +│ │ ├── common/ +│ │ │ ├── Header.tsx +│ │ │ ├── Sidebar.tsx +│ │ │ ├── Loading.tsx +│ │ │ └── Error.tsx +│ │ │ +│ │ ├── charts/ +│ │ │ ├── PriceChart.tsx +│ │ │ ├── BatteryChart.tsx +│ │ │ ├── PnLChart.tsx +│ │ │ ├── GenerationChart.tsx +│ │ │ └── ModelMetricsChart.tsx +│ │ │ +│ │ ├── alerts/ +│ │ │ ├── AlertPanel.tsx +│ │ │ └── AlertItem.tsx +│ │ │ +│ │ ├── tables/ +│ │ │ ├── ArbitrageTable.tsx +│ │ │ ├── TradeLogTable.tsx +│ │ │ └── ModelListTable.tsx +│ │ │ +│ │ └── forms/ +│ │ ├── BacktestForm.tsx +│ │ ├── TrainingForm.tsx +│ │ └── SettingsForm.tsx +│ │ +│ ├── pages/ +│ │ ├── Dashboard.tsx +│ │ ├── Backtest.tsx +│ │ ├── Models.tsx +│ │ ├── Trading.tsx +│ │ └── Settings.tsx +│ │ +│ ├── hooks/ +│ │ ├── useWebSocket.ts +│ │ ├── useApi.ts +│ │ ├── useBacktest.ts +│ │ ├── useModels.ts +│ │ └── useTrading.ts +│ │ +│ ├── services/ +│ │ ├── api.ts # REST API client +│ │ ├── websocket.ts # WebSocket client +│ │ └── types.ts # TypeScript types +│ │ +│ ├── store/ +│ │ ├── index.ts # Zustand store setup +│ │ ├── dashboardSlice.ts +│ │ ├── backtestSlice.ts +│ │ ├── modelsSlice.ts +│ │ └── tradingSlice.ts +│ │ +│ └── lib/ +│ ├── utils.ts +│ ├── constants.ts +│ └── formatters.ts +│ +├── tests/ +├── .env.example +├── .env.local +├── package.json +├── tsconfig.json +├── vite.config.ts +└── tailwind.config.js +``` + +--- + +## Configuration + +### vite.config.ts + +```typescript +import { defineConfig } from 'vite'; +import react from '@vitejs/plugin-react'; +import path from 'path'; + +export default defineConfig({ + plugins: [react()], + resolve: { + alias: { + '@': path.resolve(__dirname, './src'), + }, + }, + server: { + port: 3000, + proxy: { + '/api': { + target: 'http://localhost:8000', + changeOrigin: true, + }, + '/ws': { + target: 'ws://localhost:8000', + ws: true, + }, + }, + }, +}); +``` + +### .env.example + +```bash +VITE_API_URL=http://localhost:8000 +VITE_WS_URL=ws://localhost:8000/ws/real-time +``` + +--- + +## TypeScript Types + +### services/types.ts + +#### Enums + +```typescript +export enum Region { + FR = "FR", + BE = "BE", + DE = "DE", + NL = "NL", + UK = "UK", +} + +export enum Strategy { + FUNDAMENTAL = "fundamental", + TECHNICAL = "technical", + ML = "ml", + MINING = "mining", +} + +export enum TradeType { + BUY = "buy", + SELL = "sell", + CHARGE = "charge", + DISCHARGE = "discharge", +} + +export enum BacktestStatus { + PENDING = "pending", + RUNNING = "running", + COMPLETED = "completed", + FAILED = "failed", + CANCELLED = "cancelled", +} + +export enum ModelType { + PRICE_PREDICTION = "price_prediction", + RL_BATTERY = "rl_battery", +} +``` + +#### Key Interfaces + +```typescript +export interface PriceData { + timestamp: string; + region: Region; + day_ahead_price: number; + real_time_price: number; + volume_mw: number; +} + +export interface BatteryState { + battery_id: string; + timestamp: string; + capacity_mwh: number; + charge_level_mwh: number; + charge_level_pct: number; + charge_rate_mw: number; + discharge_rate_mw: number; + efficiency: number; +} + +export interface BacktestConfig { + start_date: string; + end_date: string; + strategies: Strategy[]; + use_ml: boolean; + battery_min_reserve?: number; + battery_max_charge?: number; + arbitrage_min_spread?: number; +} + +export interface BacktestMetrics { + total_revenue: number; + arbitrage_profit: number; + battery_revenue: number; + mining_profit: number; + battery_utilization: number; + price_capture_rate: number; + win_rate: number; + sharpe_ratio: number; + max_drawdown: number; + total_trades: number; +} + +export interface BacktestStatus { + id: string; + status: BacktestStatus; + progress: number; + current_step: string; + started_at?: string; + completed_at?: string; + error?: string; +} + +export interface ModelInfo { + id: string; + type: ModelType; + name: string; + horizon?: number; + created_at: string; + metrics: Record; + status: string; +} + +export interface TrainingRequest { + model_type: ModelType; + horizon?: number; + start_date: string; + end_date: string; + hyperparameters: Record; +} + +export interface TrainingStatus { + id: string; + status: BacktestStatus; + progress: number; + current_epoch?: number; + total_epochs?: number; + metrics: Record; + started_at?: string; + completed_at?: string; + error?: string; +} + +export interface PredictionRequest { + model_id: string; + timestamp: string; + features?: Record; +} + +export interface PredictionResponse { + model_id: string; + timestamp: string; + prediction: number; + confidence?: number; + features_used: string[]; +} + +export interface StrategyStatus { + strategy: Strategy; + running: boolean; + last_updated?: string; +} + +export interface TradingPosition { + region: Region; + position_mw: number; + battery_charge_pct: number; + pnl: number; +} + +export interface Alert { + id: string; + timestamp: string; + type: AlertType; + severity: "info" | "warning" | "error"; + message: string; + data: Record; + acknowledged: boolean; +} +``` + +--- + +## API Client Interface + +### services/api.ts + +#### Dashboard API + +```typescript +export const dashboardApi = { + getSummary: async (): Promise => { }, + + getPrices: async (): Promise> => { }, + + getPriceHistory: async ( + region: string, + start?: string, + end?: string, + limit?: number + ): Promise<{ region: string; data: PriceData[] }> => { }, + + getBatteryStates: async (): Promise => { }, + + getArbitrage: async (minSpread?: number): Promise<{ + opportunities: ArbitrageOpportunity[]; + count: number; + }> => { }, +}; +``` + +#### Backtest API + +```typescript +export const backtestApi = { + start: async (request: BacktestRequest): Promise<{ + backtest_id: string; + status: BacktestStatus; + }> => { }, + + get: async (backtestId: string): Promise<{ + status: BacktestStatus; + results?: BacktestResult; + }> => { }, + + getResults: async (backtestId: string): Promise => { }, + + getTrades: async (backtestId: string, limit?: number): Promise<{ + backtest_id: string; + trades: Trade[]; + total: number; + }> => { }, + + list: async (): Promise<{ + backtests: BacktestStatus[]; + total: number; + }> => { }, + + delete: async (backtestId: string): Promise<{ message: string }> => { }, +}; +``` + +#### Models API + +```typescript +export const modelsApi = { + list: async (): Promise<{ models: ModelInfo[]; total: number }> => { }, + + train: async (request: TrainingRequest): Promise<{ + training_id: string; + status: TrainingStatus; + }> => { }, + + getStatus: async (modelId: string): Promise => { }, + + getMetrics: async (modelId: string): Promise<{ + model_id: string; + metrics: Record; + }> => { }, + + predict: async (request: PredictionRequest): Promise => { }, +}; +``` + +#### Trading API + +```typescript +export const tradingApi = { + getStrategies: async (): Promise<{ strategies: StrategyStatus[] }> => { }, + + toggleStrategy: async (control: { + strategy: Strategy; + action: "start" | "stop"; + }): Promise<{ status: StrategyStatus }> => { }, + + getPositions: async (): Promise<{ positions: TradingPosition[] }> => { }, +}; +``` + +#### Settings API + +```typescript +export const settingsApi = { + get: async (): Promise => { }, + + update: async (settings: Partial): Promise<{ + message: string; + updated_fields: string[]; + }> => { }, +}; +``` + +--- + +## WebSocket Client Interface + +### services/websocket.ts + +```typescript +class WebSocketService { + private ws: WebSocket | null = null; + private url: string; + private eventHandlers: Map> = new Map(); + private isConnected = false; + + constructor(url: string = import.meta.env.VITE_WS_URL); + + connect(): void; + disconnect(): void; + + subscribe( + eventType: WebSocketEventType, + handler: (data: T) => void + ): () => void; // Returns unsubscribe function + + getConnectionStatus(): boolean; + + private handleMessage(message: WebSocketMessage): void; + private attemptReconnect(): void; +} + +export const webSocketService = new WebSocketService(); +``` + +### WebSocket Event Types + +```typescript +export type WebSocketEventType = + | "price_update" + | "battery_update" + | "arbitrage_opportunity" + | "trade_executed" + | "alert_triggered" + | "backtest_progress" + | "model_training_progress"; + +export interface WebSocketMessage { + type: WebSocketEventType; + timestamp: string; + data: T; +} +``` + +--- + +## State Management (Zustand) + +### store/index.ts + +```typescript +interface DashboardState { + summary: DashboardSummary | null; + prices: Record; + batteryStates: BatteryState[]; + arbitrageOpportunities: ArbitrageOpportunity[]; + alerts: Alert[]; + + updateSummary: (summary: DashboardSummary) => void; + updatePrices: (prices: Record) => void; + addAlert: (alert: Alert) => void; + // ... +} + +interface BacktestState { + backtests: Record; + currentBacktest: string | null; + isRunning: boolean; + + updateBacktest: (status: BacktestStatus) => void; + setCurrentBacktest: (backtestId: string | null) => void; + // ... +} + +interface ModelsState { + models: ModelInfo[]; + trainingJobs: Record; + selectedModel: string | null; + + setModels: (models: ModelInfo[]) => void; + updateTrainingJob: (job: TrainingStatus) => void; + // ... +} + +interface TradingState { + strategies: Record; + positions: TradingPosition[]; + pnl: number; + + updateStrategy: (strategy: StrategyStatus) => void; + updatePositions: (positions: TradingPosition[]) => void; + // ... +} + +export const useStore = create((set) => ({ + // ... state and actions +})); +``` + +--- + +## Custom Hooks + +### hooks/useWebSocket.ts + +```typescript +export function useWebSocket() { + const subscribe = ( + eventType: WebSocketEventType, + handler: (data: T) => void + ): (() => void) => { + return webSocketService.subscribe(eventType, handler); + }; + + const isConnected = webSocketService.getConnectionStatus(); + + return { subscribe, isConnected }; +} +``` + +### hooks/useApi.ts + +```typescript +// Dashboard hooks +export function useDashboardSummary(); +export function usePrices(); +export function useBatteryStates(); +export function useArbitrageOpportunities(minSpread?: number); + +// Backtest hooks +export function useStartBacktest(); +export function useBacktest(backtestId: string); +export function useBacktestList(); + +// Models hooks +export function useModels(); +export function useTrainModel(); + +// Trading hooks +export function useStrategies(); +export function useToggleStrategy(); +export function usePositions(); + +// Settings hooks +export function useSettings(); +export function useUpdateSettings(); +``` + +--- + +## Pages Interface + +### Dashboard.tsx + +```typescript +export default function Dashboard() { + const { subscribe } = useWebSocket(); + const { prices, batteryStates, arbitrageOpportunities, alerts } = useStore(); + const { data: pricesData } = usePrices(); + + // Subscribe to real-time updates + useEffect(() => { + const unsubscribePrice = subscribe('price_update', (data) => { }); + const unsubscribeBattery = subscribe('battery_update', (data) => { }); + const unsubscribeAlert = subscribe('alert_triggered', (data) => { }); + + return () => { + unsubscribePrice(); + unsubscribeBattery(); + unsubscribeAlert(); + }; + }, [subscribe]); + + // Render stats cards, charts, tables +} +``` + +### Backtest.tsx + +```typescript +export default function Backtest() { + const [selectedBacktestId, setSelectedBacktestId] = useState(null); + + const { mutate: startBacktest, isPending } = useStartBacktest(); + const { data: backtest } = useBacktest(selectedBacktestId || ''); + const { data: backtestList } = useBacktestList(); + + const handleStartBacktest = (config: any) => { + startBacktest({ config, name: config.name }, { + onSuccess: (data) => { + setSelectedBacktestId(data.backtest_id); + }, + }); + }; + + // Render form, results, progress +} +``` + +### Models.tsx + +```typescript +export default function Models() { + const [selectedModel, setSelectedModel] = useState(null); + const [showTrainingForm, setShowTrainingForm] = useState(false); + + const { data: modelsData } = useModels(); + const { mutate: trainModel, isPending } = useTrainModel(); + + const handleTrainModel = (config: any) => { + trainModel(config, { + onSuccess: () => { + setShowTrainingForm(false); + }, + }); + }; + + // Render model list, training form, model details +} +``` + +### Trading.tsx + +```typescript +export default function Trading() { + const { data: strategiesData } = useStrategies(); + const { data: positionsData } = usePositions(); + const { mutate: toggleStrategy } = useToggleStrategy(); + + const handleToggleStrategy = (strategyName: string, running: boolean) => { + toggleStrategy({ + strategy: strategyName as any, + action: running ? 'stop' : 'start', + }); + }; + + // Render strategy status, positions, controls +} +``` + +--- + +## Dependencies + +### package.json + +```json +{ + "name": "energy-trading-ui", + "version": "1.0.0", + "type": "module", + "scripts": { + "dev": "vite", + "build": "tsc && vite build", + "preview": "vite preview", + "test": "vitest", + "type-check": "tsc --noEmit" + }, + "dependencies": { + "react": "^18.2.0", + "react-dom": "^18.2.0", + "react-router-dom": "^6.20.0", + "recharts": "^2.10.0", + "zustand": "^4.4.0", + "@tanstack/react-query": "^5.0.0", + "axios": "^1.6.0", + "date-fns": "^2.30.0", + "lucide-react": "^0.292.0" + }, + "devDependencies": { + "@types/react": "^18.2.37", + "@types/react-dom": "^18.2.15", + "@vitejs/plugin-react": "^4.2.0", + "typescript": "^5.2.2", + "vite": "^5.0.0", + "vitest": "^1.0.0", + "tailwindcss": "^3.3.5" + } +} +``` + +--- + +## Data Flow + +### API Calls + +``` +Component → useApi Hook → API Client → Axios → Backend + ↓ + React Query Cache + ↓ + Automatic Refetch +``` + +### WebSocket Updates + +``` +WebSocket Event → WebSocketService → Event Handler → Store Update → Component Re-render +``` + +### State Synchronization + +``` +API Response → React Query Cache → Component Props → Zustand Store (optional) +WebSocket Event → Zustand Store → Component Re-render +``` + +--- + +## Key Integration Points + +### Backend API Integration + +- REST API endpoints are fully typed in `services/types.ts` +- API client methods in `services/api.ts` match backend routes +- WebSocket events match backend event types + +### Error Handling + +- API errors handled by React Query error states +- WebSocket reconnection automatic with exponential backoff +- Component-level error boundaries for unhandled errors + +### Performance + +- React Query caching with automatic refetch intervals +- WebSocket connection pooling +- Component memoization where appropriate +- Virtual scrolling for large data tables diff --git a/ML_IMPLEMENTATION.md b/ML_IMPLEMENTATION.md new file mode 100644 index 0000000..f1274fd --- /dev/null +++ b/ML_IMPLEMENTATION.md @@ -0,0 +1,679 @@ +# ML Implementation Strategy + +## Overview + +This document outlines the machine learning components for the energy trading system, including gradient boosting price prediction models and Q-Learning reinforcement learning for battery optimization. + +**Package Location**: `backend/app/ml/` (integrated within the FastAPI backend) + +**Related Documents**: +- `BACKEND_IMPLEMENTATION.md` - API endpoints, services, and ML integration layer +- `FRONTEND_IMPLEMENTATION.md` - React frontend that consumes backend API + +**Training Split Strategy**: Time-based split (first 7 days train, next 1.5 days validation, last 1.5 days test) to prevent look-ahead bias and ensure realistic evaluation. + +**Data Source**: `~/energy-test-data/data/processed/*.parquet` + +--- + +## Architecture + +``` +┌──────────────────────────────────────────────────────────────┐ +│ ML Pipeline │ +├──────────────────────────────────────────────────────────────┤ +│ │ +│ ┌──────────────┬──────────────────────┬──────────────────┐ │ +│ │ Feature │ │ │ │ +│ │ Engineering │ Model Training │ Model Serving │ │ +│ │ │ │ │ │ +│ │ - Lags │ - XGBoost │ - Load Model │ │ +│ │ - Rolling │ - Q-Learning │ - Predict │ │ +│ │ - Time │ - Validation │ - Return Action │ │ +│ │ - Regions │ - Evaluation │ │ │ +│ └──────────────┴──────────────────────┴──────────────────┘ │ +│ │ │ +│ ┌──────────────────────┴──────────────────────────────┐ │ +│ │ Model Management │ │ +│ │ - Versioning - Persistence - Registry - Backup │ │ +│ └──────────────────────────────────────────────────────┘ │ +└──────────────────────────────────────────────────────────────┘ + │ + ▼ +┌──────────────────────────────────────────────────────────────┐ +│ FastAPI Backend (app.services.ml_service) │ +│ /api/v1/models/train, /api/v1/models/predict │ +└──────────────────────────────────────────────────────────────┘ +``` + +--- + +## Project Structure + +``` +backend/app/ml/ +├── __init__.py +│ +├── features/ +│ ├── __init__.py +│ ├── lag_features.py # Price lag feature extraction +│ ├── rolling_features.py # Rolling statistics +│ ├── time_features.py # Time-of-day encoding +│ ├── regional_features.py # Cross-region differentials +│ └── battery_features.py # Battery state features +│ +├── price_prediction/ +│ ├── __init__.py +│ ├── model.py # XGBoost model wrapper +│ ├── trainer.py # Training pipeline +│ ├── evaluator.py # Evaluation metrics +│ └── predictor.py # Prediction interface +│ +├── rl_battery/ +│ ├── __init__.py +│ ├── environment.py # Battery MDP environment +│ ├── agent.py # Q-Learning agent +│ ├── trainer.py # RL training loop +│ └── policy.py # Policy inference +│ +├── model_management/ +│ ├── __init__.py +│ ├── registry.py # Model registry +│ ├── persistence.py # Save/load models +│ ├── versioning.py # Version handling +│ └── comparison.py # Model comparison +│ +├── evaluation/ +│ ├── __init__.py +│ ├── metrics.py # Common evaluation metrics +│ ├── backtest_evaluator.py # Backtest performance evaluation +│ └── reports.py # Generate evaluation reports +│ +├── training/ +│ ├── __init__.py +│ ├── cli.py # CLI commands for retraining +│ ├── pipeline.py # End-to-end training pipeline +│ └── scheduler.py # Training job scheduler +│ +└── utils/ + ├── __init__.py + ├── data_split.py # Time-based data splitting + ├── preprocessing.py # Data preprocessing + ├── config.py # ML configuration + └── evaluation.py # Evaluation metrics +``` + +--- + +## Configuration + +### backend/app/ml/utils/config.py + +```python +from dataclasses import dataclass +from typing import List, Dict, Any +from pathlib import Path + +@dataclass +class PricePredictionConfig: + """Configuration for price prediction models.""" + + # Data + data_path: str = "~/energy-test-data/data/processed" + target_column: str = "real_time_price" + + # Training split (time-based) + train_end_pct: float = 0.70 # First 70% for training + val_end_pct: float = 0.85 # Next 15% for validation + # Last 15% for testing + + # Features + price_lags: List[int] = None + rolling_windows: List[int] = None + include_time_features: bool = True + include_regional_features: bool = True + + # Model + n_estimators: int = 200 + max_depth: int = 6 + learning_rate: float = 0.1 + subsample: float = 0.8 + colsample_bytree: float = 0.8 + random_state: int = 42 + + # Early stopping + early_stopping_rounds: int = 20 + early_stopping_threshold: float = 0.001 + + def __post_init__(self): + if self.price_lags is None: + self.price_lags = [1, 5, 10, 15, 30, 60] + if self.rolling_windows is None: + self.rolling_windows = [5, 10, 15, 30, 60] + + +@dataclass +class RLBatteryConfig: + """Configuration for RL battery optimization.""" + + # State space + charge_level_bins: int = 10 + price_bins: int = 10 + time_bins: int = 24 # Hours + + # Action space + actions: List[str] = None + + # Q-Learning + learning_rate: float = 0.1 + discount_factor: float = 0.95 + epsilon: float = 1.0 + epsilon_decay: float = 0.995 + epsilon_min: float = 0.05 + + # Training + episodes: int = 1000 + max_steps: int = 14400 # 10 days * 1440 minutes + + # Battery constraints + min_reserve: float = 0.10 # 10% + max_charge: float = 0.90 # 90% + efficiency: float = 0.90 + + # Reward scaling + reward_scale: float = 1.0 + + def __post_init__(self): + if self.actions is None: + self.actions = ["charge", "hold", "discharge"] + + +@dataclass +class MLConfig: + """Overall ML configuration.""" + + # Paths + models_path: str = "models" + results_path: str = "results" + + # Price prediction + price_prediction: PricePredictionConfig = None + rl_battery: RLBatteryConfig = None + + # Training + enable_gpu: bool = False + n_jobs: int = 4 + verbose: bool = True + + # Retraining + keep_backup: bool = True + max_backups: int = 5 + + def __post_init__(self): + if self.price_prediction is None: + self.price_prediction = PricePredictionConfig() + if self.rl_battery is None: + self.rl_battery = RLBatteryConfig() + + +# Default configuration +default_config = MLConfig() +``` + +--- + +## Feature Engineering Interface + +### Key Functions (backend/app/ml/features/__init__.py) + +```python +def build_price_features( + df: pd.DataFrame, + price_col: str = "real_time_price", + lags: List[int] = None, + windows: List[int] = None, + regions: List[str] = None, + include_time: bool = True, + include_regional: bool = True, +) -> pd.DataFrame: + """ + Build complete feature set for price prediction. + + Args: + df: Input DataFrame + price_col: Name of price column + lags: List of lag periods + windows: List of rolling window sizes + regions: List of regions for differential features + include_time: Whether to include time features + include_regional: Whether to include regional features + + Returns: + DataFrame with all features + """ + + +def build_battery_features( + df: pd.DataFrame, + price_df: pd.DataFrame, + battery_col: str = "charge_level_mwh", + capacity_col: str = "capacity_mwh", + timestamp_col: str = "timestamp", + battery_id_col: str = "battery_id", +) -> pd.DataFrame: + """ + Build features for battery RL model. + + Args: + df: Battery DataFrame + price_df: Price DataFrame + battery_col: Name of battery charge level column + capacity_col: Name of battery capacity column + timestamp_col: Name of timestamp column + battery_id_col: Name of battery ID column + + Returns: + DataFrame with battery features + """ +``` + +--- + +## Price Prediction Interface + +### PricePredictor (backend/app/ml/price_prediction/predictor.py) + +```python +class PricePredictor: + """Interface for making price predictions.""" + + def __init__(self, models_dir: str = "models/price_prediction"): + """ + Initialize predictor. + + Args: + models_dir: Directory containing trained models + """ + + def predict( + self, + current_data: pd.DataFrame, + horizon: int = 15, + region: Optional[str] = None, + ) -> float: + """ + Predict price for a specific horizon. + + Args: + current_data: Current/historical price data + horizon: Prediction horizon in minutes + region: Specific region to predict for (optional) + + Returns: + Predicted price + """ + + def predict_all_horizons( + self, + current_data: pd.DataFrame, + region: Optional[str] = None, + ) -> Dict[int, float]: + """ + Predict prices for all available horizons. + + Returns: + Dictionary mapping horizons to predictions + """ + + def predict_with_confidence( + self, + current_data: pd.DataFrame, + horizon: int = 15, + region: Optional[str] = None, + ) -> Dict: + """ + Predict price with confidence interval. + + Returns: + Dictionary with prediction and confidence interval + """ + + def get_feature_importance(self, horizon: int) -> pd.DataFrame: + """ + Get feature importance for a specific horizon. + + Returns: + DataFrame with feature importance + """ +``` + +### PricePredictionTrainer (backend/app/ml/price_prediction/trainer.py) + +```python +class PricePredictionTrainer: + """Training pipeline for price prediction models.""" + + def __init__(self, config: PricePredictionConfig = None): + """Initialize trainer.""" + + def load_data(self) -> pd.DataFrame: + """Load price data.""" + + def prepare_data(self, df: pd.DataFrame) -> Tuple[pd.DataFrame, List[str]]: + """ + Prepare data with features. + + Returns: + Tuple of (features DataFrame, feature names) + """ + + def train_for_horizon( + self, + df_features: pd.DataFrame, + feature_cols: List[str], + horizon: int, + ) -> Dict: + """ + Train model for a specific horizon. + + Returns: + Training results dictionary with metrics + """ + + def train_all(self, horizons: List[int] = None) -> Dict: + """ + Train models for all horizons. + + Returns: + Dictionary with all training results + """ + + def save_models(self, output_dir: str = "models/price_prediction") -> None: + """Save all trained models.""" + + @classmethod + def load_models( + cls, + models_dir: str = "models/price_prediction", + horizons: List[int] = None, + ) -> Dict[int, PricePredictionModel]: + """ + Load trained models. + + Returns: + Dictionary mapping horizons to models + """ +``` + +--- + +## RL Battery Optimization Interface + +### BatteryPolicy (backend/app/ml/rl_battery/policy.py) + +```python +class BatteryPolicy: + """Interface for RL battery policy inference.""" + + def __init__(self, policy_path: str = "models/rl_battery"): + """ + Initialize policy. + + Args: + policy_path: Path to trained policy + """ + + def get_action( + self, + charge_level: float, + current_price: float, + price_forecast_1m: float = 0, + price_forecast_5m: float = 0, + price_forecast_15m: float = 0, + hour: int = 0, + ) -> Dict: + """ + Get action for current state. + + Returns: + Dictionary with action, q_values + """ +``` + +### BatteryRLTrainer (backend/app/ml/rl_battery/trainer.py) + +```python +class BatteryRLTrainer: + """Training pipeline for RL battery policy.""" + + def __init__(self, config: RLBatteryConfig = None): + """Initialize trainer.""" + + def load_data(self) -> None: + """Load price data for environment.""" + + def train(self, n_episodes: int = 1000, region: str = "FR") -> Dict: + """ + Train RL agent. + + Returns: + Training results with metrics + """ + + def save(self, output_dir: str = "models/rl_battery") -> None: + """Save trained policy.""" +``` + +--- + +## Model Management Interface + +### ModelRegistry (backend/app/ml/model_management/registry.py) + +```python +class ModelRegistry: + """Registry for tracking model versions.""" + + def __init__(self, registry_path: str = "models/registry.json"): + """Initialize registry.""" + + def register_model( + self, + model_type: str, + model_id: str, + version: str, + filepath: str, + metadata: Dict = None, + ) -> None: + """Register a model version.""" + + def get_latest_version(self, model_id: str) -> Optional[Dict]: + """Get latest version of a model.""" + + def list_models(self) -> List[Dict]: + """List all registered models.""" +``` + +--- + +## Training Task Interface + +### Training Jobs (app/tasks/training_tasks.py) + +```python +async def train_model_task(training_id: str, request: TrainingRequest): + """ + Execute ML model training via Celery task. + + Dispatches to Celery for async processing of: + - Price prediction training + - RL battery policy training + + Emits WebSocket events for progress updates. + """ + +@shared_task(name="tasks.train_price_prediction") +def train_price_prediction(training_id: str, request_dict: dict): + """ + Celery task for price prediction model training. + + Process: + 1. Load and prepare data + 2. Train XGBoost models for specified horizon + 3. Save models + 4. Register in model registry + 5. Update training job status + """ + +@shared_task(name="tasks.train_rl_battery") +def train_rl_battery(training_id: str, request_dict: dict): + """ + Celery task for RL battery policy training. + + Process: + 1. Load environment and data + 2. Train Q-Learning agent + 3. Save policy + 4. Register in model registry + 5. Update training job status + """ +``` + +--- + +## Key Integration Points + +### ML Service Integration (app/services/ml_service.py) + +The ML service provides the bridge between API routes and ML models: + +```python +class MLService: + """Service for ML model management and inference.""" + + def list_models(self) -> List[ModelInfo]: + """List all available trained models.""" + + def get_model_metrics(self, model_id: str) -> Dict[str, float]: + """Get performance metrics for a model.""" + + def predict( + self, + model_id: str, + timestamp: datetime, + features: Optional[Dict[str, Any]] = None + ) -> Dict[str, Any]: + """ + Run prediction with on-demand model loading. + + Supports: + - Price prediction models + - RL battery policy + """ + + def get_feature_importance(self, model_id: str) -> Dict[str, float]: + """Get feature importance for a model.""" +``` + +### WebSocket Events + +Real-time training progress updates: + +```python +# Event types +- "model_training_progress" # Training progress updates + +# Payload +{ + "model_id": str, + "progress": float, # 0.0 to 1.0 + "epoch": int, # Current epoch (optional) + "metrics": dict # Current metrics +} +``` + +--- + +## Data Flow + +### Training Pipeline + +``` +API Request (POST /api/v1/models/train) + ↓ +training_tasks.train_model_task() + ↓ +Celery Task (train_price_prediction / train_rl_battery) + ↓ +PricePredictionTrainer / BatteryRLTrainer + ↓ +Feature Engineering → Model Training → Evaluation + ↓ +Save Models → Register in Registry + ↓ +WebSocket Events (progress updates) + ↓ +Update Training Status +``` + +### Prediction Pipeline + +``` +API Request (POST /api/v1/models/predict) + ↓ +ml_service.predict() + ↓ +Load Model (on-demand) + ↓ +Feature Engineering + ↓ +Model.predict() / Policy.get_action() + ↓ +Return Prediction with Confidence +``` + +--- + +## Model Artifacts + +### Price Prediction Models + +Location: `models/price_prediction/` + +``` +model_1min.pkl # 1-minute horizon model +model_5min.pkl # 5-minute horizon model +model_15min.pkl # 15-minute horizon model +model_60min.pkl # 60-minute horizon model +training_results.json # Training metrics and metadata +``` + +### RL Battery Policy + +Location: `models/rl_battery/` + +``` +battery_policy.pkl # Trained Q-Learning policy +training_results.json # Training metrics +``` + +### Model Registry + +Location: `models/registry.json` + +```json +{ + "models": { + "price_prediction_15m": { + "type": "price_prediction", + "versions": ["v20260211_134500", "v20260210_100000"], + "latest": "v20260211_134500" + }, + "battery_policy": { + "type": "rl_battery", + "versions": ["v20260211_140000"], + "latest": "v20260211_140000" + } + } +} +```