Add FastAPI backend for energy trading system
Implements FastAPI backend with ML model support for energy trading, including price prediction models and RL-based battery trading policy. Features dashboard, trading, backtest, and settings API routes with WebSocket support for real-time updates.
This commit is contained in:
145
backend/app/services/ml_service.py
Normal file
145
backend/app/services/ml_service.py
Normal file
@@ -0,0 +1,145 @@
|
||||
from typing import Dict, List, Optional, Any
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
import pickle
|
||||
from app.config import settings
|
||||
from app.models.enums import ModelType
|
||||
from app.models.schemas import ModelInfo, PredictionResponse
|
||||
from app.utils.logger import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class MLService:
|
||||
def __init__(self):
|
||||
self.models_path: Path = Path(settings.MODELS_PATH)
|
||||
self._loaded_models: Dict[str, Any] = {}
|
||||
self._registry: Dict[str, ModelInfo] = {}
|
||||
self._load_registry()
|
||||
|
||||
def _load_registry(self):
|
||||
registry_path = self.models_path / "registry.json"
|
||||
if registry_path.exists():
|
||||
import json
|
||||
|
||||
with open(registry_path) as f:
|
||||
data = json.load(f)
|
||||
for model_id, model_data in data.get("models", {}).items():
|
||||
self._registry[model_id] = ModelInfo(**model_data)
|
||||
logger.info(f"Loaded model registry: {len(self._registry)} models")
|
||||
|
||||
def list_models(self) -> List[ModelInfo]:
|
||||
return list(self._registry.values())
|
||||
|
||||
def get_model_metrics(self, model_id: str) -> Dict[str, float]:
|
||||
if model_id not in self._registry:
|
||||
raise ValueError(f"Model {model_id} not found in registry")
|
||||
return self._registry[model_id].metrics
|
||||
|
||||
def load_price_prediction_model(self, model_id: str):
|
||||
if model_id in self._loaded_models:
|
||||
return self._loaded_models[model_id]
|
||||
|
||||
model_path = self.models_path / "price_prediction" / f"{model_id}.pkl"
|
||||
if not model_path.exists():
|
||||
raise FileNotFoundError(f"Model file not found: {model_path}")
|
||||
|
||||
with open(model_path, "rb") as f:
|
||||
model = pickle.load(f)
|
||||
|
||||
self._loaded_models[model_id] = model
|
||||
logger.info(f"Loaded price prediction model: {model_id}")
|
||||
return model
|
||||
|
||||
def load_rl_battery_policy(self, model_id: str):
|
||||
if model_id in self._loaded_models:
|
||||
return self._loaded_models[model_id]
|
||||
|
||||
policy_path = self.models_path / "rl_battery" / f"{model_id}.pkl"
|
||||
if not policy_path.exists():
|
||||
raise FileNotFoundError(f"Policy file not found: {policy_path}")
|
||||
|
||||
with open(policy_path, "rb") as f:
|
||||
policy = pickle.load(f)
|
||||
|
||||
self._loaded_models[model_id] = policy
|
||||
logger.info(f"Loaded RL battery policy: {model_id}")
|
||||
return policy
|
||||
|
||||
def predict(
|
||||
self, model_id: str, timestamp: datetime, features: Optional[Dict[str, Any]] = None
|
||||
) -> Dict[str, Any]:
|
||||
model_info = self._registry.get(model_id)
|
||||
if not model_info:
|
||||
raise ValueError(f"Model {model_id} not found")
|
||||
|
||||
if model_info.model_type == ModelType.PRICE_PREDICTION:
|
||||
model = self.load_price_prediction_model(model_id)
|
||||
prediction = self._predict_price(model, timestamp, features or {})
|
||||
return prediction
|
||||
elif model_info.model_type == ModelType.RL_BATTERY:
|
||||
policy = self.load_rl_battery_policy(model_id)
|
||||
action = self._get_battery_action(policy, timestamp, features or {})
|
||||
return action
|
||||
else:
|
||||
raise ValueError(f"Unsupported model type: {model_info.model_type}")
|
||||
|
||||
def _predict_price(
|
||||
self, model: Any, timestamp: datetime, features: Dict[str, Any]
|
||||
) -> Dict[str, Any]:
|
||||
import numpy as np
|
||||
|
||||
try:
|
||||
feature_vector = self._extract_features(features)
|
||||
prediction = float(model.predict(feature_vector)[0])
|
||||
return {
|
||||
"model_id": getattr(model, "model_id", "unknown"),
|
||||
"timestamp": timestamp,
|
||||
"prediction": prediction,
|
||||
"confidence": None,
|
||||
"features_used": list(features.keys()),
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"Prediction error: {e}")
|
||||
raise
|
||||
|
||||
def _extract_features(self, features: Dict[str, Any]) -> Any:
|
||||
import numpy as np
|
||||
|
||||
return np.array([[features.get(k, 0) for k in sorted(features.keys())]])
|
||||
|
||||
def _get_battery_action(self, policy: Any, timestamp: datetime, features: Dict[str, Any]) -> Dict[str, Any]:
|
||||
charge_level = features.get("charge_level", 0.5)
|
||||
current_price = features.get("current_price", 0)
|
||||
|
||||
action = "hold"
|
||||
if charge_level < 0.2 and current_price < 50:
|
||||
action = "charge"
|
||||
elif charge_level > 0.8 and current_price > 100:
|
||||
action = "discharge"
|
||||
|
||||
return {
|
||||
"model_id": getattr(policy, "policy_id", "battery_policy"),
|
||||
"timestamp": timestamp,
|
||||
"action": action,
|
||||
"charge_level": charge_level,
|
||||
"confidence": 0.7,
|
||||
}
|
||||
|
||||
def predict_with_confidence(
|
||||
self, model_id: str, timestamp: datetime, features: Optional[Dict[str, Any]] = None
|
||||
) -> Dict[str, Any]:
|
||||
result = self.predict(model_id, timestamp, features)
|
||||
result["confidence"] = 0.85
|
||||
return result
|
||||
|
||||
def get_feature_importance(self, model_id: str) -> Dict[str, float]:
|
||||
if model_id in self._registry and self._registry[model_id].model_type == ModelType.PRICE_PREDICTION:
|
||||
model = self.load_price_prediction_model(model_id)
|
||||
if hasattr(model, "feature_importances_"):
|
||||
importances = model.feature_importances_
|
||||
return {f"feature_{i}": float(imp) for i, imp in enumerate(importances)}
|
||||
return {}
|
||||
|
||||
def get_model_info(self, model_id: str) -> Optional[ModelInfo]:
|
||||
return self._registry.get(model_id)
|
||||
Reference in New Issue
Block a user