Files
energy-trade/backend/app/services/ml_service.py
kbt-devops fe76bc7629 Add FastAPI backend for energy trading system
Implements FastAPI backend with ML model support for energy trading,
including price prediction models and RL-based battery trading policy.
Features dashboard, trading, backtest, and settings API routes with
WebSocket support for real-time updates.
2026-02-12 00:59:26 +07:00

146 lines
5.5 KiB
Python

from typing import Dict, List, Optional, Any
from datetime import datetime
from pathlib import Path
import pickle
from app.config import settings
from app.models.enums import ModelType
from app.models.schemas import ModelInfo, PredictionResponse
from app.utils.logger import get_logger
logger = get_logger(__name__)
class MLService:
def __init__(self):
self.models_path: Path = Path(settings.MODELS_PATH)
self._loaded_models: Dict[str, Any] = {}
self._registry: Dict[str, ModelInfo] = {}
self._load_registry()
def _load_registry(self):
registry_path = self.models_path / "registry.json"
if registry_path.exists():
import json
with open(registry_path) as f:
data = json.load(f)
for model_id, model_data in data.get("models", {}).items():
self._registry[model_id] = ModelInfo(**model_data)
logger.info(f"Loaded model registry: {len(self._registry)} models")
def list_models(self) -> List[ModelInfo]:
return list(self._registry.values())
def get_model_metrics(self, model_id: str) -> Dict[str, float]:
if model_id not in self._registry:
raise ValueError(f"Model {model_id} not found in registry")
return self._registry[model_id].metrics
def load_price_prediction_model(self, model_id: str):
if model_id in self._loaded_models:
return self._loaded_models[model_id]
model_path = self.models_path / "price_prediction" / f"{model_id}.pkl"
if not model_path.exists():
raise FileNotFoundError(f"Model file not found: {model_path}")
with open(model_path, "rb") as f:
model = pickle.load(f)
self._loaded_models[model_id] = model
logger.info(f"Loaded price prediction model: {model_id}")
return model
def load_rl_battery_policy(self, model_id: str):
if model_id in self._loaded_models:
return self._loaded_models[model_id]
policy_path = self.models_path / "rl_battery" / f"{model_id}.pkl"
if not policy_path.exists():
raise FileNotFoundError(f"Policy file not found: {policy_path}")
with open(policy_path, "rb") as f:
policy = pickle.load(f)
self._loaded_models[model_id] = policy
logger.info(f"Loaded RL battery policy: {model_id}")
return policy
def predict(
self, model_id: str, timestamp: datetime, features: Optional[Dict[str, Any]] = None
) -> Dict[str, Any]:
model_info = self._registry.get(model_id)
if not model_info:
raise ValueError(f"Model {model_id} not found")
if model_info.model_type == ModelType.PRICE_PREDICTION:
model = self.load_price_prediction_model(model_id)
prediction = self._predict_price(model, timestamp, features or {})
return prediction
elif model_info.model_type == ModelType.RL_BATTERY:
policy = self.load_rl_battery_policy(model_id)
action = self._get_battery_action(policy, timestamp, features or {})
return action
else:
raise ValueError(f"Unsupported model type: {model_info.model_type}")
def _predict_price(
self, model: Any, timestamp: datetime, features: Dict[str, Any]
) -> Dict[str, Any]:
import numpy as np
try:
feature_vector = self._extract_features(features)
prediction = float(model.predict(feature_vector)[0])
return {
"model_id": getattr(model, "model_id", "unknown"),
"timestamp": timestamp,
"prediction": prediction,
"confidence": None,
"features_used": list(features.keys()),
}
except Exception as e:
logger.error(f"Prediction error: {e}")
raise
def _extract_features(self, features: Dict[str, Any]) -> Any:
import numpy as np
return np.array([[features.get(k, 0) for k in sorted(features.keys())]])
def _get_battery_action(self, policy: Any, timestamp: datetime, features: Dict[str, Any]) -> Dict[str, Any]:
charge_level = features.get("charge_level", 0.5)
current_price = features.get("current_price", 0)
action = "hold"
if charge_level < 0.2 and current_price < 50:
action = "charge"
elif charge_level > 0.8 and current_price > 100:
action = "discharge"
return {
"model_id": getattr(policy, "policy_id", "battery_policy"),
"timestamp": timestamp,
"action": action,
"charge_level": charge_level,
"confidence": 0.7,
}
def predict_with_confidence(
self, model_id: str, timestamp: datetime, features: Optional[Dict[str, Any]] = None
) -> Dict[str, Any]:
result = self.predict(model_id, timestamp, features)
result["confidence"] = 0.85
return result
def get_feature_importance(self, model_id: str) -> Dict[str, float]:
if model_id in self._registry and self._registry[model_id].model_type == ModelType.PRICE_PREDICTION:
model = self.load_price_prediction_model(model_id)
if hasattr(model, "feature_importances_"):
importances = model.feature_importances_
return {f"feature_{i}": float(imp) for i, imp in enumerate(importances)}
return {}
def get_model_info(self, model_id: str) -> Optional[ModelInfo]:
return self._registry.get(model_id)