from typing import Dict from datetime import datetime from app.utils.logger import get_logger from app.models.schemas import TrainingRequest, TrainingStatusEnum import uuid logger = get_logger(__name__) async def train_model_task(training_id: str, request: TrainingRequest): logger.info(f"Training model: {request.model_type.value}, horizon: {request.horizon}") try: if request.model_type.value == "price_prediction": from app.ml.price_prediction.trainer import PricePredictionTrainer trainer = PricePredictionTrainer() results = trainer.train_all(horizons=[request.horizon] if request.horizon else None) trainer.save_models() return { "training_id": training_id, "status": TrainingStatusEnum.COMPLETED, "results": results, "completed_at": datetime.utcnow().isoformat(), } elif request.model_type.value == "rl_battery": from app.ml.rl_battery.trainer import BatteryRLTrainer trainer = BatteryRLTrainer() results = trainer.train(n_episodes=500) trainer.save() return { "training_id": training_id, "status": TrainingStatusEnum.COMPLETED, "results": results, "completed_at": datetime.utcnow().isoformat(), } else: raise ValueError(f"Unknown model type: {request.model_type}") except Exception as e: logger.error(f"Training failed: {e}") raise __all__ = ["train_model_task"]