Skip to content

Utilities for efficient sparse matrix operations

sparse

SparseLinear

Bases: Module

Sparse linear layer for efficient operations with sparse matrices.

This layer implements a sparse linear transformation, similar to nn.Linear, but operates on sparse matrices for memory efficiency.

Parameters:

Name Type Description Default
in_features int

Size of the input feature dimension.

required
out_features int

Size of the output feature dimension.

required
connectivity Tensor

Sparse connectivity matrix in COO format.

required
feature_dim int

Dimension on which features reside (0 for rows, 1 for columns).

-1
bias bool

If set to False, no bias term is added.

True
requires_grad bool

Whether the weight and bias parameters require gradient updates.

True
Source code in src/bioplnn/models/sparse.py
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
class SparseLinear(nn.Module):
    """Sparse linear layer for efficient operations with sparse matrices.

    This layer implements a sparse linear transformation, similar to nn.Linear,
    but operates on sparse matrices for memory efficiency.

    Args:
        in_features: Size of the input feature dimension.
        out_features: Size of the output feature dimension.
        connectivity: Sparse connectivity matrix in COO format.
        feature_dim: Dimension on which features reside (0 for rows, 1 for columns).
        bias: If set to False, no bias term is added.
        requires_grad: Whether the weight and bias parameters require gradient updates.
    """

    def __init__(
        self,
        in_features: int,
        out_features: int,
        connectivity: torch.Tensor,
        feature_dim: int = -1,
        bias: bool = True,
        requires_grad: bool = True,
    ):
        super().__init__()

        self.in_features = in_features
        self.out_features = out_features
        self.bias = bias
        self.feature_dim = feature_dim

        # Validate connectivity format
        if connectivity.layout != torch.sparse_coo:
            raise ValueError("connectivity must be in COO format.")

        # Validate input and output sizes against connectivity
        if in_features != connectivity.shape[1]:
            raise ValueError(
                f"Input size ({in_features}) must be equal to the number of columns in connectivity ({connectivity.shape[1]})."
            )
        if out_features != connectivity.shape[0]:
            raise ValueError(
                f"Output size ({out_features}) must be equal to the number of rows in connectivity ({connectivity.shape[0]})."
            )

        # Create sparse matrix
        indices: torch.Tensor
        values: torch.Tensor
        indices, values = torch_sparse.coalesce(
            connectivity.indices().clone(),
            connectivity.values().clone(),
            self.out_features,
            self.in_features,
        )  # type: ignore

        self.indices = nn.Parameter(indices, requires_grad=False)
        self.values = nn.Parameter(values.float(), requires_grad=requires_grad)

        self.bias = (
            nn.Parameter(
                torch.zeros(self.out_features, 1), requires_grad=requires_grad
            )
            if bias
            else None
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """Performs sparse linear transformation on the input tensor.

        Args:
            x: Input tensor of shape (H, *) if feature_dim is 0, otherwise (*, H).

        Returns:
            Output tensor after sparse linear transformation.
        """

        shape = list(x.shape)

        if self.feature_dim != 0:
            permutation = torch.arange(x.dim())
            permutation[self.feature_dim] = 0
            permutation[0] = self.feature_dim
            x = x.permute(*permutation)  # type: ignore

        x = x.flatten(start_dim=1)

        x = torch_sparse.spmm(
            self.indices,
            self.values,
            self.out_features,
            self.in_features,
            x,
        )

        if self.bias is not None:
            x = x + self.bias

        if self.feature_dim != 0:
            x = x.permute(*permutation)  # type: ignore

        shape[self.feature_dim] = self.out_features
        x = x.view(*shape)

        return x

forward(x)

Performs sparse linear transformation on the input tensor.

Parameters:

Name Type Description Default
x Tensor

Input tensor of shape (H, ) if feature_dim is 0, otherwise (, H).

required

Returns:

Type Description
Tensor

Output tensor after sparse linear transformation.

Source code in src/bioplnn/models/sparse.py
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
def forward(self, x: torch.Tensor) -> torch.Tensor:
    """Performs sparse linear transformation on the input tensor.

    Args:
        x: Input tensor of shape (H, *) if feature_dim is 0, otherwise (*, H).

    Returns:
        Output tensor after sparse linear transformation.
    """

    shape = list(x.shape)

    if self.feature_dim != 0:
        permutation = torch.arange(x.dim())
        permutation[self.feature_dim] = 0
        permutation[0] = self.feature_dim
        x = x.permute(*permutation)  # type: ignore

    x = x.flatten(start_dim=1)

    x = torch_sparse.spmm(
        self.indices,
        self.values,
        self.out_features,
        self.in_features,
        x,
    )

    if self.bias is not None:
        x = x + self.bias

    if self.feature_dim != 0:
        x = x.permute(*permutation)  # type: ignore

    shape[self.feature_dim] = self.out_features
    x = x.view(*shape)

    return x

SparseODERNN

Bases: SparseRNN

Sparse Ordinary Differential Equation Recurrent Neural Network.

A continuous-time version of SparseRNN that uses an ODE solver to simulate the dynamics of the network and simultaneously compute the parameter gradients (see torchode.AutoDiffAdjoint).

Parameters:

Name Type Description Default
compile_solver_kwargs Optional[Mapping[str, Any]]

Keyword arguments for torch.compile.

None
Source code in src/bioplnn/models/sparse.py
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
class SparseODERNN(SparseRNN):
    """Sparse Ordinary Differential Equation Recurrent Neural Network.

    A continuous-time version of SparseRNN that uses an ODE solver to
    simulate the dynamics of the network and simultaneously compute the
    parameter gradients (see `torchode.AutoDiffAdjoint`).

    Args:
        compile_solver_kwargs: Keyword arguments for torch.compile.
    """

    def __init__(
        self,
        *args,
        compile_solver_kwargs: Optional[Mapping[str, Any]] = None,
        compile_update_fn_kwargs: Optional[Mapping[str, Any]] = None,
        **kwargs,
    ):
        super().__init__(*args, **kwargs)

        # Compile update_fn
        if compile_update_fn_kwargs is not None:
            self.update_fn = torch.compile(
                self.update_fn, **compile_update_fn_kwargs
            )

        # Define ODE solver
        term = to.ODETerm(self.update_fn, with_args=True)  # type: ignore
        step_method = to.Dopri5(term=term)
        step_size_controller = to.IntegralController(
            atol=1e-6, rtol=1e-3, term=term
        )
        self.solver = to.AutoDiffAdjoint(step_method, step_size_controller)  # type: ignore

        # Compile solver
        if compile_solver_kwargs is not None:
            self.solver = torch.compile(self.solver, **compile_solver_kwargs)

    def _format_x(self, x: torch.Tensor):
        """Format the input tensor to match the expected shape.

        Args:
            x: Input tensor. If 2-dimensional, it is assumed to be of shape
                (batch_size, input_size). If 3-dimensional, it is assumed to be
                of shape (batch_size, sequence_length, input_size) if
                batch_first, else (sequence_length, batch_size, input_size).

        Returns:
            Formatted input tensor of shape (sequence_length, batch_size,
            input_size)

        Raises:
            ValueError: For invalid input dimensions.
        """

        if x.dim() == 2:
            x = x.t()
            x = x.unsqueeze(0)
        elif x.dim() == 3:
            if self.batch_first:
                x = x.permute(1, 2, 0)
            else:
                x = x.permute(0, 2, 1)
        else:
            raise ValueError(
                f"Input tensor must be 2D or 3D, but got {x.dim()} dimensions."
            )

        return x

    def _format_ts(self, ts: torch.Tensor) -> torch.Tensor:
        """Format the time points based on batch_first setting.

        Args:
            ts: Time points tensor.

        Returns:
            Formatted time points tensor.
        """
        if self.batch_first:
            return ts
        else:
            return ts.transpose(0, 1)

    def _index_from_time(
        self,
        t: torch.Tensor,
        x: torch.Tensor,
        start_time: float,
        end_time: float,
    ) -> torch.Tensor:
        """Calculate the index of the input tensor corresponding to the given time.

        Args:
            t: Current time point.
            x: Input tensor.
            start_time: Start time for simulation.
            end_time: End time for simulation.

        Returns:
            Index tensor for selecting the correct input.
        """
        idx = (t - start_time) / (end_time - start_time) * x.shape[0]
        idx[idx == x.shape[0]] = x.shape[0] - 1

        return idx.long()

    def update_fn(
        self, t: torch.Tensor, h: torch.Tensor, args: Mapping[str, Any]
    ) -> torch.Tensor:
        """ODE function for the SparseODERNN.

        Args:
            t: Current time point.
            h: Current hidden state.
            args: Additional arguments including input data.

        Returns:
            Rate of change of the hidden state.
        """
        h = h.t()
        x = args["x"]
        start_time = args["start_time"]
        end_time = args["end_time"]

        idx = self._index_from_time(t, x, start_time, end_time)

        h_new = self.nonlinearity(self.ih(x[idx]) + self.hh(h))

        dhdt = h_new - h

        return dhdt.t()

    def forward(
        self,
        x: torch.Tensor,
        num_evals: int = 2,
        start_time: float = 0.0,
        end_time: float = 1.0,
        h0: Optional[torch.Tensor] = None,
        hidden_init_fn: Optional[Union[str, TensorInitFnType]] = None,
    ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        """Forward pass of the SparseODERNN layer.

        Solves the initial value problem for the ODE defined by update_fn.
        The gradients of the parameters are computed using the adjoint method
        (see `torchode.AutoDiffAdjoint`).

        Args:
            x: Input tensor.
            num_evals: Number of evaluations to return. The default of 2 means
                that the ODE will be evaluated at the start and end of the
                simulation and those values will be returned. Note that this
                does not mean the `update_fn` will be called `num_evals` times.
                It only affects the number of values returned, the step size
                controller determines the number of times the solver will call
                the `update_fn`.
            start_time: Start time for simulation.
            end_time: End time for simulation.
            h0: Initial hidden state.
            hidden_init_fn: Initialization function.

        Returns:
            Hidden states, outputs, and time points.
        """

        # Ensure connectivity matrix is nonnegative
        self._clamp_connectivity()

        # Format input and initialize variables
        x = self._format_x(x)
        batch_size = x.shape[-1]
        device = x.device

        # Define evaluation time points
        if num_evals < 2:
            raise ValueError("num_evals must be greater than 1")
        t_eval = (
            torch.linspace(start_time, end_time, num_evals, device=device)
            .unsqueeze(0)
            .expand(batch_size, -1)
        )

        # Initialize hidden state
        if h0 is None:
            h0 = self.init_hidden(
                batch_size,
                init_fn=hidden_init_fn,
                device=device,
            )

        # Solve ODE
        problem = to.InitialValueProblem(y0=h0, t_eval=t_eval)  # type: ignore
        sol = self.solver.solve(
            problem,
            args={
                "x": x,
                "start_time": start_time,
                "end_time": end_time,
            },
        )
        hs = sol.ys.permute(1, 2, 0)

        # Project to output space
        outs = self.ho(hs.transpose(0, 1).flatten(1))
        outs = outs.view(self.output_size, num_evals, batch_size).transpose(
            0, 1
        )

        # Format outputs
        ts = self._format_ts(sol.ts)
        outs, hs = self._format_result(outs, hs)

        return outs, hs, ts

forward(x, num_evals=2, start_time=0.0, end_time=1.0, h0=None, hidden_init_fn=None)

Forward pass of the SparseODERNN layer.

Solves the initial value problem for the ODE defined by update_fn. The gradients of the parameters are computed using the adjoint method (see torchode.AutoDiffAdjoint).

Parameters:

Name Type Description Default
x Tensor

Input tensor.

required
num_evals int

Number of evaluations to return. The default of 2 means that the ODE will be evaluated at the start and end of the simulation and those values will be returned. Note that this does not mean the update_fn will be called num_evals times. It only affects the number of values returned, the step size controller determines the number of times the solver will call the update_fn.

2
start_time float

Start time for simulation.

0.0
end_time float

End time for simulation.

1.0
h0 Optional[Tensor]

Initial hidden state.

None
hidden_init_fn Optional[Union[str, TensorInitFnType]]

Initialization function.

None

Returns:

Type Description
tuple[Tensor, Tensor, Tensor]

Hidden states, outputs, and time points.

Source code in src/bioplnn/models/sparse.py
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
def forward(
    self,
    x: torch.Tensor,
    num_evals: int = 2,
    start_time: float = 0.0,
    end_time: float = 1.0,
    h0: Optional[torch.Tensor] = None,
    hidden_init_fn: Optional[Union[str, TensorInitFnType]] = None,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
    """Forward pass of the SparseODERNN layer.

    Solves the initial value problem for the ODE defined by update_fn.
    The gradients of the parameters are computed using the adjoint method
    (see `torchode.AutoDiffAdjoint`).

    Args:
        x: Input tensor.
        num_evals: Number of evaluations to return. The default of 2 means
            that the ODE will be evaluated at the start and end of the
            simulation and those values will be returned. Note that this
            does not mean the `update_fn` will be called `num_evals` times.
            It only affects the number of values returned, the step size
            controller determines the number of times the solver will call
            the `update_fn`.
        start_time: Start time for simulation.
        end_time: End time for simulation.
        h0: Initial hidden state.
        hidden_init_fn: Initialization function.

    Returns:
        Hidden states, outputs, and time points.
    """

    # Ensure connectivity matrix is nonnegative
    self._clamp_connectivity()

    # Format input and initialize variables
    x = self._format_x(x)
    batch_size = x.shape[-1]
    device = x.device

    # Define evaluation time points
    if num_evals < 2:
        raise ValueError("num_evals must be greater than 1")
    t_eval = (
        torch.linspace(start_time, end_time, num_evals, device=device)
        .unsqueeze(0)
        .expand(batch_size, -1)
    )

    # Initialize hidden state
    if h0 is None:
        h0 = self.init_hidden(
            batch_size,
            init_fn=hidden_init_fn,
            device=device,
        )

    # Solve ODE
    problem = to.InitialValueProblem(y0=h0, t_eval=t_eval)  # type: ignore
    sol = self.solver.solve(
        problem,
        args={
            "x": x,
            "start_time": start_time,
            "end_time": end_time,
        },
    )
    hs = sol.ys.permute(1, 2, 0)

    # Project to output space
    outs = self.ho(hs.transpose(0, 1).flatten(1))
    outs = outs.view(self.output_size, num_evals, batch_size).transpose(
        0, 1
    )

    # Format outputs
    ts = self._format_ts(sol.ts)
    outs, hs = self._format_result(outs, hs)

    return outs, hs, ts

update_fn(t, h, args)

ODE function for the SparseODERNN.

Parameters:

Name Type Description Default
t Tensor

Current time point.

required
h Tensor

Current hidden state.

required
args Mapping[str, Any]

Additional arguments including input data.

required

Returns:

Type Description
Tensor

Rate of change of the hidden state.

Source code in src/bioplnn/models/sparse.py
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
def update_fn(
    self, t: torch.Tensor, h: torch.Tensor, args: Mapping[str, Any]
) -> torch.Tensor:
    """ODE function for the SparseODERNN.

    Args:
        t: Current time point.
        h: Current hidden state.
        args: Additional arguments including input data.

    Returns:
        Rate of change of the hidden state.
    """
    h = h.t()
    x = args["x"]
    start_time = args["start_time"]
    end_time = args["end_time"]

    idx = self._index_from_time(t, x, start_time, end_time)

    h_new = self.nonlinearity(self.ih(x[idx]) + self.hh(h))

    dhdt = h_new - h

    return dhdt.t()

SparseRNN

Bases: Module

Sparse Recurrent Neural Network (RNN) layer.

A sparse variant of the standard RNN that uses truly sparse linear transformations to compute the input-to-hidden and hidden-to-hidden transformations (and optionally the hidden-to-output transformations).

These sparse transformations are computed using the torch_sparse package and allow for efficient memory usage for large networks.

This allows for the network weights to directly be trained, a departure from GANs, which typically use fixed sparse weights.

Source code in src/bioplnn/models/sparse.py
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
class SparseRNN(nn.Module):
    """Sparse Recurrent Neural Network (RNN) layer.

    A sparse variant of the standard RNN that uses truly sparse linear
    transformations to compute the input-to-hidden and hidden-to-hidden
    transformations (and optionally the hidden-to-output transformations).

    These sparse transformations are computed using the `torch_sparse` package
    and allow for efficient memory usage for large networks.

    This allows for the network weights to directly be trained, a departure
    from GANs, which typically use fixed sparse weights.
    """

    def __init__(
        self,
        input_size: int,
        hidden_size: int,
        connectivity_hh: Union[torch.Tensor, PathLike, str],
        output_size: Optional[int] = None,
        connectivity_ih: Optional[Union[torch.Tensor, PathLike, str]] = None,
        connectivity_ho: Optional[Union[torch.Tensor, PathLike, str]] = None,
        bias_hh: bool = True,
        bias_ih: bool = False,
        bias_ho: bool = True,
        use_dense_ih: bool = False,
        use_dense_ho: bool = False,
        train_hh: bool = True,
        train_ih: bool = True,
        train_ho: bool = True,
        default_hidden_init_fn: str = "zeros",
        nonlinearity: str = "Sigmoid",
        batch_first: bool = True,
    ):
        """Initialize the SparseRNN layer.

        Args:
            input_size: Size of the input features.
            hidden_size: Size of the hidden state.
            connectivity_hh: Connectivity matrix for hidden-to-hidden connections.
            output_size: Size of the output features.
            connectivity_ih: Connectivity matrix for input-to-hidden connections.
            connectivity_ho: Connectivity matrix for hidden-to-output connections.
            bias_hh: Whether to use bias in the hidden-to-hidden connections.
            bias_ih: Whether to use bias in the input-to-hidden connections.
            bias_ho: Whether to use bias in the hidden-to-output connections.
            use_dense_ih: Whether to use a dense linear layer for input-to-hidden connections.
            use_dense_ho: Whether to use a dense linear layer for hidden-to-output connections.
            train_hh: Whether to train the hidden-to-hidden connections.
            train_ih: Whether to train the input-to-hidden connections.
            train_ho: Whether to train the hidden-to-output connections.
            default_hidden_init_fn: Initialization mode for the hidden state.
            nonlinearity: Nonlinearity function.
            batch_first: Whether the input is in (batch_size, seq_len, input_size) format.
        """

        super().__init__()
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.output_size = (
            output_size if output_size is not None else hidden_size
        )
        self.default_hidden_init_fn = default_hidden_init_fn
        self.nonlinearity = get_activation(nonlinearity)
        self.batch_first = batch_first

        connectivity_hh, connectivity_ih, connectivity_ho = (
            self._init_connectivity(
                connectivity_hh, connectivity_ih, connectivity_ho
            )
        )

        self.hh = SparseLinear(
            in_features=hidden_size,
            out_features=hidden_size,
            connectivity=connectivity_hh,
            feature_dim=0,
            bias=bias_hh,
            requires_grad=train_hh,
        )

        if connectivity_ih is not None:
            if use_dense_ih:
                raise ValueError(
                    "use_dense_ih must be False if connectivity_ih is provided"
                )
            self.ih = SparseLinear(
                in_features=input_size,
                out_features=hidden_size,
                connectivity=connectivity_ih,
                feature_dim=0,
                bias=bias_ih,
                requires_grad=train_ih,
            )
        elif use_dense_ih:
            warnings.warn(
                "connectivity_ih is not provided and use_dense_ih is True, "
                "using a dense linear layer for input-to-hidden connections. "
                "This may result in memory issues. If you are running out of "
                "memory, consider providing connectivity_ih or decreasing "
                "input_size."
            )
            if not train_ih:
                raise ValueError(
                    "train_ih must be True if connectivity_ih is not provided"
                )
            self.ih = nn.Linear(
                in_features=input_size,
                out_features=hidden_size,
                bias=bias_ih,
            )
        else:
            if input_size != hidden_size:
                raise ValueError(
                    "input_size must be equal to hidden_size if "
                    "connectivity_ih is not provided and use_dense_ih is False"
                )
            self.ih = nn.Identity()

        if connectivity_ho is not None:
            if use_dense_ho:
                raise ValueError(
                    "use_dense_ho must be False if connectivity_ho is provided"
                )
            if output_size is None:
                raise ValueError(
                    "output_size must be provided if and only if connectivity_ho is provided"
                )
            self.ho = SparseLinear(
                in_features=hidden_size,
                out_features=output_size,
                connectivity=connectivity_ho,
                feature_dim=0,
                bias=bias_ho,
                requires_grad=train_ho,
            )
        elif use_dense_ho:
            warnings.warn(
                "connectivity_ho is not provided and use_dense_ho is True, "
                "using a dense linear layer for hidden-to-output connections. "
                "This may result in memory issues. If you are running out of "
                "memory, consider providing connectivity_ho or decreasing "
                "hidden_size."
            )
            if output_size is None:
                raise ValueError(
                    "output_size must be provided if and only if use_dense_ho is True"
                )
            if not train_ho:
                raise ValueError(
                    "train_ho must be True if connectivity_ho is not provided"
                )
            self.ho = nn.Linear(
                in_features=hidden_size,
                out_features=output_size,
                bias=bias_ho,
            )
        else:
            if output_size is not None and output_size != hidden_size:
                raise ValueError(
                    "output_size should not be provided or should be equal to "
                    "hidden_size if connectivity_ho is not provided and "
                    "use_dense_ho is False"
                )
            self.ho = nn.Identity()

    def _init_connectivity(
        self,
        connectivity_hh: Union[torch.Tensor, PathLike, str],
        connectivity_ih: Optional[Union[torch.Tensor, PathLike, str]] = None,
        connectivity_ho: Optional[Union[torch.Tensor, PathLike, str]] = None,
    ) -> tuple[
        torch.Tensor, Union[torch.Tensor, None], Union[torch.Tensor, None]
    ]:
        """Initialize connectivity matrices.

        Args:
            connectivity_hh: Connectivity matrix for hidden-to-hidden connections or path to load it from.
            connectivity_ih: Connectivity matrix for input-to-hidden connections or path to load it from.
            connectivity_ho: Connectivity matrix for hidden-to-output connections or path to load it from.

        Returns:
            Tuple containing the hidden-to-hidden connectivity tensor and input-to-hidden connectivity tensor (or None).

        Raises:
            ValueError: If connectivity matrices are not in COO format or have
                invalid dimensions.
        """

        connectivity_hh_tensor = load_sparse_tensor(connectivity_hh)

        if connectivity_ih is not None:
            connectivity_ih_tensor = load_sparse_tensor(connectivity_ih)
        else:
            connectivity_ih_tensor = None

        if connectivity_ho is not None:
            connectivity_ho_tensor = load_sparse_tensor(connectivity_ho)
        else:
            connectivity_ho_tensor = None

        # Validate connectivity matrix dimensions
        if not (
            self.hidden_size
            == connectivity_hh_tensor.shape[0]
            == connectivity_hh_tensor.shape[1]
        ):
            raise ValueError(
                "connectivity_ih.shape[0], connectivity_hh.shape[0], and connectivity_hh.shape[1] must be equal"
            )

        if connectivity_ih_tensor is not None and (
            self.input_size != connectivity_ih_tensor.shape[1]
            or self.hidden_size != connectivity_ih_tensor.shape[0]
        ):
            raise ValueError(
                "connectivity_ih.shape[1] and input_size must be equal"
            )

        if connectivity_ho_tensor is not None and (
            self.hidden_size != connectivity_ho_tensor.shape[1]
            or self.output_size != connectivity_ho_tensor.shape[0]
        ):
            raise ValueError(
                "connectivity_ho.shape[0] and output_size must be equal"
            )

        return (
            connectivity_hh_tensor,
            connectivity_ih_tensor,
            connectivity_ho_tensor,
        )

    def init_hidden(
        self,
        batch_size: int,
        init_fn: Optional[Union[str, TensorInitFnType]] = None,
        device: Optional[Union[torch.device, str]] = None,
    ) -> torch.Tensor:
        """Initialize the hidden state.

        Args:
            batch_size: Batch size.
            init_fn: Initialization function.
            device: Device to allocate the hidden state on.

        Returns:
            The initialized hidden state of shape (batch_size, hidden_size).
        """

        if init_fn is None:
            init_fn = self.default_hidden_init_fn

        return init_tensor(
            init_fn, batch_size, self.hidden_size, device=device
        )

    def init_state(
        self,
        num_steps: int,
        batch_size: int,
        h0: Optional[torch.Tensor] = None,
        hidden_init_fn: Optional[Union[str, TensorInitFnType]] = None,
        device: Optional[Union[torch.device, str]] = None,
    ) -> list[Optional[torch.Tensor]]:
        """Initialize the internal state of the network.

        Args:
            num_steps: Number of time steps.
            batch_size: Batch size.
            h0: Initial hidden states.
            hidden_init_fn: Initialization function.
            device: Device to allocate tensors on.

        Returns:
            The initialized hidden states for each time step.
        """

        hs: list[Optional[torch.Tensor]] = [None] * num_steps
        if h0 is None:
            h0 = self.init_hidden(
                batch_size,
                init_fn=hidden_init_fn,
                device=device,
            )
        hs[-1] = h0.t()
        return hs

    def _format_x(self, x: torch.Tensor, num_steps: Optional[int] = None):
        """Format the input tensor to match the expected shape.

        Args:
            x: Input tensor.
            num_steps: Number of time steps.

        Returns:
            The formatted input tensor and corrected number of time steps.

        Raises:
            ValueError: For invalid input dimensions or step counts.
        """
        if x.dim() == 2:
            if num_steps is None or num_steps < 1:
                raise ValueError(
                    "If x is 2D, num_steps must be provided and greater than 0"
                )
            x = x.t()
            x = x.unsqueeze(0).expand((num_steps, -1, -1))
        elif x.dim() == 3:
            if self.batch_first:
                x = x.permute(1, 2, 0)
            else:
                x = x.permute(0, 2, 1)
            if num_steps is not None and num_steps != x.shape[0]:
                raise ValueError(
                    "If x is 3D and num_steps is provided, it must match the "
                    "sequence length."
                )
            num_steps = x.shape[0]
        else:
            raise ValueError(
                f"Input tensor must be 2D or 3D, but got {x.dim()} dimensions."
            )
        return x, num_steps

    def _format_result(
        self,
        outs: torch.Tensor,
        hs: torch.Tensor,
    ) -> tuple[torch.Tensor, torch.Tensor]:
        """Format the hidden states and outputs for output.

        Args:
            hs: Hidden states for each time step.
            outs: Outputs for each time step.

        Returns:
            Formatted outputs and hidden states.
        """

        if self.batch_first:
            return outs.permute(2, 0, 1), hs.permute(2, 0, 1)
        else:
            return outs.permute(0, 2, 1), hs.permute(0, 2, 1)

    def _clamp_connectivity(self) -> None:
        """Ensure the connectivity matrix is nonnegative."""

        self.hh.values.data.clamp_(min=0.0)

    def update_fn(self, x: torch.Tensor, h: torch.Tensor) -> torch.Tensor:
        """Update function for the SparseRNN.

        Args:
            x: Input tensor at current timestep.
            h: Hidden state from previous timestep.

        Returns:
            Updated hidden state.
        """
        return self.nonlinearity(self.ih(x) + self.hh(h))

    def forward(
        self,
        x: torch.Tensor,
        num_steps: Optional[int] = None,
        h0: Optional[torch.Tensor] = None,
        hidden_init_fn: Optional[Union[str, TensorInitFnType]] = None,
    ) -> tuple[torch.Tensor, torch.Tensor]:
        """Forward pass of the SparseRNN layer.

        Args:
            x: Input tensor.
            num_steps: Number of time steps.
            h0: Initial hidden state.
            hidden_init_fn: Initialization function.

        Returns:
            Hidden states and outputs.
        """

        # Ensure connectivity matrix is nonnegative
        self._clamp_connectivity()

        # Format input and initialize variables
        x, num_steps = self._format_x(x, num_steps)
        batch_size = x.shape[-1]
        device = x.device

        hs = self.init_state(
            num_steps,
            batch_size,
            h0=h0,
            hidden_init_fn=hidden_init_fn,
            device=device,
        )

        for t in range(num_steps):
            hs[t] = self.update_fn(x[t], hs[t - 1])  # type: ignore

        assert all(h is not None for h in hs)
        hs = torch.stack(hs)  # type: ignore

        assert hs.shape == (num_steps, self.hidden_size, batch_size)
        outs = self.ho(hs.transpose(0, 1).flatten(1))
        outs = outs.view(self.output_size, num_steps, batch_size).transpose(
            0, 1
        )

        return self._format_result(outs, hs)

__init__(input_size, hidden_size, connectivity_hh, output_size=None, connectivity_ih=None, connectivity_ho=None, bias_hh=True, bias_ih=False, bias_ho=True, use_dense_ih=False, use_dense_ho=False, train_hh=True, train_ih=True, train_ho=True, default_hidden_init_fn='zeros', nonlinearity='Sigmoid', batch_first=True)

Initialize the SparseRNN layer.

Parameters:

Name Type Description Default
input_size int

Size of the input features.

required
hidden_size int

Size of the hidden state.

required
connectivity_hh Union[Tensor, PathLike, str]

Connectivity matrix for hidden-to-hidden connections.

required
output_size Optional[int]

Size of the output features.

None
connectivity_ih Optional[Union[Tensor, PathLike, str]]

Connectivity matrix for input-to-hidden connections.

None
connectivity_ho Optional[Union[Tensor, PathLike, str]]

Connectivity matrix for hidden-to-output connections.

None
bias_hh bool

Whether to use bias in the hidden-to-hidden connections.

True
bias_ih bool

Whether to use bias in the input-to-hidden connections.

False
bias_ho bool

Whether to use bias in the hidden-to-output connections.

True
use_dense_ih bool

Whether to use a dense linear layer for input-to-hidden connections.

False
use_dense_ho bool

Whether to use a dense linear layer for hidden-to-output connections.

False
train_hh bool

Whether to train the hidden-to-hidden connections.

True
train_ih bool

Whether to train the input-to-hidden connections.

True
train_ho bool

Whether to train the hidden-to-output connections.

True
default_hidden_init_fn str

Initialization mode for the hidden state.

'zeros'
nonlinearity str

Nonlinearity function.

'Sigmoid'
batch_first bool

Whether the input is in (batch_size, seq_len, input_size) format.

True
Source code in src/bioplnn/models/sparse.py
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
def __init__(
    self,
    input_size: int,
    hidden_size: int,
    connectivity_hh: Union[torch.Tensor, PathLike, str],
    output_size: Optional[int] = None,
    connectivity_ih: Optional[Union[torch.Tensor, PathLike, str]] = None,
    connectivity_ho: Optional[Union[torch.Tensor, PathLike, str]] = None,
    bias_hh: bool = True,
    bias_ih: bool = False,
    bias_ho: bool = True,
    use_dense_ih: bool = False,
    use_dense_ho: bool = False,
    train_hh: bool = True,
    train_ih: bool = True,
    train_ho: bool = True,
    default_hidden_init_fn: str = "zeros",
    nonlinearity: str = "Sigmoid",
    batch_first: bool = True,
):
    """Initialize the SparseRNN layer.

    Args:
        input_size: Size of the input features.
        hidden_size: Size of the hidden state.
        connectivity_hh: Connectivity matrix for hidden-to-hidden connections.
        output_size: Size of the output features.
        connectivity_ih: Connectivity matrix for input-to-hidden connections.
        connectivity_ho: Connectivity matrix for hidden-to-output connections.
        bias_hh: Whether to use bias in the hidden-to-hidden connections.
        bias_ih: Whether to use bias in the input-to-hidden connections.
        bias_ho: Whether to use bias in the hidden-to-output connections.
        use_dense_ih: Whether to use a dense linear layer for input-to-hidden connections.
        use_dense_ho: Whether to use a dense linear layer for hidden-to-output connections.
        train_hh: Whether to train the hidden-to-hidden connections.
        train_ih: Whether to train the input-to-hidden connections.
        train_ho: Whether to train the hidden-to-output connections.
        default_hidden_init_fn: Initialization mode for the hidden state.
        nonlinearity: Nonlinearity function.
        batch_first: Whether the input is in (batch_size, seq_len, input_size) format.
    """

    super().__init__()
    self.input_size = input_size
    self.hidden_size = hidden_size
    self.output_size = (
        output_size if output_size is not None else hidden_size
    )
    self.default_hidden_init_fn = default_hidden_init_fn
    self.nonlinearity = get_activation(nonlinearity)
    self.batch_first = batch_first

    connectivity_hh, connectivity_ih, connectivity_ho = (
        self._init_connectivity(
            connectivity_hh, connectivity_ih, connectivity_ho
        )
    )

    self.hh = SparseLinear(
        in_features=hidden_size,
        out_features=hidden_size,
        connectivity=connectivity_hh,
        feature_dim=0,
        bias=bias_hh,
        requires_grad=train_hh,
    )

    if connectivity_ih is not None:
        if use_dense_ih:
            raise ValueError(
                "use_dense_ih must be False if connectivity_ih is provided"
            )
        self.ih = SparseLinear(
            in_features=input_size,
            out_features=hidden_size,
            connectivity=connectivity_ih,
            feature_dim=0,
            bias=bias_ih,
            requires_grad=train_ih,
        )
    elif use_dense_ih:
        warnings.warn(
            "connectivity_ih is not provided and use_dense_ih is True, "
            "using a dense linear layer for input-to-hidden connections. "
            "This may result in memory issues. If you are running out of "
            "memory, consider providing connectivity_ih or decreasing "
            "input_size."
        )
        if not train_ih:
            raise ValueError(
                "train_ih must be True if connectivity_ih is not provided"
            )
        self.ih = nn.Linear(
            in_features=input_size,
            out_features=hidden_size,
            bias=bias_ih,
        )
    else:
        if input_size != hidden_size:
            raise ValueError(
                "input_size must be equal to hidden_size if "
                "connectivity_ih is not provided and use_dense_ih is False"
            )
        self.ih = nn.Identity()

    if connectivity_ho is not None:
        if use_dense_ho:
            raise ValueError(
                "use_dense_ho must be False if connectivity_ho is provided"
            )
        if output_size is None:
            raise ValueError(
                "output_size must be provided if and only if connectivity_ho is provided"
            )
        self.ho = SparseLinear(
            in_features=hidden_size,
            out_features=output_size,
            connectivity=connectivity_ho,
            feature_dim=0,
            bias=bias_ho,
            requires_grad=train_ho,
        )
    elif use_dense_ho:
        warnings.warn(
            "connectivity_ho is not provided and use_dense_ho is True, "
            "using a dense linear layer for hidden-to-output connections. "
            "This may result in memory issues. If you are running out of "
            "memory, consider providing connectivity_ho or decreasing "
            "hidden_size."
        )
        if output_size is None:
            raise ValueError(
                "output_size must be provided if and only if use_dense_ho is True"
            )
        if not train_ho:
            raise ValueError(
                "train_ho must be True if connectivity_ho is not provided"
            )
        self.ho = nn.Linear(
            in_features=hidden_size,
            out_features=output_size,
            bias=bias_ho,
        )
    else:
        if output_size is not None and output_size != hidden_size:
            raise ValueError(
                "output_size should not be provided or should be equal to "
                "hidden_size if connectivity_ho is not provided and "
                "use_dense_ho is False"
            )
        self.ho = nn.Identity()

forward(x, num_steps=None, h0=None, hidden_init_fn=None)

Forward pass of the SparseRNN layer.

Parameters:

Name Type Description Default
x Tensor

Input tensor.

required
num_steps Optional[int]

Number of time steps.

None
h0 Optional[Tensor]

Initial hidden state.

None
hidden_init_fn Optional[Union[str, TensorInitFnType]]

Initialization function.

None

Returns:

Type Description
tuple[Tensor, Tensor]

Hidden states and outputs.

Source code in src/bioplnn/models/sparse.py
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
def forward(
    self,
    x: torch.Tensor,
    num_steps: Optional[int] = None,
    h0: Optional[torch.Tensor] = None,
    hidden_init_fn: Optional[Union[str, TensorInitFnType]] = None,
) -> tuple[torch.Tensor, torch.Tensor]:
    """Forward pass of the SparseRNN layer.

    Args:
        x: Input tensor.
        num_steps: Number of time steps.
        h0: Initial hidden state.
        hidden_init_fn: Initialization function.

    Returns:
        Hidden states and outputs.
    """

    # Ensure connectivity matrix is nonnegative
    self._clamp_connectivity()

    # Format input and initialize variables
    x, num_steps = self._format_x(x, num_steps)
    batch_size = x.shape[-1]
    device = x.device

    hs = self.init_state(
        num_steps,
        batch_size,
        h0=h0,
        hidden_init_fn=hidden_init_fn,
        device=device,
    )

    for t in range(num_steps):
        hs[t] = self.update_fn(x[t], hs[t - 1])  # type: ignore

    assert all(h is not None for h in hs)
    hs = torch.stack(hs)  # type: ignore

    assert hs.shape == (num_steps, self.hidden_size, batch_size)
    outs = self.ho(hs.transpose(0, 1).flatten(1))
    outs = outs.view(self.output_size, num_steps, batch_size).transpose(
        0, 1
    )

    return self._format_result(outs, hs)

init_hidden(batch_size, init_fn=None, device=None)

Initialize the hidden state.

Parameters:

Name Type Description Default
batch_size int

Batch size.

required
init_fn Optional[Union[str, TensorInitFnType]]

Initialization function.

None
device Optional[Union[device, str]]

Device to allocate the hidden state on.

None

Returns:

Type Description
Tensor

The initialized hidden state of shape (batch_size, hidden_size).

Source code in src/bioplnn/models/sparse.py
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
def init_hidden(
    self,
    batch_size: int,
    init_fn: Optional[Union[str, TensorInitFnType]] = None,
    device: Optional[Union[torch.device, str]] = None,
) -> torch.Tensor:
    """Initialize the hidden state.

    Args:
        batch_size: Batch size.
        init_fn: Initialization function.
        device: Device to allocate the hidden state on.

    Returns:
        The initialized hidden state of shape (batch_size, hidden_size).
    """

    if init_fn is None:
        init_fn = self.default_hidden_init_fn

    return init_tensor(
        init_fn, batch_size, self.hidden_size, device=device
    )

init_state(num_steps, batch_size, h0=None, hidden_init_fn=None, device=None)

Initialize the internal state of the network.

Parameters:

Name Type Description Default
num_steps int

Number of time steps.

required
batch_size int

Batch size.

required
h0 Optional[Tensor]

Initial hidden states.

None
hidden_init_fn Optional[Union[str, TensorInitFnType]]

Initialization function.

None
device Optional[Union[device, str]]

Device to allocate tensors on.

None

Returns:

Type Description
list[Optional[Tensor]]

The initialized hidden states for each time step.

Source code in src/bioplnn/models/sparse.py
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
def init_state(
    self,
    num_steps: int,
    batch_size: int,
    h0: Optional[torch.Tensor] = None,
    hidden_init_fn: Optional[Union[str, TensorInitFnType]] = None,
    device: Optional[Union[torch.device, str]] = None,
) -> list[Optional[torch.Tensor]]:
    """Initialize the internal state of the network.

    Args:
        num_steps: Number of time steps.
        batch_size: Batch size.
        h0: Initial hidden states.
        hidden_init_fn: Initialization function.
        device: Device to allocate tensors on.

    Returns:
        The initialized hidden states for each time step.
    """

    hs: list[Optional[torch.Tensor]] = [None] * num_steps
    if h0 is None:
        h0 = self.init_hidden(
            batch_size,
            init_fn=hidden_init_fn,
            device=device,
        )
    hs[-1] = h0.t()
    return hs

update_fn(x, h)

Update function for the SparseRNN.

Parameters:

Name Type Description Default
x Tensor

Input tensor at current timestep.

required
h Tensor

Hidden state from previous timestep.

required

Returns:

Type Description
Tensor

Updated hidden state.

Source code in src/bioplnn/models/sparse.py
473
474
475
476
477
478
479
480
481
482
483
def update_fn(self, x: torch.Tensor, h: torch.Tensor) -> torch.Tensor:
    """Update function for the SparseRNN.

    Args:
        x: Input tensor at current timestep.
        h: Hidden state from previous timestep.

    Returns:
        Updated hidden state.
    """
    return self.nonlinearity(self.ih(x) + self.hh(h))