diff --git a/refactron/cli.py b/refactron/cli.py index 92b6a9d..42ffe7a 100644 --- a/refactron/cli.py +++ b/refactron/cli.py @@ -145,6 +145,43 @@ def _print_refactor_messages(summary: dict, preview: bool) -> None: console.print("\n[green]✅ Refactoring completed! Don't forget to test your code.[/green]") +def _print_operation_with_risk(operation) -> None: + """Print refactoring operation with color-coded risk visualization.""" + from rich.panel import Panel + + from refactron.core.risk_assessment import RiskAssessor + + # Determine risk level and color using centralized logic + assessor = RiskAssessor() + risk_score = operation.risk_score + risk_color = assessor.get_risk_display_color(risk_score) + risk_label = assessor.get_risk_display_label(risk_score) + + # Build content + content = f"[bold]{operation.operation_type}[/bold]\n" + content += f"Location: {operation.file_path}:{operation.line_number}\n" + content += f"Risk: [{risk_color}]{risk_score:.2f} - {risk_label}[/{risk_color}]\n" + content += f"\n{operation.description}" + + # Add risk factors if available + if "risk_factors" in operation.metadata: + risk_factors = operation.metadata["risk_factors"] + content += "\n\n[bold]Risk Breakdown:[/bold]\n" + content += f" • Impact Scope: {risk_factors.get('impact_scope', 0):.2f}\n" + content += f" • Change Type: {risk_factors.get('change_type_risk', 0):.2f}\n" + content += f" • Test Coverage: {risk_factors.get('test_coverage_risk', 0):.2f}\n" + content += f" • Dependencies: {risk_factors.get('dependency_risk', 0):.2f}\n" + content += f" • Complexity: {risk_factors.get('complexity_risk', 0):.2f}" + + if risk_factors.get("test_file_exists"): + content += "\n\n✓ Test file exists" + else: + content += "\n\n⚠ No test file found - higher risk" + + panel = Panel(content, border_style=risk_color, expand=False) + console.print(panel) + + @click.group() @click.version_option(version="1.0.0") def main() -> None: diff --git a/refactron/core/refactor_result.py b/refactron/core/refactor_result.py index 08c8bd4..36b26bb 100644 --- a/refactron/core/refactor_result.py +++ b/refactron/core/refactor_result.py @@ -39,7 +39,7 @@ def operations_by_type(self, operation_type: str) -> List[RefactoringOperation]: return [op for op in self.operations if op.operation_type == operation_type] def show_diff(self) -> str: - """Show a diff of all operations.""" + """Show a diff of all operations with detailed risk assessment.""" lines = [] lines.append("=" * 80) lines.append("REFACTORING PREVIEW") @@ -54,7 +54,39 @@ def show_diff(self) -> str: lines.append("-" * 80) lines.append(f"Operation {i}: {op.operation_type}") lines.append(f"Location: {op.file_path}:{op.line_number}") - lines.append(f"Risk Score: {op.risk_score:.2f}") + + # Show risk with visual indicator + risk_icon = self._get_risk_icon(op.risk_score) + lines.append(f"Risk Score: {op.risk_score:.2f} {risk_icon}") + + # Show detailed risk factors if available + if "risk_factors" in op.metadata: + lines.append("") + lines.append(" Risk Breakdown:") + risk_factors = op.metadata["risk_factors"] + lines.append(f" • Impact Scope: {risk_factors.get('impact_scope', 0):.2f}") + lines.append( + f" • Change Type Risk: {risk_factors.get('change_type_risk', 0):.2f}" + ) + lines.append( + f" • Test Coverage Risk: {risk_factors.get('test_coverage_risk', 0):.2f}" + ) + lines.append(f" • Dependency Risk: {risk_factors.get('dependency_risk', 0):.2f}") + lines.append(f" • Complexity Risk: {risk_factors.get('complexity_risk', 0):.2f}") + + # Show affected components + if risk_factors.get("affected_functions"): + lines.append( + f" • Affected Functions: {', '.join(risk_factors['affected_functions'])}" + ) + + # Show test coverage status + if risk_factors.get("test_file_exists"): + lines.append(" • Test Coverage: ✓ Test file exists") + else: + lines.append(" • Test Coverage: ⚠ No test file found") + + lines.append("") lines.append(f"Description: {op.description}") if op.reasoning: @@ -74,6 +106,13 @@ def show_diff(self) -> str: lines.append("=" * 80) return "\n".join(lines) + def _get_risk_icon(self, risk_score: float) -> str: + """Get visual indicator for risk level.""" + from refactron.core.risk_assessment import RiskAssessor + + assessor = RiskAssessor() + return assessor.get_risk_display_label(risk_score) + def apply(self) -> bool: """Apply the refactoring operations (placeholder).""" # This would actually apply the changes to files diff --git a/refactron/core/risk_assessment.py b/refactron/core/risk_assessment.py new file mode 100644 index 0000000..bf7b598 --- /dev/null +++ b/refactron/core/risk_assessment.py @@ -0,0 +1,519 @@ +"""Advanced risk assessment for refactoring operations.""" + +import ast +from dataclasses import dataclass, field +from enum import Enum +from pathlib import Path +from typing import Dict, List, Optional, Tuple + + +class ChangeType(Enum): + """Type of refactoring change.""" + + RENAMING = "renaming" # Simple renaming - lowest risk + EXTRACTION = "extraction" # Extracting code to new location + RESTRUCTURING = "restructuring" # Major structural changes + API_CHANGE = "api_change" # Changes to function signatures + LOGIC_CHANGE = "logic_change" # Changes to control flow or logic + + +class RiskLevel(Enum): + """Risk level categories.""" + + SAFE = "safe" # 0.0 - 0.3 + LOW = "low" # 0.3 - 0.5 + MODERATE = "moderate" # 0.5 - 0.7 + HIGH = "high" # 0.7 - 0.9 + CRITICAL = "critical" # 0.9 - 1.0 + + +@dataclass +class RiskFactors: + """Individual risk factors for a refactoring operation.""" + + impact_scope: float = 0.0 # 0.0-1.0: How many parts of code affected + change_type_risk: float = 0.0 # 0.0-1.0: Risk based on change type + test_coverage_risk: float = 0.0 # 0.0-1.0: Risk from low test coverage + dependency_risk: float = 0.0 # 0.0-1.0: Risk from breaking dependencies + complexity_risk: float = 0.0 # 0.0-1.0: Risk from code complexity + + # Metadata for detailed analysis + affected_functions: List[str] = field(default_factory=list) + affected_files: List[Path] = field(default_factory=list) + dependencies: List[str] = field(default_factory=list) + has_tests: bool = False + test_file_exists: bool = False + + def to_dict(self) -> Dict: + """Convert risk factors to dictionary.""" + return { + "impact_scope": round(self.impact_scope, 3), + "change_type_risk": round(self.change_type_risk, 3), + "test_coverage_risk": round(self.test_coverage_risk, 3), + "dependency_risk": round(self.dependency_risk, 3), + "complexity_risk": round(self.complexity_risk, 3), + "affected_functions": self.affected_functions, + "affected_files": [str(f) for f in self.affected_files], + "dependencies": self.dependencies, + "has_tests": self.has_tests, + "test_file_exists": self.test_file_exists, + } + + +class RiskAssessor: + """Advanced risk assessment for refactoring operations.""" + + # Change type risk weights + CHANGE_TYPE_WEIGHTS = { + ChangeType.RENAMING: 0.1, + ChangeType.EXTRACTION: 0.3, + ChangeType.RESTRUCTURING: 0.6, + ChangeType.API_CHANGE: 0.7, + ChangeType.LOGIC_CHANGE: 0.8, + } + + # Impact scope weights + LINE_IMPACT_WEIGHT = 0.3 + FUNCTION_IMPACT_WEIGHT = 0.7 + + # Test coverage thresholds and risk values + TEST_COVERAGE_GOOD_THRESHOLD = 0.5 # Test size >= 50% of source + TEST_COVERAGE_MODERATE_THRESHOLD = 0.2 # Test size >= 20% of source + TEST_COVERAGE_GOOD_RISK = 0.1 + TEST_COVERAGE_MODERATE_RISK = 0.3 + TEST_COVERAGE_MINIMAL_RISK = 0.6 + TEST_COVERAGE_NONE_RISK = 0.8 + TEST_COVERAGE_UNKNOWN_RISK = 0.5 + + # Dependency risk thresholds and values + DEPENDENCY_THRESHOLD_LOW = 5 + DEPENDENCY_THRESHOLD_MEDIUM = 15 + DEPENDENCY_THRESHOLD_HIGH = 30 + DEPENDENCY_RISK_LOW = 0.1 + DEPENDENCY_RISK_MEDIUM = 0.3 + DEPENDENCY_RISK_HIGH = 0.5 + DEPENDENCY_RISK_VERY_HIGH = 0.7 + DEPENDENCY_CALL_WEIGHT = 0.1 # Weight for function calls vs imports + + # Complexity risk constants + MAX_COMPLEXITY_THRESHOLD = 50.0 # Control flow statements for very complex code + + # Risk level thresholds + RISK_THRESHOLD_SAFE = 0.3 + RISK_THRESHOLD_LOW = 0.5 + RISK_THRESHOLD_MODERATE = 0.7 + RISK_THRESHOLD_HIGH = 0.9 + + # Risk level display labels + RISK_DISPLAY_LABELS = { + RiskLevel.SAFE: "✓ SAFE", + RiskLevel.LOW: "⚡ LOW", + RiskLevel.MODERATE: "⚠ MODERATE", + RiskLevel.HIGH: "❌ HIGH", + RiskLevel.CRITICAL: "🔴 CRITICAL", + } + + # Risk level colors for CLI display + RISK_DISPLAY_COLORS = { + RiskLevel.SAFE: "green", + RiskLevel.LOW: "blue", + RiskLevel.MODERATE: "yellow", + RiskLevel.HIGH: "red", + RiskLevel.CRITICAL: "bright_red", + } + + def __init__(self, project_root: Optional[Path] = None): + """Initialize risk assessor. + + Args: + project_root: Root directory of the project for dependency analysis + """ + self.project_root = project_root or Path.cwd() + + def calculate_risk_score( + self, + file_path: Path, + source_code: str, + change_type: ChangeType, + affected_lines: Optional[List[int]] = None, + operation_description: str = "", + ) -> Tuple[float, RiskFactors]: + """Calculate comprehensive risk score for a refactoring operation. + + Args: + file_path: Path to file being refactored + source_code: Current source code + change_type: Type of change being made + affected_lines: Lines of code being changed + operation_description: Description of the operation + + Returns: + Tuple of (overall_risk_score, detailed_risk_factors) + """ + risk_factors = RiskFactors() + + # Calculate individual risk factors + risk_factors.impact_scope = self._calculate_impact_scope( + file_path, source_code, affected_lines + ) + risk_factors.change_type_risk = self.CHANGE_TYPE_WEIGHTS.get(change_type, 0.5) + risk_factors.test_coverage_risk = self._calculate_test_coverage_risk(file_path) + risk_factors.dependency_risk = self._calculate_dependency_risk( + file_path, source_code, affected_lines + ) + risk_factors.complexity_risk = self._calculate_complexity_risk(source_code, affected_lines) + + # Populate test file metadata + risk_factors.test_file_exists = self._find_test_file(file_path) is not None + + # Calculate weighted overall risk score + overall_risk = self._calculate_weighted_risk(risk_factors) + + return overall_risk, risk_factors + + def _calculate_impact_scope( + self, + file_path: Path, + source_code: str, + affected_lines: Optional[List[int]] = None, + ) -> float: + """Calculate how much of the code is affected by the change. + + Returns: + 0.0-1.0 where 1.0 means high impact + """ + try: + tree = ast.parse(source_code) + total_lines = len(source_code.split("\n")) + + if affected_lines: + # Calculate percentage of file affected + affected_percentage = len(affected_lines) / max(total_lines, 1) + + # Count affected functions/classes + affected_funcs = self._count_affected_functions(tree, affected_lines) + total_funcs = self._count_total_functions(tree) + + if total_funcs > 0: + func_percentage = len(affected_funcs) / total_funcs + # Weight function impact higher than line impact + impact = (affected_percentage * self.LINE_IMPACT_WEIGHT) + ( + func_percentage * self.FUNCTION_IMPACT_WEIGHT + ) + else: + impact = affected_percentage + + return min(impact, 1.0) + else: + # If no specific lines provided, assume moderate impact + return 0.5 + + except SyntaxError: + return 0.5 # Unknown, assume moderate risk + + def _count_affected_functions(self, tree: ast.AST, affected_lines: List[int]) -> List[str]: + """Count and list functions affected by the change.""" + affected = [] + + for node in ast.walk(tree): + if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)): + # Check if function overlaps with affected lines + func_start = node.lineno + func_end = getattr(node, "end_lineno", func_start + 10) + + if any(func_start <= line <= func_end for line in affected_lines): + affected.append(node.name) + + return affected + + def _count_total_functions(self, tree: ast.AST) -> int: + """Count total number of functions in the module.""" + count = 0 + for node in ast.walk(tree): + if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)): + count += 1 + return count + + def _calculate_test_coverage_risk(self, file_path: Path) -> float: + """Calculate risk based on test coverage. + + Returns: + 0.0-1.0 where 1.0 means high risk (no tests) + """ + # Check if corresponding test file exists + test_file = self._find_test_file(file_path) + + if test_file and test_file.exists(): + # Test file exists - lower risk + # Check if tests are comprehensive by looking at file size + try: + test_size = test_file.stat().st_size + source_size = file_path.stat().st_size + + if test_size >= source_size * self.TEST_COVERAGE_GOOD_THRESHOLD: + # Good test coverage + return self.TEST_COVERAGE_GOOD_RISK + elif test_size >= source_size * self.TEST_COVERAGE_MODERATE_THRESHOLD: + # Moderate test coverage + return self.TEST_COVERAGE_MODERATE_RISK + else: + # Minimal tests + return self.TEST_COVERAGE_MINIMAL_RISK + except Exception: + return self.TEST_COVERAGE_UNKNOWN_RISK + else: + # No test file found - higher risk + return self.TEST_COVERAGE_NONE_RISK + + def _find_test_file(self, file_path: Path) -> Optional[Path]: + """Find corresponding test file for a source file.""" + # Common test patterns + file_name = file_path.stem + file_dir = file_path.parent + + # Look for test files in common locations + test_patterns = [ + file_dir / f"test_{file_name}.py", + file_dir / f"{file_name}_test.py", + file_dir.parent / "tests" / f"test_{file_name}.py", + self.project_root / "tests" / f"test_{file_name}.py", + ] + + for test_path in test_patterns: + if test_path.exists(): + return test_path + + return None + + def _calculate_dependency_risk( + self, + file_path: Path, + source_code: str, + affected_lines: Optional[List[int]] = None, + ) -> float: + """Calculate risk based on dependencies that might break. + + Returns: + 0.0-1.0 where 1.0 means high risk (many dependencies) + """ + try: + tree = ast.parse(source_code) + + # Count imports (external dependencies) + import_count = 0 + for node in ast.walk(tree): + if isinstance(node, (ast.Import, ast.ImportFrom)): + import_count += 1 + + # Count function calls (internal dependencies) + call_count = 0 + for node in ast.walk(tree): + if isinstance(node, ast.Call): + call_count += 1 + + # Calculate dependency risk based on counts + # More dependencies = higher risk + total_deps = import_count + (call_count * self.DEPENDENCY_CALL_WEIGHT) + + if total_deps < self.DEPENDENCY_THRESHOLD_LOW: + return self.DEPENDENCY_RISK_LOW + elif total_deps < self.DEPENDENCY_THRESHOLD_MEDIUM: + return self.DEPENDENCY_RISK_MEDIUM + elif total_deps < self.DEPENDENCY_THRESHOLD_HIGH: + return self.DEPENDENCY_RISK_HIGH + else: + return self.DEPENDENCY_RISK_VERY_HIGH + + except SyntaxError: + return 0.5 # Unknown, assume moderate risk + + def _calculate_complexity_risk( + self, + source_code: str, + affected_lines: Optional[List[int]] = None, + ) -> float: + """Calculate risk based on code complexity. + + Returns: + 0.0-1.0 where 1.0 means high risk (very complex code) + """ + try: + tree = ast.parse(source_code) + + # Count complexity indicators + complexity_score = 0 + + for node in ast.walk(tree): + # Control flow increases complexity + if isinstance(node, (ast.If, ast.For, ast.While, ast.With)): + complexity_score += 1 + elif isinstance(node, (ast.Try, ast.ExceptHandler)): + complexity_score += 2 + elif isinstance(node, ast.Lambda): + complexity_score += 1 + + # Normalize to 0-1 range + normalized = min(complexity_score / self.MAX_COMPLEXITY_THRESHOLD, 1.0) + + return normalized + + except SyntaxError: + return 0.5 # Unknown, assume moderate risk + + def _calculate_weighted_risk(self, risk_factors: RiskFactors) -> float: + """Calculate overall risk score using weighted factors. + + Weights: + - Change type: 30% (most important) + - Test coverage: 25% + - Dependency: 20% + - Impact scope: 15% + - Complexity: 10% + """ + weights = { + "change_type": 0.30, + "test_coverage": 0.25, + "dependency": 0.20, + "impact_scope": 0.15, + "complexity": 0.10, + } + + overall_risk = ( + risk_factors.change_type_risk * weights["change_type"] + + risk_factors.test_coverage_risk * weights["test_coverage"] + + risk_factors.dependency_risk * weights["dependency"] + + risk_factors.impact_scope * weights["impact_scope"] + + risk_factors.complexity_risk * weights["complexity"] + ) + + return round(overall_risk, 3) + + def get_risk_level(self, risk_score: float) -> RiskLevel: + """Get risk level category from score.""" + if risk_score < self.RISK_THRESHOLD_SAFE: + return RiskLevel.SAFE + elif risk_score < self.RISK_THRESHOLD_LOW: + return RiskLevel.LOW + elif risk_score < self.RISK_THRESHOLD_MODERATE: + return RiskLevel.MODERATE + elif risk_score < self.RISK_THRESHOLD_HIGH: + return RiskLevel.HIGH + else: + return RiskLevel.CRITICAL + + def get_risk_display_label(self, risk_score: float) -> str: + """Get display label for risk score. + + Args: + risk_score: Risk score between 0.0 and 1.0 + + Returns: + Display label with icon (e.g., "✓ SAFE", "⚠ MODERATE") + """ + risk_level = self.get_risk_level(risk_score) + return self.RISK_DISPLAY_LABELS[risk_level] + + def get_risk_display_color(self, risk_score: float) -> str: + """Get display color for risk score (for CLI). + + Args: + risk_score: Risk score between 0.0 and 1.0 + + Returns: + Color name for CLI display (e.g., "green", "yellow") + """ + risk_level = self.get_risk_level(risk_score) + return self.RISK_DISPLAY_COLORS[risk_level] + + def analyze_dependency_impact( + self, file_path: Path, function_name: Optional[str] = None + ) -> Dict[str, List[str]]: + """Analyze what breaks if this refactoring is applied. + + Args: + file_path: File being refactored + function_name: Specific function being changed (optional) + + Returns: + Dictionary with potential breakage points + """ + impact = { + "importing_files": [], + "calling_functions": [], + "dependent_tests": [], + } + + # Find files that import this module + if self.project_root.exists(): + for py_file in self.project_root.rglob("*.py"): + if py_file == file_path: + continue + + # Limit the number of files scanned to avoid excessive I/O on large projects. + max_files = getattr(self, "max_dependency_scan_files", 1000) + scanned_files = 0 + module_name = file_path.stem + + for py_file in self.project_root.rglob("*.py"): + if scanned_files >= max_files: + break + + if py_file == file_path: + continue + + scanned_files += 1 + + try: + content = py_file.read_text(encoding="utf-8", errors="ignore") + except (OSError, UnicodeDecodeError): + # Skip files that can't be read + continue + + try: + tree = ast.parse(content) + except SyntaxError: + # Skip files that can't be parsed + continue + + imports_module = False + calls_function = False + + for node in ast.walk(tree): + # Detect imports of the target module + if isinstance(node, ast.Import): + for alias in node.names: + if alias.name.split(".")[-1] == module_name: + imports_module = True + break + elif isinstance(node, ast.ImportFrom): + if node.module and node.module.split(".")[-1] == module_name: + imports_module = True + + # Detect calls to the target function, if provided + if function_name and isinstance(node, ast.Call): + func = node.func + func_name: Optional[str] = None + if isinstance(func, ast.Name): + func_name = func.id + elif isinstance(func, ast.Attribute): + func_name = func.attr + + if func_name == function_name: + calls_function = True + + # Stop walking early if we've found everything we care about + if imports_module and (not function_name or calls_function): + break + + if imports_module: + impact["importing_files"].append(str(py_file.relative_to(self.project_root))) + + if calls_function: + impact["calling_functions"].append(str(py_file.relative_to(self.project_root))) + + # Find dependent test files + test_file = self._find_test_file(file_path) + if test_file: + impact["dependent_tests"].append(str(test_file)) + + return impact diff --git a/refactron/refactorers/add_docstring_refactorer.py b/refactron/refactorers/add_docstring_refactorer.py index 3ec1341..0f0f99f 100644 --- a/refactron/refactorers/add_docstring_refactorer.py +++ b/refactron/refactorers/add_docstring_refactorer.py @@ -5,12 +5,20 @@ from typing import List, Union from refactron.core.models import RefactoringOperation +from refactron.core.risk_assessment import ChangeType, RiskAssessor from refactron.refactorers.base_refactorer import BaseRefactorer class AddDocstringRefactorer(BaseRefactorer): """Suggests adding docstrings to undocumented functions and classes.""" + # Maximum risk score for docstring additions (just documentation, very safe) + MAX_DOCSTRING_RISK = 0.1 + + def __init__(self, config): + super().__init__(config) + self.risk_assessor = RiskAssessor() + @property def operation_type(self) -> str: return "add_docstring" @@ -67,6 +75,24 @@ def _create_docstring_addition( # Generate version with docstring new_code = self._generate_with_docstring(node, lines) + # Calculate advanced risk score - adding docstrings is very safe + affected_lines = ( + list(range(node.lineno, node.end_lineno + 1)) + if hasattr(node, "end_lineno") + else [node.lineno] + ) + risk_score, risk_factors = self.risk_assessor.calculate_risk_score( + file_path=file_path, + source_code="\n".join(lines), + change_type=ChangeType.RENAMING, # Safest type, just documentation + affected_lines=affected_lines, + operation_description=f"Add docstring to '{node.name}'", + ) + + # Override to ensure very low risk for docstrings + risk_score = min(risk_score, self.MAX_DOCSTRING_RISK) + risk_factors.affected_functions = [node.name] + return RefactoringOperation( operation_type=self.operation_type, file_path=file_path, @@ -74,12 +100,16 @@ def _create_docstring_addition( description=f"Add docstring to {entity_type.lower()} '{node.name}'", old_code=old_code, new_code=new_code, - risk_score=0.0, # Very safe - only adding documentation + risk_score=risk_score, reasoning=f"Adding a docstring improves code documentation and helps other " f"developers understand what this {entity_type.lower()} does. " f"Good docstrings include a brief description, parameters (Args), " f"and return values (Returns).", - metadata={"entity_type": entity_type, "entity_name": node.name}, + metadata={ + "entity_type": entity_type, + "entity_name": node.name, + "risk_factors": risk_factors.to_dict(), + }, ) def _generate_with_docstring( diff --git a/refactron/refactorers/magic_number_refactorer.py b/refactron/refactorers/magic_number_refactorer.py index 85f1966..f57ef19 100644 --- a/refactron/refactorers/magic_number_refactorer.py +++ b/refactron/refactorers/magic_number_refactorer.py @@ -5,12 +5,17 @@ from typing import Dict, List, Tuple, Union from refactron.core.models import RefactoringOperation +from refactron.core.risk_assessment import ChangeType, RiskAssessor from refactron.refactorers.base_refactorer import BaseRefactorer class MagicNumberRefactorer(BaseRefactorer): """Suggests extracting magic numbers into named constants.""" + def __init__(self, config): + super().__init__(config) + self.risk_assessor = RiskAssessor() + @property def operation_type(self) -> str: return "extract_constant" @@ -120,6 +125,23 @@ def _create_refactoring( new_code = "\n".join(constant_defs) + "\n\n" + new_func_code + # Calculate advanced risk score + affected_lines = ( + list(range(func_node.lineno, func_node.end_lineno + 1)) + if hasattr(func_node, "end_lineno") + else [func_node.lineno] + ) + risk_score, risk_factors = self.risk_assessor.calculate_risk_score( + file_path=file_path, + source_code="\n".join(lines), + change_type=ChangeType.EXTRACTION, + affected_lines=affected_lines, + operation_description=f"Extract magic numbers to constants in '{func_name}'", + ) + + # Store affected functions in risk factors + risk_factors.affected_functions = [func_name] + return RefactoringOperation( operation_type=self.operation_type, file_path=file_path, @@ -127,11 +149,15 @@ def _create_refactoring( description=f"Extract magic numbers to named constants in '{func_name}'", old_code=old_code, new_code=new_code, - risk_score=0.1, # Very safe refactoring + risk_score=risk_score, reasoning=f"Extracting {len(constants)} magic numbers to named constants " f"improves code readability and maintainability. " f"Constants can be reused and their meaning is clear.", - metadata={"constants": list(constants.items()), "function_name": func_name}, + metadata={ + "constants": list(constants.items()), + "function_name": func_name, + "risk_factors": risk_factors.to_dict(), + }, ) def _generate_constant_name(self, value: float) -> str: diff --git a/refactron/refactorers/reduce_parameters_refactorer.py b/refactron/refactorers/reduce_parameters_refactorer.py index 75e6c0d..28513fa 100644 --- a/refactron/refactorers/reduce_parameters_refactorer.py +++ b/refactron/refactorers/reduce_parameters_refactorer.py @@ -5,12 +5,17 @@ from typing import List, Union from refactron.core.models import RefactoringOperation +from refactron.core.risk_assessment import ChangeType, RiskAssessor from refactron.refactorers.base_refactorer import BaseRefactorer class ReduceParametersRefactorer(BaseRefactorer): """Suggests using configuration objects for functions with many parameters.""" + def __init__(self, config): + super().__init__(config) + self.risk_assessor = RiskAssessor() + @property def operation_type(self) -> str: return "reduce_parameters" @@ -65,6 +70,23 @@ def _create_parameter_reduction( # Generate refactored version with config object new_code = self._generate_with_config_object(func_node, old_code) + # Calculate advanced risk score - API changes are higher risk + affected_lines = ( + list(range(func_node.lineno, func_node.end_lineno + 1)) + if hasattr(func_node, "end_lineno") + else [func_node.lineno] + ) + risk_score, risk_factors = self.risk_assessor.calculate_risk_score( + file_path=file_path, + source_code="\n".join(lines), + change_type=ChangeType.API_CHANGE, # This is an API change + affected_lines=affected_lines, + operation_description=f"Replace parameters with config object in '{func_node.name}'", + ) + + # Store affected functions in risk factors + risk_factors.affected_functions = [func_node.name] + return RefactoringOperation( operation_type=self.operation_type, file_path=file_path, @@ -75,7 +97,7 @@ def _create_parameter_reduction( ), old_code=old_code, new_code=new_code, - risk_score=0.4, # Moderate risk - API change + risk_score=risk_score, reasoning=( f"This function has {param_count} parameters (limit: " f"{self.config.max_parameters}). " @@ -87,6 +109,7 @@ def _create_parameter_reduction( "parameter_count": param_count, "function_name": func_node.name, "parameters": [arg.arg for arg in func_node.args.args], + "risk_factors": risk_factors.to_dict(), }, ) diff --git a/refactron/refactorers/simplify_conditionals_refactorer.py b/refactron/refactorers/simplify_conditionals_refactorer.py index 627c04d..78caa7d 100644 --- a/refactron/refactorers/simplify_conditionals_refactorer.py +++ b/refactron/refactorers/simplify_conditionals_refactorer.py @@ -5,12 +5,17 @@ from typing import List, Union from refactron.core.models import RefactoringOperation +from refactron.core.risk_assessment import ChangeType, RiskAssessor from refactron.refactorers.base_refactorer import BaseRefactorer class SimplifyConditionalsRefactorer(BaseRefactorer): """Suggests simplifying deeply nested conditionals.""" + def __init__(self, config): + super().__init__(config) + self.risk_assessor = RiskAssessor() + @property def operation_type(self) -> str: return "simplify_conditionals" @@ -75,6 +80,23 @@ def _create_simplification( # Generate simplified version using early returns new_code = self._generate_simplified_version(func_node, old_code) + # Calculate advanced risk score - logic changes are moderate to high risk + affected_lines = ( + list(range(func_node.lineno, func_node.end_lineno + 1)) + if hasattr(func_node, "end_lineno") + else [func_node.lineno] + ) + risk_score, risk_factors = self.risk_assessor.calculate_risk_score( + file_path=file_path, + source_code="\n".join(lines), + change_type=ChangeType.LOGIC_CHANGE, # This changes control flow + affected_lines=affected_lines, + operation_description=f"Simplify conditionals in '{func_node.name}'", + ) + + # Store affected functions in risk factors + risk_factors.affected_functions = [func_node.name] + return RefactoringOperation( operation_type=self.operation_type, file_path=file_path, @@ -84,14 +106,18 @@ def _create_simplification( ), old_code=old_code, new_code=new_code, - risk_score=0.3, # Moderate risk - changes control flow + risk_score=risk_score, reasoning=( f"This function has {depth} levels of nesting. " f"Using early returns (guard clauses) reduces nesting and " f"improves readability. Each condition is checked upfront, " f"making the logic easier to follow." ), - metadata={"original_depth": depth, "function_name": func_node.name}, + metadata={ + "original_depth": depth, + "function_name": func_node.name, + "risk_factors": risk_factors.to_dict(), + }, ) def _generate_simplified_version( diff --git a/tests/test_refactorers.py b/tests/test_refactorers.py index 3ecf975..bf12176 100644 --- a/tests/test_refactorers.py +++ b/tests/test_refactorers.py @@ -37,9 +37,14 @@ def calculate_discount(price): # Should suggest extracting constants op = operations[0] assert "constant" in op.description.lower() - assert op.risk_score < 0.3 # Should be safe + # Risk score is now calculated using advanced assessment + # Without tests, it can be moderate risk + assert op.risk_score < 0.7 # Should be safe to moderate assert "THRESHOLD" in op.new_code or "DISCOUNT" in op.new_code + # Check that risk factors are included + assert "risk_factors" in op.metadata + def test_ignores_common_numbers(self): config = RefactronConfig() refactorer = MagicNumberRefactorer(config) @@ -167,7 +172,11 @@ def calculate_total(price, tax): op = operations[0] assert "docstring" in op.description.lower() - assert op.risk_score == 0.0 # Perfectly safe + # With advanced risk assessment, docstrings are still very safe (capped at 0.1) + assert op.risk_score <= 0.1 # Very safe - just documentation + + # Check that risk factors are included + assert "risk_factors" in op.metadata assert "'''" in op.new_code assert "Args:" in op.new_code assert "Returns:" in op.new_code diff --git a/tests/test_risk_assessment.py b/tests/test_risk_assessment.py new file mode 100644 index 0000000..b7c44b1 --- /dev/null +++ b/tests/test_risk_assessment.py @@ -0,0 +1,422 @@ +"""Tests for advanced risk assessment module.""" + +import ast +import tempfile +from pathlib import Path + +from refactron.core.risk_assessment import ( + ChangeType, + RiskAssessor, + RiskFactors, + RiskLevel, +) + + +class TestRiskFactors: + """Test RiskFactors dataclass.""" + + def test_risk_factors_initialization(self): + """Test creating risk factors with default values.""" + factors = RiskFactors() + assert factors.impact_scope == 0.0 + assert factors.change_type_risk == 0.0 + assert factors.test_coverage_risk == 0.0 + assert factors.dependency_risk == 0.0 + assert factors.complexity_risk == 0.0 + assert factors.affected_functions == [] + assert factors.affected_files == [] + assert factors.has_tests is False + + def test_risk_factors_to_dict(self): + """Test converting risk factors to dictionary.""" + factors = RiskFactors( + impact_scope=0.5, + change_type_risk=0.3, + test_coverage_risk=0.8, + dependency_risk=0.2, + complexity_risk=0.4, + affected_functions=["test_func"], + has_tests=True, + ) + + result = factors.to_dict() + assert result["impact_scope"] == 0.5 + assert result["change_type_risk"] == 0.3 + assert result["test_coverage_risk"] == 0.8 + assert result["dependency_risk"] == 0.2 + assert result["complexity_risk"] == 0.4 + assert result["affected_functions"] == ["test_func"] + assert result["has_tests"] is True + + +class TestRiskAssessor: + """Test RiskAssessor functionality.""" + + def test_risk_assessor_initialization(self): + """Test creating a risk assessor.""" + assessor = RiskAssessor() + assert assessor.project_root is not None + + def test_calculate_risk_score_basic(self): + """Test basic risk score calculation.""" + assessor = RiskAssessor() + + code = """ +def simple_function(): + return 42 +""" + + with tempfile.NamedTemporaryFile(mode="w", suffix=".py", delete=False) as f: + f.write(code) + temp_file = Path(f.name) + + try: + risk_score, risk_factors = assessor.calculate_risk_score( + file_path=temp_file, + source_code=code, + change_type=ChangeType.RENAMING, + affected_lines=[1, 2], + ) + + assert 0.0 <= risk_score <= 1.0 + assert isinstance(risk_factors, RiskFactors) + assert risk_factors.change_type_risk == 0.1 # RENAMING weight + finally: + temp_file.unlink() + + def test_calculate_risk_score_api_change(self): + """Test risk score for API changes.""" + assessor = RiskAssessor() + + code = """ +def api_function(param1, param2, param3): + return param1 + param2 + param3 +""" + + with tempfile.NamedTemporaryFile(mode="w", suffix=".py", delete=False) as f: + f.write(code) + temp_file = Path(f.name) + + try: + risk_score, risk_factors = assessor.calculate_risk_score( + file_path=temp_file, + source_code=code, + change_type=ChangeType.API_CHANGE, + affected_lines=[1, 2], + ) + + assert risk_score > 0.3 # API changes are higher risk + assert risk_factors.change_type_risk == 0.7 # API_CHANGE weight + finally: + temp_file.unlink() + + def test_impact_scope_calculation(self): + """Test impact scope calculation.""" + assessor = RiskAssessor() + + code = """ +def func1(): + pass + +def func2(): + pass + +def func3(): + pass +""" + + # Test affecting one function out of three + impact = assessor._calculate_impact_scope( + file_path=Path("test.py"), + source_code=code, + affected_lines=[1, 2, 3], + ) + + assert 0.0 <= impact <= 1.0 + # Should have moderate impact (1 of 3 functions) + assert impact < 0.8 + + def test_test_coverage_risk_no_tests(self): + """Test test coverage risk when no test file exists.""" + with tempfile.TemporaryDirectory() as tmpdir: + tmppath = Path(tmpdir) + source_file = tmppath / "module.py" + source_file.write_text("def test(): pass") + + assessor = RiskAssessor(project_root=tmppath) + risk = assessor._calculate_test_coverage_risk(source_file) + + # Should be high risk with no test file + assert risk >= 0.5 + + def test_test_coverage_risk_with_tests(self): + """Test test coverage risk when test file exists.""" + with tempfile.TemporaryDirectory() as tmpdir: + tmppath = Path(tmpdir) + + # Create source file + source_file = tmppath / "module.py" + source_file.write_text("def func(): pass") + + # Create test file + test_file = tmppath / "test_module.py" + test_file.write_text("def test_func(): assert True") + + assessor = RiskAssessor(project_root=tmppath) + risk = assessor._calculate_test_coverage_risk(source_file) + + # Should be lower risk with test file + assert risk < 0.8 + + def test_dependency_risk_calculation(self): + """Test dependency risk calculation.""" + assessor = RiskAssessor() + + # Code with few dependencies + simple_code = """ +def simple(): + return 1 + 1 +""" + + risk_low = assessor._calculate_dependency_risk( + file_path=Path("test.py"), + source_code=simple_code, + affected_lines=[1], + ) + + # Code with many dependencies + complex_code = """ +import os +import sys +import json +import requests +from pathlib import Path + +def complex(): + os.path.exists('.') + sys.exit(0) + json.loads('{}') + requests.get('url') + Path('.').exists() +""" + + risk_high = assessor._calculate_dependency_risk( + file_path=Path("test.py"), + source_code=complex_code, + affected_lines=[1], + ) + + # More dependencies should mean higher risk + assert risk_high > risk_low + + def test_complexity_risk_calculation(self): + """Test complexity risk calculation.""" + assessor = RiskAssessor() + + # Simple code + simple_code = """ +def simple(): + return 42 +""" + + risk_low = assessor._calculate_complexity_risk(simple_code, affected_lines=None) + + # Complex code with many control structures + complex_code = """ +def complex(): + if x: + for i in range(10): + while y: + try: + if z: + with open('f') as f: + if a: + for j in range(5): + pass + except Exception: + pass +""" + + risk_high = assessor._calculate_complexity_risk(complex_code, affected_lines=None) + + # More complexity should mean higher risk + assert risk_high > risk_low + + def test_weighted_risk_calculation(self): + """Test weighted overall risk calculation.""" + assessor = RiskAssessor() + + # Create risk factors with known values + factors = RiskFactors( + impact_scope=0.5, + change_type_risk=0.7, # 30% weight + test_coverage_risk=0.8, # 25% weight + dependency_risk=0.3, # 20% weight + complexity_risk=0.4, # 10% weight + ) + + risk = assessor._calculate_weighted_risk(factors) + + # Should be weighted average + expected = (0.7 * 0.30) + (0.8 * 0.25) + (0.3 * 0.20) + (0.5 * 0.15) + (0.4 * 0.10) + assert abs(risk - expected) < 0.01 + + def test_get_risk_level(self): + """Test risk level categorization.""" + assessor = RiskAssessor() + + assert assessor.get_risk_level(0.1) == RiskLevel.SAFE + assert assessor.get_risk_level(0.4) == RiskLevel.LOW + assert assessor.get_risk_level(0.6) == RiskLevel.MODERATE + assert assessor.get_risk_level(0.8) == RiskLevel.HIGH + assert assessor.get_risk_level(0.95) == RiskLevel.CRITICAL + + def test_change_type_weights(self): + """Test that change type weights are properly defined.""" + assert RiskAssessor.CHANGE_TYPE_WEIGHTS[ChangeType.RENAMING] == 0.1 + assert RiskAssessor.CHANGE_TYPE_WEIGHTS[ChangeType.EXTRACTION] == 0.3 + assert RiskAssessor.CHANGE_TYPE_WEIGHTS[ChangeType.RESTRUCTURING] == 0.6 + assert RiskAssessor.CHANGE_TYPE_WEIGHTS[ChangeType.API_CHANGE] == 0.7 + assert RiskAssessor.CHANGE_TYPE_WEIGHTS[ChangeType.LOGIC_CHANGE] == 0.8 + + def test_count_affected_functions(self): + """Test counting functions affected by changes.""" + assessor = RiskAssessor() + + code = """ +def func1(): + pass + +def func2(): + pass + +def func3(): + pass +""" + + tree = ast.parse(code) + + # Lines affecting only func1 (lines 1-2) + affected = assessor._count_affected_functions(tree, [1, 2]) + assert len(affected) == 1 + assert "func1" in affected + + # Lines affecting func1 and func2 (lines 1-5) + affected = assessor._count_affected_functions(tree, [1, 2, 3, 4, 5]) + assert len(affected) == 2 + + def test_count_total_functions(self): + """Test counting total functions in module.""" + assessor = RiskAssessor() + + code = """ +def func1(): + pass + +def func2(): + pass + +class MyClass: + def method1(self): + pass +""" + + tree = ast.parse(code) + + count = assessor._count_total_functions(tree) + assert count == 3 # func1, func2, method1 + + def test_analyze_dependency_impact(self): + """Test analyzing dependency impact.""" + with tempfile.TemporaryDirectory() as tmpdir: + tmppath = Path(tmpdir) + + # Create source file + source_file = tmppath / "module.py" + source_file.write_text("def my_function(): pass") + + # Create file that imports and calls the function + dependent_file = tmppath / "dependent.py" + dependent_file.write_text("from module import my_function\nresult = my_function()") + + # Create test file + test_file = tmppath / "test_module.py" + test_file.write_text("from module import my_function\nmy_function()") + + assessor = RiskAssessor(project_root=tmppath) + impact = assessor.analyze_dependency_impact(source_file, "my_function") + + # Should find dependent files in specific categories based on test setup + # We created dependent.py which imports the module + assert len(impact["importing_files"]) > 0, "Should find importing files" + # We created test_module.py which should be detected as a test + assert len(impact["dependent_tests"]) > 0, "Should find dependent test files" + # Calling functions should find the dependent file that uses my_function + assert len(impact["calling_functions"]) > 0, "Should find calling functions" + + def test_find_test_file(self): + """Test finding corresponding test files.""" + with tempfile.TemporaryDirectory() as tmpdir: + tmppath = Path(tmpdir) + + # Create source file + source_file = tmppath / "module.py" + source_file.write_text("pass") + + # Create test file with standard naming + test_file = tmppath / "test_module.py" + test_file.write_text("pass") + + assessor = RiskAssessor(project_root=tmppath) + found = assessor._find_test_file(source_file) + + assert found is not None + assert found.exists() + assert "test" in found.name + + def test_syntax_error_handling(self): + """Test that syntax errors are handled gracefully.""" + assessor = RiskAssessor() + + invalid_code = "def broken function(:" + + with tempfile.NamedTemporaryFile(mode="w", suffix=".py", delete=False) as f: + f.write(invalid_code) + temp_file = Path(f.name) + + try: + # Should not raise exception + risk_score, risk_factors = assessor.calculate_risk_score( + file_path=temp_file, + source_code=invalid_code, + change_type=ChangeType.RENAMING, + ) + + # Should return moderate risk for unknown code + assert 0.0 <= risk_score <= 1.0 + finally: + temp_file.unlink() + + +class TestChangeType: + """Test ChangeType enum.""" + + def test_change_type_values(self): + """Test that change types have correct values.""" + assert ChangeType.RENAMING.value == "renaming" + assert ChangeType.EXTRACTION.value == "extraction" + assert ChangeType.RESTRUCTURING.value == "restructuring" + assert ChangeType.API_CHANGE.value == "api_change" + assert ChangeType.LOGIC_CHANGE.value == "logic_change" + + +class TestRiskLevel: + """Test RiskLevel enum.""" + + def test_risk_level_values(self): + """Test that risk levels have correct values.""" + assert RiskLevel.SAFE.value == "safe" + assert RiskLevel.LOW.value == "low" + assert RiskLevel.MODERATE.value == "moderate" + assert RiskLevel.HIGH.value == "high" + assert RiskLevel.CRITICAL.value == "critical"