|
1 |
| -#!/usr/bin/env python3 |
2 |
| -# -*- coding: utf-8 -*- |
3 |
| -# ────────────────────────────────────────────────────────────────────────────── |
4 |
| -from contextlib import ExitStack |
5 |
| - |
6 |
| -import torch |
7 |
| -from torch import Tensor |
8 |
| - |
9 |
| -from .impl import call_to_impl_cdist_topk |
10 |
| - |
11 |
| -# ────────────────────────────────────────────────────────────────────────────── |
12 |
| -__all__ = ["mle_id", "mle_id_avg"] |
13 |
| - |
14 |
| - |
15 |
| -# ────────────────────────────────────────────────────────────────────────────── |
16 |
| -def mle_id( |
17 |
| - x: Tensor, |
18 |
| - nneigh: int = 2, |
19 |
| - twonn_fix: bool = False, |
20 |
| - differentiable: bool = False, |
21 |
| - impl: str = "torch", |
22 |
| -) -> Tensor: |
23 |
| - |
24 |
| - with ExitStack() as stack: |
25 |
| - stack.enter_context(torch.no_grad()) if not differentiable else None |
26 |
| - |
27 |
| - ks: Tensor = call_to_impl_cdist_topk[impl](x, nneigh, False)[:, 1:] |
28 |
| - |
29 |
| - if twonn_fix and nneigh == 2: |
30 |
| - return -2 * ks.size(0) / torch.log(torch.div(*torch.unbind(ks, 1))).sum() |
31 |
| - |
32 |
| - return (2 * (nneigh - 1) / torch.log(ks[:, -1].view(-1, 1) / ks).sum(1)).mean() |
33 |
| - |
34 |
| - |
35 |
| -# ────────────────────────────────────────────────────────────────────────────── |
36 |
| - |
37 |
| - |
38 |
| -def mle_id_avg( |
39 |
| - x: Tensor, |
40 |
| - nneigh_min: int = 2, |
41 |
| - nneigh_max: int = 10, |
42 |
| - twonn_fix: bool = False, |
43 |
| - differentiable: bool = False, |
44 |
| - impl: str = "torch", |
45 |
| -) -> Tensor: |
46 |
| - |
47 |
| - with ExitStack() as stack: |
48 |
| - stack.enter_context(torch.no_grad()) if not differentiable else None |
49 |
| - |
50 |
| - ks: Tensor = call_to_impl_cdist_topk[impl](x, nneigh_max, False)[:, 1:] |
51 |
| - runs = [ |
52 |
| - ( |
53 |
| - 2 |
54 |
| - * (nneigh_max - 1 - i) |
55 |
| - / torch.log( |
56 |
| - ks[:, -1 - i].view(-1, 1) / (ks[:, :-i] if i != 0 else ks) |
57 |
| - ).sum(1) |
58 |
| - ).mean() |
59 |
| - for i in range(nneigh_max - nneigh_min + (not twonn_fix)) |
60 |
| - ] |
61 |
| - |
62 |
| - if twonn_fix and nneigh_min == 2: |
63 |
| - runs.append( |
64 |
| - -2 |
65 |
| - * ks.size(0) |
66 |
| - / torch.log(torch.div(*torch.unbind(ks[:, 0:2], 1))).sum() |
67 |
| - ) |
68 |
| - |
69 |
| - return torch.stack(runs).nanmean() |
| 1 | +#!/usr/bin/env python3 |
| 2 | +# -*- coding: utf-8 -*- |
| 3 | +# ────────────────────────────────────────────────────────────────────────────── |
| 4 | +from contextlib import ExitStack |
| 5 | + |
| 6 | +import torch |
| 7 | +from torch import Tensor |
| 8 | + |
| 9 | +from .impl import call_to_impl_cdist_topk |
| 10 | + |
| 11 | +# ────────────────────────────────────────────────────────────────────────────── |
| 12 | +__all__ = ["mle_id", "mle_id_avg"] |
| 13 | + |
| 14 | + |
| 15 | +# ────────────────────────────────────────────────────────────────────────────── |
| 16 | +def mle_id( |
| 17 | + x: Tensor, |
| 18 | + nneigh: int = 2, |
| 19 | + twonn_fix: bool = False, |
| 20 | + differentiable: bool = False, |
| 21 | + impl: str = "torch", |
| 22 | +) -> Tensor: |
| 23 | + |
| 24 | + with ExitStack() as stack: |
| 25 | + stack.enter_context(torch.no_grad()) if not differentiable else None |
| 26 | + |
| 27 | + ks: Tensor = call_to_impl_cdist_topk[impl](x, nneigh, False)[:, 1:] |
| 28 | + |
| 29 | + if twonn_fix and nneigh == 2: |
| 30 | + return -2 * ks.size(0) / torch.log(torch.div(*torch.unbind(ks, 1))).sum() |
| 31 | + |
| 32 | + return (2 * (nneigh - 1) / torch.log(ks[:, -1].view(-1, 1) / ks).sum(1)).mean() |
| 33 | + |
| 34 | + |
| 35 | +# ────────────────────────────────────────────────────────────────────────────── |
| 36 | + |
| 37 | + |
| 38 | +def mle_id_avg( |
| 39 | + x: Tensor, |
| 40 | + nneigh_min: int = 2, |
| 41 | + nneigh_max: int = 10, |
| 42 | + twonn_fix: bool = False, |
| 43 | + differentiable: bool = False, |
| 44 | + impl: str = "torch", |
| 45 | +) -> Tensor: |
| 46 | + |
| 47 | + with ExitStack() as stack: |
| 48 | + stack.enter_context(torch.no_grad()) if not differentiable else None |
| 49 | + |
| 50 | + twonn_sep: bool = twonn_fix and nneigh_min == 2 |
| 51 | + |
| 52 | + ks: Tensor = call_to_impl_cdist_topk[impl](x, nneigh_max, False)[:, 1:] |
| 53 | + runs = [ |
| 54 | + ( |
| 55 | + 2 |
| 56 | + * (nneigh_max - 1 - i) |
| 57 | + / torch.log( |
| 58 | + ks[:, -1 - i].view(-1, 1) / (ks[:, :-i] if i != 0 else ks) |
| 59 | + ).sum(1) |
| 60 | + ).mean() |
| 61 | + for i in range(nneigh_max - nneigh_min + (not twonn_sep)) |
| 62 | + ] |
| 63 | + |
| 64 | + if twonn_sep: |
| 65 | + runs.append( |
| 66 | + -2 |
| 67 | + * ks.size(0) |
| 68 | + / torch.log(torch.div(*torch.unbind(ks[:, 0:2], 1))).sum() |
| 69 | + ) |
| 70 | + |
| 71 | + return torch.stack(runs).nanmean() |
0 commit comments