Skip to content

Commit

Permalink
make threshold and percentage an argument
Browse files Browse the repository at this point in the history
  • Loading branch information
kashif committed Sep 20, 2024
1 parent 277056e commit e34a600
Showing 1 changed file with 18 additions and 4 deletions.
22 changes: 18 additions & 4 deletions examples/anomaly_detection_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,9 +130,11 @@ def main(args):
# calculate the surprisal scores for the context target
scores = -distr.log_prob(context_target / scale)

# get the top 10% of the scores for each time series of the batch
# get the args.top_score_percentage of the scores for each time series of the batch
top_scores = torch.topk(
scores, k=int(scores.shape[1] * 0.1), dim=1
scores,
k=int(scores.shape[1] * args.top_score_percentage),
dim=1,
)
# get top scores [B, 10% of context_length]
top_scores = top_scores.values
Expand All @@ -155,7 +157,7 @@ def main(args):
# mask out the score where is_anomaly is True
score = torch.where(is_anomaly, gpd.loc + 1, score)
is_anomaly = torch.where(
is_anomaly, False, gpd.cdf(score) < 0.05
is_anomaly, False, gpd.cdf(score) < args.anomaly_threshold
)
batch_anomalies.append(is_anomaly)

Expand Down Expand Up @@ -217,11 +219,23 @@ def main(args):
"--context_length", type=int, default=None, help="Context length"
)
parser.add_argument(
"--max_epochs", type=int, default=10, help="Maximum number of epochs"
"--max_epochs", type=int, default=30, help="Maximum number of epochs"
)
parser.add_argument(
"--batch_size", type=int, default=32, help="Batch size"
)
parser.add_argument(
"--anomaly_threshold",
type=float,
default=0.05,
help="Threshold for anomaly detection",
)
parser.add_argument(
"--top_score_percentage",
type=float,
default=0.1,
help="Percentage of top scores to consider for GPD fitting",
)

args = parser.parse_args()

Expand Down

0 comments on commit e34a600

Please sign in to comment.