from typing import List, Dict, Tuple, Optional from pathlib import Path import pandas as pd from app.ml.price_prediction.model import PricePredictionModel from app.utils.logger import get_logger logger = get_logger(__name__) class PricePredictionTrainer: def __init__(self, config=None): self.config = config self.data: Optional[pd.DataFrame] = None self.models: Dict[int, PricePredictionModel] = {} def load_data(self, data_path: Optional[str] = None) -> pd.DataFrame: if data_path is None: data_path = "~/energy-test-data/data/processed" path = Path(data_path).expanduser() dfs = [] for region in ["FR", "BE", "DE", "NL", "UK"]: file_path = path / f"{region.lower()}_processed.parquet" if file_path.exists(): df = pd.read_parquet(file_path) df["region"] = region dfs.append(df) if dfs: self.data = pd.concat(dfs, ignore_index=True) logger.info(f"Loaded data: {len(self.data)} rows") return self.data def prepare_data(self, df: pd.DataFrame) -> Tuple[pd.DataFrame, List[str]]: from app.ml.features import build_price_features df_features = build_price_features(df) df_features = df_features.dropna() feature_cols = [col for col in df_features.columns if col not in ["timestamp", "region", "day_ahead_price", "real_time_price"]] return df_features, feature_cols def train_for_horizon( self, df_features: pd.DataFrame, feature_cols: List[str], horizon: int ) -> Dict: logger.info(f"Training model for {horizon} minute horizon") df_features = df_features.sort_values("timestamp") n_total = len(df_features) n_train = int(n_total * 0.70) n_val = int(n_total * 0.85) train_data = df_features.iloc[:n_train] val_data = df_features.iloc[n_train:n_val] X_train = train_data[feature_cols] y_train = train_data["real_time_price"].shift(-horizon).dropna() X_train = X_train.loc[y_train.index] X_val = val_data[feature_cols] y_val = val_data["real_time_price"].shift(-horizon).dropna() X_val = X_val.loc[y_val.index] model = PricePredictionModel(horizon=horizon) model.fit(X_train, y_train) val_preds = model.predict(X_val) from sklearn.metrics import mean_absolute_error, mean_squared_error, r2_score mae = mean_absolute_error(y_val, val_preds) rmse = mean_squared_error(y_val, val_preds, squared=False) r2 = r2_score(y_val, val_preds) self.models[horizon] = model results = { "horizon": horizon, "mae": mae, "rmse": rmse, "r2": r2, "n_train": len(X_train), "n_val": len(X_val), } logger.info(f"Training complete for {horizon}m: MAE={mae:.2f}, RMSE={rmse:.2f}, R2={r2:.3f}") return results def train_all(self, horizons: Optional[List[int]] = None) -> Dict: if horizons is None: horizons = [1, 5, 15, 60] if self.data is None: self.load_data() df_features, feature_cols = self.prepare_data(self.data) all_results = {} for horizon in horizons: try: result = self.train_for_horizon(df_features, feature_cols, horizon) all_results[horizon] = result except Exception as e: logger.error(f"Failed to train for horizon {horizon}: {e}") all_results[horizon] = {"error": str(e)} return all_results def save_models(self, output_dir: str = "models/price_prediction") -> None: output_path = Path(output_dir) output_path.mkdir(parents=True, exist_ok=True) for horizon, model in self.models.items(): filepath = output_path / f"model_{horizon}min.pkl" model.save(filepath) logger.info(f"Saved model for {horizon}m to {filepath}") @classmethod def load_models(cls, models_dir: str = "models/price_prediction", horizons: Optional[List[int]] = None) -> Dict[int, PricePredictionModel]: models = {} path = Path(models_dir) if horizons is None: horizons = [1, 5, 15, 60] for horizon in horizons: filepath = path / f"model_{horizon}min.pkl" if filepath.exists(): model = PricePredictionModel.load(filepath) models[horizon] = model logger.info(f"Loaded model for {horizon}m") return models __all__ = ["PricePredictionTrainer"]