Implements FastAPI backend with ML model support for energy trading, including price prediction models and RL-based battery trading policy. Features dashboard, trading, backtest, and settings API routes with WebSocket support for real-time updates.
53 lines
1.4 KiB
Python
53 lines
1.4 KiB
Python
import pickle
|
|
from typing import Optional
|
|
import xgboost as xgb
|
|
import numpy as np
|
|
|
|
|
|
class PricePredictionModel:
|
|
def __init__(self, horizon: int, model_id: Optional[str] = None):
|
|
self.horizon = horizon
|
|
self.model_id = model_id or f"price_prediction_{horizon}m"
|
|
self.model: Optional[xgb.XGBRegressor] = None
|
|
self.feature_names = []
|
|
|
|
def fit(self, X, y):
|
|
self.model = xgb.XGBRegressor(
|
|
n_estimators=200,
|
|
max_depth=6,
|
|
learning_rate=0.1,
|
|
subsample=0.8,
|
|
colsample_bytree=0.8,
|
|
random_state=42,
|
|
)
|
|
|
|
if isinstance(X, np.ndarray):
|
|
self.feature_names = [f"feature_{i}" for i in range(X.shape[1])]
|
|
else:
|
|
self.feature_names = list(X.columns)
|
|
|
|
self.model.fit(X, y)
|
|
|
|
def predict(self, X):
|
|
if self.model is None:
|
|
raise ValueError("Model not trained")
|
|
return self.model.predict(X)
|
|
|
|
def save(self, filepath: str):
|
|
with open(filepath, "wb") as f:
|
|
pickle.dump(self, f)
|
|
|
|
@classmethod
|
|
def load(cls, filepath: str):
|
|
with open(filepath, "rb") as f:
|
|
return pickle.load(f)
|
|
|
|
@property
|
|
def feature_importances_(self):
|
|
if self.model is None:
|
|
raise ValueError("Model not trained")
|
|
return self.model.feature_importances_
|
|
|
|
|
|
__all__ = ["PricePredictionModel"]
|