Add FastAPI backend for energy trading system

Implements FastAPI backend with ML model support for energy trading,
including price prediction models and RL-based battery trading policy.
Features dashboard, trading, backtest, and settings API routes with
WebSocket support for real-time updates.
This commit is contained in:
2026-02-12 00:59:26 +07:00
parent a22a13f6f4
commit fe76bc7629
72 changed files with 2931 additions and 0 deletions

24
backend/.env.example Normal file
View File

@@ -0,0 +1,24 @@
APP_NAME=Energy Trading API
APP_VERSION=1.0.0
DEBUG=true
HOST=0.0.0.0
PORT=8000
DATA_PATH=~/energy-test-data/data/processed
CORS_ORIGINS=http://localhost:3000,http://localhost:5173
WS_HEARTBEAT_INTERVAL=30
CELERY_BROKER_URL=redis://localhost:6379/0
CELERY_RESULT_BACKEND=redis://localhost:6379/0
MODELS_PATH=models
RESULTS_PATH=results
BATTERY_MIN_RESERVE=0.10
BATTERY_MAX_CHARGE=0.90
ARBITRAGE_MIN_SPREAD=5.0
MINING_MARGIN_THRESHOLD=5.0

46
backend/.gitignore vendored Normal file
View File

@@ -0,0 +1,46 @@
.env
.venv
venv/
__pycache__/
*.py[cod]
*$py.class
*.so
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
*.egg-info/
.installed.cfg
*.egg
.pytest_cache/
.coverage
htmlcov/
*.cover
.hypothesis/
.tox/
.mypy_cache/
.dmypy.json
dmypy.json
*.log
logs/
server.log
*.pid
.vscode/
.idea/
*.swp
*.swo
*~
.DS_Store
Thumbs.db
*.bak
*.tmp
*.orig

12
backend/Dockerfile Normal file
View File

@@ -0,0 +1,12 @@
FROM python:3.10-slim
WORKDIR /app
COPY requirements.txt .
RUN pip install --no-cache-dir -r requirements.txt
COPY . .
EXPOSE 8000
CMD ["uvicorn", "app.main:app", "--host", "0.0.0.0", "--port", "8000"]

54
backend/README.md Normal file
View File

@@ -0,0 +1,54 @@
# Energy Trading Backend
FastAPI backend for the energy trading system with ML model support.
## Setup
```bash
cd backend
pip install -r requirements.txt
cp .env.example .env
```
## Running
```bash
uvicorn app.main:app --reload --host 0.0.0.0 --port 8000
```
## API Documentation
- Swagger UI: http://localhost:8000/docs
- ReDoc: http://localhost:8000/redoc
## Project Structure
```
backend/
├── app/
│ ├── api/ # API routes and WebSocket
│ ├── services/ # Business logic services
│ ├── tasks/ # Background tasks
│ ├── ml/ # ML models and training
│ ├── models/ # Pydantic models
│ └── utils/ # Utilities
├── models/ # Trained models
├── results/ # Backtest results
└── tests/ # Tests
```
## Training ML Models
```bash
# Train price prediction models
python -m app.ml.training.cli price --horizons 1 5 15 60
# Train RL battery policy
python -m app.ml.training.cli rl --episodes 1000
```
## Running Tests
```bash
pytest
```

7
backend/app/__init__.py Normal file
View File

@@ -0,0 +1,7 @@
from fastapi import FastAPI
from app.config import settings
from app.utils.logger import get_logger
logger = get_logger(__name__)
__all__ = ["app", "settings", "logger"]

View File

@@ -0,0 +1,3 @@
from app.api.routes import dashboard, backtest, models, trading, settings
__all__ = ["dashboard", "backtest", "models", "trading", "settings"]

View File

@@ -0,0 +1,3 @@
from app.api.routes import dashboard, backtest, models, trading, settings
__all__ = ["dashboard", "backtest", "models", "trading", "settings"]

View File

@@ -0,0 +1,108 @@
from typing import List, Optional
from fastapi import APIRouter, BackgroundTasks, HTTPException
from datetime import datetime
from app.models.schemas import (
BacktestConfig,
BacktestMetrics,
BacktestStatusEnum,
Trade,
)
from app.utils.logger import get_logger
import uuid
logger = get_logger(__name__)
router = APIRouter()
_backtest_store: dict = {}
_backtest_results: dict = {}
@router.post("/start")
async def start_backtest(config: BacktestConfig, name: Optional[str] = None):
backtest_id = str(uuid.uuid4())
_backtest_store[backtest_id] = {
"backtest_id": backtest_id,
"name": name or f"Backtest {backtest_id[:8]}",
"status": BacktestStatusEnum.RUNNING,
"config": config.dict(),
"created_at": datetime.utcnow(),
"started_at": datetime.utcnow(),
"completed_at": None,
"error_message": None,
}
logger.info(f"Backtest started: {backtest_id}")
return {
"backtest_id": backtest_id,
"status": BacktestStatusEnum.RUNNING,
"name": _backtest_store[backtest_id]["name"],
}
@router.get("/{backtest_id}")
async def get_backtest_status(backtest_id: str):
if backtest_id not in _backtest_store:
raise HTTPException(status_code=404, detail=f"Backtest {backtest_id} not found")
backtest = _backtest_store[backtest_id]
result = _backtest_results.get(backtest_id)
return {
"status": backtest["status"],
"name": backtest["name"],
"created_at": backtest["created_at"],
"started_at": backtest["started_at"],
"completed_at": backtest["completed_at"],
"error_message": backtest["error_message"],
"results": result if backtest["status"] == BacktestStatusEnum.COMPLETED else None,
}
@router.get("/{backtest_id}/results")
async def get_backtest_results(backtest_id: str):
if backtest_id not in _backtest_results:
raise HTTPException(status_code=404, detail=f"Results for backtest {backtest_id} not found")
return _backtest_results[backtest_id]
@router.get("/{backtest_id}/trades")
async def get_backtest_trades(backtest_id: str, limit: int = 100):
if backtest_id not in _backtest_store:
raise HTTPException(status_code=404, detail=f"Backtest {backtest_id} not found")
trades = _backtest_results.get(backtest_id, {}).get("trades", [])
return {"backtest_id": backtest_id, "trades": trades[-limit:], "total": len(trades)}
@router.get("/history")
async def get_backtest_history():
backtests = []
for backtest_id, backtest in _backtest_store.items():
backtests.append(
{
"backtest_id": backtest_id,
"name": backtest["name"],
"status": backtest["status"],
"created_at": backtest["created_at"],
"completed_at": backtest["completed_at"],
}
)
return {"backtests": backtests, "total": len(backtests)}
@router.delete("/{backtest_id}")
async def delete_backtest(backtest_id: str):
if backtest_id not in _backtest_store:
raise HTTPException(status_code=404, detail=f"Backtest {backtest_id} not found")
del _backtest_store[backtest_id]
if backtest_id in _backtest_results:
del _backtest_results[backtest_id]
logger.info(f"Backtest deleted: {backtest_id}")
return {"message": f"Backtest {backtest_id} deleted"}

View File

@@ -0,0 +1,48 @@
from typing import List
from fastapi import APIRouter, Depends, Query
from app.models.schemas import DashboardSummary, ArbitrageOpportunity, PriceData, BatteryState
from app.services import DataService
router = APIRouter()
data_service = DataService()
@router.on_event("startup")
async def startup():
await data_service.initialize()
@router.get("/summary", response_model=DashboardSummary)
async def get_summary():
summary = data_service.get_dashboard_summary()
return DashboardSummary(**summary)
@router.get("/prices")
async def get_latest_prices():
return {"regions": data_service.get_latest_prices()}
@router.get("/prices/history")
async def get_price_history(
region: str = Query(..., description="Region code (FR, BE, DE, NL, UK)"),
start: str = Query(None, description="Start date (YYYY-MM-DD)"),
end: str = Query(None, description="End date (YYYY-MM-DD)"),
limit: int = Query(1000, description="Maximum number of records"),
):
data = data_service.get_price_history(region, start, end, limit)
return {"region": region, "data": data}
@router.get("/battery")
async def get_battery_states():
batteries = data_service.get_battery_states()
return {"batteries": batteries}
@router.get("/arbitrage")
async def get_arbitrage_opportunities(
min_spread: float = Query(None, description="Minimum spread in EUR/MWh")
):
opportunities = data_service.get_arbitrage_opportunities(min_spread)
return {"opportunities": opportunities, "count": len(opportunities)}

View File

@@ -0,0 +1,71 @@
from typing import List
from fastapi import APIRouter, HTTPException
from datetime import datetime
from app.models.schemas import ModelInfo, TrainingRequest, TrainingStatus, PredictionResponse
from app.services import MLService
import uuid
router = APIRouter()
ml_service = MLService()
_training_store: dict = {}
@router.get("", response_model=List[ModelInfo])
async def list_models():
return ml_service.list_models()
@router.post("/train")
async def train_model(request: TrainingRequest):
training_id = f"training_{uuid.uuid4().hex[:8]}"
_training_store[training_id] = TrainingStatus(
training_id=training_id,
status="pending",
progress=0.0,
started_at=datetime.utcnow(),
)
return {"training_id": training_id, "status": _training_store[training_id]}
@router.get("/{training_id}/status", response_model=TrainingStatus)
async def get_training_status(training_id: str):
if training_id not in _training_store:
raise HTTPException(status_code=404, detail=f"Training job {training_id} not found")
return _training_store[training_id]
@router.get("/{model_id}/metrics")
async def get_model_metrics(model_id: str):
try:
metrics = ml_service.get_model_metrics(model_id)
return {"model_id": model_id, "metrics": metrics}
except ValueError as e:
raise HTTPException(status_code=404, detail=str(e))
@router.post("/predict", response_model=PredictionResponse)
async def predict(
model_id: str,
timestamp: datetime,
features: dict = None,
):
try:
result = ml_service.predict(model_id, timestamp, features)
return PredictionResponse(**result)
except ValueError as e:
raise HTTPException(status_code=404, detail=str(e))
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@router.get("/{model_id}/feature-importance")
async def get_feature_importance(model_id: str):
try:
importance = ml_service.get_feature_importance(model_id)
return {"model_id": model_id, "feature_importance": importance}
except ValueError as e:
raise HTTPException(status_code=404, detail=str(e))

View File

