Skip to content

Commit ade10d7

Browse files
committed
add IDS SPI + config YAML
1 parent efd6206 commit ade10d7

File tree

6 files changed

+168
-1
lines changed

6 files changed

+168
-1
lines changed

pyspi/config.yaml

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1391,3 +1391,15 @@
13911391
- orth: True
13921392
log: True
13931393
absolute: True
1394+
1395+
# Interdependence score
1396+
InterDependenceScore:
1397+
labels:
1398+
- unsigned
1399+
- undirected
1400+
- nonlinear
1401+
dependencies:
1402+
configs: # default params
1403+
- terms: 6
1404+
pnorm: 'max'
1405+
bandwidth: 0.5

pyspi/fast_config.yaml

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1291,3 +1291,16 @@
12911291
- orth: True
12921292
log: True
12931293
absolute: True
1294+
1295+
# Interdependence score
1296+
InterDependenceScore:
1297+
labels:
1298+
- unsigned
1299+
- undirected
1300+
- nonlinear
1301+
dependencies:
1302+
configs: # default params
1303+
- terms: 6
1304+
pnorm: 'max'
1305+
bandwidth: 0.5
1306+

pyspi/lib/ids/__init__.py

Whitespace-only changes.

pyspi/lib/ids/dependence.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
from .numpy_dependence import compute_IDS_numpy
2+
3+
def compute_IDS(X, Y=None, num_terms=6, p_norm='max',
4+
p_val=False, num_tests=100, bandwidth_term=1/2):
5+
"""Compute IDS between all pairs of variables in X (or between X and Y).
6+
7+
Taken from the implementation in: https://github.com/aradha/interdependence_scores
8+
9+
10+
Parameters:
11+
X: np.ndarray or torch.Tensor
12+
Y: np.ndarray or torch.Tensor (optional)
13+
num_terms: Number of terms for Taylor series approximation (optional)
14+
p_norm: String 'max' if using IDS-max. 1 or 2 for IDS-1, IDS-2, respectively. (optional)
15+
p_val: Boolean. Indicates whether to compute p-values using permutation tests
16+
num_tests: Number of permutation tests if p_val=True
17+
bandwidth_term: Constant term in Gaussian kernel
18+
Returns:
19+
IDS matrix, p-value matrix (if p_val=True)
20+
"""
21+
return compute_IDS_numpy(X, Y=Y, num_terms=num_terms, p_norm=p_norm,
22+
p_val=p_val, num_tests=num_tests, bandwidth_term=bandwidth_term)

