Add parallel checker execution with connection pooling

Implements Level 2 parallelization for row_count, schema, and
aggregate checkers, improving performance by 2-3x for tables with
multiple enabled checks.

Changes:
- Add max_workers config option (default: 4)
- Add ConnectionPool module with SQLAlchemy QueuePool
- Add URL encoding for connection strings
- Implement parallel checker execution with ThreadPoolExecutor
- Add fail-fast behavior on checker errors
- Update executor for SQLAlchemy 2.0 compatibility
- Fix engine disposal resource leak
- Cache pooled engines in ConnectionManager
- Add disconnect() cleanup for pooled engines

Performance:
- Sequential: 3 checkers × 100ms = 300ms
- Parallel: 3 checkers ≈ 100ms (2-3x speedup)

Configuration:
  execution:
    max_workers: 4  # Controls parallel checker execution
    continue_on_error: true
This commit is contained in:
DevOps Team
2026-02-11 21:46:10 +07:00
parent f5b190c91d
commit 40bc615bf7
8 changed files with 643 additions and 47 deletions

View File

@@ -195,22 +195,19 @@ logging:
# Configure execution behavior # Configure execution behavior
# ============================================================================ # ============================================================================
execution: execution:
# Parallel execution (future feature) # Continue execution even if a table check fails
parallel: continue_on_error: true
enabled: false
max_workers: 4 # Maximum number of parallel workers for checker execution
# Higher values = more parallel execution, but more database connections
# Recommended: 4 for most scenarios, 8 for high-performance servers
# Connection pool size = max_workers + 2
max_workers: 4
# Retry settings for transient failures # Retry settings for transient failures
retry: retry:
enabled: true attempts: 3
max_attempts: 3
delay_seconds: 5 delay_seconds: 5
# Performance settings
performance:
batch_size: 1000 # Rows per batch for large queries
use_nolock: true # Use NOLOCK hints (read uncommitted)
connection_pooling: true
# ============================================================================ # ============================================================================
# FILTERS # FILTERS

View File

@@ -78,6 +78,7 @@ class ExecutionConfig(BaseModel):
"""Execution settings.""" """Execution settings."""
continue_on_error: bool = True continue_on_error: bool = True
retry: Dict[str, int] = Field(default_factory=lambda: {"attempts": 3, "delay_seconds": 5}) retry: Dict[str, int] = Field(default_factory=lambda: {"attempts": 3, "delay_seconds": 5})
max_workers: int = 4
class TableFilterConfig(BaseModel): class TableFilterConfig(BaseModel):

View File

@@ -2,7 +2,7 @@
import pyodbc import pyodbc
import platform import platform
from typing import Optional from typing import Optional, Any
from contextlib import contextmanager from contextlib import contextmanager
from drt.config.models import ConnectionConfig from drt.config.models import ConnectionConfig
from drt.utils.logging import get_logger from drt.utils.logging import get_logger
@@ -56,6 +56,7 @@ class ConnectionManager:
""" """
self.config = config self.config = config
self._connection: Optional[pyodbc.Connection] = None self._connection: Optional[pyodbc.Connection] = None
self._pooled_engine: Optional[Any] = None
def connect(self) -> pyodbc.Connection: def connect(self) -> pyodbc.Connection:
""" """
@@ -114,11 +115,16 @@ class ConnectionManager:
raise raise
def disconnect(self) -> None: def disconnect(self) -> None:
"""Close database connection.""" """Close database connection and dispose pooled engine."""
if self._connection and not self._connection.closed: if self._connection and not self._connection.closed:
self._connection.close() self._connection.close()
logger.info("Connection closed") logger.info("Connection closed")
self._connection = None self._connection = None
if self._pooled_engine is not None:
self._pooled_engine.dispose()
self._pooled_engine = None
logger.info("Pooled engine disposed")
@contextmanager @contextmanager
def get_connection(self): def get_connection(self):
@@ -132,6 +138,10 @@ class ConnectionManager:
with conn_mgr.get_connection() as conn: with conn_mgr.get_connection() as conn:
cursor = conn.cursor() cursor = conn.cursor()
cursor.execute("SELECT 1") cursor.execute("SELECT 1")
Note:
This method is not thread-safe. For parallel execution,
use ConnectionPool from connection_pool module instead.
""" """
conn = self.connect() conn = self.connect()
try: try:
@@ -140,6 +150,27 @@ class ConnectionManager:
# Don't close connection here - reuse it # Don't close connection here - reuse it
pass pass
def get_pooled_engine(self, pool_size: int = 6) -> Any:
"""
Get a pooled SQLAlchemy engine for parallel execution.
Args:
pool_size: Number of connections to maintain in pool
Returns:
SQLAlchemy engine with connection pooling
Note:
Use this for parallel/async execution scenarios where
multiple threads need database access simultaneously.
The engine is cached and reused for subsequent calls.
"""
if self._pooled_engine is None:
from drt.database.connection_pool import ConnectionPool
pool = ConnectionPool(self.config, pool_size)
self._pooled_engine = pool.get_engine()
return self._pooled_engine
def test_connection(self) -> bool: def test_connection(self) -> bool:
""" """
Test database connectivity. Test database connectivity.

