Skip to content

Commit

Permalink
Format
Browse files Browse the repository at this point in the history
  • Loading branch information
edadaltocg committed Jun 19, 2024
1 parent 362f697 commit 46bba8a
Show file tree
Hide file tree
Showing 19 changed files with 44 additions and 4 deletions.
1 change: 1 addition & 0 deletions examples/list_resources.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Example on how to list all resources available in `detectors` package."""

import detectors

if __name__ == "__main__":
Expand Down
1 change: 1 addition & 0 deletions scripts/parse_arxiv.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""Requirements
- feedparser installed: pip install feedparser
"""

import argparse
import json
import logging
Expand Down
1 change: 1 addition & 0 deletions scripts/push_model_to_hf_hub.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
- Jinja2 installed
- Git LFS installed
"""

import argparse
import json
import logging
Expand Down
30 changes: 26 additions & 4 deletions src/detectors/aggregations/combine_and_conquer.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,15 @@ def p_value_fn(test_statistic: np.ndarray, X: np.ndarray, w=None):
y_ecdf = np.concatenate([np.arange(1, X.shape[0] + 1).reshape(-1, 1) / X.shape[0]] * X.shape[1], axis=1)
if w is not None:
y_ecdf = y_ecdf * w.reshape(1, -1)
return np.concatenate(list(map(lambda xx: np.interp(*xx).reshape(-1, 1), zip(test_statistic.T, X.T, y_ecdf.T))), 1)
return np.concatenate(
list(
map(
lambda xx: np.interp(*xx).reshape(-1, 1),
zip(test_statistic.T, X.T, y_ecdf.T),
)
),
1,
)


def fisher_method(p_values: np.ndarray):
Expand Down Expand Up @@ -153,7 +161,10 @@ def simes_tau_method(p_values: np.ndarray):
Returns:
np.ndarray (n,): combined p-values
"""
tau = np.min(np.sort(p_values, axis=1) / np.arange(1, p_values.shape[1] + 1) * p_values.shape[1], 1)
tau = np.min(
np.sort(p_values, axis=1) / np.arange(1, p_values.shape[1] + 1) * p_values.shape[1],
1,
)
return tau


Expand All @@ -169,6 +180,7 @@ def geometric_mean_tau_method(p_values: np.ndarray):
tau = np.prod(p_values, axis=1) ** (1 / p_values.shape[1])
return tau


def rho(p_values):
k = p_values.shape[1]
phi = stats.norm.ppf(p_values)
Expand All @@ -178,7 +190,8 @@ def rho(p_values):
def hartung(p_values, r):
k = p_values.shape[1]
t = stats.norm.ppf(p_values)
return np.sum(t, axis=1) / np.sqrt((1 - r) * k + r * k**2
return np.sum(t, axis=1) / np.sqrt((1 - r) * k + r * k**2)


def get_combine_p_values_fn(method_name: str):
method_name = method_name.lower()
Expand All @@ -202,4 +215,13 @@ def get_combine_p_values_fn(method_name: str):
raise NotImplementedError(f"method {method_name} not implemented")


ensemble_names = ["fisher", "stouffer", "tippet", "wilkinson", "edgington", "pearson", "simes", "geometric_mean"]
ensemble_names = [
"fisher",
"stouffer",
"tippet",
"wilkinson",
"edgington",
"pearson",
"simes",
"geometric_mean",
]
1 change: 1 addition & 0 deletions src/detectors/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
- `CHECKPOINTS_DIR`: The directory where the checkpoints are stored.
- `RESULTS_DIR`: The directory where the results are stored.
"""

import os

HOME = os.path.dirname(os.path.dirname(os.path.dirname(__file__)))
Expand Down
1 change: 1 addition & 0 deletions src/detectors/data/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""
Datasets module.
"""

import logging
import os
from enum import Enum
Expand Down
1 change: 1 addition & 0 deletions src/detectors/data/imagenetlt.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""From https://github.com/KaihuaTang/Long-Tailed-Recognition.pytorch/blob/master/classification/data/dataloader.py"""

import os

import numpy as np
Expand Down
1 change: 1 addition & 0 deletions src/detectors/eval.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""
Module containing evaluation metrics.
"""

from typing import Dict, Union

import numpy as np
Expand Down
1 change: 1 addition & 0 deletions src/detectors/methods/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""
Detection methods.
"""

import logging
import types
from enum import Enum
Expand Down
1 change: 1 addition & 0 deletions src/detectors/methods/templates.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""
Generalized detection methods templates.
"""

import logging
from abc import ABC, abstractmethod
from typing import Dict, List, Optional
Expand Down
1 change: 1 addition & 0 deletions src/detectors/models/densenet.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Densenet models for CIFAR10, CIFAR100 and SVHN datasets."""

import timm
import timm.models
import torch
Expand Down
1 change: 1 addition & 0 deletions src/detectors/models/dino.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
https://github.com/facebookresearch/dino/
"""

import math
from functools import partial

Expand Down
1 change: 1 addition & 0 deletions src/detectors/models/resnet.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""ResNet models for CIFAR10, CIFAR100, and SVHN datasets."""

import logging

import timm
Expand Down
1 change: 1 addition & 0 deletions src/detectors/models/vgg.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""VGG models for CIFAR10, CIFAR100 and SVHN datasets."""

import timm
import timm.models
import torch
Expand Down
1 change: 1 addition & 0 deletions src/detectors/models/vit.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Finetuned ViT models for CIFAR10, CIFAR100, and SVHN datasets."""

import timm
import timm.models
import torch
Expand Down
1 change: 1 addition & 0 deletions src/detectors/pipelines/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""
Pipeline module.
"""

from enum import Enum
from typing import Any, List, Optional, Tuple

Expand Down
1 change: 1 addition & 0 deletions src/detectors/pipelines/base.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Base abstract pipeline class."""

import logging
from typing import Any, Dict

Expand Down
1 change: 1 addition & 0 deletions src/detectors/pipelines/ood.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""
OOD Pipelines.
"""

import logging
import time
from typing import Any, Callable, Dict, List, Literal, Tuple, Union
Expand Down
1 change: 1 addition & 0 deletions src/detectors/trainer.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Trainer for classification models."""

import json
import logging
import os
Expand Down

0 comments on commit 46bba8a

Please sign in to comment.