from typing import Dict, Optional import numpy as np import pickle from app.utils.logger import get_logger logger = get_logger(__name__) class QLearningAgent: def __init__( self, state_bins: int = 10, action_space: int = 3, learning_rate: float = 0.1, discount_factor: float = 0.95, epsilon: float = 1.0, epsilon_decay: float = 0.995, epsilon_min: float = 0.05, ): self.state_bins = state_bins self.action_space = action_space self.learning_rate = learning_rate self.discount_factor = discount_factor self.epsilon = epsilon self.epsilon_decay = epsilon_decay self.epsilon_min = epsilon_min self.q_table: Optional[np.ndarray] = None self.policy_id = "battery_policy" def initialize_q_table(self, observation_space: int): self.q_table = np.zeros((self.state_bins ** observation_space, self.action_space)) def _discretize_state(self, state: np.ndarray) -> int: discretized = (state * self.state_bins).astype(int) discretized = np.clip(discretized, 0, self.state_bins - 1) index = 0 multiplier = 1 for val in discretized: index += val * multiplier multiplier *= self.state_bins return index def get_action(self, state: np.ndarray, training: bool = True) -> int: state_idx = self._discretize_state(state) if training and np.random.random() < self.epsilon: return np.random.randint(self.action_space) if self.q_table is None: return 1 return np.argmax(self.q_table[state_idx]) def update(self, state: np.ndarray, action: int, reward: float, next_state: np.ndarray, done: bool): if self.q_table is None: return state_idx = self._discretize_state(state) next_state_idx = self._discretize_state(next_state) current_q = self.q_table[state_idx, action] if done: target = reward else: next_q = np.max(self.q_table[next_state_idx]) target = reward + self.discount_factor * next_q self.q_table[state_idx, action] += self.learning_rate * (target - current_q) def decay_epsilon(self): self.epsilon = max(self.epsilon_min, self.epsilon * self.epsilon_decay) def save(self, filepath: str): with open(filepath, "wb") as f: pickle.dump(self, f) logger.info(f"Saved Q-learning policy to {filepath}") @classmethod def load(cls, filepath: str): with open(filepath, "rb") as f: return pickle.load(f) __all__ = ["QLearningAgent"]