View File

@@ -0,0 +1,164 @@
"""Thread-safe connection pooling for parallel query execution."""
import platform
from typing import Optional, Any
from contextlib import contextmanager
from urllib.parse import quote_plus
from sqlalchemy import create_engine, text
from sqlalchemy.pool import QueuePool
from drt.config.models import ConnectionConfig
from drt.utils.logging import get_logger
logger = get_logger(__name__)
def get_odbc_driver() -> str:
"""
Detect available ODBC driver for SQL Server.
Returns:
ODBC driver name
"""
import pyodbc
drivers = [driver for driver in pyodbc.drivers() if 'SQL Server' in driver]
preferred_order = [
'ODBC Driver 18 for SQL Server',
'ODBC Driver 17 for SQL Server',
'ODBC Driver 13 for SQL Server',
'SQL Server Native Client 11.0',
'SQL Server'
]
for preferred in preferred_order:
if preferred in drivers:
logger.debug(f"Using ODBC driver: {preferred}")
return preferred
if drivers:
logger.warning(f"Using fallback driver: {drivers[0]}")
return drivers[0]
logger.warning("No SQL Server ODBC driver found, using default")
return 'ODBC Driver 17 for SQL Server'
def build_connection_string(config: ConnectionConfig) -> str:
"""
Build ODBC connection string from config.
Args:
config: Connection configuration
Returns:
ODBC connection string
"""
driver = get_odbc_driver()
conn_str_parts = [
f"DRIVER={{{driver}}}",
f"SERVER={config.server}",
f"DATABASE={config.database}",
f"Connection Timeout={config.timeout.get('connection', 30)}"
]
if hasattr(config, 'username') and config.username:
conn_str_parts.append(f"UID={config.username}")
conn_str_parts.append(f"PWD={config.password}")
auth_type = "SQL Authentication"
else:
conn_str_parts.append("Trusted_Connection=yes")
auth_type = "Windows Authentication"
if platform.system() != 'Windows':
conn_str_parts.append("TrustServerCertificate=yes")
logger.debug(f"Connection string (masked): {';'.join([p if 'PWD=' not in p else 'PWD=***' for p in conn_str_parts])}")
logger.info(f"Connection pool for {config.server}.{config.database} ({auth_type})")
return ";".join(conn_str_parts) + ";"
class ConnectionPool:
"""Thread-safe connection pool for parallel query execution."""
def __init__(self, config: ConnectionConfig, pool_size: int = 6):
"""
Initialize connection pool.
Args:
config: Connection configuration
pool_size: Maximum number of connections in pool
"""
self.config = config
self._engine: Optional[Any] = None
self._pool_size = pool_size
def get_engine(self):
"""
Get or create SQLAlchemy engine with connection pooling.
Returns:
SQLAlchemy engine
"""
if self._engine is None:
conn_str = build_connection_string(self.config)
query_timeout = self.config.timeout.get('query', 300)
self._engine = create_engine(
f"mssql+pyodbc:///?odbc_connect={quote_plus(conn_str)}",
poolclass=QueuePool,
pool_size=self._pool_size,
max_overflow=self._pool_size,
pool_pre_ping=True,
pool_recycle=3600,
connect_args={
"timeout": query_timeout,
"timeout_before_cancel": query_timeout - 10
}
)
logger.info(f"Connection pool created: size={self._pool_size}, max_overflow={self._pool_size}")
return self._engine
@contextmanager
def get_connection(self):
"""
Context manager for pooled connection.
Yields:
SQLAlchemy connection
Example:
with pool.get_connection() as conn:
result = conn.execute(text("SELECT 1"))
"""
engine = self.get_engine()
conn = engine.connect()
try:
yield conn
finally:
conn.close()
def dispose(self) -> None:
"""Dispose of the connection pool and close all connections."""
if self._engine is not None:
self._engine.dispose()
self._engine = None
logger.info("Connection pool disposed")
def __enter__(self):
"""Context manager entry."""
return self
def __exit__(self, exc_type, exc_val, exc_tb):
"""Context manager exit."""
self.dispose()
def __del__(self):
"""Cleanup on deletion."""
self.dispose()

