Implements FastAPI backend with ML model support for energy trading, including price prediction models and RL-based battery trading policy. Features dashboard, trading, backtest, and settings API routes with WebSocket support for real-time updates.
89 lines
2.6 KiB
Python
89 lines
2.6 KiB
Python
from typing import Dict, Optional
|
|
import numpy as np
|
|
import pickle
|
|
from app.utils.logger import get_logger
|
|
|
|
logger = get_logger(__name__)
|
|
|
|
|
|
class QLearningAgent:
|
|
def __init__(
|
|
self,
|
|
state_bins: int = 10,
|
|
action_space: int = 3,
|
|
learning_rate: float = 0.1,
|
|
discount_factor: float = 0.95,
|
|
epsilon: float = 1.0,
|
|
epsilon_decay: float = 0.995,
|
|
epsilon_min: float = 0.05,
|
|
):
|
|
self.state_bins = state_bins
|
|
self.action_space = action_space
|
|
self.learning_rate = learning_rate
|
|
self.discount_factor = discount_factor
|
|
self.epsilon = epsilon
|
|
self.epsilon_decay = epsilon_decay
|
|
self.epsilon_min = epsilon_min
|
|
|
|
self.q_table: Optional[np.ndarray] = None
|
|
self.policy_id = "battery_policy"
|
|
|
|
def initialize_q_table(self, observation_space: int):
|
|
self.q_table = np.zeros((self.state_bins ** observation_space, self.action_space))
|
|
|
|
def _discretize_state(self, state: np.ndarray) -> int:
|
|
discretized = (state * self.state_bins).astype(int)
|
|
discretized = np.clip(discretized, 0, self.state_bins - 1)
|
|
|
|
index = 0
|
|
multiplier = 1
|
|
for val in discretized:
|
|
index += val * multiplier
|
|
multiplier *= self.state_bins
|
|
|
|
return index
|
|
|
|
def get_action(self, state: np.ndarray, training: bool = True) -> int:
|
|
state_idx = self._discretize_state(state)
|
|
|
|
if training and np.random.random() < self.epsilon:
|
|
return np.random.randint(self.action_space)
|
|
|
|
if self.q_table is None:
|
|
return 1
|
|
|
|
return np.argmax(self.q_table[state_idx])
|
|
|
|
def update(self, state: np.ndarray, action: int, reward: float, next_state: np.ndarray, done: bool):
|
|
if self.q_table is None:
|
|
return
|
|
|
|
state_idx = self._discretize_state(state)
|
|
next_state_idx = self._discretize_state(next_state)
|
|
|
|
current_q = self.q_table[state_idx, action]
|
|
|
|
if done:
|
|
target = reward
|
|
else:
|
|
next_q = np.max(self.q_table[next_state_idx])
|
|
target = reward + self.discount_factor * next_q
|
|
|
|
self.q_table[state_idx, action] += self.learning_rate * (target - current_q)
|
|
|
|
def decay_epsilon(self):
|
|
self.epsilon = max(self.epsilon_min, self.epsilon * self.epsilon_decay)
|
|
|
|
def save(self, filepath: str):
|
|
with open(filepath, "wb") as f:
|
|
pickle.dump(self, f)
|
|
logger.info(f"Saved Q-learning policy to {filepath}")
|
|
|
|
@classmethod
|
|
def load(cls, filepath: str):
|
|
with open(filepath, "rb") as f:
|
|
return pickle.load(f)
|
|
|
|
|
|
__all__ = ["QLearningAgent"]
|