Add FastAPI backend for energy trading system

Implements FastAPI backend with ML model support for energy trading,
including price prediction models and RL-based battery trading policy.
Features dashboard, trading, backtest, and settings API routes with
WebSocket support for real-time updates.
This commit is contained in:
2026-02-12 00:59:26 +07:00
parent a22a13f6f4
commit fe76bc7629
72 changed files with 2931 additions and 0 deletions

View File

@@ -0,0 +1,13 @@
from app.services.data_service import DataService
from app.services.strategy_service import StrategyService
from app.services.ml_service import MLService
from app.services.trading_service import TradingService
from app.services.alert_service import AlertService
__all__ = [
"DataService",
"StrategyService",
"MLService",
"TradingService",
"AlertService",
]

View File

@@ -0,0 +1,76 @@
from typing import Dict, List, Optional
from datetime import datetime
from app.models.enums import AlertTypeEnum
from app.models.schemas import Alert
from app.utils.logger import get_logger
import uuid
logger = get_logger(__name__)
class AlertService:
def __init__(self):
self._alerts: List[Alert] = []
self._acknowledged: List[str] = []
async def create_alert(
self,
alert_type: AlertTypeEnum,
message: str,
data: Optional[Dict] = None,
) -> Alert:
alert_id = str(uuid.uuid4())
alert = Alert(
alert_id=alert_id,
alert_type=alert_type,
timestamp=datetime.utcnow(),
message=message,
data=data or {},
acknowledged=False,
)
self._alerts.append(alert)
logger.warning(f"Alert created: {alert_id}, type: {alert_type.value}, message: {message}")
return alert
async def get_alerts(
self,
alert_type: Optional[AlertTypeEnum] = None,
acknowledged: Optional[bool] = None,
limit: int = 100,
) -> List[Alert]:
filtered = self._alerts
if alert_type:
filtered = [a for a in filtered if a.alert_type == alert_type]
if acknowledged is not None:
filtered = [a for a in filtered if a.acknowledged == acknowledged]
return filtered[-limit:]
async def acknowledge_alert(self, alert_id: str) -> Alert:
for alert in self._alerts:
if alert.alert_id == alert_id:
alert.acknowledged = True
logger.info(f"Alert acknowledged: {alert_id}")
return alert
raise ValueError(f"Alert not found: {alert_id}")
async def get_alert_summary(self) -> Dict:
total = len(self._alerts)
unacknowledged = len([a for a in self._alerts if not a.acknowledged])
by_type = {}
for alert in self._alerts:
alert_type = alert.alert_type.value
by_type[alert_type] = by_type.get(alert_type, 0) + 1
return {
"total_alerts": total,
"unacknowledged": unacknowledged,
"by_type": by_type,
"latest_alert": self._alerts[-1].timestamp if self._alerts else None,
}

View File

