Skip to content

Connectome-constrained model utilities

connectome

ConnectomeODERNN

Bases: _ConnectomeRNNMixIn, SparseODERNN

Source code in src/bioplnn/models/connectome.py
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
class ConnectomeODERNN(_ConnectomeRNNMixIn, SparseODERNN):
    __doc__ = f"""Continuous-time ConnectomeRNN using ODE solver integration.

    A continuous-time version of the ConnectomeRNN that simulates neural
    dynamics using an ODE solver and computes the gradient with respect to the
    parameters for efficient training.

    Key features:
        - ODE solver integration: Built on top of `SparseODERNN`, allowing for
          the efficient computation of the gradients of the hidden states with
          respect to the parameters. In the particular, the `torchode.AutoDiffAdjoint`
          solver allows for O(1) computational complexity for the backward pass with
          respect to the number of simulated timesteps.
        - {key_features_docstring}

    Attributes:
        {attributes_docstring}

    Example:
        >>> connectome = create_sparse_topographic_connectome((10, 10), 0.1, 10, True)
        >>> rnn = ConnectomeRNN(10, 10, connectome)
        >>> odenn = ConnectomeODERNN(10, 10, connectome)
    """

    def update_fn(
        self, t: torch.Tensor, h: torch.Tensor, args: Mapping[str, Any]
    ) -> torch.Tensor:
        """ODE function for neural dynamics with neuron type differentiation.

        Args:
            t: Current time point.
            h: Current hidden state.
            args: Additional arguments containing:
                x: Input sequence tensor
                start_time: Integration start time
                end_time: Integration end time

        Returns:
            Rate of change of hidden state (dh/dt)
        """
        h = h.t()

        x = args["x"]
        start_time = args["start_time"]
        end_time = args["end_time"]
        batch_size = x.shape[-1]

        # Get index corresponding to time t
        idx = self._index_from_time(t, x, start_time, end_time)

        # Get input at time t
        batch_indices = torch.arange(batch_size, device=idx.device)
        x_t = x[idx, :, batch_indices].t()

        # Apply sign mask to input
        h_signed = h * self.neuron_sign_mask

        # Compute new hidden state
        h_new = self.ih(x_t) + self.hh(h_signed)

        if self.neuron_nonlinearity_mode == "one":
            h_new = self.neuron_nonlinearity(h_new)
        else:
            for i in range(self.num_neuron_types):
                h_new[self.neuron_type_indices[i]] = self.neuron_nonlinearity[  # type: ignore
                    i
                ](h_new[self.neuron_type_indices[i]])

        # Rectify hidden state to ensure it's non-negative
        # Note: The user is still expected to provide a nonlinearity that
        #   ensures non-negativity, e.g. sigmoid.
        h_new = F.relu(h_new)

        # Compute rate of change of hidden state
        if self.neuron_tau_mode == "per_neuron":
            dhdt = (h_new - h) / self.tau
        else:
            dhdt = (h_new - h) / self.tau[self.neuron_type_mask, :]

        return dhdt.t()

    def forward(
        self,
        x,
        num_evals: int = 2,
        start_time: float = 0.0,
        end_time: float = 1.0,
        neuron_state0: Optional[torch.Tensor] = None,
        neuron_state_init_fn: Optional[Union[str, TensorInitFnType]] = None,
    ):
        """Forward pass of the ConnectomeODERNN layer.

        Wraps the `SparseODERNN.forward` method to change nomenclature.

        See `SparseODERNN.forward` for more details.
        """
        self._clamp_tau()

        return super().forward(
            x=x,
            num_evals=num_evals,
            start_time=start_time,
            end_time=end_time,
            h0=neuron_state0,
            hidden_init_fn=neuron_state_init_fn,
        )

forward(x, num_evals=2, start_time=0.0, end_time=1.0, neuron_state0=None, neuron_state_init_fn=None)

Forward pass of the ConnectomeODERNN layer.

Wraps the SparseODERNN.forward method to change nomenclature.

See SparseODERNN.forward for more details.