pyspi/lib/ids/numpy_dependence.py

Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,94 @@
1+
import numpy as np
2+
import math
3+
import sys
4+
from tqdm import tqdm
5+
6+
SEED = 1717
7+
np.random.seed(SEED)
8+
9+
EPSILON = sys.float_info.epsilon
10+
11+
def transform(y, num_terms=6, bandwidth_term=1/2):
12+
B = bandwidth_term
13+
exp = np.exp(-B * y**2)
14+
terms = []
15+
for i in range(num_terms):
16+
terms.append(exp * (y)**i / math.sqrt(math.factorial(i) *1.))
17+
y_ = np.concatenate(terms, axis=-1)
18+
return y_
19+
20+
def center(X):
21+
return X - np.mean(X, axis=0, keepdims=True)
22+
23+
24+
def compute_p_val(C, X, Y=None, num_terms=6, p_norm='max', n_tests=100, bandwidth_term=1/2):
25+
26+
gt = C
27+
count = 0
28+
29+
n, dx = X.shape
30+
for i in tqdm(range(n_tests)):
31+
32+
# Used to shuffle data
33+
random_noise = np.random.normal(size=(n, dx))
34+
permutations = np.argsort(random_noise, axis=0)
35+
X_permuted = X[permutations, np.arange(dx)[None, :]]
36+
37+
if Y is not None:
38+
n, dy = Y.shape
39+
random_noise = np.random.normal(size=(n, dy))
40+
permutations = np.argsort(random_noise, axis=0)
41+
Y_permuted = Y[permutations, np.arange(dy)[None, :]]
42+
null = compute_IDS_numpy(X_permuted, Y=Y_permuted, num_terms=num_terms,
43+
p_norm=p_norm, bandwidth_term=bandwidth_term)
44+
else:
45+
null = compute_IDS_numpy(X_permuted, Y=Y, num_terms=num_terms,
46+
p_norm=p_norm, bandwidth_term=bandwidth_term)
47+
48+
49+
count += np.where(null > gt, 1, 0)
50+
51+
p_vals = count / n_tests
52+
return p_vals
53+
54+
55+
def compute_IDS_numpy(X, Y=None, num_terms=6, p_norm='max',
56+
p_val=False, num_tests=100, bandwidth_term=1/2):
57+
n, dx = X.shape
58+
X_t = transform(X, num_terms=num_terms, bandwidth_term=bandwidth_term)
59+
X_t = center(X_t)
60+
61+
if Y is not None:
62+
_, dy = Y.shape
63+
Y_t = transform(Y, num_terms=num_terms, bandwidth_term=bandwidth_term)
64+
Y_t = center(Y_t)
65+
cov = X_t.T @ Y_t
66+
X_std = np.sqrt(np.sum(X_t**2, axis=0))
67+
Y_std = np.sqrt(np.sum(Y_t**2, axis=0))
68+
correlations = cov / (X_std.reshape(-1, 1) + EPSILON)
69+
C = correlations / (Y_std.reshape(1, -1) + EPSILON)
70+
C = C.reshape(num_terms, dx, num_terms, dy)
71+
else:
72+
C = np.corrcoef(X_t.T)
73+
C = C.reshape(num_terms, dx, num_terms, dx)
74+
75+
C = np.nan_to_num(C, nan=0, posinf=0, neginf=0)
76+
C = np.abs(C)
77+
78+
if p_norm == 'max':
79+
C = np.amax(C, axis=(0, 2))
80+
elif p_norm == 2:
81+
C = C**2
82+
C = np.mean(C, axis=0)
83+
C = np.mean(C, axis=1)
84+
C = np.sqrt(C)
85+
elif p_norm == 1:
86+
C = np.mean(C, axis=0)
87+
C = np.mean(C, axis=1)
88+
89+
if p_val:
90+
p_vals = compute_p_val(C, X, Y=Y, num_terms=num_terms, p_norm=p_norm,
91+
n_tests=num_tests, bandwidth_term=bandwidth_term)
92+
return C, p_vals
93+
else:
94+
return C

pyspi/statistics/misc.py

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from sklearn.metrics import mean_squared_error
99
from sklearn import linear_model
1010
import mne.connectivity as mnec
11+
from pyspi.lib.ids.dependence import compute_IDS
1112

1213
from pyspi.base import (
1314
Directed,
@@ -147,7 +148,7 @@ def bivariate(self, data, i=None, j=None):
147148

148149

149150
class PowerEnvelopeCorrelation(Undirected, Unsigned):
150-
humanname = "Power envelope correlation"
151+
name = "Power envelope correlation"
151152
identifier = "pec"
152153
labels = ["unsigned", "misc", "undirected"]
153154

@@ -173,3 +174,28 @@ def multivariate(self, data):
173174
)
174175
np.fill_diagonal(adj, np.nan)
175176
return adj
177+
178+
class InterDependenceScore(Undirected, Unsigned):
179+
name = "Interdependence score"
180+
identifier = "ids"
181+
labels = ["unsigned", "misc", "undirected", "nonlinear"]
182+
183+
def __init__(
184+
self,
185+
terms=6,
186+
pnorm='max',
187+
bandwidth=0.5
188+
):
189+
self._num_terms = terms
190+
self._p_norm = pnorm
191+
self._bandwidth_term = bandwidth
192+
193+
194+
@parse_multivariate
195+
def multivariate(self, data):
196+
# reshape for the compute_IDS function which expects shape (obs, proc)
197+
z = np.squeeze(data.to_numpy(), axis=2).T
198+
ids = compute_IDS(z, num_terms=self._num_terms, p_norm=self._p_norm,
199+
bandwidth_term=self._bandwidth_term)
200+
return ids
201+

0 commit comments

Comments
 (0)