Implements FastAPI backend with ML model support for energy trading, including price prediction models and RL-based battery trading policy. Features dashboard, trading, backtest, and settings API routes with WebSocket support for real-time updates.
87 lines
2.8 KiB
Python
87 lines
2.8 KiB
Python
from typing import Dict, Optional
|
|
import pandas as pd
|
|
import numpy as np
|
|
from app.ml.price_prediction.model import PricePredictionModel
|
|
from app.ml.price_prediction.trainer import PricePredictionTrainer
|
|
from app.utils.logger import get_logger
|
|
|
|
logger = get_logger(__name__)
|
|
|
|
|
|
class PricePredictor:
|
|
def __init__(self, models_dir: str = "models/price_prediction"):
|
|
self.models_dir = models_dir
|
|
self.models: Dict[int, PricePredictionModel] = {}
|
|
self._load_models()
|
|
|
|
def _load_models(self):
|
|
self.models = PricePredictionTrainer.load_models(self.models_dir)
|
|
logger.info(f"Loaded {len(self.models)} prediction models")
|
|
|
|
def predict(
|
|
self, current_data: pd.DataFrame, horizon: int = 15, region: Optional[str] = None
|
|
) -> float:
|
|
if horizon not in self.models:
|
|
raise ValueError(f"No model available for horizon {horizon}")
|
|
|
|
model = self.models[horizon]
|
|
|
|
from app.ml.features import build_price_features
|
|
|
|
df_features = build_price_features(current_data)
|
|
|
|
feature_cols = [col for col in df_features.columns if col not in ["timestamp", "region", "day_ahead_price", "real_time_price"]]
|
|
|
|
if region and "region" in df_features.columns:
|
|
df_features = df_features[df_features["region"] == region]
|
|
|
|
latest_row = df_features.iloc[-1:][feature_cols]
|
|
|
|
prediction = model.predict(latest_row.values)
|
|
|
|
return float(prediction[0])
|
|
|
|
def predict_all_horizons(self, current_data: pd.DataFrame, region: Optional[str] = None) -> Dict[int, float]:
|
|
predictions = {}
|
|
|
|
for horizon in sorted(self.models.keys()):
|
|
try:
|
|
pred = self.predict(current_data, horizon, region)
|
|
predictions[horizon] = pred
|
|
except Exception as e:
|
|
logger.error(f"Failed to predict for horizon {horizon}: {e}")
|
|
predictions[horizon] = None
|
|
|
|
return predictions
|
|
|
|
def predict_with_confidence(
|
|
self, current_data: pd.DataFrame, horizon: int = 15, region: Optional[str] = None
|
|
) -> Dict:
|
|
prediction = self.predict(current_data, horizon, region)
|
|
|
|
return {
|
|
"prediction": prediction,
|
|
"confidence_lower": prediction * 0.95,
|
|
"confidence_upper": prediction * 1.05,
|
|
"horizon": horizon,
|
|
}
|
|
|
|
def get_feature_importance(self, horizon: int) -> pd.DataFrame:
|
|
if horizon not in self.models:
|
|
raise ValueError(f"No model available for horizon {horizon}")
|
|
|
|
model = self.models[horizon]
|
|
|
|
importances = model.feature_importances_
|
|
feature_names = model.feature_names
|
|
|
|
df = pd.DataFrame({
|
|
"feature": feature_names,
|
|
"importance": importances,
|
|
}).sort_values("importance", ascending=False)
|
|
|
|
return df
|
|
|
|
|
|
__all__ = ["PricePredictor"]
|