Files
energy-trade/backend/app/ml/rl_battery/agent.py
kbt-devops fe76bc7629 Add FastAPI backend for energy trading system
Implements FastAPI backend with ML model support for energy trading,
including price prediction models and RL-based battery trading policy.
Features dashboard, trading, backtest, and settings API routes with
WebSocket support for real-time updates.
2026-02-12 00:59:26 +07:00

89 lines
2.6 KiB
Python

from typing import Dict, Optional
import numpy as np
import pickle
from app.utils.logger import get_logger
logger = get_logger(__name__)
class QLearningAgent:
def __init__(
self,
state_bins: int = 10,
action_space: int = 3,
learning_rate: float = 0.1,
discount_factor: float = 0.95,
epsilon: float = 1.0,
epsilon_decay: float = 0.995,
epsilon_min: float = 0.05,
):
self.state_bins = state_bins
self.action_space = action_space
self.learning_rate = learning_rate
self.discount_factor = discount_factor
self.epsilon = epsilon
self.epsilon_decay = epsilon_decay
self.epsilon_min = epsilon_min
self.q_table: Optional[np.ndarray] = None
self.policy_id = "battery_policy"
def initialize_q_table(self, observation_space: int):
self.q_table = np.zeros((self.state_bins ** observation_space, self.action_space))
def _discretize_state(self, state: np.ndarray) -> int:
discretized = (state * self.state_bins).astype(int)
discretized = np.clip(discretized, 0, self.state_bins - 1)
index = 0
multiplier = 1
for val in discretized:
index += val * multiplier
multiplier *= self.state_bins
return index
def get_action(self, state: np.ndarray, training: bool = True) -> int:
state_idx = self._discretize_state(state)
if training and np.random.random() < self.epsilon:
return np.random.randint(self.action_space)
if self.q_table is None:
return 1
return np.argmax(self.q_table[state_idx])
def update(self, state: np.ndarray, action: int, reward: float, next_state: np.ndarray, done: bool):
if self.q_table is None:
return
state_idx = self._discretize_state(state)
next_state_idx = self._discretize_state(next_state)
current_q = self.q_table[state_idx, action]
if done:
target = reward
else:
next_q = np.max(self.q_table[next_state_idx])
target = reward + self.discount_factor * next_q
self.q_table[state_idx, action] += self.learning_rate * (target - current_q)
def decay_epsilon(self):
self.epsilon = max(self.epsilon_min, self.epsilon * self.epsilon_decay)
def save(self, filepath: str):
with open(filepath, "wb") as f:
pickle.dump(self, f)
logger.info(f"Saved Q-learning policy to {filepath}")
@classmethod
def load(cls, filepath: str):
with open(filepath, "rb") as f:
return pickle.load(f)
__all__ = ["QLearningAgent"]