Source code in src/bioplnn/models/connectome.py
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
def forward(
    self,
    x,
    num_evals: int = 2,
    start_time: float = 0.0,
    end_time: float = 1.0,
    neuron_state0: Optional[torch.Tensor] = None,
    neuron_state_init_fn: Optional[Union[str, TensorInitFnType]] = None,
):
    """Forward pass of the ConnectomeODERNN layer.

    Wraps the `SparseODERNN.forward` method to change nomenclature.

    See `SparseODERNN.forward` for more details.
    """
    self._clamp_tau()

    return super().forward(
        x=x,
        num_evals=num_evals,
        start_time=start_time,
        end_time=end_time,
        h0=neuron_state0,
        hidden_init_fn=neuron_state_init_fn,
    )

update_fn(t, h, args)

ODE function for neural dynamics with neuron type differentiation.

Parameters:

Name Type Description Default
t Tensor

Current time point.

required
h Tensor

Current hidden state.

required
args Mapping[str, Any]

Additional arguments containing: x: Input sequence tensor start_time: Integration start time end_time: Integration end time

required

Returns:

Type Description
Tensor

Rate of change of hidden state (dh/dt)

Source code in src/bioplnn/models/connectome.py
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
def update_fn(
    self, t: torch.Tensor, h: torch.Tensor, args: Mapping[str, Any]
) -> torch.Tensor:
    """ODE function for neural dynamics with neuron type differentiation.

    Args:
        t: Current time point.
        h: Current hidden state.
        args: Additional arguments containing:
            x: Input sequence tensor
            start_time: Integration start time
            end_time: Integration end time

    Returns:
        Rate of change of hidden state (dh/dt)
    """
    h = h.t()

    x = args["x"]
    start_time = args["start_time"]
    end_time = args["end_time"]
    batch_size = x.shape[-1]

    # Get index corresponding to time t
    idx = self._index_from_time(t, x, start_time, end_time)

    # Get input at time t
    batch_indices = torch.arange(batch_size, device=idx.device)
    x_t = x[idx, :, batch_indices].t()

    # Apply sign mask to input
    h_signed = h * self.neuron_sign_mask

    # Compute new hidden state
    h_new = self.ih(x_t) + self.hh(h_signed)

    if self.neuron_nonlinearity_mode == "one":
        h_new = self.neuron_nonlinearity(h_new)
    else:
        for i in range(self.num_neuron_types):
            h_new[self.neuron_type_indices[i]] = self.neuron_nonlinearity[  # type: ignore
                i
            ](h_new[self.neuron_type_indices[i]])

    # Rectify hidden state to ensure it's non-negative
    # Note: The user is still expected to provide a nonlinearity that
    #   ensures non-negativity, e.g. sigmoid.
    h_new = F.relu(h_new)

    # Compute rate of change of hidden state
    if self.neuron_tau_mode == "per_neuron":
        dhdt = (h_new - h) / self.tau
    else:
        dhdt = (h_new - h) / self.tau[self.neuron_type_mask, :]

    return dhdt.t()

ConnectomeRNN

Bases: _ConnectomeRNNMixIn, SparseRNN

