"""
XOR benchmark suite for Tversky Neural Networks.
Reproduces XOR experiments from "Tversky Neural Networks: Psychologically
Plausible Deep Learning with Differentiable Tversky Similarity"
(Doumbouya et al., 2025).
Provides both fast development benchmarks and full paper replication capabilities.
"""
import time
from dataclasses import dataclass, field
from itertools import product
from typing import Dict, List, Optional, Tuple, Union
import torch
import torch.nn.functional as F
import torch.optim as optim
from verskyt.core.similarity import DifferenceReduction, IntersectionReduction
from verskyt.layers.projection import TverskyProjectionLayer
[docs]
@dataclass
class XORConfig:
"""Configuration for XOR benchmark experiments."""
intersection_methods: List[Union[str, IntersectionReduction]] = field(
default_factory=lambda: ["product", "mean", "max", "gmean"]
)
difference_methods: List[Union[str, DifferenceReduction]] = field(
default_factory=lambda: ["substractmatch", "ignorematch"]
)
normalization: List[bool] = field(default_factory=lambda: [False])
feature_counts: List[int] = field(default_factory=lambda: [1, 4, 16, 32])
prototype_init: List[str] = field(default_factory=lambda: ["uniform", "normal"])
feature_init: List[str] = field(default_factory=lambda: ["uniform", "normal"])
random_seeds: List[int] = field(default_factory=lambda: [0, 1, 2])
epochs: int = 1000
learning_rate: float = 0.1
convergence_threshold: float = 1.0 # 100% accuracy
@property
def total_runs(self) -> int:
"""Calculate total number of experimental runs."""
return (
len(self.intersection_methods)
* len(self.difference_methods)
* len(self.normalization)
* len(self.feature_counts)
* len(self.prototype_init)
* len(self.feature_init)
* len(self.random_seeds)
)
# Fast benchmark for development (24 runs, ~15 seconds)
# Uses xavier_uniform for reliable convergence during development
FAST_BENCHMARK_CONFIG = XORConfig(
intersection_methods=["product", "mean", "max", "gmean"],
difference_methods=["substractmatch", "ignorematch"],
normalization=[False],
feature_counts=[4, 16], # Focus on good feature counts
prototype_init=["xavier_uniform"], # Use working initialization
feature_init=["xavier_uniform"],
random_seeds=[0, 1, 2],
)
# Full paper replication (12,960 runs, ~2.2 hours)
# NOTE: Paper's "uniform" initialization may differ from PyTorch's implementation
# This may result in lower convergence rates than paper reports
FULL_PAPER_CONFIG = XORConfig(
intersection_methods=["min", "max", "product", "mean", "gmean", "softmin"],
difference_methods=["ignorematch", "substractmatch"],
normalization=[False, True],
feature_counts=[1, 2, 4, 8, 16, 32],
prototype_init=["uniform", "normal", "orthogonal"],
feature_init=["uniform", "normal", "orthogonal"],
random_seeds=list(range(9)),
)
[docs]
@dataclass
class XORResult:
"""Results from a single XOR training run."""
intersection_method: str
difference_method: str
normalize: bool
feature_count: int
prototype_init: str
feature_init: str
seed: int
final_loss: float
final_accuracy: float
converged: bool
training_time: float
# Optional detailed tracking
loss_history: Optional[List[float]] = None
accuracy_history: Optional[List[float]] = None
[docs]
class XORBenchmark:
"""XOR benchmark runner for Tversky Neural Networks."""
[docs]
def __init__(self, config: XORConfig):
self.config = config
self.results: List[XORResult] = []
# XOR dataset (matching working implementation)
self.xor_inputs = torch.tensor(
[
[0.0, 0.0],
[0.0, 1.0],
[1.0, 0.0],
[1.0, 1.0],
]
)
self.xor_targets = torch.tensor([0, 1, 1, 0])
[docs]
def run_single_experiment(
self,
intersection_method: str,
difference_method: str,
normalize: bool,
feature_count: int,
prototype_init: str,
feature_init: str,
seed: int,
track_history: bool = False,
) -> XORResult:
"""Run a single XOR training experiment."""
# Set random seed for reproducibility
torch.manual_seed(seed)
# Map initialization methods to what layer supports
layer_proto_init = (
prototype_init if prototype_init != "orthogonal" else "xavier_uniform"
)
layer_feat_init = (
feature_init if feature_init != "orthogonal" else "xavier_uniform"
)
# Create TverskyProjectionLayer
layer = TverskyProjectionLayer(
in_features=2,
num_prototypes=2, # XOR has 2 output classes
num_features=feature_count,
alpha=0.5,
beta=0.5,
learnable_ab=True,
theta=1e-7,
intersection_reduction=intersection_method,
difference_reduction=difference_method,
feature_init=layer_feat_init,
prototype_init=layer_proto_init,
)
# Apply orthogonal initialization if requested (paper requirement)
if prototype_init == "orthogonal":
torch.nn.init.orthogonal_(layer.prototypes.data)
if feature_init == "orthogonal":
torch.nn.init.orthogonal_(layer.feature_bank.data)
optimizer = optim.Adam(layer.parameters(), lr=self.config.learning_rate)
# Training history tracking
loss_history = [] if track_history else None
accuracy_history = [] if track_history else None
start_time = time.time()
for epoch in range(self.config.epochs):
optimizer.zero_grad()
# Forward pass
outputs = layer(self.xor_inputs)
loss = F.cross_entropy(outputs, self.xor_targets)
# Backward pass
loss.backward()
optimizer.step()
# Track history if requested
if track_history:
with torch.no_grad():
predicted = torch.argmax(outputs, dim=1)
accuracy = (predicted == self.xor_targets).float().mean().item()
loss_history.append(loss.item())
accuracy_history.append(accuracy)
training_time = time.time() - start_time
# Final evaluation
with torch.no_grad():
final_outputs = layer(self.xor_inputs)
final_loss = F.cross_entropy(final_outputs, self.xor_targets).item()
predicted = torch.argmax(final_outputs, dim=1)
final_accuracy = (predicted == self.xor_targets).float().mean().item()
converged = final_accuracy >= self.config.convergence_threshold
return XORResult(
intersection_method=intersection_method,
difference_method=difference_method,
normalize=normalize,
feature_count=feature_count,
prototype_init=prototype_init,
feature_init=feature_init,
seed=seed,
final_loss=final_loss,
final_accuracy=final_accuracy,
converged=converged,
training_time=training_time,
loss_history=loss_history,
accuracy_history=accuracy_history,
)
[docs]
def run_benchmark(
self, verbose: bool = True, track_history: bool = False
) -> List[XORResult]:
"""Run complete benchmark suite."""
total_runs = self.config.total_runs
if verbose:
print(f"Starting XOR benchmark: {total_runs} total runs")
print(f"Estimated runtime: {total_runs * 0.6:.1f} seconds")
results = []
start_time = time.time()
# Generate all parameter combinations
combinations = product(
self.config.intersection_methods,
self.config.difference_methods,
self.config.normalization,
self.config.feature_counts,
self.config.prototype_init,
self.config.feature_init,
self.config.random_seeds,
)
for i, (
int_method,
diff_method,
normalize,
n_features,
proto_init,
feat_init,
seed,
) in enumerate(combinations):
if verbose and (i + 1) % 50 == 0:
elapsed = time.time() - start_time
eta = elapsed / (i + 1) * (total_runs - i - 1)
print(
f" Progress: {i+1}/{total_runs} ({100*(i+1)/total_runs:.1f}%) "
f"ETA: {eta:.1f}s"
)
try:
result = self.run_single_experiment(
intersection_method=int_method,
difference_method=diff_method,
normalize=normalize,
feature_count=n_features,
prototype_init=proto_init,
feature_init=feat_init,
seed=seed,
track_history=track_history,
)
results.append(result)
except Exception as e:
if verbose:
print(f" Warning: Run {i+1} failed: {e}")
# Create failed result
failed_result = XORResult(
intersection_method=int_method,
difference_method=diff_method,
normalize=normalize,
feature_count=n_features,
prototype_init=proto_init,
feature_init=feat_init,
seed=seed,
final_loss=float("nan"),
final_accuracy=0.5, # Random guessing
converged=False,
training_time=0.0,
)
results.append(failed_result)
total_time = time.time() - start_time
if verbose:
convergence_rate = sum(r.converged for r in results) / len(results)
print(f"Benchmark complete: {total_time:.1f}s")
print(f"Overall convergence rate: {convergence_rate:.2%}")
self.results = results
return results
[docs]
def analyze_results(self) -> Dict[str, float]:
"""Analyze benchmark results and compute convergence rates."""
if not self.results:
raise ValueError("No results to analyze. Run benchmark first.")
analysis = {}
# Overall convergence rate
total_converged = sum(r.converged for r in self.results)
analysis["overall_convergence_rate"] = total_converged / len(self.results)
# By intersection method
for method in self.config.intersection_methods:
method_results = [
r for r in self.results if r.intersection_method == method
]
if method_results:
converged = sum(r.converged for r in method_results)
analysis[f"convergence_rate_{method}"] = converged / len(method_results)
# By difference method
for method in self.config.difference_methods:
method_results = [r for r in self.results if r.difference_method == method]
if method_results:
converged = sum(r.converged for r in method_results)
analysis[f"convergence_rate_{method}"] = converged / len(method_results)
# By method combination (key paper finding)
method_combos = set(
(r.intersection_method, r.difference_method) for r in self.results
)
for int_method, diff_method in method_combos:
combo_results = [
r
for r in self.results
if r.intersection_method == int_method
and r.difference_method == diff_method
]
if combo_results:
converged = sum(r.converged for r in combo_results)
rate = converged / len(combo_results)
analysis[f"convergence_rate_{int_method}_{diff_method}"] = rate
return analysis
[docs]
def run_fast_xor_benchmark(
verbose: bool = True,
) -> Tuple[List[XORResult], Dict[str, float]]:
"""Run fast XOR benchmark for development (96 runs, ~60 seconds)."""
benchmark = XORBenchmark(FAST_BENCHMARK_CONFIG)
results = benchmark.run_benchmark(verbose=verbose)
analysis = benchmark.analyze_results()
if verbose:
print("\n=== Fast XOR Benchmark Results ===")
print(f"Total runs: {len(results)}")
print(f"Overall convergence: {analysis['overall_convergence_rate']:.2%}")
# Show key method combinations
key_combos = [
("product", "substractmatch"),
("mean", "substractmatch"),
("max", "ignorematch"),
("gmean", "ignorematch"),
]
print("\nKey method combinations:")
for int_method, diff_method in key_combos:
key = f"convergence_rate_{int_method}_{diff_method}"
if key in analysis:
print(f" {int_method} + {diff_method}: {analysis[key]:.2%}")
return results, analysis
[docs]
def run_full_xor_replication(
verbose: bool = True,
) -> Tuple[List[XORResult], Dict[str, float]]:
"""Run full paper replication (12,960 runs, ~2.2 hours)."""
if verbose:
print("⚠️ WARNING: Full replication will take ~2.2 hours")
print("Use run_fast_xor_benchmark() for development")
benchmark = XORBenchmark(FULL_PAPER_CONFIG)
results = benchmark.run_benchmark(verbose=verbose)
analysis = benchmark.analyze_results()
if verbose:
print("\n=== Full XOR Replication Results ===")
print(f"Total runs: {len(results)}")
print(f"Overall convergence: {analysis['overall_convergence_rate']:.2%}")
# Paper validation targets
paper_targets = {
("product", "substractmatch"): 0.53,
("mean", "substractmatch"): 0.51,
("max", "ignorematch"): 0.47,
("gmean", "ignorematch"): 0.00, # Should fail
}
print("\nPaper validation (expected vs actual):")
for (int_method, diff_method), expected in paper_targets.items():
key = f"convergence_rate_{int_method}_{diff_method}"
if key in analysis:
actual = analysis[key]
diff = abs(actual - expected)
status = "✅" if diff < 0.05 else "❌"
print(
f" {int_method} + {diff_method}: "
f"{expected:.2%} vs {actual:.2%} {status}"
)
return results, analysis