from typing import Dict from app.ml.rl_battery.agent import QLearningAgent from app.ml.rl_battery.environment import BatteryEnvironment from app.utils.logger import get_logger logger = get_logger(__name__) class BatteryPolicy: def __init__(self, policy_path: str = "models/rl_battery"): self.policy_path = policy_path self.agent: QLearningAgent = None self.env: BatteryEnvironment = None self._load_policy() def _load_policy(self): from pathlib import Path filepath = Path(self.policy_path) / "battery_policy.pkl" if filepath.exists(): self.agent = QLearningAgent.load(filepath) self.env = BatteryEnvironment() logger.info(f"Loaded policy from {filepath}") def get_action( self, charge_level: float, current_price: float, price_forecast_1m: float = 0, price_forecast_5m: float = 0, price_forecast_15m: float = 0, hour: int = 0, ) -> Dict: if self.agent is None: return { "action": "hold", "q_values": [0.0, 0.0, 0.0], "confidence": 0.0, } self.env.charge_level = charge_level self.env.current_price = current_price self.env.time_step = hour * 60 state = self.env._get_state() action_idx = self.agent.get_action(state, training=False) actions = ["charge", "hold", "discharge"] action_name = actions[action_idx] state_idx = self.agent._discretize_state(state) q_values = self.agent.q_table[state_idx].tolist() if self.agent.q_table is not None else [0.0, 0.0, 0.0] max_q = max(q_values) if q_values else 0.0 confidence = (max_q - min(q_values)) / (max_q + 0.001) if q_values else 0.0 return { "action": action_name, "q_values": q_values, "confidence": min(confidence, 1.0), } __all__ = ["BatteryPolicy"]