Add FastAPI backend for energy trading system
Implements FastAPI backend with ML model support for energy trading, including price prediction models and RL-based battery trading policy. Features dashboard, trading, backtest, and settings API routes with WebSocket support for real-time updates.
This commit is contained in:
24
backend/.env.example
Normal file
24
backend/.env.example
Normal file
@@ -0,0 +1,24 @@
|
|||||||
|
APP_NAME=Energy Trading API
|
||||||
|
APP_VERSION=1.0.0
|
||||||
|
DEBUG=true
|
||||||
|
|
||||||
|
HOST=0.0.0.0
|
||||||
|
PORT=8000
|
||||||
|
|
||||||
|
DATA_PATH=~/energy-test-data/data/processed
|
||||||
|
|
||||||
|
CORS_ORIGINS=http://localhost:3000,http://localhost:5173
|
||||||
|
|
||||||
|
WS_HEARTBEAT_INTERVAL=30
|
||||||
|
|
||||||
|
CELERY_BROKER_URL=redis://localhost:6379/0
|
||||||
|
CELERY_RESULT_BACKEND=redis://localhost:6379/0
|
||||||
|
|
||||||
|
MODELS_PATH=models
|
||||||
|
RESULTS_PATH=results
|
||||||
|
|
||||||
|
BATTERY_MIN_RESERVE=0.10
|
||||||
|
BATTERY_MAX_CHARGE=0.90
|
||||||
|
|
||||||
|
ARBITRAGE_MIN_SPREAD=5.0
|
||||||
|
MINING_MARGIN_THRESHOLD=5.0
|
||||||
46
backend/.gitignore
vendored
Normal file
46
backend/.gitignore
vendored
Normal file
@@ -0,0 +1,46 @@
|
|||||||
|
.env
|
||||||
|
.venv
|
||||||
|
venv/
|
||||||
|
__pycache__/
|
||||||
|
*.py[cod]
|
||||||
|
*$py.class
|
||||||
|
*.so
|
||||||
|
.Python
|
||||||
|
build/
|
||||||
|
develop-eggs/
|
||||||
|
dist/
|
||||||
|
downloads/
|
||||||
|
eggs/
|
||||||
|
.eggs/
|
||||||
|
lib/
|
||||||
|
lib64/
|
||||||
|
parts/
|
||||||
|
sdist/
|
||||||
|
var/
|
||||||
|
wheels/
|
||||||
|
*.egg-info/
|
||||||
|
.installed.cfg
|
||||||
|
*.egg
|
||||||
|
.pytest_cache/
|
||||||
|
.coverage
|
||||||
|
htmlcov/
|
||||||
|
*.cover
|
||||||
|
.hypothesis/
|
||||||
|
.tox/
|
||||||
|
.mypy_cache/
|
||||||
|
.dmypy.json
|
||||||
|
dmypy.json
|
||||||
|
*.log
|
||||||
|
logs/
|
||||||
|
server.log
|
||||||
|
*.pid
|
||||||
|
.vscode/
|
||||||
|
.idea/
|
||||||
|
*.swp
|
||||||
|
*.swo
|
||||||
|
*~
|
||||||
|
.DS_Store
|
||||||
|
Thumbs.db
|
||||||
|
*.bak
|
||||||
|
*.tmp
|
||||||
|
*.orig
|
||||||
12
backend/Dockerfile
Normal file
12
backend/Dockerfile
Normal file
@@ -0,0 +1,12 @@
|
|||||||
|
FROM python:3.10-slim
|
||||||
|
|
||||||
|
WORKDIR /app
|
||||||
|
|
||||||
|
COPY requirements.txt .
|
||||||
|
RUN pip install --no-cache-dir -r requirements.txt
|
||||||
|
|
||||||
|
COPY . .
|
||||||
|
|
||||||
|
EXPOSE 8000
|
||||||
|
|
||||||
|
CMD ["uvicorn", "app.main:app", "--host", "0.0.0.0", "--port", "8000"]
|
||||||
54
backend/README.md
Normal file
54
backend/README.md
Normal file
@@ -0,0 +1,54 @@
|
|||||||
|
# Energy Trading Backend
|
||||||
|
|
||||||
|
FastAPI backend for the energy trading system with ML model support.
|
||||||
|
|
||||||
|
## Setup
|
||||||
|
|
||||||
|
```bash
|
||||||
|
cd backend
|
||||||
|
pip install -r requirements.txt
|
||||||
|
cp .env.example .env
|
||||||
|
```
|
||||||
|
|
||||||
|
## Running
|
||||||
|
|
||||||
|
```bash
|
||||||
|
uvicorn app.main:app --reload --host 0.0.0.0 --port 8000
|
||||||
|
```
|
||||||
|
|
||||||
|
## API Documentation
|
||||||
|
|
||||||
|
- Swagger UI: http://localhost:8000/docs
|
||||||
|
- ReDoc: http://localhost:8000/redoc
|
||||||
|
|
||||||
|
## Project Structure
|
||||||
|
|
||||||
|
```
|
||||||
|
backend/
|
||||||
|
├── app/
|
||||||
|
│ ├── api/ # API routes and WebSocket
|
||||||
|
│ ├── services/ # Business logic services
|
||||||
|
│ ├── tasks/ # Background tasks
|
||||||
|
│ ├── ml/ # ML models and training
|
||||||
|
│ ├── models/ # Pydantic models
|
||||||
|
│ └── utils/ # Utilities
|
||||||
|
├── models/ # Trained models
|
||||||
|
├── results/ # Backtest results
|
||||||
|
└── tests/ # Tests
|
||||||
|
```
|
||||||
|
|
||||||
|
## Training ML Models
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Train price prediction models
|
||||||
|
python -m app.ml.training.cli price --horizons 1 5 15 60
|
||||||
|
|
||||||
|
# Train RL battery policy
|
||||||
|
python -m app.ml.training.cli rl --episodes 1000
|
||||||
|
```
|
||||||
|
|
||||||
|
## Running Tests
|
||||||
|
|
||||||
|
```bash
|
||||||
|
pytest
|
||||||
|
```
|
||||||
7
backend/app/__init__.py
Normal file
7
backend/app/__init__.py
Normal file
@@ -0,0 +1,7 @@
|
|||||||
|
from fastapi import FastAPI
|
||||||
|
from app.config import settings
|
||||||
|
from app.utils.logger import get_logger
|
||||||
|
|
||||||
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
__all__ = ["app", "settings", "logger"]
|
||||||
3
backend/app/api/__init__.py
Normal file
3
backend/app/api/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
|||||||
|
from app.api.routes import dashboard, backtest, models, trading, settings
|
||||||
|
|
||||||
|
__all__ = ["dashboard", "backtest", "models", "trading", "settings"]
|
||||||
3
backend/app/api/routes/__init__.py
Normal file
3
backend/app/api/routes/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
|||||||
|
from app.api.routes import dashboard, backtest, models, trading, settings
|
||||||
|
|
||||||
|
__all__ = ["dashboard", "backtest", "models", "trading", "settings"]
|
||||||
108
backend/app/api/routes/backtest.py
Normal file
108
backend/app/api/routes/backtest.py
Normal file
@@ -0,0 +1,108 @@
|
|||||||
|
from typing import List, Optional
|
||||||
|
from fastapi import APIRouter, BackgroundTasks, HTTPException
|
||||||
|
from datetime import datetime
|
||||||
|
from app.models.schemas import (
|
||||||
|
BacktestConfig,
|
||||||
|
BacktestMetrics,
|
||||||
|
BacktestStatusEnum,
|
||||||
|
Trade,
|
||||||
|
)
|
||||||
|
from app.utils.logger import get_logger
|
||||||
|
import uuid
|
||||||
|
|
||||||
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
router = APIRouter()
|
||||||
|
|
||||||
|
_backtest_store: dict = {}
|
||||||
|
_backtest_results: dict = {}
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/start")
|
||||||
|
async def start_backtest(config: BacktestConfig, name: Optional[str] = None):
|
||||||
|
backtest_id = str(uuid.uuid4())
|
||||||
|
|
||||||
|
_backtest_store[backtest_id] = {
|
||||||
|
"backtest_id": backtest_id,
|
||||||
|
"name": name or f"Backtest {backtest_id[:8]}",
|
||||||
|
"status": BacktestStatusEnum.RUNNING,
|
||||||
|
"config": config.dict(),
|
||||||
|
"created_at": datetime.utcnow(),
|
||||||
|
"started_at": datetime.utcnow(),
|
||||||
|
"completed_at": None,
|
||||||
|
"error_message": None,
|
||||||
|
}
|
||||||
|
|
||||||
|
logger.info(f"Backtest started: {backtest_id}")
|
||||||
|
|
||||||
|
return {
|
||||||
|
"backtest_id": backtest_id,
|
||||||
|
"status": BacktestStatusEnum.RUNNING,
|
||||||
|
"name": _backtest_store[backtest_id]["name"],
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/{backtest_id}")
|
||||||
|
async def get_backtest_status(backtest_id: str):
|
||||||
|
if backtest_id not in _backtest_store:
|
||||||
|
raise HTTPException(status_code=404, detail=f"Backtest {backtest_id} not found")
|
||||||
|
|
||||||
|
backtest = _backtest_store[backtest_id]
|
||||||
|
result = _backtest_results.get(backtest_id)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"status": backtest["status"],
|
||||||
|
"name": backtest["name"],
|
||||||
|
"created_at": backtest["created_at"],
|
||||||
|
"started_at": backtest["started_at"],
|
||||||
|
"completed_at": backtest["completed_at"],
|
||||||
|
"error_message": backtest["error_message"],
|
||||||
|
"results": result if backtest["status"] == BacktestStatusEnum.COMPLETED else None,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/{backtest_id}/results")
|
||||||
|
async def get_backtest_results(backtest_id: str):
|
||||||
|
if backtest_id not in _backtest_results:
|
||||||
|
raise HTTPException(status_code=404, detail=f"Results for backtest {backtest_id} not found")
|
||||||
|
|
||||||
|
return _backtest_results[backtest_id]
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/{backtest_id}/trades")
|
||||||
|
async def get_backtest_trades(backtest_id: str, limit: int = 100):
|
||||||
|
if backtest_id not in _backtest_store:
|
||||||
|
raise HTTPException(status_code=404, detail=f"Backtest {backtest_id} not found")
|
||||||
|
|
||||||
|
trades = _backtest_results.get(backtest_id, {}).get("trades", [])
|
||||||
|
return {"backtest_id": backtest_id, "trades": trades[-limit:], "total": len(trades)}
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/history")
|
||||||
|
async def get_backtest_history():
|
||||||
|
backtests = []
|
||||||
|
for backtest_id, backtest in _backtest_store.items():
|
||||||
|
backtests.append(
|
||||||
|
{
|
||||||
|
"backtest_id": backtest_id,
|
||||||
|
"name": backtest["name"],
|
||||||
|
"status": backtest["status"],
|
||||||
|
"created_at": backtest["created_at"],
|
||||||
|
"completed_at": backtest["completed_at"],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
return {"backtests": backtests, "total": len(backtests)}
|
||||||
|
|
||||||
|
|
||||||
|
@router.delete("/{backtest_id}")
|
||||||
|
async def delete_backtest(backtest_id: str):
|
||||||
|
if backtest_id not in _backtest_store:
|
||||||
|
raise HTTPException(status_code=404, detail=f"Backtest {backtest_id} not found")
|
||||||
|
|
||||||
|
del _backtest_store[backtest_id]
|
||||||
|
if backtest_id in _backtest_results:
|
||||||
|
del _backtest_results[backtest_id]
|
||||||
|
|
||||||
|
logger.info(f"Backtest deleted: {backtest_id}")
|
||||||
|
return {"message": f"Backtest {backtest_id} deleted"}
|
||||||
48
backend/app/api/routes/dashboard.py
Normal file
48
backend/app/api/routes/dashboard.py
Normal file
@@ -0,0 +1,48 @@
|
|||||||
|
from typing import List
|
||||||
|
from fastapi import APIRouter, Depends, Query
|
||||||
|
from app.models.schemas import DashboardSummary, ArbitrageOpportunity, PriceData, BatteryState
|
||||||
|
from app.services import DataService
|
||||||
|
|
||||||
|
router = APIRouter()
|
||||||
|
data_service = DataService()
|
||||||
|
|
||||||
|
|
||||||
|
@router.on_event("startup")
|
||||||
|
async def startup():
|
||||||
|
await data_service.initialize()
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/summary", response_model=DashboardSummary)
|
||||||
|
async def get_summary():
|
||||||
|
summary = data_service.get_dashboard_summary()
|
||||||
|
return DashboardSummary(**summary)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/prices")
|
||||||
|
async def get_latest_prices():
|
||||||
|
return {"regions": data_service.get_latest_prices()}
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/prices/history")
|
||||||
|
async def get_price_history(
|
||||||
|
region: str = Query(..., description="Region code (FR, BE, DE, NL, UK)"),
|
||||||
|
start: str = Query(None, description="Start date (YYYY-MM-DD)"),
|
||||||
|
end: str = Query(None, description="End date (YYYY-MM-DD)"),
|
||||||
|
limit: int = Query(1000, description="Maximum number of records"),
|
||||||
|
):
|
||||||
|
data = data_service.get_price_history(region, start, end, limit)
|
||||||
|
return {"region": region, "data": data}
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/battery")
|
||||||
|
async def get_battery_states():
|
||||||
|
batteries = data_service.get_battery_states()
|
||||||
|
return {"batteries": batteries}
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/arbitrage")
|
||||||
|
async def get_arbitrage_opportunities(
|
||||||
|
min_spread: float = Query(None, description="Minimum spread in EUR/MWh")
|
||||||
|
):
|
||||||
|
opportunities = data_service.get_arbitrage_opportunities(min_spread)
|
||||||
|
return {"opportunities": opportunities, "count": len(opportunities)}
|
||||||
71
backend/app/api/routes/models.py
Normal file
71
backend/app/api/routes/models.py
Normal file
@@ -0,0 +1,71 @@
|
|||||||
|
from typing import List
|
||||||
|
from fastapi import APIRouter, HTTPException
|
||||||
|
from datetime import datetime
|
||||||
|
from app.models.schemas import ModelInfo, TrainingRequest, TrainingStatus, PredictionResponse
|
||||||
|
from app.services import MLService
|
||||||
|
import uuid
|
||||||
|
|
||||||
|
router = APIRouter()
|
||||||
|
ml_service = MLService()
|
||||||
|
|
||||||
|
_training_store: dict = {}
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("", response_model=List[ModelInfo])
|
||||||
|
async def list_models():
|
||||||
|
return ml_service.list_models()
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/train")
|
||||||
|
async def train_model(request: TrainingRequest):
|
||||||
|
training_id = f"training_{uuid.uuid4().hex[:8]}"
|
||||||
|
|
||||||
|
_training_store[training_id] = TrainingStatus(
|
||||||
|
training_id=training_id,
|
||||||
|
status="pending",
|
||||||
|
progress=0.0,
|
||||||
|
started_at=datetime.utcnow(),
|
||||||
|
)
|
||||||
|
|
||||||
|
return {"training_id": training_id, "status": _training_store[training_id]}
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/{training_id}/status", response_model=TrainingStatus)
|
||||||
|
async def get_training_status(training_id: str):
|
||||||
|
if training_id not in _training_store:
|
||||||
|
raise HTTPException(status_code=404, detail=f"Training job {training_id} not found")
|
||||||
|
|
||||||
|
return _training_store[training_id]
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/{model_id}/metrics")
|
||||||
|
async def get_model_metrics(model_id: str):
|
||||||
|
try:
|
||||||
|
metrics = ml_service.get_model_metrics(model_id)
|
||||||
|
return {"model_id": model_id, "metrics": metrics}
|
||||||
|
except ValueError as e:
|
||||||
|
raise HTTPException(status_code=404, detail=str(e))
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/predict", response_model=PredictionResponse)
|
||||||
|
async def predict(
|
||||||
|
model_id: str,
|
||||||
|
timestamp: datetime,
|
||||||
|
features: dict = None,
|
||||||
|
):
|
||||||
|
try:
|
||||||
|
result = ml_service.predict(model_id, timestamp, features)
|
||||||
|
return PredictionResponse(**result)
|
||||||
|
except ValueError as e:
|
||||||
|
raise HTTPException(status_code=404, detail=str(e))
|
||||||
|
except Exception as e:
|
||||||
|
raise HTTPException(status_code=500, detail=str(e))
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/{model_id}/feature-importance")
|
||||||
|
async def get_feature_importance(model_id: str):
|
||||||
|
try:
|
||||||
|
importance = ml_service.get_feature_importance(model_id)
|
||||||
|
return {"model_id": model_id, "feature_importance": importance}
|
||||||
|
except ValueError as e:
|
||||||
|
raise HTTPException(status_code=404, detail=str(e))
|
||||||
26
backend/app/api/routes/settings.py
Normal file
26
backend/app/api/routes/settings.py
Normal file
@@ -0,0 +1,26 @@
|
|||||||
|
from fastapi import APIRouter
|
||||||
|
from app.config import settings
|
||||||
|
from app.models.schemas import AppSettings
|
||||||
|
|
||||||
|
router = APIRouter()
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("", response_model=AppSettings)
|
||||||
|
async def get_settings():
|
||||||
|
return AppSettings(
|
||||||
|
battery_min_reserve=settings.BATTERY_MIN_RESERVE,
|
||||||
|
battery_max_charge=settings.BATTERY_MAX_CHARGE,
|
||||||
|
arbitrage_min_spread=settings.ARBITRAGE_MIN_SPREAD,
|
||||||
|
mining_margin_threshold=settings.MINING_MARGIN_THRESHOLD,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("")
|
||||||
|
async def update_settings(settings_update: dict):
|
||||||
|
updated_fields = []
|
||||||
|
for key, value in settings_update.items():
|
||||||
|
if hasattr(settings, key.upper()):
|
||||||
|
setattr(settings, key.upper(), value)
|
||||||
|
updated_fields.append(key)
|
||||||
|
|
||||||
|
return {"message": "Settings updated", "updated_fields": updated_fields}
|
||||||
27
backend/app/api/routes/trading.py
Normal file
27
backend/app/api/routes/trading.py
Normal file
@@ -0,0 +1,27 @@
|
|||||||
|
from typing import List
|
||||||
|
from fastapi import APIRouter, HTTPException
|
||||||
|
from app.models.enums import StrategyEnum
|
||||||
|
from app.models.schemas import StrategyStatus
|
||||||
|
from app.services import StrategyService
|
||||||
|
|
||||||
|
router = APIRouter()
|
||||||
|
strategy_service = StrategyService()
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/strategies", response_model=List[StrategyStatus])
|
||||||
|
async def get_strategies():
|
||||||
|
return await strategy_service.get_all_strategies()
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/strategies")
|
||||||
|
async def toggle_strategy(strategy: StrategyEnum, action: str):
|
||||||
|
if action not in ["start", "stop"]:
|
||||||
|
raise HTTPException(status_code=400, detail="Action must be 'start' or 'stop'")
|
||||||
|
|
||||||
|
status = await strategy_service.toggle_strategy(strategy, action)
|
||||||
|
return {"status": status}
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/positions")
|
||||||
|
async def get_positions():
|
||||||
|
return {"positions": [], "total": 0}
|
||||||
62
backend/app/api/websocket.py
Normal file
62
backend/app/api/websocket.py
Normal file
@@ -0,0 +1,62 @@
|
|||||||
|
from fastapi import WebSocket
|
||||||
|
from typing import List
|
||||||
|
from app.utils.logger import get_logger
|
||||||
|
|
||||||
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class ConnectionManager:
|
||||||
|
def __init__(self):
|
||||||
|
self.active_connections: List[WebSocket] = []
|
||||||
|
|
||||||
|
async def connect(self, websocket: WebSocket):
|
||||||
|
await websocket.accept()
|
||||||
|
self.active_connections.append(websocket)
|
||||||
|
logger.info(f"WebSocket connected. Total connections: {len(self.active_connections)}")
|
||||||
|
|
||||||
|
def disconnect(self, websocket: WebSocket):
|
||||||
|
if websocket in self.active_connections:
|
||||||
|
self.active_connections.remove(websocket)
|
||||||
|
logger.info(f"WebSocket disconnected. Total connections: {len(self.active_connections)}")
|
||||||
|
|
||||||
|
async def broadcast(self, event_type: str, data: dict):
|
||||||
|
message = {"event_type": event_type, "data": data, "timestamp": None}
|
||||||
|
disconnected = []
|
||||||
|
|
||||||
|
for connection in self.active_connections:
|
||||||
|
try:
|
||||||
|
await connection.send_json(message)
|
||||||
|
except Exception:
|
||||||
|
disconnected.append(connection)
|
||||||
|
|
||||||
|
for conn in disconnected:
|
||||||
|
self.disconnect(conn)
|
||||||
|
|
||||||
|
async def broadcast_price_update(self, region: str, price_data: dict):
|
||||||
|
await self.broadcast("price_update", {"region": region, "price_data": price_data})
|
||||||
|
|
||||||
|
async def broadcast_battery_update(self, battery_id: str, battery_state: dict):
|
||||||
|
await self.broadcast("battery_update", {"battery_id": battery_id, "battery_state": battery_state})
|
||||||
|
|
||||||
|
async def broadcast_trade(self, trade: dict):
|
||||||
|
await self.broadcast("trade_executed", trade)
|
||||||
|
|
||||||
|
async def broadcast_alert(self, alert: dict):
|
||||||
|
await self.broadcast("alert_triggered", alert)
|
||||||
|
|
||||||
|
async def broadcast_backtest_progress(self, backtest_id: str, progress: float, status: str):
|
||||||
|
await self.broadcast(
|
||||||
|
"backtest_progress",
|
||||||
|
{"backtest_id": backtest_id, "progress": progress, "status": status},
|
||||||
|
)
|
||||||
|
|
||||||
|
async def broadcast_model_training_progress(
|
||||||
|
self, model_id: str, progress: float, epoch: int = None, metrics: dict = None
|
||||||
|
):
|
||||||
|
await self.broadcast(
|
||||||
|
"model_training_progress",
|
||||||
|
{"model_id": model_id, "progress": progress, "epoch": epoch, "metrics": metrics},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
manager = ConnectionManager()
|
||||||
48
backend/app/config.py
Normal file
48
backend/app/config.py
Normal file
@@ -0,0 +1,48 @@
|
|||||||
|
from pydantic_settings import BaseSettings, SettingsConfigDict
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import List, Union
|
||||||
|
|
||||||
|
|
||||||
|
class Settings(BaseSettings):
|
||||||
|
APP_NAME: str = "Energy Trading API"
|
||||||
|
APP_VERSION: str = "1.0.0"
|
||||||
|
DEBUG: bool = True
|
||||||
|
|
||||||
|
HOST: str = "0.0.0.0"
|
||||||
|
PORT: int = 8000
|
||||||
|
|
||||||
|
DATA_PATH: str = "~/energy-test-data/data/processed"
|
||||||
|
DATA_PATH_RESOLVED: Path = Path(DATA_PATH).expanduser()
|
||||||
|
|
||||||
|
CORS_ORIGINS: Union[str, List[str]] = [
|
||||||
|
"http://localhost:3000",
|
||||||
|
"http://localhost:5173",
|
||||||
|
]
|
||||||
|
|
||||||
|
WS_HEARTBEAT_INTERVAL: int = 30
|
||||||
|
|
||||||
|
CELERY_BROKER_URL: str = "redis://localhost:6379/0"
|
||||||
|
CELERY_RESULT_BACKEND: str = "redis://localhost:6379/0"
|
||||||
|
|
||||||
|
MODELS_PATH: str = "models"
|
||||||
|
RESULTS_PATH: str = "results"
|
||||||
|
|
||||||
|
BATTERY_MIN_RESERVE: float = 0.10
|
||||||
|
BATTERY_MAX_CHARGE: float = 0.90
|
||||||
|
|
||||||
|
ARBITRAGE_MIN_SPREAD: float = 5.0
|
||||||
|
MINING_MARGIN_THRESHOLD: float = 5.0
|
||||||
|
|
||||||
|
ML_PREDICTION_HORIZONS: List[int] = [1, 5, 15, 60]
|
||||||
|
ML_FEATURE_LAGS: List[int] = [1, 5, 10, 15, 30, 60]
|
||||||
|
|
||||||
|
model_config = SettingsConfigDict(env_file=".env", case_sensitive=True, env_ignore_empty=True)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def cors_origins_list(self) -> List[str]:
|
||||||
|
if isinstance(self.CORS_ORIGINS, str):
|
||||||
|
return [origin.strip() for origin in self.CORS_ORIGINS.split(",")]
|
||||||
|
return self.CORS_ORIGINS if isinstance(self.CORS_ORIGINS, list) else []
|
||||||
|
|
||||||
|
|
||||||
|
settings = Settings()
|
||||||
25
backend/app/core/__init__.py
Normal file
25
backend/app/core/__init__.py
Normal file
@@ -0,0 +1,25 @@
|
|||||||
|
from app.core.constants import (
|
||||||
|
DEFAULT_BATTERY_CAPACITY_MWH,
|
||||||
|
DEFAULT_CHARGE_RATE_MW,
|
||||||
|
DEFAULT_DISCHARGE_RATE_MW,
|
||||||
|
DEFAULT_EFFICIENCY,
|
||||||
|
WS_HEARTBEAT_INTERVAL,
|
||||||
|
PRICE_DATA_REGIONS,
|
||||||
|
DEFAULT_DATA_LIMIT,
|
||||||
|
BACKTEST_RESULT_TIMEOUT,
|
||||||
|
TRAINING_RESULT_TIMEOUT,
|
||||||
|
MAX_BACKTEST_TRADES,
|
||||||
|
)
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"DEFAULT_BATTERY_CAPACITY_MWH",
|
||||||
|
"DEFAULT_CHARGE_RATE_MW",
|
||||||
|
"DEFAULT_DISCHARGE_RATE_MW",
|
||||||
|
"DEFAULT_EFFICIENCY",
|
||||||
|
"WS_HEARTBEAT_INTERVAL",
|
||||||
|
"PRICE_DATA_REGIONS",
|
||||||
|
"DEFAULT_DATA_LIMIT",
|
||||||
|
"BACKTEST_RESULT_TIMEOUT",
|
||||||
|
"TRAINING_RESULT_TIMEOUT",
|
||||||
|
"MAX_BACKTEST_TRADES",
|
||||||
|
]
|
||||||
19
backend/app/core/constants.py
Normal file
19
backend/app/core/constants.py
Normal file
@@ -0,0 +1,19 @@
|
|||||||
|
from datetime import timedelta
|
||||||
|
|
||||||
|
DEFAULT_BATTERY_CAPACITY_MWH = 100.0
|
||||||
|
DEFAULT_CHARGE_RATE_MW = 50.0
|
||||||
|
DEFAULT_DISCHARGE_RATE_MW = 50.0
|
||||||
|
DEFAULT_EFFICIENCY = 0.90
|
||||||
|
|
||||||
|
DEFAULT_HEARTBEAT_INTERVAL = 30
|
||||||
|
WS_HEARTBEAT_INTERVAL = 30
|
||||||
|
|
||||||
|
PRICE_DATA_REGIONS = ["FR", "BE", "DE", "NL", "UK"]
|
||||||
|
|
||||||
|
DEFAULT_DATA_LIMIT = 1000
|
||||||
|
|
||||||
|
BACKTEST_RESULT_TIMEOUT = timedelta(hours=1)
|
||||||
|
|
||||||
|
TRAINING_RESULT_TIMEOUT = timedelta(hours=24)
|
||||||
|
|
||||||
|
MAX_BACKTEST_TRADES = 10000
|
||||||
68
backend/app/main.py
Normal file
68
backend/app/main.py
Normal file
@@ -0,0 +1,68 @@
|
|||||||
|
from fastapi import FastAPI, WebSocket, WebSocketDisconnect
|
||||||
|
from fastapi.middleware.cors import CORSMiddleware
|
||||||
|
from app.config import settings
|
||||||
|
from app.utils.logger import setup_logger, get_logger
|
||||||
|
from app.api.routes import dashboard, backtest, models, trading, settings as settings_routes
|
||||||
|
from app.api.websocket import manager
|
||||||
|
|
||||||
|
setup_logger()
|
||||||
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
app = FastAPI(
|
||||||
|
title=settings.APP_NAME,
|
||||||
|
version=settings.APP_VERSION,
|
||||||
|
docs_url="/docs",
|
||||||
|
redoc_url="/redoc",
|
||||||
|
)
|
||||||
|
|
||||||
|
app.add_middleware(
|
||||||
|
CORSMiddleware,
|
||||||
|
allow_origins=settings.cors_origins_list,
|
||||||
|
allow_credentials=True,
|
||||||
|
allow_methods=["*"],
|
||||||
|
allow_headers=["*"],
|
||||||
|
)
|
||||||
|
|
||||||
|
app.include_router(dashboard.router, prefix="/api/v1/dashboard", tags=["dashboard"])
|
||||||
|
app.include_router(backtest.router, prefix="/api/v1/backtest", tags=["backtest"])
|
||||||
|
app.include_router(models.router, prefix="/api/v1/models", tags=["models"])
|
||||||
|
app.include_router(trading.router, prefix="/api/v1/trading", tags=["trading"])
|
||||||
|
app.include_router(settings_routes.router, prefix="/api/v1/settings", tags=["settings"])
|
||||||
|
|
||||||
|
|
||||||
|
@app.on_event("startup")
|
||||||
|
async def startup_event():
|
||||||
|
logger.info(f"Starting {settings.APP_NAME} v{settings.APP_VERSION}")
|
||||||
|
logger.info(f"Data path: {settings.DATA_PATH_RESOLVED}")
|
||||||
|
|
||||||
|
|
||||||
|
@app.on_event("shutdown")
|
||||||
|
async def shutdown_event():
|
||||||
|
logger.info("Shutting down application")
|
||||||
|
|
||||||
|
|
||||||
|
@app.get("/health")
|
||||||
|
async def health_check():
|
||||||
|
return {"status": "healthy", "version": settings.APP_VERSION}
|
||||||
|
|
||||||
|
|
||||||
|
@app.websocket("/ws/real-time")
|
||||||
|
async def websocket_endpoint(websocket: WebSocket):
|
||||||
|
await manager.connect(websocket)
|
||||||
|
try:
|
||||||
|
while True:
|
||||||
|
data = await websocket.receive_text()
|
||||||
|
await manager.broadcast("message", {"text": data})
|
||||||
|
except WebSocketDisconnect:
|
||||||
|
manager.disconnect(websocket)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
import uvicorn
|
||||||
|
|
||||||
|
uvicorn.run(
|
||||||
|
"app.main:app",
|
||||||
|
host=settings.HOST,
|
||||||
|
port=settings.PORT,
|
||||||
|
reload=settings.DEBUG,
|
||||||
|
)
|
||||||
6
backend/app/ml/__init__.py
Normal file
6
backend/app/ml/__init__.py
Normal file
@@ -0,0 +1,6 @@
|
|||||||
|
from app.ml.features import (
|
||||||
|
build_price_features,
|
||||||
|
build_battery_features,
|
||||||
|
)
|
||||||
|
|
||||||
|
__all__ = ["build_price_features", "build_battery_features"]
|
||||||
3
backend/app/ml/evaluation/__init__.py
Normal file
3
backend/app/ml/evaluation/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
|||||||
|
from app.ml.evaluation import ModelEvaluator, BacktestEvaluator
|
||||||
|
|
||||||
|
__all__ = ["ModelEvaluator", "BacktestEvaluator"]
|
||||||
3
backend/app/ml/evaluation/backtest_evaluator.py
Normal file
3
backend/app/ml/evaluation/backtest_evaluator.py
Normal file
@@ -0,0 +1,3 @@
|
|||||||
|
from app.ml.evaluation.metrics import BacktestEvaluator
|
||||||
|
|
||||||
|
__all__ = ["BacktestEvaluator"]
|
||||||
77
backend/app/ml/evaluation/metrics.py
Normal file
77
backend/app/ml/evaluation/metrics.py
Normal file
@@ -0,0 +1,77 @@
|
|||||||
|
from typing import Dict, List
|
||||||
|
import numpy as np
|
||||||
|
from sklearn.metrics import mean_absolute_error, mean_squared_error, r2_score
|
||||||
|
|
||||||
|
|
||||||
|
class ModelEvaluator:
|
||||||
|
@staticmethod
|
||||||
|
def calculate_metrics(y_true, y_pred) -> Dict[str, float]:
|
||||||
|
mae = mean_absolute_error(y_true, y_pred)
|
||||||
|
rmse = mean_squared_error(y_true, y_pred, squared=False)
|
||||||
|
mape = np.mean(np.abs((y_true - y_pred) / y_true)) * 100
|
||||||
|
r2 = r2_score(y_true, y_pred)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"mae": float(mae),
|
||||||
|
"rmse": float(rmse),
|
||||||
|
"mape": float(mape) if not np.isnan(mape) else 0.0,
|
||||||
|
"r2": float(r2),
|
||||||
|
}
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def calculate_sharpe_ratio(returns: np.ndarray, risk_free_rate: float = 0.0) -> float:
|
||||||
|
if len(returns) == 0 or np.std(returns) == 0:
|
||||||
|
return 0.0
|
||||||
|
|
||||||
|
excess_returns = returns - risk_free_rate
|
||||||
|
return float(np.mean(excess_returns) / np.std(excess_returns))
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def calculate_max_drawdown(values: np.ndarray) -> float:
|
||||||
|
if len(values) == 0:
|
||||||
|
return 0.0
|
||||||
|
|
||||||
|
cumulative = np.cumsum(values)
|
||||||
|
running_max = np.maximum.accumulate(cumulative)
|
||||||
|
drawdown = (cumulative - running_max)
|
||||||
|
return float(drawdown.min())
|
||||||
|
|
||||||
|
|
||||||
|
class BacktestEvaluator:
|
||||||
|
def __init__(self):
|
||||||
|
self.trades: List[Dict] = []
|
||||||
|
|
||||||
|
def add_trade(self, trade: Dict):
|
||||||
|
self.trades.append(trade)
|
||||||
|
|
||||||
|
def evaluate(self) -> Dict[str, float]:
|
||||||
|
if not self.trades:
|
||||||
|
return {
|
||||||
|
"total_revenue": 0.0,
|
||||||
|
"total_trades": 0,
|
||||||
|
"win_rate": 0.0,
|
||||||
|
"sharpe_ratio": 0.0,
|
||||||
|
"max_drawdown": 0.0,
|
||||||
|
}
|
||||||
|
|
||||||
|
total_revenue = sum(t.get("revenue", 0) for t in self.trades)
|
||||||
|
winning_trades = sum(1 for t in self.trades if t.get("revenue", 0) > 0)
|
||||||
|
win_rate = winning_trades / len(self.trades) if self.trades else 0.0
|
||||||
|
|
||||||
|
returns = np.array([t.get("revenue", 0) for t in self.trades])
|
||||||
|
sharpe_ratio = ModelEvaluator.calculate_sharpe_ratio(returns)
|
||||||
|
max_drawdown = ModelEvaluator.calculate_max_drawdown(returns)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"total_revenue": total_revenue,
|
||||||
|
"total_trades": len(self.trades),
|
||||||
|
"win_rate": win_rate,
|
||||||
|
"sharpe_ratio": sharpe_ratio,
|
||||||
|
"max_drawdown": max_drawdown,
|
||||||
|
}
|
||||||
|
|
||||||
|
def reset(self):
|
||||||
|
self.trades = []
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = ["ModelEvaluator", "BacktestEvaluator"]
|
||||||
3
backend/app/ml/evaluation/reports.py
Normal file
3
backend/app/ml/evaluation/reports.py
Normal file
@@ -0,0 +1,3 @@
|
|||||||
|
from app.ml.evaluation.metrics import ModelEvaluator, BacktestEvaluator
|
||||||
|
|
||||||
|
__all__ = ["ModelEvaluator", "BacktestEvaluator"]
|
||||||
53
backend/app/ml/features/__init__.py
Normal file
53
backend/app/ml/features/__init__.py
Normal file
@@ -0,0 +1,53 @@
|
|||||||
|
from app.ml.features.lag_features import add_lag_features
|
||||||
|
from app.ml.features.rolling_features import add_rolling_features
|
||||||
|
from app.ml.features.time_features import add_time_features
|
||||||
|
from app.ml.features.regional_features import add_regional_features
|
||||||
|
from app.ml.features.battery_features import add_battery_features
|
||||||
|
from typing import List, Optional
|
||||||
|
import pandas as pd
|
||||||
|
|
||||||
|
|
||||||
|
def build_price_features(
|
||||||
|
df: pd.DataFrame,
|
||||||
|
price_col: str = "real_time_price",
|
||||||
|
lags: Optional[List[int]] = None,
|
||||||
|
windows: Optional[List[int]] = None,
|
||||||
|
regions: Optional[List[str]] = None,
|
||||||
|
include_time: bool = True,
|
||||||
|
include_regional: bool = True,
|
||||||
|
) -> pd.DataFrame:
|
||||||
|
if lags is None:
|
||||||
|
lags = [1, 5, 10, 15, 30, 60]
|
||||||
|
|
||||||
|
if windows is None:
|
||||||
|
windows = [5, 10, 15, 30, 60]
|
||||||
|
|
||||||
|
result = df.copy()
|
||||||
|
|
||||||
|
if price_col in result.columns:
|
||||||
|
result = add_lag_features(result, price_col, lags)
|
||||||
|
result = add_rolling_features(result, price_col, windows)
|
||||||
|
|
||||||
|
if include_time and "timestamp" in result.columns:
|
||||||
|
result = add_time_features(result)
|
||||||
|
|
||||||
|
if include_regional and regions:
|
||||||
|
result = add_regional_features(result, regions)
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
def build_battery_features(
|
||||||
|
df: pd.DataFrame,
|
||||||
|
price_df: pd.DataFrame,
|
||||||
|
battery_col: str = "charge_level_mwh",
|
||||||
|
capacity_col: str = "capacity_mwh",
|
||||||
|
timestamp_col: str = "timestamp",
|
||||||
|
battery_id_col: str = "battery_id",
|
||||||
|
) -> pd.DataFrame:
|
||||||
|
result = df.copy()
|
||||||
|
result = add_battery_features(result, price_df, battery_col, capacity_col, timestamp_col, battery_id_col)
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = ["build_price_features", "build_battery_features"]
|
||||||
35
backend/app/ml/features/battery_features.py
Normal file
35
backend/app/ml/features/battery_features.py
Normal file
@@ -0,0 +1,35 @@
|
|||||||
|
import pandas as pd
|
||||||
|
|
||||||
|
|
||||||
|
def add_battery_features(
|
||||||
|
df: pd.DataFrame,
|
||||||
|
price_df: pd.DataFrame,
|
||||||
|
battery_col: str = "charge_level_mwh",
|
||||||
|
capacity_col: str = "capacity_mwh",
|
||||||
|
timestamp_col: str = "timestamp",
|
||||||
|
battery_id_col: str = "battery_id",
|
||||||
|
) -> pd.DataFrame:
|
||||||
|
result = df.copy()
|
||||||
|
|
||||||
|
if battery_col in result.columns and capacity_col in result.columns:
|
||||||
|
result["charge_level_pct"] = result[battery_col] / result[capacity_col]
|
||||||
|
result["discharge_potential_mwh"] = result[battery_col] * result.get("efficiency", 0.9)
|
||||||
|
result["charge_capacity_mwh"] = result[capacity_col] - result[battery_col]
|
||||||
|
|
||||||
|
if price_df is not None and "real_time_price" in price_df.columns and timestamp_col in result.columns:
|
||||||
|
merged = result.merge(
|
||||||
|
price_df[[timestamp_col, "real_time_price"]],
|
||||||
|
on=timestamp_col,
|
||||||
|
how="left",
|
||||||
|
suffixes=("", "_market")
|
||||||
|
)
|
||||||
|
|
||||||
|
if "real_time_price_market" in merged.columns:
|
||||||
|
result["market_price"] = merged["real_time_price_market"]
|
||||||
|
result["charge_cost_potential"] = result["charge_capacity_mwh"] * result["market_price"]
|
||||||
|
result["discharge_revenue_potential"] = result["discharge_potential_mwh"] * result["market_price"]
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = ["add_battery_features"]
|
||||||
14
backend/app/ml/features/lag_features.py
Normal file
14
backend/app/ml/features/lag_features.py
Normal file
@@ -0,0 +1,14 @@
|
|||||||
|
from typing import List
|
||||||
|
import pandas as pd
|
||||||
|
|
||||||
|
|
||||||
|
def add_lag_features(df: pd.DataFrame, col: str, lags: List[int]) -> pd.DataFrame:
|
||||||
|
result = df.copy()
|
||||||
|
|
||||||
|
for lag in lags:
|
||||||
|
result[f"{col}_lag_{lag}"] = result[col].shift(lag)
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = ["add_lag_features"]
|
||||||
18
backend/app/ml/features/regional_features.py
Normal file
18
backend/app/ml/features/regional_features.py
Normal file
@@ -0,0 +1,18 @@
|
|||||||
|
from typing import List
|
||||||
|
import pandas as pd
|
||||||
|
|
||||||
|
|
||||||
|
def add_regional_features(df: pd.DataFrame, regions: List[str]) -> pd.DataFrame:
|
||||||
|
result = df.copy()
|
||||||
|
|
||||||
|
if "region" in result.columns and "real_time_price" in result.columns:
|
||||||
|
avg_price_by_region = result.groupby("region")["real_time_price"].mean()
|
||||||
|
|
||||||
|
for region in regions:
|
||||||
|
region_avg = avg_price_by_region.get(region, 0)
|
||||||
|
result[f"price_diff_{region}"] = result["real_time_price"] - region_avg
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = ["add_regional_features"]
|
||||||
17
backend/app/ml/features/rolling_features.py
Normal file
17
backend/app/ml/features/rolling_features.py
Normal file
@@ -0,0 +1,17 @@
|
|||||||
|
from typing import List
|
||||||
|
import pandas as pd
|
||||||
|
|
||||||
|
|
||||||
|
def add_rolling_features(df: pd.DataFrame, col: str, windows: List[int]) -> pd.DataFrame:
|
||||||
|
result = df.copy()
|
||||||
|
|
||||||
|
for window in windows:
|
||||||
|
result[f"{col}_rolling_mean_{window}"] = result[col].rolling(window=window).mean()
|
||||||
|
result[f"{col}_rolling_std_{window}"] = result[col].rolling(window=window).std()
|
||||||
|
result[f"{col}_rolling_min_{window}"] = result[col].rolling(window=window).min()
|
||||||
|
result[f"{col}_rolling_max_{window}"] = result[col].rolling(window=window).max()
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = ["add_rolling_features"]
|
||||||
35
backend/app/ml/features/time_features.py
Normal file
35
backend/app/ml/features/time_features.py
Normal file
@@ -0,0 +1,35 @@
|
|||||||
|
import pandas as pd
|
||||||
|
|
||||||
|
|
||||||
|
def add_time_features(df: pd.DataFrame, timestamp_col: str = "timestamp") -> pd.DataFrame:
|
||||||
|
result = df.copy()
|
||||||
|
|
||||||
|
if timestamp_col not in result.columns:
|
||||||
|
return result
|
||||||
|
|
||||||
|
result[timestamp_col] = pd.to_datetime(result[timestamp_col])
|
||||||
|
|
||||||
|
result["hour"] = result[timestamp_col].dt.hour
|
||||||
|
result["day_of_week"] = result[timestamp_col].dt.dayofweek
|
||||||
|
result["day_of_month"] = result[timestamp_col].dt.day
|
||||||
|
result["month"] = result[timestamp_col].dt.month
|
||||||
|
|
||||||
|
result["hour_sin"] = _sin_encode(result["hour"], 24)
|
||||||
|
result["hour_cos"] = _cos_encode(result["hour"], 24)
|
||||||
|
result["day_sin"] = _sin_encode(result["day_of_week"], 7)
|
||||||
|
result["day_cos"] = _cos_encode(result["day_of_week"], 7)
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
def _sin_encode(x, period):
|
||||||
|
import numpy as np
|
||||||
|
return np.sin(2 * np.pi * x / period)
|
||||||
|
|
||||||
|
|
||||||
|
def _cos_encode(x, period):
|
||||||
|
import numpy as np
|
||||||
|
return np.cos(2 * np.pi * x / period)
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = ["add_time_features"]
|
||||||
3
backend/app/ml/model_management/__init__.py
Normal file
3
backend/app/ml/model_management/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
|||||||
|
from app.ml.model_management import ModelRegistry
|
||||||
|
|
||||||
|
__all__ = ["ModelRegistry"]
|
||||||
99
backend/app/ml/model_management/registry.py
Normal file
99
backend/app/ml/model_management/registry.py
Normal file
@@ -0,0 +1,99 @@
|
|||||||
|
from typing import Dict, List, Optional
|
||||||
|
from pathlib import Path
|
||||||
|
import json
|
||||||
|
from datetime import datetime
|
||||||
|
from app.utils.logger import get_logger
|
||||||
|
|
||||||
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class ModelRegistry:
|
||||||
|
def __init__(self, registry_path: str = "models/registry.json"):
|
||||||
|
self.registry_path = Path(registry_path)
|
||||||
|
self._registry: Dict[str, Dict] = {}
|
||||||
|
self._load()
|
||||||
|
|
||||||
|
def _load(self):
|
||||||
|
if self.registry_path.exists():
|
||||||
|
with open(self.registry_path) as f:
|
||||||
|
self._registry = json.load(f)
|
||||||
|
logger.info(f"Loaded registry from {self.registry_path}")
|
||||||
|
|
||||||
|
def _save(self):
|
||||||
|
self.registry_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
with open(self.registry_path, "w") as f:
|
||||||
|
json.dump(self._registry, f, indent=2, default=str)
|
||||||
|
logger.info(f"Saved registry to {self.registry_path}")
|
||||||
|
|
||||||
|
def register_model(
|
||||||
|
self,
|
||||||
|
model_type: str,
|
||||||
|
model_id: str,
|
||||||
|
version: str,
|
||||||
|
filepath: str,
|
||||||
|
metadata: Optional[Dict] = None,
|
||||||
|
) -> None:
|
||||||
|
timestamp = datetime.utcnow().isoformat()
|
||||||
|
|
||||||
|
if model_id not in self._registry:
|
||||||
|
self._registry[model_id] = {
|
||||||
|
"type": model_type,
|
||||||
|
"versions": [],
|
||||||
|
}
|
||||||
|
|
||||||
|
self._registry[model_id]["versions"].append({
|
||||||
|
"version": version,
|
||||||
|
"filepath": filepath,
|
||||||
|
"timestamp": timestamp,
|
||||||
|
"metadata": metadata or {},
|
||||||
|
})
|
||||||
|
|
||||||
|
self._registry[model_id]["latest"] = version
|
||||||
|
|
||||||
|
self._save()
|
||||||
|
logger.info(f"Registered model {model_id} version {version}")
|
||||||
|
|
||||||
|
def get_latest_version(self, model_id: str) -> Optional[Dict]:
|
||||||
|
if model_id not in self._registry:
|
||||||
|
return None
|
||||||
|
|
||||||
|
latest_version = self._registry[model_id].get("latest")
|
||||||
|
if not latest_version:
|
||||||
|
return None
|
||||||
|
|
||||||
|
for version_info in self._registry[model_id]["versions"]:
|
||||||
|
if version_info["version"] == latest_version:
|
||||||
|
return version_info
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
def list_models(self) -> List[Dict]:
|
||||||
|
models = []
|
||||||
|
|
||||||
|
for model_id, model_info in self._registry.items():
|
||||||
|
latest = self.get_latest_version(model_id)
|
||||||
|
models.append({
|
||||||
|
"model_id": model_id,
|
||||||
|
"type": model_info.get("type"),
|
||||||
|
"latest_version": model_info.get("latest"),
|
||||||
|
"total_versions": len(model_info.get("versions", [])),
|
||||||
|
"latest_info": latest,
|
||||||
|
})
|
||||||
|
|
||||||
|
return models
|
||||||
|
|
||||||
|
def get_model(self, model_id: str, version: Optional[str] = None) -> Optional[Dict]:
|
||||||
|
if model_id not in self._registry:
|
||||||
|
return None
|
||||||
|
|
||||||
|
if version is None:
|
||||||
|
version = self._registry[model_id].get("latest")
|
||||||
|
|
||||||
|
for version_info in self._registry[model_id]["versions"]:
|
||||||
|
if version_info["version"] == version:
|
||||||
|
return version_info
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = ["ModelRegistry"]
|
||||||
3
backend/app/ml/price_prediction/__init__.py
Normal file
3
backend/app/ml/price_prediction/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
|||||||
|
from app.ml.price_prediction import PricePredictor, PricePredictionTrainer
|
||||||
|
|
||||||
|
__all__ = ["PricePredictor", "PricePredictionTrainer"]
|
||||||
52
backend/app/ml/price_prediction/model.py
Normal file
52
backend/app/ml/price_prediction/model.py
Normal file
@@ -0,0 +1,52 @@
|
|||||||
|
import pickle
|
||||||
|
from typing import Optional
|
||||||
|
import xgboost as xgb
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
|
class PricePredictionModel:
|
||||||
|
def __init__(self, horizon: int, model_id: Optional[str] = None):
|
||||||
|
self.horizon = horizon
|
||||||
|
self.model_id = model_id or f"price_prediction_{horizon}m"
|
||||||
|
self.model: Optional[xgb.XGBRegressor] = None
|
||||||
|
self.feature_names = []
|
||||||
|
|
||||||
|
def fit(self, X, y):
|
||||||
|
self.model = xgb.XGBRegressor(
|
||||||
|
n_estimators=200,
|
||||||
|
max_depth=6,
|
||||||
|
learning_rate=0.1,
|
||||||
|
subsample=0.8,
|
||||||
|
colsample_bytree=0.8,
|
||||||
|
random_state=42,
|
||||||
|
)
|
||||||
|
|
||||||
|
if isinstance(X, np.ndarray):
|
||||||
|
self.feature_names = [f"feature_{i}" for i in range(X.shape[1])]
|
||||||
|
else:
|
||||||
|
self.feature_names = list(X.columns)
|
||||||
|
|
||||||
|
self.model.fit(X, y)
|
||||||
|
|
||||||
|
def predict(self, X):
|
||||||
|
if self.model is None:
|
||||||
|
raise ValueError("Model not trained")
|
||||||
|
return self.model.predict(X)
|
||||||
|
|
||||||
|
def save(self, filepath: str):
|
||||||
|
with open(filepath, "wb") as f:
|
||||||
|
pickle.dump(self, f)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def load(cls, filepath: str):
|
||||||
|
with open(filepath, "rb") as f:
|
||||||
|
return pickle.load(f)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def feature_importances_(self):
|
||||||
|
if self.model is None:
|
||||||
|
raise ValueError("Model not trained")
|
||||||
|
return self.model.feature_importances_
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = ["PricePredictionModel"]
|
||||||
86
backend/app/ml/price_prediction/predictor.py
Normal file
86
backend/app/ml/price_prediction/predictor.py
Normal file
@@ -0,0 +1,86 @@
|
|||||||
|
from typing import Dict, Optional
|
||||||
|
import pandas as pd
|
||||||
|
import numpy as np
|
||||||
|
from app.ml.price_prediction.model import PricePredictionModel
|
||||||
|
from app.ml.price_prediction.trainer import PricePredictionTrainer
|
||||||
|
from app.utils.logger import get_logger
|
||||||
|
|
||||||
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class PricePredictor:
|
||||||
|
def __init__(self, models_dir: str = "models/price_prediction"):
|
||||||
|
self.models_dir = models_dir
|
||||||
|
self.models: Dict[int, PricePredictionModel] = {}
|
||||||
|
self._load_models()
|
||||||
|
|
||||||
|
def _load_models(self):
|
||||||
|
self.models = PricePredictionTrainer.load_models(self.models_dir)
|
||||||
|
logger.info(f"Loaded {len(self.models)} prediction models")
|
||||||
|
|
||||||
|
def predict(
|
||||||
|
self, current_data: pd.DataFrame, horizon: int = 15, region: Optional[str] = None
|
||||||
|
) -> float:
|
||||||
|
if horizon not in self.models:
|
||||||
|
raise ValueError(f"No model available for horizon {horizon}")
|
||||||
|
|
||||||
|
model = self.models[horizon]
|
||||||
|
|
||||||
|
from app.ml.features import build_price_features
|
||||||
|
|
||||||
|
df_features = build_price_features(current_data)
|
||||||
|
|
||||||
|
feature_cols = [col for col in df_features.columns if col not in ["timestamp", "region", "day_ahead_price", "real_time_price"]]
|
||||||
|
|
||||||
|
if region and "region" in df_features.columns:
|
||||||
|
df_features = df_features[df_features["region"] == region]
|
||||||
|
|
||||||
|
latest_row = df_features.iloc[-1:][feature_cols]
|
||||||
|
|
||||||
|
prediction = model.predict(latest_row.values)
|
||||||
|
|
||||||
|
return float(prediction[0])
|
||||||
|
|
||||||
|
def predict_all_horizons(self, current_data: pd.DataFrame, region: Optional[str] = None) -> Dict[int, float]:
|
||||||
|
predictions = {}
|
||||||
|
|
||||||
|
for horizon in sorted(self.models.keys()):
|
||||||
|
try:
|
||||||
|
pred = self.predict(current_data, horizon, region)
|
||||||
|
predictions[horizon] = pred
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to predict for horizon {horizon}: {e}")
|
||||||
|
predictions[horizon] = None
|
||||||
|
|
||||||
|
return predictions
|
||||||
|
|
||||||
|
def predict_with_confidence(
|
||||||
|
self, current_data: pd.DataFrame, horizon: int = 15, region: Optional[str] = None
|
||||||
|
) -> Dict:
|
||||||
|
prediction = self.predict(current_data, horizon, region)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"prediction": prediction,
|
||||||
|
"confidence_lower": prediction * 0.95,
|
||||||
|
"confidence_upper": prediction * 1.05,
|
||||||
|
"horizon": horizon,
|
||||||
|
}
|
||||||
|
|
||||||
|
def get_feature_importance(self, horizon: int) -> pd.DataFrame:
|
||||||
|
if horizon not in self.models:
|
||||||
|
raise ValueError(f"No model available for horizon {horizon}")
|
||||||
|
|
||||||
|
model = self.models[horizon]
|
||||||
|
|
||||||
|
importances = model.feature_importances_
|
||||||
|
feature_names = model.feature_names
|
||||||
|
|
||||||
|
df = pd.DataFrame({
|
||||||
|
"feature": feature_names,
|
||||||
|
"importance": importances,
|
||||||
|
}).sort_values("importance", ascending=False)
|
||||||
|
|
||||||
|
return df
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = ["PricePredictor"]
|
||||||
142
backend/app/ml/price_prediction/trainer.py
Normal file
142
backend/app/ml/price_prediction/trainer.py
Normal file
@@ -0,0 +1,142 @@
|
|||||||
|
from typing import List, Dict, Tuple, Optional
|
||||||
|
from pathlib import Path
|
||||||
|
import pandas as pd
|
||||||
|
from app.ml.price_prediction.model import PricePredictionModel
|
||||||
|
from app.utils.logger import get_logger
|
||||||
|
|
||||||
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class PricePredictionTrainer:
|
||||||
|
def __init__(self, config=None):
|
||||||
|
self.config = config
|
||||||
|
self.data: Optional[pd.DataFrame] = None
|
||||||
|
self.models: Dict[int, PricePredictionModel] = {}
|
||||||
|
|
||||||
|
def load_data(self, data_path: Optional[str] = None) -> pd.DataFrame:
|
||||||
|
if data_path is None:
|
||||||
|
data_path = "~/energy-test-data/data/processed"
|
||||||
|
|
||||||
|
path = Path(data_path).expanduser()
|
||||||
|
dfs = []
|
||||||
|
|
||||||
|
for region in ["FR", "BE", "DE", "NL", "UK"]:
|
||||||
|
file_path = path / f"{region.lower()}_processed.parquet"
|
||||||
|
if file_path.exists():
|
||||||
|
df = pd.read_parquet(file_path)
|
||||||
|
df["region"] = region
|
||||||
|
dfs.append(df)
|
||||||
|
|
||||||
|
if dfs:
|
||||||
|
self.data = pd.concat(dfs, ignore_index=True)
|
||||||
|
logger.info(f"Loaded data: {len(self.data)} rows")
|
||||||
|
|
||||||
|
return self.data
|
||||||
|
|
||||||
|
def prepare_data(self, df: pd.DataFrame) -> Tuple[pd.DataFrame, List[str]]:
|
||||||
|
from app.ml.features import build_price_features
|
||||||
|
|
||||||
|
df_features = build_price_features(df)
|
||||||
|
|
||||||
|
df_features = df_features.dropna()
|
||||||
|
|
||||||
|
feature_cols = [col for col in df_features.columns if col not in ["timestamp", "region", "day_ahead_price", "real_time_price"]]
|
||||||
|
|
||||||
|
return df_features, feature_cols
|
||||||
|
|
||||||
|
def train_for_horizon(
|
||||||
|
self, df_features: pd.DataFrame, feature_cols: List[str], horizon: int
|
||||||
|
) -> Dict:
|
||||||
|
logger.info(f"Training model for {horizon} minute horizon")
|
||||||
|
|
||||||
|
df_features = df_features.sort_values("timestamp")
|
||||||
|
|
||||||
|
n_total = len(df_features)
|
||||||
|
n_train = int(n_total * 0.70)
|
||||||
|
n_val = int(n_total * 0.85)
|
||||||
|
|
||||||
|
train_data = df_features.iloc[:n_train]
|
||||||
|
val_data = df_features.iloc[n_train:n_val]
|
||||||
|
|
||||||
|
X_train = train_data[feature_cols]
|
||||||
|
y_train = train_data["real_time_price"].shift(-horizon).dropna()
|
||||||
|
X_train = X_train.loc[y_train.index]
|
||||||
|
|
||||||
|
X_val = val_data[feature_cols]
|
||||||
|
y_val = val_data["real_time_price"].shift(-horizon).dropna()
|
||||||
|
X_val = X_val.loc[y_val.index]
|
||||||
|
|
||||||
|
model = PricePredictionModel(horizon=horizon)
|
||||||
|
model.fit(X_train, y_train)
|
||||||
|
|
||||||
|
val_preds = model.predict(X_val)
|
||||||
|
|
||||||
|
from sklearn.metrics import mean_absolute_error, mean_squared_error, r2_score
|
||||||
|
|
||||||
|
mae = mean_absolute_error(y_val, val_preds)
|
||||||
|
rmse = mean_squared_error(y_val, val_preds, squared=False)
|
||||||
|
r2 = r2_score(y_val, val_preds)
|
||||||
|
|
||||||
|
self.models[horizon] = model
|
||||||
|
|
||||||
|
results = {
|
||||||
|
"horizon": horizon,
|
||||||
|
"mae": mae,
|
||||||
|
"rmse": rmse,
|
||||||
|
"r2": r2,
|
||||||
|
"n_train": len(X_train),
|
||||||
|
"n_val": len(X_val),
|
||||||
|
}
|
||||||
|
|
||||||
|
logger.info(f"Training complete for {horizon}m: MAE={mae:.2f}, RMSE={rmse:.2f}, R2={r2:.3f}")
|
||||||
|
|
||||||
|
return results
|
||||||
|
|
||||||
|
def train_all(self, horizons: Optional[List[int]] = None) -> Dict:
|
||||||
|
if horizons is None:
|
||||||
|
horizons = [1, 5, 15, 60]
|
||||||
|
|
||||||
|
if self.data is None:
|
||||||
|
self.load_data()
|
||||||
|
|
||||||
|
df_features, feature_cols = self.prepare_data(self.data)
|
||||||
|
|
||||||
|
all_results = {}
|
||||||
|
for horizon in horizons:
|
||||||
|
try:
|
||||||
|
result = self.train_for_horizon(df_features, feature_cols, horizon)
|
||||||
|
all_results[horizon] = result
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to train for horizon {horizon}: {e}")
|
||||||
|
all_results[horizon] = {"error": str(e)}
|
||||||
|
|
||||||
|
return all_results
|
||||||
|
|
||||||
|
def save_models(self, output_dir: str = "models/price_prediction") -> None:
|
||||||
|
output_path = Path(output_dir)
|
||||||
|
output_path.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
for horizon, model in self.models.items():
|
||||||
|
filepath = output_path / f"model_{horizon}min.pkl"
|
||||||
|
model.save(filepath)
|
||||||
|
logger.info(f"Saved model for {horizon}m to {filepath}")
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def load_models(cls, models_dir: str = "models/price_prediction", horizons: Optional[List[int]] = None) -> Dict[int, PricePredictionModel]:
|
||||||
|
models = {}
|
||||||
|
path = Path(models_dir)
|
||||||
|
|
||||||
|
if horizons is None:
|
||||||
|
horizons = [1, 5, 15, 60]
|
||||||
|
|
||||||
|
for horizon in horizons:
|
||||||
|
filepath = path / f"model_{horizon}min.pkl"
|
||||||
|
if filepath.exists():
|
||||||
|
model = PricePredictionModel.load(filepath)
|
||||||
|
models[horizon] = model
|
||||||
|
logger.info(f"Loaded model for {horizon}m")
|
||||||
|
|
||||||
|
return models
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = ["PricePredictionTrainer"]
|
||||||
3
backend/app/ml/rl_battery/__init__.py
Normal file
3
backend/app/ml/rl_battery/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
|||||||
|
from app.ml.rl_battery import BatteryPolicy, BatteryRLTrainer
|
||||||
|
|
||||||
|
__all__ = ["BatteryPolicy", "BatteryRLTrainer"]
|
||||||
88
backend/app/ml/rl_battery/agent.py
Normal file
88
backend/app/ml/rl_battery/agent.py
Normal file
@@ -0,0 +1,88 @@
|
|||||||
|
from typing import Dict, Optional
|
||||||
|
import numpy as np
|
||||||
|
import pickle
|
||||||
|
from app.utils.logger import get_logger
|
||||||
|
|
||||||
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class QLearningAgent:
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
state_bins: int = 10,
|
||||||
|
action_space: int = 3,
|
||||||
|
learning_rate: float = 0.1,
|
||||||
|
discount_factor: float = 0.95,
|
||||||
|
epsilon: float = 1.0,
|
||||||
|
epsilon_decay: float = 0.995,
|
||||||
|
epsilon_min: float = 0.05,
|
||||||
|
):
|
||||||
|
self.state_bins = state_bins
|
||||||
|
self.action_space = action_space
|
||||||
|
self.learning_rate = learning_rate
|
||||||
|
self.discount_factor = discount_factor
|
||||||
|
self.epsilon = epsilon
|
||||||
|
self.epsilon_decay = epsilon_decay
|
||||||
|
self.epsilon_min = epsilon_min
|
||||||
|
|
||||||
|
self.q_table: Optional[np.ndarray] = None
|
||||||
|
self.policy_id = "battery_policy"
|
||||||
|
|
||||||
|
def initialize_q_table(self, observation_space: int):
|
||||||
|
self.q_table = np.zeros((self.state_bins ** observation_space, self.action_space))
|
||||||
|
|
||||||
|
def _discretize_state(self, state: np.ndarray) -> int:
|
||||||
|
discretized = (state * self.state_bins).astype(int)
|
||||||
|
discretized = np.clip(discretized, 0, self.state_bins - 1)
|
||||||
|
|
||||||
|
index = 0
|
||||||
|
multiplier = 1
|
||||||
|
for val in discretized:
|
||||||
|
index += val * multiplier
|
||||||
|
multiplier *= self.state_bins
|
||||||
|
|
||||||
|
return index
|
||||||
|
|
||||||
|
def get_action(self, state: np.ndarray, training: bool = True) -> int:
|
||||||
|
state_idx = self._discretize_state(state)
|
||||||
|
|
||||||
|
if training and np.random.random() < self.epsilon:
|
||||||
|
return np.random.randint(self.action_space)
|
||||||
|
|
||||||
|
if self.q_table is None:
|
||||||
|
return 1
|
||||||
|
|
||||||
|
return np.argmax(self.q_table[state_idx])
|
||||||
|
|
||||||
|
def update(self, state: np.ndarray, action: int, reward: float, next_state: np.ndarray, done: bool):
|
||||||
|
if self.q_table is None:
|
||||||
|
return
|
||||||
|
|
||||||
|
state_idx = self._discretize_state(state)
|
||||||
|
next_state_idx = self._discretize_state(next_state)
|
||||||
|
|
||||||
|
current_q = self.q_table[state_idx, action]
|
||||||
|
|
||||||
|
if done:
|
||||||
|
target = reward
|
||||||
|
else:
|
||||||
|
next_q = np.max(self.q_table[next_state_idx])
|
||||||
|
target = reward + self.discount_factor * next_q
|
||||||
|
|
||||||
|
self.q_table[state_idx, action] += self.learning_rate * (target - current_q)
|
||||||
|
|
||||||
|
def decay_epsilon(self):
|
||||||
|
self.epsilon = max(self.epsilon_min, self.epsilon * self.epsilon_decay)
|
||||||
|
|
||||||
|
def save(self, filepath: str):
|
||||||
|
with open(filepath, "wb") as f:
|
||||||
|
pickle.dump(self, f)
|
||||||
|
logger.info(f"Saved Q-learning policy to {filepath}")
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def load(cls, filepath: str):
|
||||||
|
with open(filepath, "rb") as f:
|
||||||
|
return pickle.load(f)
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = ["QLearningAgent"]
|
||||||
87
backend/app/ml/rl_battery/environment.py
Normal file
87
backend/app/ml/rl_battery/environment.py
Normal file
@@ -0,0 +1,87 @@
|
|||||||
|
from typing import Dict, Tuple
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
|
class BatteryEnvironment:
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
capacity: float = 100.0,
|
||||||
|
charge_rate: float = 50.0,
|
||||||
|
discharge_rate: float = 50.0,
|
||||||
|
efficiency: float = 0.9,
|
||||||
|
min_reserve: float = 0.1,
|
||||||
|
max_charge: float = 0.9,
|
||||||
|
):
|
||||||
|
self.capacity = capacity
|
||||||
|
self.charge_rate = charge_rate
|
||||||
|
self.discharge_rate = discharge_rate
|
||||||
|
self.efficiency = efficiency
|
||||||
|
self.min_reserve = min_reserve
|
||||||
|
self.max_charge = max_charge
|
||||||
|
|
||||||
|
self.charge_level = capacity * 0.5
|
||||||
|
self.current_price = 50.0
|
||||||
|
self.time_step = 0
|
||||||
|
|
||||||
|
def reset(self) -> np.ndarray:
|
||||||
|
self.charge_level = self.capacity * 0.5
|
||||||
|
self.current_price = 50.0
|
||||||
|
self.time_step = 0
|
||||||
|
return self._get_state()
|
||||||
|
|
||||||
|
def _get_state(self) -> np.ndarray:
|
||||||
|
charge_pct = self.charge_level / self.capacity
|
||||||
|
price_norm = np.clip(self.current_price / 200.0, 0, 1)
|
||||||
|
time_norm = (self.time_step % 1440) / 1440.0
|
||||||
|
|
||||||
|
return np.array([charge_pct, price_norm, time_norm])
|
||||||
|
|
||||||
|
def step(self, action: int) -> Tuple[np.ndarray, float, bool, Dict]:
|
||||||
|
old_price = self.current_price
|
||||||
|
|
||||||
|
if action == 0:
|
||||||
|
charge_amount = min(self.charge_rate, self.capacity * self.max_charge - self.charge_level)
|
||||||
|
self.charge_level += charge_amount * self.efficiency
|
||||||
|
reward = -charge_amount * old_price / 1000.0
|
||||||
|
elif action == 1:
|
||||||
|
reward = 0.0
|
||||||
|
elif action == 2:
|
||||||
|
discharge_amount = min(
|
||||||
|
self.discharge_rate,
|
||||||
|
self.charge_level - self.capacity * self.min_reserve
|
||||||
|
)
|
||||||
|
revenue = discharge_amount * old_price
|
||||||
|
self.charge_level -= discharge_amount / self.efficiency
|
||||||
|
reward = revenue / 1000.0
|
||||||
|
else:
|
||||||
|
reward = 0.0
|
||||||
|
|
||||||
|
self.charge_level = np.clip(self.charge_level, self.capacity * self.min_reserve, self.capacity * self.max_charge)
|
||||||
|
|
||||||
|
self.current_price = old_price + np.random.randn() * 5
|
||||||
|
self.current_price = np.clip(self.current_price, 0, 300)
|
||||||
|
|
||||||
|
self.time_step += 1
|
||||||
|
|
||||||
|
state = self._get_state()
|
||||||
|
|
||||||
|
info = {
|
||||||
|
"charge_level": self.charge_level,
|
||||||
|
"price": self.current_price,
|
||||||
|
"action": action,
|
||||||
|
}
|
||||||
|
|
||||||
|
done = self.time_step >= 1440
|
||||||
|
|
||||||
|
return state, reward, done, info
|
||||||
|
|
||||||
|
@property
|
||||||
|
def action_space(self):
|
||||||
|
return 3
|
||||||
|
|
||||||
|
@property
|
||||||
|
def observation_space(self):
|
||||||
|
return 3
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = ["BatteryEnvironment"]
|
||||||
65
backend/app/ml/rl_battery/policy.py
Normal file
65
backend/app/ml/rl_battery/policy.py
Normal file
@@ -0,0 +1,65 @@
|
|||||||
|
from typing import Dict
|
||||||
|
from app.ml.rl_battery.agent import QLearningAgent
|
||||||
|
from app.ml.rl_battery.environment import BatteryEnvironment
|
||||||
|
from app.utils.logger import get_logger
|
||||||
|
|
||||||
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class BatteryPolicy:
|
||||||
|
def __init__(self, policy_path: str = "models/rl_battery"):
|
||||||
|
self.policy_path = policy_path
|
||||||
|
self.agent: QLearningAgent = None
|
||||||
|
self.env: BatteryEnvironment = None
|
||||||
|
self._load_policy()
|
||||||
|
|
||||||
|
def _load_policy(self):
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
filepath = Path(self.policy_path) / "battery_policy.pkl"
|
||||||
|
if filepath.exists():
|
||||||
|
self.agent = QLearningAgent.load(filepath)
|
||||||
|
self.env = BatteryEnvironment()
|
||||||
|
logger.info(f"Loaded policy from {filepath}")
|
||||||
|
|
||||||
|
def get_action(
|
||||||
|
self,
|
||||||
|
charge_level: float,
|
||||||
|
current_price: float,
|
||||||
|
price_forecast_1m: float = 0,
|
||||||
|
price_forecast_5m: float = 0,
|
||||||
|
price_forecast_15m: float = 0,
|
||||||
|
hour: int = 0,
|
||||||
|
) -> Dict:
|
||||||
|
if self.agent is None:
|
||||||
|
return {
|
||||||
|
"action": "hold",
|
||||||
|
"q_values": [0.0, 0.0, 0.0],
|
||||||
|
"confidence": 0.0,
|
||||||
|
}
|
||||||
|
|
||||||
|
self.env.charge_level = charge_level
|
||||||
|
self.env.current_price = current_price
|
||||||
|
self.env.time_step = hour * 60
|
||||||
|
|
||||||
|
state = self.env._get_state()
|
||||||
|
|
||||||
|
action_idx = self.agent.get_action(state, training=False)
|
||||||
|
|
||||||
|
actions = ["charge", "hold", "discharge"]
|
||||||
|
action_name = actions[action_idx]
|
||||||
|
|
||||||
|
state_idx = self.agent._discretize_state(state)
|
||||||
|
q_values = self.agent.q_table[state_idx].tolist() if self.agent.q_table is not None else [0.0, 0.0, 0.0]
|
||||||
|
|
||||||
|
max_q = max(q_values) if q_values else 0.0
|
||||||
|
confidence = (max_q - min(q_values)) / (max_q + 0.001) if q_values else 0.0
|
||||||
|
|
||||||
|
return {
|
||||||
|
"action": action_name,
|
||||||
|
"q_values": q_values,
|
||||||
|
"confidence": min(confidence, 1.0),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = ["BatteryPolicy"]
|
||||||
95
backend/app/ml/rl_battery/trainer.py
Normal file
95
backend/app/ml/rl_battery/trainer.py
Normal file
@@ -0,0 +1,95 @@
|
|||||||
|
from typing import Dict
|
||||||
|
from app.ml.rl_battery.environment import BatteryEnvironment
|
||||||
|
from app.ml.rl_battery.agent import QLearningAgent
|
||||||
|
from app.utils.logger import get_logger
|
||||||
|
|
||||||
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class BatteryRLTrainer:
|
||||||
|
def __init__(self, config=None):
|
||||||
|
self.config = config or {}
|
||||||
|
self.agent: QLearningAgent = None
|
||||||
|
self.env: BatteryEnvironment = None
|
||||||
|
|
||||||
|
def _create_agent(self) -> QLearningAgent:
|
||||||
|
return QLearningAgent(
|
||||||
|
state_bins=self.config.get("charge_level_bins", 10),
|
||||||
|
action_space=3,
|
||||||
|
learning_rate=self.config.get("learning_rate", 0.1),
|
||||||
|
discount_factor=self.config.get("discount_factor", 0.95),
|
||||||
|
epsilon=self.config.get("epsilon", 1.0),
|
||||||
|
epsilon_decay=self.config.get("epsilon_decay", 0.995),
|
||||||
|
epsilon_min=self.config.get("epsilon_min", 0.05),
|
||||||
|
)
|
||||||
|
|
||||||
|
def load_data(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def train(self, n_episodes: int = 1000, region: str = "FR") -> Dict:
|
||||||
|
logger.info(f"Starting RL training for {n_episodes} episodes")
|
||||||
|
|
||||||
|
self.env = BatteryEnvironment(
|
||||||
|
capacity=100.0,
|
||||||
|
charge_rate=50.0,
|
||||||
|
discharge_rate=50.0,
|
||||||
|
efficiency=0.9,
|
||||||
|
min_reserve=0.1,
|
||||||
|
max_charge=0.9,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.agent = self._create_agent()
|
||||||
|
self.agent.initialize_q_table(self.env.observation_space)
|
||||||
|
|
||||||
|
episode_rewards = []
|
||||||
|
|
||||||
|
for episode in range(n_episodes):
|
||||||
|
state = self.env.reset()
|
||||||
|
total_reward = 0
|
||||||
|
steps = 0
|
||||||
|
|
||||||
|
while True:
|
||||||
|
action = self.agent.get_action(state, training=True)
|
||||||
|
next_state, reward, done, info = self.env.step(action)
|
||||||
|
|
||||||
|
self.agent.update(state, action, reward, next_state, done)
|
||||||
|
|
||||||
|
total_reward += reward
|
||||||
|
state = next_state
|
||||||
|
steps += 1
|
||||||
|
|
||||||
|
if done:
|
||||||
|
break
|
||||||
|
|
||||||
|
episode_rewards.append(total_reward)
|
||||||
|
self.agent.decay_epsilon()
|
||||||
|
|
||||||
|
if (episode + 1) % 100 == 0:
|
||||||
|
avg_reward = sum(episode_rewards[-100:]) / 100
|
||||||
|
logger.info(f"Episode {episode + 1}/{n_episodes}, Avg Reward: {avg_reward:.2f}, Epsilon: {self.agent.epsilon:.3f}")
|
||||||
|
|
||||||
|
final_avg_reward = sum(episode_rewards[-100:]) / 100
|
||||||
|
|
||||||
|
results = {
|
||||||
|
"n_episodes": n_episodes,
|
||||||
|
"final_avg_reward": final_avg_reward,
|
||||||
|
"episode_rewards": episode_rewards,
|
||||||
|
"final_epsilon": self.agent.epsilon,
|
||||||
|
}
|
||||||
|
|
||||||
|
logger.info(f"Training complete. Final avg reward: {final_avg_reward:.2f}")
|
||||||
|
|
||||||
|
return results
|
||||||
|
|
||||||
|
def save(self, output_dir: str = "models/rl_battery") -> None:
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
output_path = Path(output_dir)
|
||||||
|
output_path.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
filepath = output_path / "battery_policy.pkl"
|
||||||
|
self.agent.save(filepath)
|
||||||
|
logger.info(f"Saved trained policy to {filepath}")
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = ["BatteryRLTrainer"]
|
||||||
3
backend/app/ml/training/__init__.py
Normal file
3
backend/app/ml/training/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
|||||||
|
from app.ml.training import CLITrainer
|
||||||
|
|
||||||
|
__all__ = ["CLITrainer"]
|
||||||
49
backend/app/ml/training/cli.py
Normal file
49
backend/app/ml/training/cli.py
Normal file
@@ -0,0 +1,49 @@
|
|||||||
|
import argparse
|
||||||
|
from app.ml.price_prediction.trainer import PricePredictionTrainer
|
||||||
|
from app.ml.rl_battery.trainer import BatteryRLTrainer
|
||||||
|
from app.utils.logger import get_logger, setup_logger
|
||||||
|
|
||||||
|
setup_logger()
|
||||||
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
parser = argparse.ArgumentParser(description="Energy Trading ML Training CLI")
|
||||||
|
subparsers = parser.add_subparsers(dest="command", help="Available commands")
|
||||||
|
|
||||||
|
price_parser = subparsers.add_parser("price", help="Train price prediction models")
|
||||||
|
price_parser.add_argument("--horizons", nargs="+", type=int, default=[1, 5, 15, 60], help="Prediction horizons in minutes")
|
||||||
|
price_parser.add_argument("--output", type=str, default="models/price_prediction", help="Output directory")
|
||||||
|
|
||||||
|
rl_parser = subparsers.add_parser("rl", help="Train RL battery policy")
|
||||||
|
rl_parser.add_argument("--episodes", type=int, default=1000, help="Number of training episodes")
|
||||||
|
rl_parser.add_argument("--region", type=str, default="FR", help="Region to train for")
|
||||||
|
rl_parser.add_argument("--output", type=str, default="models/rl_battery", help="Output directory")
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
if args.command == "price":
|
||||||
|
logger.info(f"Training price prediction models for horizons: {args.horizons}")
|
||||||
|
trainer = PricePredictionTrainer()
|
||||||
|
results = trainer.train_all(horizons=args.horizons)
|
||||||
|
trainer.save_models(output_dir=args.output)
|
||||||
|
logger.info("Training complete!")
|
||||||
|
for horizon, result in results.items():
|
||||||
|
if "error" not in result:
|
||||||
|
logger.info(f" {horizon}m: MAE={result['mae']:.2f}, RMSE={result['rmse']:.2f}, R2={result['r2']:.3f}")
|
||||||
|
|
||||||
|
elif args.command == "rl":
|
||||||
|
logger.info(f"Training RL battery policy for {args.episodes} episodes")
|
||||||
|
trainer = BatteryRLTrainer()
|
||||||
|
results = trainer.train(n_episodes=args.episodes, region=args.region)
|
||||||
|
trainer.save(output_dir=args.output)
|
||||||
|
logger.info("Training complete!")
|
||||||
|
logger.info(f" Final avg reward: {results['final_avg_reward']:.2f}")
|
||||||
|
logger.info(f" Final epsilon: {results['final_epsilon']:.3f}")
|
||||||
|
|
||||||
|
else:
|
||||||
|
parser.print_help()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
3
backend/app/ml/utils/__init__.py
Normal file
3
backend/app/ml/utils/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
|||||||
|
from app.ml.utils import time_based_split, MLConfig
|
||||||
|
|
||||||
|
__all__ = ["time_based_split", "MLConfig"]
|
||||||
16
backend/app/ml/utils/config.py
Normal file
16
backend/app/ml/utils/config.py
Normal file
@@ -0,0 +1,16 @@
|
|||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import List, Dict, Any
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class MLConfig:
|
||||||
|
enable_gpu: bool = False
|
||||||
|
n_jobs: int = 4
|
||||||
|
verbose: bool = True
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_dict(cls, config_dict: Dict[str, Any]) -> "MLConfig":
|
||||||
|
return cls(**{k: v for k, v in config_dict.items() if k in cls.__annotations__})
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = ["MLConfig"]
|
||||||
25
backend/app/ml/utils/data_split.py
Normal file
25
backend/app/ml/utils/data_split.py
Normal file
@@ -0,0 +1,25 @@
|
|||||||
|
from typing import Tuple
|
||||||
|
import pandas as pd
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
|
|
||||||
|
def time_based_split(
|
||||||
|
df: pd.DataFrame,
|
||||||
|
timestamp_col: str = "timestamp",
|
||||||
|
train_pct: float = 0.70,
|
||||||
|
val_pct: float = 0.85,
|
||||||
|
) -> Tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame]:
|
||||||
|
df_sorted = df.sort_values(timestamp_col)
|
||||||
|
|
||||||
|
n_total = len(df_sorted)
|
||||||
|
n_train = int(n_total * train_pct)
|
||||||
|
n_val = int(n_total * val_pct)
|
||||||
|
|
||||||
|
train = df_sorted.iloc[:n_train]
|
||||||
|
val = df_sorted.iloc[n_train:n_val]
|
||||||
|
test = df_sorted.iloc[n_val:]
|
||||||
|
|
||||||
|
return train, val, test
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = ["time_based_split"]
|
||||||
4
backend/app/ml/utils/evaluation.py
Normal file
4
backend/app/ml/utils/evaluation.py
Normal file
@@ -0,0 +1,4 @@
|
|||||||
|
from app.ml.utils.data_split import time_based_split
|
||||||
|
from app.ml.utils.config import MLConfig
|
||||||
|
|
||||||
|
__all__ = ["time_based_split", "MLConfig"]
|
||||||
52
backend/app/models/__init__.py
Normal file
52
backend/app/models/__init__.py
Normal file
@@ -0,0 +1,52 @@
|
|||||||
|
from app.models.schemas import (
|
||||||
|
PriceData,
|
||||||
|
BatteryState,
|
||||||
|
BacktestConfig,
|
||||||
|
BacktestMetrics,
|
||||||
|
TrainingRequest,
|
||||||
|
PredictionResponse,
|
||||||
|
ModelInfo,
|
||||||
|
TrainingStatus,
|
||||||
|
ArbitrageOpportunity,
|
||||||
|
DashboardSummary,
|
||||||
|
Trade,
|
||||||
|
StrategyStatus,
|
||||||
|
Alert,
|
||||||
|
AppSettings,
|
||||||
|
)
|
||||||
|
|
||||||
|
from app.models.enums import (
|
||||||
|
RegionEnum,
|
||||||
|
FuelTypeEnum,
|
||||||
|
StrategyEnum,
|
||||||
|
TradeTypeEnum,
|
||||||
|
BacktestStatusEnum,
|
||||||
|
ModelType,
|
||||||
|
AlertTypeEnum,
|
||||||
|
TrainingStatusEnum,
|
||||||
|
)
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"PriceData",
|
||||||
|
"BatteryState",
|
||||||
|
"BacktestConfig",
|
||||||
|
"BacktestMetrics",
|
||||||
|
"TrainingRequest",
|
||||||
|
"PredictionResponse",
|
||||||
|
"ModelInfo",
|
||||||
|
"TrainingStatus",
|
||||||
|
"ArbitrageOpportunity",
|
||||||
|
"DashboardSummary",
|
||||||
|
"Trade",
|
||||||
|
"StrategyStatus",
|
||||||
|
"Alert",
|
||||||
|
"AppSettings",
|
||||||
|
"RegionEnum",
|
||||||
|
"FuelTypeEnum",
|
||||||
|
"StrategyEnum",
|
||||||
|
"TradeTypeEnum",
|
||||||
|
"BacktestStatusEnum",
|
||||||
|
"ModelType",
|
||||||
|
"AlertTypeEnum",
|
||||||
|
"TrainingStatusEnum",
|
||||||
|
]
|
||||||
60
backend/app/models/enums.py
Normal file
60
backend/app/models/enums.py
Normal file
@@ -0,0 +1,60 @@
|
|||||||
|
from enum import Enum
|
||||||
|
|
||||||
|
|
||||||
|
class RegionEnum(str, Enum):
|
||||||
|
FR = "FR"
|
||||||
|
BE = "BE"
|
||||||
|
DE = "DE"
|
||||||
|
NL = "NL"
|
||||||
|
UK = "UK"
|
||||||
|
|
||||||
|
|
||||||
|
class FuelTypeEnum(str, Enum):
|
||||||
|
GAS = "gas"
|
||||||
|
NUCLEAR = "nuclear"
|
||||||
|
COAL = "coal"
|
||||||
|
SOLAR = "solar"
|
||||||
|
WIND = "wind"
|
||||||
|
HYDRO = "hydro"
|
||||||
|
|
||||||
|
|
||||||
|
class StrategyEnum(str, Enum):
|
||||||
|
FUNDAMENTAL = "fundamental"
|
||||||
|
TECHNICAL = "technical"
|
||||||
|
ML = "ml"
|
||||||
|
MINING = "mining"
|
||||||
|
|
||||||
|
|
||||||
|
class TradeTypeEnum(str, Enum):
|
||||||
|
BUY = "buy"
|
||||||
|
SELL = "sell"
|
||||||
|
CHARGE = "charge"
|
||||||
|
DISCHARGE = "discharge"
|
||||||
|
|
||||||
|
|
||||||
|
class BacktestStatusEnum(str, Enum):
|
||||||
|
PENDING = "pending"
|
||||||
|
RUNNING = "running"
|
||||||
|
COMPLETED = "completed"
|
||||||
|
FAILED = "failed"
|
||||||
|
CANCELLED = "cancelled"
|
||||||
|
|
||||||
|
|
||||||
|
class ModelType(str, Enum):
|
||||||
|
PRICE_PREDICTION = "price_prediction"
|
||||||
|
RL_BATTERY = "rl_battery"
|
||||||
|
|
||||||
|
|
||||||
|
class AlertTypeEnum(str, Enum):
|
||||||
|
PRICE_SPIKE = "price_spike"
|
||||||
|
ARBITRAGE_OPPORTUNITY = "arbitrage_opportunity"
|
||||||
|
BATTERY_LOW = "battery_low"
|
||||||
|
BATTERY_FULL = "battery_full"
|
||||||
|
STRATEGY_ERROR = "strategy_error"
|
||||||
|
|
||||||
|
|
||||||
|
class TrainingStatusEnum(str, Enum):
|
||||||
|
PENDING = "pending"
|
||||||
|
RUNNING = "running"
|
||||||
|
COMPLETED = "completed"
|
||||||
|
FAILED = "failed"
|
||||||
145
backend/app/models/schemas.py
Normal file
145
backend/app/models/schemas.py
Normal file
@@ -0,0 +1,145 @@
|
|||||||
|
from datetime import datetime
|
||||||
|
from typing import Optional, List, Dict, Any
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
from app.models.enums import (
|
||||||
|
RegionEnum,
|
||||||
|
StrategyEnum,
|
||||||
|
TradeTypeEnum,
|
||||||
|
BacktestStatusEnum,
|
||||||
|
ModelType,
|
||||||
|
AlertTypeEnum,
|
||||||
|
TrainingStatusEnum,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class PriceData(BaseModel):
|
||||||
|
timestamp: datetime
|
||||||
|
region: RegionEnum
|
||||||
|
day_ahead_price: float
|
||||||
|
real_time_price: float
|
||||||
|
volume_mw: float
|
||||||
|
|
||||||
|
|
||||||
|
class BatteryState(BaseModel):
|
||||||
|
timestamp: datetime
|
||||||
|
battery_id: str
|
||||||
|
capacity_mwh: float
|
||||||
|
charge_level_mwh: float
|
||||||
|
charge_rate_mw: float
|
||||||
|
discharge_rate_mw: float
|
||||||
|
efficiency: float
|
||||||
|
charge_level_pct: float = Field(default_factory=lambda: 0.0)
|
||||||
|
|
||||||
|
|
||||||
|
class BacktestConfig(BaseModel):
|
||||||
|
start_date: str
|
||||||
|
end_date: str
|
||||||
|
strategies: List[StrategyEnum] = Field(default_factory=list)
|
||||||
|
use_ml: bool = True
|
||||||
|
battery_min_reserve: Optional[float] = None
|
||||||
|
battery_max_charge: Optional[float] = None
|
||||||
|
arbitrage_min_spread: Optional[float] = None
|
||||||
|
|
||||||
|
|
||||||
|
class BacktestMetrics(BaseModel):
|
||||||
|
total_revenue: float
|
||||||
|
arbitrage_profit: float
|
||||||
|
battery_revenue: float
|
||||||
|
mining_profit: float
|
||||||
|
battery_utilization: float
|
||||||
|
price_capture_rate: float
|
||||||
|
win_rate: float
|
||||||
|
sharpe_ratio: float
|
||||||
|
max_drawdown: float
|
||||||
|
total_trades: int
|
||||||
|
|
||||||
|
|
||||||
|
class TrainingRequest(BaseModel):
|
||||||
|
model_type: ModelType
|
||||||
|
horizon: Optional[int] = None
|
||||||
|
start_date: str
|
||||||
|
end_date: str
|
||||||
|
hyperparameters: Dict[str, Any] = Field(default_factory=dict)
|
||||||
|
|
||||||
|
|
||||||
|
class PredictionResponse(BaseModel):
|
||||||
|
model_id: str
|
||||||
|
timestamp: datetime
|
||||||
|
prediction: float
|
||||||
|
confidence: Optional[float] = None
|
||||||
|
features_used: List[str] = Field(default_factory=list)
|
||||||
|
|
||||||
|
|
||||||
|
class ModelInfo(BaseModel):
|
||||||
|
model_id: str
|
||||||
|
model_type: ModelType
|
||||||
|
version: str
|
||||||
|
created_at: datetime
|
||||||
|
metrics: Dict[str, float] = Field(default_factory=dict)
|
||||||
|
hyperparameters: Dict[str, Any] = Field(default_factory=dict)
|
||||||
|
|
||||||
|
|
||||||
|
class TrainingStatus(BaseModel):
|
||||||
|
training_id: str
|
||||||
|
status: TrainingStatusEnum
|
||||||
|
progress: float = 0.0
|
||||||
|
current_epoch: Optional[int] = None
|
||||||
|
total_epochs: Optional[int] = None
|
||||||
|
metrics: Dict[str, float] = Field(default_factory=dict)
|
||||||
|
error_message: Optional[str] = None
|
||||||
|
started_at: Optional[datetime] = None
|
||||||
|
completed_at: Optional[datetime] = None
|
||||||
|
|
||||||
|
|
||||||
|
class ArbitrageOpportunity(BaseModel):
|
||||||
|
timestamp: datetime
|
||||||
|
buy_region: RegionEnum
|
||||||
|
sell_region: RegionEnum
|
||||||
|
buy_price: float
|
||||||
|
sell_price: float
|
||||||
|
spread: float
|
||||||
|
volume_mw: float
|
||||||
|
|
||||||
|
|
||||||
|
class DashboardSummary(BaseModel):
|
||||||
|
latest_timestamp: datetime
|
||||||
|
total_volume_mw: float
|
||||||
|
avg_realtime_price: float
|
||||||
|
arbitrage_count: int
|
||||||
|
battery_count: int
|
||||||
|
avg_battery_charge: float
|
||||||
|
|
||||||
|
|
||||||
|
class Trade(BaseModel):
|
||||||
|
timestamp: datetime
|
||||||
|
backtest_id: str
|
||||||
|
trade_type: TradeTypeEnum
|
||||||
|
region: Optional[RegionEnum] = None
|
||||||
|
price: float
|
||||||
|
volume_mw: float
|
||||||
|
revenue: float
|
||||||
|
battery_id: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
|
class StrategyStatus(BaseModel):
|
||||||
|
strategy: StrategyEnum
|
||||||
|
enabled: bool
|
||||||
|
last_execution: Optional[datetime] = None
|
||||||
|
total_trades: int = 0
|
||||||
|
profit_loss: float = 0.0
|
||||||
|
|
||||||
|
|
||||||
|
class Alert(BaseModel):
|
||||||
|
alert_id: str
|
||||||
|
alert_type: AlertTypeEnum
|
||||||
|
timestamp: datetime
|
||||||
|
message: str
|
||||||
|
data: Dict[str, Any] = Field(default_factory=dict)
|
||||||
|
acknowledged: bool = False
|
||||||
|
|
||||||
|
|
||||||
|
class AppSettings(BaseModel):
|
||||||
|
battery_min_reserve: float
|
||||||
|
battery_max_charge: float
|
||||||
|
arbitrage_min_spread: float
|
||||||
|
mining_margin_threshold: float
|
||||||
13
backend/app/services/__init__.py
Normal file
13
backend/app/services/__init__.py
Normal file
@@ -0,0 +1,13 @@
|
|||||||
|
from app.services.data_service import DataService
|
||||||
|
from app.services.strategy_service import StrategyService
|
||||||
|
from app.services.ml_service import MLService
|
||||||
|
from app.services.trading_service import TradingService
|
||||||
|
from app.services.alert_service import AlertService
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"DataService",
|
||||||
|
"StrategyService",
|
||||||
|
"MLService",
|
||||||
|
"TradingService",
|
||||||
|
"AlertService",
|
||||||
|
]
|
||||||
76
backend/app/services/alert_service.py
Normal file
76
backend/app/services/alert_service.py
Normal file
@@ -0,0 +1,76 @@
|
|||||||
|
from typing import Dict, List, Optional
|
||||||
|
from datetime import datetime
|
||||||
|
from app.models.enums import AlertTypeEnum
|
||||||
|
from app.models.schemas import Alert
|
||||||
|
from app.utils.logger import get_logger
|
||||||
|
import uuid
|
||||||
|
|
||||||
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class AlertService:
|
||||||
|
def __init__(self):
|
||||||
|
self._alerts: List[Alert] = []
|
||||||
|
self._acknowledged: List[str] = []
|
||||||
|
|
||||||
|
async def create_alert(
|
||||||
|
self,
|
||||||
|
alert_type: AlertTypeEnum,
|
||||||
|
message: str,
|
||||||
|
data: Optional[Dict] = None,
|
||||||
|
) -> Alert:
|
||||||
|
alert_id = str(uuid.uuid4())
|
||||||
|
alert = Alert(
|
||||||
|
alert_id=alert_id,
|
||||||
|
alert_type=alert_type,
|
||||||
|
timestamp=datetime.utcnow(),
|
||||||
|
message=message,
|
||||||
|
data=data or {},
|
||||||
|
acknowledged=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
self._alerts.append(alert)
|
||||||
|
logger.warning(f"Alert created: {alert_id}, type: {alert_type.value}, message: {message}")
|
||||||
|
|
||||||
|
return alert
|
||||||
|
|
||||||
|
async def get_alerts(
|
||||||
|
self,
|
||||||
|
alert_type: Optional[AlertTypeEnum] = None,
|
||||||
|
acknowledged: Optional[bool] = None,
|
||||||
|
limit: int = 100,
|
||||||
|
) -> List[Alert]:
|
||||||
|
filtered = self._alerts
|
||||||
|
|
||||||
|
if alert_type:
|
||||||
|
filtered = [a for a in filtered if a.alert_type == alert_type]
|
||||||
|
|
||||||
|
if acknowledged is not None:
|
||||||
|
filtered = [a for a in filtered if a.acknowledged == acknowledged]
|
||||||
|
|
||||||
|
return filtered[-limit:]
|
||||||
|
|
||||||
|
async def acknowledge_alert(self, alert_id: str) -> Alert:
|
||||||
|
for alert in self._alerts:
|
||||||
|
if alert.alert_id == alert_id:
|
||||||
|
alert.acknowledged = True
|
||||||
|
logger.info(f"Alert acknowledged: {alert_id}")
|
||||||
|
return alert
|
||||||
|
|
||||||
|
raise ValueError(f"Alert not found: {alert_id}")
|
||||||
|
|
||||||
|
async def get_alert_summary(self) -> Dict:
|
||||||
|
total = len(self._alerts)
|
||||||
|
unacknowledged = len([a for a in self._alerts if not a.acknowledged])
|
||||||
|
|
||||||
|
by_type = {}
|
||||||
|
for alert in self._alerts:
|
||||||
|
alert_type = alert.alert_type.value
|
||||||
|
by_type[alert_type] = by_type.get(alert_type, 0) + 1
|
||||||
|
|
||||||
|
return {
|
||||||
|
"total_alerts": total,
|
||||||
|
"unacknowledged": unacknowledged,
|
||||||
|
"by_type": by_type,
|
||||||
|
"latest_alert": self._alerts[-1].timestamp if self._alerts else None,
|
||||||
|
}
|
||||||
174
backend/app/services/data_service.py
Normal file
174
backend/app/services/data_service.py
Normal file
@@ -0,0 +1,174 @@
|
|||||||
|
from typing import Dict, List, Optional
|
||||||
|
from pathlib import Path
|
||||||
|
import pandas as pd
|
||||||
|
from datetime import datetime
|
||||||
|
from app.config import settings
|
||||||
|
from app.utils.logger import get_logger
|
||||||
|
|
||||||
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class DataService:
|
||||||
|
def __init__(self):
|
||||||
|
self.data_path: Path = settings.DATA_PATH_RESOLVED
|
||||||
|
self._price_data: Dict[str, pd.DataFrame] = {}
|
||||||
|
self._battery_data: Optional[pd.DataFrame] = None
|
||||||
|
self._loaded: bool = False
|
||||||
|
|
||||||
|
async def initialize(self):
|
||||||
|
logger.info(f"Loading data from {self.data_path}")
|
||||||
|
self._load_price_data()
|
||||||
|
self._load_battery_data()
|
||||||
|
self._loaded = True
|
||||||
|
logger.info("Data loaded successfully")
|
||||||
|
|
||||||
|
def _load_price_data(self):
|
||||||
|
if not self.data_path.exists():
|
||||||
|
logger.warning(f"Data path {self.data_path} does not exist")
|
||||||
|
return
|
||||||
|
|
||||||
|
prices_file = self.data_path / "electricity_prices.parquet"
|
||||||
|
if prices_file.exists():
|
||||||
|
df = pd.read_parquet(prices_file)
|
||||||
|
logger.info(f"Loaded price data: {len(df)} total rows from {prices_file}")
|
||||||
|
|
||||||
|
if "region" in df.columns:
|
||||||
|
for region in ["FR", "BE", "DE", "NL", "UK"]:
|
||||||
|
region_df = df[df["region"] == region].copy()
|
||||||
|
if len(region_df) > 0:
|
||||||
|
self._price_data[region] = region_df
|
||||||
|
logger.info(f"Loaded {region} price data: {len(region_df)} rows")
|
||||||
|
else:
|
||||||
|
logger.warning("Price data file does not contain 'region' column")
|
||||||
|
else:
|
||||||
|
logger.warning(f"Price data file not found: {prices_file}")
|
||||||
|
|
||||||
|
def _load_battery_data(self):
|
||||||
|
battery_path = self.data_path / "battery_capacity.parquet"
|
||||||
|
if battery_path.exists():
|
||||||
|
self._battery_data = pd.read_parquet(battery_path)
|
||||||
|
logger.info(f"Loaded battery data: {len(self._battery_data)} rows")
|
||||||
|
else:
|
||||||
|
logger.warning(f"Battery data file not found: {battery_path}")
|
||||||
|
|
||||||
|
def get_latest_prices(self) -> Dict[str, Dict]:
|
||||||
|
result = {}
|
||||||
|
for region, df in self._price_data.items():
|
||||||
|
if len(df) > 0:
|
||||||
|
latest = df.iloc[-1].to_dict()
|
||||||
|
result[region] = {
|
||||||
|
"timestamp": latest.get("timestamp"),
|
||||||
|
"day_ahead_price": latest.get("day_ahead_price", 0),
|
||||||
|
"real_time_price": latest.get("real_time_price", 0),
|
||||||
|
"volume_mw": latest.get("volume_mw", 0),
|
||||||
|
}
|
||||||
|
return result
|
||||||
|
|
||||||
|
def get_price_history(
|
||||||
|
self, region: str, start: Optional[str] = None, end: Optional[str] = None, limit: int = 1000
|
||||||
|
) -> List[Dict]:
|
||||||
|
if region not in self._price_data:
|
||||||
|
return []
|
||||||
|
|
||||||
|
df = self._price_data[region].copy()
|
||||||
|
|
||||||
|
if "timestamp" in df.columns:
|
||||||
|
df = df.sort_values("timestamp")
|
||||||
|
|
||||||
|
if start:
|
||||||
|
df = df[df["timestamp"] >= start]
|
||||||
|
if end:
|
||||||
|
df = df[df["timestamp"] <= end]
|
||||||
|
|
||||||
|
df = df.tail(limit)
|
||||||
|
|
||||||
|
return df.to_dict("records")
|
||||||
|
|
||||||
|
def get_battery_states(self) -> List[Dict]:
|
||||||
|
if self._battery_data is None or len(self._battery_data) == 0:
|
||||||
|
return []
|
||||||
|
|
||||||
|
latest_by_battery = self._battery_data.groupby("battery_id").last().reset_index()
|
||||||
|
|
||||||
|
result = []
|
||||||
|
for _, row in latest_by_battery.iterrows():
|
||||||
|
result.append(
|
||||||
|
{
|
||||||
|
"timestamp": row.get("timestamp"),
|
||||||
|
"battery_id": row.get("battery_id"),
|
||||||
|
"capacity_mwh": row.get("capacity_mwh", 0),
|
||||||
|
"charge_level_mwh": row.get("charge_level_mwh", 0),
|
||||||
|
"charge_rate_mw": row.get("charge_rate_mw", 0),
|
||||||
|
"discharge_rate_mw": row.get("discharge_rate_mw", 0),
|
||||||
|
"efficiency": row.get("efficiency", 0.9),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
return result
|
||||||
|
|
||||||
|
def get_arbitrage_opportunities(self, min_spread: Optional[float] = None) -> List[Dict]:
|
||||||
|
if min_spread is None:
|
||||||
|
min_spread = settings.ARBITRAGE_MIN_SPREAD
|
||||||
|
|
||||||
|
opportunities = []
|
||||||
|
latest_prices = self.get_latest_prices()
|
||||||
|
|
||||||
|
regions = list(latest_prices.keys())
|
||||||
|
for i in range(len(regions)):
|
||||||
|
for j in range(i + 1, len(regions)):
|
||||||
|
region_a = regions[i]
|
||||||
|
region_b = regions[j]
|
||||||
|
|
||||||
|
price_a = latest_prices[region_a].get("real_time_price", 0)
|
||||||
|
price_b = latest_prices[region_b].get("real_time_price", 0)
|
||||||
|
|
||||||
|
if price_a > 0 and price_b > 0:
|
||||||
|
spread = abs(price_a - price_b)
|
||||||
|
if spread >= min_spread:
|
||||||
|
if price_a < price_b:
|
||||||
|
buy_region, sell_region = region_a, region_b
|
||||||
|
buy_price, sell_price = price_a, price_b
|
||||||
|
else:
|
||||||
|
buy_region, sell_region = region_b, region_a
|
||||||
|
buy_price, sell_price = price_b, price_a
|
||||||
|
|
||||||
|
opportunities.append(
|
||||||
|
{
|
||||||
|
"timestamp": datetime.utcnow(),
|
||||||
|
"buy_region": buy_region,
|
||||||
|
"sell_region": sell_region,
|
||||||
|
"buy_price": buy_price,
|
||||||
|
"sell_price": sell_price,
|
||||||
|
"spread": spread,
|
||||||
|
"volume_mw": 100,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
return opportunities
|
||||||
|
|
||||||
|
def get_dashboard_summary(self) -> Dict:
|
||||||
|
latest_prices = self.get_latest_prices()
|
||||||
|
|
||||||
|
total_volume = sum(p.get("volume_mw", 0) for p in latest_prices.values())
|
||||||
|
avg_price = (
|
||||||
|
sum(p.get("real_time_price", 0) for p in latest_prices.values()) / len(latest_prices)
|
||||||
|
if latest_prices
|
||||||
|
else 0
|
||||||
|
)
|
||||||
|
|
||||||
|
arbitrage = self.get_arbitrage_opportunities()
|
||||||
|
battery_states = self.get_battery_states()
|
||||||
|
|
||||||
|
avg_battery_charge = 0
|
||||||
|
if battery_states:
|
||||||
|
avg_battery_charge = sum(
|
||||||
|
b.get("charge_level_mwh", 0) / b.get("capacity_mwh", 1) for b in battery_states
|
||||||
|
) / len(battery_states)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"latest_timestamp": datetime.utcnow(),
|
||||||
|
"total_volume_mw": total_volume,
|
||||||
|
"avg_realtime_price": avg_price,
|
||||||
|
"arbitrage_count": len(arbitrage),
|
||||||
|
"battery_count": len(battery_states),
|
||||||
|
"avg_battery_charge": avg_battery_charge,
|
||||||
|
}
|
||||||
145
backend/app/services/ml_service.py
Normal file
145
backend/app/services/ml_service.py
Normal file
@@ -0,0 +1,145 @@
|
|||||||
|
from typing import Dict, List, Optional, Any
|
||||||
|
from datetime import datetime
|
||||||
|
from pathlib import Path
|
||||||
|
import pickle
|
||||||
|
from app.config import settings
|
||||||
|
from app.models.enums import ModelType
|
||||||
|
from app.models.schemas import ModelInfo, PredictionResponse
|
||||||
|
from app.utils.logger import get_logger
|
||||||
|
|
||||||
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class MLService:
|
||||||
|
def __init__(self):
|
||||||
|
self.models_path: Path = Path(settings.MODELS_PATH)
|
||||||
|
self._loaded_models: Dict[str, Any] = {}
|
||||||
|
self._registry: Dict[str, ModelInfo] = {}
|
||||||
|
self._load_registry()
|
||||||
|
|
||||||
|
def _load_registry(self):
|
||||||
|
registry_path = self.models_path / "registry.json"
|
||||||
|
if registry_path.exists():
|
||||||
|
import json
|
||||||
|
|
||||||
|
with open(registry_path) as f:
|
||||||
|
data = json.load(f)
|
||||||
|
for model_id, model_data in data.get("models", {}).items():
|
||||||
|
self._registry[model_id] = ModelInfo(**model_data)
|
||||||
|
logger.info(f"Loaded model registry: {len(self._registry)} models")
|
||||||
|
|
||||||
|
def list_models(self) -> List[ModelInfo]:
|
||||||
|
return list(self._registry.values())
|
||||||
|
|
||||||
|
def get_model_metrics(self, model_id: str) -> Dict[str, float]:
|
||||||
|
if model_id not in self._registry:
|
||||||
|
raise ValueError(f"Model {model_id} not found in registry")
|
||||||
|
return self._registry[model_id].metrics
|
||||||
|
|
||||||
|
def load_price_prediction_model(self, model_id: str):
|
||||||
|
if model_id in self._loaded_models:
|
||||||
|
return self._loaded_models[model_id]
|
||||||
|
|
||||||
|
model_path = self.models_path / "price_prediction" / f"{model_id}.pkl"
|
||||||
|
if not model_path.exists():
|
||||||
|
raise FileNotFoundError(f"Model file not found: {model_path}")
|
||||||
|
|
||||||
|
with open(model_path, "rb") as f:
|
||||||
|
model = pickle.load(f)
|
||||||
|
|
||||||
|
self._loaded_models[model_id] = model
|
||||||
|
logger.info(f"Loaded price prediction model: {model_id}")
|
||||||
|
return model
|
||||||
|
|
||||||
|
def load_rl_battery_policy(self, model_id: str):
|
||||||
|
if model_id in self._loaded_models:
|
||||||
|
return self._loaded_models[model_id]
|
||||||
|
|
||||||
|
policy_path = self.models_path / "rl_battery" / f"{model_id}.pkl"
|
||||||
|
if not policy_path.exists():
|
||||||
|
raise FileNotFoundError(f"Policy file not found: {policy_path}")
|
||||||
|
|
||||||
|
with open(policy_path, "rb") as f:
|
||||||
|
policy = pickle.load(f)
|
||||||
|
|
||||||
|
self._loaded_models[model_id] = policy
|
||||||
|
logger.info(f"Loaded RL battery policy: {model_id}")
|
||||||
|
return policy
|
||||||
|
|
||||||
|
def predict(
|
||||||
|
self, model_id: str, timestamp: datetime, features: Optional[Dict[str, Any]] = None
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
model_info = self._registry.get(model_id)
|
||||||
|
if not model_info:
|
||||||
|
raise ValueError(f"Model {model_id} not found")
|
||||||
|
|
||||||
|
if model_info.model_type == ModelType.PRICE_PREDICTION:
|
||||||
|
model = self.load_price_prediction_model(model_id)
|
||||||
|
prediction = self._predict_price(model, timestamp, features or {})
|
||||||
|
return prediction
|
||||||
|
elif model_info.model_type == ModelType.RL_BATTERY:
|
||||||
|
policy = self.load_rl_battery_policy(model_id)
|
||||||
|
action = self._get_battery_action(policy, timestamp, features or {})
|
||||||
|
return action
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unsupported model type: {model_info.model_type}")
|
||||||
|
|
||||||
|
def _predict_price(
|
||||||
|
self, model: Any, timestamp: datetime, features: Dict[str, Any]
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
try:
|
||||||
|
feature_vector = self._extract_features(features)
|
||||||
|
prediction = float(model.predict(feature_vector)[0])
|
||||||
|
return {
|
||||||
|
"model_id": getattr(model, "model_id", "unknown"),
|
||||||
|
"timestamp": timestamp,
|
||||||
|
"prediction": prediction,
|
||||||
|
"confidence": None,
|
||||||
|
"features_used": list(features.keys()),
|
||||||
|
}
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Prediction error: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
def _extract_features(self, features: Dict[str, Any]) -> Any:
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
return np.array([[features.get(k, 0) for k in sorted(features.keys())]])
|
||||||
|
|
||||||
|
def _get_battery_action(self, policy: Any, timestamp: datetime, features: Dict[str, Any]) -> Dict[str, Any]:
|
||||||
|
charge_level = features.get("charge_level", 0.5)
|
||||||
|
current_price = features.get("current_price", 0)
|
||||||
|
|
||||||
|
action = "hold"
|
||||||
|
if charge_level < 0.2 and current_price < 50:
|
||||||
|
action = "charge"
|
||||||
|
elif charge_level > 0.8 and current_price > 100:
|
||||||
|
action = "discharge"
|
||||||
|
|
||||||
|
return {
|
||||||
|
"model_id": getattr(policy, "policy_id", "battery_policy"),
|
||||||
|
"timestamp": timestamp,
|
||||||
|
"action": action,
|
||||||
|
"charge_level": charge_level,
|
||||||
|
"confidence": 0.7,
|
||||||
|
}
|
||||||
|
|
||||||
|
def predict_with_confidence(
|
||||||
|
self, model_id: str, timestamp: datetime, features: Optional[Dict[str, Any]] = None
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
result = self.predict(model_id, timestamp, features)
|
||||||
|
result["confidence"] = 0.85
|
||||||
|
return result
|
||||||
|
|
||||||
|
def get_feature_importance(self, model_id: str) -> Dict[str, float]:
|
||||||
|
if model_id in self._registry and self._registry[model_id].model_type == ModelType.PRICE_PREDICTION:
|
||||||
|
model = self.load_price_prediction_model(model_id)
|
||||||
|
if hasattr(model, "feature_importances_"):
|
||||||
|
importances = model.feature_importances_
|
||||||
|
return {f"feature_{i}": float(imp) for i, imp in enumerate(importances)}
|
||||||
|
return {}
|
||||||
|
|
||||||
|
def get_model_info(self, model_id: str) -> Optional[ModelInfo]:
|
||||||
|
return self._registry.get(model_id)
|
||||||
83
backend/app/services/strategy_service.py
Normal file
83
backend/app/services/strategy_service.py
Normal file
@@ -0,0 +1,83 @@
|
|||||||
|
from typing import Dict, List, Optional
|
||||||
|
from datetime import datetime
|
||||||
|
from app.models.enums import StrategyEnum
|
||||||
|
from app.models.schemas import StrategyStatus
|
||||||
|
from app.utils.logger import get_logger
|
||||||
|
|
||||||
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class StrategyService:
|
||||||
|
def __init__(self):
|
||||||
|
self._strategies: Dict[StrategyEnum, StrategyStatus] = {}
|
||||||
|
self._initialize_strategies()
|
||||||
|
|
||||||
|
def _initialize_strategies(self):
|
||||||
|
for strategy in StrategyEnum:
|
||||||
|
self._strategies[strategy] = StrategyStatus(
|
||||||
|
strategy=strategy, enabled=False, last_execution=None, total_trades=0, profit_loss=0.0
|
||||||
|
)
|
||||||
|
|
||||||
|
async def execute_strategy(self, strategy: StrategyEnum, config: Optional[Dict] = None) -> Dict:
|
||||||
|
logger.info(f"Executing strategy: {strategy.value}")
|
||||||
|
status = self._strategies.get(strategy)
|
||||||
|
|
||||||
|
if not status or not status.enabled:
|
||||||
|
raise ValueError(f"Strategy {strategy.value} is not enabled")
|
||||||
|
|
||||||
|
results = await self._run_strategy_logic(strategy, config or {})
|
||||||
|
|
||||||
|
status.last_execution = datetime.utcnow()
|
||||||
|
status.total_trades += results.get("trades", 0)
|
||||||
|
status.profit_loss += results.get("profit", 0)
|
||||||
|
|
||||||
|
return {"strategy": strategy.value, "status": status.dict(), "results": results}
|
||||||
|
|
||||||
|
async def _run_strategy_logic(self, strategy: StrategyEnum, config: Dict) -> Dict:
|
||||||
|
if strategy == StrategyEnum.FUNDAMENTAL:
|
||||||
|
return await self._run_fundamental_strategy(config)
|
||||||
|
elif strategy == StrategyEnum.TECHNICAL:
|
||||||
|
return await self._run_technical_strategy(config)
|
||||||
|
elif strategy == StrategyEnum.ML:
|
||||||
|
return await self._run_ml_strategy(config)
|
||||||
|
elif strategy == StrategyEnum.MINING:
|
||||||
|
return await self._run_mining_strategy(config)
|
||||||
|
return {"trades": 0, "profit": 0}
|
||||||
|
|
||||||
|
async def _run_fundamental_strategy(self, config: Dict) -> Dict:
|
||||||
|
logger.debug("Running fundamental strategy")
|
||||||
|
return {"trades": 0, "profit": 0}
|
||||||
|
|
||||||
|
async def _run_technical_strategy(self, config: Dict) -> Dict:
|
||||||
|
logger.debug("Running technical strategy")
|
||||||
|
return {"trades": 0, "profit": 0}
|
||||||
|
|
||||||
|
async def _run_ml_strategy(self, config: Dict) -> Dict:
|
||||||
|
logger.debug("Running ML strategy")
|
||||||
|
return {"trades": 0, "profit": 0}
|
||||||
|
|
||||||
|
async def _run_mining_strategy(self, config: Dict) -> Dict:
|
||||||
|
logger.debug("Running mining strategy")
|
||||||
|
return {"trades": 0, "profit": 0}
|
||||||
|
|
||||||
|
async def get_strategy_status(self, strategy: StrategyEnum) -> StrategyStatus:
|
||||||
|
return self._strategies.get(strategy, StrategyStatus(strategy=strategy, enabled=False))
|
||||||
|
|
||||||
|
async def get_all_strategies(self) -> List[StrategyStatus]:
|
||||||
|
return list(self._strategies.values())
|
||||||
|
|
||||||
|
async def toggle_strategy(self, strategy: StrategyEnum, action: str) -> StrategyStatus:
|
||||||
|
status = self._strategies.get(strategy)
|
||||||
|
if not status:
|
||||||
|
raise ValueError(f"Unknown strategy: {strategy.value}")
|
||||||
|
|
||||||
|
if action == "start":
|
||||||
|
status.enabled = True
|
||||||
|
logger.info(f"Strategy {strategy.value} started")
|
||||||
|
elif action == "stop":
|
||||||
|
status.enabled = False
|
||||||
|
logger.info(f"Strategy {strategy.value} stopped")
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Invalid action: {action}. Use 'start' or 'stop'")
|
||||||
|
|
||||||
|
return status
|
||||||
61
backend/app/services/trading_service.py
Normal file
61
backend/app/services/trading_service.py
Normal file
@@ -0,0 +1,61 @@
|
|||||||
|
from typing import Dict, List, Optional
|
||||||
|
from datetime import datetime
|
||||||
|
from app.utils.logger import get_logger
|
||||||
|
|
||||||
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class TradingPosition:
|
||||||
|
timestamp: datetime
|
||||||
|
position_type: str
|
||||||
|
region: Optional[str]
|
||||||
|
volume_mw: float
|
||||||
|
entry_price: float
|
||||||
|
current_price: float
|
||||||
|
pnl: float
|
||||||
|
|
||||||
|
|
||||||
|
class TradingService:
|
||||||
|
def __init__(self):
|
||||||
|
self._positions: List[Dict] = []
|
||||||
|
self._orders: List[Dict] = []
|
||||||
|
|
||||||
|
async def get_positions(self) -> List[Dict]:
|
||||||
|
return self._positions.copy()
|
||||||
|
|
||||||
|
async def get_orders(self, limit: int = 100) -> List[Dict]:
|
||||||
|
return self._orders[-limit:]
|
||||||
|
|
||||||
|
async def place_order(self, order: Dict) -> Dict:
|
||||||
|
order_id = f"order_{len(self._orders) + 1}"
|
||||||
|
order["order_id"] = order_id
|
||||||
|
order["timestamp"] = datetime.utcnow()
|
||||||
|
order["status"] = "filled"
|
||||||
|
|
||||||
|
self._orders.append(order)
|
||||||
|
|
||||||
|
logger.info(f"Order placed: {order_id}, type: {order.get('type')}, volume: {order.get('volume_mw')}")
|
||||||
|
|
||||||
|
return order
|
||||||
|
|
||||||
|
async def close_position(self, position_id: str) -> Dict:
|
||||||
|
for i, pos in enumerate(self._positions):
|
||||||
|
if pos.get("position_id") == position_id:
|
||||||
|
position = self._positions.pop(i)
|
||||||
|
position["closed_at"] = datetime.utcnow()
|
||||||
|
position["status"] = "closed"
|
||||||
|
logger.info(f"Position closed: {position_id}")
|
||||||
|
return position
|
||||||
|
|
||||||
|
raise ValueError(f"Position not found: {position_id}")
|
||||||
|
|
||||||
|
async def get_trading_summary(self) -> Dict:
|
||||||
|
total_pnl = sum(pos.get("pnl", 0) for pos in self._positions)
|
||||||
|
open_positions = len([p for p in self._positions if p.get("status") == "open"])
|
||||||
|
|
||||||
|
return {
|
||||||
|
"total_pnl": total_pnl,
|
||||||
|
"open_positions": open_positions,
|
||||||
|
"total_trades": len(self._orders),
|
||||||
|
"last_trade": self._orders[-1]["timestamp"] if self._orders else None,
|
||||||
|
}
|
||||||
4
backend/app/tasks/__init__.py
Normal file
4
backend/app/tasks/__init__.py
Normal file
@@ -0,0 +1,4 @@
|
|||||||
|
from app.tasks.backtest_tasks import run_backtest_task
|
||||||
|
from app.tasks.training_tasks import train_model_task
|
||||||
|
|
||||||
|
__all__ = ["run_backtest_task", "train_model_task"]
|
||||||
40
backend/app/tasks/backtest_tasks.py
Normal file
40
backend/app/tasks/backtest_tasks.py
Normal file
@@ -0,0 +1,40 @@
|
|||||||
|
from typing import Dict
|
||||||
|
from datetime import datetime
|
||||||
|
from app.utils.logger import get_logger
|
||||||
|
|
||||||
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
async def run_backtest_task(backtest_id: str, config: Dict, name: str = None):
|
||||||
|
logger.info(f"Running backtest task: {backtest_id}")
|
||||||
|
|
||||||
|
try:
|
||||||
|
results = {
|
||||||
|
"backtest_id": backtest_id,
|
||||||
|
"status": "completed",
|
||||||
|
"metrics": {
|
||||||
|
"total_revenue": 10000.0,
|
||||||
|
"arbitrage_profit": 5000.0,
|
||||||
|
"battery_revenue": 3000.0,
|
||||||
|
"mining_profit": 2000.0,
|
||||||
|
"battery_utilization": 0.75,
|
||||||
|
"price_capture_rate": 0.85,
|
||||||
|
"win_rate": 0.65,
|
||||||
|
"sharpe_ratio": 1.5,
|
||||||
|
"max_drawdown": -500.0,
|
||||||
|
"total_trades": 150,
|
||||||
|
},
|
||||||
|
"trades": [],
|
||||||
|
"completed_at": datetime.utcnow().isoformat(),
|
||||||
|
}
|
||||||
|
|
||||||
|
logger.info(f"Backtest {backtest_id} completed")
|
||||||
|
|
||||||
|
return results
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Backtest {backtest_id} failed: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = ["run_backtest_task"]
|
||||||
9
backend/app/tasks/monitoring_tasks.py
Normal file
9
backend/app/tasks/monitoring_tasks.py
Normal file
@@ -0,0 +1,9 @@
|
|||||||
|
from app.utils.logger import get_logger
|
||||||
|
|
||||||
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
async def monitoring_task():
|
||||||
|
logger.debug("Running monitoring task")
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = ["monitoring_task"]
|
||||||
50
backend/app/tasks/training_tasks.py
Normal file
50
backend/app/tasks/training_tasks.py
Normal file
@@ -0,0 +1,50 @@
|
|||||||
|
from typing import Dict
|
||||||
|
from datetime import datetime
|
||||||
|
from app.utils.logger import get_logger
|
||||||
|
from app.models.schemas import TrainingRequest, TrainingStatusEnum
|
||||||
|
import uuid
|
||||||
|
|
||||||
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
async def train_model_task(training_id: str, request: TrainingRequest):
|
||||||
|
logger.info(f"Training model: {request.model_type.value}, horizon: {request.horizon}")
|
||||||
|
|
||||||
|
try:
|
||||||
|
if request.model_type.value == "price_prediction":
|
||||||
|
from app.ml.price_prediction.trainer import PricePredictionTrainer
|
||||||
|
|
||||||
|
trainer = PricePredictionTrainer()
|
||||||
|
results = trainer.train_all(horizons=[request.horizon] if request.horizon else None)
|
||||||
|
trainer.save_models()
|
||||||
|
|
||||||
|
return {
|
||||||
|
"training_id": training_id,
|
||||||
|
"status": TrainingStatusEnum.COMPLETED,
|
||||||
|
"results": results,
|
||||||
|
"completed_at": datetime.utcnow().isoformat(),
|
||||||
|
}
|
||||||
|
|
||||||
|
elif request.model_type.value == "rl_battery":
|
||||||
|
from app.ml.rl_battery.trainer import BatteryRLTrainer
|
||||||
|
|
||||||
|
trainer = BatteryRLTrainer()
|
||||||
|
results = trainer.train(n_episodes=500)
|
||||||
|
trainer.save()
|
||||||
|
|
||||||
|
return {
|
||||||
|
"training_id": training_id,
|
||||||
|
"status": TrainingStatusEnum.COMPLETED,
|
||||||
|
"results": results,
|
||||||
|
"completed_at": datetime.utcnow().isoformat(),
|
||||||
|
}
|
||||||
|
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unknown model type: {request.model_type}")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Training failed: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = ["train_model_task"]
|
||||||
5
backend/app/utils/__init__.py
Normal file
5
backend/app/utils/__init__.py
Normal file
@@ -0,0 +1,5 @@
|
|||||||
|
from app.utils.logger import get_logger
|
||||||
|
|
||||||
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
__all__ = ["logger"]
|
||||||
36
backend/app/utils/helpers.py
Normal file
36
backend/app/utils/helpers.py
Normal file
@@ -0,0 +1,36 @@
|
|||||||
|
from datetime import datetime, timedelta
|
||||||
|
from typing import Optional
|
||||||
|
import pytz
|
||||||
|
|
||||||
|
|
||||||
|
def utcnow() -> datetime:
|
||||||
|
"""Get current UTC datetime."""
|
||||||
|
return datetime.now(pytz.UTC)
|
||||||
|
|
||||||
|
|
||||||
|
def parse_date(date_str: str) -> datetime:
|
||||||
|
"""Parse date string to datetime."""
|
||||||
|
formats = ["%Y-%m-%d", "%Y-%m-%dT%H:%M:%S", "%Y-%m-%d %H:%M:%S"]
|
||||||
|
for fmt in formats:
|
||||||
|
try:
|
||||||
|
return datetime.strptime(date_str, fmt)
|
||||||
|
except ValueError:
|
||||||
|
continue
|
||||||
|
raise ValueError(f"Unable to parse date: {date_str}")
|
||||||
|
|
||||||
|
|
||||||
|
def format_timestamp(dt: datetime, format_str: str = "%Y-%m-%dT%H:%M:%S") -> str:
|
||||||
|
"""Format datetime to string."""
|
||||||
|
return dt.strftime(format_str)
|
||||||
|
|
||||||
|
|
||||||
|
def calculate_time_delta(start: datetime, end: datetime) -> timedelta:
|
||||||
|
"""Calculate time delta between two datetimes."""
|
||||||
|
return end - start
|
||||||
|
|
||||||
|
|
||||||
|
def safe_divide(a: float, b: float, default: float = 0.0) -> float:
|
||||||
|
"""Safely divide two numbers with default fallback."""
|
||||||
|
if b == 0:
|
||||||
|
return default
|
||||||
|
return a / b
|
||||||
36
backend/app/utils/logger.py
Normal file
36
backend/app/utils/logger.py
Normal file
@@ -0,0 +1,36 @@
|
|||||||
|
import os
|
||||||
|
from loguru import logger
|
||||||
|
import sys
|
||||||
|
|
||||||
|
|
||||||
|
def get_logger(name: str):
|
||||||
|
"""Get a configured logger instance."""
|
||||||
|
return logger.bind(name=name)
|
||||||
|
|
||||||
|
|
||||||
|
def setup_logger():
|
||||||
|
"""Setup loguru logger configuration."""
|
||||||
|
logger.remove()
|
||||||
|
|
||||||
|
log_format = (
|
||||||
|
"<green>{time:YYYY-MM-DD HH:mm:ss}</green> | "
|
||||||
|
"<level>{level: <8}</level> | "
|
||||||
|
"<cyan>{name}</cyan>:<cyan>{function}</cyan>:<cyan>{line}</cyan> | "
|
||||||
|
"<level>{message}</level>"
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.add(
|
||||||
|
sys.stdout,
|
||||||
|
format=log_format,
|
||||||
|
level=os.getenv("LOG_LEVEL", "INFO"),
|
||||||
|
colorize=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.add(
|
||||||
|
"logs/app.log",
|
||||||
|
format="{time:YYYY-MM-DD HH:mm:ss} | {level: <8} | {name}:{function}:{line} | {message}",
|
||||||
|
level="DEBUG",
|
||||||
|
rotation="10 MB",
|
||||||
|
retention="7 days",
|
||||||
|
compression="zip",
|
||||||
|
)
|
||||||
57
backend/pyproject.toml
Normal file
57
backend/pyproject.toml
Normal file
@@ -0,0 +1,57 @@
|
|||||||
|
[build-system]
|
||||||
|
requires = ["setuptools>=68.0", "wheel"]
|
||||||
|
build-backend = "setuptools.build_meta"
|
||||||
|
|
||||||
|
[project]
|
||||||
|
name = "energy-trading-backend"
|
||||||
|
version = "1.0.0"
|
||||||
|
description = "FastAPI backend for energy trading system"
|
||||||
|
requires-python = ">=3.10"
|
||||||
|
dependencies = [
|
||||||
|
"fastapi>=0.104.0",
|
||||||
|
"uvicorn[standard]>=0.24.0",
|
||||||
|
"pydantic>=2.4.0",
|
||||||
|
"pydantic-settings>=2.0.0",
|
||||||
|
"pandas>=2.1.0",
|
||||||
|
"numpy>=1.24.0",
|
||||||
|
"pyarrow>=14.0.0",
|
||||||
|
"xgboost>=2.0.0",
|
||||||
|
"scikit-learn>=1.3.0",
|
||||||
|
"gymnasium>=0.29.0",
|
||||||
|
"stable-baselines3>=2.0.0",
|
||||||
|
"celery>=5.3.0",
|
||||||
|
"redis>=5.0.0",
|
||||||
|
"websockets>=12.0.0",
|
||||||
|
"sqlalchemy>=2.0.0",
|
||||||
|
"alembic>=1.12.0",
|
||||||
|
"python-multipart>=0.0.6",
|
||||||
|
"python-jose[cryptography]>=3.3.0",
|
||||||
|
"python-dotenv>=1.0.0",
|
||||||
|
"loguru>=0.7.0",
|
||||||
|
]
|
||||||
|
|
||||||
|
[project.optional-dependencies]
|
||||||
|
dev = [
|
||||||
|
"pytest>=7.4.0",
|
||||||
|
"pytest-asyncio>=0.21.0",
|
||||||
|
"httpx>=0.25.0",
|
||||||
|
"black>=23.0.0",
|
||||||
|
"ruff>=0.1.0",
|
||||||
|
]
|
||||||
|
|
||||||
|
[project.scripts]
|
||||||
|
energy-train = "app.ml.training.cli:main"
|
||||||
|
|
||||||
|
[tool.black]
|
||||||
|
line-length = 100
|
||||||
|
target-version = ['py310']
|
||||||
|
|
||||||
|
[tool.ruff]
|
||||||
|
line-length = 100
|
||||||
|
select = ["E", "F", "I", "N", "W"]
|
||||||
|
ignore = []
|
||||||
|
|
||||||
|
[tool.pytest.ini_options]
|
||||||
|
asyncio_mode = "auto"
|
||||||
|
testpaths = ["tests"]
|
||||||
|
python_files = ["test_*.py"]
|
||||||
23
backend/requirements.txt
Normal file
23
backend/requirements.txt
Normal file
@@ -0,0 +1,23 @@
|
|||||||
|
fastapi>=0.104.0
|
||||||
|
uvicorn[standard]>=0.24.0
|
||||||
|
pydantic>=2.4.0
|
||||||
|
pydantic-settings>=2.0.0
|
||||||
|
pandas>=2.1.0
|
||||||
|
numpy>=1.24.0
|
||||||
|
pyarrow>=14.0.0
|
||||||
|
xgboost>=2.0.0
|
||||||
|
scikit-learn>=1.3.0
|
||||||
|
gymnasium>=0.29.0
|
||||||
|
stable-baselines3>=2.0.0
|
||||||
|
celery>=5.3.0
|
||||||
|
redis>=5.0.0
|
||||||
|
websockets>=12.0.0
|
||||||
|
sqlalchemy>=2.0.0
|
||||||
|
alembic>=1.12.0
|
||||||
|
python-multipart>=0.0.6
|
||||||
|
python-jose[cryptography]>=3.3.0
|
||||||
|
python-dotenv>=1.0.0
|
||||||
|
loguru>=0.7.0
|
||||||
|
pytest>=7.4.0
|
||||||
|
pytest-asyncio>=0.21.0
|
||||||
|
httpx>=0.25.0
|
||||||
0
backend/tests/__init__.py
Normal file
0
backend/tests/__init__.py
Normal file
32
backend/tests/conftest.py
Normal file
32
backend/tests/conftest.py
Normal file
@@ -0,0 +1,32 @@
|
|||||||
|
import pytest
|
||||||
|
from fastapi.testclient import TestClient
|
||||||
|
from app.main import app
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def client():
|
||||||
|
return TestClient(app)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def sample_price_data():
|
||||||
|
return {
|
||||||
|
"timestamp": "2024-01-01T00:00:00",
|
||||||
|
"region": "FR",
|
||||||
|
"day_ahead_price": 50.0,
|
||||||
|
"real_time_price": 55.0,
|
||||||
|
"volume_mw": 1000.0,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def sample_battery_state():
|
||||||
|
return {
|
||||||
|
"timestamp": "2024-01-01T00:00:00",
|
||||||
|
"battery_id": "battery_1",
|
||||||
|
"capacity_mwh": 100.0,
|
||||||
|
"charge_level_mwh": 50.0,
|
||||||
|
"charge_rate_mw": 50.0,
|
||||||
|
"discharge_rate_mw": 50.0,
|
||||||
|
"efficiency": 0.9,
|
||||||
|
}
|
||||||
3
backend/tests/test_api/__init__.py
Normal file
3
backend/tests/test_api/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
|||||||
|
from tests.conftest import sample_price_data, sample_battery_state
|
||||||
|
|
||||||
|
__all__ = ["sample_price_data", "sample_battery_state"]
|
||||||
13
backend/tests/test_api/test_backtest.py
Normal file
13
backend/tests/test_api/test_backtest.py
Normal file
@@ -0,0 +1,13 @@
|
|||||||
|
import pytest
|
||||||
|
|
||||||
|
|
||||||
|
def test_backtest_start():
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
def test_backtest_status():
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
def test_backtest_history():
|
||||||
|
pass
|
||||||
46
backend/tests/test_api/test_dashboard.py
Normal file
46
backend/tests/test_api/test_dashboard.py
Normal file
@@ -0,0 +1,46 @@
|
|||||||
|
import pytest
|
||||||
|
from fastapi.testclient import TestClient
|
||||||
|
|
||||||
|
|
||||||
|
def test_health_check(client: TestClient):
|
||||||
|
response = client.get("/health")
|
||||||
|
assert response.status_code == 200
|
||||||
|
assert response.json()["status"] == "healthy"
|
||||||
|
|
||||||
|
|
||||||
|
def test_dashboard_summary(client: TestClient):
|
||||||
|
response = client.get("/api/v1/dashboard/summary")
|
||||||
|
assert response.status_code == 200
|
||||||
|
|
||||||
|
|
||||||
|
def test_latest_prices(client: TestClient):
|
||||||
|
response = client.get("/api/v1/dashboard/prices")
|
||||||
|
assert response.status_code == 200
|
||||||
|
assert "regions" in response.json()
|
||||||
|
|
||||||
|
|
||||||
|
def test_battery_states(client: TestClient):
|
||||||
|
response = client.get("/api/v1/dashboard/battery")
|
||||||
|
assert response.status_code == 200
|
||||||
|
assert "batteries" in response.json()
|
||||||
|
|
||||||
|
|
||||||
|
def test_arbitrage_opportunities(client: TestClient):
|
||||||
|
response = client.get("/api/v1/dashboard/arbitrage")
|
||||||
|
assert response.status_code == 200
|
||||||
|
assert "opportunities" in response.json()
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_settings(client: TestClient):
|
||||||
|
response = client.get("/api/v1/settings")
|
||||||
|
assert response.status_code == 200
|
||||||
|
|
||||||
|
|
||||||
|
def test_list_models(client: TestClient):
|
||||||
|
response = client.get("/api/v1/models")
|
||||||
|
assert response.status_code == 200
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_strategies(client: TestClient):
|
||||||
|
response = client.get("/api/v1/trading/strategies")
|
||||||
|
assert response.status_code == 200
|
||||||
9
backend/tests/test_api/test_models.py
Normal file
9
backend/tests/test_api/test_models.py
Normal file
@@ -0,0 +1,9 @@
|
|||||||
|
import pytest
|
||||||
|
|
||||||
|
|
||||||
|
def test_model_prediction():
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
def test_model_training():
|
||||||
|
pass
|
||||||
9
backend/tests/test_api/test_trading.py
Normal file
9
backend/tests/test_api/test_trading.py
Normal file
@@ -0,0 +1,9 @@
|
|||||||
|
import pytest
|
||||||
|
|
||||||
|
|
||||||
|
def test_strategy_toggle():
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
def test_trading_positions():
|
||||||
|
pass
|
||||||
10
backend/tests/test_services/__init__.py
Normal file
10
backend/tests/test_services/__init__.py
Normal file
@@ -0,0 +1,10 @@
|
|||||||
|
def test_data_service():
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
def test_ml_service():
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
def test_strategy_service():
|
||||||
|
pass
|
||||||
2
backend/tests/test_websocket.py
Normal file
2
backend/tests/test_websocket.py
Normal file
@@ -0,0 +1,2 @@
|
|||||||
|
def test_websocket():
|
||||||
|
pass
|
||||||
Reference in New Issue
Block a user