@@ -0,0 +1,26 @@
from fastapi import APIRouter
from app.config import settings
from app.models.schemas import AppSettings
router = APIRouter()
@router.get("", response_model=AppSettings)
async def get_settings():
return AppSettings(
battery_min_reserve=settings.BATTERY_MIN_RESERVE,
battery_max_charge=settings.BATTERY_MAX_CHARGE,
arbitrage_min_spread=settings.ARBITRAGE_MIN_SPREAD,
mining_margin_threshold=settings.MINING_MARGIN_THRESHOLD,
)
@router.post("")
async def update_settings(settings_update: dict):
updated_fields = []
for key, value in settings_update.items():
if hasattr(settings, key.upper()):
setattr(settings, key.upper(), value)
updated_fields.append(key)
return {"message": "Settings updated", "updated_fields": updated_fields}

View File

@@ -0,0 +1,27 @@
from typing import List
from fastapi import APIRouter, HTTPException
from app.models.enums import StrategyEnum
from app.models.schemas import StrategyStatus
from app.services import StrategyService
router = APIRouter()
strategy_service = StrategyService()
@router.get("/strategies", response_model=List[StrategyStatus])
async def get_strategies():
return await strategy_service.get_all_strategies()
@router.post("/strategies")
async def toggle_strategy(strategy: StrategyEnum, action: str):
if action not in ["start", "stop"]:
raise HTTPException(status_code=400, detail="Action must be 'start' or 'stop'")
status = await strategy_service.toggle_strategy(strategy, action)
return {"status": status}
@router.get("/positions")
async def get_positions():
return {"positions": [], "total": 0}

View File

@@ -0,0 +1,62 @@
from fastapi import WebSocket
from typing import List
from app.utils.logger import get_logger
logger = get_logger(__name__)
class ConnectionManager:
def __init__(self):
self.active_connections: List[WebSocket] = []
async def connect(self, websocket: WebSocket):
await websocket.accept()
self.active_connections.append(websocket)
logger.info(f"WebSocket connected. Total connections: {len(self.active_connections)}")
def disconnect(self, websocket: WebSocket):
if websocket in self.active_connections:
self.active_connections.remove(websocket)
logger.info(f"WebSocket disconnected. Total connections: {len(self.active_connections)}")
async def broadcast(self, event_type: str, data: dict):
message = {"event_type": event_type, "data": data, "timestamp": None}
disconnected = []
for connection in self.active_connections:
try:
await connection.send_json(message)
except Exception:
disconnected.append(connection)
for conn in disconnected:
self.disconnect(conn)
async def broadcast_price_update(self, region: str, price_data: dict):
await self.broadcast("price_update", {"region": region, "price_data": price_data})
async def broadcast_battery_update(self, battery_id: str, battery_state: dict):
await self.broadcast("battery_update", {"battery_id": battery_id, "battery_state": battery_state})
async def broadcast_trade(self, trade: dict):
await self.broadcast("trade_executed", trade)
async def broadcast_alert(self, alert: dict):
await self.broadcast("alert_triggered", alert)
async def broadcast_backtest_progress(self, backtest_id: str, progress: float, status: str):
await self.broadcast(
"backtest_progress",
{"backtest_id": backtest_id, "progress": progress, "status": status},
)
async def broadcast_model_training_progress(
self, model_id: str, progress: float, epoch: int = None, metrics: dict = None
):
await self.broadcast(
"model_training_progress",
{"model_id": model_id, "progress": progress, "epoch": epoch, "metrics": metrics},
)
manager = ConnectionManager()

48
backend/app/config.py Normal file
View File

@@ -0,0 +1,48 @@
from pydantic_settings import BaseSettings, SettingsConfigDict
from pathlib import Path
from typing import List, Union
class Settings(BaseSettings):
APP_NAME: str = "Energy Trading API"
APP_VERSION: str = "1.0.0"
DEBUG: bool = True
HOST: str = "0.0.0.0"
PORT: int = 8000
DATA_PATH: str = "~/energy-test-data/data/processed"
DATA_PATH_RESOLVED: Path = Path(DATA_PATH).expanduser()
CORS_ORIGINS: Union[str, List[str]] = [
"http://localhost:3000",
"http://localhost:5173",
]
WS_HEARTBEAT_INTERVAL: int = 30
CELERY_BROKER_URL: str = "redis://localhost:6379/0"
CELERY_RESULT_BACKEND: str = "redis://localhost:6379/0"
MODELS_PATH: str = "models"
RESULTS_PATH: str = "results"
BATTERY_MIN_RESERVE: float = 0.10
BATTERY_MAX_CHARGE: float = 0.90
ARBITRAGE_MIN_SPREAD: float = 5.0
MINING_MARGIN_THRESHOLD: float = 5.0
ML_PREDICTION_HORIZONS: List[int] = [1, 5, 15, 60]
ML_FEATURE_LAGS: List[int] = [1, 5, 10, 15, 30, 60]
model_config = SettingsConfigDict(env_file=".env", case_sensitive=True, env_ignore_empty=True)
@property
def cors_origins_list(self) -> List[str]:
if isinstance(self.CORS_ORIGINS, str):
return [origin.strip() for origin in self.CORS_ORIGINS.split(",")]
return self.CORS_ORIGINS if isinstance(self.CORS_ORIGINS, list) else []
settings = Settings()

View File

@@ -0,0 +1,25 @@
from app.core.constants import (
DEFAULT_BATTERY_CAPACITY_MWH,
DEFAULT_CHARGE_RATE_MW,
DEFAULT_DISCHARGE_RATE_MW,
DEFAULT_EFFICIENCY,
WS_HEARTBEAT_INTERVAL,
PRICE_DATA_REGIONS,
DEFAULT_DATA_LIMIT,
BACKTEST_RESULT_TIMEOUT,
TRAINING_RESULT_TIMEOUT,
MAX_BACKTEST_TRADES,
)
__all__ = [
"DEFAULT_BATTERY_CAPACITY_MWH",
"DEFAULT_CHARGE_RATE_MW",
"DEFAULT_DISCHARGE_RATE_MW",
"DEFAULT_EFFICIENCY",
"WS_HEARTBEAT_INTERVAL",
"PRICE_DATA_REGIONS",
"DEFAULT_DATA_LIMIT",
"BACKTEST_RESULT_TIMEOUT",
"TRAINING_RESULT_TIMEOUT",
"MAX_BACKTEST_TRADES",
]

View File

@@ -0,0 +1,19 @@
from datetime import timedelta
DEFAULT_BATTERY_CAPACITY_MWH = 100.0
DEFAULT_CHARGE_RATE_MW = 50.0
DEFAULT_DISCHARGE_RATE_MW = 50.0
DEFAULT_EFFICIENCY = 0.90
DEFAULT_HEARTBEAT_INTERVAL = 30
WS_HEARTBEAT_INTERVAL = 30
PRICE_DATA_REGIONS = ["FR", "BE", "DE", "NL", "UK"]
DEFAULT_DATA_LIMIT = 1000
BACKTEST_RESULT_TIMEOUT = timedelta(hours=1)
TRAINING_RESULT_TIMEOUT = timedelta(hours=24)
MAX_BACKTEST_TRADES = 10000

68
backend/app/main.py Normal file
View File

