diff --git a/.gitignore b/.gitignore index c82f146..87c7f1e 100644 --- a/.gitignore +++ b/.gitignore @@ -1,4 +1,8 @@ data/ data2/ __pycache__ -model_checkpoints/ \ No newline at end of file +model_checkpoints/ +output/ +*.mp4 +*.png +*.txt \ No newline at end of file diff --git a/data/extract_frames.py b/data/extract_frames.py index 3051281..3974139 100644 --- a/data/extract_frames.py +++ b/data/extract_frames.py @@ -10,6 +10,10 @@ import datetime import argparse +import numpy as np +from PIL import Image +import skvideo.io + def extract_frames(video_path): frames = [] diff --git a/result.txt b/result.txt new file mode 100644 index 0000000..7173886 --- /dev/null +++ b/result.txt @@ -0,0 +1,37 @@ +9 9 Mating_501.mp4 +205 217 Mating_502.mp4 +205 217 Mating_503.mp4 +392 483 Mating_504.mp4 +205 217 Mating_505.mp4 +243 308 Mating_506.mp4 +0 49 Sniffing_506.mp4 +15 56 Sniffing_507.mp4 +0 19 Sniffing_508.mp4 +0 6 Sniffing_509.mp4 +0 459 Sniffing_511.mp4 +6 36 Sniffing_512.mp4 +0 28 Sniffing_513.mp4 +28 62 Sniffing_514.mp4 +12 61 Sniffing_515.mp4 +16 67 Sniffing_516.mp4 +20 30 Sniffing_517.mp4 +5 48 Sniffing_518.mp4 +48 292 Sniffing_519.mp4 +2 89 Sniffing_520.mp4 +0 24 Sniffing_521.mp4 +18 91 Sniffing_522.mp4 +1 107 Sniffing_523.mp4 +0 94 Sniffing_524.mp4 +12 140 Sniffing_525.mp4 +1 78 Sniffing_526.mp4 +0 32 Sniffing_527.mp4 +11 122 Sniffing_528.mp4 +27 63 Sniffing_529.mp4 +30 45 Sniffing_530.mp4 +20 26 Sniffing_531.mp4 +22 82 Sniffing_532.mp4 +19 27 Sniffing_533.mp4 +0 115 Sniffing_534.mp4 +4 122 Sniffing_535.mp4 +18 58 Sniffing_536.mp4 +0 30 Sniffing_537.mp4 diff --git a/run_testOnVideo.sh b/run_testOnVideo.sh new file mode 100644 index 0000000..fba3936 --- /dev/null +++ b/run_testOnVideo.sh @@ -0,0 +1,21 @@ +#!/bin/bash + +# Define the input file and the checkpoint model path +input_file="data/ucfTrainTestlist/testlist01.txt" +checkpoint_model="model_checkpoints/ConvLSTM_195.pth" +video_path_prefix="data/UCF-101/" + +# Check if the file exists +if [ -f "$input_file" ]; then + # Read and process each video path in the input file + while IFS= read -r video_path; do + # Remove the '\r' character from the video path + video_path=$(echo "$video_path" | tr -d '\r') + # Add the video path prefix + full_video_path="${video_path_prefix}${video_path}" + # Execute the Python script with the video path and checkpoint model + python3 test_on_video.py --video_path "$full_video_path" --checkpoint_model "$checkpoint_model" + done < "$input_file" +else + echo "Error: File '$input_file' does not exist." +fi \ No newline at end of file diff --git a/static_graph.py b/static_graph.py new file mode 100644 index 0000000..92118ef --- /dev/null +++ b/static_graph.py @@ -0,0 +1,49 @@ +import matplotlib.pyplot as plt + +# Read the data from the text file +with open("result.txt", "r") as file: + data = [line.strip().split() for line in file.readlines()] + +# # Separate the data into durations, video names, and their indices +# durations = [int(item[0]) for item in data] +# video_names = [item[2] for item in data] +# video_indices = range(len(video_names)) + +# # Create a bar graph using matplotlib +# plt.bar(video_indices, durations) +# plt.xticks(video_indices, video_names, rotation=45, ha="right") +# plt.xlabel("Video Names") +# plt.ylabel("Duration") +# plt.title("Duration of Videos") +# plt.tight_layout() + +# # Save the graph to a file and display it +# plt.savefig("output_graph.png") +# plt.show() + + + +import matplotlib.pyplot as plt + +# Data +files = [item[2] for item in data] +actual = [int(item[0]) for item in data] +expected = [int(item[1]) for item in data] + +# Calculate accuracy +accuracy = [a / e * 100 for a, e in zip(actual, expected)] + +# Plot the bar graph +fig, ax = plt.subplots() +ax.bar(files, accuracy) + +# Set labels and title +ax.set_ylabel('Accuracy (%)') +ax.set_xlabel('Files') +ax.set_title('Accuracy for Mating Files') + +# Rotate x-axis labels for better readability +plt.xticks(rotation=45) + +# Show the graph +plt.show() \ No newline at end of file diff --git a/test_on_video.py b/test_on_video.py index 3f99147..c28cfd0 100644 --- a/test_on_video.py +++ b/test_on_video.py @@ -8,6 +8,10 @@ from torchvision.utils import make_grid from PIL import Image, ImageDraw import skvideo.io +import numpy as np +import cv2 + + if __name__ == "__main__": parser = argparse.ArgumentParser() @@ -38,13 +42,25 @@ labels = sorted(list(set(os.listdir(opt.dataset_path)))) # Define model and load model checkpoint - #model = ConvLSTM(input_shape=input_shape, num_classes=len(labels), latent_dim=opt.latent_dim) - model = ConvLSTM(num_classes=len(labels), latent_dim=opt.latent_dim) + ### model = ConvLSTM(input_shape=input_shape, num_classes=len(labels), latent_dim=opt.latent_dim) + model = ConvLSTM( + num_classes=len(labels), + latent_dim=opt.latent_dim, + lstm_layers=1, + hidden_dim=1024, + bidirectional=True, + attention=True, + ) model.to(device) - #model.load_state_dict(torch.load(opt.checkpoint_model)) - model.load_state_dict(torch.load(opt.checkpoint_model), strict=False) + model.load_state_dict(torch.load(opt.checkpoint_model)) model.eval() + ### labels statistics + num_true_labels = 0 + input_file = opt.video_path.split('/') + true_labels = input_file[2] + #print(type(opt.video_path), opt.video_path) + # Extract predictions output_frames = [] for frame in tqdm.tqdm(extract_frames(opt.video_path), desc="Processing frames"): @@ -62,8 +78,48 @@ output_frames += [frame] - # Create video from frames - writer = skvideo.io.FFmpegWriter("output.gif") - for frame in tqdm.tqdm(output_frames, desc="Writing to video"): - writer.writeFrame(np.array(frame)) - writer.close() + ### Count correct predictions and save the result in text file + if true_labels == predicted_label: + num_true_labels += 1 + # Save the result to a text file + output_file = "result.txt" + with open(output_file, "a") as file: + result = str(num_true_labels) + " " + str(len(output_frames)) + " " + str(input_file[3]) + file.write(result + "\n") + + # Print a confirmation message + print(f"result of '{input_file[3]}' has been saved to {output_file}.") + ### print(num_true_labels, len(output_frames)) + # # Create video from frames + # writer = skvideo.io.FFmpegWriter("output.gif") + # for frame in tqdm.tqdm(output_frames, desc="Writing to video"): + # writer.writeFrame(np.array(frame)) + # writer.close() + + # Create output folder if it does not exist + output_folder = "output" + if not os.path.exists(output_folder): + os.makedirs(output_folder) + + +# Save output_frames as images + for i, frame in enumerate(output_frames): + frame_np = np.array(frame) + image = Image.fromarray(frame_np) + image.save(os.path.join(output_folder, f"frame_{i}.jpg")) + + # # Set up VideoWriter object + # video_width, video_height = 640, 480 + # fourcc = cv2.VideoWriter_fourcc(*"mp4v") + # video_writer = cv2.VideoWriter("output.mp4", fourcc, 25.0, (video_width, video_height)) + + # for i, frame in enumerate(output_frames): + # # Convert PIL Image to numpy array + # frame = np.array(frame) + # # Resize frame to video size + # frame = cv2.resize(frame, (video_width, video_height)) + # # Write frame to video file + # video_writer.write(frame) + + # # Release VideoWriter object + # video_writer.release() \ No newline at end of file