View File

@@ -2,7 +2,8 @@
import pandas as pd import pandas as pd
import time import time
from typing import Any, Dict, List, Optional, Tuple from typing import Any, Dict, List, Optional, Tuple, Union
from sqlalchemy import text
from drt.database.connection import ConnectionManager from drt.database.connection import ConnectionManager
from drt.database.queries import SQLQueries from drt.database.queries import SQLQueries
from drt.models.enums import Status from drt.models.enums import Status
@@ -14,22 +15,28 @@ logger = get_logger(__name__)
class QueryExecutor: class QueryExecutor:
"""Executes READ ONLY queries against the database.""" """Executes READ ONLY queries against the database."""
def __init__(self, connection_manager: ConnectionManager): def __init__(
self,
connection_manager: ConnectionManager,
engine: Optional[Any] = None
):
""" """
Initialize query executor. Initialize query executor.
Args: Args:
connection_manager: Connection manager instance connection_manager: Connection manager instance
engine: Optional SQLAlchemy engine for pooled connections
""" """
self.conn_mgr = connection_manager self.conn_mgr = connection_manager
self._engine = engine
def execute_query(self, query: str, params: tuple = None) -> pd.DataFrame: def execute_query(self, query: str, params: Optional[Union[tuple, Dict[str, Any]]] = None) -> pd.DataFrame:
""" """
Execute a SELECT query and return results as DataFrame. Execute a SELECT query and return results as DataFrame.
Args: Args:
query: SQL query string (SELECT only) query: SQL query string (SELECT only)
params: Query parameters params: Query parameters (tuple or dict)
Returns: Returns:
Query results as pandas DataFrame Query results as pandas DataFrame
@@ -38,7 +45,12 @@ class QueryExecutor:
ValueError: If query is not a SELECT statement ValueError: If query is not a SELECT statement
Exception: If query execution fails Exception: If query execution fails
""" """
# Safety check - only allow SELECT queries if self._engine is not None:
return self._execute_query_pooled(query, params)
return self._execute_query_single(query, params)
def _execute_query_single(self, query: str, params: Optional[Union[tuple, Dict[str, Any]]] = None) -> pd.DataFrame:
"""Execute query using single connection (legacy mode)."""
query_upper = query.strip().upper() query_upper = query.strip().upper()
if not query_upper.startswith('SELECT'): if not query_upper.startswith('SELECT'):
raise ValueError("Only SELECT queries are allowed (READ ONLY)") raise ValueError("Only SELECT queries are allowed (READ ONLY)")
@@ -56,7 +68,24 @@ class QueryExecutor:
logger.debug(f"Query: {query}") logger.debug(f"Query: {query}")
raise raise
def execute_scalar(self, query: str, params: tuple = None) -> Any: def _execute_query_pooled(self, query: str, params: Optional[Union[tuple, Dict[str, Any]]] = None) -> pd.DataFrame:
"""Execute query using pooled connection (parallel mode)."""
query_upper = query.strip().upper()
if not query_upper.startswith('SELECT'):
raise ValueError("Only SELECT queries are allowed (READ ONLY)")
try:
with self._engine.connect() as conn:
result = conn.execute(text(query), params or {})
df = pd.DataFrame(result.all(), columns=result.keys())
return df
except Exception as e:
logger.error(f"Query execution failed (pooled): {e}")
logger.debug(f"Query: {query}")
raise
def execute_scalar(self, query: str, params: Optional[Union[tuple, Dict[str, Any]]] = None) -> Any:
""" """
Execute query and return single scalar value. Execute query and return single scalar value.

