from typing import Dict, List, Optional, Any from datetime import datetime from pathlib import Path import pickle from app.config import settings from app.models.enums import ModelType from app.models.schemas import ModelInfo, PredictionResponse from app.utils.logger import get_logger logger = get_logger(__name__) class MLService: def __init__(self): self.models_path: Path = Path(settings.MODELS_PATH) self._loaded_models: Dict[str, Any] = {} self._registry: Dict[str, ModelInfo] = {} self._load_registry() def _load_registry(self): registry_path = self.models_path / "registry.json" if registry_path.exists(): import json with open(registry_path) as f: data = json.load(f) for model_id, model_data in data.get("models", {}).items(): self._registry[model_id] = ModelInfo(**model_data) logger.info(f"Loaded model registry: {len(self._registry)} models") def list_models(self) -> List[ModelInfo]: return list(self._registry.values()) def get_model_metrics(self, model_id: str) -> Dict[str, float]: if model_id not in self._registry: raise ValueError(f"Model {model_id} not found in registry") return self._registry[model_id].metrics def load_price_prediction_model(self, model_id: str): if model_id in self._loaded_models: return self._loaded_models[model_id] model_path = self.models_path / "price_prediction" / f"{model_id}.pkl" if not model_path.exists(): raise FileNotFoundError(f"Model file not found: {model_path}") with open(model_path, "rb") as f: model = pickle.load(f) self._loaded_models[model_id] = model logger.info(f"Loaded price prediction model: {model_id}") return model def load_rl_battery_policy(self, model_id: str): if model_id in self._loaded_models: return self._loaded_models[model_id] policy_path = self.models_path / "rl_battery" / f"{model_id}.pkl" if not policy_path.exists(): raise FileNotFoundError(f"Policy file not found: {policy_path}") with open(policy_path, "rb") as f: policy = pickle.load(f) self._loaded_models[model_id] = policy logger.info(f"Loaded RL battery policy: {model_id}") return policy def predict( self, model_id: str, timestamp: datetime, features: Optional[Dict[str, Any]] = None ) -> Dict[str, Any]: model_info = self._registry.get(model_id) if not model_info: raise ValueError(f"Model {model_id} not found") if model_info.model_type == ModelType.PRICE_PREDICTION: model = self.load_price_prediction_model(model_id) prediction = self._predict_price(model, timestamp, features or {}) return prediction elif model_info.model_type == ModelType.RL_BATTERY: policy = self.load_rl_battery_policy(model_id) action = self._get_battery_action(policy, timestamp, features or {}) return action else: raise ValueError(f"Unsupported model type: {model_info.model_type}") def _predict_price( self, model: Any, timestamp: datetime, features: Dict[str, Any] ) -> Dict[str, Any]: import numpy as np try: feature_vector = self._extract_features(features) prediction = float(model.predict(feature_vector)[0]) return { "model_id": getattr(model, "model_id", "unknown"), "timestamp": timestamp, "prediction": prediction, "confidence": None, "features_used": list(features.keys()), } except Exception as e: logger.error(f"Prediction error: {e}") raise def _extract_features(self, features: Dict[str, Any]) -> Any: import numpy as np return np.array([[features.get(k, 0) for k in sorted(features.keys())]]) def _get_battery_action(self, policy: Any, timestamp: datetime, features: Dict[str, Any]) -> Dict[str, Any]: charge_level = features.get("charge_level", 0.5) current_price = features.get("current_price", 0) action = "hold" if charge_level < 0.2 and current_price < 50: action = "charge" elif charge_level > 0.8 and current_price > 100: action = "discharge" return { "model_id": getattr(policy, "policy_id", "battery_policy"), "timestamp": timestamp, "action": action, "charge_level": charge_level, "confidence": 0.7, } def predict_with_confidence( self, model_id: str, timestamp: datetime, features: Optional[Dict[str, Any]] = None ) -> Dict[str, Any]: result = self.predict(model_id, timestamp, features) result["confidence"] = 0.85 return result def get_feature_importance(self, model_id: str) -> Dict[str, float]: if model_id in self._registry and self._registry[model_id].model_type == ModelType.PRICE_PREDICTION: model = self.load_price_prediction_model(model_id) if hasattr(model, "feature_importances_"): importances = model.feature_importances_ return {f"feature_{i}": float(imp) for i, imp in enumerate(importances)} return {} def get_model_info(self, model_id: str) -> Optional[ModelInfo]: return self._registry.get(model_id)