Files
energy-trade/backend/app/ml/price_prediction/trainer.py
kbt-devops fe76bc7629 Add FastAPI backend for energy trading system
Implements FastAPI backend with ML model support for energy trading,
including price prediction models and RL-based battery trading policy.
Features dashboard, trading, backtest, and settings API routes with
WebSocket support for real-time updates.
2026-02-12 00:59:26 +07:00

143 lines
4.6 KiB
Python

from typing import List, Dict, Tuple, Optional
from pathlib import Path
import pandas as pd
from app.ml.price_prediction.model import PricePredictionModel
from app.utils.logger import get_logger
logger = get_logger(__name__)
class PricePredictionTrainer:
def __init__(self, config=None):
self.config = config
self.data: Optional[pd.DataFrame] = None
self.models: Dict[int, PricePredictionModel] = {}
def load_data(self, data_path: Optional[str] = None) -> pd.DataFrame:
if data_path is None:
data_path = "~/energy-test-data/data/processed"
path = Path(data_path).expanduser()
dfs = []
for region in ["FR", "BE", "DE", "NL", "UK"]:
file_path = path / f"{region.lower()}_processed.parquet"
if file_path.exists():
df = pd.read_parquet(file_path)
df["region"] = region
dfs.append(df)
if dfs:
self.data = pd.concat(dfs, ignore_index=True)
logger.info(f"Loaded data: {len(self.data)} rows")
return self.data
def prepare_data(self, df: pd.DataFrame) -> Tuple[pd.DataFrame, List[str]]:
from app.ml.features import build_price_features
df_features = build_price_features(df)
df_features = df_features.dropna()
feature_cols = [col for col in df_features.columns if col not in ["timestamp", "region", "day_ahead_price", "real_time_price"]]
return df_features, feature_cols
def train_for_horizon(
self, df_features: pd.DataFrame, feature_cols: List[str], horizon: int
) -> Dict:
logger.info(f"Training model for {horizon} minute horizon")
df_features = df_features.sort_values("timestamp")
n_total = len(df_features)
n_train = int(n_total * 0.70)
n_val = int(n_total * 0.85)
train_data = df_features.iloc[:n_train]
val_data = df_features.iloc[n_train:n_val]
X_train = train_data[feature_cols]
y_train = train_data["real_time_price"].shift(-horizon).dropna()
X_train = X_train.loc[y_train.index]
X_val = val_data[feature_cols]
y_val = val_data["real_time_price"].shift(-horizon).dropna()
X_val = X_val.loc[y_val.index]
model = PricePredictionModel(horizon=horizon)
model.fit(X_train, y_train)
val_preds = model.predict(X_val)
from sklearn.metrics import mean_absolute_error, mean_squared_error, r2_score
mae = mean_absolute_error(y_val, val_preds)
rmse = mean_squared_error(y_val, val_preds, squared=False)
r2 = r2_score(y_val, val_preds)
self.models[horizon] = model
results = {
"horizon": horizon,
"mae": mae,
"rmse": rmse,
"r2": r2,
"n_train": len(X_train),
"n_val": len(X_val),
}
logger.info(f"Training complete for {horizon}m: MAE={mae:.2f}, RMSE={rmse:.2f}, R2={r2:.3f}")
return results
def train_all(self, horizons: Optional[List[int]] = None) -> Dict:
if horizons is None:
horizons = [1, 5, 15, 60]
if self.data is None:
self.load_data()
df_features, feature_cols = self.prepare_data(self.data)
all_results = {}
for horizon in horizons:
try:
result = self.train_for_horizon(df_features, feature_cols, horizon)
all_results[horizon] = result
except Exception as e:
logger.error(f"Failed to train for horizon {horizon}: {e}")
all_results[horizon] = {"error": str(e)}
return all_results
def save_models(self, output_dir: str = "models/price_prediction") -> None:
output_path = Path(output_dir)
output_path.mkdir(parents=True, exist_ok=True)
for horizon, model in self.models.items():
filepath = output_path / f"model_{horizon}min.pkl"
model.save(filepath)
logger.info(f"Saved model for {horizon}m to {filepath}")
@classmethod
def load_models(cls, models_dir: str = "models/price_prediction", horizons: Optional[List[int]] = None) -> Dict[int, PricePredictionModel]:
models = {}
path = Path(models_dir)
if horizons is None:
horizons = [1, 5, 15, 60]
for horizon in horizons:
filepath = path / f"model_{horizon}min.pkl"
if filepath.exists():
model = PricePredictionModel.load(filepath)
models[horizon] = model
logger.info(f"Loaded model for {horizon}m")
return models
__all__ = ["PricePredictionTrainer"]