import argparse from app.ml.price_prediction.trainer import PricePredictionTrainer from app.ml.rl_battery.trainer import BatteryRLTrainer from app.utils.logger import get_logger, setup_logger setup_logger() logger = get_logger(__name__) def main(): parser = argparse.ArgumentParser(description="Energy Trading ML Training CLI") subparsers = parser.add_subparsers(dest="command", help="Available commands") price_parser = subparsers.add_parser("price", help="Train price prediction models") price_parser.add_argument("--horizons", nargs="+", type=int, default=[1, 5, 15, 60], help="Prediction horizons in minutes") price_parser.add_argument("--output", type=str, default="models/price_prediction", help="Output directory") rl_parser = subparsers.add_parser("rl", help="Train RL battery policy") rl_parser.add_argument("--episodes", type=int, default=1000, help="Number of training episodes") rl_parser.add_argument("--region", type=str, default="FR", help="Region to train for") rl_parser.add_argument("--output", type=str, default="models/rl_battery", help="Output directory") args = parser.parse_args() if args.command == "price": logger.info(f"Training price prediction models for horizons: {args.horizons}") trainer = PricePredictionTrainer() results = trainer.train_all(horizons=args.horizons) trainer.save_models(output_dir=args.output) logger.info("Training complete!") for horizon, result in results.items(): if "error" not in result: logger.info(f" {horizon}m: MAE={result['mae']:.2f}, RMSE={result['rmse']:.2f}, R2={result['r2']:.3f}") elif args.command == "rl": logger.info(f"Training RL battery policy for {args.episodes} episodes") trainer = BatteryRLTrainer() results = trainer.train(n_episodes=args.episodes, region=args.region) trainer.save(output_dir=args.output) logger.info("Training complete!") logger.info(f" Final avg reward: {results['final_avg_reward']:.2f}") logger.info(f" Final epsilon: {results['final_epsilon']:.3f}") else: parser.print_help() if __name__ == "__main__": main()