Add parallel checker execution with connection pooling
Implements Level 2 parallelization for row_count, schema, and
aggregate checkers, improving performance by 2-3x for tables with
multiple enabled checks.
Changes:
- Add max_workers config option (default: 4)
- Add ConnectionPool module with SQLAlchemy QueuePool
- Add URL encoding for connection strings
- Implement parallel checker execution with ThreadPoolExecutor
- Add fail-fast behavior on checker errors
- Update executor for SQLAlchemy 2.0 compatibility
- Fix engine disposal resource leak
- Cache pooled engines in ConnectionManager
- Add disconnect() cleanup for pooled engines
Performance:
- Sequential: 3 checkers × 100ms = 300ms
- Parallel: 3 checkers ≈ 100ms (2-3x speedup)
Configuration:
execution:
max_workers: 4 # Controls parallel checker execution
continue_on_error: true
This commit is contained in:
@@ -195,22 +195,19 @@ logging:
|
|||||||
# Configure execution behavior
|
# Configure execution behavior
|
||||||
# ============================================================================
|
# ============================================================================
|
||||||
execution:
|
execution:
|
||||||
# Parallel execution (future feature)
|
# Continue execution even if a table check fails
|
||||||
parallel:
|
continue_on_error: true
|
||||||
enabled: false
|
|
||||||
max_workers: 4
|
# Maximum number of parallel workers for checker execution
|
||||||
|
# Higher values = more parallel execution, but more database connections
|
||||||
|
# Recommended: 4 for most scenarios, 8 for high-performance servers
|
||||||
|
# Connection pool size = max_workers + 2
|
||||||
|
max_workers: 4
|
||||||
|
|
||||||
# Retry settings for transient failures
|
# Retry settings for transient failures
|
||||||
retry:
|
retry:
|
||||||
enabled: true
|
attempts: 3
|
||||||
max_attempts: 3
|
|
||||||
delay_seconds: 5
|
delay_seconds: 5
|
||||||
|
|
||||||
# Performance settings
|
|
||||||
performance:
|
|
||||||
batch_size: 1000 # Rows per batch for large queries
|
|
||||||
use_nolock: true # Use NOLOCK hints (read uncommitted)
|
|
||||||
connection_pooling: true
|
|
||||||
|
|
||||||
# ============================================================================
|
# ============================================================================
|
||||||
# FILTERS
|
# FILTERS
|
||||||
|
|||||||
@@ -78,6 +78,7 @@ class ExecutionConfig(BaseModel):
|
|||||||
"""Execution settings."""
|
"""Execution settings."""
|
||||||
continue_on_error: bool = True
|
continue_on_error: bool = True
|
||||||
retry: Dict[str, int] = Field(default_factory=lambda: {"attempts": 3, "delay_seconds": 5})
|
retry: Dict[str, int] = Field(default_factory=lambda: {"attempts": 3, "delay_seconds": 5})
|
||||||
|
max_workers: int = 4
|
||||||
|
|
||||||
|
|
||||||
class TableFilterConfig(BaseModel):
|
class TableFilterConfig(BaseModel):
|
||||||
|
|||||||
@@ -2,7 +2,7 @@
|
|||||||
|
|
||||||
import pyodbc
|
import pyodbc
|
||||||
import platform
|
import platform
|
||||||
from typing import Optional
|
from typing import Optional, Any
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
from drt.config.models import ConnectionConfig
|
from drt.config.models import ConnectionConfig
|
||||||
from drt.utils.logging import get_logger
|
from drt.utils.logging import get_logger
|
||||||
@@ -56,6 +56,7 @@ class ConnectionManager:
|
|||||||
"""
|
"""
|
||||||
self.config = config
|
self.config = config
|
||||||
self._connection: Optional[pyodbc.Connection] = None
|
self._connection: Optional[pyodbc.Connection] = None
|
||||||
|
self._pooled_engine: Optional[Any] = None
|
||||||
|
|
||||||
def connect(self) -> pyodbc.Connection:
|
def connect(self) -> pyodbc.Connection:
|
||||||
"""
|
"""
|
||||||
@@ -114,11 +115,16 @@ class ConnectionManager:
|
|||||||
raise
|
raise
|
||||||
|
|
||||||
def disconnect(self) -> None:
|
def disconnect(self) -> None:
|
||||||
"""Close database connection."""
|
"""Close database connection and dispose pooled engine."""
|
||||||
if self._connection and not self._connection.closed:
|
if self._connection and not self._connection.closed:
|
||||||
self._connection.close()
|
self._connection.close()
|
||||||
logger.info("Connection closed")
|
logger.info("Connection closed")
|
||||||
self._connection = None
|
self._connection = None
|
||||||
|
|
||||||
|
if self._pooled_engine is not None:
|
||||||
|
self._pooled_engine.dispose()
|
||||||
|
self._pooled_engine = None
|
||||||
|
logger.info("Pooled engine disposed")
|
||||||
|
|
||||||
@contextmanager
|
@contextmanager
|
||||||
def get_connection(self):
|
def get_connection(self):
|
||||||
@@ -132,6 +138,10 @@ class ConnectionManager:
|
|||||||
with conn_mgr.get_connection() as conn:
|
with conn_mgr.get_connection() as conn:
|
||||||
cursor = conn.cursor()
|
cursor = conn.cursor()
|
||||||
cursor.execute("SELECT 1")
|
cursor.execute("SELECT 1")
|
||||||
|
|
||||||
|
Note:
|
||||||
|
This method is not thread-safe. For parallel execution,
|
||||||
|
use ConnectionPool from connection_pool module instead.
|
||||||
"""
|
"""
|
||||||
conn = self.connect()
|
conn = self.connect()
|
||||||
try:
|
try:
|
||||||
@@ -140,6 +150,27 @@ class ConnectionManager:
|
|||||||
# Don't close connection here - reuse it
|
# Don't close connection here - reuse it
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
def get_pooled_engine(self, pool_size: int = 6) -> Any:
|
||||||
|
"""
|
||||||
|
Get a pooled SQLAlchemy engine for parallel execution.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
pool_size: Number of connections to maintain in pool
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
SQLAlchemy engine with connection pooling
|
||||||
|
|
||||||
|
Note:
|
||||||
|
Use this for parallel/async execution scenarios where
|
||||||
|
multiple threads need database access simultaneously.
|
||||||
|
The engine is cached and reused for subsequent calls.
|
||||||
|
"""
|
||||||
|
if self._pooled_engine is None:
|
||||||
|
from drt.database.connection_pool import ConnectionPool
|
||||||
|
pool = ConnectionPool(self.config, pool_size)
|
||||||
|
self._pooled_engine = pool.get_engine()
|
||||||
|
return self._pooled_engine
|
||||||
|
|
||||||
def test_connection(self) -> bool:
|
def test_connection(self) -> bool:
|
||||||
"""
|
"""
|
||||||
Test database connectivity.
|
Test database connectivity.
|
||||||
|
|||||||
164
src/drt/database/connection_pool.py
Normal file
164
src/drt/database/connection_pool.py
Normal file
@@ -0,0 +1,164 @@
|
|||||||
|
"""Thread-safe connection pooling for parallel query execution."""
|
||||||
|
|
||||||
|
import platform
|
||||||
|
from typing import Optional, Any
|
||||||
|
from contextlib import contextmanager
|
||||||
|
from urllib.parse import quote_plus
|
||||||
|
from sqlalchemy import create_engine, text
|
||||||
|
from sqlalchemy.pool import QueuePool
|
||||||
|
from drt.config.models import ConnectionConfig
|
||||||
|
from drt.utils.logging import get_logger
|
||||||
|
|
||||||
|
|
||||||
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def get_odbc_driver() -> str:
|
||||||
|
"""
|
||||||
|
Detect available ODBC driver for SQL Server.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
ODBC driver name
|
||||||
|
"""
|
||||||
|
import pyodbc
|
||||||
|
|
||||||
|
drivers = [driver for driver in pyodbc.drivers() if 'SQL Server' in driver]
|
||||||
|
|
||||||
|
preferred_order = [
|
||||||
|
'ODBC Driver 18 for SQL Server',
|
||||||
|
'ODBC Driver 17 for SQL Server',
|
||||||
|
'ODBC Driver 13 for SQL Server',
|
||||||
|
'SQL Server Native Client 11.0',
|
||||||
|
'SQL Server'
|
||||||
|
]
|
||||||
|
|
||||||
|
for preferred in preferred_order:
|
||||||
|
if preferred in drivers:
|
||||||
|
logger.debug(f"Using ODBC driver: {preferred}")
|
||||||
|
return preferred
|
||||||
|
|
||||||
|
if drivers:
|
||||||
|
logger.warning(f"Using fallback driver: {drivers[0]}")
|
||||||
|
return drivers[0]
|
||||||
|
|
||||||
|
logger.warning("No SQL Server ODBC driver found, using default")
|
||||||
|
return 'ODBC Driver 17 for SQL Server'
|
||||||
|
|
||||||
|
|
||||||
|
def build_connection_string(config: ConnectionConfig) -> str:
|
||||||
|
"""
|
||||||
|
Build ODBC connection string from config.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
config: Connection configuration
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
ODBC connection string
|
||||||
|
"""
|
||||||
|
driver = get_odbc_driver()
|
||||||
|
|
||||||
|
conn_str_parts = [
|
||||||
|
f"DRIVER={{{driver}}}",
|
||||||
|
f"SERVER={config.server}",
|
||||||
|
f"DATABASE={config.database}",
|
||||||
|
f"Connection Timeout={config.timeout.get('connection', 30)}"
|
||||||
|
]
|
||||||
|
|
||||||
|
if hasattr(config, 'username') and config.username:
|
||||||
|
conn_str_parts.append(f"UID={config.username}")
|
||||||
|
conn_str_parts.append(f"PWD={config.password}")
|
||||||
|
auth_type = "SQL Authentication"
|
||||||
|
else:
|
||||||
|
conn_str_parts.append("Trusted_Connection=yes")
|
||||||
|
auth_type = "Windows Authentication"
|
||||||
|
|
||||||
|
if platform.system() != 'Windows':
|
||||||
|
conn_str_parts.append("TrustServerCertificate=yes")
|
||||||
|
|
||||||
|
logger.debug(f"Connection string (masked): {';'.join([p if 'PWD=' not in p else 'PWD=***' for p in conn_str_parts])}")
|
||||||
|
logger.info(f"Connection pool for {config.server}.{config.database} ({auth_type})")
|
||||||
|
|
||||||
|
return ";".join(conn_str_parts) + ";"
|
||||||
|
|
||||||
|
|
||||||
|
class ConnectionPool:
|
||||||
|
"""Thread-safe connection pool for parallel query execution."""
|
||||||
|
|
||||||
|
def __init__(self, config: ConnectionConfig, pool_size: int = 6):
|
||||||
|
"""
|
||||||
|
Initialize connection pool.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
config: Connection configuration
|
||||||
|
pool_size: Maximum number of connections in pool
|
||||||
|
"""
|
||||||
|
self.config = config
|
||||||
|
self._engine: Optional[Any] = None
|
||||||
|
self._pool_size = pool_size
|
||||||
|
|
||||||
|
def get_engine(self):
|
||||||
|
"""
|
||||||
|
Get or create SQLAlchemy engine with connection pooling.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
SQLAlchemy engine
|
||||||
|
"""
|
||||||
|
if self._engine is None:
|
||||||
|
conn_str = build_connection_string(self.config)
|
||||||
|
|
||||||
|
query_timeout = self.config.timeout.get('query', 300)
|
||||||
|
|
||||||
|
self._engine = create_engine(
|
||||||
|
f"mssql+pyodbc:///?odbc_connect={quote_plus(conn_str)}",
|
||||||
|
poolclass=QueuePool,
|
||||||
|
pool_size=self._pool_size,
|
||||||
|
max_overflow=self._pool_size,
|
||||||
|
pool_pre_ping=True,
|
||||||
|
pool_recycle=3600,
|
||||||
|
connect_args={
|
||||||
|
"timeout": query_timeout,
|
||||||
|
"timeout_before_cancel": query_timeout - 10
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info(f"Connection pool created: size={self._pool_size}, max_overflow={self._pool_size}")
|
||||||
|
|
||||||
|
return self._engine
|
||||||
|
|
||||||
|
@contextmanager
|
||||||
|
def get_connection(self):
|
||||||
|
"""
|
||||||
|
Context manager for pooled connection.
|
||||||
|
|
||||||
|
Yields:
|
||||||
|
SQLAlchemy connection
|
||||||
|
|
||||||
|
Example:
|
||||||
|
with pool.get_connection() as conn:
|
||||||
|
result = conn.execute(text("SELECT 1"))
|
||||||
|
"""
|
||||||
|
engine = self.get_engine()
|
||||||
|
conn = engine.connect()
|
||||||
|
try:
|
||||||
|
yield conn
|
||||||
|
finally:
|
||||||
|
conn.close()
|
||||||
|
|
||||||
|
def dispose(self) -> None:
|
||||||
|
"""Dispose of the connection pool and close all connections."""
|
||||||
|
if self._engine is not None:
|
||||||
|
self._engine.dispose()
|
||||||
|
self._engine = None
|
||||||
|
logger.info("Connection pool disposed")
|
||||||
|
|
||||||
|
def __enter__(self):
|
||||||
|
"""Context manager entry."""
|
||||||
|
return self
|
||||||
|
|
||||||
|
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||||
|
"""Context manager exit."""
|
||||||
|
self.dispose()
|
||||||
|
|
||||||
|
def __del__(self):
|
||||||
|
"""Cleanup on deletion."""
|
||||||
|
self.dispose()
|
||||||
@@ -2,7 +2,8 @@
|
|||||||
|
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
import time
|
import time
|
||||||
from typing import Any, Dict, List, Optional, Tuple
|
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||||
|
from sqlalchemy import text
|
||||||
from drt.database.connection import ConnectionManager
|
from drt.database.connection import ConnectionManager
|
||||||
from drt.database.queries import SQLQueries
|
from drt.database.queries import SQLQueries
|
||||||
from drt.models.enums import Status
|
from drt.models.enums import Status
|
||||||
@@ -14,22 +15,28 @@ logger = get_logger(__name__)
|
|||||||
class QueryExecutor:
|
class QueryExecutor:
|
||||||
"""Executes READ ONLY queries against the database."""
|
"""Executes READ ONLY queries against the database."""
|
||||||
|
|
||||||
def __init__(self, connection_manager: ConnectionManager):
|
def __init__(
|
||||||
|
self,
|
||||||
|
connection_manager: ConnectionManager,
|
||||||
|
engine: Optional[Any] = None
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
Initialize query executor.
|
Initialize query executor.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
connection_manager: Connection manager instance
|
connection_manager: Connection manager instance
|
||||||
|
engine: Optional SQLAlchemy engine for pooled connections
|
||||||
"""
|
"""
|
||||||
self.conn_mgr = connection_manager
|
self.conn_mgr = connection_manager
|
||||||
|
self._engine = engine
|
||||||
|
|
||||||
def execute_query(self, query: str, params: tuple = None) -> pd.DataFrame:
|
def execute_query(self, query: str, params: Optional[Union[tuple, Dict[str, Any]]] = None) -> pd.DataFrame:
|
||||||
"""
|
"""
|
||||||
Execute a SELECT query and return results as DataFrame.
|
Execute a SELECT query and return results as DataFrame.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
query: SQL query string (SELECT only)
|
query: SQL query string (SELECT only)
|
||||||
params: Query parameters
|
params: Query parameters (tuple or dict)
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Query results as pandas DataFrame
|
Query results as pandas DataFrame
|
||||||
@@ -38,7 +45,12 @@ class QueryExecutor:
|
|||||||
ValueError: If query is not a SELECT statement
|
ValueError: If query is not a SELECT statement
|
||||||
Exception: If query execution fails
|
Exception: If query execution fails
|
||||||
"""
|
"""
|
||||||
# Safety check - only allow SELECT queries
|
if self._engine is not None:
|
||||||
|
return self._execute_query_pooled(query, params)
|
||||||
|
return self._execute_query_single(query, params)
|
||||||
|
|
||||||
|
def _execute_query_single(self, query: str, params: Optional[Union[tuple, Dict[str, Any]]] = None) -> pd.DataFrame:
|
||||||
|
"""Execute query using single connection (legacy mode)."""
|
||||||
query_upper = query.strip().upper()
|
query_upper = query.strip().upper()
|
||||||
if not query_upper.startswith('SELECT'):
|
if not query_upper.startswith('SELECT'):
|
||||||
raise ValueError("Only SELECT queries are allowed (READ ONLY)")
|
raise ValueError("Only SELECT queries are allowed (READ ONLY)")
|
||||||
@@ -56,7 +68,24 @@ class QueryExecutor:
|
|||||||
logger.debug(f"Query: {query}")
|
logger.debug(f"Query: {query}")
|
||||||
raise
|
raise
|
||||||
|
|
||||||
def execute_scalar(self, query: str, params: tuple = None) -> Any:
|
def _execute_query_pooled(self, query: str, params: Optional[Union[tuple, Dict[str, Any]]] = None) -> pd.DataFrame:
|
||||||
|
"""Execute query using pooled connection (parallel mode)."""
|
||||||
|
query_upper = query.strip().upper()
|
||||||
|
if not query_upper.startswith('SELECT'):
|
||||||
|
raise ValueError("Only SELECT queries are allowed (READ ONLY)")
|
||||||
|
|
||||||
|
try:
|
||||||
|
with self._engine.connect() as conn:
|
||||||
|
result = conn.execute(text(query), params or {})
|
||||||
|
df = pd.DataFrame(result.all(), columns=result.keys())
|
||||||
|
return df
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Query execution failed (pooled): {e}")
|
||||||
|
logger.debug(f"Query: {query}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
def execute_scalar(self, query: str, params: Optional[Union[tuple, Dict[str, Any]]] = None) -> Any:
|
||||||
"""
|
"""
|
||||||
Execute query and return single scalar value.
|
Execute query and return single scalar value.
|
||||||
|
|
||||||
|
|||||||
@@ -1,14 +1,16 @@
|
|||||||
"""Comparison service for executing database comparisons."""
|
"""Comparison service for executing database comparisons."""
|
||||||
|
|
||||||
import time
|
import time
|
||||||
from typing import List
|
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||||
|
from typing import List, Dict, Any
|
||||||
from drt.database.connection import ConnectionManager
|
from drt.database.connection import ConnectionManager
|
||||||
|
from drt.database.connection_pool import ConnectionPool
|
||||||
from drt.database.executor import QueryExecutor
|
from drt.database.executor import QueryExecutor
|
||||||
from drt.config.models import Config, DatabasePairConfig
|
from drt.config.models import Config, DatabasePairConfig
|
||||||
from drt.models.table import TableInfo
|
from drt.models.table import TableInfo
|
||||||
from drt.models.results import ComparisonResult
|
from drt.models.results import ComparisonResult, CheckResult
|
||||||
from drt.models.summary import ExecutionSummary
|
from drt.models.summary import ExecutionSummary
|
||||||
from drt.models.enums import Status
|
from drt.models.enums import Status, CheckType
|
||||||
from drt.services.checkers import (
|
from drt.services.checkers import (
|
||||||
ExistenceChecker,
|
ExistenceChecker,
|
||||||
RowCountChecker,
|
RowCountChecker,
|
||||||
@@ -56,13 +58,23 @@ class ComparisonService:
|
|||||||
target_mgr = ConnectionManager(db_pair.target)
|
target_mgr = ConnectionManager(db_pair.target)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
baseline_engine = None
|
||||||
|
target_engine = None
|
||||||
|
|
||||||
# Connect to databases
|
# Connect to databases
|
||||||
baseline_mgr.connect()
|
baseline_mgr.connect()
|
||||||
target_mgr.connect()
|
target_mgr.connect()
|
||||||
|
|
||||||
# Create executors
|
max_workers = self.config.execution.max_workers
|
||||||
baseline_executor = QueryExecutor(baseline_mgr)
|
pool_size = max_workers + 2
|
||||||
target_executor = QueryExecutor(target_mgr)
|
|
||||||
|
# Create pooled engines for parallel execution
|
||||||
|
baseline_engine = baseline_mgr.get_pooled_engine(pool_size)
|
||||||
|
target_engine = target_mgr.get_pooled_engine(pool_size)
|
||||||
|
|
||||||
|
# Create executors with pooled connections
|
||||||
|
baseline_executor = QueryExecutor(baseline_mgr, engine=baseline_engine)
|
||||||
|
target_executor = QueryExecutor(target_mgr, engine=target_engine)
|
||||||
|
|
||||||
# Initialize checkers
|
# Initialize checkers
|
||||||
existence_checker = ExistenceChecker(baseline_executor, target_executor, self.config)
|
existence_checker = ExistenceChecker(baseline_executor, target_executor, self.config)
|
||||||
@@ -124,6 +136,10 @@ class ComparisonService:
|
|||||||
return summary
|
return summary
|
||||||
|
|
||||||
finally:
|
finally:
|
||||||
|
if baseline_engine is not None:
|
||||||
|
baseline_engine.dispose()
|
||||||
|
if target_engine is not None:
|
||||||
|
target_engine.dispose()
|
||||||
baseline_mgr.disconnect()
|
baseline_mgr.disconnect()
|
||||||
target_mgr.disconnect()
|
target_mgr.disconnect()
|
||||||
|
|
||||||
@@ -145,7 +161,7 @@ class ComparisonService:
|
|||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Check existence first
|
# Check existence first (must be sequential)
|
||||||
check_start = time.time()
|
check_start = time.time()
|
||||||
existence_result = existence_checker.check(table)
|
existence_result = existence_checker.check(table)
|
||||||
existence_time = (time.time() - check_start) * 1000
|
existence_time = (time.time() - check_start) * 1000
|
||||||
@@ -154,26 +170,17 @@ class ComparisonService:
|
|||||||
|
|
||||||
# Only proceed with other checks if table exists in both
|
# Only proceed with other checks if table exists in both
|
||||||
if existence_result.status == Status.PASS:
|
if existence_result.status == Status.PASS:
|
||||||
# Row count check
|
# Run row count, schema, and aggregate checkers in parallel
|
||||||
check_start = time.time()
|
parallel_results = self._run_checkers_parallel(
|
||||||
row_count_result = row_count_checker.check(table)
|
table,
|
||||||
row_count_time = (time.time() - check_start) * 1000
|
row_count_checker,
|
||||||
logger.debug(f" └─ Row count check: {row_count_time:.0f}ms")
|
schema_checker,
|
||||||
result.add_check(row_count_result)
|
aggregate_checker
|
||||||
|
)
|
||||||
# Schema check
|
|
||||||
check_start = time.time()
|
# Add all results to the comparison result
|
||||||
schema_result = schema_checker.check(table)
|
for name, check_result in parallel_results.items():
|
||||||
schema_time = (time.time() - check_start) * 1000
|
result.add_check(check_result)
|
||||||
logger.debug(f" └─ Schema check: {schema_time:.0f}ms")
|
|
||||||
result.add_check(schema_result)
|
|
||||||
|
|
||||||
# Aggregate check
|
|
||||||
check_start = time.time()
|
|
||||||
aggregate_result = aggregate_checker.check(table)
|
|
||||||
aggregate_time = (time.time() - check_start) * 1000
|
|
||||||
logger.debug(f" └─ Aggregate check: {aggregate_time:.0f}ms")
|
|
||||||
result.add_check(aggregate_result)
|
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Comparison failed for {table.full_name}: {e}")
|
logger.error(f"Comparison failed for {table.full_name}: {e}")
|
||||||
@@ -184,6 +191,108 @@ class ComparisonService:
|
|||||||
logger.debug(f" └─ Total table time: {result.execution_time_ms}ms")
|
logger.debug(f" └─ Total table time: {result.execution_time_ms}ms")
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
def _run_checkers_parallel(
|
||||||
|
self,
|
||||||
|
table: TableInfo,
|
||||||
|
row_count_checker: RowCountChecker,
|
||||||
|
schema_checker: SchemaChecker,
|
||||||
|
aggregate_checker: AggregateChecker
|
||||||
|
) -> Dict[str, CheckResult]:
|
||||||
|
"""
|
||||||
|
Run row count, schema, and aggregate checkers in parallel.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
table: Table information
|
||||||
|
row_count_checker: Row count checker instance
|
||||||
|
schema_checker: Schema checker instance
|
||||||
|
aggregate_checker: Aggregate checker instance
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary mapping checker names to their results
|
||||||
|
|
||||||
|
Note:
|
||||||
|
Uses fail-fast behavior - cancels remaining checkers on first error.
|
||||||
|
"""
|
||||||
|
max_workers = self.config.execution.max_workers
|
||||||
|
results: Dict[str, CheckResult] = {}
|
||||||
|
exceptions: Dict[str, Exception] = {}
|
||||||
|
|
||||||
|
# Build list of checkers to run in parallel
|
||||||
|
checkers_to_run: List[tuple] = []
|
||||||
|
|
||||||
|
if self.config.comparison.row_count.enabled:
|
||||||
|
checkers_to_run.append(("row_count", row_count_checker))
|
||||||
|
|
||||||
|
if self.config.comparison.schema.enabled:
|
||||||
|
checkers_to_run.append(("schema", schema_checker))
|
||||||
|
|
||||||
|
if self.config.comparison.aggregates.enabled and table.aggregate_columns:
|
||||||
|
checkers_to_run.append(("aggregate", aggregate_checker))
|
||||||
|
|
||||||
|
# If only one checker or none, run sequentially
|
||||||
|
if len(checkers_to_run) <= 1:
|
||||||
|
for name, checker in checkers_to_run:
|
||||||
|
check_start = time.time()
|
||||||
|
result = checker.check(table)
|
||||||
|
check_time = (time.time() - check_start) * 1000
|
||||||
|
logger.debug(f" └─ {name} check: {check_time:.0f}ms (sequential)")
|
||||||
|
results[name] = result
|
||||||
|
return results
|
||||||
|
|
||||||
|
# Run checkers in parallel
|
||||||
|
check_start = time.time()
|
||||||
|
logger.debug(f" └─ Running {len(checkers_to_run)} checkers in parallel (max_workers={max_workers})")
|
||||||
|
|
||||||
|
def run_checker(checker_data: tuple) -> tuple[str, CheckResult]:
|
||||||
|
"""Helper to run a single checker and return result."""
|
||||||
|
name, checker = checker_data
|
||||||
|
try:
|
||||||
|
result = checker.check(table)
|
||||||
|
return (name, result)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"{name} checker failed: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
with ThreadPoolExecutor(max_workers=max_workers) as executor:
|
||||||
|
futures = {
|
||||||
|
executor.submit(run_checker, checker_data): checker_data[0]
|
||||||
|
for checker_data in checkers_to_run
|
||||||
|
}
|
||||||
|
|
||||||
|
for future in as_completed(futures):
|
||||||
|
checker_name = futures[future]
|
||||||
|
try:
|
||||||
|
name, result = future.result()
|
||||||
|
results[name] = result
|
||||||
|
|
||||||
|
# Fail-fast on ERROR status
|
||||||
|
if result.status == Status.ERROR:
|
||||||
|
logger.debug(f" └─ {name} check failed with ERROR, cancelling remaining checkers")
|
||||||
|
for f in futures:
|
||||||
|
f.cancel()
|
||||||
|
break
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
exceptions[checker_name] = e
|
||||||
|
logger.debug(f" └─ {checker_name} check raised exception, cancelling remaining checkers")
|
||||||
|
for f in futures:
|
||||||
|
f.cancel()
|
||||||
|
break
|
||||||
|
|
||||||
|
check_time = (time.time() - check_start) * 1000
|
||||||
|
logger.debug(f" └─ Total parallel checker time: {check_time:.0f}ms")
|
||||||
|
|
||||||
|
# Handle exceptions
|
||||||
|
for name, exc in exceptions.items():
|
||||||
|
results[name] = CheckResult(
|
||||||
|
check_type=CheckType.ROW_COUNT if name == "row_count" else
|
||||||
|
CheckType.SCHEMA if name == "schema" else CheckType.AGGREGATE,
|
||||||
|
status=Status.ERROR,
|
||||||
|
message=f"{name} check error: {str(exc)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
return results
|
||||||
|
|
||||||
def _get_tables_to_compare(self) -> List[TableInfo]:
|
def _get_tables_to_compare(self) -> List[TableInfo]:
|
||||||
"""Get list of tables to compare based on configuration."""
|
"""Get list of tables to compare based on configuration."""
|
||||||
tables = []
|
tables = []
|
||||||
|
|||||||
177
tests/test_parallel_logic.py
Normal file
177
tests/test_parallel_logic.py
Normal file
@@ -0,0 +1,177 @@
|
|||||||
|
"""Test parallel checker execution logic."""
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from unittest.mock import Mock, MagicMock, patch
|
||||||
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
|
from typing import Dict, Any
|
||||||
|
import time
|
||||||
|
|
||||||
|
|
||||||
|
class TestParallelCheckerLogic:
|
||||||
|
"""Test parallel checker execution logic."""
|
||||||
|
|
||||||
|
def test_parallel_execution_with_multiple_checkers(self):
|
||||||
|
"""Test that multiple checkers can run in parallel."""
|
||||||
|
# Don't import modules that need pyodbc - test the logic conceptually
|
||||||
|
|
||||||
|
# Create mock results
|
||||||
|
from drt.models.results import CheckResult
|
||||||
|
from drt.models.enums import Status, CheckType
|
||||||
|
from drt.models.table import TableInfo
|
||||||
|
|
||||||
|
row_count_result = CheckResult(
|
||||||
|
check_type=CheckType.ROW_COUNT,
|
||||||
|
status=Status.PASS,
|
||||||
|
message="Row count matches"
|
||||||
|
)
|
||||||
|
schema_result = CheckResult(
|
||||||
|
check_type=CheckType.SCHEMA,
|
||||||
|
status=Status.PASS,
|
||||||
|
message="Schema matches"
|
||||||
|
)
|
||||||
|
aggregate_result = CheckResult(
|
||||||
|
check_type=CheckType.AGGREGATE,
|
||||||
|
status=Status.PASS,
|
||||||
|
message="Aggregates match"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Mock checkers
|
||||||
|
mock_row_count_checker = Mock()
|
||||||
|
mock_schema_checker = Mock()
|
||||||
|
mock_aggregate_checker = Mock()
|
||||||
|
|
||||||
|
# Simulate check execution with delay
|
||||||
|
def mock_check_row_count(table):
|
||||||
|
time.sleep(0.05)
|
||||||
|
return row_count_result
|
||||||
|
|
||||||
|
def mock_check_schema(table):
|
||||||
|
time.sleep(0.05)
|
||||||
|
return schema_result
|
||||||
|
|
||||||
|
def mock_check_aggregate(table):
|
||||||
|
time.sleep(0.05)
|
||||||
|
return aggregate_result
|
||||||
|
|
||||||
|
mock_row_count_checker.check = mock_check_row_count
|
||||||
|
mock_schema_checker.check = mock_check_schema
|
||||||
|
mock_aggregate_checker.check = mock_check_aggregate
|
||||||
|
|
||||||
|
# Create table mock
|
||||||
|
table = TableInfo(
|
||||||
|
schema="dbo",
|
||||||
|
name="TestTable",
|
||||||
|
enabled=True,
|
||||||
|
expected_in_target=True,
|
||||||
|
aggregate_columns=["Amount", "Quantity"]
|
||||||
|
)
|
||||||
|
|
||||||
|
# Test parallel execution timing
|
||||||
|
start = time.time()
|
||||||
|
|
||||||
|
results = {}
|
||||||
|
checkers = [
|
||||||
|
("row_count", mock_row_count_checker),
|
||||||
|
("schema", mock_schema_checker),
|
||||||
|
("aggregate", mock_aggregate_checker)
|
||||||
|
]
|
||||||
|
|
||||||
|
with ThreadPoolExecutor(max_workers=3) as executor:
|
||||||
|
futures = {
|
||||||
|
executor.submit(lambda c: c[1].check(table), c): c[0]
|
||||||
|
for c in checkers
|
||||||
|
}
|
||||||
|
|
||||||
|
for future in futures:
|
||||||
|
name = futures[future]
|
||||||
|
results[name] = future.result()
|
||||||
|
|
||||||
|
elapsed = time.time() - start
|
||||||
|
|
||||||
|
# Verify results
|
||||||
|
assert len(results) == 3
|
||||||
|
assert results["row_count"].status == Status.PASS
|
||||||
|
assert results["schema"].status == Status.PASS
|
||||||
|
assert results["aggregate"].status == Status.PASS
|
||||||
|
|
||||||
|
# Verify parallel execution (should be ~0.05s, not 0.15s if sequential)
|
||||||
|
assert elapsed < 0.15, f"Expected parallel execution but took {elapsed:.2f}s"
|
||||||
|
|
||||||
|
def test_fail_fast_on_error(self):
|
||||||
|
"""Test that parallel execution cancels remaining checkers on error."""
|
||||||
|
from drt.models.results import CheckResult
|
||||||
|
from drt.models.enums import Status, CheckType
|
||||||
|
|
||||||
|
# Mock checkers where second one fails
|
||||||
|
mock_checker1 = Mock()
|
||||||
|
mock_checker2 = Mock()
|
||||||
|
mock_checker3 = Mock()
|
||||||
|
|
||||||
|
result1 = CheckResult(
|
||||||
|
check_type=CheckType.ROW_COUNT,
|
||||||
|
status=Status.PASS,
|
||||||
|
message="OK"
|
||||||
|
)
|
||||||
|
result2 = CheckResult(
|
||||||
|
check_type=CheckType.SCHEMA,
|
||||||
|
status=Status.ERROR,
|
||||||
|
message="Database connection failed"
|
||||||
|
)
|
||||||
|
|
||||||
|
call_count = {"checker1": 0, "checker2": 0, "checker3": 0}
|
||||||
|
|
||||||
|
def mock_check1(table):
|
||||||
|
call_count["checker1"] += 1
|
||||||
|
time.sleep(0.05)
|
||||||
|
return result1
|
||||||
|
|
||||||
|
def mock_check2(table):
|
||||||
|
call_count["checker2"] += 1
|
||||||
|
time.sleep(0.05)
|
||||||
|
return result2
|
||||||
|
|
||||||
|
def mock_check3(table):
|
||||||
|
call_count["checker3"] += 1
|
||||||
|
time.sleep(0.05)
|
||||||
|
return CheckResult(check_type=CheckType.AGGREGATE, status=Status.PASS, message="OK")
|
||||||
|
|
||||||
|
mock_checker1.check = mock_check1
|
||||||
|
mock_checker2.check = mock_check2
|
||||||
|
mock_checker3.check = mock_check3
|
||||||
|
|
||||||
|
# Run with fail-fast
|
||||||
|
from drt.models.table import TableInfo
|
||||||
|
table = TableInfo(schema="dbo", name="TestTable", enabled=True)
|
||||||
|
|
||||||
|
results = {}
|
||||||
|
checkers = [
|
||||||
|
("row_count", mock_checker1),
|
||||||
|
("schema", mock_checker2),
|
||||||
|
("aggregate", mock_checker3)
|
||||||
|
]
|
||||||
|
|
||||||
|
with ThreadPoolExecutor(max_workers=3) as executor:
|
||||||
|
futures = {
|
||||||
|
executor.submit(lambda c: c[1].check(table), c): c[0]
|
||||||
|
for c in checkers
|
||||||
|
}
|
||||||
|
|
||||||
|
for future in futures:
|
||||||
|
name = futures[future]
|
||||||
|
try:
|
||||||
|
result = future.result()
|
||||||
|
results[name] = result
|
||||||
|
if result.status == Status.ERROR:
|
||||||
|
# Cancel remaining
|
||||||
|
for f in futures:
|
||||||
|
f.cancel()
|
||||||
|
break
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
# Verify that we got at least one result
|
||||||
|
assert len(results) >= 1
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
pytest.main([__file__, "-v"])
|
||||||
88
tests/test_parallelization.py
Normal file
88
tests/test_parallelization.py
Normal file
@@ -0,0 +1,88 @@
|
|||||||
|
"""Test parallelization features."""
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from drt.config.models import Config, ExecutionConfig
|
||||||
|
|
||||||
|
|
||||||
|
class TestParallelizationConfig:
|
||||||
|
"""Test parallelization configuration."""
|
||||||
|
|
||||||
|
def test_default_max_workers(self):
|
||||||
|
"""Test default max_workers value."""
|
||||||
|
config = Config()
|
||||||
|
assert config.execution.max_workers == 4
|
||||||
|
|
||||||
|
def test_custom_max_workers(self):
|
||||||
|
"""Test custom max_workers value."""
|
||||||
|
config = Config(execution=ExecutionConfig(max_workers=8))
|
||||||
|
assert config.execution.max_workers == 8
|
||||||
|
|
||||||
|
def test_max_workers_positive(self):
|
||||||
|
"""Test that max_workers is positive."""
|
||||||
|
config = Config()
|
||||||
|
assert config.execution.max_workers > 0
|
||||||
|
|
||||||
|
def test_continue_on_error_default(self):
|
||||||
|
"""Test default continue_on_error value."""
|
||||||
|
config = Config()
|
||||||
|
assert config.execution.continue_on_error is True
|
||||||
|
|
||||||
|
|
||||||
|
def test_imports():
|
||||||
|
"""Test that all parallelization modules can be imported."""
|
||||||
|
from drt.config.models import Config, ExecutionConfig
|
||||||
|
from urllib.parse import quote_plus
|
||||||
|
from sqlalchemy import create_engine, text, QueuePool
|
||||||
|
|
||||||
|
assert Config is not None
|
||||||
|
assert ExecutionConfig is not None
|
||||||
|
assert quote_plus is not None
|
||||||
|
assert create_engine is not None
|
||||||
|
assert text is not None
|
||||||
|
assert QueuePool is not None
|
||||||
|
|
||||||
|
|
||||||
|
def test_url_encoding():
|
||||||
|
"""Test URL encoding for connection strings."""
|
||||||
|
from urllib.parse import quote_plus
|
||||||
|
|
||||||
|
test_conn_str = "DRIVER={ODBC Driver};SERVER=localhost;PWD=test@pass#123"
|
||||||
|
encoded = quote_plus(test_conn_str)
|
||||||
|
|
||||||
|
# Special characters should be encoded
|
||||||
|
assert "%40" in encoded # @ encoded
|
||||||
|
assert "%23" in encoded # # encoded
|
||||||
|
assert "%3D" in encoded # = encoded
|
||||||
|
assert encoded != test_conn_str # Should be different after encoding
|
||||||
|
|
||||||
|
# Test that it can be decoded back
|
||||||
|
from urllib.parse import unquote_plus
|
||||||
|
decoded = unquote_plus(encoded)
|
||||||
|
assert decoded == test_conn_str
|
||||||
|
|
||||||
|
|
||||||
|
def test_config_load_minimal():
|
||||||
|
"""Test loading a minimal config with parallelization settings."""
|
||||||
|
config_dict = {
|
||||||
|
"database_pairs": [
|
||||||
|
{
|
||||||
|
"name": "Test",
|
||||||
|
"enabled": True,
|
||||||
|
"baseline": {"server": "S1", "database": "D1"},
|
||||||
|
"target": {"server": "S1", "database": "D2"}
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"execution": {
|
||||||
|
"max_workers": 6,
|
||||||
|
"continue_on_error": False
|
||||||
|
},
|
||||||
|
"tables": []
|
||||||
|
}
|
||||||
|
|
||||||
|
config = Config(**config_dict)
|
||||||
|
assert config.execution.max_workers == 6
|
||||||
|
assert config.execution.continue_on_error is False
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
pytest.main([__file__, "-v"])
|
||||||
Reference in New Issue
Block a user