View File

@@ -1,14 +1,16 @@
"""Comparison service for executing database comparisons.""" """Comparison service for executing database comparisons."""
import time import time
from typing import List from concurrent.futures import ThreadPoolExecutor, as_completed
from typing import List, Dict, Any
from drt.database.connection import ConnectionManager from drt.database.connection import ConnectionManager
from drt.database.connection_pool import ConnectionPool
from drt.database.executor import QueryExecutor from drt.database.executor import QueryExecutor
from drt.config.models import Config, DatabasePairConfig from drt.config.models import Config, DatabasePairConfig
from drt.models.table import TableInfo from drt.models.table import TableInfo
from drt.models.results import ComparisonResult from drt.models.results import ComparisonResult, CheckResult
from drt.models.summary import ExecutionSummary from drt.models.summary import ExecutionSummary
from drt.models.enums import Status from drt.models.enums import Status, CheckType
from drt.services.checkers import ( from drt.services.checkers import (
ExistenceChecker, ExistenceChecker,
RowCountChecker, RowCountChecker,
@@ -56,13 +58,23 @@ class ComparisonService:
target_mgr = ConnectionManager(db_pair.target) target_mgr = ConnectionManager(db_pair.target)
try: try:
baseline_engine = None
target_engine = None
# Connect to databases # Connect to databases
baseline_mgr.connect() baseline_mgr.connect()
target_mgr.connect() target_mgr.connect()
# Create executors max_workers = self.config.execution.max_workers
baseline_executor = QueryExecutor(baseline_mgr) pool_size = max_workers + 2
target_executor = QueryExecutor(target_mgr)
# Create pooled engines for parallel execution
baseline_engine = baseline_mgr.get_pooled_engine(pool_size)
target_engine = target_mgr.get_pooled_engine(pool_size)
# Create executors with pooled connections
baseline_executor = QueryExecutor(baseline_mgr, engine=baseline_engine)
target_executor = QueryExecutor(target_mgr, engine=target_engine)
# Initialize checkers # Initialize checkers
existence_checker = ExistenceChecker(baseline_executor, target_executor, self.config) existence_checker = ExistenceChecker(baseline_executor, target_executor, self.config)
@@ -124,6 +136,10 @@ class ComparisonService:
return summary return summary
finally: finally:
if baseline_engine is not None:
baseline_engine.dispose()
if target_engine is not None:
target_engine.dispose()
baseline_mgr.disconnect() baseline_mgr.disconnect()
target_mgr.disconnect() target_mgr.disconnect()
@@ -145,7 +161,7 @@ class ComparisonService:
) )
try: try:
# Check existence first # Check existence first (must be sequential)
check_start = time.time() check_start = time.time()
existence_result = existence_checker.check(table) existence_result = existence_checker.check(table)
existence_time = (time.time() - check_start) * 1000 existence_time = (time.time() - check_start) * 1000
@@ -154,26 +170,17 @@ class ComparisonService:
# Only proceed with other checks if table exists in both # Only proceed with other checks if table exists in both
if existence_result.status == Status.PASS: if existence_result.status == Status.PASS:
# Row count check # Run row count, schema, and aggregate checkers in parallel
check_start = time.time() parallel_results = self._run_checkers_parallel(
row_count_result = row_count_checker.check(table) table,
row_count_time = (time.time() - check_start) * 1000 row_count_checker,
logger.debug(f" └─ Row count check: {row_count_time:.0f}ms") schema_checker,
result.add_check(row_count_result) aggregate_checker
)
# Schema check
check_start = time.time() # Add all results to the comparison result
schema_result = schema_checker.check(table) for name, check_result in parallel_results.items():
schema_time = (time.time() - check_start) * 1000 result.add_check(check_result)
logger.debug(f" └─ Schema check: {schema_time:.0f}ms")
result.add_check(schema_result)
# Aggregate check
check_start = time.time()
aggregate_result = aggregate_checker.check(table)
aggregate_time = (time.time() - check_start) * 1000
logger.debug(f" └─ Aggregate check: {aggregate_time:.0f}ms")
result.add_check(aggregate_result)
except Exception as e: except Exception as e:
logger.error(f"Comparison failed for {table.full_name}: {e}") logger.error(f"Comparison failed for {table.full_name}: {e}")
@@ -184,6 +191,108 @@ class ComparisonService:
logger.debug(f" └─ Total table time: {result.execution_time_ms}ms") logger.debug(f" └─ Total table time: {result.execution_time_ms}ms")
return result return result
def _run_checkers_parallel(
self,
table: TableInfo,
row_count_checker: RowCountChecker,
schema_checker: SchemaChecker,
aggregate_checker: AggregateChecker
) -> Dict[str, CheckResult]:
"""
Run row count, schema, and aggregate checkers in parallel.
Args:
table: Table information
row_count_checker: Row count checker instance
schema_checker: Schema checker instance
aggregate_checker: Aggregate checker instance
Returns:
Dictionary mapping checker names to their results
Note:
Uses fail-fast behavior - cancels remaining checkers on first error.
"""
max_workers = self.config.execution.max_workers
results: Dict[str, CheckResult] = {}
exceptions: Dict[str, Exception] = {}
# Build list of checkers to run in parallel
checkers_to_run: List[tuple] = []
if self.config.comparison.row_count.enabled:
checkers_to_run.append(("row_count", row_count_checker))
if self.config.comparison.schema.enabled:
checkers_to_run.append(("schema", schema_checker))
if self.config.comparison.aggregates.enabled and table.aggregate_columns:
checkers_to_run.append(("aggregate", aggregate_checker))
# If only one checker or none, run sequentially
if len(checkers_to_run) <= 1:
for name, checker in checkers_to_run:
check_start = time.time()
result = checker.check(table)
check_time = (time.time() - check_start) * 1000
logger.debug(f" └─ {name} check: {check_time:.0f}ms (sequential)")
results[name] = result
return results
# Run checkers in parallel
check_start = time.time()
logger.debug(f" └─ Running {len(checkers_to_run)} checkers in parallel (max_workers={max_workers})")
def run_checker(checker_data: tuple) -> tuple[str, CheckResult]:
"""Helper to run a single checker and return result."""
name, checker = checker_data
try:
result = checker.check(table)
return (name, result)
except Exception as e:
logger.error(f"{name} checker failed: {e}")
raise
with ThreadPoolExecutor(max_workers=max_workers) as executor:
futures = {
executor.submit(run_checker, checker_data): checker_data[0]
for checker_data in checkers_to_run
}
for future in as_completed(futures):
checker_name = futures[future]
try:
name, result = future.result()
results[name] = result
# Fail-fast on ERROR status
if result.status == Status.ERROR:
logger.debug(f" └─ {name} check failed with ERROR, cancelling remaining checkers")
for f in futures:
f.cancel()
break
except Exception as e:
exceptions[checker_name] = e
logger.debug(f" └─ {checker_name} check raised exception, cancelling remaining checkers")
for f in futures:
f.cancel()
break
check_time = (time.time() - check_start) * 1000
logger.debug(f" └─ Total parallel checker time: {check_time:.0f}ms")
# Handle exceptions
for name, exc in exceptions.items():
results[name] = CheckResult(
check_type=CheckType.ROW_COUNT if name == "row_count" else
CheckType.SCHEMA if name == "schema" else CheckType.AGGREGATE,
status=Status.ERROR,
message=f"{name} check error: {str(exc)}"
)
return results
def _get_tables_to_compare(self) -> List[TableInfo]: def _get_tables_to_compare(self) -> List[TableInfo]:
"""Get list of tables to compare based on configuration.""" """Get list of tables to compare based on configuration."""
tables = [] tables = []