Source code in src/bioplnn/models/connectome.py
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
class ConnectomeRNN(_ConnectomeRNNMixIn, SparseRNN):
    __doc__ = f"""Connectome Recurrent Neural Network.

    An RNN that leverages the `SparseRNN` class to efficiently simulate a
    sparsely-connected network of neurons with biologically-inspired dynamics.

    Key features:
        {key_features_docstring}

    Attributes:
        {attributes_docstring}
    """

    def update_fn(self, x_t: torch.Tensor, h: torch.Tensor) -> torch.Tensor:
        """update hidden state for one timestep.

        Args:
            x_t (torch.Tensor): Input at current timestep.
            h (torch.Tensor): Hidden state at previous timestep.

        Returns:
            torch.Tensor: Updated hidden state.
        """

        # Apply sign mask to input
        h_signed = h * self.neuron_sign_mask

        # Compute new hidden state
        h_new = self.ih(x_t) + self.hh(h_signed)

        if self.neuron_nonlinearity_mode == "one":
            h_new = self.neuron_nonlinearity(h_new)
        else:
            for i in range(self.num_neuron_types):
                h_new[self.neuron_type_indices[i]] = self.neuron_nonlinearity[  # type: ignore
                    i
                ](h_new[self.neuron_type_indices[i]])

        # Rectify hidden state to ensure it's non-negative
        # Note: The user is still expected to provide a nonlinearity that
        #   ensures non-negativity, e.g. sigmoid.
        h_new = F.relu(h_new)

        # Compute rate of change of hidden state
        if self.neuron_tau_mode == "per_neuron":
            tau_inv = 1 / self.tau
            return tau_inv * h_new + (1 - tau_inv) * h
        else:
            tau_inv = 1 / self.tau[self.neuron_type_mask, :]
            return tau_inv * h_new + (1 - tau_inv) * h

    def forward(
        self,
        x,
        num_steps: Optional[int] = None,
        neuron_state0: Optional[torch.Tensor] = None,
        neuron_state_init_fn: Optional[Union[str, TensorInitFnType]] = None,
    ):
        """Forward pass of the ConnectomeODERNN layer.

        Wraps the `SparseODERNN.forward` method to change nomenclature.

        See `SparseODERNN.forward` for more details.
        """
        self._clamp_tau()

        return super().forward(
            x=x,
            num_steps=num_steps,
            h0=neuron_state0,
            hidden_init_fn=neuron_state_init_fn,
        )

forward(x, num_steps=None, neuron_state0=None, neuron_state_init_fn=None)

Forward pass of the ConnectomeODERNN layer.

Wraps the SparseODERNN.forward method to change nomenclature.

See SparseODERNN.forward for more details.

Source code in src/bioplnn/models/connectome.py
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
def forward(
    self,
    x,
    num_steps: Optional[int] = None,
    neuron_state0: Optional[torch.Tensor] = None,
    neuron_state_init_fn: Optional[Union[str, TensorInitFnType]] = None,
):
    """Forward pass of the ConnectomeODERNN layer.

    Wraps the `SparseODERNN.forward` method to change nomenclature.

    See `SparseODERNN.forward` for more details.
    """
    self._clamp_tau()

    return super().forward(
        x=x,
        num_steps=num_steps,
        h0=neuron_state0,
        hidden_init_fn=neuron_state_init_fn,
    )

update_fn(x_t, h)

update hidden state for one timestep.

Parameters:

Name Type Description Default
x_t Tensor

Input at current timestep.

required
h Tensor

Hidden state at previous timestep.

required

Returns:

Type Description
Tensor

torch.Tensor: Updated hidden state.

Source code in src/bioplnn/models/connectome.py
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
def update_fn(self, x_t: torch.Tensor, h: torch.Tensor) -> torch.Tensor:
    """update hidden state for one timestep.

    Args:
        x_t (torch.Tensor): Input at current timestep.
        h (torch.Tensor): Hidden state at previous timestep.

    Returns:
        torch.Tensor: Updated hidden state.
    """

    # Apply sign mask to input
    h_signed = h * self.neuron_sign_mask

    # Compute new hidden state
    h_new = self.ih(x_t) + self.hh(h_signed)

    if self.neuron_nonlinearity_mode == "one":
        h_new = self.neuron_nonlinearity(h_new)
    else:
        for i in range(self.num_neuron_types):
            h_new[self.neuron_type_indices[i]] = self.neuron_nonlinearity[  # type: ignore
                i
            ](h_new[self.neuron_type_indices[i]])

    # Rectify hidden state to ensure it's non-negative
    # Note: The user is still expected to provide a nonlinearity that
    #   ensures non-negativity, e.g. sigmoid.
    h_new = F.relu(h_new)

    # Compute rate of change of hidden state
    if self.neuron_tau_mode == "per_neuron":
        tau_inv = 1 / self.tau
        return tau_inv * h_new + (1 - tau_inv) * h
    else:
        tau_inv = 1 / self.tau[self.neuron_type_mask, :]
        return tau_inv * h_new + (1 - tau_inv) * h