Implements FastAPI backend with ML model support for energy trading, including price prediction models and RL-based battery trading policy. Features dashboard, trading, backtest, and settings API routes with WebSocket support for real-time updates.
72 lines
2.1 KiB
Python
72 lines
2.1 KiB
Python
from typing import List
|
|
from fastapi import APIRouter, HTTPException
|
|
from datetime import datetime
|
|
from app.models.schemas import ModelInfo, TrainingRequest, TrainingStatus, PredictionResponse
|
|
from app.services import MLService
|
|
import uuid
|
|
|
|
router = APIRouter()
|
|
ml_service = MLService()
|
|
|
|
_training_store: dict = {}
|
|
|
|
|
|
@router.get("", response_model=List[ModelInfo])
|
|
async def list_models():
|
|
return ml_service.list_models()
|
|
|
|
|
|
@router.post("/train")
|
|
async def train_model(request: TrainingRequest):
|
|
training_id = f"training_{uuid.uuid4().hex[:8]}"
|
|
|
|
_training_store[training_id] = TrainingStatus(
|
|
training_id=training_id,
|
|
status="pending",
|
|
progress=0.0,
|
|
started_at=datetime.utcnow(),
|
|
)
|
|
|
|
return {"training_id": training_id, "status": _training_store[training_id]}
|
|
|
|
|
|
@router.get("/{training_id}/status", response_model=TrainingStatus)
|
|
async def get_training_status(training_id: str):
|
|
if training_id not in _training_store:
|
|
raise HTTPException(status_code=404, detail=f"Training job {training_id} not found")
|
|
|
|
return _training_store[training_id]
|
|
|
|
|
|
@router.get("/{model_id}/metrics")
|
|
async def get_model_metrics(model_id: str):
|
|
try:
|
|
metrics = ml_service.get_model_metrics(model_id)
|
|
return {"model_id": model_id, "metrics": metrics}
|
|
except ValueError as e:
|
|
raise HTTPException(status_code=404, detail=str(e))
|
|
|
|
|
|
@router.post("/predict", response_model=PredictionResponse)
|
|
async def predict(
|
|
model_id: str,
|
|
timestamp: datetime,
|
|
features: dict = None,
|
|
):
|
|
try:
|
|
result = ml_service.predict(model_id, timestamp, features)
|
|
return PredictionResponse(**result)
|
|
except ValueError as e:
|
|
raise HTTPException(status_code=404, detail=str(e))
|
|
except Exception as e:
|
|
raise HTTPException(status_code=500, detail=str(e))
|
|
|
|
|
|
@router.get("/{model_id}/feature-importance")
|
|
async def get_feature_importance(model_id: str):
|
|
try:
|
|
importance = ml_service.get_feature_importance(model_id)
|
|
return {"model_id": model_id, "feature_importance": importance}
|
|
except ValueError as e:
|
|
raise HTTPException(status_code=404, detail=str(e))
|