Files
energy-trade/backend/app/ml/training/cli.py
kbt-devops fe76bc7629 Add FastAPI backend for energy trading system
Implements FastAPI backend with ML model support for energy trading,
including price prediction models and RL-based battery trading policy.
Features dashboard, trading, backtest, and settings API routes with
WebSocket support for real-time updates.
2026-02-12 00:59:26 +07:00

50 lines
2.2 KiB
Python

import argparse
from app.ml.price_prediction.trainer import PricePredictionTrainer
from app.ml.rl_battery.trainer import BatteryRLTrainer
from app.utils.logger import get_logger, setup_logger
setup_logger()
logger = get_logger(__name__)
def main():
parser = argparse.ArgumentParser(description="Energy Trading ML Training CLI")
subparsers = parser.add_subparsers(dest="command", help="Available commands")
price_parser = subparsers.add_parser("price", help="Train price prediction models")
price_parser.add_argument("--horizons", nargs="+", type=int, default=[1, 5, 15, 60], help="Prediction horizons in minutes")
price_parser.add_argument("--output", type=str, default="models/price_prediction", help="Output directory")
rl_parser = subparsers.add_parser("rl", help="Train RL battery policy")
rl_parser.add_argument("--episodes", type=int, default=1000, help="Number of training episodes")
rl_parser.add_argument("--region", type=str, default="FR", help="Region to train for")
rl_parser.add_argument("--output", type=str, default="models/rl_battery", help="Output directory")
args = parser.parse_args()
if args.command == "price":
logger.info(f"Training price prediction models for horizons: {args.horizons}")
trainer = PricePredictionTrainer()
results = trainer.train_all(horizons=args.horizons)
trainer.save_models(output_dir=args.output)
logger.info("Training complete!")
for horizon, result in results.items():
if "error" not in result:
logger.info(f" {horizon}m: MAE={result['mae']:.2f}, RMSE={result['rmse']:.2f}, R2={result['r2']:.3f}")
elif args.command == "rl":
logger.info(f"Training RL battery policy for {args.episodes} episodes")
trainer = BatteryRLTrainer()
results = trainer.train(n_episodes=args.episodes, region=args.region)
trainer.save(output_dir=args.output)
logger.info("Training complete!")
logger.info(f" Final avg reward: {results['final_avg_reward']:.2f}")
logger.info(f" Final epsilon: {results['final_epsilon']:.3f}")
else:
parser.print_help()
if __name__ == "__main__":
main()