Implements FastAPI backend with ML model support for energy trading, including price prediction models and RL-based battery trading policy. Features dashboard, trading, backtest, and settings API routes with WebSocket support for real-time updates.
100 lines
3.0 KiB
Python
100 lines
3.0 KiB
Python
from typing import Dict, List, Optional
|
|
from pathlib import Path
|
|
import json
|
|
from datetime import datetime
|
|
from app.utils.logger import get_logger
|
|
|
|
logger = get_logger(__name__)
|
|
|
|
|
|
class ModelRegistry:
|
|
def __init__(self, registry_path: str = "models/registry.json"):
|
|
self.registry_path = Path(registry_path)
|
|
self._registry: Dict[str, Dict] = {}
|
|
self._load()
|
|
|
|
def _load(self):
|
|
if self.registry_path.exists():
|
|
with open(self.registry_path) as f:
|
|
self._registry = json.load(f)
|
|
logger.info(f"Loaded registry from {self.registry_path}")
|
|
|
|
def _save(self):
|
|
self.registry_path.parent.mkdir(parents=True, exist_ok=True)
|
|
with open(self.registry_path, "w") as f:
|
|
json.dump(self._registry, f, indent=2, default=str)
|
|
logger.info(f"Saved registry to {self.registry_path}")
|
|
|
|
def register_model(
|
|
self,
|
|
model_type: str,
|
|
model_id: str,
|
|
version: str,
|
|
filepath: str,
|
|
metadata: Optional[Dict] = None,
|
|
) -> None:
|
|
timestamp = datetime.utcnow().isoformat()
|
|
|
|
if model_id not in self._registry:
|
|
self._registry[model_id] = {
|
|
"type": model_type,
|
|
"versions": [],
|
|
}
|
|
|
|
self._registry[model_id]["versions"].append({
|
|
"version": version,
|
|
"filepath": filepath,
|
|
"timestamp": timestamp,
|
|
"metadata": metadata or {},
|
|
})
|
|
|
|
self._registry[model_id]["latest"] = version
|
|
|
|
self._save()
|
|
logger.info(f"Registered model {model_id} version {version}")
|
|
|
|
def get_latest_version(self, model_id: str) -> Optional[Dict]:
|
|
if model_id not in self._registry:
|
|
return None
|
|
|
|
latest_version = self._registry[model_id].get("latest")
|
|
if not latest_version:
|
|
return None
|
|
|
|
for version_info in self._registry[model_id]["versions"]:
|
|
if version_info["version"] == latest_version:
|
|
return version_info
|
|
|
|
return None
|
|
|
|
def list_models(self) -> List[Dict]:
|
|
models = []
|
|
|
|
for model_id, model_info in self._registry.items():
|
|
latest = self.get_latest_version(model_id)
|
|
models.append({
|
|
"model_id": model_id,
|
|
"type": model_info.get("type"),
|
|
"latest_version": model_info.get("latest"),
|
|
"total_versions": len(model_info.get("versions", [])),
|
|
"latest_info": latest,
|
|
})
|
|
|
|
return models
|
|
|
|
def get_model(self, model_id: str, version: Optional[str] = None) -> Optional[Dict]:
|
|
if model_id not in self._registry:
|
|
return None
|
|
|
|
if version is None:
|
|
version = self._registry[model_id].get("latest")
|
|
|
|
for version_info in self._registry[model_id]["versions"]:
|
|
if version_info["version"] == version:
|
|
return version_info
|
|
|
|
return None
|
|
|
|
|
|
__all__ = ["ModelRegistry"]
|