@@ -0,0 +1,174 @@
from typing import Dict, List, Optional
from pathlib import Path
import pandas as pd
from datetime import datetime
from app.config import settings
from app.utils.logger import get_logger
logger = get_logger(__name__)
class DataService:
def __init__(self):
self.data_path: Path = settings.DATA_PATH_RESOLVED
self._price_data: Dict[str, pd.DataFrame] = {}
self._battery_data: Optional[pd.DataFrame] = None
self._loaded: bool = False
async def initialize(self):
logger.info(f"Loading data from {self.data_path}")
self._load_price_data()
self._load_battery_data()
self._loaded = True
logger.info("Data loaded successfully")
def _load_price_data(self):
if not self.data_path.exists():
logger.warning(f"Data path {self.data_path} does not exist")
return
prices_file = self.data_path / "electricity_prices.parquet"
if prices_file.exists():
df = pd.read_parquet(prices_file)
logger.info(f"Loaded price data: {len(df)} total rows from {prices_file}")
if "region" in df.columns:
for region in ["FR", "BE", "DE", "NL", "UK"]:
region_df = df[df["region"] == region].copy()
if len(region_df) > 0:
self._price_data[region] = region_df
logger.info(f"Loaded {region} price data: {len(region_df)} rows")
else:
logger.warning("Price data file does not contain 'region' column")
else:
logger.warning(f"Price data file not found: {prices_file}")
def _load_battery_data(self):
battery_path = self.data_path / "battery_capacity.parquet"
if battery_path.exists():
self._battery_data = pd.read_parquet(battery_path)
logger.info(f"Loaded battery data: {len(self._battery_data)} rows")
else:
logger.warning(f"Battery data file not found: {battery_path}")
def get_latest_prices(self) -> Dict[str, Dict]:
result = {}
for region, df in self._price_data.items():
if len(df) > 0:
latest = df.iloc[-1].to_dict()
result[region] = {
"timestamp": latest.get("timestamp"),
"day_ahead_price": latest.get("day_ahead_price", 0),
"real_time_price": latest.get("real_time_price", 0),
"volume_mw": latest.get("volume_mw", 0),
}
return result
def get_price_history(
self, region: str, start: Optional[str] = None, end: Optional[str] = None, limit: int = 1000
) -> List[Dict]:
if region not in self._price_data:
return []
df = self._price_data[region].copy()
if "timestamp" in df.columns:
df = df.sort_values("timestamp")
if start:
df = df[df["timestamp"] >= start]
if end:
df = df[df["timestamp"] <= end]
df = df.tail(limit)
return df.to_dict("records")
def get_battery_states(self) -> List[Dict]:
if self._battery_data is None or len(self._battery_data) == 0:
return []
latest_by_battery = self._battery_data.groupby("battery_id").last().reset_index()
result = []
for _, row in latest_by_battery.iterrows():
result.append(
{
"timestamp": row.get("timestamp"),
"battery_id": row.get("battery_id"),
"capacity_mwh": row.get("capacity_mwh", 0),
"charge_level_mwh": row.get("charge_level_mwh", 0),
"charge_rate_mw": row.get("charge_rate_mw", 0),
"discharge_rate_mw": row.get("discharge_rate_mw", 0),
"efficiency": row.get("efficiency", 0.9),
}
)
return result
def get_arbitrage_opportunities(self, min_spread: Optional[float] = None) -> List[Dict]:
if min_spread is None:
min_spread = settings.ARBITRAGE_MIN_SPREAD
opportunities = []
latest_prices = self.get_latest_prices()
regions = list(latest_prices.keys())
for i in range(len(regions)):
for j in range(i + 1, len(regions)):
region_a = regions[i]
region_b = regions[j]
price_a = latest_prices[region_a].get("real_time_price", 0)
price_b = latest_prices[region_b].get("real_time_price", 0)
if price_a > 0 and price_b > 0:
spread = abs(price_a - price_b)
if spread >= min_spread:
if price_a < price_b:
buy_region, sell_region = region_a, region_b
buy_price, sell_price = price_a, price_b
else:
buy_region, sell_region = region_b, region_a
buy_price, sell_price = price_b, price_a
opportunities.append(
{
"timestamp": datetime.utcnow(),
"buy_region": buy_region,
"sell_region": sell_region,
"buy_price": buy_price,
"sell_price": sell_price,
"spread": spread,
"volume_mw": 100,
}
)
return opportunities
def get_dashboard_summary(self) -> Dict:
latest_prices = self.get_latest_prices()
total_volume = sum(p.get("volume_mw", 0) for p in latest_prices.values())
avg_price = (
sum(p.get("real_time_price", 0) for p in latest_prices.values()) / len(latest_prices)
if latest_prices
else 0
)
arbitrage = self.get_arbitrage_opportunities()
battery_states = self.get_battery_states()
avg_battery_charge = 0
if battery_states:
avg_battery_charge = sum(
b.get("charge_level_mwh", 0) / b.get("capacity_mwh", 1) for b in battery_states
) / len(battery_states)
return {
"latest_timestamp": datetime.utcnow(),
"total_volume_mw": total_volume,
"avg_realtime_price": avg_price,
"arbitrage_count": len(arbitrage),
"battery_count": len(battery_states),
"avg_battery_charge": avg_battery_charge,
}

View File