@@ -0,0 +1,68 @@
from fastapi import FastAPI, WebSocket, WebSocketDisconnect
from fastapi.middleware.cors import CORSMiddleware
from app.config import settings
from app.utils.logger import setup_logger, get_logger
from app.api.routes import dashboard, backtest, models, trading, settings as settings_routes
from app.api.websocket import manager
setup_logger()
logger = get_logger(__name__)
app = FastAPI(
title=settings.APP_NAME,
version=settings.APP_VERSION,
docs_url="/docs",
redoc_url="/redoc",
)
app.add_middleware(
CORSMiddleware,
allow_origins=settings.cors_origins_list,
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
app.include_router(dashboard.router, prefix="/api/v1/dashboard", tags=["dashboard"])
app.include_router(backtest.router, prefix="/api/v1/backtest", tags=["backtest"])
app.include_router(models.router, prefix="/api/v1/models", tags=["models"])
app.include_router(trading.router, prefix="/api/v1/trading", tags=["trading"])
app.include_router(settings_routes.router, prefix="/api/v1/settings", tags=["settings"])
@app.on_event("startup")
async def startup_event():
logger.info(f"Starting {settings.APP_NAME} v{settings.APP_VERSION}")
logger.info(f"Data path: {settings.DATA_PATH_RESOLVED}")
@app.on_event("shutdown")
async def shutdown_event():
logger.info("Shutting down application")
@app.get("/health")
async def health_check():
return {"status": "healthy", "version": settings.APP_VERSION}
@app.websocket("/ws/real-time")
async def websocket_endpoint(websocket: WebSocket):
await manager.connect(websocket)
try:
while True:
data = await websocket.receive_text()
await manager.broadcast("message", {"text": data})
except WebSocketDisconnect:
manager.disconnect(websocket)
if __name__ == "__main__":
import uvicorn
uvicorn.run(
"app.main:app",
host=settings.HOST,
port=settings.PORT,
reload=settings.DEBUG,
)

View File

@@ -0,0 +1,6 @@
from app.ml.features import (
build_price_features,
build_battery_features,
)
__all__ = ["build_price_features", "build_battery_features"]

View File

@@ -0,0 +1,3 @@
from app.ml.evaluation import ModelEvaluator, BacktestEvaluator
__all__ = ["ModelEvaluator", "BacktestEvaluator"]

View File

@@ -0,0 +1,3 @@
from app.ml.evaluation.metrics import BacktestEvaluator
__all__ = ["BacktestEvaluator"]

View File

@@ -0,0 +1,77 @@
from typing import Dict, List
import numpy as np
from sklearn.metrics import mean_absolute_error, mean_squared_error, r2_score
class ModelEvaluator:
@staticmethod
def calculate_metrics(y_true, y_pred) -> Dict[str, float]:
mae = mean_absolute_error(y_true, y_pred)
rmse = mean_squared_error(y_true, y_pred, squared=False)
mape = np.mean(np.abs((y_true - y_pred) / y_true)) * 100
r2 = r2_score(y_true, y_pred)
return {
"mae": float(mae),
"rmse": float(rmse),
"mape": float(mape) if not np.isnan(mape) else 0.0,
"r2": float(r2),
}
@staticmethod
def calculate_sharpe_ratio(returns: np.ndarray, risk_free_rate: float = 0.0) -> float:
if len(returns) == 0 or np.std(returns) == 0:
return 0.0
excess_returns = returns - risk_free_rate
return float(np.mean(excess_returns) / np.std(excess_returns))
@staticmethod
def calculate_max_drawdown(values: np.ndarray) -> float:
if len(values) == 0:
return 0.0
cumulative = np.cumsum(values)
running_max = np.maximum.accumulate(cumulative)
drawdown = (cumulative - running_max)
return float(drawdown.min())
class BacktestEvaluator:
def __init__(self):
self.trades: List[Dict] = []
def add_trade(self, trade: Dict):
self.trades.append(trade)
def evaluate(self) -> Dict[str, float]:
if not self.trades:
return {
"total_revenue": 0.0,
"total_trades": 0,
"win_rate": 0.0,
"sharpe_ratio": 0.0,
"max_drawdown": 0.0,
}
total_revenue = sum(t.get("revenue", 0) for t in self.trades)
winning_trades = sum(1 for t in self.trades if t.get("revenue", 0) > 0)
win_rate = winning_trades / len(self.trades) if self.trades else 0.0
returns = np.array([t.get("revenue", 0) for t in self.trades])
sharpe_ratio = ModelEvaluator.calculate_sharpe_ratio(returns)
max_drawdown = ModelEvaluator.calculate_max_drawdown(returns)
return {
"total_revenue": total_revenue,
"total_trades": len(self.trades),
"win_rate": win_rate,
"sharpe_ratio": sharpe_ratio,
"max_drawdown": max_drawdown,
}
def reset(self):
self.trades = []
__all__ = ["ModelEvaluator", "BacktestEvaluator"]

View File

@@ -0,0 +1,3 @@
from app.ml.evaluation.metrics import ModelEvaluator, BacktestEvaluator
__all__ = ["ModelEvaluator", "BacktestEvaluator"]

View File

@@ -0,0 +1,53 @@
from app.ml.features.lag_features import add_lag_features
from app.ml.features.rolling_features import add_rolling_features
from app.ml.features.time_features import add_time_features
from app.ml.features.regional_features import add_regional_features
from app.ml.features.battery_features import add_battery_features
from typing import List, Optional
import pandas as pd
def build_price_features(
df: pd.DataFrame,
price_col: str = "real_time_price",
lags: Optional[List[int]] = None,
windows: Optional[List[int]] = None,
regions: Optional[List[str]] = None,
include_time: bool = True,
include_regional: bool = True,
) -> pd.DataFrame:
if lags is None:
lags = [1, 5, 10, 15, 30, 60]
if windows is None:
windows = [5, 10, 15, 30, 60]
result = df.copy()
if price_col in result.columns:
result = add_lag_features(result, price_col, lags)
result = add_rolling_features(result, price_col, windows)
if include_time and "timestamp" in result.columns:
result = add_time_features(result)
if include_regional and regions:
result = add_regional_features(result, regions)
return result
def build_battery_features(
df: pd.DataFrame,
price_df: pd.DataFrame,
battery_col: str = "charge_level_mwh",
capacity_col: str = "capacity_mwh",
timestamp_col: str = "timestamp",
battery_id_col: str = "battery_id",
) -> pd.DataFrame:
result = df.copy()
result = add_battery_features(result, price_df, battery_col, capacity_col, timestamp_col, battery_id_col)
return result
__all__ = ["build_price_features", "build_battery_features"]

View File

@@ -0,0 +1,35 @@
import pandas as pd
def add_battery_features(
df: pd.DataFrame,
price_df: pd.DataFrame,
battery_col: str = "charge_level_mwh",
capacity_col: str = "capacity_mwh",
timestamp_col: str = "timestamp",
battery_id_col: str = "battery_id",
) -> pd.DataFrame:
result = df.copy()
if battery_col in result.columns and capacity_col in result.columns:
result["charge_level_pct"] = result[battery_col] / result[capacity_col]
result["discharge_potential_mwh"] = result[battery_col] * result.get("efficiency", 0.9)
result["charge_capacity_mwh"] = result[capacity_col] - result[battery_col]
if price_df is not None and "real_time_price" in price_df.columns and timestamp_col in result.columns:
merged = result.merge(
price_df[[timestamp_col, "real_time_price"]],
on=timestamp_col,
how="left",
suffixes=("", "_market")
)
if "real_time_price_market" in merged.columns:
result["market_price"] = merged["real_time_price_market"]
result["charge_cost_potential"] = result["charge_capacity_mwh"] * result["market_price"]
result["discharge_revenue_potential"] = result["discharge_potential_mwh"] * result["market_price"]
return result
__all__ = ["add_battery_features"]

View File

@@ -0,0 +1,14 @@
from typing import List
import pandas as pd
def add_lag_features(df: pd.DataFrame, col: str, lags: List[int]) -> pd.DataFrame:
result = df.copy()
for lag in lags:
result[f"{col}_lag_{lag}"] = result[col].shift(lag)
return result
__all__ = ["add_lag_features"]

View File

@@ -0,0 +1,18 @@
from typing import List
import pandas as pd
def add_regional_features(df: pd.DataFrame, regions: List[str]) -> pd.DataFrame:
result = df.copy()
if "region" in result.columns and "real_time_price" in result.columns:
avg_price_by_region = result.groupby("region")["real_time_price"].mean()
for region in regions:
region_avg = avg_price_by_region.get(region, 0)
result[f"price_diff_{region}"] = result["real_time_price"] - region_avg
return result
__all__ = ["add_regional_features"]

View File

@@ -0,0 +1,17 @@
from typing import List
import pandas as pd
def add_rolling_features(df: pd.DataFrame, col: str, windows: List[int]) -> pd.DataFrame:
result = df.copy()
for window in windows:
result[f"{col}_rolling_mean_{window}"] = result[col].rolling(window=window).mean()
result[f"{col}_rolling_std_{window}"] = result[col].rolling(window=window).std()
result[f"{col}_rolling_min_{window}"] = result[col].rolling(window=window).min()
result[f"{col}_rolling_max_{window}"] = result[col].rolling(window=window).max()
return result
__all__ = ["add_rolling_features"]

View File

@@ -0,0 +1,35 @@
import pandas as pd
def add_time_features(df: pd.DataFrame, timestamp_col: str = "timestamp") -> pd.DataFrame:
result = df.copy()
if timestamp_col not in result.columns:
return result
result[timestamp_col] = pd.to_datetime(result[timestamp_col])
result["hour"] = result[timestamp_col].dt.hour
result["day_of_week"] = result[timestamp_col].dt.dayofweek
result["day_of_month"] = result[timestamp_col].dt.day
result["month"] = result[timestamp_col].dt.month
result["hour_sin"] = _sin_encode(result["hour"], 24)
result["hour_cos"] = _cos_encode(result["hour"], 24)
result["day_sin"] = _sin_encode(result["day_of_week"], 7)
result["day_cos"] = _cos_encode(result["day_of_week"], 7)
return result
def _sin_encode(x, period):
import numpy as np
return np.sin(2 * np.pi * x / period)
def _cos_encode(x, period):
import numpy as np
return np.cos(2 * np.pi * x / period)
__all__ = ["add_time_features"]

View File

@@ -0,0 +1,3 @@
from app.ml.model_management import ModelRegistry
__all__ = ["ModelRegistry"]

View File

@@ -0,0 +1,99 @@
from typing import Dict, List, Optional
from pathlib import Path
import json
from datetime import datetime
from app.utils.logger import get_logger
logger = get_logger(__name__)
class ModelRegistry:
def __init__(self, registry_path: str = "models/registry.json"):
self.registry_path = Path(registry_path)
self._registry: Dict[str, Dict] = {}
self._load()
def _load(self):
if self.registry_path.exists():
with open(self.registry_path) as f:
self._registry = json.load(f)
logger.info(f"Loaded registry from {self.registry_path}")
def _save(self):
self.registry_path.parent.mkdir(parents=True, exist_ok=True)
with open(self.registry_path, "w") as f:
json.dump(self._registry, f, indent=2, default=str)
logger.info(f"Saved registry to {self.registry_path}")
def register_model(
self,
model_type: str,
model_id: str,
version: str,
filepath: str,
metadata: Optional[Dict] = None,
) -> None:
timestamp = datetime.utcnow().isoformat()
if model_id not in self._registry:
self._registry[model_id] = {
"type": model_type,
"versions": [],
}
self._registry[model_id]["versions"].append({
"version": version,
"filepath": filepath,
"timestamp": timestamp,
"metadata": metadata or {},
})
self._registry[model_id]["latest"] = version
self._save()
logger.info(f"Registered model {model_id} version {version}")
def get_latest_version(self, model_id: str) -> Optional[Dict]:
if model_id not in self._registry:
return None
latest_version = self._registry[model_id].get("latest")
if not latest_version:
return None
for version_info in self._registry[model_id]["versions"]:
if version_info["version"] == latest_version:
return version_info
return None
def list_models(self) -> List[Dict]:
models = []
for model_id, model_info in self._registry.items():
latest = self.get_latest_version(model_id)
models.append({
"model_id": model_id,
"type": model_info.get("type"),
"latest_version": model_info.get("latest"),
"total_versions": len(model_info.get("versions", [])),
"latest_info": latest,
})
return models
def get_model(self, model_id: str, version: Optional[str] = None) -> Optional[Dict]:
if model_id not in self._registry:
return None
if version is None:
version = self._registry[model_id].get("latest")
for version_info in self._registry[model_id]["versions"]:
if version_info["version"] == version:
return version_info
return None
__all__ = ["ModelRegistry"]

View File

@@ -0,0 +1,3 @@
from app.ml.price_prediction import PricePredictor, PricePredictionTrainer
__all__ = ["PricePredictor", "PricePredictionTrainer"]

View File

@@ -0,0 +1,52 @@
import pickle
from typing import Optional
import xgboost as xgb
import numpy as np
class PricePredictionModel:
def __init__(self, horizon: int, model_id: Optional[str] = None):
self.horizon = horizon
self.model_id = model_id or f"price_prediction_{horizon}m"
self.model: Optional[xgb.XGBRegressor] = None
self.feature_names = []
def fit(self, X, y):
self.model = xgb.XGBRegressor(
n_estimators=200,
max_depth=6,
learning_rate=0.1,
subsample=0.8,
colsample_bytree=0.8,
random_state=42,
)
if isinstance(X, np.ndarray):
self.feature_names = [f"feature_{i}" for i in range(X.shape[1])]
else:
self.feature_names = list(X.columns)
self.model.fit(X, y)
def predict(self, X):
if self.model is None:
raise ValueError("Model not trained")
return self.model.predict(X)
def save(self, filepath: str):
with open(filepath, "wb") as f:
pickle.dump(self, f)
@classmethod
def load(cls, filepath: str):
with open(filepath, "rb") as f:
return pickle.load(f)
@property
def feature_importances_(self):
if self.model is None:
raise ValueError("Model not trained")
return self.model.feature_importances_
__all__ = ["PricePredictionModel"]

View File

@@ -0,0 +1,86 @@
from typing import Dict, Optional
import pandas as pd
import numpy as np
from app.ml.price_prediction.model import PricePredictionModel
from app.ml.price_prediction.trainer import PricePredictionTrainer
from app.utils.logger import get_logger
logger = get_logger(__name__)
class PricePredictor:
def __init__(self, models_dir: str = "models/price_prediction"):
self.models_dir = models_dir
self.models: Dict[int, PricePredictionModel] = {}
self._load_models()
def _load_models(self):
self.models = PricePredictionTrainer.load_models(self.models_dir)
logger.info(f"Loaded {len(self.models)} prediction models")
def predict(
self, current_data: pd.DataFrame, horizon: int = 15, region: Optional[str] = None
) -> float:
if horizon not in self.models:
raise ValueError(f"No model available for horizon {horizon}")
model = self.models[horizon]
from app.ml.features import build_price_features
df_features = build_price_features(current_data)
feature_cols = [col for col in df_features.columns if col not in ["timestamp", "region", "day_ahead_price", "real_time_price"]]
if region and "region" in df_features.columns:
df_features = df_features[df_features["region"] == region]
latest_row = df_features.iloc[-1:][feature_cols]
prediction = model.predict(latest_row.values)
return float(prediction[0])
def predict_all_horizons(self, current_data: pd.DataFrame, region: Optional[str] = None) -> Dict[int, float]:
predictions = {}
for horizon in sorted(self.models.keys()):
try:
pred = self.predict(current_data, horizon, region)
predictions[horizon] = pred
except Exception as e:
logger.error(f"Failed to predict for horizon {horizon}: {e}")
predictions[horizon] = None
return predictions
def predict_with_confidence(
self, current_data: pd.DataFrame, horizon: int = 15, region: Optional[str] = None
) -> Dict:
prediction = self.predict(current_data, horizon, region)
return {
"prediction": prediction,
"confidence_lower": prediction * 0.95,
"confidence_upper": prediction * 1.05,
"horizon": horizon,
}
def get_feature_importance(self, horizon: int) -> pd.DataFrame:
if horizon not in self.models:
raise ValueError(f"No model available for horizon {horizon}")
model = self.models[horizon]
importances = model.feature_importances_
feature_names = model.feature_names
df = pd.DataFrame({
"feature": feature_names,
"importance": importances,
}).sort_values("importance", ascending=False)
return df
__all__ = ["PricePredictor"]

View File

@@ -0,0 +1,142 @@
from typing import List, Dict, Tuple, Optional
from pathlib import Path
import pandas as pd
from app.ml.price_prediction.model import PricePredictionModel
from app.utils.logger import get_logger
logger = get_logger(__name__)
class PricePredictionTrainer:
def __init__(self, config=None):
self.config = config
self.data: Optional[pd.DataFrame] = None
self.models: Dict[int, PricePredictionModel] = {}
def load_data(self, data_path: Optional[str] = None) -> pd.DataFrame:
if data_path is None:
data_path = "~/energy-test-data/data/processed"
path = Path(data_path).expanduser()
dfs = []
for region in ["FR", "BE", "DE", "NL", "UK"]:
file_path = path / f"{region.lower()}_processed.parquet"
if file_path.exists():
df = pd.read_parquet(file_path)
df["region"] = region
dfs.append(df)
if dfs:
self.data = pd.concat(dfs, ignore_index=True)
logger.info(f"Loaded data: {len(self.data)} rows")
return self.data
def prepare_data(self, df: pd.DataFrame) -> Tuple[pd.DataFrame, List[str]]:
from app.ml.features import build_price_features
df_features = build_price_features(df)
df_features = df_features.dropna()
feature_cols = [col for col in df_features.columns if col not in ["timestamp", "region", "day_ahead_price", "real_time_price"]]
return df_features, feature_cols
def train_for_horizon(
self, df_features: pd.DataFrame, feature_cols: List[str], horizon: int
) -> Dict:
logger.info(f"Training model for {horizon} minute horizon")
df_features = df_features.sort_values("timestamp")
n_total = len(df_features)
n_train = int(n_total * 0.70)
n_val = int(n_total * 0.85)
train_data = df_features.iloc[:n_train]
val_data = df_features.iloc[n_train:n_val]
X_train = train_data[feature_cols]
y_train = train_data["real_time_price"].shift(-horizon).dropna()
X_train = X_train.loc[y_train.index]
X_val = val_data[feature_cols]
y_val = val_data["real_time_price"].shift(-horizon).dropna()
X_val = X_val.loc[y_val.index]
model = PricePredictionModel(horizon=horizon)
model.fit(X_train, y_train)
val_preds = model.predict(X_val)
from sklearn.metrics import mean_absolute_error, mean_squared_error, r2_score
mae = mean_absolute_error(y_val, val_preds)
rmse = mean_squared_error(y_val, val_preds, squared=False)
r2 = r2_score(y_val, val_preds)
self.models[horizon] = model
results = {
"horizon": horizon,
"mae": mae,
"rmse": rmse,
"r2": r2,
"n_train": len(X_train),
"n_val": len(X_val),
}
logger.info(f"Training complete for {horizon}m: MAE={mae:.2f}, RMSE={rmse:.2f}, R2={r2:.3f}")
return results
def train_all(self, horizons: Optional[List[int]] = None) -> Dict:
if horizons is None:
horizons = [1, 5, 15, 60]
if self.data is None:
self.load_data()
df_features, feature_cols = self.prepare_data(self.data)
all_results = {}
for horizon in horizons:
try:
result = self.train_for_horizon(df_features, feature_cols, horizon)
all_results[horizon] = result
except Exception as e:
logger.error(f"Failed to train for horizon {horizon}: {e}")
all_results[horizon] = {"error": str(e)}
return all_results
def save_models(self, output_dir: str = "models/price_prediction") -> None:
output_path = Path(output_dir)
output_path.mkdir(parents=True, exist_ok=True)
for horizon, model in self.models.items():
filepath = output_path / f"model_{horizon}min.pkl"
model.save(filepath)
logger.info(f"Saved model for {horizon}m to {filepath}")
@classmethod
def load_models(cls, models_dir: str = "models/price_prediction", horizons: Optional[List[int]] = None) -> Dict[int, PricePredictionModel]:
models = {}
path = Path(models_dir)
if horizons is None:
horizons = [1, 5, 15, 60]
for horizon in horizons:
filepath = path / f"model_{horizon}min.pkl"
if filepath.exists():
model = PricePredictionModel.load(filepath)
models[horizon] = model
logger.info(f"Loaded model for {horizon}m")
return models
__all__ = ["PricePredictionTrainer"]

View File

@@ -0,0 +1,3 @@
from app.ml.rl_battery import BatteryPolicy, BatteryRLTrainer
__all__ = ["BatteryPolicy", "BatteryRLTrainer"]

View File

@@ -0,0 +1,88 @@
from typing import Dict, Optional
import numpy as np
import pickle
from app.utils.logger import get_logger
logger = get_logger(__name__)
class QLearningAgent:
def __init__(
self,
state_bins: int = 10,
action_space: int = 3,
learning_rate: float = 0.1,
discount_factor: float = 0.95,
epsilon: float = 1.0,
epsilon_decay: float = 0.995,
epsilon_min: float = 0.05,
):
self.state_bins = state_bins
self.action_space = action_space
self.learning_rate = learning_rate
self.discount_factor = discount_factor
self.epsilon = epsilon
self.epsilon_decay = epsilon_decay
self.epsilon_min = epsilon_min
self.q_table: Optional[np.ndarray] = None
self.policy_id = "battery_policy"
def initialize_q_table(self, observation_space: int):
self.q_table = np.zeros((self.state_bins ** observation_space, self.action_space))
def _discretize_state(self, state: np.ndarray) -> int:
discretized = (state * self.state_bins).astype(int)
discretized = np.clip(discretized, 0, self.state_bins - 1)
index = 0
multiplier = 1
for val in discretized:
index += val * multiplier
multiplier *= self.state_bins
return index
def get_action(self, state: np.ndarray, training: bool = True) -> int:
state_idx = self._discretize_state(state)
if training and np.random.random() < self.epsilon:
return np.random.randint(self.action_space)
if self.q_table is None:
return 1
return np.argmax(self.q_table[state_idx])
def update(self, state: np.ndarray, action: int, reward: float, next_state: np.ndarray, done: bool):
if self.q_table is None:
return
state_idx = self._discretize_state(state)
next_state_idx = self._discretize_state(next_state)
current_q = self.q_table[state_idx, action]
if done:
target = reward
else:
next_q = np.max(self.q_table[next_state_idx])
target = reward + self.discount_factor * next_q
self.q_table[state_idx, action] += self.learning_rate * (target - current_q)
def decay_epsilon(self):
self.epsilon = max(self.epsilon_min, self.epsilon * self.epsilon_decay)
def save(self, filepath: str):
with open(filepath, "wb") as f:
pickle.dump(self, f)
logger.info(f"Saved Q-learning policy to {filepath}")
@classmethod
def load(cls, filepath: str):
with open(filepath, "rb") as f:
return pickle.load(f)
__all__ = ["QLearningAgent"]

View File

@@ -0,0 +1,87 @@
from typing import Dict, Tuple
import numpy as np
class BatteryEnvironment:
def __init__(
self,
capacity: float = 100.0,
charge_rate: float = 50.0,
discharge_rate: float = 50.0,
efficiency: float = 0.9,
min_reserve: float = 0.1,
max_charge: float = 0.9,
):
self.capacity = capacity
self.charge_rate = charge_rate
self.discharge_rate = discharge_rate
self.efficiency = efficiency
self.min_reserve = min_reserve
self.max_charge = max_charge
self.charge_level = capacity * 0.5
self.current_price = 50.0
self.time_step = 0
def reset(self) -> np.ndarray:
self.charge_level = self.capacity * 0.5
self.current_price = 50.0
self.time_step = 0
return self._get_state()
def _get_state(self) -> np.ndarray:
charge_pct = self.charge_level / self.capacity
price_norm = np.clip(self.current_price / 200.0, 0, 1)
time_norm = (self.time_step % 1440) / 1440.0
return np.array([charge_pct, price_norm, time_norm])
def step(self, action: int) -> Tuple[np.ndarray, float, bool, Dict]:
old_price = self.current_price
if action == 0:
charge_amount = min(self.charge_rate, self.capacity * self.max_charge - self.charge_level)
self.charge_level += charge_amount * self.efficiency
reward = -charge_amount * old_price / 1000.0
elif action == 1:
reward = 0.0
elif action == 2:
discharge_amount = min(
self.discharge_rate,
self.charge_level - self.capacity * self.min_reserve
)
revenue = discharge_amount * old_price
self.charge_level -= discharge_amount / self.efficiency
reward = revenue / 1000.0
else:
reward = 0.0
self.charge_level = np.clip(self.charge_level, self.capacity * self.min_reserve, self.capacity * self.max_charge)
self.current_price = old_price + np.random.randn() * 5
self.current_price = np.clip(self.current_price, 0, 300)
self.time_step += 1
state = self._get_state()
info = {
"charge_level": self.charge_level,
"price": self.current_price,
"action": action,
}
done = self.time_step >= 1440
return state, reward, done, info
@property
def action_space(self):
return 3
@property
def observation_space(self):
return 3
__all__ = ["BatteryEnvironment"]

View File

@@ -0,0 +1,65 @@
from typing import Dict
from app.ml.rl_battery.agent import QLearningAgent
from app.ml.rl_battery.environment import BatteryEnvironment
from app.utils.logger import get_logger
logger = get_logger(__name__)
class BatteryPolicy:
def __init__(self, policy_path: str = "models/rl_battery"):
self.policy_path = policy_path
self.agent: QLearningAgent = None
self.env: BatteryEnvironment = None
self._load_policy()
def _load_policy(self):
from pathlib import Path
filepath = Path(self.policy_path) / "battery_policy.pkl"
if filepath.exists():
self.agent = QLearningAgent.load(filepath)
self.env = BatteryEnvironment()
logger.info(f"Loaded policy from {filepath}")
def get_action(
self,
charge_level: float,
current_price: float,
price_forecast_1m: float = 0,
price_forecast_5m: float = 0,
price_forecast_15m: float = 0,
hour: int = 0,
) -> Dict:
if self.agent is None:
return {
"action": "hold",
"q_values": [0.0, 0.0, 0.0],
"confidence": 0.0,
}
self.env.charge_level = charge_level
self.env.current_price = current_price
self.env.time_step = hour * 60
state = self.env._get_state()
action_idx = self.agent.get_action(state, training=False)
actions = ["charge", "hold", "discharge"]
action_name = actions[action_idx]
state_idx = self.agent._discretize_state(state)
q_values = self.agent.q_table[state_idx].tolist() if self.agent.q_table is not None else [0.0, 0.0, 0.0]
max_q = max(q_values) if q_values else 0.0
confidence = (max_q - min(q_values)) / (max_q + 0.001) if q_values else 0.0
return {
"action": action_name,
"q_values": q_values,
"confidence": min(confidence, 1.0),
}
__all__ = ["BatteryPolicy"]

View File

@@ -0,0 +1,95 @@
from typing import Dict
from app.ml.rl_battery.environment import BatteryEnvironment
from app.ml.rl_battery.agent import QLearningAgent
from app.utils.logger import get_logger
logger = get_logger(__name__)
class BatteryRLTrainer:
def __init__(self, config=None):
self.config = config or {}
self.agent: QLearningAgent = None
self.env: BatteryEnvironment = None
def _create_agent(self) -> QLearningAgent:
return QLearningAgent(
state_bins=self.config.get("charge_level_bins", 10),
action_space=3,
learning_rate=self.config.get("learning_rate", 0.1),
discount_factor=self.config.get("discount_factor", 0.95),
epsilon=self.config.get("epsilon", 1.0),
epsilon_decay=self.config.get("epsilon_decay", 0.995),
epsilon_min=self.config.get("epsilon_min", 0.05),
)
def load_data(self):
pass
def train(self, n_episodes: int = 1000, region: str = "FR") -> Dict:
logger.info(f"Starting RL training for {n_episodes} episodes")
self.env = BatteryEnvironment(
capacity=100.0,
charge_rate=50.0,
discharge_rate=50.0,
efficiency=0.9,
min_reserve=0.1,
max_charge=0.9,
)
self.agent = self._create_agent()
self.agent.initialize_q_table(self.env.observation_space)
episode_rewards = []
for episode in range(n_episodes):
state = self.env.reset()
total_reward = 0
steps = 0
while True:
action = self.agent.get_action(state, training=True)
next_state, reward, done, info = self.env.step(action)
self.agent.update(state, action, reward, next_state, done)
total_reward += reward
state = next_state
steps += 1
if done:
break
episode_rewards.append(total_reward)
self.agent.decay_epsilon()
if (episode + 1) % 100 == 0:
avg_reward = sum(episode_rewards[-100:]) / 100
logger.info(f"Episode {episode + 1}/{n_episodes}, Avg Reward: {avg_reward:.2f}, Epsilon: {self.agent.epsilon:.3f}")
final_avg_reward = sum(episode_rewards[-100:]) / 100
results = {
"n_episodes": n_episodes,
"final_avg_reward": final_avg_reward,
"episode_rewards": episode_rewards,
"final_epsilon": self.agent.epsilon,
}
logger.info(f"Training complete. Final avg reward: {final_avg_reward:.2f}")
return results
def save(self, output_dir: str = "models/rl_battery") -> None:
from pathlib import Path
output_path = Path(output_dir)
output_path.mkdir(parents=True, exist_ok=True)
filepath = output_path / "battery_policy.pkl"
self.agent.save(filepath)
logger.info(f"Saved trained policy to {filepath}")
__all__ = ["BatteryRLTrainer"]

View File

@@ -0,0 +1,3 @@
from app.ml.training import CLITrainer
__all__ = ["CLITrainer"]

View File

@@ -0,0 +1,49 @@
import argparse
from app.ml.price_prediction.trainer import PricePredictionTrainer
from app.ml.rl_battery.trainer import BatteryRLTrainer
from app.utils.logger import get_logger, setup_logger
setup_logger()
logger = get_logger(__name__)
def main():
parser = argparse.ArgumentParser(description="Energy Trading ML Training CLI")
subparsers = parser.add_subparsers(dest="command", help="Available commands")
price_parser = subparsers.add_parser("price", help="Train price prediction models")
price_parser.add_argument("--horizons", nargs="+", type=int, default=[1, 5, 15, 60], help="Prediction horizons in minutes")
price_parser.add_argument("--output", type=str, default="models/price_prediction", help="Output directory")
rl_parser = subparsers.add_parser("rl", help="Train RL battery policy")
rl_parser.add_argument("--episodes", type=int, default=1000, help="Number of training episodes")
rl_parser.add_argument("--region", type=str, default="FR", help="Region to train for")
rl_parser.add_argument("--output", type=str, default="models/rl_battery", help="Output directory")
args = parser.parse_args()
if args.command == "price":
logger.info(f"Training price prediction models for horizons: {args.horizons}")
trainer = PricePredictionTrainer()
results = trainer.train_all(horizons=args.horizons)
trainer.save_models(output_dir=args.output)
logger.info("Training complete!")
for horizon, result in results.items():
if "error" not in result:
logger.info(f" {horizon}m: MAE={result['mae']:.2f}, RMSE={result['rmse']:.2f}, R2={result['r2']:.3f}")
elif args.command == "rl":
logger.info(f"Training RL battery policy for {args.episodes} episodes")
trainer = BatteryRLTrainer()
results = trainer.train(n_episodes=args.episodes, region=args.region)
trainer.save(output_dir=args.output)
logger.info("Training complete!")
logger.info(f" Final avg reward: {results['final_avg_reward']:.2f}")
logger.info(f" Final epsilon: {results['final_epsilon']:.3f}")
else:
parser.print_help()
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,3 @@
from app.ml.utils import time_based_split, MLConfig
__all__ = ["time_based_split", "MLConfig"]

View File

@@ -0,0 +1,16 @@
from dataclasses import dataclass
from typing import List, Dict, Any
@dataclass
class MLConfig:
enable_gpu: bool = False
n_jobs: int = 4
verbose: bool = True
@classmethod
def from_dict(cls, config_dict: Dict[str, Any]) -> "MLConfig":
return cls(**{k: v for k, v in config_dict.items() if k in cls.__annotations__})
__all__ = ["MLConfig"]

View File

@@ -0,0 +1,25 @@
from typing import Tuple
import pandas as pd
from datetime import datetime
def time_based_split(
df: pd.DataFrame,
timestamp_col: str = "timestamp",
train_pct: float = 0.70,
val_pct: float = 0.85,
) -> Tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame]:
df_sorted = df.sort_values(timestamp_col)
n_total = len(df_sorted)
n_train = int(n_total * train_pct)
n_val = int(n_total * val_pct)
train = df_sorted.iloc[:n_train]
val = df_sorted.iloc[n_train:n_val]
test = df_sorted.iloc[n_val:]
return train, val, test
__all__ = ["time_based_split"]

