Implements FastAPI backend with ML model support for energy trading, including price prediction models and RL-based battery trading policy. Features dashboard, trading, backtest, and settings API routes with WebSocket support for real-time updates.
50 lines
2.2 KiB
Python
50 lines
2.2 KiB
Python
import argparse
|
|
from app.ml.price_prediction.trainer import PricePredictionTrainer
|
|
from app.ml.rl_battery.trainer import BatteryRLTrainer
|
|
from app.utils.logger import get_logger, setup_logger
|
|
|
|
setup_logger()
|
|
logger = get_logger(__name__)
|
|
|
|
|
|
def main():
|
|
parser = argparse.ArgumentParser(description="Energy Trading ML Training CLI")
|
|
subparsers = parser.add_subparsers(dest="command", help="Available commands")
|
|
|
|
price_parser = subparsers.add_parser("price", help="Train price prediction models")
|
|
price_parser.add_argument("--horizons", nargs="+", type=int, default=[1, 5, 15, 60], help="Prediction horizons in minutes")
|
|
price_parser.add_argument("--output", type=str, default="models/price_prediction", help="Output directory")
|
|
|
|
rl_parser = subparsers.add_parser("rl", help="Train RL battery policy")
|
|
rl_parser.add_argument("--episodes", type=int, default=1000, help="Number of training episodes")
|
|
rl_parser.add_argument("--region", type=str, default="FR", help="Region to train for")
|
|
rl_parser.add_argument("--output", type=str, default="models/rl_battery", help="Output directory")
|
|
|
|
args = parser.parse_args()
|
|
|
|
if args.command == "price":
|
|
logger.info(f"Training price prediction models for horizons: {args.horizons}")
|
|
trainer = PricePredictionTrainer()
|
|
results = trainer.train_all(horizons=args.horizons)
|
|
trainer.save_models(output_dir=args.output)
|
|
logger.info("Training complete!")
|
|
for horizon, result in results.items():
|
|
if "error" not in result:
|
|
logger.info(f" {horizon}m: MAE={result['mae']:.2f}, RMSE={result['rmse']:.2f}, R2={result['r2']:.3f}")
|
|
|
|
elif args.command == "rl":
|
|
logger.info(f"Training RL battery policy for {args.episodes} episodes")
|
|
trainer = BatteryRLTrainer()
|
|
results = trainer.train(n_episodes=args.episodes, region=args.region)
|
|
trainer.save(output_dir=args.output)
|
|
logger.info("Training complete!")
|
|
logger.info(f" Final avg reward: {results['final_avg_reward']:.2f}")
|
|
logger.info(f" Final epsilon: {results['final_epsilon']:.3f}")
|
|
|
|
else:
|
|
parser.print_help()
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|