Implements FastAPI backend with ML model support for energy trading, including price prediction models and RL-based battery trading policy. Features dashboard, trading, backtest, and settings API routes with WebSocket support for real-time updates.
146 lines
5.5 KiB
Python
146 lines
5.5 KiB
Python
from typing import Dict, List, Optional, Any
|
|
from datetime import datetime
|
|
from pathlib import Path
|
|
import pickle
|
|
from app.config import settings
|
|
from app.models.enums import ModelType
|
|
from app.models.schemas import ModelInfo, PredictionResponse
|
|
from app.utils.logger import get_logger
|
|
|
|
logger = get_logger(__name__)
|
|
|
|
|
|
class MLService:
|
|
def __init__(self):
|
|
self.models_path: Path = Path(settings.MODELS_PATH)
|
|
self._loaded_models: Dict[str, Any] = {}
|
|
self._registry: Dict[str, ModelInfo] = {}
|
|
self._load_registry()
|
|
|
|
def _load_registry(self):
|
|
registry_path = self.models_path / "registry.json"
|
|
if registry_path.exists():
|
|
import json
|
|
|
|
with open(registry_path) as f:
|
|
data = json.load(f)
|
|
for model_id, model_data in data.get("models", {}).items():
|
|
self._registry[model_id] = ModelInfo(**model_data)
|
|
logger.info(f"Loaded model registry: {len(self._registry)} models")
|
|
|
|
def list_models(self) -> List[ModelInfo]:
|
|
return list(self._registry.values())
|
|
|
|
def get_model_metrics(self, model_id: str) -> Dict[str, float]:
|
|
if model_id not in self._registry:
|
|
raise ValueError(f"Model {model_id} not found in registry")
|
|
return self._registry[model_id].metrics
|
|
|
|
def load_price_prediction_model(self, model_id: str):
|
|
if model_id in self._loaded_models:
|
|
return self._loaded_models[model_id]
|
|
|
|
model_path = self.models_path / "price_prediction" / f"{model_id}.pkl"
|
|
if not model_path.exists():
|
|
raise FileNotFoundError(f"Model file not found: {model_path}")
|
|
|
|
with open(model_path, "rb") as f:
|
|
model = pickle.load(f)
|
|
|
|
self._loaded_models[model_id] = model
|
|
logger.info(f"Loaded price prediction model: {model_id}")
|
|
return model
|
|
|
|
def load_rl_battery_policy(self, model_id: str):
|
|
if model_id in self._loaded_models:
|
|
return self._loaded_models[model_id]
|
|
|
|
policy_path = self.models_path / "rl_battery" / f"{model_id}.pkl"
|
|
if not policy_path.exists():
|
|
raise FileNotFoundError(f"Policy file not found: {policy_path}")
|
|
|
|
with open(policy_path, "rb") as f:
|
|
policy = pickle.load(f)
|
|
|
|
self._loaded_models[model_id] = policy
|
|
logger.info(f"Loaded RL battery policy: {model_id}")
|
|
return policy
|
|
|
|
def predict(
|
|
self, model_id: str, timestamp: datetime, features: Optional[Dict[str, Any]] = None
|
|
) -> Dict[str, Any]:
|
|
model_info = self._registry.get(model_id)
|
|
if not model_info:
|
|
raise ValueError(f"Model {model_id} not found")
|
|
|
|
if model_info.model_type == ModelType.PRICE_PREDICTION:
|
|
model = self.load_price_prediction_model(model_id)
|
|
prediction = self._predict_price(model, timestamp, features or {})
|
|
return prediction
|
|
elif model_info.model_type == ModelType.RL_BATTERY:
|
|
policy = self.load_rl_battery_policy(model_id)
|
|
action = self._get_battery_action(policy, timestamp, features or {})
|
|
return action
|
|
else:
|
|
raise ValueError(f"Unsupported model type: {model_info.model_type}")
|
|
|
|
def _predict_price(
|
|
self, model: Any, timestamp: datetime, features: Dict[str, Any]
|
|
) -> Dict[str, Any]:
|
|
import numpy as np
|
|
|
|
try:
|
|
feature_vector = self._extract_features(features)
|
|
prediction = float(model.predict(feature_vector)[0])
|
|
return {
|
|
"model_id": getattr(model, "model_id", "unknown"),
|
|
"timestamp": timestamp,
|
|
"prediction": prediction,
|
|
"confidence": None,
|
|
"features_used": list(features.keys()),
|
|
}
|
|
except Exception as e:
|
|
logger.error(f"Prediction error: {e}")
|
|
raise
|
|
|
|
def _extract_features(self, features: Dict[str, Any]) -> Any:
|
|
import numpy as np
|
|
|
|
return np.array([[features.get(k, 0) for k in sorted(features.keys())]])
|
|
|
|
def _get_battery_action(self, policy: Any, timestamp: datetime, features: Dict[str, Any]) -> Dict[str, Any]:
|
|
charge_level = features.get("charge_level", 0.5)
|
|
current_price = features.get("current_price", 0)
|
|
|
|
action = "hold"
|
|
if charge_level < 0.2 and current_price < 50:
|
|
action = "charge"
|
|
elif charge_level > 0.8 and current_price > 100:
|
|
action = "discharge"
|
|
|
|
return {
|
|
"model_id": getattr(policy, "policy_id", "battery_policy"),
|
|
"timestamp": timestamp,
|
|
"action": action,
|
|
"charge_level": charge_level,
|
|
"confidence": 0.7,
|
|
}
|
|
|
|
def predict_with_confidence(
|
|
self, model_id: str, timestamp: datetime, features: Optional[Dict[str, Any]] = None
|
|
) -> Dict[str, Any]:
|
|
result = self.predict(model_id, timestamp, features)
|
|
result["confidence"] = 0.85
|
|
return result
|
|
|
|
def get_feature_importance(self, model_id: str) -> Dict[str, float]:
|
|
if model_id in self._registry and self._registry[model_id].model_type == ModelType.PRICE_PREDICTION:
|
|
model = self.load_price_prediction_model(model_id)
|
|
if hasattr(model, "feature_importances_"):
|
|
importances = model.feature_importances_
|
|
return {f"feature_{i}": float(imp) for i, imp in enumerate(importances)}
|
|
return {}
|
|
|
|
def get_model_info(self, model_id: str) -> Optional[ModelInfo]:
|
|
return self._registry.get(model_id)
|