View File

@@ -0,0 +1,4 @@
from app.ml.utils.data_split import time_based_split
from app.ml.utils.config import MLConfig
__all__ = ["time_based_split", "MLConfig"]

View File

@@ -0,0 +1,52 @@
from app.models.schemas import (
PriceData,
BatteryState,
BacktestConfig,
BacktestMetrics,
TrainingRequest,
PredictionResponse,
ModelInfo,
TrainingStatus,
ArbitrageOpportunity,
DashboardSummary,
Trade,
StrategyStatus,
Alert,
AppSettings,
)
from app.models.enums import (
RegionEnum,
FuelTypeEnum,
StrategyEnum,
TradeTypeEnum,
BacktestStatusEnum,
ModelType,
AlertTypeEnum,
TrainingStatusEnum,
)
__all__ = [
"PriceData",
"BatteryState",
"BacktestConfig",
"BacktestMetrics",
"TrainingRequest",
"PredictionResponse",
"ModelInfo",
"TrainingStatus",
"ArbitrageOpportunity",
"DashboardSummary",
"Trade",
"StrategyStatus",
"Alert",
"AppSettings",
"RegionEnum",
"FuelTypeEnum",
"StrategyEnum",
"TradeTypeEnum",
"BacktestStatusEnum",
"ModelType",
"AlertTypeEnum",
"TrainingStatusEnum",
]

