"""
Intervention Manager for Tversky Neural Networks.
Provides high-level APIs for inspecting and modifying TNN models,
enabling interpretability and counterfactual analysis.
"""
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Union
import torch
import torch.nn as nn
from verskyt.layers.projection import TverskyProjectionLayer, TverskySimilarityLayer
[docs]
@dataclass
class PrototypeInfo:
"""Information about a prototype in a TNN layer.
Contains metadata and vector data for a single prototype, enabling
inspection and modification of learned prototype representations.
Attributes:
layer_name (str): Name of the layer containing this prototype.
prototype_index (int): Index of the prototype within the layer.
vector (torch.Tensor): The prototype vector data.
layer_ref (Union[TverskyProjectionLayer, TverskySimilarityLayer]):
Reference to the layer object.
"""
layer_name: str
prototype_index: int
vector: torch.Tensor
layer_ref: Union[TverskyProjectionLayer, TverskySimilarityLayer]
@property
def shape(self) -> torch.Size:
"""Get the shape of the prototype vector.
Returns:
torch.Size: Shape of the prototype vector, typically [in_features].
"""
return self.vector.shape
@property
def norm(self) -> float:
"""Get the L2 norm of the prototype vector.
Returns:
float: L2 norm of the prototype vector, useful for comparing
prototype magnitudes and analyzing learned representations.
"""
return torch.norm(self.vector).item()
[docs]
@dataclass
class FeatureInfo:
"""Information about a feature in a TNN layer.
Contains metadata and vector data for a single feature, enabling
inspection and modification of learned feature representations.
Attributes:
layer_name (str): Name of the layer containing this feature.
feature_index (int): Index of the feature within the layer's feature bank.
vector (torch.Tensor): The feature vector data.
layer_ref (Union[TverskyProjectionLayer, TverskySimilarityLayer]):
Reference to the layer object.
"""
layer_name: str
feature_index: int
vector: torch.Tensor
layer_ref: Union[TverskyProjectionLayer, TverskySimilarityLayer]
@property
def shape(self) -> torch.Size:
"""Get the shape of the feature vector.
Returns:
torch.Size: Shape of the feature vector, typically [in_features].
"""
return self.vector.shape
@property
def norm(self) -> float:
"""Get the L2 norm of the feature vector.
Returns:
float: L2 norm of the feature vector, useful for comparing
feature magnitudes and analyzing learned representations.
"""
return torch.norm(self.vector).item()
[docs]
class InterventionManager:
"""Manager for interventions on Tversky Neural Networks.
Provides a unified API for inspecting, modifying, and analyzing TNN models
to enable interpretability research and counterfactual analysis. Supports
tracking of interventions and restoration of original model states.
This class serves as the central hub for TNN interpretability, offering:
- Comprehensive prototype and feature discovery across all layers
- Safe parameter modification with automatic state tracking
- Integration with impact assessment and grounding frameworks
- Batch operations for systematic intervention studies
Note:
The manager automatically discovers TNN layers (TverskyProjectionLayer
and TverskySimilarityLayer) within the provided model and maintains
original parameter states for restoration.
"""
[docs]
def __init__(self, model: nn.Module, model_name: str = "TNN_Model"):
"""Initialize InterventionManager for a TNN model.
Automatically discovers all TNN layers within the model and captures
the original parameter state for later restoration.
Args:
model (nn.Module): PyTorch model containing TverskyProjectionLayer
or TverskySimilarityLayer instances.
model_name (str, optional): Human-readable name for the model.
Defaults to "TNN_Model".
Note:
The manager will only operate on TverskyProjectionLayer and
TverskySimilarityLayer instances found within the model.
"""
self.model = model
self.model_name = model_name
self._tnn_layers = self._discover_tnn_layers()
# Track original state for impact assessment
self._original_state = self._capture_model_state()
self._intervention_history: List[Dict[str, Any]] = []
def _discover_tnn_layers(
self,
) -> Dict[str, Union[TverskyProjectionLayer, TverskySimilarityLayer]]:
"""Discover all TNN layers in the model.
Recursively searches through all modules in the model to find
TverskyProjectionLayer and TverskySimilarityLayer instances.
Returns:
Dict[str, Union[TverskyProjectionLayer, TverskySimilarityLayer]]:
Dictionary mapping layer names to layer objects.
"""
tnn_layers = {}
for name, module in self.model.named_modules():
if isinstance(module, (TverskyProjectionLayer, TverskySimilarityLayer)):
tnn_layers[name] = module
return tnn_layers
def _capture_model_state(self) -> Dict[str, torch.Tensor]:
"""Capture current state of all TNN layer parameters.
Creates deep copies of all prototype and feature bank tensors
to enable restoration after interventions.
Returns:
Dict[str, torch.Tensor]: Dictionary mapping parameter names
to cloned tensor data.
"""
state = {}
for layer_name, layer in self._tnn_layers.items():
if hasattr(layer, "prototypes"):
state[f"{layer_name}.prototypes"] = layer.prototypes.data.clone()
if hasattr(layer, "feature_bank"):
state[f"{layer_name}.feature_bank"] = layer.feature_bank.data.clone()
if hasattr(layer, "alpha"):
state[f"{layer_name}.alpha"] = layer.alpha.data.clone()
if hasattr(layer, "beta"):
state[f"{layer_name}.beta"] = layer.beta.data.clone()
return state
@property
def num_layers(self) -> int:
"""Get the number of TNN layers in the model.
Returns:
int: Total count of TverskyProjectionLayer and TverskySimilarityLayer
instances found in the model.
"""
return len(self._tnn_layers)
@property
def layer_names(self) -> List[str]:
"""Get names of all TNN layers in the model.
Returns:
List[str]: List of layer names that can be used with other
manager methods for layer-specific operations.
"""
return list(self._tnn_layers.keys())
[docs]
def get_layer_info(self, layer_name: str) -> Dict[str, Any]:
"""Get comprehensive information about a TNN layer.
Provides detailed metadata about layer configuration, parameter shapes,
and capabilities for inspection and intervention planning.
Args:
layer_name (str): Name of the layer to inspect. Must be one of
the names returned by the layer_names property.
Returns:
Dict[str, Any]: Dictionary containing layer metadata including:
- layer_name: Name of the layer
- layer_type: Class name of the layer
- in_features: Input feature dimension
- num_prototypes: Number of prototypes (if applicable)
- num_features: Number of features (if applicable)
- learnable_ab: Whether alpha/beta are learnable (if applicable)
Raises:
ValueError: If layer_name is not found in the model.
"""
if layer_name not in self._tnn_layers:
raise ValueError(
f"Layer '{layer_name}' not found. Available: {self.layer_names}"
)
layer = self._tnn_layers[layer_name]
info = {
"layer_name": layer_name,
"layer_type": type(layer).__name__,
"in_features": layer.in_features,
}
# Add layer-specific information
if isinstance(layer, TverskyProjectionLayer):
info.update(
{
"num_prototypes": layer.num_prototypes,
"num_features": layer.num_features,
"has_bias": layer.bias is not None,
"shared_features": getattr(layer, "shared_features", False),
}
)
elif isinstance(layer, TverskySimilarityLayer):
info.update(
{
"num_features": layer.num_features,
"use_contrast_form": layer.use_contrast_form,
}
)
# Add parameter information
if hasattr(layer, "alpha"):
info["alpha"] = layer.alpha.item()
if hasattr(layer, "beta"):
info["beta"] = layer.beta.item()
if hasattr(layer, "theta"):
if isinstance(layer.theta, torch.Tensor):
info["theta"] = layer.theta.item()
else:
info["theta"] = layer.theta
# Add reduction methods
info["intersection_reduction"] = str(layer.intersection_reduction)
info["difference_reduction"] = str(layer.difference_reduction)
return info
[docs]
def list_prototypes(self, layer_name: Optional[str] = None) -> List[PrototypeInfo]:
"""List all prototypes in the model or specific layer.
Discovers and returns metadata for all prototype vectors across
TNN layers, enabling systematic inspection and analysis.
Args:
layer_name (Optional[str], optional): If specified, only return
prototypes from this layer. If None, returns prototypes from
all layers. Defaults to None.
Returns:
List[PrototypeInfo]: List of PrototypeInfo objects containing
prototype vectors and metadata. Each object provides access
to the prototype vector, layer reference, and computed properties.
Note:
Only layers with 'prototypes' attribute (typically TverskyProjectionLayer)
will contribute to the returned list.
"""
prototypes = []
layers_to_check = [layer_name] if layer_name else self.layer_names
for name in layers_to_check:
if name not in self._tnn_layers:
continue
layer = self._tnn_layers[name]
if hasattr(layer, "prototypes"):
for i in range(layer.prototypes.shape[0]):
prototypes.append(
PrototypeInfo(
layer_name=name,
prototype_index=i,
vector=layer.get_prototype(i),
layer_ref=layer,
)
)
return prototypes
[docs]
def list_features(self, layer_name: Optional[str] = None) -> List[FeatureInfo]:
"""List all features in the model or specific layer.
Discovers and returns metadata for all feature vectors across
TNN layers, enabling systematic inspection and analysis of the
learned feature representations.
Args:
layer_name (Optional[str], optional): If specified, only return
features from this layer. If None, returns features from
all layers. Defaults to None.
Returns:
List[FeatureInfo]: List of FeatureInfo objects containing
feature vectors and metadata. Each object provides access
to the feature vector, layer reference, and computed properties.
Note:
Only layers with 'feature_bank' attribute will contribute to
the returned list. This typically includes both TverskyProjectionLayer
and TverskySimilarityLayer instances.
"""
features = []
layers_to_check = [layer_name] if layer_name else self.layer_names
for name in layers_to_check:
if name not in self._tnn_layers:
continue
layer = self._tnn_layers[name]
if hasattr(layer, "feature_bank"):
for i in range(layer.feature_bank.shape[0]):
features.append(
FeatureInfo(
layer_name=name,
feature_index=i,
vector=layer.get_feature(i),
layer_ref=layer,
)
)
return features
[docs]
def get_prototype(self, layer_name: str, prototype_index: int) -> PrototypeInfo:
"""Get specific prototype information.
Retrieves detailed information about a single prototype vector,
including its current values and layer context.
Args:
layer_name (str): Name of the layer containing the prototype.
Must be one of the names returned by layer_names.
prototype_index (int): Index of the prototype within the layer.
Must be in range [0, num_prototypes).
Returns:
PrototypeInfo: Object containing the prototype vector, metadata,
and layer reference for further operations.
Raises:
ValueError: If layer_name is not found or layer has no prototypes.
IndexError: If prototype_index is out of bounds.
"""
if layer_name not in self._tnn_layers:
raise ValueError(f"Layer '{layer_name}' not found")
layer = self._tnn_layers[layer_name]
if not hasattr(layer, "prototypes"):
raise ValueError(f"Layer '{layer_name}' has no prototypes")
if prototype_index >= layer.prototypes.shape[0]:
raise ValueError(
f"Prototype index {prototype_index} out of range for "
f"layer '{layer_name}'"
)
return PrototypeInfo(
layer_name=layer_name,
prototype_index=prototype_index,
vector=layer.get_prototype(prototype_index),
layer_ref=layer,
)
[docs]
def get_feature(self, layer_name: str, feature_index: int) -> FeatureInfo:
"""Get specific feature information.
Retrieves detailed information about a single feature vector,
including its current values and layer context.
Args:
layer_name (str): Name of the layer containing the feature.
Must be one of the names returned by layer_names.
feature_index (int): Index of the feature within the layer's
feature bank. Must be in range [0, num_features).
Returns:
FeatureInfo: Object containing the feature vector, metadata,
and layer reference for further operations.
Raises:
ValueError: If layer_name is not found or layer has no feature bank.
IndexError: If feature_index is out of bounds.
"""
if layer_name not in self._tnn_layers:
raise ValueError(f"Layer '{layer_name}' not found")
layer = self._tnn_layers[layer_name]
if not hasattr(layer, "feature_bank"):
raise ValueError(f"Layer '{layer_name}' has no feature bank")
if feature_index >= layer.feature_bank.shape[0]:
raise ValueError(
f"Feature index {feature_index} out of range for "
f"layer '{layer_name}'"
)
return FeatureInfo(
layer_name=layer_name,
feature_index=feature_index,
vector=layer.get_feature(feature_index),
layer_ref=layer,
)
[docs]
def modify_prototype(
self,
layer_name: str,
prototype_index: int,
new_vector: torch.Tensor,
track_intervention: bool = True,
) -> PrototypeInfo:
"""Modify a prototype vector in a TNN layer.
Safely modifies a prototype vector with automatic validation and
optional intervention tracking for impact assessment and restoration.
Args:
layer_name (str): Name of the layer containing the prototype.
Must be one of the names returned by layer_names.
prototype_index (int): Index of the prototype to modify.
Must be in range [0, num_prototypes).
new_vector (torch.Tensor): New prototype vector to set.
Must match the shape of the existing prototype.
track_intervention (bool, optional): Whether to record this
intervention in the history for impact assessment.
Defaults to True.
Returns:
PrototypeInfo: Updated PrototypeInfo object reflecting the
new prototype vector state.
Raises:
ValueError: If layer_name is not found, layer has no prototypes,
or new_vector shape doesn't match expected dimensions.
IndexError: If prototype_index is out of bounds.
Note:
When track_intervention=True, the original vector is stored
for potential restoration via reset_to_original().
"""
if layer_name not in self._tnn_layers:
raise ValueError(f"Layer '{layer_name}' not found")
layer = self._tnn_layers[layer_name]
if not hasattr(layer, "prototypes"):
raise ValueError(f"Layer '{layer_name}' has no prototypes")
# Validate dimensions
expected_shape = layer.prototypes[prototype_index].shape
if new_vector.shape != expected_shape:
raise ValueError(
f"New vector shape {new_vector.shape} doesn't match "
f"expected {expected_shape}"
)
# Store original for tracking
if track_intervention:
original_vector = layer.get_prototype(prototype_index)
intervention_record = {
"type": "prototype_modification",
"layer_name": layer_name,
"prototype_index": prototype_index,
"original_vector": original_vector.clone(),
"new_vector": new_vector.clone(),
"timestamp": torch.tensor(
len(self._intervention_history), dtype=torch.long
),
}
self._intervention_history.append(intervention_record)
# Apply modification
layer.set_prototype(prototype_index, new_vector)
return self.get_prototype(layer_name, prototype_index)
[docs]
def modify_feature(
self,
layer_name: str,
feature_index: int,
new_vector: torch.Tensor,
track_intervention: bool = True,
) -> FeatureInfo:
"""Modify a feature vector in a TNN layer.
Safely modifies a feature vector with automatic validation and
optional intervention tracking for impact assessment and restoration.
Args:
layer_name (str): Name of the layer containing the feature.
Must be one of the names returned by layer_names.
feature_index (int): Index of the feature to modify within
the layer's feature bank. Must be in range [0, num_features).
new_vector (torch.Tensor): New feature vector to set.
Must match the shape of the existing feature.
track_intervention (bool, optional): Whether to record this
intervention in the history for impact assessment.
Defaults to True.
Returns:
FeatureInfo: Updated FeatureInfo object reflecting the
new feature vector state.
Raises:
ValueError: If layer_name is not found, layer has no feature bank,
or new_vector shape doesn't match expected dimensions.
IndexError: If feature_index is out of bounds.
Note:
When track_intervention=True, the original vector is stored
for potential restoration via reset_to_original().
"""
if layer_name not in self._tnn_layers:
raise ValueError(f"Layer '{layer_name}' not found")
layer = self._tnn_layers[layer_name]
if not hasattr(layer, "feature_bank"):
raise ValueError(f"Layer '{layer_name}' has no feature bank")
# Validate dimensions
expected_shape = layer.feature_bank[feature_index].shape
if new_vector.shape != expected_shape:
raise ValueError(
f"New vector shape {new_vector.shape} doesn't match "
f"expected {expected_shape}"
)
# Store original for tracking
if track_intervention:
original_vector = layer.get_feature(feature_index)
intervention_record = {
"type": "feature_modification",
"layer_name": layer_name,
"feature_index": feature_index,
"original_vector": original_vector.clone(),
"new_vector": new_vector.clone(),
"timestamp": torch.tensor(
len(self._intervention_history), dtype=torch.long
),
}
self._intervention_history.append(intervention_record)
# Apply modification
layer.set_feature(feature_index, new_vector)
return self.get_feature(layer_name, feature_index)
[docs]
def reset_to_original(self) -> None:
"""Reset all TNN layers to their original state.
Restores all prototype vectors, feature vectors, and learnable
parameters (alpha, beta) to their values at manager initialization.
Also clears the intervention history.
Note:
This operation cannot be undone. All modifications made through
modify_prototype() and modify_feature() will be reverted to the
original model state.
"""
for param_name, original_value in self._original_state.items():
layer_name, param_type = param_name.rsplit(".", 1)
layer = self._tnn_layers[layer_name]
if param_type == "prototypes" and hasattr(layer, "prototypes"):
layer.prototypes.data.copy_(original_value)
elif param_type == "feature_bank" and hasattr(layer, "feature_bank"):
layer.feature_bank.data.copy_(original_value)
elif param_type == "alpha" and hasattr(layer, "alpha"):
layer.alpha.data.copy_(original_value)
elif param_type == "beta" and hasattr(layer, "beta"):
layer.beta.data.copy_(original_value)
# Clear intervention history
self._intervention_history.clear()
[docs]
def get_intervention_history(self) -> List[Dict[str, Any]]:
"""Get history of all interventions performed.
Returns a copy of the intervention history containing detailed
records of all modifications made through this manager.
Returns:
List[Dict[str, Any]]: List of intervention records, each containing:
- type: Type of intervention ('prototype_modification'
or 'feature_modification')
- layer_name: Name of the affected layer
- index: Index of the modified parameter
- original_vector: Original parameter vector (cloned)
- new_vector: New parameter vector (cloned)
- timestamp: Sequential intervention number
"""
return self._intervention_history.copy()
[docs]
def summary(self) -> str:
"""Get a summary of the model and available interventions.
Provides a comprehensive overview of the model structure,
TNN layers, and intervention capabilities for inspection.
Returns:
str: Multi-line summary string containing:
- Model name and layer count
- Detailed information for each TNN layer
- Number of prototypes and features per layer
- Parameter values (alpha, beta, theta)
- Total intervention count
"""
lines = [
f"Intervention Manager for: {self.model_name}",
f"TNN Layers: {self.num_layers}",
"",
"Layer Details:",
]
for layer_name in self.layer_names:
info = self.get_layer_info(layer_name)
lines.append(f" {layer_name}: {info['layer_type']}")
if "num_prototypes" in info:
lines.append(f" Prototypes: {info['num_prototypes']}")
if "num_features" in info:
lines.append(f" Features: {info['num_features']}")
lines.append(
f" α={info.get('alpha', 'N/A'):.3f}, β={info.get('beta', 'N/A'):.3f}"
)
lines.extend(
[
"",
f"Interventions Performed: {len(self._intervention_history)}",
"Available Operations: inspect, modify, analyze, reset",
]
)
return "\n".join(lines)