diff --git a/.dockerignore b/.dockerignore new file mode 100644 index 0000000..b6c1da7 --- /dev/null +++ b/.dockerignore @@ -0,0 +1,26 @@ +# Ignore the following files and directories when building the Docker image +*.pyc +__pycache__/ +*.ipynb_checkpoints +*.log +*.csv +*.tsv +*.h5 +*.pth +*.pt +*.zip +*.tar.gz +*.egg-info/ +dist/ +build/ +.env +venv/ +.env.local +*.DS_Store +*.egg +*.whl +*.pkl +*.json +*.yaml +*.yml +submodules/ \ No newline at end of file diff --git a/.gitignore b/.gitignore index a05a2b7..c38a86d 100644 --- a/.gitignore +++ b/.gitignore @@ -158,3 +158,6 @@ dmypy.json # Pyre type checker .pyre/ learnableearthparser/fast_sampler/_sampler.c + +# data +data/ diff --git a/Dockerfile b/Dockerfile new file mode 100644 index 0000000..aa2d20a --- /dev/null +++ b/Dockerfile @@ -0,0 +1,60 @@ +FROM nvidia/cuda:12.1.1-devel-ubuntu22.04 + +# Set the working directory +WORKDIR /EDGS + +# Install system dependencies first, including git, build-essential, and cmake +RUN apt-get update && apt-get install -y \ + git \ + wget \ + build-essential \ + cmake \ + ninja-build \ + libgl1-mesa-glx \ + libglib2.0-0 \ + && rm -rf /var/lib/apt/lists/* + +# Copy only essential files for cloning submodules first (e.g., .gitmodules) +# Or, if submodules are public, you might not need to copy anything specific for this step +# For simplicity, we'll copy everything, but this could be optimized +COPY . . + +# Initialize and update submodules +RUN git submodule init && git submodule update --recursive + +# Install Miniconda +RUN wget --quiet https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh -O /tmp/miniconda.sh && \ + bash /tmp/miniconda.sh -b -p /opt/conda && \ + rm /tmp/miniconda.sh +ENV PATH="/opt/conda/bin:${PATH}" + +# Create the conda environment and install dependencies +# Accept Anaconda TOS before using conda +RUN conda init bash && \ + conda config --set always_yes yes --set changeps1 no && \ + conda config --add channels defaults && \ + conda config --set channel_priority strict && \ + conda tos accept --override-channels --channel https://repo.anaconda.com/pkgs/main && \ + conda tos accept --override-channels --channel https://repo.anaconda.com/pkgs/r +# Now you can safely create your environment +RUN conda create -y -n edgs python=3.10 pip && \ + conda clean -afy && \ + echo "source activate edgs" > ~/.bashrc + +# Set CUDA architectures to compile for +ENV TORCH_CUDA_ARCH_LIST="7.5;8.0;8.6;8.9;9.0+PTX" + +# Activate the environment and install Python dependencies +RUN /bin/bash -c "source activate edgs && \ + pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu121 && \ + pip install -e ./submodules/gaussian-splatting/submodules/diff-gaussian-rasterization && \ + pip install -e ./submodules/gaussian-splatting/submodules/simple-knn && \ + pip install pycolmap wandb hydra-core tqdm torchmetrics lpips matplotlib rich plyfile imageio imageio-ffmpeg && \ + pip install -e ./submodules/RoMa && \ + pip install gradio plotly scikit-learn moviepy==2.1.1 ffmpeg open3d jupyterlab matplotlib" + +# Expose the port for Gradio +EXPOSE 7862 + +# Keep the container running in detached mode +CMD ["tail", "-f", "/dev/null"] \ No newline at end of file diff --git a/README.md b/README.md index 8c40427..e944857 100644 --- a/README.md +++ b/README.md @@ -69,45 +69,14 @@ Alternatively, check our [Colab notebook](https://colab.research.google.com/gith ## đŸ› ī¸ Installation -You can either run `install.sh` or manually install using the following: +You can install it just: ```bash -git clone git@github.com:CompVis/EDGS.git --recursive -cd EDGS -git submodule update --init --recursive - -conda create -y -n edgs python=3.10 pip -conda activate edgs - -# Set up path to your CUDA. In our experience similar versions like 12.2 also work well -export CUDA_HOME=/usr/local/cuda-12.1 -export LD_LIBRARY_PATH=$CUDA_HOME/lib64:$LD_LIBRARY_PATH -export PATH=$CUDA_HOME/bin:$PATH - -conda install pytorch torchvision torchaudio pytorch-cuda=12.1 -c pytorch -c nvidia -y -conda install nvidia/label/cuda-12.1.0::cuda-toolkit -y - -pip install -e submodules/gaussian-splatting/submodules/diff-gaussian-rasterization -pip install -e submodules/gaussian-splatting/submodules/simple-knn - -# For COLMAP and pycolmap -# Optionally install original colmap but probably pycolmap suffices -# conda install conda-forge/label/colmap_dev::colmap -pip install pycolmap - - -pip install wandb hydra-core tqdm torchmetrics lpips matplotlib rich plyfile imageio imageio-ffmpeg -conda install numpy=1.26.4 -y -c conda-forge --override-channels - -pip install -e submodules/RoMa -conda install anaconda::jupyter --yes - -# Stuff necessary for gradio and visualizations -pip install gradio -pip install plotly scikit-learn moviepy==2.1.1 ffmpeg -pip install open3d +docker compose up -d ``` +or you can install with running `script/install.sh`. + ## đŸ“Ļ Data @@ -118,6 +87,36 @@ We evaluated on the following datasets: ### Using Your Own Dataset +#### Option A +Use gradle demo. +After running `docker compose up -d`, +``` +docker compose exec edgs-app bash +python script/gradio_demo.py --port 7862 +``` + +#### Option B +From command line. +``` +docker compose exec edgs-app bash +python script/fit_model_to_scene_full.py --video_path [--processed_scenes_dir ] +``` + +#### Option C +Using Jupyter lab. +``` +docker compose exec edgs-app bash +``` +And in the terminal in the docker container, +``` +jupyter lab --ip=0.0.0.0 --port=8888 --no-browser --allow-root --notebook-dir=notebooks +``` +After JupyterLab starts, it will print URLs to the terminal. Look for a URL containing a token, like: + `http://127.0.0.1:8888/lab?token=xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx` +Open `http://localhost:8888` (or `http://127.0.0.1:8888`) in your host browser. +When prompted for a "Password or token", paste the `xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx` part from the URL in step 4 into the field and log in. Alternatively, you can paste the full URL from step 4 directly into your browser. + +#### Option D You can use the same data format as the [3DGS project](https://github.com/graphdeco-inria/gaussian-splatting?tab=readme-ov-file#processing-your-own-scenes). Please follow their guide to prepare your scene. Expected folder structure: @@ -134,6 +133,11 @@ scene_folder |---points3D.bin ``` +``` +docker compose exec edgs-app bash +``` +Then run training command as described below section. + Nerf synthetic format is also acceptable. You can also use functions provided in our code to convert a collection of images or a sinlge video into a desired format. However, this may requre tweaking and processing time can be large for large collection of images with little overlap. @@ -144,7 +148,7 @@ You can also use functions provided in our code to convert a collection of image To optimize on a single scene in COLMAP format use this code. ```bash -python train.py \ +python script/train.py \ train.gs_epochs=30000 \ train.no_densify=True \ gs.dataset.source_path= \ diff --git a/configs/train.yaml b/configs/train.yaml index e40ec5b..6e776d0 100644 --- a/configs/train.yaml +++ b/configs/train.yaml @@ -1,22 +1,22 @@ defaults: - gs: base - - _self_ + - _self_ seed: 228 wandb: - mode: "online" # "disabled" for no logging + mode: "disabled" # "online" or "disabled" entity: "3dcorrespondence" project: "Adv3DGS" group: null name: null tag: "debug" - + train: - gs_epochs: 0 # number of 3dgs iterations - reduce_opacity: True + gs_epochs: 1000 # number of 3dgs iterations + reduce_opacity: True no_densify: False # if True, the model will not be densified - max_lr: True + max_lr: True load: gs: null #path to 3dgs checkpoint @@ -28,11 +28,9 @@ verbose: true init_wC: use: True # use EDGS matches_per_ref: 15_000 # number of matches per reference - num_refs: 180 # number of reference images + num_refs: 18 # number of reference images nns_per_ref: 3 # number of nearest neighbors per reference scaling_factor: 0.001 proj_err_tolerance: 0.01 roma_model: "outdoors" # you can change this to "indoors" or "outdoors" - add_SfM_init : False - - + add_SfM_init: False diff --git a/docker-compose.yml b/docker-compose.yml new file mode 100644 index 0000000..24975eb --- /dev/null +++ b/docker-compose.yml @@ -0,0 +1,25 @@ +services: + edgs-app: + build: . # Instructs Docker Compose to build using the Dockerfile in the current directory + image: edgs-app # This is the name of the image you built + ports: + - "7862:7862" # For Gradio, if you still use it + - "8888:8888" # Map port 8888 for JupyterLab + deploy: + resources: + reservations: + devices: + - driver: nvidia + count: all # Use all available GPUs + capabilities: [gpu] # Request GPU capabilities + environment: + - NVIDIA_VISIBLE_DEVICES=all + volumes: + - ./data:/EDGS/data # Example: map a local 'data' folder to '/EDGS/data' in the container + - ./outputs:/EDGS/outputs # Example: map a local 'output' folder + - ./script:/EDGS/script # Example: map a local 'scripts' folder + - ./source:/EDGS/source # Example: map a local 'sources' folder + - ./configs:/EDGS/configs # Example: map a local 'config' folder + - ./notebooks:/EDGS/notebooks # Map a local 'notebooks' folder for JupyterLab + stdin_open: true # Keep STDIN open for interactive processes + tty: true # Allocate a TTY \ No newline at end of file diff --git a/notebooks/fit_model_to_scene_full.ipynb b/notebooks/fit_model_to_scene_full.ipynb index 09e6323..1d0b6c6 100644 --- a/notebooks/fit_model_to_scene_full.ipynb +++ b/notebooks/fit_model_to_scene_full.ipynb @@ -67,10 +67,18 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 1, "id": "59c64632-e31a-4ead-98a5-4ab0f295e54d", "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "xFormers not available\n" + ] + } + ], "source": [ "import torch\n", "import numpy as np\n", @@ -96,10 +104,96 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 2, "id": "925adfa3-c311-44b6-a8c4-a31fb7426947", "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "gs:\n", + " _target_: source.networks.Warper3DGS\n", + " verbose: true\n", + " viewpoint_stack: null\n", + " sh_degree: 3\n", + " opt:\n", + " iterations: 30000\n", + " position_lr_init: 0.00016\n", + " position_lr_final: 1.6e-06\n", + " position_lr_delay_mult: 0.01\n", + " position_lr_max_steps: 30000\n", + " feature_lr: 0.0025\n", + " opacity_lr: 0.025\n", + " scaling_lr: 0.005\n", + " rotation_lr: 0.001\n", + " percent_dense: 0.01\n", + " lambda_dssim: 0.2\n", + " densification_interval: 100\n", + " opacity_reset_interval: 30000\n", + " densify_from_iter: 500\n", + " densify_until_iter: 15000\n", + " densify_grad_threshold: 0.0002\n", + " random_background: false\n", + " save_iterations:\n", + " - 3000\n", + " - 7000\n", + " - 15000\n", + " - 30000\n", + " batch_size: 64\n", + " exposure_lr_init: 0.01\n", + " exposure_lr_final: 0.0001\n", + " exposure_lr_delay_steps: 0\n", + " exposure_lr_delay_mult: 0.0\n", + " TRAIN_CAM_IDX_TO_LOG: 50\n", + " TEST_CAM_IDX_TO_LOG: 10\n", + " pipe:\n", + " convert_SHs_python: false\n", + " compute_cov3D_python: false\n", + " debug: false\n", + " antialiasing: false\n", + " dataset:\n", + " densify_until_iter: 15000\n", + " source_path: ''\n", + " model_path: ''\n", + " images: images\n", + " resolution: -1\n", + " white_background: false\n", + " data_device: cuda\n", + " eval: false\n", + " depths: ''\n", + " train_test_exp: false\n", + "seed: 228\n", + "wandb:\n", + " mode: online\n", + " entity: m-ogawa-sensyn\n", + " project: Adv3DGS\n", + " group: null\n", + " name: null\n", + " tag: debug\n", + "train:\n", + " gs_epochs: 0\n", + " reduce_opacity: true\n", + " no_densify: false\n", + " max_lr: true\n", + "load:\n", + " gs: null\n", + " gs_step: null\n", + "device: cuda:0\n", + "verbose: true\n", + "init_wC:\n", + " use: true\n", + " matches_per_ref: 15000\n", + " num_refs: 180\n", + " nns_per_ref: 3\n", + " scaling_factor: 0.001\n", + " proj_err_tolerance: 0.01\n", + " roma_model: outdoors\n", + " add_SfM_init: false\n", + "\n" + ] + } + ], "source": [ "with initialize(config_path=\"../configs\", version_base=\"1.1\"):\n", " cfg = compose(config_name=\"train\")\n", @@ -124,7 +218,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 3, "id": "07e4ca51", "metadata": {}, "outputs": [], @@ -177,10 +271,40 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 5, "id": "2056ee6f-dbb6-4ce8-86f0-5b4f9721d093", - "metadata": {}, - "outputs": [], + "metadata": { + "scrolled": true + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Output folder: ./scene_edgsed/\n" + ] + }, + { + "ename": "InstantiationException", + "evalue": "Error in call to target 'source.networks.Warper3DGS':\nAssertionError('Could not recognize scene type!')\nfull_key: gs", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mAssertionError\u001b[0m Traceback (most recent call last)", + "File \u001b[0;32m/opt/conda/envs/edgs/lib/python3.10/site-packages/hydra/_internal/instantiate/_instantiate2.py:92\u001b[0m, in \u001b[0;36m_call_target\u001b[0;34m(_target_, _partial_, args, kwargs, full_key)\u001b[0m\n\u001b[1;32m 91\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[0;32m---> 92\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43m_target_\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 93\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m \u001b[38;5;167;01mException\u001b[39;00m \u001b[38;5;28;01mas\u001b[39;00m e:\n", + "File \u001b[0;32m/EDGS/notebooks/../source/networks.py:26\u001b[0m, in \u001b[0;36mWarper3DGS.__init__\u001b[0;34m(self, sh_degree, opt, pipe, dataset, viewpoint_stack, verbose, do_train_test_split)\u001b[0m\n\u001b[1;32m 25\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mpipe \u001b[38;5;241m=\u001b[39m pipe\n\u001b[0;32m---> 26\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mscene \u001b[38;5;241m=\u001b[39m \u001b[43mScene\u001b[49m\u001b[43m(\u001b[49m\u001b[43mdataset\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mgaussians\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mshuffle\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mFalse\u001b[39;49;00m\u001b[43m)\u001b[49m\n\u001b[1;32m 27\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m do_train_test_split:\n", + "File \u001b[0;32m/EDGS/notebooks/../submodules/gaussian-splatting/scene/__init__.py:49\u001b[0m, in \u001b[0;36mScene.__init__\u001b[0;34m(self, args, gaussians, load_iteration, shuffle, resolution_scales)\u001b[0m\n\u001b[1;32m 48\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m---> 49\u001b[0m \u001b[38;5;28;01massert\u001b[39;00m \u001b[38;5;28;01mFalse\u001b[39;00m, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mCould not recognize scene type!\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 51\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mloaded_iter:\n", + "\u001b[0;31mAssertionError\u001b[0m: Could not recognize scene type!", + "\nThe above exception was the direct cause of the following exception:\n", + "\u001b[0;31mInstantiationException\u001b[0m Traceback (most recent call last)", + "Cell \u001b[0;32mIn[5], line 12\u001b[0m\n\u001b[1;32m 10\u001b[0m os\u001b[38;5;241m.\u001b[39mmakedirs(cfg\u001b[38;5;241m.\u001b[39mgs\u001b[38;5;241m.\u001b[39mdataset\u001b[38;5;241m.\u001b[39mmodel_path, exist_ok\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mTrue\u001b[39;00m)\n\u001b[1;32m 11\u001b[0m \u001b[38;5;66;03m# Init gs model\u001b[39;00m\n\u001b[0;32m---> 12\u001b[0m gs \u001b[38;5;241m=\u001b[39m \u001b[43mhydra\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mutils\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43minstantiate\u001b[49m\u001b[43m(\u001b[49m\u001b[43mcfg\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mgs\u001b[49m\u001b[43m)\u001b[49m \n\u001b[1;32m 13\u001b[0m trainer \u001b[38;5;241m=\u001b[39m EDGSTrainer(GS\u001b[38;5;241m=\u001b[39mgs,\n\u001b[1;32m 14\u001b[0m training_config\u001b[38;5;241m=\u001b[39mcfg\u001b[38;5;241m.\u001b[39mgs\u001b[38;5;241m.\u001b[39mopt,\n\u001b[1;32m 15\u001b[0m device\u001b[38;5;241m=\u001b[39mcfg\u001b[38;5;241m.\u001b[39mdevice)\n", + "File \u001b[0;32m/opt/conda/envs/edgs/lib/python3.10/site-packages/hydra/_internal/instantiate/_instantiate2.py:226\u001b[0m, in \u001b[0;36minstantiate\u001b[0;34m(config, *args, **kwargs)\u001b[0m\n\u001b[1;32m 223\u001b[0m _convert_ \u001b[38;5;241m=\u001b[39m config\u001b[38;5;241m.\u001b[39mpop(_Keys\u001b[38;5;241m.\u001b[39mCONVERT, ConvertMode\u001b[38;5;241m.\u001b[39mNONE)\n\u001b[1;32m 224\u001b[0m _partial_ \u001b[38;5;241m=\u001b[39m config\u001b[38;5;241m.\u001b[39mpop(_Keys\u001b[38;5;241m.\u001b[39mPARTIAL, \u001b[38;5;28;01mFalse\u001b[39;00m)\n\u001b[0;32m--> 226\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43minstantiate_node\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 227\u001b[0m \u001b[43m \u001b[49m\u001b[43mconfig\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mrecursive\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43m_recursive_\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mconvert\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43m_convert_\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mpartial\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43m_partial_\u001b[49m\n\u001b[1;32m 228\u001b[0m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 229\u001b[0m \u001b[38;5;28;01melif\u001b[39;00m OmegaConf\u001b[38;5;241m.\u001b[39mis_list(config):\n\u001b[1;32m 230\u001b[0m \u001b[38;5;66;03m# Finalize config (convert targets to strings, merge with kwargs)\u001b[39;00m\n\u001b[1;32m 231\u001b[0m config_copy \u001b[38;5;241m=\u001b[39m copy\u001b[38;5;241m.\u001b[39mdeepcopy(config)\n", + "File \u001b[0;32m/opt/conda/envs/edgs/lib/python3.10/site-packages/hydra/_internal/instantiate/_instantiate2.py:347\u001b[0m, in \u001b[0;36minstantiate_node\u001b[0;34m(node, convert, recursive, partial, *args)\u001b[0m\n\u001b[1;32m 342\u001b[0m value \u001b[38;5;241m=\u001b[39m instantiate_node(\n\u001b[1;32m 343\u001b[0m value, convert\u001b[38;5;241m=\u001b[39mconvert, recursive\u001b[38;5;241m=\u001b[39mrecursive\n\u001b[1;32m 344\u001b[0m )\n\u001b[1;32m 345\u001b[0m kwargs[key] \u001b[38;5;241m=\u001b[39m _convert_node(value, convert)\n\u001b[0;32m--> 347\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43m_call_target\u001b[49m\u001b[43m(\u001b[49m\u001b[43m_target_\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mpartial\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mkwargs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mfull_key\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 348\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m 349\u001b[0m \u001b[38;5;66;03m# If ALL or PARTIAL non structured or OBJECT non structured,\u001b[39;00m\n\u001b[1;32m 350\u001b[0m \u001b[38;5;66;03m# instantiate in dict and resolve interpolations eagerly.\u001b[39;00m\n\u001b[1;32m 351\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m convert \u001b[38;5;241m==\u001b[39m ConvertMode\u001b[38;5;241m.\u001b[39mALL \u001b[38;5;129;01mor\u001b[39;00m (\n\u001b[1;32m 352\u001b[0m convert \u001b[38;5;129;01min\u001b[39;00m (ConvertMode\u001b[38;5;241m.\u001b[39mPARTIAL, ConvertMode\u001b[38;5;241m.\u001b[39mOBJECT)\n\u001b[1;32m 353\u001b[0m \u001b[38;5;129;01mand\u001b[39;00m node\u001b[38;5;241m.\u001b[39m_metadata\u001b[38;5;241m.\u001b[39mobject_type \u001b[38;5;129;01min\u001b[39;00m (\u001b[38;5;28;01mNone\u001b[39;00m, \u001b[38;5;28mdict\u001b[39m)\n\u001b[1;32m 354\u001b[0m ):\n", + "File \u001b[0;32m/opt/conda/envs/edgs/lib/python3.10/site-packages/hydra/_internal/instantiate/_instantiate2.py:97\u001b[0m, in \u001b[0;36m_call_target\u001b[0;34m(_target_, _partial_, args, kwargs, full_key)\u001b[0m\n\u001b[1;32m 95\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m full_key:\n\u001b[1;32m 96\u001b[0m msg \u001b[38;5;241m+\u001b[39m\u001b[38;5;241m=\u001b[39m \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;130;01m\\n\u001b[39;00m\u001b[38;5;124mfull_key: \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mfull_key\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\n\u001b[0;32m---> 97\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m InstantiationException(msg) \u001b[38;5;28;01mfrom\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;21;01me\u001b[39;00m\n", + "\u001b[0;31mInstantiationException\u001b[0m: Error in call to target 'source.networks.Warper3DGS':\nAssertionError('Could not recognize scene type!')\nfull_key: gs" + ] + } + ], "source": [ "_ = wandb.init(entity=cfg.wandb.entity,\n", " project=cfg.wandb.project,\n", @@ -1781,7 +1905,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.9.21" + "version": "3.10.18" } }, "nbformat": 4, diff --git a/script/fit_model_to_scene_full.py b/script/fit_model_to_scene_full.py new file mode 100644 index 0000000..bffd51b --- /dev/null +++ b/script/fit_model_to_scene_full.py @@ -0,0 +1,270 @@ +#!/usr/bin/env python +# coding: utf-8 + +# # EDGS: Eliminating Densification for Gaussian Splatting +# EDGS improves 3D Gaussian Splatting by removing the need for densification. It starts from a dense point cloud initialization based on 2D correspondences, leading to: +# - ⚡ Faster convergence (only 25% of training time) +# - 🌀 Higher rendering quality +# - 💡 No need for progressive densification + +# ## 2. Import libraries +import argparse +import logging +import os +import random +import sys + +import hydra +import numpy as np +import omegaconf +import torch +import wandb +from hydra import compose, initialize +from matplotlib import pyplot as plt +from omegaconf import OmegaConf + +# Add the project root directory to sys.path +# so that modules from 'source' can be imported. +project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) +if project_root not in sys.path: + sys.path.insert(0, project_root) +# sys.path.append("../submodules/gaussian-splatting") +from source.trainer import EDGSTrainer +from source.utils_aux import set_seed +from source.utils_preprocess import ( + orchestrate_video_to_colmap_scene, # Use the refactored function +) + +# Initialize logging +logging.basicConfig( + level=logging.INFO, + format="%(asctime)s - %(levelname)s - %(message)s", +) + +# --- Add argument parsing --- +parser = argparse.ArgumentParser( + description="Fit EDGS model to a scene, optionally from a video." +) +parser.add_argument( + "--video_path", + type=str, + default=os.path.join( + project_root, "assets", "examples", "video_fruits.mp4" + ), # Use project_root + help="Path to the input video file.", +) +parser.add_argument( + "--outputs_dir", + type=str, + default=os.path.join(project_root, "outputs"), # Use project_root + help="Base directory where processed COLMAP scenes will be stored.", +) +args = parser.parse_args() +# --- End argument parsing --- + +with initialize(config_path="../configs", version_base="1.1"): + cfg = compose(config_name="train") + +SAME_WITH_GRADIO_DEMO = True +if SAME_WITH_GRADIO_DEMO: + cfg.gs.opt.opacity_reset_interval = 1_000_000 + cfg.train.reduce_opacity = True + cfg.train.no_densify = True + cfg.train.max_lr = True + cfg.train.gs_epochs = 1000 + + cfg.init_wC.use = False # Disable for fallback cases + cfg.init_wC.nns_per_ref = 1 + cfg.init_wC.add_SfM_init = False + cfg.init_wC.scaling_factor = 0.00077 * 2.0 + cfg.init_wC.num_refs = 2 # Reduce to minimum since COLMAP only found 2 cameras + cfg.init_wC.matches_per_ref = 20000 + +print(OmegaConf.to_yaml(cfg)) + + +# # 3. Init input parameters + +# ## 3.1 Optionally preprocess video +# process the input video +if os.path.exists(args.video_path): + print(f"Starting video processing for: {args.video_path}") + try: + # The first return value 'images_data' might not be directly used by the trainer + # if the Scene object loads everything from the COLMAP directory. + _, scene_dir = orchestrate_video_to_colmap_scene( + args.video_path, + cfg.init_wC.num_refs, # Assuming you added this arg + max_size=1024, # Or make it an arg + base_work_dir=args.outputs_dir, # Assuming you added this arg + ) + if scene_dir is None: + print(f"Failed to process video {args.video_path}. Exiting.") + sys.exit(1) + cfg.gs.dataset.source_path = scene_dir + cfg.gs.dataset.model_path = os.path.join(scene_dir, "models") + print(f"Set model_path to: {cfg.gs.dataset.model_path}") + os.makedirs(cfg.gs.dataset.model_path, exist_ok=True) + except Exception as e: + print(f"Error during video preprocessing: {e}") + sys.exit(1) + + +# # 4. Initilize model and logger +if cfg.wandb.mode != "disabled": + logging.info( + "wandb logging is enabled (mode={}). Results will be logged to wandb.".format( + cfg.wandb.mode + ) + ) + _ = wandb.init( + entity=cfg.wandb.entity, + project=cfg.wandb.project, + config=omegaconf.OmegaConf.to_container( + cfg, resolve=True, throw_on_missing=True + ), + name=cfg.wandb.name, + mode=cfg.wandb.mode, + ) +else: + logging.info( + "wandb logging is disabled (mode={}). Results will not be logged to wandb.".format( + cfg.wandb.mode + ) + ) +omegaconf.OmegaConf.resolve(cfg) +set_seed(cfg.seed) +# Init output folder +print("Output folder: {}".format(cfg.gs.dataset.model_path)) +os.makedirs(cfg.gs.dataset.model_path, exist_ok=True) +# Init gs model +gs = hydra.utils.instantiate(cfg.gs) +trainer = EDGSTrainer( + GS=gs, + training_config=cfg.gs.opt, + device=cfg.device, + log_wandb=(cfg.wandb.mode != "disabled"), +) + + +# # 5. Init with matchings +trainer.timer.start() +trainer.init_with_corr(cfg.init_wC) +trainer.timer.pause() + + +# ### Visualize a few initial viewpoints +with torch.no_grad(): + viewpoint_stack = trainer.GS.scene.getTrainCameras() + available_cams = trainer.GS.scene.getTrainCameras() + num_cams_to_viz = min(4, len(available_cams)) + viewpoint_cams_to_viz = random.sample(available_cams, num_cams_to_viz) + for idx, viewpoint_cam in enumerate(viewpoint_cams_to_viz): + render_pkg = trainer.GS(viewpoint_cam) + image = render_pkg["render"] + + image_np = image.clone().detach().cpu().numpy().transpose(1, 2, 0) + image_gt_np = ( + viewpoint_cam.original_image.clone() + .detach() + .cpu() + .numpy() + .transpose(1, 2, 0) + ) + + # Clip values to be in the range [0, 1] + image_np = np.clip(image_np * 255, 0, 255).astype(np.uint8) + image_gt_np = np.clip(image_gt_np * 255, 0, 255).astype(np.uint8) + + fig, ax = plt.subplots(nrows=1, ncols=2, figsize=(12, 6)) + ax[0].imshow(image_gt_np) + ax[0].axis("off") + ax[1].imshow(image_np) + ax[1].axis("off") + plt.tight_layout() + plt.savefig( + os.path.join( + cfg.gs.dataset.model_path, + f"viewpoint_{idx}_initial.png", + ) + ) + plt.show() + plt.close(fig) + + +# # 6.Optimize scene +# Optimize first briefly for 5k steps and visualize results. We also disable saving of pretrained models. Train function can be changed for any other method +trainer.saving_iterations = [] +# cfg.train.gs_epochs = 5_000 +trainer.train(cfg.train) + + +# ### Visualize same viewpoints +with torch.no_grad(): + for viewpoint_cam in viewpoint_cams_to_viz: + render_pkg = trainer.GS(viewpoint_cam) + image = render_pkg["render"] + + image_np = image.clone().detach().cpu().numpy().transpose(1, 2, 0) + image_gt_np = ( + viewpoint_cam.original_image.clone() + .detach() + .cpu() + .numpy() + .transpose(1, 2, 0) + ) + + # Clip values to be in the range [0, 1] + image_np = np.clip(image_np * 255, 0, 255).astype(np.uint8) + image_gt_np = np.clip(image_gt_np * 255, 0, 255).astype(np.uint8) + + fig, ax = plt.subplots(nrows=1, ncols=2, figsize=(12, 6)) + ax[0].imshow(image_gt_np) + ax[0].axis("off") + ax[1].imshow(image_np) + ax[1].axis("off") + plt.tight_layout() + plt.show() + + +# ### Save model +with torch.no_grad(): + trainer.save_model() + + +# # # 7. Continue training until we reach total 30K training steps +# cfg.train.gs_epochs = 25_000 +# trainer.train(cfg.train) + + +# # ### Visualize same viewpoints +# with torch.no_grad(): +# for viewpoint_cam in viewpoint_cams_to_viz: +# render_pkg = trainer.GS(viewpoint_cam) +# image = render_pkg["render"] + +# image_np = image.clone().detach().cpu().numpy().transpose(1, 2, 0) +# image_gt_np = ( +# viewpoint_cam.original_image.clone() +# .detach() +# .cpu() +# .numpy() +# .transpose(1, 2, 0) +# ) + +# # Clip values to be in the range [0, 1] +# image_np = np.clip(image_np * 255, 0, 255).astype(np.uint8) +# image_gt_np = np.clip(image_gt_np * 255, 0, 255).astype(np.uint8) + +# fig, ax = plt.subplots(nrows=1, ncols=2, figsize=(12, 6)) +# ax[0].imshow(image_gt_np) +# ax[0].axis("off") +# ax[1].imshow(image_np) +# ax[1].axis("off") +# plt.tight_layout() +# plt.show() + + +# ### Save model +# with torch.no_grad(): +# trainer.save_model() diff --git a/full_eval.py b/script/full_eval.py similarity index 100% rename from full_eval.py rename to script/full_eval.py diff --git a/gradio_demo.py b/script/gradio_demo.py similarity index 52% rename from gradio_demo.py rename to script/gradio_demo.py index 9c55e44..8684e9e 100644 --- a/gradio_demo.py +++ b/script/gradio_demo.py @@ -1,27 +1,41 @@ -import torch +import argparse +import contextlib +import io import os import shutil +import sys import tempfile -import argparse +import time + import gradio as gr -import sys -import io -from PIL import Image -import numpy as np -from source.utils_aux import set_seed -from source.utils_preprocess import read_video_frames, preprocess_frames, select_optimal_frames, save_frames_to_scene_dir, run_colmap_on_scene -from source.trainer import EDGSTrainer -from hydra import initialize, compose import hydra -import time -from source.visualization import generate_circular_camera_path, save_numpy_frames_as_mp4, generate_fully_smooth_cameras_with_tsp, put_text_on_image -import contextlib -import base64 +import numpy as np +import torch +from hydra import compose, initialize + +# Add the project root directory to sys.path +# so that modules from 'source' can be imported. +project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) +if project_root not in sys.path: + sys.path.insert(0, project_root) +from source.trainer import EDGSTrainer +from source.utils_aux import set_seed +from source.utils_preprocess import ( + orchestrate_video_to_colmap_scene, # Import the new/refactored function + process_input_for_colmap, + run_colmap_on_scene, +) +from source.visualization import ( + generate_circular_camera_path, + generate_fully_smooth_cameras_with_tsp, + put_text_on_image, + save_numpy_frames_as_mp4, +) # Init RoMA model: -sys.path.append('../submodules/RoMa') -from romatch import roma_outdoor, roma_indoor +sys.path.append("../submodules/RoMa") +from romatch import roma_indoor roma_model = roma_indoor(device="cuda:0") roma_model.upsample_preds = False @@ -33,6 +47,7 @@ trainer = None + class Tee(io.TextIOBase): def __init__(self, *streams): self.streams = streams @@ -46,6 +61,7 @@ def flush(self): for stream in self.streams: stream.flush() + def capture_logs(func, *args, **kwargs): log_capture_string = io.StringIO() tee = Tee(sys.__stdout__, log_capture_string) @@ -53,17 +69,20 @@ def capture_logs(func, *args, **kwargs): result = func(*args, **kwargs) return result, log_capture_string.getvalue() + # Training Pipeline -def run_training_pipeline(scene_dir, - num_ref_views=16, - num_corrs_per_view=20000, - num_steps=1_000, - mode_toggle="Ours (EDGS)"): - with initialize(config_path="./configs", version_base="1.1"): +def run_training_pipeline( + scene_dir, + num_ref_views=16, + num_corrs_per_view=20000, + num_steps=1_000, + mode_toggle="Ours (EDGS)", +): + with initialize(config_path="../configs", version_base="1.1"): cfg = compose(config_name="train") scene_name = os.path.basename(scene_dir) - model_output_dir = f"./outputs/{scene_name}_trained" + model_output_dir = f"../outputs/{scene_name}_trained" cfg.wandb.mode = "disabled" cfg.gs.dataset.model_path = model_output_dir @@ -72,8 +91,8 @@ def run_training_pipeline(scene_dir, cfg.gs.opt.TEST_CAM_IDX_TO_LOG = 12 cfg.train.gs_epochs = 30000 - - if mode_toggle=="Ours (EDGS)": + + if mode_toggle == "Ours (EDGS)": cfg.gs.opt.opacity_reset_interval = 1_000_000 cfg.train.reduce_opacity = True cfg.train.no_densify = True @@ -84,15 +103,20 @@ def run_training_pipeline(scene_dir, cfg.init_wC.nns_per_ref = 1 cfg.init_wC.num_refs = num_ref_views cfg.init_wC.add_SfM_init = False - cfg.init_wC.scaling_factor = 0.00077 * 2. - + cfg.init_wC.scaling_factor = 0.00077 * 2.0 + set_seed(cfg.seed) os.makedirs(cfg.gs.dataset.model_path, exist_ok=True) global trainer global MODEL_PATH generator3dgs = hydra.utils.instantiate(cfg.gs, do_train_test_split=False) - trainer = EDGSTrainer(GS=generator3dgs, training_config=cfg.gs.opt, device=cfg.device, log_wandb=cfg.wandb.mode != 'disabled') + trainer = EDGSTrainer( + GS=generator3dgs, + training_config=cfg.gs.opt, + device=cfg.device, + log_wandb=cfg.wandb.mode != "disabled", + ) # Disable evaluation and saving trainer.saving_iterations = [] @@ -102,13 +126,15 @@ def run_training_pipeline(scene_dir, trainer.timer.start() start_time = time.time() trainer.init_with_corr(cfg.init_wC, roma_model=roma_model) - time_for_init = time.time()-start_time + time_for_init = time.time() - start_time viewpoint_cams = trainer.GS.scene.getTrainCameras() - path_cameras = generate_fully_smooth_cameras_with_tsp(existing_cameras=viewpoint_cams, - n_selected=6, # 8 - n_points_per_segment=30, # 30 - closed=False) + path_cameras = generate_fully_smooth_cameras_with_tsp( + existing_cameras=viewpoint_cams, + n_selected=6, # 8 + n_points_per_segment=30, # 30 + closed=False, + ) path_cameras = path_cameras + path_cameras[::-1] path_renderings = [] @@ -122,13 +148,24 @@ def run_training_pipeline(scene_dir, image = render_pkg["render"] image_np = np.clip(image.detach().cpu().numpy().transpose(1, 2, 0), 0, 1) image_np = (image_np * 255).astype(np.uint8) - path_renderings.append(put_text_on_image(img=image_np, - text=f"Init stage.\nTime:{time_for_init:.3f}s. ")) - path_renderings = path_renderings + [put_text_on_image(img=image_np, text=f"Start fitting.\nTime:{time_for_init:.3f}s. ")]*30 - + path_renderings.append( + put_text_on_image( + img=image_np, text=f"Init stage.\nTime:{time_for_init:.3f}s. " + ) + ) + path_renderings = ( + path_renderings + + [ + put_text_on_image( + img=image_np, text=f"Start fitting.\nTime:{time_for_init:.3f}s. " + ) + ] + * 30 + ) + # Train and save visualizations during training. start_time = time.time() - for _ in range(int(num_steps//10)): + for _ in range(int(num_steps // 10)): with torch.no_grad(): viewpoint_cam = path_cameras[idx] idx = (idx + 1) % len(path_cameras) @@ -136,20 +173,27 @@ def run_training_pipeline(scene_dir, image = render_pkg["render"] image_np = np.clip(image.detach().cpu().numpy().transpose(1, 2, 0), 0, 1) image_np = (image_np * 255).astype(np.uint8) - path_renderings.append(put_text_on_image( - img=image_np, - text=f"Fitting stage.\nTime:{time_for_init + time.time()-start_time:.3f}s. ")) - + path_renderings.append( + put_text_on_image( + img=image_np, + text=f"Fitting stage.\nTime:{time_for_init + time.time() - start_time:.3f}s. ", + ) + ) + cfg.train.gs_epochs = 10 trainer.train(cfg.train) - print(f"Time elapsed: {(time_for_init + time.time()-start_time):.2f}s.") + print(f"Time elapsed: {(time_for_init + time.time() - start_time):.2f}s.") # if (cfg.init_wC.use == False) and (time_for_init + time.time()-start_time) > 60: # break final_time = time.time() - + # Add static frame. To highlight we're done - path_renderings += [put_text_on_image( - img=image_np, text=f"Done.\nTime:{time_for_init + final_time -start_time:.3f}s. ")]*30 + path_renderings += [ + put_text_on_image( + img=image_np, + text=f"Done.\nTime:{time_for_init + final_time - start_time:.3f}s. ", + ) + ] * 30 # Final rendering at the end. for _ in range(len(path_cameras)): with torch.no_grad(): @@ -159,37 +203,61 @@ def run_training_pipeline(scene_dir, image = render_pkg["render"] image_np = np.clip(image.detach().cpu().numpy().transpose(1, 2, 0), 0, 1) image_np = (image_np * 255).astype(np.uint8) - path_renderings.append(put_text_on_image(img=image_np, - text=f"Final result.\nTime:{time_for_init + final_time -start_time:.3f}s. ")) + path_renderings.append( + put_text_on_image( + img=image_np, + text=f"Final result.\nTime:{time_for_init + final_time - start_time:.3f}s. ", + ) + ) trainer.save_model() - final_video_path = os.path.join(STATIC_FILE_SERVING_FOLDER, f"{scene_name}_final.mp4") - save_numpy_frames_as_mp4(frames=path_renderings, output_path=final_video_path, fps=30, center_crop=0.85) + final_video_path = os.path.join( + STATIC_FILE_SERVING_FOLDER, f"{scene_name}_final.mp4" + ) + save_numpy_frames_as_mp4( + frames=path_renderings, output_path=final_video_path, fps=30, center_crop=0.85 + ) MODEL_PATH = cfg.gs.dataset.model_path - ply_path = os.path.join(cfg.gs.dataset.model_path, f"point_cloud/iteration_{trainer.gs_step}/point_cloud.ply") - shutil.copy(ply_path, os.path.join(STATIC_FILE_SERVING_FOLDER, "point_cloud_final.ply")) + original_ply_path = os.path.join( # Renamed for clarity + cfg.gs.dataset.model_path, + f"point_cloud/iteration_{trainer.gs_step}/point_cloud.ply", + ) + # This is the path to the copied file in an allowed directory + copied_ply_path_for_serving = os.path.join( + STATIC_FILE_SERVING_FOLDER, "point_cloud_final.ply" + ) + shutil.copy(original_ply_path, copied_ply_path_for_serving) + + return ( + final_video_path, + copied_ply_path_for_serving, + ) # Return the path to the copied .ply file - return final_video_path, ply_path # Gradio Interface def gradio_interface(input_path, num_ref_views, num_corrs, num_steps): - images, scene_dir = run_full_pipeline(input_path, num_ref_views, num_corrs, max_size=1024) - shutil.copytree(scene_dir, STATIC_FILE_SERVING_FOLDER+'/scene_colmaped', dirs_exist_ok=True) - (final_video_path, ply_path), log_output = capture_logs(run_training_pipeline, - scene_dir, - num_ref_views, - num_corrs, - num_steps) + images, scene_dir = run_full_pipeline( + input_path, num_ref_views, num_corrs, max_size=1024 + ) + shutil.copytree( + scene_dir, STATIC_FILE_SERVING_FOLDER + "/scene_colmaped", dirs_exist_ok=True + ) + (final_video_path, ply_path), log_output = capture_logs( + run_training_pipeline, scene_dir, num_ref_views, num_corrs, num_steps + ) images_rgb = [img[:, :, ::-1] for img in images] return images_rgb, final_video_path, scene_dir, ply_path, log_output + # Dummy Render Functions def render_all_views(scene_dir): viewpoint_cams = trainer.GS.scene.getTrainCameras() - path_cameras = generate_fully_smooth_cameras_with_tsp(existing_cameras=viewpoint_cams, - n_selected=8, - n_points_per_segment=60, - closed=False) + path_cameras = generate_fully_smooth_cameras_with_tsp( + existing_cameras=viewpoint_cams, + n_selected=8, + n_points_per_segment=60, + closed=False, + ) path_cameras = path_cameras + path_cameras[::-1] path_renderings = [] @@ -200,19 +268,21 @@ def render_all_views(scene_dir): image_np = np.clip(image.detach().cpu().numpy().transpose(1, 2, 0), 0, 1) image_np = (image_np * 255).astype(np.uint8) path_renderings.append(image_np) - save_numpy_frames_as_mp4(frames=path_renderings, - output_path=os.path.join(STATIC_FILE_SERVING_FOLDER, "render_all_views.mp4"), - fps=30, - center_crop=0.85) - + save_numpy_frames_as_mp4( + frames=path_renderings, + output_path=os.path.join(STATIC_FILE_SERVING_FOLDER, "render_all_views.mp4"), + fps=30, + center_crop=0.85, + ) + return os.path.join(STATIC_FILE_SERVING_FOLDER, "render_all_views.mp4") + def render_circular_path(scene_dir): viewpoint_cams = trainer.GS.scene.getTrainCameras() - path_cameras = generate_circular_camera_path(existing_cameras=viewpoint_cams, - N=240, - radius_scale=0.65, - d=0) + path_cameras = generate_circular_camera_path( + existing_cameras=viewpoint_cams, N=240, radius_scale=0.65, d=0 + ) path_renderings = [] with torch.no_grad(): @@ -222,78 +292,63 @@ def render_circular_path(scene_dir): image_np = np.clip(image.detach().cpu().numpy().transpose(1, 2, 0), 0, 1) image_np = (image_np * 255).astype(np.uint8) path_renderings.append(image_np) - save_numpy_frames_as_mp4(frames=path_renderings, - output_path=os.path.join(STATIC_FILE_SERVING_FOLDER, "render_circular_path.mp4"), - fps=30, - center_crop=0.85) - + save_numpy_frames_as_mp4( + frames=path_renderings, + output_path=os.path.join( + STATIC_FILE_SERVING_FOLDER, "render_circular_path.mp4" + ), + fps=30, + center_crop=0.85, + ) + return os.path.join(STATIC_FILE_SERVING_FOLDER, "render_circular_path.mp4") + # Download Functions def download_cameras(): path = os.path.join(MODEL_PATH, "cameras.json") return f"[đŸ“Ĩ Download Cameras.json](file={path})" + def download_model(): path = os.path.join(STATIC_FILE_SERVING_FOLDER, "point_cloud_final.ply") return f"[đŸ“Ĩ Download Pretrained Model (.ply)](file={path})" + # Full pipeline helpers def run_full_pipeline(input_path, num_ref_views, num_corrs, max_size=1024): tmpdirname = tempfile.mkdtemp() scene_dir = os.path.join(tmpdirname, "scene") os.makedirs(scene_dir, exist_ok=True) - selected_frames = process_input(input_path, num_ref_views, scene_dir, max_size) + selected_frames = process_input_for_colmap( + input_path, num_ref_views, scene_dir, max_size + ) run_colmap_on_scene(scene_dir) return selected_frames, scene_dir -# Preprocess Input -def process_input(input_path, num_ref_views, output_dir, max_size=1024): - if isinstance(input_path, (str, os.PathLike)): - if os.path.isdir(input_path): - frames = [] - for img_file in sorted(os.listdir(input_path)): - if img_file.lower().endswith(('jpg', 'jpeg', 'png')): - img = Image.open(os.path.join(output_dir, img_file)).convert('RGB') - img.thumbnail((1024, 1024)) - frames.append(np.array(img)) - else: - frames = read_video_frames(video_input=input_path, max_size=max_size) - else: - frames = read_video_frames(video_input=input_path, max_size=max_size) - - frames_scores = preprocess_frames(frames) - selected_frames_indices = select_optimal_frames(scores=frames_scores, - k=min(num_ref_views, len(frames))) - selected_frames = [frames[frame_idx] for frame_idx in selected_frames_indices] - - save_frames_to_scene_dir(frames=selected_frames, scene_dir=output_dir) - return selected_frames - -def preprocess_input(input_path, num_ref_views, max_size=1024): - tmpdirname = tempfile.mkdtemp() - scene_dir = os.path.join(tmpdirname, "scene") - os.makedirs(scene_dir, exist_ok=True) - selected_frames = process_input(input_path, num_ref_views, scene_dir, max_size) - run_colmap_on_scene(scene_dir) - return selected_frames, scene_dir def start_training(scene_dir, num_ref_views, num_corrs, num_steps): - return capture_logs(run_training_pipeline, scene_dir, num_ref_views, num_corrs, num_steps) - + return capture_logs( + run_training_pipeline, scene_dir, num_ref_views, num_corrs, num_steps + ) + # Gradio App with gr.Blocks() as demo: with gr.Row(): with gr.Column(scale=6): - gr.Markdown(""" + gr.Markdown( + """ ## 📄 EDGS: Eliminating Densification for Efficient Convergence of 3DGS 🔗 Project Page - """, elem_id="header") + """, + elem_id="header", + ) - gr.Markdown(""" + gr.Markdown( + """ ### đŸ› ī¸ How to Use This Demo 1. Upload a **front-facing video** or **a folder of images** of a **static** scene. @@ -306,37 +361,52 @@ def start_training(scene_dir, num_ref_views, num_corrs, num_steps): ✅ Best for scenes with small camera motion. ❗ For full 360° or large-scale scenes, we recommend the Colab version (see project page). - """, elem_id="quickstart") - + """, + elem_id="quickstart", + ) scene_dir_state = gr.State() ply_model_state = gr.State() with gr.Row(): with gr.Column(scale=2): - input_file = gr.File(label="Upload Video or Images", - file_types=[".mp4", ".avi", ".mov", ".png", ".jpg", ".jpeg"], - file_count="multiple") + input_file = gr.File( + label="Upload Video or Images", + file_types=[".mp4", ".avi", ".mov", ".png", ".jpg", ".jpeg"], + file_count="multiple", + ) gr.Examples( - examples = [ + examples=[ [["assets/examples/video_bakery.mp4"]], [["assets/examples/video_flowers.mp4"]], [["assets/examples/video_fruits.mp4"]], [["assets/examples/video_plant.mp4"]], [["assets/examples/video_salad.mp4"]], [["assets/examples/video_tram.mp4"]], - [["assets/examples/video_tulips.mp4"]] - ], + [["assets/examples/video_tulips.mp4"]], + ], inputs=[input_file], label="đŸŽžī¸ ALternatively, try an Example Video", - examples_per_page=4 + examples_per_page=4, + ) + ref_slider = gr.Slider( + 4, 32, value=16, step=1, label="Number of Reference Views" + ) + corr_slider = gr.Slider( + 5000, + 30000, + value=20000, + step=1000, + label="Correspondences per Reference View", + ) + fit_steps_slider = gr.Slider( + 100, 5000, value=400, step=100, label="Number of optimization steps" ) - ref_slider = gr.Slider(4, 32, value=16, step=1, label="Number of Reference Views") - corr_slider = gr.Slider(5000, 30000, value=20000, step=1000, label="Correspondences per Reference View") - fit_steps_slider = gr.Slider(100, 5000, value=400, step=100, label="Number of optimization steps") preprocess_button = gr.Button("📸 Preprocess Input") start_button = gr.Button("🚀 Start Reconstruction", interactive=False) - gallery = gr.Gallery(label="Selected Reference Views", columns=4, height=300) + gallery = gr.Gallery( + label="Selected Reference Views", columns=4, height=300 + ) with gr.Column(scale=3): gr.Markdown("### đŸ‹ī¸ Training Visualization") @@ -351,43 +421,104 @@ def start_training(scene_dir, num_ref_views, num_corrs, num_steps): gr.Markdown("### đŸ“Ļ Output Files") with gr.Row(height=50): with gr.Column(): - #gr.Markdown(value=f"[đŸ“Ĩ Download .ply](file/point_cloud_final.ply)") + # gr.Markdown(value=f"[đŸ“Ĩ Download .ply](file/point_cloud_final.ply)") download_cameras_button = gr.Button("đŸ“Ĩ Download Cameras.json") download_cameras_file = gr.File(label="📄 Cameras.json") with gr.Column(): - download_model_button = gr.Button("đŸ“Ĩ Download Pretrained Model (.ply)") + download_model_button = gr.Button( + "đŸ“Ĩ Download Pretrained Model (.ply)" + ) download_model_file = gr.File(label="📄 Pretrained Model (.ply)") log_output_box = gr.Textbox(label="đŸ–Ĩī¸ Log", lines=10, interactive=False) - def on_preprocess_click(input_file, num_ref_views): - images, scene_dir = preprocess_input(input_file, num_ref_views) - return gr.update(value=[x[...,::-1] for x in images]), scene_dir, gr.update(interactive=True) + def on_preprocess_click(input_file_obj, num_ref_views_val): + """ + Handles the preprocess button click. + Calls the main preprocessing orchestrator and updates the UI. + """ + if input_file_obj is None: + # Handles case where no file is uploaded if input_file component is not required + # or if the user clears the selection. + # For gr.Examples, input_file_obj will be a list containing a list of paths. + # For direct upload with file_count="multiple", it's a list of file objects. + # For direct upload with file_count="single", it's a single file object. + # orchestrate_video_to_colmap_scene should be robust to these. + gr.Warning("Please upload a file or select an example.") + return None, None, gr.update(interactive=False) + + selected_bgr_frames, scene_dir = orchestrate_video_to_colmap_scene( + input_path=input_file_obj, # Pass the raw Gradio file object(s) + num_ref_views=num_ref_views_val, + max_size=1024, + base_work_dir="./gradio_processed_scenes", # Or configure as needed + ) + + if not scene_dir: # Indicates preprocessing failed + gr.Error("Preprocessing failed. Please check the logs or input file.") + return ( + None, + None, + gr.update(interactive=False), + ) # Keep gallery empty, scene_dir None, button disabled + + # Convert BGR numpy arrays to RGB for Gradio gallery. + # gr.Gallery can display a list of NumPy arrays (H, W, C) or PIL Images. + # Assuming selected_bgr_frames contains BGR NumPy arrays. + gallery_display_images = [] + if selected_bgr_frames: + gallery_display_images = [ + frame[..., ::-1] + for frame in selected_bgr_frames + if isinstance(frame, np.ndarray) + ] + + return ( + gr.update(value=gallery_display_images), + scene_dir, # Update the scene_dir_state + gr.update(interactive=True), # Enable the 'Start Reconstruction' button + ) def on_start_click(scene_dir, num_ref_views, num_corrs, num_steps): - (video_path, ply_path), logs = start_training(scene_dir, num_ref_views, num_corrs, num_steps) + (video_path, ply_path), logs = start_training( + scene_dir, num_ref_views, num_corrs, num_steps + ) return video_path, ply_path, logs preprocess_button.click( fn=on_preprocess_click, inputs=[input_file, ref_slider], - outputs=[gallery, scene_dir_state, start_button] + outputs=[gallery, scene_dir_state, start_button], ) start_button.click( fn=on_start_click, inputs=[scene_dir_state, ref_slider, corr_slider, fit_steps_slider], - outputs=[video_output, model3d_viewer, log_output_box] + outputs=[video_output, model3d_viewer, log_output_box], ) - render_all_views_button.click(fn=render_all_views, inputs=[scene_dir_state], outputs=[rendered_video_output]) - render_circular_path_button.click(fn=render_circular_path, inputs=[scene_dir_state], outputs=[rendered_video_output]) - - download_cameras_button.click(fn=lambda: os.path.join(MODEL_PATH, "cameras.json"), inputs=[], outputs=[download_cameras_file]) - download_model_button.click(fn=lambda: os.path.join(STATIC_FILE_SERVING_FOLDER, "point_cloud_final.ply"), inputs=[], outputs=[download_model_file]) + render_all_views_button.click( + fn=render_all_views, inputs=[scene_dir_state], outputs=[rendered_video_output] + ) + render_circular_path_button.click( + fn=render_circular_path, + inputs=[scene_dir_state], + outputs=[rendered_video_output], + ) + download_cameras_button.click( + fn=lambda: os.path.join(MODEL_PATH, "cameras.json"), + inputs=[], + outputs=[download_cameras_file], + ) + download_model_button.click( + fn=lambda: os.path.join(STATIC_FILE_SERVING_FOLDER, "point_cloud_final.ply"), + inputs=[], + outputs=[download_model_file], + ) - gr.Markdown(""" + gr.Markdown( + """ --- ### 📖 Detailed Overview @@ -413,12 +544,22 @@ def on_start_click(scene_dir, num_ref_views, num_corrs, num_steps): --- Preloaded models coming soon. (TODO) - """, elem_id="details") + """, + elem_id="details", + ) if __name__ == "__main__": - parser = argparse.ArgumentParser(description="Launch Gradio demo for EDGS preprocessing and 3D viewing.") - parser.add_argument("--port", type=int, default=7860, help="Port to launch the Gradio app on.") - parser.add_argument("--no_share", action='store_true', help="Disable Gradio sharing and assume local access (default: share=True)") + parser = argparse.ArgumentParser( + description="Launch Gradio demo for EDGS preprocessing and 3D viewing." + ) + parser.add_argument( + "--port", type=int, default=7860, help="Port to launch the Gradio app on." + ) + parser.add_argument( + "--no_share", + action="store_true", + help="Disable Gradio sharing and assume local access (default: share=True)", + ) args = parser.parse_args() demo.launch(server_name="0.0.0.0", server_port=args.port, share=not args.no_share) diff --git a/install.sh b/script/install.sh similarity index 100% rename from install.sh rename to script/install.sh diff --git a/metrics.py b/script/metrics.py similarity index 100% rename from metrics.py rename to script/metrics.py diff --git a/script/train.py b/script/train.py new file mode 100644 index 0000000..895c1e3 --- /dev/null +++ b/script/train.py @@ -0,0 +1,73 @@ +import os +import sys +from argparse import Namespace + +import hydra +import omegaconf +import wandb + +# Add the project root directory to sys.path +# so that modules from 'source' can be imported. +project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) +if project_root not in sys.path: + sys.path.insert(0, project_root) + +from source.trainer import EDGSTrainer +from source.utils_aux import set_seed + + +@hydra.main(config_path="../configs", config_name="train", version_base="1.2") +def main(cfg: omegaconf.DictConfig): + _ = wandb.init( + entity=cfg.wandb.entity, + project=cfg.wandb.project, + config=omegaconf.OmegaConf.to_container( + cfg, resolve=True, throw_on_missing=True + ), + tags=[cfg.wandb.tag], + name=cfg.wandb.name, + mode=cfg.wandb.mode, + ) + omegaconf.OmegaConf.resolve(cfg) + set_seed(cfg.seed) + + # Init output folder + print("Output folder: {}".format(cfg.gs.dataset.model_path)) + os.makedirs(cfg.gs.dataset.model_path, exist_ok=True) + with open(os.path.join(cfg.gs.dataset.model_path, "cfg_args"), "w") as cfg_log_f: + params = { + "sh_degree": 3, + "source_path": cfg.gs.dataset.source_path, + "model_path": cfg.gs.dataset.model_path, + "images": cfg.gs.dataset.images, + "depths": "", + "resolution": -1, + "_white_background": cfg.gs.dataset.white_background, + "train_test_exp": False, + "data_device": cfg.gs.dataset.data_device, + "eval": False, + "convert_SHs_python": False, + "compute_cov3D_python": False, + "debug": False, + "antialiasing": False, + } + cfg_log_f.write(str(Namespace(**params))) + + # Init both agents + gs = hydra.utils.instantiate(cfg.gs) + + # Init trainer and launch training + trainer = EDGSTrainer(GS=gs, training_config=cfg.gs.opt, device=cfg.device) + + trainer.load_checkpoints(cfg.load) + trainer.timer.start() + trainer.init_with_corr(cfg.init_wC) + trainer.train(cfg.train) + + # All done + wandb.finish() + print("\nTraining complete.") + + +if __name__ == "__main__": + main() diff --git a/source/corr_init.py b/source/corr_init.py index 09c94a2..bfd6985 100644 --- a/source/corr_init.py +++ b/source/corr_init.py @@ -787,6 +787,11 @@ def init_gaussians_with_corr_fast(gaussians, scene, cfg, device, verbose=False, # Dummy first pass to initialize model with torch.no_grad(): + if len(viewpoint_stack) < 2: + print(f"âš ī¸ Warning: Only {len(viewpoint_stack)} viewpoints available. Need at least 2 for correspondence initialization.") + print("Skipping correspondence initialization - using SfM points only.") + return scene.train_cameras[3:], [], {} + viewpoint_cam1 = viewpoint_stack[0] viewpoint_cam2 = viewpoint_stack[1] imA = viewpoint_cam1.original_image.detach().cpu().numpy().transpose(1, 2, 0) diff --git a/source/utils_preprocess.py b/source/utils_preprocess.py index 7c6dab3..14da72b 100644 --- a/source/utils_preprocess.py +++ b/source/utils_preprocess.py @@ -1,41 +1,42 @@ # This file contains function for video or image collection preprocessing. # For video we do the preprocessing and select k sharpest frames. -# Afterwards scene is constructed +# Afterwards scene is constructed +import os +import time + import cv2 import numpy as np -from tqdm import tqdm import pycolmap -import os -import time -import tempfile -from moviepy import VideoFileClip from matplotlib import pyplot as plt +from moviepy import VideoFileClip from PIL import Image -import cv2 from tqdm import tqdm -WORKDIR = "../outputs/" - def get_rotation_moviepy(video_path): clip = VideoFileClip(video_path) rotation = 0 try: - displaymatrix = clip.reader.infos['inputs'][0]['streams'][2]['metadata'].get('displaymatrix', '') - if 'rotation of' in displaymatrix: - angle = float(displaymatrix.strip().split('rotation of')[-1].split('degrees')[0]) + displaymatrix = clip.reader.infos["inputs"][0]["streams"][2]["metadata"].get( + "displaymatrix", "" + ) + if "rotation of" in displaymatrix: + angle = float( + displaymatrix.strip().split("rotation of")[-1].split("degrees")[0] + ) rotation = int(angle) % 360 - + except Exception as e: print(f"No displaymatrix rotation found: {e}") clip.reader.close() - #if clip.audio: + # if clip.audio: # clip.audio.reader.close_proc() return rotation + def resize_max_side(frame, max_size): h, w = frame.shape[:2] scale = max_size / max(h, w) @@ -43,66 +44,138 @@ def resize_max_side(frame, max_size): frame = cv2.resize(frame, (int(w * scale), int(h * scale))) return frame -def read_video_frames(video_input, k=1, max_size=1024): + +def extract_video_frames_to_disk(video_input, output_dir, k=1, max_size=1024): """ - Extracts every k-th frame from a video or list of images, resizes to max size, and returns frames as list. + Extracts every k-th frame from a video using ffmpeg, saves to disk to avoid memory overflow. Parameters: video_input (str, file-like, or list): Path to video file, file-like object, or list of image files. + output_dir (str): Directory to save extracted frames. k (int): Interval for frame extraction (every k-th frame). max_size (int): Maximum size for width or height after resizing. Returns: - frames (list): List of resized frames (numpy arrays). + frame_paths (list): List of paths to extracted frame files. """ + import subprocess + import tempfile + import shutil + # Handle list of image files (not single video in a list) if isinstance(video_input, list): # If it's a single video in a list, treat it as video - if len(video_input) == 1 and video_input[0].name.endswith(('.mp4', '.avi', '.mov')): + if len(video_input) == 1 and video_input[0].name.endswith( + (".mp4", ".avi", ".mov") + ): video_input = video_input[0] # unwrap single video file else: - # Treat as list of images - frames = [] - for img_file in video_input: + # Treat as list of images - copy and resize them + frame_paths = [] + for idx, img_file in enumerate(video_input): img = Image.open(img_file.name).convert("RGB") - img.thumbnail((max_size, max_size)) - frames.append(np.array(img)[...,::-1]) - return frames + # Resize if necessary + width, height = img.size + if max(width, height) > max_size: + scale = max_size / max(width, height) + new_width = int(width * scale) + new_height = int(height * scale) + img = img.resize((new_width, new_height), Image.LANCZOS) + + output_path = os.path.join(output_dir, f"frame_{idx:08d}.jpg") + img.save(output_path, "JPEG", quality=95) + frame_paths.append(output_path) + return frame_paths # Handle file-like or path - if hasattr(video_input, 'name'): + if hasattr(video_input, "name"): video_path = video_input.name elif isinstance(video_input, (str, os.PathLike)): video_path = str(video_input) else: - raise ValueError("Unsupported video input type. Must be a filepath, file-like object, or list of images.") + raise ValueError( + "Unsupported video input type. Must be a filepath, file-like object, or list of images." + ) + # Create output directory + os.makedirs(output_dir, exist_ok=True) + # Use ffmpeg to extract frames + print(f"Extracting frames from video using ffmpeg...") + try: + # First, get video info to calculate frame interval + result = subprocess.run([ + 'ffprobe', '-v', 'quiet', '-count_frames', '-select_streams', 'v:0', + '-show_entries', 'stream=nb_frames', '-of', 'csv=p=0', video_path + ], capture_output=True, text=True, check=True) + + total_frames = int(result.stdout.strip()) + print(f"Total frames in video: {total_frames}") + + # Extract every k-th frame using ffmpeg with scaling + ffmpeg_cmd = [ + 'ffmpeg', '-i', video_path, '-y', + '-vf', f'select=not(mod(n\\,{k})),scale=w=min({max_size}\\,iw):h=min({max_size}\\,ih):force_original_aspect_ratio=decrease', + '-q:v', '2', # High quality + os.path.join(output_dir, 'frame_%08d.jpg') + ] + + subprocess.run(ffmpeg_cmd, check=True, capture_output=True) + + # Get list of extracted frame paths + frame_paths = sorted([ + os.path.join(output_dir, f) for f in os.listdir(output_dir) + if f.startswith('frame_') and f.endswith('.jpg') + ]) + + print(f"Extracted {len(frame_paths)} frames to {output_dir}") + return frame_paths + + except subprocess.CalledProcessError as e: + print(f"ffmpeg failed: {e}") + # Fallback to opencv if ffmpeg fails + return extract_video_frames_fallback(video_path, output_dir, k, max_size) + except FileNotFoundError: + print("ffmpeg not found, using opencv fallback") + return extract_video_frames_fallback(video_path, output_dir, k, max_size) + + +def extract_video_frames_fallback(video_path, output_dir, k=1, max_size=1024): + """ + Fallback method using opencv, but saves frames to disk instead of memory. + """ cap = cv2.VideoCapture(video_path) if not cap.isOpened(): raise ValueError(f"Error: Could not open video {video_path}.") total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) frame_count = 0 - frames = [] + frame_paths = [] - with tqdm(total=total_frames // k, desc="Processing Video", unit="frame") as pbar: + os.makedirs(output_dir, exist_ok=True) + + with tqdm(total=total_frames // k, desc="Extracting Video Frames", unit="frame") as pbar: while True: ret, frame = cap.read() if not ret: break if frame_count % k == 0: - frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) + # Resize frame h, w = frame.shape[:2] scale = max(h, w) / max_size if scale > 1: frame = cv2.resize(frame, (int(w / scale), int(h / scale))) - frames.append(frame[...,[2,1,0]]) + + # Save frame to disk + frame_path = os.path.join(output_dir, f"frame_{len(frame_paths):08d}.jpg") + cv2.imwrite(frame_path, frame, [cv2.IMWRITE_JPEG_QUALITY, 95]) + frame_paths.append(frame_path) pbar.update(1) frame_count += 1 cap.release() - return frames + return frame_paths + def resize_max_side(frame, max_size): """ @@ -110,7 +183,7 @@ def resize_max_side(frame, max_size): """ height, width = frame.shape[:2] max_dim = max(height, width) - + if max_dim <= max_size: return frame # No need to resize @@ -118,41 +191,47 @@ def resize_max_side(frame, max_size): new_width = int(width * scale) new_height = int(height * scale) - resized_frame = cv2.resize(frame, (new_width, new_height), interpolation=cv2.INTER_AREA) + resized_frame = cv2.resize( + frame, (new_width, new_height), interpolation=cv2.INTER_AREA + ) return resized_frame - def variance_of_laplacian(image): - # compute the Laplacian of the image and then return the focus - # measure, which is simply the variance of the Laplacian - return cv2.Laplacian(image, cv2.CV_64F).var() - -def process_all_frames(IMG_FOLDER = '/scratch/datasets/hq_data/night2_all_frames', - to_visualize=False, - save_images=True): + # compute the Laplacian of the image and then return the focus + # measure, which is simply the variance of the Laplacian + return cv2.Laplacian(image, cv2.CV_64F).var() + + +def process_all_frames( + IMG_FOLDER="/scratch/datasets/hq_data/night2_all_frames", + to_visualize=False, + save_images=True, +): dict_scores = {} - for idx, img_name in tqdm(enumerate(sorted([x for x in os.listdir(IMG_FOLDER) if '.png' in x]))): - - img = cv2.imread(os.path.join(IMG_FOLDER, img_name))#[250:, 100:] + for idx, img_name in tqdm( + enumerate(sorted([x for x in os.listdir(IMG_FOLDER) if ".png" in x])) + ): + img = cv2.imread(os.path.join(IMG_FOLDER, img_name)) # [250:, 100:] gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) - fm = variance_of_laplacian(gray) + \ - variance_of_laplacian(cv2.resize(gray, (0,0), fx=0.75, fy=0.75)) + \ - variance_of_laplacian(cv2.resize(gray, (0,0), fx=0.5, fy=0.5)) + \ - variance_of_laplacian(cv2.resize(gray, (0,0), fx=0.25, fy=0.25)) + fm = ( + variance_of_laplacian(gray) + + variance_of_laplacian(cv2.resize(gray, (0, 0), fx=0.75, fy=0.75)) + + variance_of_laplacian(cv2.resize(gray, (0, 0), fx=0.5, fy=0.5)) + + variance_of_laplacian(cv2.resize(gray, (0, 0), fx=0.25, fy=0.25)) + ) if to_visualize: plt.figure() plt.title(f"Laplacian score: {fm:.2f}") - plt.imshow(img[..., [2,1,0]]) + plt.imshow(img[..., [2, 1, 0]]) plt.show() - dict_scores[idx] = {"idx" : idx, - "img_name" : img_name, - "score" : fm} + dict_scores[idx] = {"idx": idx, "img_name": img_name, "score": fm} if save_images: dict_scores[idx]["img"] = img - + return dict_scores + def select_optimal_frames(scores, k): """ Selects a minimal subset of frames while ensuring no gaps exceed k. @@ -165,12 +244,14 @@ def select_optimal_frames(scores, k): list of int: Indices of selected frames. """ n = len(scores) - selected = [0, n-1] + selected = [0, n - 1] i = 0 # Start at the first frame while i < n: # Find the best frame to select within the next k frames - best_idx = max(range(i, min(i + k + 1, n)), key=lambda x: scores[x], default=None) + best_idx = max( + range(i, min(i + k + 1, n)), key=lambda x: scores[x], default=None + ) if best_idx is None: break # No more frames left @@ -187,9 +268,49 @@ def variance_of_laplacian(image): """ return cv2.Laplacian(image, cv2.CV_64F).var() + +def preprocess_frame_paths(frame_paths, verbose=False): + """ + Compute sharpness scores for a list of frame files using multi-scale Laplacian variance. + + Args: + frame_paths (list of str): List of paths to frame image files. + verbose (bool): If True, print scores. + + Returns: + list of float: Sharpness scores for each frame. + """ + scores = [] + + for idx, frame_path in enumerate(tqdm(frame_paths, desc="Scoring frames")): + # Load frame from disk + frame = cv2.imread(frame_path) + if frame is None: + print(f"Warning: Could not load frame {frame_path}") + scores.append(0.0) + continue + + gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY) + + fm = ( + variance_of_laplacian(gray) + + variance_of_laplacian(cv2.resize(gray, (0, 0), fx=0.75, fy=0.75)) + + variance_of_laplacian(cv2.resize(gray, (0, 0), fx=0.5, fy=0.5)) + + variance_of_laplacian(cv2.resize(gray, (0, 0), fx=0.25, fy=0.25)) + ) + + if verbose: + print(f"Frame {idx} ({os.path.basename(frame_path)}): Sharpness Score = {fm:.2f}") + + scores.append(fm) + + return scores + + def preprocess_frames(frames, verbose=False): """ Compute sharpness scores for a list of frames using multi-scale Laplacian variance. + DEPRECATED: Use preprocess_frame_paths instead to avoid memory issues. Args: frames (list of np.ndarray): List of frames (BGR images). @@ -204,12 +325,12 @@ def preprocess_frames(frames, verbose=False): gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY) fm = ( - variance_of_laplacian(gray) + - variance_of_laplacian(cv2.resize(gray, (0, 0), fx=0.75, fy=0.75)) + - variance_of_laplacian(cv2.resize(gray, (0, 0), fx=0.5, fy=0.5)) + - variance_of_laplacian(cv2.resize(gray, (0, 0), fx=0.25, fy=0.25)) + variance_of_laplacian(gray) + + variance_of_laplacian(cv2.resize(gray, (0, 0), fx=0.75, fy=0.75)) + + variance_of_laplacian(cv2.resize(gray, (0, 0), fx=0.5, fy=0.5)) + + variance_of_laplacian(cv2.resize(gray, (0, 0), fx=0.25, fy=0.25)) ) - + if verbose: print(f"Frame {idx}: Sharpness Score = {fm:.2f}") @@ -217,6 +338,7 @@ def preprocess_frames(frames, verbose=False): return scores + def select_optimal_frames(scores, k): """ Selects k frames by splitting into k segments and picking the sharpest frame from each. @@ -226,7 +348,7 @@ def select_optimal_frames(scores, k): k (int): Number of frames to select. Returns: - list of int: Indices of selected frames. + list of int: Indices of selected frames. """ n = len(scores) selected_indices = [] @@ -236,18 +358,42 @@ def select_optimal_frames(scores, k): start = i * segment_size end = (i + 1) * segment_size if i < k - 1 else n # Last chunk may be larger segment_scores = scores[start:end] - + if len(segment_scores) == 0: continue # Safety check if some segment is empty - + best_in_segment = start + np.argmax(segment_scores) selected_indices.append(best_in_segment) return sorted(selected_indices) + +def copy_selected_frames_to_scene_dir(selected_frame_paths, scene_dir): + """ + Copies selected frame files into the target scene directory under 'images/' subfolder. + + Args: + selected_frame_paths (list of str): List of paths to selected frame files. + scene_dir (str): Target path where 'images/' subfolder will be created. + """ + import shutil + + images_dir = os.path.join(scene_dir, "images") + os.makedirs(images_dir, exist_ok=True) + + for idx, frame_path in enumerate(selected_frame_paths): + filename = os.path.join( + images_dir, f"{idx:08d}.jpg" + ) # 00000000.jpg, 00000001.jpg, etc. + shutil.copy2(frame_path, filename) + + print(f"Copied {len(selected_frame_paths)} selected frames to {images_dir}") + + def save_frames_to_scene_dir(frames, scene_dir): """ Saves a list of frames into the target scene directory under 'images/' subfolder. + DEPRECATED: Use copy_selected_frames_to_scene_dir to avoid memory issues. Args: frames (list of np.ndarray): List of frames (BGR images) to save. @@ -257,19 +403,88 @@ def save_frames_to_scene_dir(frames, scene_dir): os.makedirs(images_dir, exist_ok=True) for idx, frame in enumerate(frames): - filename = os.path.join(images_dir, f"{idx:08d}.png") # 00000000.png, 00000001.png, etc. + filename = os.path.join( + images_dir, f"{idx:08d}.png" + ) # 00000000.png, 00000001.png, etc. cv2.imwrite(filename, frame) print(f"Saved {len(frames)} frames to {images_dir}") -def run_colmap_on_scene(scene_dir): +def create_fallback_reconstruction(image_dir, sparse_path): + """ + Create a minimal fallback reconstruction when COLMAP fails completely. + Creates a simple linear camera trajectory for the available images. + """ + # No need to import colmap_loader - we'll create text files directly + + print("🔧 Creating fallback reconstruction with assumed camera positions...") + + # Get list of images + image_files = sorted([f for f in os.listdir(image_dir) if f.lower().endswith(('.jpg', '.jpeg', '.png'))]) + + if len(image_files) < 1: + raise RuntimeError("No images found for fallback reconstruction") + + # Read first image to get dimensions + first_img_path = os.path.join(image_dir, image_files[0]) + from PIL import Image + img = Image.open(first_img_path) + width, height = img.size + + # Create minimal reconstruction directory + fallback_dir = os.path.join(sparse_path, "0") + os.makedirs(fallback_dir, exist_ok=True) + + # Create cameras.txt with simple pinhole model + cameras_txt = os.path.join(fallback_dir, "cameras.txt") + focal = max(width, height) # Simple focal length estimation + with open(cameras_txt, 'w') as f: + f.write("# Camera list with one line of data per camera:\n") + f.write("# CAMERA_ID, MODEL, WIDTH, HEIGHT, PARAMS[]\n") + f.write(f"1 PINHOLE {width} {height} {focal} {focal} {width/2} {height/2}\n") + + # Create images.txt with linear trajectory + images_txt = os.path.join(fallback_dir, "images.txt") + with open(images_txt, 'w') as f: + f.write("# Image list with two lines of data per image:\n") + f.write("# IMAGE_ID, QW, QX, QY, QZ, TX, TY, TZ, CAMERA_ID, NAME\n") + f.write("# POINTS2D[] as (X, Y, POINT3D_ID)\n") + + for i, img_file in enumerate(image_files): + # Simple linear trajectory along Z-axis + tx, ty, tz = 0.0, 0.0, -i * 0.5 + # Identity quaternion (no rotation) + qw, qx, qy, qz = 1.0, 0.0, 0.0, 0.0 + + f.write(f"{i+1} {qw} {qx} {qy} {qz} {tx} {ty} {tz} 1 {img_file}\n") + f.write("\n") # Empty line for points2D + + # Create minimal points3D.txt with a few dummy points + points_txt = os.path.join(fallback_dir, "points3D.txt") + with open(points_txt, 'w') as f: + f.write("# 3D point list with one line of data per point:\n") + f.write("# POINT3D_ID, X, Y, Z, R, G, B, ERROR, TRACK[] as (IMAGE_ID, POINT2D_IDX)\n") + # Add some dummy 3D points for basic initialization + for i in range(10): + x, y, z = i * 0.1, 0.0, -1.0 # Simple grid of points + r, g, b = 128, 128, 128 # Gray color + error = 1.0 + f.write(f"{i+1} {x} {y} {z} {r} {g} {b} {error}\n") + + print(f"✅ Created fallback reconstruction with {len(image_files)} cameras at {fallback_dir}") + print("âš ī¸ Note: This is a basic reconstruction with assumed camera positions. Results may be limited.") + + +def run_colmap_on_scene(scene_dir, force_pinhole=True): """ Runs feature extraction, matching, and mapping on all images inside scene_dir/images using pycolmap. + Forces PINHOLE camera model to avoid distortion issues. Args: scene_dir (str): Path to scene directory containing 'images' folder. - + force_pinhole (bool): If True, forces PINHOLE camera model during reconstruction. + TODO: if the function hasn't managed to match all the frames either increase image size, increase number of features or just remove those frames from the folder scene_dir/images """ @@ -280,55 +495,384 @@ def run_colmap_on_scene(scene_dir): database_path = os.path.join(scene_dir, "database.db") sparse_path = os.path.join(scene_dir, "sparse") image_dir = os.path.join(scene_dir, "images") - + # Make sure output directories exist os.makedirs(sparse_path, exist_ok=True) - # Step 1: Feature Extraction + # Step 1: Feature Extraction with more aggressive settings pycolmap.extract_features( database_path, image_dir, sift_options={ - "max_num_features": 512 * 2, - "max_image_size": 512 * 1, - } + "max_num_features": 8192, # Much higher feature count + "max_image_size": 1600, # Higher resolution + "first_octave": -1, # More detailed features + "num_octaves": 4, + "octave_resolution": 3, + "peak_threshold": 0.005, # More lenient peak detection + "edge_threshold": 20, # More lenient edge threshold + }, ) print(f"Finished feature extraction in {(time.time() - start_time):.2f}s.") - # Step 2: Feature Matching - pycolmap.match_exhaustive(database_path) + # Step 2: Feature Matching with correct API + sift_matching_options = pycolmap.SiftMatchingOptions() + sift_matching_options.max_ratio = 0.9 + sift_matching_options.max_distance = 0.8 + sift_matching_options.cross_check = True + + pycolmap.match_exhaustive(database_path, sift_options=sift_matching_options) print(f"Finished feature matching in {(time.time() - start_time):.2f}s.") - # Step 3: Mapping + # Step 3: Mapping with more lenient parameters for challenging videos pipeline_options = pycolmap.IncrementalPipelineOptions() - pipeline_options.min_num_matches = 15 + pipeline_options.min_num_matches = 8 # Lower minimum matches pipeline_options.multiple_models = True pipeline_options.max_num_models = 50 pipeline_options.max_model_overlap = 20 - pipeline_options.min_model_size = 10 + pipeline_options.min_model_size = 3 # Allow smaller models pipeline_options.extract_colors = True pipeline_options.num_threads = 8 - pipeline_options.mapper.init_min_num_inliers = 30 - pipeline_options.mapper.init_max_error = 8.0 - pipeline_options.mapper.init_min_tri_angle = 5.0 - - reconstruction = pycolmap.incremental_mapping( - database_path=database_path, - image_path=image_dir, - output_path=sparse_path, - options=pipeline_options, - ) - print(f"Finished incremental mapping in {(time.time() - start_time):.2f}s.") + + # More lenient mapper options + pipeline_options.mapper.init_min_num_inliers = 15 # Lower inlier threshold + pipeline_options.mapper.init_max_error = 12.0 # Higher error tolerance + pipeline_options.mapper.init_min_tri_angle = 2.0 # Lower triangulation angle + pipeline_options.mapper.abs_pose_min_num_inliers = 15 + pipeline_options.mapper.abs_pose_max_error = 12.0 + pipeline_options.mapper.filter_max_reproj_error = 8.0 + pipeline_options.mapper.filter_min_tri_angle = 1.5 + + # Note: force_pinhole will be applied after reconstruction - # Step 4: Post-process Cameras to SIMPLE_PINHOLE + try: + reconstruction = pycolmap.incremental_mapping( + database_path=database_path, + image_path=image_dir, + output_path=sparse_path, + options=pipeline_options, + ) + print(f"Finished incremental mapping in {(time.time() - start_time):.2f}s.") + except Exception as e: + print(f"âš ī¸ Initial reconstruction failed: {e}") + print("🔄 Trying with even more lenient settings...") + + # Try with ultra-lenient settings as fallback + pipeline_options.min_num_matches = 5 + pipeline_options.min_model_size = 2 + pipeline_options.mapper.init_min_num_inliers = 10 + pipeline_options.mapper.init_max_error = 20.0 + pipeline_options.mapper.init_min_tri_angle = 1.0 + + try: + reconstruction = pycolmap.incremental_mapping( + database_path=database_path, + image_path=image_dir, + output_path=sparse_path, + options=pipeline_options, + ) + print(f"✅ Fallback reconstruction succeeded in {(time.time() - start_time):.2f}s.") + except Exception as e2: + print(f"❌ Both reconstruction attempts failed: {e2}") + raise RuntimeError("COLMAP reconstruction failed. The video might have insufficient overlap or features.") + + # Step 4: Check if reconstruction was successful recon_path = os.path.join(sparse_path, "0") - reconstruction = pycolmap.Reconstruction(recon_path) - - for cam in reconstruction.cameras.values(): - cam.model = 'SIMPLE_PINHOLE' - cam.params = cam.params[:3] # Keep only [f, cx, cy] + if not os.path.exists(recon_path): + # Check for other reconstruction indices + reconstructions_found = [] + for i in range(10): # Check indices 0-9 + alt_path = os.path.join(sparse_path, str(i)) + if os.path.exists(alt_path) and any(os.path.exists(os.path.join(alt_path, f)) + for f in ["cameras.bin", "images.bin", "points3D.bin"]): + reconstructions_found.append(i) + + if reconstructions_found: + # Use the largest reconstruction + best_idx = max(reconstructions_found) + recon_path = os.path.join(sparse_path, str(best_idx)) + print(f"â„šī¸ Using reconstruction {best_idx} instead of 0") + + # Move to index 0 for compatibility + target_path = os.path.join(sparse_path, "0") + if not os.path.exists(target_path): + import shutil + shutil.move(recon_path, target_path) + recon_path = target_path + print(f"📁 Moved reconstruction to sparse/0/") + else: + print("❌ COLMAP reconstruction failed - creating minimal fallback reconstruction") + return create_fallback_reconstruction(image_dir, sparse_path) - reconstruction.write(recon_path) + # Step 5: Convert cameras to PINHOLE if needed + if os.path.exists(recon_path): + reconstruction = pycolmap.Reconstruction(recon_path) + + if len(reconstruction.cameras) == 0: + raise RuntimeError("❌ Reconstruction contains no cameras") + if len(reconstruction.images) == 0: + raise RuntimeError("❌ Reconstruction contains no images") + if len(reconstruction.points3D) == 0: + print("âš ī¸ Warning: Reconstruction contains no 3D points") + + for cam in reconstruction.cameras.values(): + if force_pinhole and cam.model != "PINHOLE": + print(f"Converting camera {cam.camera_id} from {cam.model} to PINHOLE") + cam.model = "PINHOLE" + # Ensure we have exactly 4 parameters [fx, fy, cx, cy] + if len(cam.params) >= 4: + cam.params = cam.params[:4] + elif len(cam.params) >= 3: + # Duplicate focal length if we only have 3 params + f, cx, cy = cam.params[:3] + cam.params = [f, f, cx, cy] + else: + # Default values if params are insufficient + focal = max(cam.width, cam.height) + cam.params = [focal, focal, cam.width/2, cam.height/2] + + reconstruction.write(recon_path) + print(f"✅ Saved reconstruction with PINHOLE cameras to {recon_path}") + print(f"📊 Reconstruction stats: {len(reconstruction.cameras)} cameras, {len(reconstruction.images)} images, {len(reconstruction.points3D)} points") print(f"Total pipeline time: {(time.time() - start_time):.2f}s.") + +def process_input_for_colmap(input_path, num_ref_views, output_dir, max_size=1024): + """ + Memory-efficient helper function to process video/images, select optimal frames, + and save them to the output_dir/images without loading all frames into memory. + """ + import tempfile + import shutil + + # Create temporary directory for extracted frames + temp_frames_dir = tempfile.mkdtemp(prefix="edgs_frames_") + + try: + if isinstance(input_path, (str, os.PathLike)): # If input_path is a path string + if os.path.isdir(input_path): # If it's a directory of images + print(f"Processing image directory: {input_path}") + # Copy and resize images to temp directory + frame_paths = [] + image_files = sorted([ + f for f in os.listdir(input_path) + if f.lower().endswith(("jpg", "jpeg", "png")) + ]) + + for idx, img_file in enumerate(image_files): + img = Image.open(os.path.join(input_path, img_file)).convert("RGB") + # Resize if necessary + width, height = img.size + if max(width, height) > max_size: + scale = max_size / max(width, height) + new_width = int(width * scale) + new_height = int(height * scale) + img = img.resize((new_width, new_height), Image.LANCZOS) + + output_path = os.path.join(temp_frames_dir, f"frame_{idx:08d}.jpg") + img.save(output_path, "JPEG", quality=95) + frame_paths.append(output_path) + + else: # If it's a single video file path + print(f"Processing video file: {input_path}") + frame_paths = extract_video_frames_to_disk( + video_input=input_path, + output_dir=temp_frames_dir, + max_size=max_size + ) + elif hasattr(input_path, "name"): # File-like object (e.g., from Gradio upload) + print(f"Processing uploaded video file: {input_path.name}") + frame_paths = extract_video_frames_to_disk( + video_input=input_path, + output_dir=temp_frames_dir, + max_size=max_size + ) + else: + raise ValueError(f"Unsupported input_path type: {type(input_path)}") + + if not frame_paths: + print("No frames extracted or read.") + return [] + + # Score frames without loading them all into memory + print(f"Scoring {len(frame_paths)} frames...") + frames_scores = preprocess_frame_paths(frame_paths) + + # Select optimal frames + selected_frames_indices = select_optimal_frames( + scores=frames_scores, k=min(num_ref_views, len(frame_paths)) + ) + + # Get paths to selected frames + selected_frame_paths = [frame_paths[idx] for idx in selected_frames_indices] + + print(f"Selected {len(selected_frame_paths)} optimal frames out of {len(frame_paths)}") + + # Copy selected frames to scene directory + copy_selected_frames_to_scene_dir(selected_frame_paths, output_dir) + + # Return empty list since we're not loading frames into memory anymore + # The actual frames are saved to disk in the scene directory + return [] + + finally: + # Clean up temporary directory + if os.path.exists(temp_frames_dir): + shutil.rmtree(temp_frames_dir) + print(f"Cleaned up temporary frame directory: {temp_frames_dir}") + + +def process_input_for_colmap_legacy(input_path, num_ref_views, output_dir, max_size=1024): + """ + DEPRECATED: Original memory-intensive version. + Helper function to read frames from video or image folder, select optimal ones, + and save them to the output_dir/images. + This is based on process_input from gradio_demo.py. + Renamed to avoid potential confusion if 'process_input' is too generic. + """ + frames_to_save_in_scene_dir = [] + if isinstance(input_path, (str, os.PathLike)): # If input_path is a path string + if os.path.isdir(input_path): # If it's a directory of images + print(f"Processing image directory: {input_path}") + raw_frames = [] + image_files = sorted( + [ + f + for f in os.listdir(input_path) + if f.lower().endswith(("jpg", "jpeg", "png")) + ] + ) + for img_file in image_files: + img = Image.open(os.path.join(input_path, img_file)).convert("RGB") + # Resize if necessary, similar to video frames + width, height = img.size + if max(width, height) > max_size: + scale = max_size / max(width, height) + new_width = int(width * scale) + new_height = int(height * scale) + img = img.resize((new_width, new_height), Image.LANCZOS) + raw_frames.append(np.array(img)) + else: # If it's a single video file path + print(f"Processing video file: {input_path}") + raw_frames = read_video_frames(video_input=input_path, max_size=max_size) + elif hasattr( + input_path, "name" + ): # If input_path is a file-like object (e.g., from Gradio upload) + print(f"Processing uploaded video file: {input_path.name}") + raw_frames = read_video_frames(video_input=input_path.name, max_size=max_size) + else: + raise ValueError(f"Unsupported input_path type: {type(input_path)}") + + if not raw_frames: + print("No frames extracted or read.") + return [] + + frames_scores = preprocess_frames( + raw_frames + ) # Assuming preprocess_frames takes list of numpy arrays + selected_frames_indices = select_optimal_frames( + scores=frames_scores, k=min(num_ref_views, len(raw_frames)) + ) + frames_to_save_in_scene_dir = [ + raw_frames[frame_idx] for frame_idx in selected_frames_indices + ] + + # The 'output_dir' here is the scene_dir where 'images' subfolder will be created + save_frames_to_scene_dir(frames=frames_to_save_in_scene_dir, scene_dir=output_dir) + return frames_to_save_in_scene_dir # Returns the list of selected frame data (numpy arrays) + + +def orchestrate_video_to_colmap_scene( + input_path, + num_ref_views, + max_size=1024, + base_work_dir="../outputs/processed_scenes", +): + """ + Orchestrates the full video/image folder preprocessing pipeline: + 1. Creates a temporary scene directory. + 2. Reads frames, selects optimal ones, saves them. + 3. Runs COLMAP on the scene. + Args: + input_path (str or file-like): Path string, a Gradio file object, or a list (e.g., from gr.Examples). + num_ref_views (int): Number of reference views to select. + max_size (int): Maximum size for width or height after resizing. + base_work_dir (str): Base directory for temporary scene directories. + Returns: + the list of selected frame image data and the path to the COLMAP processed scene directory. + This is based on preprocess_input from gradio_demo.py. + """ + actual_input_path_str = None + input_name_part = "temp_scene" # Default + + if hasattr(input_path, "name") and isinstance( + input_path.name, str + ): # Gradio file object + actual_input_path_str = input_path.name + input_name_part = os.path.splitext(os.path.basename(input_path.name))[0] + elif isinstance(input_path, (str, os.PathLike)): # Direct path string + actual_input_path_str = str(input_path) + input_name_part = os.path.splitext(os.path.basename(input_path))[0] + elif ( + isinstance(input_path, list) and input_path + ): # Handle list: take the first item. + # gr.Examples often wraps the path in another list, e.g., [['path/to/example.mp4']] + # So, we might need to unwrap it. + first_item_candidate = input_path[0] + if ( + isinstance(first_item_candidate, list) and first_item_candidate + ): # Check for nested list + first_item = first_item_candidate[0] + else: + first_item = first_item_candidate + + if hasattr(first_item, "name") and isinstance( + first_item.name, str + ): # Gradio file object in list + actual_input_path_str = first_item.name + input_name_part = os.path.splitext(os.path.basename(first_item.name))[0] + elif isinstance(first_item, (str, os.PathLike)): # Path string in list + actual_input_path_str = str(first_item) + input_name_part = os.path.splitext(os.path.basename(first_item))[0] + else: + print(f"Warning: Unsupported item type in input list: {type(first_item)}") + return [], None + else: + print(f"Error: Unsupported input_path type: {type(input_path)}") + return [], None + + if not actual_input_path_str: + print("Error: Could not determine a valid input file path.") + return [], None + + print(f"Orchestrating COLMAP scene from: {actual_input_path_str}") + + # Using a structured output directory instead of pure tempfile.mkdtemp for easier inspection + # scene_dir_parent = tempfile.mkdtemp() # Original approach + + # Ensure base_work_dir exists + os.makedirs(base_work_dir, exist_ok=True) + # Create a unique subdirectory within base_work_dir + timestamp = time.strftime("%Y%m%d-%H%M%S") + scene_dir = os.path.join(base_work_dir, f"{input_name_part}_{timestamp}") + + os.makedirs(scene_dir, exist_ok=True) + print(f"Created scene directory for COLMAP: {scene_dir}") + + # Process video/images to extract and select optimal frames + selected_frames_data = process_input_for_colmap( + actual_input_path_str, num_ref_views, scene_dir, max_size + ) + + # Check if images were saved to scene directory + images_dir = os.path.join(scene_dir, "images") + if not os.path.exists(images_dir) or not os.listdir(images_dir): + print(f"Frame processing failed for {input_path}. No images found in {images_dir}. Aborting COLMAP.") + return [], None + + # Run COLMAP with PINHOLE camera model enforced + run_colmap_on_scene(scene_dir, force_pinhole=True) # Force PINHOLE to avoid distortion + + print(f"COLMAP processing complete for {scene_dir}") + return selected_frames_data, scene_dir diff --git a/train.py b/train.py deleted file mode 100644 index 646409a..0000000 --- a/train.py +++ /dev/null @@ -1,63 +0,0 @@ -import os -from source.trainer import EDGSTrainer -from source.utils_aux import set_seed -import omegaconf -import wandb -import hydra -from argparse import Namespace -from omegaconf import OmegaConf - - -@hydra.main(config_path="configs", config_name="train", version_base="1.2") -def main(cfg: omegaconf.DictConfig): - _ = wandb.init(entity=cfg.wandb.entity, - project=cfg.wandb.project, - config=omegaconf.OmegaConf.to_container(cfg, resolve=True, throw_on_missing=True), - tags=[cfg.wandb.tag], - name = cfg.wandb.name, - mode = cfg.wandb.mode) - omegaconf.OmegaConf.resolve(cfg) - set_seed(cfg.seed) - - # Init output folder - print("Output folder: {}".format(cfg.gs.dataset.model_path)) - os.makedirs(cfg.gs.dataset.model_path, exist_ok=True) - with open(os.path.join(cfg.gs.dataset.model_path, "cfg_args"), 'w') as cfg_log_f: - params = { - "sh_degree": 3, - "source_path": cfg.gs.dataset.source_path, - "model_path": cfg.gs.dataset.model_path, - "images": cfg.gs.dataset.images, - "depths": "", - "resolution": -1, - "_white_background": cfg.gs.dataset.white_background, - "train_test_exp": False, - "data_device": cfg.gs.dataset.data_device, - "eval": False, - "convert_SHs_python": False, - "compute_cov3D_python": False, - "debug": False, - "antialiasing": False - } - cfg_log_f.write(str(Namespace(**params))) - - # Init both agents - gs = hydra.utils.instantiate(cfg.gs) - - # Init trainer and launch training - trainer = EDGSTrainer(GS=gs, - training_config=cfg.gs.opt, - device=cfg.device) - - trainer.load_checkpoints(cfg.load) - trainer.timer.start() - trainer.init_with_corr(cfg.init_wC) - trainer.train(cfg.train) - - # All done - wandb.finish() - print("\nTraining complete.") - -if __name__ == "__main__": - main() -