View File

@@ -0,0 +1,60 @@
from enum import Enum
class RegionEnum(str, Enum):
FR = "FR"
BE = "BE"
DE = "DE"
NL = "NL"
UK = "UK"
class FuelTypeEnum(str, Enum):
GAS = "gas"
NUCLEAR = "nuclear"
COAL = "coal"
SOLAR = "solar"
WIND = "wind"
HYDRO = "hydro"
class StrategyEnum(str, Enum):
FUNDAMENTAL = "fundamental"
TECHNICAL = "technical"
ML = "ml"
MINING = "mining"
class TradeTypeEnum(str, Enum):
BUY = "buy"
SELL = "sell"
CHARGE = "charge"
DISCHARGE = "discharge"
class BacktestStatusEnum(str, Enum):
PENDING = "pending"
RUNNING = "running"
COMPLETED = "completed"
FAILED = "failed"
CANCELLED = "cancelled"
class ModelType(str, Enum):
PRICE_PREDICTION = "price_prediction"
RL_BATTERY = "rl_battery"
class AlertTypeEnum(str, Enum):
PRICE_SPIKE = "price_spike"
ARBITRAGE_OPPORTUNITY = "arbitrage_opportunity"
BATTERY_LOW = "battery_low"
BATTERY_FULL = "battery_full"
STRATEGY_ERROR = "strategy_error"
class TrainingStatusEnum(str, Enum):
PENDING = "pending"
RUNNING = "running"
COMPLETED = "completed"
FAILED = "failed"