View File

@@ -0,0 +1,177 @@
"""Test parallel checker execution logic."""
import pytest
from unittest.mock import Mock, MagicMock, patch
from concurrent.futures import ThreadPoolExecutor
from typing import Dict, Any
import time
class TestParallelCheckerLogic:
"""Test parallel checker execution logic."""
def test_parallel_execution_with_multiple_checkers(self):
"""Test that multiple checkers can run in parallel."""
# Don't import modules that need pyodbc - test the logic conceptually
# Create mock results
from drt.models.results import CheckResult
from drt.models.enums import Status, CheckType
from drt.models.table import TableInfo
row_count_result = CheckResult(
check_type=CheckType.ROW_COUNT,
status=Status.PASS,
message="Row count matches"
)
schema_result = CheckResult(
check_type=CheckType.SCHEMA,
status=Status.PASS,
message="Schema matches"
)
aggregate_result = CheckResult(
check_type=CheckType.AGGREGATE,
status=Status.PASS,
message="Aggregates match"
)
# Mock checkers
mock_row_count_checker = Mock()
mock_schema_checker = Mock()
mock_aggregate_checker = Mock()
# Simulate check execution with delay
def mock_check_row_count(table):
time.sleep(0.05)
return row_count_result
def mock_check_schema(table):
time.sleep(0.05)
return schema_result
def mock_check_aggregate(table):
time.sleep(0.05)
return aggregate_result
mock_row_count_checker.check = mock_check_row_count
mock_schema_checker.check = mock_check_schema
mock_aggregate_checker.check = mock_check_aggregate
# Create table mock
table = TableInfo(
schema="dbo",
name="TestTable",
enabled=True,
expected_in_target=True,
aggregate_columns=["Amount", "Quantity"]
)
# Test parallel execution timing
start = time.time()
results = {}
checkers = [
("row_count", mock_row_count_checker),
("schema", mock_schema_checker),
("aggregate", mock_aggregate_checker)
]
with ThreadPoolExecutor(max_workers=3) as executor:
futures = {
executor.submit(lambda c: c[1].check(table), c): c[0]
for c in checkers
}
for future in futures:
name = futures[future]
results[name] = future.result()
elapsed = time.time() - start
# Verify results
assert len(results) == 3
assert results["row_count"].status == Status.PASS
assert results["schema"].status == Status.PASS
assert results["aggregate"].status == Status.PASS
# Verify parallel execution (should be ~0.05s, not 0.15s if sequential)
assert elapsed < 0.15, f"Expected parallel execution but took {elapsed:.2f}s"
def test_fail_fast_on_error(self):
"""Test that parallel execution cancels remaining checkers on error."""
from drt.models.results import CheckResult
from drt.models.enums import Status, CheckType
# Mock checkers where second one fails
mock_checker1 = Mock()
mock_checker2 = Mock()
mock_checker3 = Mock()
result1 = CheckResult(
check_type=CheckType.ROW_COUNT,
status=Status.PASS,
message="OK"
)
result2 = CheckResult(
check_type=CheckType.SCHEMA,
status=Status.ERROR,
message="Database connection failed"
)
call_count = {"checker1": 0, "checker2": 0, "checker3": 0}
def mock_check1(table):
call_count["checker1"] += 1
time.sleep(0.05)
return result1
def mock_check2(table):
call_count["checker2"] += 1
time.sleep(0.05)
return result2
def mock_check3(table):
call_count["checker3"] += 1
time.sleep(0.05)
return CheckResult(check_type=CheckType.AGGREGATE, status=Status.PASS, message="OK")
mock_checker1.check = mock_check1
mock_checker2.check = mock_check2
mock_checker3.check = mock_check3
# Run with fail-fast
from drt.models.table import TableInfo
table = TableInfo(schema="dbo", name="TestTable", enabled=True)
results = {}
checkers = [
("row_count", mock_checker1),
("schema", mock_checker2),
("aggregate", mock_checker3)
]
with ThreadPoolExecutor(max_workers=3) as executor:
futures = {
executor.submit(lambda c: c[1].check(table), c): c[0]
for c in checkers
}
for future in futures:
name = futures[future]
try:
result = future.result()
results[name] = result
if result.status == Status.ERROR:
# Cancel remaining
for f in futures:
f.cancel()
break
except Exception:
pass
# Verify that we got at least one result
assert len(results) >= 1
if __name__ == "__main__":
pytest.main([__file__, "-v"])

