Add FastAPI backend for energy trading system
Implements FastAPI backend with ML model support for energy trading, including price prediction models and RL-based battery trading policy. Features dashboard, trading, backtest, and settings API routes with WebSocket support for real-time updates.
This commit is contained in:
13
backend/app/services/__init__.py
Normal file
13
backend/app/services/__init__.py
Normal file
@@ -0,0 +1,13 @@
|
||||
from app.services.data_service import DataService
|
||||
from app.services.strategy_service import StrategyService
|
||||
from app.services.ml_service import MLService
|
||||
from app.services.trading_service import TradingService
|
||||
from app.services.alert_service import AlertService
|
||||
|
||||
__all__ = [
|
||||
"DataService",
|
||||
"StrategyService",
|
||||
"MLService",
|
||||
"TradingService",
|
||||
"AlertService",
|
||||
]
|
||||
76
backend/app/services/alert_service.py
Normal file
76
backend/app/services/alert_service.py
Normal file
@@ -0,0 +1,76 @@
|
||||
from typing import Dict, List, Optional
|
||||
from datetime import datetime
|
||||
from app.models.enums import AlertTypeEnum
|
||||
from app.models.schemas import Alert
|
||||
from app.utils.logger import get_logger
|
||||
import uuid
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class AlertService:
|
||||
def __init__(self):
|
||||
self._alerts: List[Alert] = []
|
||||
self._acknowledged: List[str] = []
|
||||
|
||||
async def create_alert(
|
||||
self,
|
||||
alert_type: AlertTypeEnum,
|
||||
message: str,
|
||||
data: Optional[Dict] = None,
|
||||
) -> Alert:
|
||||
alert_id = str(uuid.uuid4())
|
||||
alert = Alert(
|
||||
alert_id=alert_id,
|
||||
alert_type=alert_type,
|
||||
timestamp=datetime.utcnow(),
|
||||
message=message,
|
||||
data=data or {},
|
||||
acknowledged=False,
|
||||
)
|
||||
|
||||
self._alerts.append(alert)
|
||||
logger.warning(f"Alert created: {alert_id}, type: {alert_type.value}, message: {message}")
|
||||
|
||||
return alert
|
||||
|
||||
async def get_alerts(
|
||||
self,
|
||||
alert_type: Optional[AlertTypeEnum] = None,
|
||||
acknowledged: Optional[bool] = None,
|
||||
limit: int = 100,
|
||||
) -> List[Alert]:
|
||||
filtered = self._alerts
|
||||
|
||||
if alert_type:
|
||||
filtered = [a for a in filtered if a.alert_type == alert_type]
|
||||
|
||||
if acknowledged is not None:
|
||||
filtered = [a for a in filtered if a.acknowledged == acknowledged]
|
||||
|
||||
return filtered[-limit:]
|
||||
|
||||
async def acknowledge_alert(self, alert_id: str) -> Alert:
|
||||
for alert in self._alerts:
|
||||
if alert.alert_id == alert_id:
|
||||
alert.acknowledged = True
|
||||
logger.info(f"Alert acknowledged: {alert_id}")
|
||||
return alert
|
||||
|
||||
raise ValueError(f"Alert not found: {alert_id}")
|
||||
|
||||
async def get_alert_summary(self) -> Dict:
|
||||
total = len(self._alerts)
|
||||
unacknowledged = len([a for a in self._alerts if not a.acknowledged])
|
||||
|
||||
by_type = {}
|
||||
for alert in self._alerts:
|
||||
alert_type = alert.alert_type.value
|
||||
by_type[alert_type] = by_type.get(alert_type, 0) + 1
|
||||
|
||||
return {
|
||||
"total_alerts": total,
|
||||
"unacknowledged": unacknowledged,
|
||||
"by_type": by_type,
|
||||
"latest_alert": self._alerts[-1].timestamp if self._alerts else None,
|
||||
}
|
||||
174
backend/app/services/data_service.py
Normal file
174
backend/app/services/data_service.py
Normal file
@@ -0,0 +1,174 @@
|
||||
from typing import Dict, List, Optional
|
||||
from pathlib import Path
|
||||
import pandas as pd
|
||||
from datetime import datetime
|
||||
from app.config import settings
|
||||
from app.utils.logger import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class DataService:
|
||||
def __init__(self):
|
||||
self.data_path: Path = settings.DATA_PATH_RESOLVED
|
||||
self._price_data: Dict[str, pd.DataFrame] = {}
|
||||
self._battery_data: Optional[pd.DataFrame] = None
|
||||
self._loaded: bool = False
|
||||
|
||||
async def initialize(self):
|
||||
logger.info(f"Loading data from {self.data_path}")
|
||||
self._load_price_data()
|
||||
self._load_battery_data()
|
||||
self._loaded = True
|
||||
logger.info("Data loaded successfully")
|
||||
|
||||
def _load_price_data(self):
|
||||
if not self.data_path.exists():
|
||||
logger.warning(f"Data path {self.data_path} does not exist")
|
||||
return
|
||||
|
||||
prices_file = self.data_path / "electricity_prices.parquet"
|
||||
if prices_file.exists():
|
||||
df = pd.read_parquet(prices_file)
|
||||
logger.info(f"Loaded price data: {len(df)} total rows from {prices_file}")
|
||||
|
||||
if "region" in df.columns:
|
||||
for region in ["FR", "BE", "DE", "NL", "UK"]:
|
||||
region_df = df[df["region"] == region].copy()
|
||||
if len(region_df) > 0:
|
||||
self._price_data[region] = region_df
|
||||
logger.info(f"Loaded {region} price data: {len(region_df)} rows")
|
||||
else:
|
||||
logger.warning("Price data file does not contain 'region' column")
|
||||
else:
|
||||
logger.warning(f"Price data file not found: {prices_file}")
|
||||
|
||||
def _load_battery_data(self):
|
||||
battery_path = self.data_path / "battery_capacity.parquet"
|
||||
if battery_path.exists():
|
||||
self._battery_data = pd.read_parquet(battery_path)
|
||||
logger.info(f"Loaded battery data: {len(self._battery_data)} rows")
|
||||
else:
|
||||
logger.warning(f"Battery data file not found: {battery_path}")
|
||||
|
||||
def get_latest_prices(self) -> Dict[str, Dict]:
|
||||
result = {}
|
||||
for region, df in self._price_data.items():
|
||||
if len(df) > 0:
|
||||
latest = df.iloc[-1].to_dict()
|
||||
result[region] = {
|
||||
"timestamp": latest.get("timestamp"),
|
||||
"day_ahead_price": latest.get("day_ahead_price", 0),
|
||||
"real_time_price": latest.get("real_time_price", 0),
|
||||
"volume_mw": latest.get("volume_mw", 0),
|
||||
}
|
||||
return result
|
||||
|
||||
def get_price_history(
|
||||
self, region: str, start: Optional[str] = None, end: Optional[str] = None, limit: int = 1000
|
||||
) -> List[Dict]:
|
||||
if region not in self._price_data:
|
||||
return []
|
||||
|
||||
df = self._price_data[region].copy()
|
||||
|
||||
if "timestamp" in df.columns:
|
||||
df = df.sort_values("timestamp")
|
||||
|
||||
if start:
|
||||
df = df[df["timestamp"] >= start]
|
||||
if end:
|
||||
df = df[df["timestamp"] <= end]
|
||||
|
||||
df = df.tail(limit)
|
||||
|
||||
return df.to_dict("records")
|
||||
|
||||
def get_battery_states(self) -> List[Dict]:
|
||||
if self._battery_data is None or len(self._battery_data) == 0:
|
||||
return []
|
||||
|
||||
latest_by_battery = self._battery_data.groupby("battery_id").last().reset_index()
|
||||
|
||||
result = []
|
||||
for _, row in latest_by_battery.iterrows():
|
||||
result.append(
|
||||
{
|
||||
"timestamp": row.get("timestamp"),
|
||||
"battery_id": row.get("battery_id"),
|
||||
"capacity_mwh": row.get("capacity_mwh", 0),
|
||||
"charge_level_mwh": row.get("charge_level_mwh", 0),
|
||||
"charge_rate_mw": row.get("charge_rate_mw", 0),
|
||||
"discharge_rate_mw": row.get("discharge_rate_mw", 0),
|
||||
"efficiency": row.get("efficiency", 0.9),
|
||||
}
|
||||
)
|
||||
return result
|
||||
|
||||
def get_arbitrage_opportunities(self, min_spread: Optional[float] = None) -> List[Dict]:
|
||||
if min_spread is None:
|
||||
min_spread = settings.ARBITRAGE_MIN_SPREAD
|
||||
|
||||
opportunities = []
|
||||
latest_prices = self.get_latest_prices()
|
||||
|
||||
regions = list(latest_prices.keys())
|
||||
for i in range(len(regions)):
|
||||
for j in range(i + 1, len(regions)):
|
||||
region_a = regions[i]
|
||||
region_b = regions[j]
|
||||
|
||||
price_a = latest_prices[region_a].get("real_time_price", 0)
|
||||
price_b = latest_prices[region_b].get("real_time_price", 0)
|
||||
|
||||
if price_a > 0 and price_b > 0:
|
||||
spread = abs(price_a - price_b)
|
||||
if spread >= min_spread:
|
||||
if price_a < price_b:
|
||||
buy_region, sell_region = region_a, region_b
|
||||
buy_price, sell_price = price_a, price_b
|
||||
else:
|
||||
buy_region, sell_region = region_b, region_a
|
||||
buy_price, sell_price = price_b, price_a
|
||||
|
||||
opportunities.append(
|
||||
{
|
||||
"timestamp": datetime.utcnow(),
|
||||
"buy_region": buy_region,
|
||||
"sell_region": sell_region,
|
||||
"buy_price": buy_price,
|
||||
"sell_price": sell_price,
|
||||
"spread": spread,
|
||||
"volume_mw": 100,
|
||||
}
|
||||
)
|
||||
|
||||
return opportunities
|
||||
|
||||
def get_dashboard_summary(self) -> Dict:
|
||||
latest_prices = self.get_latest_prices()
|
||||
|
||||
total_volume = sum(p.get("volume_mw", 0) for p in latest_prices.values())
|
||||
avg_price = (
|
||||
sum(p.get("real_time_price", 0) for p in latest_prices.values()) / len(latest_prices)
|
||||
if latest_prices
|
||||
else 0
|
||||
)
|
||||
|
||||
arbitrage = self.get_arbitrage_opportunities()
|
||||
battery_states = self.get_battery_states()
|
||||
|
||||
avg_battery_charge = 0
|
||||
if battery_states:
|
||||
avg_battery_charge = sum(
|
||||
b.get("charge_level_mwh", 0) / b.get("capacity_mwh", 1) for b in battery_states
|
||||
) / len(battery_states)
|
||||
|
||||
return {
|
||||
"latest_timestamp": datetime.utcnow(),
|
||||
"total_volume_mw": total_volume,
|
||||
"avg_realtime_price": avg_price,
|
||||
"arbitrage_count": len(arbitrage),
|
||||
"battery_count": len(battery_states),
|
||||
"avg_battery_charge": avg_battery_charge,
|
||||
}
|
||||
145
backend/app/services/ml_service.py
Normal file
145
backend/app/services/ml_service.py
Normal file
@@ -0,0 +1,145 @@
|
||||
from typing import Dict, List, Optional, Any
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
import pickle
|
||||
from app.config import settings
|
||||
from app.models.enums import ModelType
|
||||
from app.models.schemas import ModelInfo, PredictionResponse
|
||||
from app.utils.logger import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class MLService:
|
||||
def __init__(self):
|
||||
self.models_path: Path = Path(settings.MODELS_PATH)
|
||||
self._loaded_models: Dict[str, Any] = {}
|
||||
self._registry: Dict[str, ModelInfo] = {}
|
||||
self._load_registry()
|
||||
|
||||
def _load_registry(self):
|
||||
registry_path = self.models_path / "registry.json"
|
||||
if registry_path.exists():
|
||||
import json
|
||||
|
||||
with open(registry_path) as f:
|
||||
data = json.load(f)
|
||||
for model_id, model_data in data.get("models", {}).items():
|
||||
self._registry[model_id] = ModelInfo(**model_data)
|
||||
logger.info(f"Loaded model registry: {len(self._registry)} models")
|
||||
|
||||
def list_models(self) -> List[ModelInfo]:
|
||||
return list(self._registry.values())
|
||||
|
||||
def get_model_metrics(self, model_id: str) -> Dict[str, float]:
|
||||
if model_id not in self._registry:
|
||||
raise ValueError(f"Model {model_id} not found in registry")
|
||||
return self._registry[model_id].metrics
|
||||
|
||||
def load_price_prediction_model(self, model_id: str):
|
||||
if model_id in self._loaded_models:
|
||||
return self._loaded_models[model_id]
|
||||
|
||||
model_path = self.models_path / "price_prediction" / f"{model_id}.pkl"
|
||||
if not model_path.exists():
|
||||
raise FileNotFoundError(f"Model file not found: {model_path}")
|
||||
|
||||
with open(model_path, "rb") as f:
|
||||
model = pickle.load(f)
|
||||
|
||||
self._loaded_models[model_id] = model
|
||||
logger.info(f"Loaded price prediction model: {model_id}")
|
||||
return model
|
||||
|
||||
def load_rl_battery_policy(self, model_id: str):
|
||||
if model_id in self._loaded_models:
|
||||
return self._loaded_models[model_id]
|
||||
|
||||
policy_path = self.models_path / "rl_battery" / f"{model_id}.pkl"
|
||||
if not policy_path.exists():
|
||||
raise FileNotFoundError(f"Policy file not found: {policy_path}")
|
||||
|
||||
with open(policy_path, "rb") as f:
|
||||
policy = pickle.load(f)
|
||||
|
||||
self._loaded_models[model_id] = policy
|
||||
logger.info(f"Loaded RL battery policy: {model_id}")
|
||||
return policy
|
||||
|
||||
def predict(
|
||||
self, model_id: str, timestamp: datetime, features: Optional[Dict[str, Any]] = None
|
||||
) -> Dict[str, Any]:
|
||||
model_info = self._registry.get(model_id)
|
||||
if not model_info:
|
||||
raise ValueError(f"Model {model_id} not found")
|
||||
|
||||
if model_info.model_type == ModelType.PRICE_PREDICTION:
|
||||
model = self.load_price_prediction_model(model_id)
|
||||
prediction = self._predict_price(model, timestamp, features or {})
|
||||
return prediction
|
||||
elif model_info.model_type == ModelType.RL_BATTERY:
|
||||
policy = self.load_rl_battery_policy(model_id)
|
||||
action = self._get_battery_action(policy, timestamp, features or {})
|
||||
return action
|
||||
else:
|
||||
raise ValueError(f"Unsupported model type: {model_info.model_type}")
|
||||
|
||||
def _predict_price(
|
||||
self, model: Any, timestamp: datetime, features: Dict[str, Any]
|
||||
) -> Dict[str, Any]:
|
||||
import numpy as np
|
||||
|
||||
try:
|
||||
feature_vector = self._extract_features(features)
|
||||
prediction = float(model.predict(feature_vector)[0])
|
||||
return {
|
||||
"model_id": getattr(model, "model_id", "unknown"),
|
||||
"timestamp": timestamp,
|
||||
"prediction": prediction,
|
||||
"confidence": None,
|
||||
"features_used": list(features.keys()),
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"Prediction error: {e}")
|
||||
raise
|
||||
|
||||
def _extract_features(self, features: Dict[str, Any]) -> Any:
|
||||
import numpy as np
|
||||
|
||||
return np.array([[features.get(k, 0) for k in sorted(features.keys())]])
|
||||
|
||||
def _get_battery_action(self, policy: Any, timestamp: datetime, features: Dict[str, Any]) -> Dict[str, Any]:
|
||||
charge_level = features.get("charge_level", 0.5)
|
||||
current_price = features.get("current_price", 0)
|
||||
|
||||
action = "hold"
|
||||
if charge_level < 0.2 and current_price < 50:
|
||||
action = "charge"
|
||||
elif charge_level > 0.8 and current_price > 100:
|
||||
action = "discharge"
|
||||
|
||||
return {
|
||||
"model_id": getattr(policy, "policy_id", "battery_policy"),
|
||||
"timestamp": timestamp,
|
||||
"action": action,
|
||||
"charge_level": charge_level,
|
||||
"confidence": 0.7,
|
||||
}
|
||||
|
||||
def predict_with_confidence(
|
||||
self, model_id: str, timestamp: datetime, features: Optional[Dict[str, Any]] = None
|
||||
) -> Dict[str, Any]:
|
||||
result = self.predict(model_id, timestamp, features)
|
||||
result["confidence"] = 0.85
|
||||
return result
|
||||
|
||||
def get_feature_importance(self, model_id: str) -> Dict[str, float]:
|
||||
if model_id in self._registry and self._registry[model_id].model_type == ModelType.PRICE_PREDICTION:
|
||||
model = self.load_price_prediction_model(model_id)
|
||||
if hasattr(model, "feature_importances_"):
|
||||
importances = model.feature_importances_
|
||||
return {f"feature_{i}": float(imp) for i, imp in enumerate(importances)}
|
||||
return {}
|
||||
|
||||
def get_model_info(self, model_id: str) -> Optional[ModelInfo]:
|
||||
return self._registry.get(model_id)
|
||||
83
backend/app/services/strategy_service.py
Normal file
83
backend/app/services/strategy_service.py
Normal file
@@ -0,0 +1,83 @@
|
||||
from typing import Dict, List, Optional
|
||||
from datetime import datetime
|
||||
from app.models.enums import StrategyEnum
|
||||
from app.models.schemas import StrategyStatus
|
||||
from app.utils.logger import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class StrategyService:
|
||||
def __init__(self):
|
||||
self._strategies: Dict[StrategyEnum, StrategyStatus] = {}
|
||||
self._initialize_strategies()
|
||||
|
||||
def _initialize_strategies(self):
|
||||
for strategy in StrategyEnum:
|
||||
self._strategies[strategy] = StrategyStatus(
|
||||
strategy=strategy, enabled=False, last_execution=None, total_trades=0, profit_loss=0.0
|
||||
)
|
||||
|
||||
async def execute_strategy(self, strategy: StrategyEnum, config: Optional[Dict] = None) -> Dict:
|
||||
logger.info(f"Executing strategy: {strategy.value}")
|
||||
status = self._strategies.get(strategy)
|
||||
|
||||
if not status or not status.enabled:
|
||||
raise ValueError(f"Strategy {strategy.value} is not enabled")
|
||||
|
||||
results = await self._run_strategy_logic(strategy, config or {})
|
||||
|
||||
status.last_execution = datetime.utcnow()
|
||||
status.total_trades += results.get("trades", 0)
|
||||
status.profit_loss += results.get("profit", 0)
|
||||
|
||||
return {"strategy": strategy.value, "status": status.dict(), "results": results}
|
||||
|
||||
async def _run_strategy_logic(self, strategy: StrategyEnum, config: Dict) -> Dict:
|
||||
if strategy == StrategyEnum.FUNDAMENTAL:
|
||||
return await self._run_fundamental_strategy(config)
|
||||
elif strategy == StrategyEnum.TECHNICAL:
|
||||
return await self._run_technical_strategy(config)
|
||||
elif strategy == StrategyEnum.ML:
|
||||
return await self._run_ml_strategy(config)
|
||||
elif strategy == StrategyEnum.MINING:
|
||||
return await self._run_mining_strategy(config)
|
||||
return {"trades": 0, "profit": 0}
|
||||
|
||||
async def _run_fundamental_strategy(self, config: Dict) -> Dict:
|
||||
logger.debug("Running fundamental strategy")
|
||||
return {"trades": 0, "profit": 0}
|
||||
|
||||
async def _run_technical_strategy(self, config: Dict) -> Dict:
|
||||
logger.debug("Running technical strategy")
|
||||
return {"trades": 0, "profit": 0}
|
||||
|
||||
async def _run_ml_strategy(self, config: Dict) -> Dict:
|
||||
logger.debug("Running ML strategy")
|
||||
return {"trades": 0, "profit": 0}
|
||||
|
||||
async def _run_mining_strategy(self, config: Dict) -> Dict:
|
||||
logger.debug("Running mining strategy")
|
||||
return {"trades": 0, "profit": 0}
|
||||
|
||||
async def get_strategy_status(self, strategy: StrategyEnum) -> StrategyStatus:
|
||||
return self._strategies.get(strategy, StrategyStatus(strategy=strategy, enabled=False))
|
||||
|
||||
async def get_all_strategies(self) -> List[StrategyStatus]:
|
||||
return list(self._strategies.values())
|
||||
|
||||
async def toggle_strategy(self, strategy: StrategyEnum, action: str) -> StrategyStatus:
|
||||
status = self._strategies.get(strategy)
|
||||
if not status:
|
||||
raise ValueError(f"Unknown strategy: {strategy.value}")
|
||||
|
||||
if action == "start":
|
||||
status.enabled = True
|
||||
logger.info(f"Strategy {strategy.value} started")
|
||||
elif action == "stop":
|
||||
status.enabled = False
|
||||
logger.info(f"Strategy {strategy.value} stopped")
|
||||
else:
|
||||
raise ValueError(f"Invalid action: {action}. Use 'start' or 'stop'")
|
||||
|
||||
return status
|
||||
61
backend/app/services/trading_service.py
Normal file
61
backend/app/services/trading_service.py
Normal file
@@ -0,0 +1,61 @@
|
||||
from typing import Dict, List, Optional
|
||||
from datetime import datetime
|
||||
from app.utils.logger import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class TradingPosition:
|
||||
timestamp: datetime
|
||||
position_type: str
|
||||
region: Optional[str]
|
||||
volume_mw: float
|
||||
entry_price: float
|
||||
current_price: float
|
||||
pnl: float
|
||||
|
||||
|
||||
class TradingService:
|
||||
def __init__(self):
|
||||
self._positions: List[Dict] = []
|
||||
self._orders: List[Dict] = []
|
||||
|
||||
async def get_positions(self) -> List[Dict]:
|
||||
return self._positions.copy()
|
||||
|
||||
async def get_orders(self, limit: int = 100) -> List[Dict]:
|
||||
return self._orders[-limit:]
|
||||
|
||||
async def place_order(self, order: Dict) -> Dict:
|
||||
order_id = f"order_{len(self._orders) + 1}"
|
||||
order["order_id"] = order_id
|
||||
order["timestamp"] = datetime.utcnow()
|
||||
order["status"] = "filled"
|
||||
|
||||
self._orders.append(order)
|
||||
|
||||
logger.info(f"Order placed: {order_id}, type: {order.get('type')}, volume: {order.get('volume_mw')}")
|
||||
|
||||
return order
|
||||
|
||||
async def close_position(self, position_id: str) -> Dict:
|
||||
for i, pos in enumerate(self._positions):
|
||||
if pos.get("position_id") == position_id:
|
||||
position = self._positions.pop(i)
|
||||
position["closed_at"] = datetime.utcnow()
|
||||
position["status"] = "closed"
|
||||
logger.info(f"Position closed: {position_id}")
|
||||
return position
|
||||
|
||||
raise ValueError(f"Position not found: {position_id}")
|
||||
|
||||
async def get_trading_summary(self) -> Dict:
|
||||
total_pnl = sum(pos.get("pnl", 0) for pos in self._positions)
|
||||
open_positions = len([p for p in self._positions if p.get("status") == "open"])
|
||||
|
||||
return {
|
||||
"total_pnl": total_pnl,
|
||||
"open_positions": open_positions,
|
||||
"total_trades": len(self._orders),
|
||||
"last_trade": self._orders[-1]["timestamp"] if self._orders else None,
|
||||
}
|
||||
Reference in New Issue
Block a user