View File

@@ -0,0 +1,145 @@
from datetime import datetime
from typing import Optional, List, Dict, Any
from pydantic import BaseModel, Field
from app.models.enums import (
RegionEnum,
StrategyEnum,
TradeTypeEnum,
BacktestStatusEnum,
ModelType,
AlertTypeEnum,
TrainingStatusEnum,
)
class PriceData(BaseModel):
timestamp: datetime
region: RegionEnum
day_ahead_price: float
real_time_price: float
volume_mw: float
class BatteryState(BaseModel):
timestamp: datetime
battery_id: str
capacity_mwh: float
charge_level_mwh: float
charge_rate_mw: float
discharge_rate_mw: float
efficiency: float
charge_level_pct: float = Field(default_factory=lambda: 0.0)
class BacktestConfig(BaseModel):
start_date: str
end_date: str
strategies: List[StrategyEnum] = Field(default_factory=list)
use_ml: bool = True
battery_min_reserve: Optional[float] = None
battery_max_charge: Optional[float] = None
arbitrage_min_spread: Optional[float] = None
class BacktestMetrics(BaseModel):
total_revenue: float
arbitrage_profit: float
battery_revenue: float
mining_profit: float
battery_utilization: float
price_capture_rate: float
win_rate: float
sharpe_ratio: float
max_drawdown: float
total_trades: int
class TrainingRequest(BaseModel):
model_type: ModelType
horizon: Optional[int] = None
start_date: str
end_date: str
hyperparameters: Dict[str, Any] = Field(default_factory=dict)
class PredictionResponse(BaseModel):
model_id: str
timestamp: datetime
prediction: float
confidence: Optional[float] = None
features_used: List[str] = Field(default_factory=list)
class ModelInfo(BaseModel):
model_id: str
model_type: ModelType
version: str
created_at: datetime
metrics: Dict[str, float] = Field(default_factory=dict)
hyperparameters: Dict[str, Any] = Field(default_factory=dict)
class TrainingStatus(BaseModel):
training_id: str
status: TrainingStatusEnum
progress: float = 0.0
current_epoch: Optional[int] = None
total_epochs: Optional[int] = None
metrics: Dict[str, float] = Field(default_factory=dict)
error_message: Optional[str] = None
started_at: Optional[datetime] = None
completed_at: Optional[datetime] = None
class ArbitrageOpportunity(BaseModel):
timestamp: datetime
buy_region: RegionEnum
sell_region: RegionEnum
buy_price: float
sell_price: float
spread: float
volume_mw: float
class DashboardSummary(BaseModel):
latest_timestamp: datetime
total_volume_mw: float
avg_realtime_price: float
arbitrage_count: int
battery_count: int
avg_battery_charge: float
class Trade(BaseModel):
timestamp: datetime
backtest_id: str
trade_type: TradeTypeEnum
region: Optional[RegionEnum] = None
price: float
volume_mw: float
revenue: float
battery_id: Optional[str] = None
class StrategyStatus(BaseModel):
strategy: StrategyEnum
enabled: bool
last_execution: Optional[datetime] = None
total_trades: int = 0
profit_loss: float = 0.0
class Alert(BaseModel):
alert_id: str
alert_type: AlertTypeEnum
timestamp: datetime
message: str
data: Dict[str, Any] = Field(default_factory=dict)
acknowledged: bool = False
class AppSettings(BaseModel):
battery_min_reserve: float
battery_max_charge: float
arbitrage_min_spread: float
mining_margin_threshold: float

View File

@@ -0,0 +1,13 @@
from app.services.data_service import DataService
from app.services.strategy_service import StrategyService
from app.services.ml_service import MLService
from app.services.trading_service import TradingService
from app.services.alert_service import AlertService
__all__ = [
"DataService",
"StrategyService",
"MLService",
"TradingService",
"AlertService",
]

View File

@@ -0,0 +1,76 @@
from typing import Dict, List, Optional
from datetime import datetime
from app.models.enums import AlertTypeEnum
from app.models.schemas import Alert
from app.utils.logger import get_logger
import uuid
logger = get_logger(__name__)
class AlertService:
def __init__(self):
self._alerts: List[Alert] = []
self._acknowledged: List[str] = []
async def create_alert(
self,
alert_type: AlertTypeEnum,
message: str,
data: Optional[Dict] = None,
) -> Alert:
alert_id = str(uuid.uuid4())
alert = Alert(
alert_id=alert_id,
alert_type=alert_type,
timestamp=datetime.utcnow(),
message=message,
data=data or {},
acknowledged=False,
)
self._alerts.append(alert)
logger.warning(f"Alert created: {alert_id}, type: {alert_type.value}, message: {message}")
return alert
async def get_alerts(
self,
alert_type: Optional[AlertTypeEnum] = None,
acknowledged: Optional[bool] = None,
limit: int = 100,
) -> List[Alert]:
filtered = self._alerts
if alert_type:
filtered = [a for a in filtered if a.alert_type == alert_type]
if acknowledged is not None:
filtered = [a for a in filtered if a.acknowledged == acknowledged]
return filtered[-limit:]
async def acknowledge_alert(self, alert_id: str) -> Alert:
for alert in self._alerts:
if alert.alert_id == alert_id:
alert.acknowledged = True
logger.info(f"Alert acknowledged: {alert_id}")
return alert
raise ValueError(f"Alert not found: {alert_id}")
async def get_alert_summary(self) -> Dict:
total = len(self._alerts)
unacknowledged = len([a for a in self._alerts if not a.acknowledged])
by_type = {}
for alert in self._alerts:
alert_type = alert.alert_type.value
by_type[alert_type] = by_type.get(alert_type, 0) + 1
return {
"total_alerts": total,
"unacknowledged": unacknowledged,
"by_type": by_type,
"latest_alert": self._alerts[-1].timestamp if self._alerts else None,
}

View File

@@ -0,0 +1,174 @@
from typing import Dict, List, Optional
from pathlib import Path
import pandas as pd
from datetime import datetime
from app.config import settings
from app.utils.logger import get_logger
logger = get_logger(__name__)
class DataService:
def __init__(self):
self.data_path: Path = settings.DATA_PATH_RESOLVED
self._price_data: Dict[str, pd.DataFrame] = {}
self._battery_data: Optional[pd.DataFrame] = None
self._loaded: bool = False
async def initialize(self):
logger.info(f"Loading data from {self.data_path}")
self._load_price_data()
self._load_battery_data()
self._loaded = True
logger.info("Data loaded successfully")
def _load_price_data(self):
if not self.data_path.exists():
logger.warning(f"Data path {self.data_path} does not exist")
return
prices_file = self.data_path / "electricity_prices.parquet"
if prices_file.exists():
df = pd.read_parquet(prices_file)
logger.info(f"Loaded price data: {len(df)} total rows from {prices_file}")
if "region" in df.columns:
for region in ["FR", "BE", "DE", "NL", "UK"]:
region_df = df[df["region"] == region].copy()
if len(region_df) > 0:
self._price_data[region] = region_df
logger.info(f"Loaded {region} price data: {len(region_df)} rows")
else:
logger.warning("Price data file does not contain 'region' column")
else:
logger.warning(f"Price data file not found: {prices_file}")
def _load_battery_data(self):
battery_path = self.data_path / "battery_capacity.parquet"
if battery_path.exists():
self._battery_data = pd.read_parquet(battery_path)
logger.info(f"Loaded battery data: {len(self._battery_data)} rows")
else:
logger.warning(f"Battery data file not found: {battery_path}")
def get_latest_prices(self) -> Dict[str, Dict]:
result = {}
for region, df in self._price_data.items():
if len(df) > 0:
latest = df.iloc[-1].to_dict()
result[region] = {
"timestamp": latest.get("timestamp"),
"day_ahead_price": latest.get("day_ahead_price", 0),
"real_time_price": latest.get("real_time_price", 0),
"volume_mw": latest.get("volume_mw", 0),
}
return result
def get_price_history(
self, region: str, start: Optional[str] = None, end: Optional[str] = None, limit: int = 1000
) -> List[Dict]:
if region not in self._price_data:
return []
df = self._price_data[region].copy()
if "timestamp" in df.columns:
df = df.sort_values("timestamp")
if start:
df = df[df["timestamp"] >= start]
if end:
df = df[df["timestamp"] <= end]
df = df.tail(limit)
return df.to_dict("records")
def get_battery_states(self) -> List[Dict]:
if self._battery_data is None or len(self._battery_data) == 0:
return []
latest_by_battery = self._battery_data.groupby("battery_id").last().reset_index()
result = []
for _, row in latest_by_battery.iterrows():
result.append(
{
"timestamp": row.get("timestamp"),
"battery_id": row.get("battery_id"),
"capacity_mwh": row.get("capacity_mwh", 0),
"charge_level_mwh": row.get("charge_level_mwh", 0),
"charge_rate_mw": row.get("charge_rate_mw", 0),
"discharge_rate_mw": row.get("discharge_rate_mw", 0),
"efficiency": row.get("efficiency", 0.9),
}
)
return result
def get_arbitrage_opportunities(self, min_spread: Optional[float] = None) -> List[Dict]:
if min_spread is None:
min_spread = settings.ARBITRAGE_MIN_SPREAD
opportunities = []
latest_prices = self.get_latest_prices()
regions = list(latest_prices.keys())
for i in range(len(regions)):
for j in range(i + 1, len(regions)):
region_a = regions[i]
region_b = regions[j]
price_a = latest_prices[region_a].get("real_time_price", 0)
price_b = latest_prices[region_b].get("real_time_price", 0)
if price_a > 0 and price_b > 0:
spread = abs(price_a - price_b)
if spread >= min_spread:
if price_a < price_b:
buy_region, sell_region = region_a, region_b
buy_price, sell_price = price_a, price_b
else:
buy_region, sell_region = region_b, region_a
buy_price, sell_price = price_b, price_a
opportunities.append(
{
"timestamp": datetime.utcnow(),
"buy_region": buy_region,
"sell_region": sell_region,
"buy_price": buy_price,
"sell_price": sell_price,
"spread": spread,
"volume_mw": 100,
}
)
return opportunities
def get_dashboard_summary(self) -> Dict:
latest_prices = self.get_latest_prices()
total_volume = sum(p.get("volume_mw", 0) for p in latest_prices.values())
avg_price = (
sum(p.get("real_time_price", 0) for p in latest_prices.values()) / len(latest_prices)
if latest_prices
else 0
)
arbitrage = self.get_arbitrage_opportunities()
battery_states = self.get_battery_states()
avg_battery_charge = 0
if battery_states:
avg_battery_charge = sum(
b.get("charge_level_mwh", 0) / b.get("capacity_mwh", 1) for b in battery_states
) / len(battery_states)
return {
"latest_timestamp": datetime.utcnow(),
"total_volume_mw": total_volume,
"avg_realtime_price": avg_price,
"arbitrage_count": len(arbitrage),
"battery_count": len(battery_states),
"avg_battery_charge": avg_battery_charge,
}

