Implements FastAPI backend with ML model support for energy trading, including price prediction models and RL-based battery trading policy. Features dashboard, trading, backtest, and settings API routes with WebSocket support for real-time updates.
66 lines
2.0 KiB
Python
66 lines
2.0 KiB
Python
from typing import Dict
|
|
from app.ml.rl_battery.agent import QLearningAgent
|
|
from app.ml.rl_battery.environment import BatteryEnvironment
|
|
from app.utils.logger import get_logger
|
|
|
|
logger = get_logger(__name__)
|
|
|
|
|
|
class BatteryPolicy:
|
|
def __init__(self, policy_path: str = "models/rl_battery"):
|
|
self.policy_path = policy_path
|
|
self.agent: QLearningAgent = None
|
|
self.env: BatteryEnvironment = None
|
|
self._load_policy()
|
|
|
|
def _load_policy(self):
|
|
from pathlib import Path
|
|
|
|
filepath = Path(self.policy_path) / "battery_policy.pkl"
|
|
if filepath.exists():
|
|
self.agent = QLearningAgent.load(filepath)
|
|
self.env = BatteryEnvironment()
|
|
logger.info(f"Loaded policy from {filepath}")
|
|
|
|
def get_action(
|
|
self,
|
|
charge_level: float,
|
|
current_price: float,
|
|
price_forecast_1m: float = 0,
|
|
price_forecast_5m: float = 0,
|
|
price_forecast_15m: float = 0,
|
|
hour: int = 0,
|
|
) -> Dict:
|
|
if self.agent is None:
|
|
return {
|
|
"action": "hold",
|
|
"q_values": [0.0, 0.0, 0.0],
|
|
"confidence": 0.0,
|
|
}
|
|
|
|
self.env.charge_level = charge_level
|
|
self.env.current_price = current_price
|
|
self.env.time_step = hour * 60
|
|
|
|
state = self.env._get_state()
|
|
|
|
action_idx = self.agent.get_action(state, training=False)
|
|
|
|
actions = ["charge", "hold", "discharge"]
|
|
action_name = actions[action_idx]
|
|
|
|
state_idx = self.agent._discretize_state(state)
|
|
q_values = self.agent.q_table[state_idx].tolist() if self.agent.q_table is not None else [0.0, 0.0, 0.0]
|
|
|
|
max_q = max(q_values) if q_values else 0.0
|
|
confidence = (max_q - min(q_values)) / (max_q + 0.001) if q_values else 0.0
|
|
|
|
return {
|
|
"action": action_name,
|
|
"q_values": q_values,
|
|
"confidence": min(confidence, 1.0),
|
|
}
|
|
|
|
|
|
__all__ = ["BatteryPolicy"]
|