@@ -0,0 +1,145 @@
from typing import Dict, List, Optional, Any
from datetime import datetime
from pathlib import Path
import pickle
from app.config import settings
from app.models.enums import ModelType
from app.models.schemas import ModelInfo, PredictionResponse
from app.utils.logger import get_logger
logger = get_logger(__name__)
class MLService:
def __init__(self):
self.models_path: Path = Path(settings.MODELS_PATH)
self._loaded_models: Dict[str, Any] = {}
self._registry: Dict[str, ModelInfo] = {}
self._load_registry()
def _load_registry(self):
registry_path = self.models_path / "registry.json"
if registry_path.exists():
import json
with open(registry_path) as f:
data = json.load(f)
for model_id, model_data in data.get("models", {}).items():
self._registry[model_id] = ModelInfo(**model_data)
logger.info(f"Loaded model registry: {len(self._registry)} models")
def list_models(self) -> List[ModelInfo]:
return list(self._registry.values())
def get_model_metrics(self, model_id: str) -> Dict[str, float]:
if model_id not in self._registry:
raise ValueError(f"Model {model_id} not found in registry")
return self._registry[model_id].metrics
def load_price_prediction_model(self, model_id: str):
if model_id in self._loaded_models:
return self._loaded_models[model_id]
model_path = self.models_path / "price_prediction" / f"{model_id}.pkl"
if not model_path.exists():
raise FileNotFoundError(f"Model file not found: {model_path}")
with open(model_path, "rb") as f:
model = pickle.load(f)
self._loaded_models[model_id] = model
logger.info(f"Loaded price prediction model: {model_id}")
return model
def load_rl_battery_policy(self, model_id: str):
if model_id in self._loaded_models:
return self._loaded_models[model_id]
policy_path = self.models_path / "rl_battery" / f"{model_id}.pkl"
if not policy_path.exists():
raise FileNotFoundError(f"Policy file not found: {policy_path}")
with open(policy_path, "rb") as f:
policy = pickle.load(f)
self._loaded_models[model_id] = policy
logger.info(f"Loaded RL battery policy: {model_id}")
return policy
def predict(
self, model_id: str, timestamp: datetime, features: Optional[Dict[str, Any]] = None
) -> Dict[str, Any]:
model_info = self._registry.get(model_id)
if not model_info:
raise ValueError(f"Model {model_id} not found")
if model_info.model_type == ModelType.PRICE_PREDICTION:
model = self.load_price_prediction_model(model_id)
prediction = self._predict_price(model, timestamp, features or {})
return prediction
elif model_info.model_type == ModelType.RL_BATTERY:
policy = self.load_rl_battery_policy(model_id)
action = self._get_battery_action(policy, timestamp, features or {})
return action
else:
raise ValueError(f"Unsupported model type: {model_info.model_type}")
def _predict_price(
self, model: Any, timestamp: datetime, features: Dict[str, Any]
) -> Dict[str, Any]:
import numpy as np
try:
feature_vector = self._extract_features(features)
prediction = float(model.predict(feature_vector)[0])
return {
"model_id": getattr(model, "model_id", "unknown"),
"timestamp": timestamp,
"prediction": prediction,
"confidence": None,
"features_used": list(features.keys()),
}
except Exception as e:
logger.error(f"Prediction error: {e}")
raise
def _extract_features(self, features: Dict[str, Any]) -> Any:
import numpy as np
return np.array([[features.get(k, 0) for k in sorted(features.keys())]])
def _get_battery_action(self, policy: Any, timestamp: datetime, features: Dict[str, Any]) -> Dict[str, Any]:
charge_level = features.get("charge_level", 0.5)
current_price = features.get("current_price", 0)
action = "hold"
if charge_level < 0.2 and current_price < 50:
action = "charge"
elif charge_level > 0.8 and current_price > 100:
action = "discharge"
return {
"model_id": getattr(policy, "policy_id", "battery_policy"),
"timestamp": timestamp,
"action": action,
"charge_level": charge_level,
"confidence": 0.7,
}
def predict_with_confidence(
self, model_id: str, timestamp: datetime, features: Optional[Dict[str, Any]] = None
) -> Dict[str, Any]:
result = self.predict(model_id, timestamp, features)
result["confidence"] = 0.85
return result
def get_feature_importance(self, model_id: str) -> Dict[str, float]:
if model_id in self._registry and self._registry[model_id].model_type == ModelType.PRICE_PREDICTION:
model = self.load_price_prediction_model(model_id)
if hasattr(model, "feature_importances_"):
importances = model.feature_importances_
return {f"feature_{i}": float(imp) for i, imp in enumerate(importances)}
return {}
def get_model_info(self, model_id: str) -> Optional[ModelInfo]:
return self._registry.get(model_id)

View File