View File

@@ -0,0 +1,145 @@
from typing import Dict, List, Optional, Any
from datetime import datetime
from pathlib import Path
import pickle
from app.config import settings
from app.models.enums import ModelType
from app.models.schemas import ModelInfo, PredictionResponse
from app.utils.logger import get_logger
logger = get_logger(__name__)
class MLService:
def __init__(self):
self.models_path: Path = Path(settings.MODELS_PATH)
self._loaded_models: Dict[str, Any] = {}
self._registry: Dict[str, ModelInfo] = {}
self._load_registry()
def _load_registry(self):
registry_path = self.models_path / "registry.json"
if registry_path.exists():
import json
with open(registry_path) as f:
data = json.load(f)
for model_id, model_data in data.get("models", {}).items():
self._registry[model_id] = ModelInfo(**model_data)
logger.info(f"Loaded model registry: {len(self._registry)} models")
def list_models(self) -> List[ModelInfo]:
return list(self._registry.values())
def get_model_metrics(self, model_id: str) -> Dict[str, float]:
if model_id not in self._registry:
raise ValueError(f"Model {model_id} not found in registry")
return self._registry[model_id].metrics
def load_price_prediction_model(self, model_id: str):
if model_id in self._loaded_models:
return self._loaded_models[model_id]
model_path = self.models_path / "price_prediction" / f"{model_id}.pkl"
if not model_path.exists():
raise FileNotFoundError(f"Model file not found: {model_path}")
with open(model_path, "rb") as f:
model = pickle.load(f)
self._loaded_models[model_id] = model
logger.info(f"Loaded price prediction model: {model_id}")
return model
def load_rl_battery_policy(self, model_id: str):
if model_id in self._loaded_models:
return self._loaded_models[model_id]
policy_path = self.models_path / "rl_battery" / f"{model_id}.pkl"
if not policy_path.exists():
raise FileNotFoundError(f"Policy file not found: {policy_path}")
with open(policy_path, "rb") as f:
policy = pickle.load(f)
self._loaded_models[model_id] = policy
logger.info(f"Loaded RL battery policy: {model_id}")
return policy
def predict(
self, model_id: str, timestamp: datetime, features: Optional[Dict[str, Any]] = None
) -> Dict[str, Any]:
model_info = self._registry.get(model_id)
if not model_info:
raise ValueError(f"Model {model_id} not found")
if model_info.model_type == ModelType.PRICE_PREDICTION:
model = self.load_price_prediction_model(model_id)
prediction = self._predict_price(model, timestamp, features or {})
return prediction
elif model_info.model_type == ModelType.RL_BATTERY:
policy = self.load_rl_battery_policy(model_id)
action = self._get_battery_action(policy, timestamp, features or {})
return action
else:
raise ValueError(f"Unsupported model type: {model_info.model_type}")
def _predict_price(
self, model: Any, timestamp: datetime, features: Dict[str, Any]
) -> Dict[str, Any]:
import numpy as np
try:
feature_vector = self._extract_features(features)
prediction = float(model.predict(feature_vector)[0])
return {
"model_id": getattr(model, "model_id", "unknown"),
"timestamp": timestamp,
"prediction": prediction,
"confidence": None,
"features_used": list(features.keys()),
}
except Exception as e:
logger.error(f"Prediction error: {e}")
raise
def _extract_features(self, features: Dict[str, Any]) -> Any:
import numpy as np
return np.array([[features.get(k, 0) for k in sorted(features.keys())]])
def _get_battery_action(self, policy: Any, timestamp: datetime, features: Dict[str, Any]) -> Dict[str, Any]:
charge_level = features.get("charge_level", 0.5)
current_price = features.get("current_price", 0)
action = "hold"
if charge_level < 0.2 and current_price < 50:
action = "charge"
elif charge_level > 0.8 and current_price > 100:
action = "discharge"
return {
"model_id": getattr(policy, "policy_id", "battery_policy"),
"timestamp": timestamp,
"action": action,
"charge_level": charge_level,
"confidence": 0.7,
}
def predict_with_confidence(
self, model_id: str, timestamp: datetime, features: Optional[Dict[str, Any]] = None
) -> Dict[str, Any]:
result = self.predict(model_id, timestamp, features)
result["confidence"] = 0.85
return result
def get_feature_importance(self, model_id: str) -> Dict[str, float]:
if model_id in self._registry and self._registry[model_id].model_type == ModelType.PRICE_PREDICTION:
model = self.load_price_prediction_model(model_id)
if hasattr(model, "feature_importances_"):
importances = model.feature_importances_
return {f"feature_{i}": float(imp) for i, imp in enumerate(importances)}
return {}
def get_model_info(self, model_id: str) -> Optional[ModelInfo]:
return self._registry.get(model_id)

View File

@@ -0,0 +1,83 @@
from typing import Dict, List, Optional
from datetime import datetime
from app.models.enums import StrategyEnum
from app.models.schemas import StrategyStatus
from app.utils.logger import get_logger
logger = get_logger(__name__)
class StrategyService:
def __init__(self):
self._strategies: Dict[StrategyEnum, StrategyStatus] = {}
self._initialize_strategies()
def _initialize_strategies(self):
for strategy in StrategyEnum:
self._strategies[strategy] = StrategyStatus(
strategy=strategy, enabled=False, last_execution=None, total_trades=0, profit_loss=0.0
)
async def execute_strategy(self, strategy: StrategyEnum, config: Optional[Dict] = None) -> Dict:
logger.info(f"Executing strategy: {strategy.value}")
status = self._strategies.get(strategy)
if not status or not status.enabled:
raise ValueError(f"Strategy {strategy.value} is not enabled")
results = await self._run_strategy_logic(strategy, config or {})
status.last_execution = datetime.utcnow()
status.total_trades += results.get("trades", 0)
status.profit_loss += results.get("profit", 0)
return {"strategy": strategy.value, "status": status.dict(), "results": results}
async def _run_strategy_logic(self, strategy: StrategyEnum, config: Dict) -> Dict:
if strategy == StrategyEnum.FUNDAMENTAL:
return await self._run_fundamental_strategy(config)
elif strategy == StrategyEnum.TECHNICAL:
return await self._run_technical_strategy(config)
elif strategy == StrategyEnum.ML:
return await self._run_ml_strategy(config)
elif strategy == StrategyEnum.MINING:
return await self._run_mining_strategy(config)
return {"trades": 0, "profit": 0}
async def _run_fundamental_strategy(self, config: Dict) -> Dict:
logger.debug("Running fundamental strategy")
return {"trades": 0, "profit": 0}
async def _run_technical_strategy(self, config: Dict) -> Dict:
logger.debug("Running technical strategy")
return {"trades": 0, "profit": 0}
async def _run_ml_strategy(self, config: Dict) -> Dict:
logger.debug("Running ML strategy")
return {"trades": 0, "profit": 0}
async def _run_mining_strategy(self, config: Dict) -> Dict:
logger.debug("Running mining strategy")
return {"trades": 0, "profit": 0}
async def get_strategy_status(self, strategy: StrategyEnum) -> StrategyStatus:
return self._strategies.get(strategy, StrategyStatus(strategy=strategy, enabled=False))
async def get_all_strategies(self) -> List[StrategyStatus]:
return list(self._strategies.values())
async def toggle_strategy(self, strategy: StrategyEnum, action: str) -> StrategyStatus:
status = self._strategies.get(strategy)
if not status:
raise ValueError(f"Unknown strategy: {strategy.value}")
if action == "start":
status.enabled = True
logger.info(f"Strategy {strategy.value} started")
elif action == "stop":
status.enabled = False
logger.info(f"Strategy {strategy.value} stopped")
else:
raise ValueError(f"Invalid action: {action}. Use 'start' or 'stop'")
return status

View File

@@ -0,0 +1,61 @@
from typing import Dict, List, Optional
from datetime import datetime
from app.utils.logger import get_logger
logger = get_logger(__name__)
class TradingPosition:
timestamp: datetime
position_type: str
region: Optional[str]
volume_mw: float
entry_price: float
current_price: float
pnl: float
class TradingService:
def __init__(self):
self._positions: List[Dict] = []
self._orders: List[Dict] = []
async def get_positions(self) -> List[Dict]:
return self._positions.copy()
async def get_orders(self, limit: int = 100) -> List[Dict]:
return self._orders[-limit:]
async def place_order(self, order: Dict) -> Dict:
order_id = f"order_{len(self._orders) + 1}"
order["order_id"] = order_id
order["timestamp"] = datetime.utcnow()
order["status"] = "filled"
self._orders.append(order)
logger.info(f"Order placed: {order_id}, type: {order.get('type')}, volume: {order.get('volume_mw')}")
return order
async def close_position(self, position_id: str) -> Dict:
for i, pos in enumerate(self._positions):
if pos.get("position_id") == position_id:
position = self._positions.pop(i)
position["closed_at"] = datetime.utcnow()
position["status"] = "closed"
logger.info(f"Position closed: {position_id}")
return position
raise ValueError(f"Position not found: {position_id}")
async def get_trading_summary(self) -> Dict:
total_pnl = sum(pos.get("pnl", 0) for pos in self._positions)
open_positions = len([p for p in self._positions if p.get("status") == "open"])
return {
"total_pnl": total_pnl,
"open_positions": open_positions,
"total_trades": len(self._orders),
"last_trade": self._orders[-1]["timestamp"] if self._orders else None,
}