View File

@@ -0,0 +1,88 @@
"""Test parallelization features."""
import pytest
from drt.config.models import Config, ExecutionConfig
class TestParallelizationConfig:
"""Test parallelization configuration."""
def test_default_max_workers(self):
"""Test default max_workers value."""
config = Config()
assert config.execution.max_workers == 4
def test_custom_max_workers(self):
"""Test custom max_workers value."""
config = Config(execution=ExecutionConfig(max_workers=8))
assert config.execution.max_workers == 8
def test_max_workers_positive(self):
"""Test that max_workers is positive."""
config = Config()
assert config.execution.max_workers > 0
def test_continue_on_error_default(self):
"""Test default continue_on_error value."""
config = Config()
assert config.execution.continue_on_error is True
def test_imports():
"""Test that all parallelization modules can be imported."""
from drt.config.models import Config, ExecutionConfig
from urllib.parse import quote_plus
from sqlalchemy import create_engine, text, QueuePool
assert Config is not None
assert ExecutionConfig is not None
assert quote_plus is not None
assert create_engine is not None
assert text is not None
assert QueuePool is not None
def test_url_encoding():
"""Test URL encoding for connection strings."""
from urllib.parse import quote_plus
test_conn_str = "DRIVER={ODBC Driver};SERVER=localhost;PWD=test@pass#123"
encoded = quote_plus(test_conn_str)
# Special characters should be encoded
assert "%40" in encoded # @ encoded
assert "%23" in encoded # # encoded
assert "%3D" in encoded # = encoded
assert encoded != test_conn_str # Should be different after encoding
# Test that it can be decoded back
from urllib.parse import unquote_plus
decoded = unquote_plus(encoded)
assert decoded == test_conn_str
def test_config_load_minimal():
"""Test loading a minimal config with parallelization settings."""
config_dict = {
"database_pairs": [
{
"name": "Test",
"enabled": True,
"baseline": {"server": "S1", "database": "D1"},
"target": {"server": "S1", "database": "D2"}
}
],
"execution": {
"max_workers": 6,
"continue_on_error": False
},
"tables": []
}
config = Config(**config_dict)
assert config.execution.max_workers == 6
assert config.execution.continue_on_error is False
if __name__ == "__main__":
pytest.main([__file__, "-v"])