Source code for verskyt.interventions.analysis

"""
Analysis tools for TNN interventions.

Provides counterfactual analysis and impact assessment capabilities
for understanding how interventions affect model behavior.
"""

from dataclasses import dataclass
from typing import Any, Dict, List, Optional

import torch
import torch.nn.functional as F

from .manager import InterventionManager


[docs] @dataclass class ImpactMetrics: """Metrics quantifying the impact of an intervention. Comprehensive metrics for evaluating how parameter modifications affect model behavior, including output changes, prediction shifts, and statistical significance measures. Attributes: output_distance (float): L2 distance between original and modified outputs. output_correlation (float): Pearson correlation between original and modified outputs. prediction_change_rate (float): Fraction of samples with changed predictions. confidence_change (float): Average change in prediction confidence scores. feature_activation_change (Optional[torch.Tensor]): Changes in feature activation patterns, if computed. similarity_score_change (Optional[torch.Tensor]): Changes in similarity scores, if computed. effect_size (float): Cohen's d or similar standardized effect size measure. significance (Optional[float]): p-value from statistical significance test, if performed. """ # Output-level metrics output_distance: float # L2 distance between original and modified outputs output_correlation: float # Correlation between original and modified outputs prediction_change_rate: float # Fraction of samples with changed predictions confidence_change: float # Average change in prediction confidence # Feature-level metrics (if applicable) feature_activation_change: Optional[torch.Tensor] = None similarity_score_change: Optional[torch.Tensor] = None # Statistical metrics effect_size: float = 0.0 # Cohen's d or similar effect size measure significance: Optional[float] = None # p-value if statistical test performed
[docs] @dataclass class CounterfactualResult: """Result of a counterfactual analysis. Contains the complete record of a successful counterfactual generation, including original and modified states, intervention details, and quantitative measures of the change achieved. Attributes: original_input (torch.Tensor): Original input sample. original_output (torch.Tensor): Model output for original input. original_prediction (int): Predicted class for original input. modified_input (torch.Tensor): Input after intervention (may be unchanged). modified_output (torch.Tensor): Model output after intervention. modified_prediction (int): Predicted class after intervention. intervention_description (str): Human-readable description of the intervention. success (bool): Whether intervention achieved the desired outcome. input_perturbation_norm (float): L2 norm of input perturbation. output_change_norm (float): L2 norm of output change. confidence_change (float): Change in prediction confidence. """ original_input: torch.Tensor original_output: torch.Tensor original_prediction: int modified_input: torch.Tensor modified_output: torch.Tensor modified_prediction: int intervention_description: str success: bool # Whether intervention achieved desired outcome # Metrics input_perturbation_norm: float output_change_norm: float confidence_change: float
[docs] class ImpactAssessment: """Assess the impact of interventions on model behavior. Provides comprehensive methods to quantify how prototype or feature modifications affect model outputs, enabling systematic evaluation of intervention effectiveness and model interpretability. This class works in conjunction with InterventionManager to provide safe, temporary modifications with automatic restoration, allowing researchers to explore counterfactual scenarios without permanent model changes. Note: All interventions are automatically reverted after assessment, ensuring the model state remains unchanged unless explicitly modified through the InterventionManager. """
[docs] def __init__(self, intervention_manager: InterventionManager): """Initialize ImpactAssessment. Args: intervention_manager (InterventionManager): InterventionManager instance to analyze. Must be initialized with a TNN model. Note: The impact assessor uses the manager's model directly and leverages its intervention tracking capabilities. """ self.manager = intervention_manager self.model = intervention_manager.model
[docs] def assess_prototype_impact( self, layer_name: str, prototype_index: int, new_vector: torch.Tensor, test_inputs: torch.Tensor, test_targets: Optional[torch.Tensor] = None, ) -> ImpactMetrics: """Assess impact of modifying a prototype on model behavior. Temporarily modifies a prototype vector and quantifies the resulting changes in model outputs, predictions, and confidence scores across a set of test inputs. The original prototype is automatically restored. Args: layer_name (str): Name of the layer containing the prototype. Must be one of the manager's discovered layer names. prototype_index (int): Index of the prototype to modify. Must be in range [0, num_prototypes). new_vector (torch.Tensor): New prototype vector to test. Must match the shape of the existing prototype. test_inputs (torch.Tensor): Input data to evaluate impact on. Shape should be [batch_size, in_features]. test_targets (Optional[torch.Tensor], optional): Target labels for computing accuracy-based metrics. Defaults to None. Returns: ImpactMetrics: Comprehensive metrics quantifying the intervention's effects including output distance, correlation, prediction changes, confidence shifts, and statistical effect size. Note: The prototype is automatically restored to its original value after assessment, regardless of success or failure. """ # Get original outputs self.model.eval() with torch.no_grad(): original_outputs = self.model(test_inputs) original_predictions = torch.argmax(original_outputs, dim=1) original_confidences = F.softmax(original_outputs, dim=1).max(dim=1)[0] # Store original prototype original_prototype = self.manager.get_prototype(layer_name, prototype_index) try: # Apply intervention self.manager.modify_prototype( layer_name, prototype_index, new_vector, track_intervention=False ) # Get modified outputs with torch.no_grad(): modified_outputs = self.model(test_inputs) modified_predictions = torch.argmax(modified_outputs, dim=1) modified_confidences = F.softmax(modified_outputs, dim=1).max(dim=1)[0] # Compute metrics output_distance = torch.norm(modified_outputs - original_outputs).item() # Handle correlation computation carefully orig_flat = original_outputs.flatten() mod_flat = modified_outputs.flatten() if torch.std(orig_flat) > 1e-8 and torch.std(mod_flat) > 1e-8: corrcoef = torch.corrcoef(torch.stack([orig_flat, mod_flat])) output_correlation = corrcoef[0, 1].item() else: output_correlation = 1.0 if torch.allclose(orig_flat, mod_flat) else 0.0 prediction_change_rate = ( (original_predictions != modified_predictions).float().mean().item() ) confidence_change = ( (modified_confidences - original_confidences).mean().item() ) # Compute effect size (Cohen's d) pooled_std = torch.sqrt( (torch.var(original_outputs) + torch.var(modified_outputs)) / 2 ) if pooled_std > 1e-8: effect_size = ( torch.mean(modified_outputs - original_outputs).item() / pooled_std.item() ) else: effect_size = 0.0 return ImpactMetrics( output_distance=output_distance, output_correlation=output_correlation, prediction_change_rate=prediction_change_rate, confidence_change=confidence_change, effect_size=effect_size, ) finally: # Restore original prototype self.manager.modify_prototype( layer_name, prototype_index, original_prototype.vector, track_intervention=False, )
[docs] def assess_feature_impact( self, layer_name: str, feature_index: int, new_vector: torch.Tensor, test_inputs: torch.Tensor, test_targets: Optional[torch.Tensor] = None, ) -> ImpactMetrics: """Assess impact of modifying a feature on model behavior. Temporarily modifies a feature vector and quantifies the resulting changes in model outputs, predictions, and confidence scores across a set of test inputs. The original feature is automatically restored. Args: layer_name (str): Name of the layer containing the feature. Must be one of the manager's discovered layer names. feature_index (int): Index of the feature to modify within the layer's feature bank. Must be in range [0, num_features). new_vector (torch.Tensor): New feature vector to test. Must match the shape of the existing feature. test_inputs (torch.Tensor): Input data to evaluate impact on. Shape should be [batch_size, in_features]. test_targets (Optional[torch.Tensor], optional): Target labels for computing accuracy-based metrics. Defaults to None. Returns: ImpactMetrics: Comprehensive metrics quantifying the intervention's effects including output distance, correlation, prediction changes, confidence shifts, and statistical effect size. Note: The feature is automatically restored to its original value after assessment, regardless of success or failure. """ # Get original outputs self.model.eval() with torch.no_grad(): original_outputs = self.model(test_inputs) original_predictions = torch.argmax(original_outputs, dim=1) original_confidences = F.softmax(original_outputs, dim=1).max(dim=1)[0] # Store original feature original_feature = self.manager.get_feature(layer_name, feature_index) try: # Apply intervention self.manager.modify_feature( layer_name, feature_index, new_vector, track_intervention=False ) # Get modified outputs with torch.no_grad(): modified_outputs = self.model(test_inputs) modified_predictions = torch.argmax(modified_outputs, dim=1) modified_confidences = F.softmax(modified_outputs, dim=1).max(dim=1)[0] # Compute metrics (same as prototype impact) output_distance = torch.norm(modified_outputs - original_outputs).item() orig_flat = original_outputs.flatten() mod_flat = modified_outputs.flatten() if torch.std(orig_flat) > 1e-8 and torch.std(mod_flat) > 1e-8: corrcoef = torch.corrcoef(torch.stack([orig_flat, mod_flat])) output_correlation = corrcoef[0, 1].item() else: output_correlation = 1.0 if torch.allclose(orig_flat, mod_flat) else 0.0 prediction_change_rate = ( (original_predictions != modified_predictions).float().mean().item() ) confidence_change = ( (modified_confidences - original_confidences).mean().item() ) # Compute effect size pooled_std = torch.sqrt( (torch.var(original_outputs) + torch.var(modified_outputs)) / 2 ) if pooled_std > 1e-8: effect_size = ( torch.mean(modified_outputs - original_outputs).item() / pooled_std.item() ) else: effect_size = 0.0 return ImpactMetrics( output_distance=output_distance, output_correlation=output_correlation, prediction_change_rate=prediction_change_rate, confidence_change=confidence_change, effect_size=effect_size, ) finally: # Restore original feature self.manager.modify_feature( layer_name, feature_index, original_feature.vector, track_intervention=False, )
[docs] def sensitivity_analysis( self, layer_name: str, parameter_type: str, # 'prototype' or 'feature' parameter_index: int, test_inputs: torch.Tensor, perturbation_scales: List[float] = None, ) -> Dict[float, ImpactMetrics]: """ Perform sensitivity analysis by applying different scales of perturbation. Args: layer_name: Name of layer to analyze parameter_type: 'prototype' or 'feature' parameter_index: Index of parameter to perturb test_inputs: Input data for evaluation perturbation_scales: List of perturbation scales to test Returns: Dictionary mapping perturbation scales to impact metrics """ if perturbation_scales is None: perturbation_scales = [0.01, 0.05, 0.1, 0.2, 0.5, 1.0] results = {} # Get original parameter if parameter_type == "prototype": original_param = self.manager.get_prototype(layer_name, parameter_index) assess_func = self.assess_prototype_impact elif parameter_type == "feature": original_param = self.manager.get_feature(layer_name, parameter_index) assess_func = self.assess_feature_impact else: raise ValueError(f"Unknown parameter_type: {parameter_type}") original_vector = original_param.vector for scale in perturbation_scales: # Create random perturbation perturbation = torch.randn_like(original_vector) * scale perturbed_vector = original_vector + perturbation # Assess impact impact = assess_func( layer_name, parameter_index, perturbed_vector, test_inputs ) results[scale] = impact return results
[docs] class CounterfactualAnalyzer: """Perform counterfactual analysis on TNN models. Generates counterfactual examples by finding minimal parameter interventions that change model predictions for specific inputs. Uses gradient-based optimization to discover how prototype or feature modifications can achieve desired prediction outcomes. This class enables researchers to understand model decision boundaries and generate explanations for model behavior through systematic parameter space exploration. Note: All interventions are temporary and automatically restored, allowing safe exploration of counterfactual scenarios. """
[docs] def __init__(self, intervention_manager: InterventionManager): """Initialize CounterfactualAnalyzer. Args: intervention_manager (InterventionManager): InterventionManager instance to use for parameter modifications and model access. Args: intervention_manager: InterventionManager instance to use """ self.manager = intervention_manager self.model = intervention_manager.model
[docs] def find_prototype_counterfactuals( self, input_sample: torch.Tensor, target_class: int, layer_name: str, max_iterations: int = 100, learning_rate: float = 0.01, ) -> List[CounterfactualResult]: """ Find counterfactual examples by modifying prototypes. Args: input_sample: Input sample to generate counterfactuals for target_class: Desired output class layer_name: Layer to modify prototypes in max_iterations: Maximum optimization iterations learning_rate: Learning rate for optimization Returns: List of successful counterfactual results """ if layer_name not in self.manager.layer_names: raise ValueError(f"Layer '{layer_name}' not found") layer = self.manager._tnn_layers[layer_name] if not hasattr(layer, "prototypes"): raise ValueError(f"Layer '{layer_name}' has no prototypes") # Get original prediction self.model.eval() with torch.no_grad(): original_output = self.model(input_sample.unsqueeze(0)) original_prediction = torch.argmax(original_output, dim=1).item() if original_prediction == target_class: # Already the desired class return [] counterfactuals = [] # Try modifying each prototype for proto_idx in range(layer.prototypes.shape[0]): original_prototype = self.manager.get_prototype(layer_name, proto_idx) # Create a copy of the prototype for optimization modified_prototype = original_prototype.vector.clone().requires_grad_(True) optimizer = torch.optim.Adam([modified_prototype], lr=learning_rate) for iteration in range(max_iterations): optimizer.zero_grad() # Temporarily set the prototype with torch.no_grad(): layer.prototypes[proto_idx] = modified_prototype # Forward pass output = self.model(input_sample.unsqueeze(0)) # Loss: want to maximize probability of target class loss = -F.log_softmax(output, dim=1)[0, target_class] loss.backward() optimizer.step() # Check if we achieved the target with torch.no_grad(): prediction = torch.argmax(output, dim=1).item() if prediction == target_class: # Success! Create counterfactual result result = CounterfactualResult( original_input=input_sample.clone(), original_output=original_output.clone(), original_prediction=original_prediction, modified_input=input_sample.clone(), # Input didn't change modified_output=output.clone(), modified_prediction=prediction, intervention_description=( f"Modified prototype {proto_idx} in layer {layer_name}" ), success=True, input_perturbation_norm=0.0, # No input perturbation output_change_norm=torch.norm( output - original_output ).item(), confidence_change=F.softmax(output, dim=1).max().item() - F.softmax(original_output, dim=1).max().item(), ) counterfactuals.append(result) break # Restore original prototype self.manager.modify_prototype( layer_name, proto_idx, original_prototype.vector, track_intervention=False, ) return counterfactuals
[docs] def find_feature_counterfactuals( self, input_sample: torch.Tensor, target_class: int, layer_name: str, max_iterations: int = 100, learning_rate: float = 0.01, ) -> List[CounterfactualResult]: """ Find counterfactual examples by modifying features. Args: input_sample: Input sample to generate counterfactuals for target_class: Desired output class layer_name: Layer to modify features in max_iterations: Maximum optimization iterations learning_rate: Learning rate for optimization Returns: List of successful counterfactual results """ if layer_name not in self.manager.layer_names: raise ValueError(f"Layer '{layer_name}' not found") layer = self.manager._tnn_layers[layer_name] if not hasattr(layer, "feature_bank"): raise ValueError(f"Layer '{layer_name}' has no feature bank") # Get original prediction self.model.eval() with torch.no_grad(): original_output = self.model(input_sample.unsqueeze(0)) original_prediction = torch.argmax(original_output, dim=1).item() if original_prediction == target_class: return [] counterfactuals = [] # Try modifying each feature for feat_idx in range(layer.feature_bank.shape[0]): original_feature = self.manager.get_feature(layer_name, feat_idx) # Create a copy of the feature for optimization modified_feature = original_feature.vector.clone().requires_grad_(True) optimizer = torch.optim.Adam([modified_feature], lr=learning_rate) for iteration in range(max_iterations): optimizer.zero_grad() # Temporarily set the feature with torch.no_grad(): layer.feature_bank[feat_idx] = modified_feature # Forward pass output = self.model(input_sample.unsqueeze(0)) # Loss: want to maximize probability of target class loss = -F.log_softmax(output, dim=1)[0, target_class] loss.backward() optimizer.step() # Check if we achieved the target with torch.no_grad(): prediction = torch.argmax(output, dim=1).item() if prediction == target_class: # Success! result = CounterfactualResult( original_input=input_sample.clone(), original_output=original_output.clone(), original_prediction=original_prediction, modified_input=input_sample.clone(), modified_output=output.clone(), modified_prediction=prediction, intervention_description=( f"Modified feature {feat_idx} in layer {layer_name}" ), success=True, input_perturbation_norm=0.0, output_change_norm=torch.norm( output - original_output ).item(), confidence_change=F.softmax(output, dim=1).max().item() - F.softmax(original_output, dim=1).max().item(), ) counterfactuals.append(result) break # Restore original feature self.manager.modify_feature( layer_name, feat_idx, original_feature.vector, track_intervention=False ) return counterfactuals
[docs] def analyze_decision_boundary( self, input_samples: torch.Tensor, layer_name: str, num_perturbations: int = 10 ) -> Dict[str, Any]: """ Analyze how the decision boundary changes with interventions. Args: input_samples: Set of input samples near decision boundary layer_name: Layer to analyze num_perturbations: Number of random perturbations to test Returns: Dictionary with boundary analysis results """ results = { "layer_name": layer_name, "num_samples": len(input_samples), "boundary_stability": {}, "intervention_effects": [], } # Get original predictions self.model.eval() with torch.no_grad(): original_outputs = self.model(input_samples) original_predictions = torch.argmax(original_outputs, dim=1) layer = self.manager._tnn_layers[layer_name] # Test prototype perturbations if hasattr(layer, "prototypes"): for proto_idx in range( min(layer.prototypes.shape[0], 3) ): # Limit for efficiency original_proto = self.manager.get_prototype(layer_name, proto_idx) boundary_changes = 0 for _ in range(num_perturbations): # Random perturbation perturbation = torch.randn_like(original_proto.vector) * 0.1 perturbed_proto = original_proto.vector + perturbation # Apply intervention self.manager.modify_prototype( layer_name, proto_idx, perturbed_proto, track_intervention=False ) # Check predictions with torch.no_grad(): new_outputs = self.model(input_samples) new_predictions = torch.argmax(new_outputs, dim=1) # Count boundary crossings boundary_changes += ( (original_predictions != new_predictions).sum().item() ) # Restore original self.manager.modify_prototype( layer_name, proto_idx, original_proto.vector, track_intervention=False, ) stability = 1.0 - ( boundary_changes / (num_perturbations * len(input_samples)) ) results["boundary_stability"][f"prototype_{proto_idx}"] = stability return results