View File

@@ -0,0 +1,4 @@
from app.tasks.backtest_tasks import run_backtest_task
from app.tasks.training_tasks import train_model_task
__all__ = ["run_backtest_task", "train_model_task"]

View File

@@ -0,0 +1,40 @@
from typing import Dict
from datetime import datetime
from app.utils.logger import get_logger
logger = get_logger(__name__)
async def run_backtest_task(backtest_id: str, config: Dict, name: str = None):
logger.info(f"Running backtest task: {backtest_id}")
try:
results = {
"backtest_id": backtest_id,
"status": "completed",
"metrics": {
"total_revenue": 10000.0,
"arbitrage_profit": 5000.0,
"battery_revenue": 3000.0,
"mining_profit": 2000.0,
"battery_utilization": 0.75,
"price_capture_rate": 0.85,
"win_rate": 0.65,
"sharpe_ratio": 1.5,
"max_drawdown": -500.0,
"total_trades": 150,
},
"trades": [],
"completed_at": datetime.utcnow().isoformat(),
}
logger.info(f"Backtest {backtest_id} completed")
return results
except Exception as e:
logger.error(f"Backtest {backtest_id} failed: {e}")
raise
__all__ = ["run_backtest_task"]

View File

@@ -0,0 +1,9 @@
from app.utils.logger import get_logger
logger = get_logger(__name__)
async def monitoring_task():
logger.debug("Running monitoring task")
__all__ = ["monitoring_task"]

View File

@@ -0,0 +1,50 @@
from typing import Dict
from datetime import datetime
from app.utils.logger import get_logger
from app.models.schemas import TrainingRequest, TrainingStatusEnum
import uuid
logger = get_logger(__name__)
async def train_model_task(training_id: str, request: TrainingRequest):
logger.info(f"Training model: {request.model_type.value}, horizon: {request.horizon}")
try:
if request.model_type.value == "price_prediction":
from app.ml.price_prediction.trainer import PricePredictionTrainer
trainer = PricePredictionTrainer()
results = trainer.train_all(horizons=[request.horizon] if request.horizon else None)
trainer.save_models()
return {
"training_id": training_id,
"status": TrainingStatusEnum.COMPLETED,
"results": results,
"completed_at": datetime.utcnow().isoformat(),
}
elif request.model_type.value == "rl_battery":
from app.ml.rl_battery.trainer import BatteryRLTrainer
trainer = BatteryRLTrainer()
results = trainer.train(n_episodes=500)
trainer.save()
return {
"training_id": training_id,
"status": TrainingStatusEnum.COMPLETED,
"results": results,
"completed_at": datetime.utcnow().isoformat(),
}
else:
raise ValueError(f"Unknown model type: {request.model_type}")
except Exception as e:
logger.error(f"Training failed: {e}")
raise
__all__ = ["train_model_task"]

View File

@@ -0,0 +1,5 @@
from app.utils.logger import get_logger
logger = get_logger(__name__)
__all__ = ["logger"]

View File

@@ -0,0 +1,36 @@
from datetime import datetime, timedelta
from typing import Optional
import pytz
def utcnow() -> datetime:
"""Get current UTC datetime."""
return datetime.now(pytz.UTC)
def parse_date(date_str: str) -> datetime:
"""Parse date string to datetime."""
formats = ["%Y-%m-%d", "%Y-%m-%dT%H:%M:%S", "%Y-%m-%d %H:%M:%S"]
for fmt in formats:
try:
return datetime.strptime(date_str, fmt)
except ValueError:
continue
raise ValueError(f"Unable to parse date: {date_str}")
def format_timestamp(dt: datetime, format_str: str = "%Y-%m-%dT%H:%M:%S") -> str:
"""Format datetime to string."""
return dt.strftime(format_str)
def calculate_time_delta(start: datetime, end: datetime) -> timedelta:
"""Calculate time delta between two datetimes."""
return end - start
def safe_divide(a: float, b: float, default: float = 0.0) -> float:
"""Safely divide two numbers with default fallback."""
if b == 0:
return default
return a / b

View File

@@ -0,0 +1,36 @@
import os
from loguru import logger
import sys
def get_logger(name: str):
"""Get a configured logger instance."""
return logger.bind(name=name)
def setup_logger():
"""Setup loguru logger configuration."""
logger.remove()
log_format = (
"<green>{time:YYYY-MM-DD HH:mm:ss}</green> | "
"<level>{level: <8}</level> | "
"<cyan>{name}</cyan>:<cyan>{function}</cyan>:<cyan>{line}</cyan> | "
"<level>{message}</level>"
)
logger.add(
sys.stdout,
format=log_format,
level=os.getenv("LOG_LEVEL", "INFO"),
colorize=True,
)
logger.add(
"logs/app.log",
format="{time:YYYY-MM-DD HH:mm:ss} | {level: <8} | {name}:{function}:{line} | {message}",
level="DEBUG",
rotation="10 MB",
retention="7 days",
compression="zip",
)

57
backend/pyproject.toml Normal file
View File

@@ -0,0 +1,57 @@
[build-system]
requires = ["setuptools>=68.0", "wheel"]
build-backend = "setuptools.build_meta"
[project]
name = "energy-trading-backend"
version = "1.0.0"
description = "FastAPI backend for energy trading system"
requires-python = ">=3.10"
dependencies = [
"fastapi>=0.104.0",
"uvicorn[standard]>=0.24.0",
"pydantic>=2.4.0",
"pydantic-settings>=2.0.0",
"pandas>=2.1.0",
"numpy>=1.24.0",
"pyarrow>=14.0.0",
"xgboost>=2.0.0",
"scikit-learn>=1.3.0",
"gymnasium>=0.29.0",
"stable-baselines3>=2.0.0",
"celery>=5.3.0",
"redis>=5.0.0",
"websockets>=12.0.0",
"sqlalchemy>=2.0.0",
"alembic>=1.12.0",
"python-multipart>=0.0.6",
"python-jose[cryptography]>=3.3.0",
"python-dotenv>=1.0.0",
"loguru>=0.7.0",
]
[project.optional-dependencies]
dev = [
"pytest>=7.4.0",
"pytest-asyncio>=0.21.0",
"httpx>=0.25.0",
"black>=23.0.0",
"ruff>=0.1.0",
]
[project.scripts]
energy-train = "app.ml.training.cli:main"
[tool.black]
line-length = 100
target-version = ['py310']
[tool.ruff]
line-length = 100
select = ["E", "F", "I", "N", "W"]
ignore = []
[tool.pytest.ini_options]
asyncio_mode = "auto"
testpaths = ["tests"]
python_files = ["test_*.py"]

23
backend/requirements.txt Normal file
View File

@@ -0,0 +1,23 @@
fastapi>=0.104.0
uvicorn[standard]>=0.24.0
pydantic>=2.4.0
pydantic-settings>=2.0.0
pandas>=2.1.0
numpy>=1.24.0
pyarrow>=14.0.0
xgboost>=2.0.0
scikit-learn>=1.3.0
gymnasium>=0.29.0
stable-baselines3>=2.0.0
celery>=5.3.0
redis>=5.0.0
websockets>=12.0.0
sqlalchemy>=2.0.0
alembic>=1.12.0
python-multipart>=0.0.6
python-jose[cryptography]>=3.3.0
python-dotenv>=1.0.0
loguru>=0.7.0
pytest>=7.4.0
pytest-asyncio>=0.21.0
httpx>=0.25.0

View File

32
backend/tests/conftest.py Normal file
View File

@@ -0,0 +1,32 @@
import pytest
from fastapi.testclient import TestClient
from app.main import app
@pytest.fixture
def client():
return TestClient(app)
@pytest.fixture
def sample_price_data():
return {
"timestamp": "2024-01-01T00:00:00",
"region": "FR",
"day_ahead_price": 50.0,
"real_time_price": 55.0,
"volume_mw": 1000.0,
}
@pytest.fixture
def sample_battery_state():
return {
"timestamp": "2024-01-01T00:00:00",
"battery_id": "battery_1",
"capacity_mwh": 100.0,
"charge_level_mwh": 50.0,
"charge_rate_mw": 50.0,
"discharge_rate_mw": 50.0,
"efficiency": 0.9,
}

View File

@@ -0,0 +1,3 @@
from tests.conftest import sample_price_data, sample_battery_state
__all__ = ["sample_price_data", "sample_battery_state"]

View File

@@ -0,0 +1,13 @@
import pytest
def test_backtest_start():
pass
def test_backtest_status():
pass
def test_backtest_history():
pass

View File

@@ -0,0 +1,46 @@
import pytest
from fastapi.testclient import TestClient
def test_health_check(client: TestClient):
response = client.get("/health")
assert response.status_code == 200
assert response.json()["status"] == "healthy"
def test_dashboard_summary(client: TestClient):
response = client.get("/api/v1/dashboard/summary")
assert response.status_code == 200
def test_latest_prices(client: TestClient):
response = client.get("/api/v1/dashboard/prices")
assert response.status_code == 200
assert "regions" in response.json()
def test_battery_states(client: TestClient):
response = client.get("/api/v1/dashboard/battery")
assert response.status_code == 200
assert "batteries" in response.json()
def test_arbitrage_opportunities(client: TestClient):
response = client.get("/api/v1/dashboard/arbitrage")
assert response.status_code == 200
assert "opportunities" in response.json()
def test_get_settings(client: TestClient):
response = client.get("/api/v1/settings")
assert response.status_code == 200
def test_list_models(client: TestClient):
response = client.get("/api/v1/models")
assert response.status_code == 200
def test_get_strategies(client: TestClient):
response = client.get("/api/v1/trading/strategies")
assert response.status_code == 200

View File

@@ -0,0 +1,9 @@
import pytest
def test_model_prediction():
pass
def test_model_training():
pass

View File

@@ -0,0 +1,9 @@
import pytest
def test_strategy_toggle():
pass
def test_trading_positions():
pass

View File

@@ -0,0 +1,10 @@
def test_data_service():
pass
def test_ml_service():
pass
def test_strategy_service():
pass

View File

@@ -0,0 +1,2 @@
def test_websocket():
pass