Add FastAPI backend for energy trading system
Implements FastAPI backend with ML model support for energy trading, including price prediction models and RL-based battery trading policy. Features dashboard, trading, backtest, and settings API routes with WebSocket support for real-time updates.
This commit is contained in:
3
backend/app/api/__init__.py
Normal file
3
backend/app/api/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
from app.api.routes import dashboard, backtest, models, trading, settings
|
||||
|
||||
__all__ = ["dashboard", "backtest", "models", "trading", "settings"]
|
||||
3
backend/app/api/routes/__init__.py
Normal file
3
backend/app/api/routes/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
from app.api.routes import dashboard, backtest, models, trading, settings
|
||||
|
||||
__all__ = ["dashboard", "backtest", "models", "trading", "settings"]
|
||||
108
backend/app/api/routes/backtest.py
Normal file
108
backend/app/api/routes/backtest.py
Normal file
@@ -0,0 +1,108 @@
|
||||
from typing import List, Optional
|
||||
from fastapi import APIRouter, BackgroundTasks, HTTPException
|
||||
from datetime import datetime
|
||||
from app.models.schemas import (
|
||||
BacktestConfig,
|
||||
BacktestMetrics,
|
||||
BacktestStatusEnum,
|
||||
Trade,
|
||||
)
|
||||
from app.utils.logger import get_logger
|
||||
import uuid
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
_backtest_store: dict = {}
|
||||
_backtest_results: dict = {}
|
||||
|
||||
|
||||
@router.post("/start")
|
||||
async def start_backtest(config: BacktestConfig, name: Optional[str] = None):
|
||||
backtest_id = str(uuid.uuid4())
|
||||
|
||||
_backtest_store[backtest_id] = {
|
||||
"backtest_id": backtest_id,
|
||||
"name": name or f"Backtest {backtest_id[:8]}",
|
||||
"status": BacktestStatusEnum.RUNNING,
|
||||
"config": config.dict(),
|
||||
"created_at": datetime.utcnow(),
|
||||
"started_at": datetime.utcnow(),
|
||||
"completed_at": None,
|
||||
"error_message": None,
|
||||
}
|
||||
|
||||
logger.info(f"Backtest started: {backtest_id}")
|
||||
|
||||
return {
|
||||
"backtest_id": backtest_id,
|
||||
"status": BacktestStatusEnum.RUNNING,
|
||||
"name": _backtest_store[backtest_id]["name"],
|
||||
}
|
||||
|
||||
|
||||
@router.get("/{backtest_id}")
|
||||
async def get_backtest_status(backtest_id: str):
|
||||
if backtest_id not in _backtest_store:
|
||||
raise HTTPException(status_code=404, detail=f"Backtest {backtest_id} not found")
|
||||
|
||||
backtest = _backtest_store[backtest_id]
|
||||
result = _backtest_results.get(backtest_id)
|
||||
|
||||
return {
|
||||
"status": backtest["status"],
|
||||
"name": backtest["name"],
|
||||
"created_at": backtest["created_at"],
|
||||
"started_at": backtest["started_at"],
|
||||
"completed_at": backtest["completed_at"],
|
||||
"error_message": backtest["error_message"],
|
||||
"results": result if backtest["status"] == BacktestStatusEnum.COMPLETED else None,
|
||||
}
|
||||
|
||||
|
||||
@router.get("/{backtest_id}/results")
|
||||
async def get_backtest_results(backtest_id: str):
|
||||
if backtest_id not in _backtest_results:
|
||||
raise HTTPException(status_code=404, detail=f"Results for backtest {backtest_id} not found")
|
||||
|
||||
return _backtest_results[backtest_id]
|
||||
|
||||
|
||||
@router.get("/{backtest_id}/trades")
|
||||
async def get_backtest_trades(backtest_id: str, limit: int = 100):
|
||||
if backtest_id not in _backtest_store:
|
||||
raise HTTPException(status_code=404, detail=f"Backtest {backtest_id} not found")
|
||||
|
||||
trades = _backtest_results.get(backtest_id, {}).get("trades", [])
|
||||
return {"backtest_id": backtest_id, "trades": trades[-limit:], "total": len(trades)}
|
||||
|
||||
|
||||
@router.get("/history")
|
||||
async def get_backtest_history():
|
||||
backtests = []
|
||||
for backtest_id, backtest in _backtest_store.items():
|
||||
backtests.append(
|
||||
{
|
||||
"backtest_id": backtest_id,
|
||||
"name": backtest["name"],
|
||||
"status": backtest["status"],
|
||||
"created_at": backtest["created_at"],
|
||||
"completed_at": backtest["completed_at"],
|
||||
}
|
||||
)
|
||||
|
||||
return {"backtests": backtests, "total": len(backtests)}
|
||||
|
||||
|
||||
@router.delete("/{backtest_id}")
|
||||
async def delete_backtest(backtest_id: str):
|
||||
if backtest_id not in _backtest_store:
|
||||
raise HTTPException(status_code=404, detail=f"Backtest {backtest_id} not found")
|
||||
|
||||
del _backtest_store[backtest_id]
|
||||
if backtest_id in _backtest_results:
|
||||
del _backtest_results[backtest_id]
|
||||
|
||||
logger.info(f"Backtest deleted: {backtest_id}")
|
||||
return {"message": f"Backtest {backtest_id} deleted"}
|
||||
48
backend/app/api/routes/dashboard.py
Normal file
48
backend/app/api/routes/dashboard.py
Normal file
@@ -0,0 +1,48 @@
|
||||
from typing import List
|
||||
from fastapi import APIRouter, Depends, Query
|
||||
from app.models.schemas import DashboardSummary, ArbitrageOpportunity, PriceData, BatteryState
|
||||
from app.services import DataService
|
||||
|
||||
router = APIRouter()
|
||||
data_service = DataService()
|
||||
|
||||
|
||||
@router.on_event("startup")
|
||||
async def startup():
|
||||
await data_service.initialize()
|
||||
|
||||
|
||||
@router.get("/summary", response_model=DashboardSummary)
|
||||
async def get_summary():
|
||||
summary = data_service.get_dashboard_summary()
|
||||
return DashboardSummary(**summary)
|
||||
|
||||
|
||||
@router.get("/prices")
|
||||
async def get_latest_prices():
|
||||
return {"regions": data_service.get_latest_prices()}
|
||||
|
||||
|
||||
@router.get("/prices/history")
|
||||
async def get_price_history(
|
||||
region: str = Query(..., description="Region code (FR, BE, DE, NL, UK)"),
|
||||
start: str = Query(None, description="Start date (YYYY-MM-DD)"),
|
||||
end: str = Query(None, description="End date (YYYY-MM-DD)"),
|
||||
limit: int = Query(1000, description="Maximum number of records"),
|
||||
):
|
||||
data = data_service.get_price_history(region, start, end, limit)
|
||||
return {"region": region, "data": data}
|
||||
|
||||
|
||||
@router.get("/battery")
|
||||
async def get_battery_states():
|
||||
batteries = data_service.get_battery_states()
|
||||
return {"batteries": batteries}
|
||||
|
||||
|
||||
@router.get("/arbitrage")
|
||||
async def get_arbitrage_opportunities(
|
||||
min_spread: float = Query(None, description="Minimum spread in EUR/MWh")
|
||||
):
|
||||
opportunities = data_service.get_arbitrage_opportunities(min_spread)
|
||||
return {"opportunities": opportunities, "count": len(opportunities)}
|
||||
71
backend/app/api/routes/models.py
Normal file
71
backend/app/api/routes/models.py
Normal file
@@ -0,0 +1,71 @@
|
||||
from typing import List
|
||||
from fastapi import APIRouter, HTTPException
|
||||
from datetime import datetime
|
||||
from app.models.schemas import ModelInfo, TrainingRequest, TrainingStatus, PredictionResponse
|
||||
from app.services import MLService
|
||||
import uuid
|
||||
|
||||
router = APIRouter()
|
||||
ml_service = MLService()
|
||||
|
||||
_training_store: dict = {}
|
||||
|
||||
|
||||
@router.get("", response_model=List[ModelInfo])
|
||||
async def list_models():
|
||||
return ml_service.list_models()
|
||||
|
||||
|
||||
@router.post("/train")
|
||||
async def train_model(request: TrainingRequest):
|
||||
training_id = f"training_{uuid.uuid4().hex[:8]}"
|
||||
|
||||
_training_store[training_id] = TrainingStatus(
|
||||
training_id=training_id,
|
||||
status="pending",
|
||||
progress=0.0,
|
||||
started_at=datetime.utcnow(),
|
||||
)
|
||||
|
||||
return {"training_id": training_id, "status": _training_store[training_id]}
|
||||
|
||||
|
||||
@router.get("/{training_id}/status", response_model=TrainingStatus)
|
||||
async def get_training_status(training_id: str):
|
||||
if training_id not in _training_store:
|
||||
raise HTTPException(status_code=404, detail=f"Training job {training_id} not found")
|
||||
|
||||
return _training_store[training_id]
|
||||
|
||||
|
||||
@router.get("/{model_id}/metrics")
|
||||
async def get_model_metrics(model_id: str):
|
||||
try:
|
||||
metrics = ml_service.get_model_metrics(model_id)
|
||||
return {"model_id": model_id, "metrics": metrics}
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=404, detail=str(e))
|
||||
|
||||
|
||||
@router.post("/predict", response_model=PredictionResponse)
|
||||
async def predict(
|
||||
model_id: str,
|
||||
timestamp: datetime,
|
||||
features: dict = None,
|
||||
):
|
||||
try:
|
||||
result = ml_service.predict(model_id, timestamp, features)
|
||||
return PredictionResponse(**result)
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=404, detail=str(e))
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.get("/{model_id}/feature-importance")
|
||||
async def get_feature_importance(model_id: str):
|
||||
try:
|
||||
importance = ml_service.get_feature_importance(model_id)
|
||||
return {"model_id": model_id, "feature_importance": importance}
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=404, detail=str(e))
|
||||
26
backend/app/api/routes/settings.py
Normal file
26
backend/app/api/routes/settings.py
Normal file
@@ -0,0 +1,26 @@
|
||||
from fastapi import APIRouter
|
||||
from app.config import settings
|
||||
from app.models.schemas import AppSettings
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.get("", response_model=AppSettings)
|
||||
async def get_settings():
|
||||
return AppSettings(
|
||||
battery_min_reserve=settings.BATTERY_MIN_RESERVE,
|
||||
battery_max_charge=settings.BATTERY_MAX_CHARGE,
|
||||
arbitrage_min_spread=settings.ARBITRAGE_MIN_SPREAD,
|
||||
mining_margin_threshold=settings.MINING_MARGIN_THRESHOLD,
|
||||
)
|
||||
|
||||
|
||||
@router.post("")
|
||||
async def update_settings(settings_update: dict):
|
||||
updated_fields = []
|
||||
for key, value in settings_update.items():
|
||||
if hasattr(settings, key.upper()):
|
||||
setattr(settings, key.upper(), value)
|
||||
updated_fields.append(key)
|
||||
|
||||
return {"message": "Settings updated", "updated_fields": updated_fields}
|
||||
27
backend/app/api/routes/trading.py
Normal file
27
backend/app/api/routes/trading.py
Normal file
@@ -0,0 +1,27 @@
|
||||
from typing import List
|
||||
from fastapi import APIRouter, HTTPException
|
||||
from app.models.enums import StrategyEnum
|
||||
from app.models.schemas import StrategyStatus
|
||||
from app.services import StrategyService
|
||||
|
||||
router = APIRouter()
|
||||
strategy_service = StrategyService()
|
||||
|
||||
|
||||
@router.get("/strategies", response_model=List[StrategyStatus])
|
||||
async def get_strategies():
|
||||
return await strategy_service.get_all_strategies()
|
||||
|
||||
|
||||
@router.post("/strategies")
|
||||
async def toggle_strategy(strategy: StrategyEnum, action: str):
|
||||
if action not in ["start", "stop"]:
|
||||
raise HTTPException(status_code=400, detail="Action must be 'start' or 'stop'")
|
||||
|
||||
status = await strategy_service.toggle_strategy(strategy, action)
|
||||
return {"status": status}
|
||||
|
||||
|
||||
@router.get("/positions")
|
||||
async def get_positions():
|
||||
return {"positions": [], "total": 0}
|
||||
62
backend/app/api/websocket.py
Normal file
62
backend/app/api/websocket.py
Normal file
@@ -0,0 +1,62 @@
|
||||
from fastapi import WebSocket
|
||||
from typing import List
|
||||
from app.utils.logger import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class ConnectionManager:
|
||||
def __init__(self):
|
||||
self.active_connections: List[WebSocket] = []
|
||||
|
||||
async def connect(self, websocket: WebSocket):
|
||||
await websocket.accept()
|
||||
self.active_connections.append(websocket)
|
||||
logger.info(f"WebSocket connected. Total connections: {len(self.active_connections)}")
|
||||
|
||||
def disconnect(self, websocket: WebSocket):
|
||||
if websocket in self.active_connections:
|
||||
self.active_connections.remove(websocket)
|
||||
logger.info(f"WebSocket disconnected. Total connections: {len(self.active_connections)}")
|
||||
|
||||
async def broadcast(self, event_type: str, data: dict):
|
||||
message = {"event_type": event_type, "data": data, "timestamp": None}
|
||||
disconnected = []
|
||||
|
||||
for connection in self.active_connections:
|
||||
try:
|
||||
await connection.send_json(message)
|
||||
except Exception:
|
||||
disconnected.append(connection)
|
||||
|
||||
for conn in disconnected:
|
||||
self.disconnect(conn)
|
||||
|
||||
async def broadcast_price_update(self, region: str, price_data: dict):
|
||||
await self.broadcast("price_update", {"region": region, "price_data": price_data})
|
||||
|
||||
async def broadcast_battery_update(self, battery_id: str, battery_state: dict):
|
||||
await self.broadcast("battery_update", {"battery_id": battery_id, "battery_state": battery_state})
|
||||
|
||||
async def broadcast_trade(self, trade: dict):
|
||||
await self.broadcast("trade_executed", trade)
|
||||
|
||||
async def broadcast_alert(self, alert: dict):
|
||||
await self.broadcast("alert_triggered", alert)
|
||||
|
||||
async def broadcast_backtest_progress(self, backtest_id: str, progress: float, status: str):
|
||||
await self.broadcast(
|
||||
"backtest_progress",
|
||||
{"backtest_id": backtest_id, "progress": progress, "status": status},
|
||||
)
|
||||
|
||||
async def broadcast_model_training_progress(
|
||||
self, model_id: str, progress: float, epoch: int = None, metrics: dict = None
|
||||
):
|
||||
await self.broadcast(
|
||||
"model_training_progress",
|
||||
{"model_id": model_id, "progress": progress, "epoch": epoch, "metrics": metrics},
|
||||
)
|
||||
|
||||
|
||||
manager = ConnectionManager()
|
||||
Reference in New Issue
Block a user