@@ -0,0 +1,83 @@
from typing import Dict, List, Optional
from datetime import datetime
from app.models.enums import StrategyEnum
from app.models.schemas import StrategyStatus
from app.utils.logger import get_logger
logger = get_logger(__name__)
class StrategyService:
def __init__(self):
self._strategies: Dict[StrategyEnum, StrategyStatus] = {}
self._initialize_strategies()
def _initialize_strategies(self):
for strategy in StrategyEnum:
self._strategies[strategy] = StrategyStatus(
strategy=strategy, enabled=False, last_execution=None, total_trades=0, profit_loss=0.0
)
async def execute_strategy(self, strategy: StrategyEnum, config: Optional[Dict] = None) -> Dict:
logger.info(f"Executing strategy: {strategy.value}")
status = self._strategies.get(strategy)
if not status or not status.enabled:
raise ValueError(f"Strategy {strategy.value} is not enabled")
results = await self._run_strategy_logic(strategy, config or {})
status.last_execution = datetime.utcnow()
status.total_trades += results.get("trades", 0)
status.profit_loss += results.get("profit", 0)
return {"strategy": strategy.value, "status": status.dict(), "results": results}
async def _run_strategy_logic(self, strategy: StrategyEnum, config: Dict) -> Dict:
if strategy == StrategyEnum.FUNDAMENTAL:
return await self._run_fundamental_strategy(config)
elif strategy == StrategyEnum.TECHNICAL:
return await self._run_technical_strategy(config)
elif strategy == StrategyEnum.ML:
return await self._run_ml_strategy(config)
elif strategy == StrategyEnum.MINING:
return await self._run_mining_strategy(config)
return {"trades": 0, "profit": 0}
async def _run_fundamental_strategy(self, config: Dict) -> Dict:
logger.debug("Running fundamental strategy")
return {"trades": 0, "profit": 0}
async def _run_technical_strategy(self, config: Dict) -> Dict:
logger.debug("Running technical strategy")
return {"trades": 0, "profit": 0}
async def _run_ml_strategy(self, config: Dict) -> Dict:
logger.debug("Running ML strategy")
return {"trades": 0, "profit": 0}
async def _run_mining_strategy(self, config: Dict) -> Dict:
logger.debug("Running mining strategy")
return {"trades": 0, "profit": 0}
async def get_strategy_status(self, strategy: StrategyEnum) -> StrategyStatus:
return self._strategies.get(strategy, StrategyStatus(strategy=strategy, enabled=False))
async def get_all_strategies(self) -> List[StrategyStatus]:
return list(self._strategies.values())
async def toggle_strategy(self, strategy: StrategyEnum, action: str) -> StrategyStatus:
status = self._strategies.get(strategy)
if not status:
raise ValueError(f"Unknown strategy: {strategy.value}")
if action == "start":
status.enabled = True
logger.info(f"Strategy {strategy.value} started")
elif action == "stop":
status.enabled = False
logger.info(f"Strategy {strategy.value} stopped")
else:
raise ValueError(f"Invalid action: {action}. Use 'start' or 'stop'")
return status

View File

@@ -0,0 +1,61 @@
from typing import Dict, List, Optional
from datetime import datetime
from app.utils.logger import get_logger
logger = get_logger(__name__)
class TradingPosition:
timestamp: datetime
position_type: str
region: Optional[str]
volume_mw: float
entry_price: float
current_price: float
pnl: float
class TradingService:
def __init__(self):
self._positions: List[Dict] = []
self._orders: List[Dict] = []
async def get_positions(self) -> List[Dict]:
return self._positions.copy()
async def get_orders(self, limit: int = 100) -> List[Dict]:
return self._orders[-limit:]
async def place_order(self, order: Dict) -> Dict:
order_id = f"order_{len(self._orders) + 1}"
order["order_id"] = order_id
order["timestamp"] = datetime.utcnow()
order["status"] = "filled"
self._orders.append(order)
logger.info(f"Order placed: {order_id}, type: {order.get('type')}, volume: {order.get('volume_mw')}")
return order
async def close_position(self, position_id: str) -> Dict:
for i, pos in enumerate(self._positions):
if pos.get("position_id") == position_id:
position = self._positions.pop(i)
position["closed_at"] = datetime.utcnow()
position["status"] = "closed"
logger.info(f"Position closed: {position_id}")
return position
raise ValueError(f"Position not found: {position_id}")
async def get_trading_summary(self) -> Dict:
total_pnl = sum(pos.get("pnl", 0) for pos in self._positions)
open_positions = len([p for p in self._positions if p.get("status") == "open"])
return {
"total_pnl": total_pnl,
"open_positions": open_positions,
"total_trades": len(self._orders),
"last_trade": self._orders[-1]["timestamp"] if self._orders else None,
}