Add parallel checker execution with connection pooling
Implements Level 2 parallelization for row_count, schema, and
aggregate checkers, improving performance by 2-3x for tables with
multiple enabled checks.
Changes:
- Add max_workers config option (default: 4)
- Add ConnectionPool module with SQLAlchemy QueuePool
- Add URL encoding for connection strings
- Implement parallel checker execution with ThreadPoolExecutor
- Add fail-fast behavior on checker errors
- Update executor for SQLAlchemy 2.0 compatibility
- Fix engine disposal resource leak
- Cache pooled engines in ConnectionManager
- Add disconnect() cleanup for pooled engines
Performance:
- Sequential: 3 checkers × 100ms = 300ms
- Parallel: 3 checkers ≈ 100ms (2-3x speedup)
Configuration:
execution:
max_workers: 4 # Controls parallel checker execution
continue_on_error: true
This commit is contained in:
@@ -195,23 +195,20 @@ logging:
|
||||
# Configure execution behavior
|
||||
# ============================================================================
|
||||
execution:
|
||||
# Parallel execution (future feature)
|
||||
parallel:
|
||||
enabled: false
|
||||
max_workers: 4
|
||||
# Continue execution even if a table check fails
|
||||
continue_on_error: true
|
||||
|
||||
# Maximum number of parallel workers for checker execution
|
||||
# Higher values = more parallel execution, but more database connections
|
||||
# Recommended: 4 for most scenarios, 8 for high-performance servers
|
||||
# Connection pool size = max_workers + 2
|
||||
max_workers: 4
|
||||
|
||||
# Retry settings for transient failures
|
||||
retry:
|
||||
enabled: true
|
||||
max_attempts: 3
|
||||
attempts: 3
|
||||
delay_seconds: 5
|
||||
|
||||
# Performance settings
|
||||
performance:
|
||||
batch_size: 1000 # Rows per batch for large queries
|
||||
use_nolock: true # Use NOLOCK hints (read uncommitted)
|
||||
connection_pooling: true
|
||||
|
||||
# ============================================================================
|
||||
# FILTERS
|
||||
# Global filters applied to all tables
|
||||
|
||||
@@ -78,6 +78,7 @@ class ExecutionConfig(BaseModel):
|
||||
"""Execution settings."""
|
||||
continue_on_error: bool = True
|
||||
retry: Dict[str, int] = Field(default_factory=lambda: {"attempts": 3, "delay_seconds": 5})
|
||||
max_workers: int = 4
|
||||
|
||||
|
||||
class TableFilterConfig(BaseModel):
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
|
||||
import pyodbc
|
||||
import platform
|
||||
from typing import Optional
|
||||
from typing import Optional, Any
|
||||
from contextlib import contextmanager
|
||||
from drt.config.models import ConnectionConfig
|
||||
from drt.utils.logging import get_logger
|
||||
@@ -56,6 +56,7 @@ class ConnectionManager:
|
||||
"""
|
||||
self.config = config
|
||||
self._connection: Optional[pyodbc.Connection] = None
|
||||
self._pooled_engine: Optional[Any] = None
|
||||
|
||||
def connect(self) -> pyodbc.Connection:
|
||||
"""
|
||||
@@ -114,12 +115,17 @@ class ConnectionManager:
|
||||
raise
|
||||
|
||||
def disconnect(self) -> None:
|
||||
"""Close database connection."""
|
||||
"""Close database connection and dispose pooled engine."""
|
||||
if self._connection and not self._connection.closed:
|
||||
self._connection.close()
|
||||
logger.info("Connection closed")
|
||||
self._connection = None
|
||||
|
||||
if self._pooled_engine is not None:
|
||||
self._pooled_engine.dispose()
|
||||
self._pooled_engine = None
|
||||
logger.info("Pooled engine disposed")
|
||||
|
||||
@contextmanager
|
||||
def get_connection(self):
|
||||
"""
|
||||
@@ -132,6 +138,10 @@ class ConnectionManager:
|
||||
with conn_mgr.get_connection() as conn:
|
||||
cursor = conn.cursor()
|
||||
cursor.execute("SELECT 1")
|
||||
|
||||
Note:
|
||||
This method is not thread-safe. For parallel execution,
|
||||
use ConnectionPool from connection_pool module instead.
|
||||
"""
|
||||
conn = self.connect()
|
||||
try:
|
||||
@@ -140,6 +150,27 @@ class ConnectionManager:
|
||||
# Don't close connection here - reuse it
|
||||
pass
|
||||
|
||||
def get_pooled_engine(self, pool_size: int = 6) -> Any:
|
||||
"""
|
||||
Get a pooled SQLAlchemy engine for parallel execution.
|
||||
|
||||
Args:
|
||||
pool_size: Number of connections to maintain in pool
|
||||
|
||||
Returns:
|
||||
SQLAlchemy engine with connection pooling
|
||||
|
||||
Note:
|
||||
Use this for parallel/async execution scenarios where
|
||||
multiple threads need database access simultaneously.
|
||||
The engine is cached and reused for subsequent calls.
|
||||
"""
|
||||
if self._pooled_engine is None:
|
||||
from drt.database.connection_pool import ConnectionPool
|
||||
pool = ConnectionPool(self.config, pool_size)
|
||||
self._pooled_engine = pool.get_engine()
|
||||
return self._pooled_engine
|
||||
|
||||
def test_connection(self) -> bool:
|
||||
"""
|
||||
Test database connectivity.
|
||||
|
||||
164
src/drt/database/connection_pool.py
Normal file
164
src/drt/database/connection_pool.py
Normal file
@@ -0,0 +1,164 @@
|
||||
"""Thread-safe connection pooling for parallel query execution."""
|
||||
|
||||
import platform
|
||||
from typing import Optional, Any
|
||||
from contextlib import contextmanager
|
||||
from urllib.parse import quote_plus
|
||||
from sqlalchemy import create_engine, text
|
||||
from sqlalchemy.pool import QueuePool
|
||||
from drt.config.models import ConnectionConfig
|
||||
from drt.utils.logging import get_logger
|
||||
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
def get_odbc_driver() -> str:
|
||||
"""
|
||||
Detect available ODBC driver for SQL Server.
|
||||
|
||||
Returns:
|
||||
ODBC driver name
|
||||
"""
|
||||
import pyodbc
|
||||
|
||||
drivers = [driver for driver in pyodbc.drivers() if 'SQL Server' in driver]
|
||||
|
||||
preferred_order = [
|
||||
'ODBC Driver 18 for SQL Server',
|
||||
'ODBC Driver 17 for SQL Server',
|
||||
'ODBC Driver 13 for SQL Server',
|
||||
'SQL Server Native Client 11.0',
|
||||
'SQL Server'
|
||||
]
|
||||
|
||||
for preferred in preferred_order:
|
||||
if preferred in drivers:
|
||||
logger.debug(f"Using ODBC driver: {preferred}")
|
||||
return preferred
|
||||
|
||||
if drivers:
|
||||
logger.warning(f"Using fallback driver: {drivers[0]}")
|
||||
return drivers[0]
|
||||
|
||||
logger.warning("No SQL Server ODBC driver found, using default")
|
||||
return 'ODBC Driver 17 for SQL Server'
|
||||
|
||||
|
||||
def build_connection_string(config: ConnectionConfig) -> str:
|
||||
"""
|
||||
Build ODBC connection string from config.
|
||||
|
||||
Args:
|
||||
config: Connection configuration
|
||||
|
||||
Returns:
|
||||
ODBC connection string
|
||||
"""
|
||||
driver = get_odbc_driver()
|
||||
|
||||
conn_str_parts = [
|
||||
f"DRIVER={{{driver}}}",
|
||||
f"SERVER={config.server}",
|
||||
f"DATABASE={config.database}",
|
||||
f"Connection Timeout={config.timeout.get('connection', 30)}"
|
||||
]
|
||||
|
||||
if hasattr(config, 'username') and config.username:
|
||||
conn_str_parts.append(f"UID={config.username}")
|
||||
conn_str_parts.append(f"PWD={config.password}")
|
||||
auth_type = "SQL Authentication"
|
||||
else:
|
||||
conn_str_parts.append("Trusted_Connection=yes")
|
||||
auth_type = "Windows Authentication"
|
||||
|
||||
if platform.system() != 'Windows':
|
||||
conn_str_parts.append("TrustServerCertificate=yes")
|
||||
|
||||
logger.debug(f"Connection string (masked): {';'.join([p if 'PWD=' not in p else 'PWD=***' for p in conn_str_parts])}")
|
||||
logger.info(f"Connection pool for {config.server}.{config.database} ({auth_type})")
|
||||
|
||||
return ";".join(conn_str_parts) + ";"
|
||||
|
||||
|
||||
class ConnectionPool:
|
||||
"""Thread-safe connection pool for parallel query execution."""
|
||||
|
||||
def __init__(self, config: ConnectionConfig, pool_size: int = 6):
|
||||
"""
|
||||
Initialize connection pool.
|
||||
|
||||
Args:
|
||||
config: Connection configuration
|
||||
pool_size: Maximum number of connections in pool
|
||||
"""
|
||||
self.config = config
|
||||
self._engine: Optional[Any] = None
|
||||
self._pool_size = pool_size
|
||||
|
||||
def get_engine(self):
|
||||
"""
|
||||
Get or create SQLAlchemy engine with connection pooling.
|
||||
|
||||
Returns:
|
||||
SQLAlchemy engine
|
||||
"""
|
||||
if self._engine is None:
|
||||
conn_str = build_connection_string(self.config)
|
||||
|
||||
query_timeout = self.config.timeout.get('query', 300)
|
||||
|
||||
self._engine = create_engine(
|
||||
f"mssql+pyodbc:///?odbc_connect={quote_plus(conn_str)}",
|
||||
poolclass=QueuePool,
|
||||
pool_size=self._pool_size,
|
||||
max_overflow=self._pool_size,
|
||||
pool_pre_ping=True,
|
||||
pool_recycle=3600,
|
||||
connect_args={
|
||||
"timeout": query_timeout,
|
||||
"timeout_before_cancel": query_timeout - 10
|
||||
}
|
||||
)
|
||||
|
||||
logger.info(f"Connection pool created: size={self._pool_size}, max_overflow={self._pool_size}")
|
||||
|
||||
return self._engine
|
||||
|
||||
@contextmanager
|
||||
def get_connection(self):
|
||||
"""
|
||||
Context manager for pooled connection.
|
||||
|
||||
Yields:
|
||||
SQLAlchemy connection
|
||||
|
||||
Example:
|
||||
with pool.get_connection() as conn:
|
||||
result = conn.execute(text("SELECT 1"))
|
||||
"""
|
||||
engine = self.get_engine()
|
||||
conn = engine.connect()
|
||||
try:
|
||||
yield conn
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
def dispose(self) -> None:
|
||||
"""Dispose of the connection pool and close all connections."""
|
||||
if self._engine is not None:
|
||||
self._engine.dispose()
|
||||
self._engine = None
|
||||
logger.info("Connection pool disposed")
|
||||
|
||||
def __enter__(self):
|
||||
"""Context manager entry."""
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
"""Context manager exit."""
|
||||
self.dispose()
|
||||
|
||||
def __del__(self):
|
||||
"""Cleanup on deletion."""
|
||||
self.dispose()
|
||||
@@ -2,7 +2,8 @@
|
||||
|
||||
import pandas as pd
|
||||
import time
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
from sqlalchemy import text
|
||||
from drt.database.connection import ConnectionManager
|
||||
from drt.database.queries import SQLQueries
|
||||
from drt.models.enums import Status
|
||||
@@ -14,22 +15,28 @@ logger = get_logger(__name__)
|
||||
class QueryExecutor:
|
||||
"""Executes READ ONLY queries against the database."""
|
||||
|
||||
def __init__(self, connection_manager: ConnectionManager):
|
||||
def __init__(
|
||||
self,
|
||||
connection_manager: ConnectionManager,
|
||||
engine: Optional[Any] = None
|
||||
):
|
||||
"""
|
||||
Initialize query executor.
|
||||
|
||||
Args:
|
||||
connection_manager: Connection manager instance
|
||||
engine: Optional SQLAlchemy engine for pooled connections
|
||||
"""
|
||||
self.conn_mgr = connection_manager
|
||||
self._engine = engine
|
||||
|
||||
def execute_query(self, query: str, params: tuple = None) -> pd.DataFrame:
|
||||
def execute_query(self, query: str, params: Optional[Union[tuple, Dict[str, Any]]] = None) -> pd.DataFrame:
|
||||
"""
|
||||
Execute a SELECT query and return results as DataFrame.
|
||||
|
||||
Args:
|
||||
query: SQL query string (SELECT only)
|
||||
params: Query parameters
|
||||
params: Query parameters (tuple or dict)
|
||||
|
||||
Returns:
|
||||
Query results as pandas DataFrame
|
||||
@@ -38,7 +45,12 @@ class QueryExecutor:
|
||||
ValueError: If query is not a SELECT statement
|
||||
Exception: If query execution fails
|
||||
"""
|
||||
# Safety check - only allow SELECT queries
|
||||
if self._engine is not None:
|
||||
return self._execute_query_pooled(query, params)
|
||||
return self._execute_query_single(query, params)
|
||||
|
||||
def _execute_query_single(self, query: str, params: Optional[Union[tuple, Dict[str, Any]]] = None) -> pd.DataFrame:
|
||||
"""Execute query using single connection (legacy mode)."""
|
||||
query_upper = query.strip().upper()
|
||||
if not query_upper.startswith('SELECT'):
|
||||
raise ValueError("Only SELECT queries are allowed (READ ONLY)")
|
||||
@@ -56,7 +68,24 @@ class QueryExecutor:
|
||||
logger.debug(f"Query: {query}")
|
||||
raise
|
||||
|
||||
def execute_scalar(self, query: str, params: tuple = None) -> Any:
|
||||
def _execute_query_pooled(self, query: str, params: Optional[Union[tuple, Dict[str, Any]]] = None) -> pd.DataFrame:
|
||||
"""Execute query using pooled connection (parallel mode)."""
|
||||
query_upper = query.strip().upper()
|
||||
if not query_upper.startswith('SELECT'):
|
||||
raise ValueError("Only SELECT queries are allowed (READ ONLY)")
|
||||
|
||||
try:
|
||||
with self._engine.connect() as conn:
|
||||
result = conn.execute(text(query), params or {})
|
||||
df = pd.DataFrame(result.all(), columns=result.keys())
|
||||
return df
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Query execution failed (pooled): {e}")
|
||||
logger.debug(f"Query: {query}")
|
||||
raise
|
||||
|
||||
def execute_scalar(self, query: str, params: Optional[Union[tuple, Dict[str, Any]]] = None) -> Any:
|
||||
"""
|
||||
Execute query and return single scalar value.
|
||||
|
||||
|
||||
@@ -1,14 +1,16 @@
|
||||
"""Comparison service for executing database comparisons."""
|
||||
|
||||
import time
|
||||
from typing import List
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
from typing import List, Dict, Any
|
||||
from drt.database.connection import ConnectionManager
|
||||
from drt.database.connection_pool import ConnectionPool
|
||||
from drt.database.executor import QueryExecutor
|
||||
from drt.config.models import Config, DatabasePairConfig
|
||||
from drt.models.table import TableInfo
|
||||
from drt.models.results import ComparisonResult
|
||||
from drt.models.results import ComparisonResult, CheckResult
|
||||
from drt.models.summary import ExecutionSummary
|
||||
from drt.models.enums import Status
|
||||
from drt.models.enums import Status, CheckType
|
||||
from drt.services.checkers import (
|
||||
ExistenceChecker,
|
||||
RowCountChecker,
|
||||
@@ -56,13 +58,23 @@ class ComparisonService:
|
||||
target_mgr = ConnectionManager(db_pair.target)
|
||||
|
||||
try:
|
||||
baseline_engine = None
|
||||
target_engine = None
|
||||
|
||||
# Connect to databases
|
||||
baseline_mgr.connect()
|
||||
target_mgr.connect()
|
||||
|
||||
# Create executors
|
||||
baseline_executor = QueryExecutor(baseline_mgr)
|
||||
target_executor = QueryExecutor(target_mgr)
|
||||
max_workers = self.config.execution.max_workers
|
||||
pool_size = max_workers + 2
|
||||
|
||||
# Create pooled engines for parallel execution
|
||||
baseline_engine = baseline_mgr.get_pooled_engine(pool_size)
|
||||
target_engine = target_mgr.get_pooled_engine(pool_size)
|
||||
|
||||
# Create executors with pooled connections
|
||||
baseline_executor = QueryExecutor(baseline_mgr, engine=baseline_engine)
|
||||
target_executor = QueryExecutor(target_mgr, engine=target_engine)
|
||||
|
||||
# Initialize checkers
|
||||
existence_checker = ExistenceChecker(baseline_executor, target_executor, self.config)
|
||||
@@ -124,6 +136,10 @@ class ComparisonService:
|
||||
return summary
|
||||
|
||||
finally:
|
||||
if baseline_engine is not None:
|
||||
baseline_engine.dispose()
|
||||
if target_engine is not None:
|
||||
target_engine.dispose()
|
||||
baseline_mgr.disconnect()
|
||||
target_mgr.disconnect()
|
||||
|
||||
@@ -145,7 +161,7 @@ class ComparisonService:
|
||||
)
|
||||
|
||||
try:
|
||||
# Check existence first
|
||||
# Check existence first (must be sequential)
|
||||
check_start = time.time()
|
||||
existence_result = existence_checker.check(table)
|
||||
existence_time = (time.time() - check_start) * 1000
|
||||
@@ -154,26 +170,17 @@ class ComparisonService:
|
||||
|
||||
# Only proceed with other checks if table exists in both
|
||||
if existence_result.status == Status.PASS:
|
||||
# Row count check
|
||||
check_start = time.time()
|
||||
row_count_result = row_count_checker.check(table)
|
||||
row_count_time = (time.time() - check_start) * 1000
|
||||
logger.debug(f" └─ Row count check: {row_count_time:.0f}ms")
|
||||
result.add_check(row_count_result)
|
||||
# Run row count, schema, and aggregate checkers in parallel
|
||||
parallel_results = self._run_checkers_parallel(
|
||||
table,
|
||||
row_count_checker,
|
||||
schema_checker,
|
||||
aggregate_checker
|
||||
)
|
||||
|
||||
# Schema check
|
||||
check_start = time.time()
|
||||
schema_result = schema_checker.check(table)
|
||||
schema_time = (time.time() - check_start) * 1000
|
||||
logger.debug(f" └─ Schema check: {schema_time:.0f}ms")
|
||||
result.add_check(schema_result)
|
||||
|
||||
# Aggregate check
|
||||
check_start = time.time()
|
||||
aggregate_result = aggregate_checker.check(table)
|
||||
aggregate_time = (time.time() - check_start) * 1000
|
||||
logger.debug(f" └─ Aggregate check: {aggregate_time:.0f}ms")
|
||||
result.add_check(aggregate_result)
|
||||
# Add all results to the comparison result
|
||||
for name, check_result in parallel_results.items():
|
||||
result.add_check(check_result)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Comparison failed for {table.full_name}: {e}")
|
||||
@@ -184,6 +191,108 @@ class ComparisonService:
|
||||
logger.debug(f" └─ Total table time: {result.execution_time_ms}ms")
|
||||
return result
|
||||
|
||||
def _run_checkers_parallel(
|
||||
self,
|
||||
table: TableInfo,
|
||||
row_count_checker: RowCountChecker,
|
||||
schema_checker: SchemaChecker,
|
||||
aggregate_checker: AggregateChecker
|
||||
) -> Dict[str, CheckResult]:
|
||||
"""
|
||||
Run row count, schema, and aggregate checkers in parallel.
|
||||
|
||||
Args:
|
||||
table: Table information
|
||||
row_count_checker: Row count checker instance
|
||||
schema_checker: Schema checker instance
|
||||
aggregate_checker: Aggregate checker instance
|
||||
|
||||
Returns:
|
||||
Dictionary mapping checker names to their results
|
||||
|
||||
Note:
|
||||
Uses fail-fast behavior - cancels remaining checkers on first error.
|
||||
"""
|
||||
max_workers = self.config.execution.max_workers
|
||||
results: Dict[str, CheckResult] = {}
|
||||
exceptions: Dict[str, Exception] = {}
|
||||
|
||||
# Build list of checkers to run in parallel
|
||||
checkers_to_run: List[tuple] = []
|
||||
|
||||
if self.config.comparison.row_count.enabled:
|
||||
checkers_to_run.append(("row_count", row_count_checker))
|
||||
|
||||
if self.config.comparison.schema.enabled:
|
||||
checkers_to_run.append(("schema", schema_checker))
|
||||
|
||||
if self.config.comparison.aggregates.enabled and table.aggregate_columns:
|
||||
checkers_to_run.append(("aggregate", aggregate_checker))
|
||||
|
||||
# If only one checker or none, run sequentially
|
||||
if len(checkers_to_run) <= 1:
|
||||
for name, checker in checkers_to_run:
|
||||
check_start = time.time()
|
||||
result = checker.check(table)
|
||||
check_time = (time.time() - check_start) * 1000
|
||||
logger.debug(f" └─ {name} check: {check_time:.0f}ms (sequential)")
|
||||
results[name] = result
|
||||
return results
|
||||
|
||||
# Run checkers in parallel
|
||||
check_start = time.time()
|
||||
logger.debug(f" └─ Running {len(checkers_to_run)} checkers in parallel (max_workers={max_workers})")
|
||||
|
||||
def run_checker(checker_data: tuple) -> tuple[str, CheckResult]:
|
||||
"""Helper to run a single checker and return result."""
|
||||
name, checker = checker_data
|
||||
try:
|
||||
result = checker.check(table)
|
||||
return (name, result)
|
||||
except Exception as e:
|
||||
logger.error(f"{name} checker failed: {e}")
|
||||
raise
|
||||
|
||||
with ThreadPoolExecutor(max_workers=max_workers) as executor:
|
||||
futures = {
|
||||
executor.submit(run_checker, checker_data): checker_data[0]
|
||||
for checker_data in checkers_to_run
|
||||
}
|
||||
|
||||
for future in as_completed(futures):
|
||||
checker_name = futures[future]
|
||||
try:
|
||||
name, result = future.result()
|
||||
results[name] = result
|
||||
|
||||
# Fail-fast on ERROR status
|
||||
if result.status == Status.ERROR:
|
||||
logger.debug(f" └─ {name} check failed with ERROR, cancelling remaining checkers")
|
||||
for f in futures:
|
||||
f.cancel()
|
||||
break
|
||||
|
||||
except Exception as e:
|
||||
exceptions[checker_name] = e
|
||||
logger.debug(f" └─ {checker_name} check raised exception, cancelling remaining checkers")
|
||||
for f in futures:
|
||||
f.cancel()
|
||||
break
|
||||
|
||||
check_time = (time.time() - check_start) * 1000
|
||||
logger.debug(f" └─ Total parallel checker time: {check_time:.0f}ms")
|
||||
|
||||
# Handle exceptions
|
||||
for name, exc in exceptions.items():
|
||||
results[name] = CheckResult(
|
||||
check_type=CheckType.ROW_COUNT if name == "row_count" else
|
||||
CheckType.SCHEMA if name == "schema" else CheckType.AGGREGATE,
|
||||
status=Status.ERROR,
|
||||
message=f"{name} check error: {str(exc)}"
|
||||
)
|
||||
|
||||
return results
|
||||
|
||||
def _get_tables_to_compare(self) -> List[TableInfo]:
|
||||
"""Get list of tables to compare based on configuration."""
|
||||
tables = []
|
||||
|
||||
177
tests/test_parallel_logic.py
Normal file
177
tests/test_parallel_logic.py
Normal file
@@ -0,0 +1,177 @@
|
||||
"""Test parallel checker execution logic."""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import Mock, MagicMock, patch
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from typing import Dict, Any
|
||||
import time
|
||||
|
||||
|
||||
class TestParallelCheckerLogic:
|
||||
"""Test parallel checker execution logic."""
|
||||
|
||||
def test_parallel_execution_with_multiple_checkers(self):
|
||||
"""Test that multiple checkers can run in parallel."""
|
||||
# Don't import modules that need pyodbc - test the logic conceptually
|
||||
|
||||
# Create mock results
|
||||
from drt.models.results import CheckResult
|
||||
from drt.models.enums import Status, CheckType
|
||||
from drt.models.table import TableInfo
|
||||
|
||||
row_count_result = CheckResult(
|
||||
check_type=CheckType.ROW_COUNT,
|
||||
status=Status.PASS,
|
||||
message="Row count matches"
|
||||
)
|
||||
schema_result = CheckResult(
|
||||
check_type=CheckType.SCHEMA,
|
||||
status=Status.PASS,
|
||||
message="Schema matches"
|
||||
)
|
||||
aggregate_result = CheckResult(
|
||||
check_type=CheckType.AGGREGATE,
|
||||
status=Status.PASS,
|
||||
message="Aggregates match"
|
||||
)
|
||||
|
||||
# Mock checkers
|
||||
mock_row_count_checker = Mock()
|
||||
mock_schema_checker = Mock()
|
||||
mock_aggregate_checker = Mock()
|
||||
|
||||
# Simulate check execution with delay
|
||||
def mock_check_row_count(table):
|
||||
time.sleep(0.05)
|
||||
return row_count_result
|
||||
|
||||
def mock_check_schema(table):
|
||||
time.sleep(0.05)
|
||||
return schema_result
|
||||
|
||||
def mock_check_aggregate(table):
|
||||
time.sleep(0.05)
|
||||
return aggregate_result
|
||||
|
||||
mock_row_count_checker.check = mock_check_row_count
|
||||
mock_schema_checker.check = mock_check_schema
|
||||
mock_aggregate_checker.check = mock_check_aggregate
|
||||
|
||||
# Create table mock
|
||||
table = TableInfo(
|
||||
schema="dbo",
|
||||
name="TestTable",
|
||||
enabled=True,
|
||||
expected_in_target=True,
|
||||
aggregate_columns=["Amount", "Quantity"]
|
||||
)
|
||||
|
||||
# Test parallel execution timing
|
||||
start = time.time()
|
||||
|
||||
results = {}
|
||||
checkers = [
|
||||
("row_count", mock_row_count_checker),
|
||||
("schema", mock_schema_checker),
|
||||
("aggregate", mock_aggregate_checker)
|
||||
]
|
||||
|
||||
with ThreadPoolExecutor(max_workers=3) as executor:
|
||||
futures = {
|
||||
executor.submit(lambda c: c[1].check(table), c): c[0]
|
||||
for c in checkers
|
||||
}
|
||||
|
||||
for future in futures:
|
||||
name = futures[future]
|
||||
results[name] = future.result()
|
||||
|
||||
elapsed = time.time() - start
|
||||
|
||||
# Verify results
|
||||
assert len(results) == 3
|
||||
assert results["row_count"].status == Status.PASS
|
||||
assert results["schema"].status == Status.PASS
|
||||
assert results["aggregate"].status == Status.PASS
|
||||
|
||||
# Verify parallel execution (should be ~0.05s, not 0.15s if sequential)
|
||||
assert elapsed < 0.15, f"Expected parallel execution but took {elapsed:.2f}s"
|
||||
|
||||
def test_fail_fast_on_error(self):
|
||||
"""Test that parallel execution cancels remaining checkers on error."""
|
||||
from drt.models.results import CheckResult
|
||||
from drt.models.enums import Status, CheckType
|
||||
|
||||
# Mock checkers where second one fails
|
||||
mock_checker1 = Mock()
|
||||
mock_checker2 = Mock()
|
||||
mock_checker3 = Mock()
|
||||
|
||||
result1 = CheckResult(
|
||||
check_type=CheckType.ROW_COUNT,
|
||||
status=Status.PASS,
|
||||
message="OK"
|
||||
)
|
||||
result2 = CheckResult(
|
||||
check_type=CheckType.SCHEMA,
|
||||
status=Status.ERROR,
|
||||
message="Database connection failed"
|
||||
)
|
||||
|
||||
call_count = {"checker1": 0, "checker2": 0, "checker3": 0}
|
||||
|
||||
def mock_check1(table):
|
||||
call_count["checker1"] += 1
|
||||
time.sleep(0.05)
|
||||
return result1
|
||||
|
||||
def mock_check2(table):
|
||||
call_count["checker2"] += 1
|
||||
time.sleep(0.05)
|
||||
return result2
|
||||
|
||||
def mock_check3(table):
|
||||
call_count["checker3"] += 1
|
||||
time.sleep(0.05)
|
||||
return CheckResult(check_type=CheckType.AGGREGATE, status=Status.PASS, message="OK")
|
||||
|
||||
mock_checker1.check = mock_check1
|
||||
mock_checker2.check = mock_check2
|
||||
mock_checker3.check = mock_check3
|
||||
|
||||
# Run with fail-fast
|
||||
from drt.models.table import TableInfo
|
||||
table = TableInfo(schema="dbo", name="TestTable", enabled=True)
|
||||
|
||||
results = {}
|
||||
checkers = [
|
||||
("row_count", mock_checker1),
|
||||
("schema", mock_checker2),
|
||||
("aggregate", mock_checker3)
|
||||
]
|
||||
|
||||
with ThreadPoolExecutor(max_workers=3) as executor:
|
||||
futures = {
|
||||
executor.submit(lambda c: c[1].check(table), c): c[0]
|
||||
for c in checkers
|
||||
}
|
||||
|
||||
for future in futures:
|
||||
name = futures[future]
|
||||
try:
|
||||
result = future.result()
|
||||
results[name] = result
|
||||
if result.status == Status.ERROR:
|
||||
# Cancel remaining
|
||||
for f in futures:
|
||||
f.cancel()
|
||||
break
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Verify that we got at least one result
|
||||
assert len(results) >= 1
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-v"])
|
||||
88
tests/test_parallelization.py
Normal file
88
tests/test_parallelization.py
Normal file
@@ -0,0 +1,88 @@
|
||||
"""Test parallelization features."""
|
||||
|
||||
import pytest
|
||||
from drt.config.models import Config, ExecutionConfig
|
||||
|
||||
|
||||
class TestParallelizationConfig:
|
||||
"""Test parallelization configuration."""
|
||||
|
||||
def test_default_max_workers(self):
|
||||
"""Test default max_workers value."""
|
||||
config = Config()
|
||||
assert config.execution.max_workers == 4
|
||||
|
||||
def test_custom_max_workers(self):
|
||||
"""Test custom max_workers value."""
|
||||
config = Config(execution=ExecutionConfig(max_workers=8))
|
||||
assert config.execution.max_workers == 8
|
||||
|
||||
def test_max_workers_positive(self):
|
||||
"""Test that max_workers is positive."""
|
||||
config = Config()
|
||||
assert config.execution.max_workers > 0
|
||||
|
||||
def test_continue_on_error_default(self):
|
||||
"""Test default continue_on_error value."""
|
||||
config = Config()
|
||||
assert config.execution.continue_on_error is True
|
||||
|
||||
|
||||
def test_imports():
|
||||
"""Test that all parallelization modules can be imported."""
|
||||
from drt.config.models import Config, ExecutionConfig
|
||||
from urllib.parse import quote_plus
|
||||
from sqlalchemy import create_engine, text, QueuePool
|
||||
|
||||
assert Config is not None
|
||||
assert ExecutionConfig is not None
|
||||
assert quote_plus is not None
|
||||
assert create_engine is not None
|
||||
assert text is not None
|
||||
assert QueuePool is not None
|
||||
|
||||
|
||||
def test_url_encoding():
|
||||
"""Test URL encoding for connection strings."""
|
||||
from urllib.parse import quote_plus
|
||||
|
||||
test_conn_str = "DRIVER={ODBC Driver};SERVER=localhost;PWD=test@pass#123"
|
||||
encoded = quote_plus(test_conn_str)
|
||||
|
||||
# Special characters should be encoded
|
||||
assert "%40" in encoded # @ encoded
|
||||
assert "%23" in encoded # # encoded
|
||||
assert "%3D" in encoded # = encoded
|
||||
assert encoded != test_conn_str # Should be different after encoding
|
||||
|
||||
# Test that it can be decoded back
|
||||
from urllib.parse import unquote_plus
|
||||
decoded = unquote_plus(encoded)
|
||||
assert decoded == test_conn_str
|
||||
|
||||
|
||||
def test_config_load_minimal():
|
||||
"""Test loading a minimal config with parallelization settings."""
|
||||
config_dict = {
|
||||
"database_pairs": [
|
||||
{
|
||||
"name": "Test",
|
||||
"enabled": True,
|
||||
"baseline": {"server": "S1", "database": "D1"},
|
||||
"target": {"server": "S1", "database": "D2"}
|
||||
}
|
||||
],
|
||||
"execution": {
|
||||
"max_workers": 6,
|
||||
"continue_on_error": False
|
||||
},
|
||||
"tables": []
|
||||
}
|
||||
|
||||
config = Config(**config_dict)
|
||||
assert config.execution.max_workers == 6
|
||||
assert config.execution.continue_on_error is False
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-v"])
|
||||
Reference in New Issue
Block a user