Skip to content

Commit 41a328d

Browse files
committed
Fixes
1 parent 4f2a885 commit 41a328d

File tree

3 files changed

+73
-71
lines changed

3 files changed

+73
-71
lines changed

README.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -6,4 +6,4 @@ Fast, GPU-friendly, differentiable computation of Intrinsic Dimension via Maximu
66

77
### References
88
- [E. Levina, P. Bickel; "Maximum Likelihood Estimation of Intrinsic Dimension", Advances in Neural Information Processing Systems, 2004](https://papers.nips.cc/paper_files/paper/2004/hash/74934548253bcab8490ebd74afed7031-Abstract.html)
9-
- [E. Facco, M. d'Errico, A. Rodriguez, A. Laio; "Estimating the intrinsic dimension of datasets by a minimal neighborhood information", Mature Scientific Reports, 2017](https://www.nature.com/articles/s41598-017-11873-y)
9+
- [E. Facco, M. d'Errico, A. Rodriguez, A. Laio; "Estimating the intrinsic dimension of datasets by a minimal neighborhood information", Nature Scientific Reports, 2017](https://www.nature.com/articles/s41598-017-11873-y)

fastwonn/mle.py

+71-69
Original file line numberDiff line numberDiff line change
@@ -1,69 +1,71 @@
1-
#!/usr/bin/env python3
2-
# -*- coding: utf-8 -*-
3-
# ──────────────────────────────────────────────────────────────────────────────
4-
from contextlib import ExitStack
5-
6-
import torch
7-
from torch import Tensor
8-
9-
from .impl import call_to_impl_cdist_topk
10-
11-
# ──────────────────────────────────────────────────────────────────────────────
12-
__all__ = ["mle_id", "mle_id_avg"]
13-
14-
15-
# ──────────────────────────────────────────────────────────────────────────────
16-
def mle_id(
17-
x: Tensor,
18-
nneigh: int = 2,
19-
twonn_fix: bool = False,
20-
differentiable: bool = False,
21-
impl: str = "torch",
22-
) -> Tensor:
23-
24-
with ExitStack() as stack:
25-
stack.enter_context(torch.no_grad()) if not differentiable else None
26-
27-
ks: Tensor = call_to_impl_cdist_topk[impl](x, nneigh, False)[:, 1:]
28-
29-
if twonn_fix and nneigh == 2:
30-
return -2 * ks.size(0) / torch.log(torch.div(*torch.unbind(ks, 1))).sum()
31-
32-
return (2 * (nneigh - 1) / torch.log(ks[:, -1].view(-1, 1) / ks).sum(1)).mean()
33-
34-
35-
# ──────────────────────────────────────────────────────────────────────────────
36-
37-
38-
def mle_id_avg(
39-
x: Tensor,
40-
nneigh_min: int = 2,
41-
nneigh_max: int = 10,
42-
twonn_fix: bool = False,
43-
differentiable: bool = False,
44-
impl: str = "torch",
45-
) -> Tensor:
46-
47-
with ExitStack() as stack:
48-
stack.enter_context(torch.no_grad()) if not differentiable else None
49-
50-
ks: Tensor = call_to_impl_cdist_topk[impl](x, nneigh_max, False)[:, 1:]
51-
runs = [
52-
(
53-
2
54-
* (nneigh_max - 1 - i)
55-
/ torch.log(
56-
ks[:, -1 - i].view(-1, 1) / (ks[:, :-i] if i != 0 else ks)
57-
).sum(1)
58-
).mean()
59-
for i in range(nneigh_max - nneigh_min + (not twonn_fix))
60-
]
61-
62-
if twonn_fix and nneigh_min == 2:
63-
runs.append(
64-
-2
65-
* ks.size(0)
66-
/ torch.log(torch.div(*torch.unbind(ks[:, 0:2], 1))).sum()
67-
)
68-
69-
return torch.stack(runs).nanmean()
1+
#!/usr/bin/env python3
2+
# -*- coding: utf-8 -*-
3+
# ──────────────────────────────────────────────────────────────────────────────
4+
from contextlib import ExitStack
5+
6+
import torch
7+
from torch import Tensor
8+
9+
from .impl import call_to_impl_cdist_topk
10+
11+
# ──────────────────────────────────────────────────────────────────────────────
12+
__all__ = ["mle_id", "mle_id_avg"]
13+
14+
15+
# ──────────────────────────────────────────────────────────────────────────────
16+
def mle_id(
17+
x: Tensor,
18+
nneigh: int = 2,
19+
twonn_fix: bool = False,
20+
differentiable: bool = False,
21+
impl: str = "torch",
22+
) -> Tensor:
23+
24+
with ExitStack() as stack:
25+
stack.enter_context(torch.no_grad()) if not differentiable else None
26+
27+
ks: Tensor = call_to_impl_cdist_topk[impl](x, nneigh, False)[:, 1:]
28+
29+
if twonn_fix and nneigh == 2:
30+
return -2 * ks.size(0) / torch.log(torch.div(*torch.unbind(ks, 1))).sum()
31+
32+
return (2 * (nneigh - 1) / torch.log(ks[:, -1].view(-1, 1) / ks).sum(1)).mean()
33+
34+
35+
# ──────────────────────────────────────────────────────────────────────────────
36+
37+
38+
def mle_id_avg(
39+
x: Tensor,
40+
nneigh_min: int = 2,
41+
nneigh_max: int = 10,
42+
twonn_fix: bool = False,
43+
differentiable: bool = False,
44+
impl: str = "torch",
45+
) -> Tensor:
46+
47+
with ExitStack() as stack:
48+
stack.enter_context(torch.no_grad()) if not differentiable else None
49+
50+
twonn_sep: bool = twonn_fix and nneigh_min == 2
51+
52+
ks: Tensor = call_to_impl_cdist_topk[impl](x, nneigh_max, False)[:, 1:]
53+
runs = [
54+
(
55+
2
56+
* (nneigh_max - 1 - i)
57+
/ torch.log(
58+
ks[:, -1 - i].view(-1, 1) / (ks[:, :-i] if i != 0 else ks)
59+
).sum(1)
60+
).mean()
61+
for i in range(nneigh_max - nneigh_min + (not twonn_sep))
62+
]
63+
64+
if twonn_sep:
65+
runs.append(
66+
-2
67+
* ks.size(0)
68+
/ torch.log(torch.div(*torch.unbind(ks[:, 0:2], 1))).sum()
69+
)
70+
71+
return torch.stack(runs).nanmean()

setup.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ def read(fname):
2222

2323
setup(
2424
name=PACKAGENAME,
25-
version="0.0.7",
25+
version="0.0.8",
2626
author="Emanuele Ballarin",
2727
author_email="[email protected]",
2828
url="https://github.com/emaballarin/fastwonn",

0 commit comments

Comments
 (0)