"""Custom initializers for Tversky neural network parameters.Based on empirical findings from the paper's experiments."""fromtypingimportOptionalimporttorchimporttorch.nnasnn
[docs]defuniform_init(tensor:torch.Tensor,a:float=-1.0,b:float=1.0):""" Uniform initialization. Paper finding: Uniform initialization led to higher convergence probability for XOR compared to normal and orthogonal. """withtorch.no_grad():tensor.uniform_(a,b)returntensor
[docs]definitialize_for_xor(prototypes:torch.Tensor,features:torch.Tensor,seed:Optional[int]=None,)->tuple[torch.Tensor,torch.Tensor]:""" Initialize parameters specifically for XOR task. Based on the paper's Figure 1 showing a working XOR configuration: - 2 prototypes: p0 = {}, p1 = {f0, f1} - 2 features positioned to separate XOR classes Args: prototypes: Prototype tensor to initialize [2, 2] features: Feature tensor to initialize [num_features, 2] seed: Random seed for reproducibility Returns: Tuple of (initialized_prototypes, initialized_features) """ifseedisnotNone:torch.manual_seed(seed)withtorch.no_grad():# Initialize based on paper's working configurationiffeatures.shape[0]>=2:# Feature vectors from paper Figure 1# f0 points toward [1, 0] region# f1 points toward [0, 1] regionfeatures[0]=torch.tensor([1.0,-0.5])features[1]=torch.tensor([-0.5,1.0])# Additional features if num_features > 2iffeatures.shape[0]>2:features[2:].uniform_(-0.5,0.5)else:# Single feature case (paper shows this can work)features[0]=torch.tensor([1.0,1.0])# Prototypes from paper# p0 represents class 0 (inputs [0,0] and [1,1])# p1 represents class 1 (inputs [0,1] and [1,0])ifprototypes.shape[0]>=2:# p0 has no features strongly (near origin)prototypes[0]=torch.tensor([0.1,0.1])# p1 has both features (away from origin)prototypes[1]=torch.tensor([0.5,0.5])returnprototypes,features
[docs]defsmart_init(layer:nn.Module,method:str="xavier_uniform",**kwargs):""" Smart initialization for Tversky layers based on paper findings. Args: layer: TverskyProjectionLayer or TverskySimilarityLayer method: Initialization method **kwargs: Additional arguments for initialization """fromverskyt.layersimportTverskyProjectionLayer,TverskySimilarityLayerifisinstance(layer,(TverskyProjectionLayer,TverskySimilarityLayer)):ifmethod=="uniform":ifhasattr(layer,"prototypes"):uniform_init(layer.prototypes,**kwargs)uniform_init(layer.feature_bank,**kwargs)elifmethod=="xavier_uniform":ifhasattr(layer,"prototypes"):xavier_uniform_init(layer.prototypes,**kwargs)xavier_uniform_init(layer.feature_bank,**kwargs)elifmethod=="xor"andhasattr(layer,"prototypes"):# Special initialization for XOR taskinitialize_for_xor(layer.prototypes,layer.feature_bank)else:# Fall back to standard PyTorch initializationifmethod=="uniform":nn.init.uniform_(layer.weight,**kwargs)elifmethod=="xavier_uniform":nn.init.xavier_uniform_(layer.weight,**kwargs)