Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Only Blue CSE vertex on Cat Image #3457

Open
kingsj0405 opened this issue Sep 7, 2021 · 17 comments
Open

Only Blue CSE vertex on Cat Image #3457

kingsj0405 opened this issue Sep 7, 2021 · 17 comments
Labels
densepose issues specific to densepose

Comments

@kingsj0405
Copy link

kingsj0405 commented Sep 7, 2021

Instructions To Reproduce the 🐛 Bug:

  1. Full runnable code or full changes you made:
    No Change
  2. What exact command you run:
python apply_net.py show \
    configs/cse/densepose_rcnn_R_50_FPN_soft_animals_I0_finetune_m2m_16k.yaml \
    https://dl.fbaipublicfiles.com/densepose/cse/densepose_rcnn_R_50_FPN_soft_animals_finetune_maskonly_24k/267687159/model_final_354e61.pkl \
    image_examples \
    dp_vertex,bbox \
    -v \
    --output outputs/cse_animal2_test.png
  1. Full logs or other relevant observations:
    Only blue vertex on animal
    cse_animal2_test 0005
    cse_animal2_test 0006
    cse_animal2_test 0002
    cse_animal2_test 0004

  2. please simplify the steps as much as possible so they do not require additional resources to
    run, such as a private dataset.

Run the above command with the following input examples
image_examples.zip

Expected behavior:

Rainbox CSE Vertex on Cat
csv_example

Environment:

----------------------  --------------------------------------------------------------------------------------
sys.platform            linux
Python                  3.9.6 (default, Aug 18 2021, 19:38:01) [GCC 7.5.0]
numpy                   1.20.3
detectron2              0.5 @/host/media/sejongyang/AnimalFaceReenactment/3dmm_stylegan2/detectron2/detectron2
Compiler                GCC 7.5
CUDA compiler           CUDA 10.0
detectron2 arch flags   6.1
DETECTRON2_ENV_MODULE   <not set>
PyTorch                 1.9.0 @/root/miniconda3/envs/detectron2/lib/python3.9/site-packages/torch
PyTorch debug build     False
GPU available           Yes
GPU 0,1                 GeForce GTX 1080 Ti (arch=6.1)
CUDA_HOME               /usr/local/cuda
Pillow                  8.3.1
torchvision             0.10.0 @/root/miniconda3/envs/detectron2/lib/python3.9/site-packages/torchvision
torchvision arch flags  3.5, 5.0, 6.0, 7.0, 7.5
fvcore                  0.1.5.post20210825
iopath                  0.1.9
cv2                     4.5.3
----------------------  --------------------------------------------------------------------------------------
PyTorch built with:
  - GCC 7.3
  - C++ Version: 201402
  - Intel(R) oneAPI Math Kernel Library Version 2021.3-Product Build 20210617 for Intel(R) 64 architecture applications
  - Intel(R) MKL-DNN v2.1.2 (Git Hash 98be7e8afa711dc9b66c8ff3504129cb82013cdb)
  - OpenMP 201511 (a.k.a. OpenMP 4.5)
  - NNPACK is enabled
  - CPU capability usage: AVX2
  - CUDA Runtime 10.2
  - NVCC architecture flags: -gencode;arch=compute_37,code=sm_37;-gencode;arch=compute_50,code=sm_50;-gencode;arch=compute_60,code=sm_60;-gencode;arch=compute_61,code=sm_61;-gencode;arch=compute_70,code=sm_70;-gencode;arch=compute_75,code=sm_75;-gencode;arch=compute_37,code=compute_37
  - CuDNN 7.6.5
  - Magma 2.5.2
  - Build settings: BLAS_INFO=mkl, BUILD_TYPE=Release, CUDA_VERSION=10.2, CUDNN_VERSION=7.6.5, CXX_COMPILER=/opt/rh/devtoolset-7/root/usr/bin/c++, CXX_FLAGS= -Wno-deprecated -fvisibility-inlines-hidden -DUSE_PTHREADPOOL -fopenmp -DNDEBUG -DUSE_KINETO -DUSE_FBGEMM -DUSE_QNNPACK -DUSE_PYTORCH_QNNPACK -DUSE_XNNPACK -DSYMBOLICATE_MOBILE_DEBUG_HANDLE -O2 -fPIC -Wno-narrowing -Wall -Wextra -Werror=return-type -Wno-missing-field-initializers -Wno-type-limits -Wno-array-bounds -Wno-unknown-pragmas -Wno-sign-compare -Wno-unused-parameter -Wno-unused-variable -Wno-unused-function -Wno-unused-result -Wno-unused-local-typedefs -Wno-strict-overflow -Wno-strict-aliasing -Wno-error=deprecated-declarations -Wno-stringop-overflow -Wno-psabi -Wno-error=pedantic -Wno-error=redundant-decls -Wno-error=old-style-cast -fdiagnostics-color=always -faligned-new -Wno-unused-but-set-variable -Wno-maybe-uninitialized -fno-math-errno -fno-trapping-math -Werror=format -Wno-stringop-overflow, LAPACK_INFO=mkl, PERF_WITH_AVX=1, PERF_WITH_AVX2=1, PERF_WITH_AVX512=1, TORCH_VERSION=1.9.0, USE_CUDA=ON, USE_CUDNN=ON, USE_EXCEPTION_PTR=1, USE_GFLAGS=OFF, USE_GLOG=OFF, USE_MKL=ON, USE_MKLDNN=ON, USE_MPI=OFF, USE_NCCL=ON, USE_NNPACK=ON, USE_OPENMP=ON, 
@kingsj0405 kingsj0405 changed the title Please read & provide the following Only Blue CSE vertex on Cat Image Sep 7, 2021
@ppwwyyxx ppwwyyxx added the densepose issues specific to densepose label Sep 8, 2021
@kingsj0405
Copy link
Author

kingsj0405 commented Sep 26, 2021

It turns out every mesh file provided from DensePose has zero vertices.

This python code loads cat meshes for Continuous Surface Embedding.

import pickle

from detectron2.utils.file_io import PathManager

with PathManager.open("https://dl.fbaipublicfiles.com/densepose/meshes/cat_7466.pkl", "rb") as hFile:
    data = pickle.load(hFile)
    print(data)

But every vertex has zero values

{'vertices': array([[ 0.,  0.,  0.],
       [ 0.,  0.,  0.],
       [ 0.,  0.,  0.],
       ...,
       [-0.,  0.,  0.],
       [-0.,  0.,  0.],
       [-0.,  0.,  0.]]), 'faces': array([[0, 0, 0],
       [0, 0, 0],
       [0, 0, 0],
       ...,
       [0, 0, 0],
       [0, 0, 0],
       [0, 0, 0]])}

Is there a proper URL for each animal meshes?

Or should load_mesh_data on DensePose/densepose/structures/mesh.py be updated?

This is called on get_xyz_vertex_embedding on DensePose/densepose/vis/densepose_outputs_vertex.py

@kingsj0405
Copy link
Author

kingsj0405 commented Sep 27, 2021

(+) My builtin mesh catalog on DensePose/build/lib/densepose/data/meshes/builtin.py has the content as followings

# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved

from .catalog import MeshInfo, register_meshes

DENSEPOSE_MESHES_DIR = "https://dl.fbaipublicfiles.com/densepose/meshes/"

MESHES = [
    MeshInfo(
        name="smpl_27554",
        data="smpl_27554.pkl",
        geodists="geodists/geodists_smpl_27554.pkl",
        symmetry="symmetry/symmetry_smpl_27554.pkl",
        texcoords="texcoords/texcoords_smpl_27554.pkl",
    ),
    MeshInfo(
        name="chimp_5029",
        data="chimp_5029.pkl",
        geodists="geodists/geodists_chimp_5029.pkl",
        symmetry="symmetry/symmetry_chimp_5029.pkl",
        texcoords="texcoords/texcoords_chimp_5029.pkl",
    ),
    MeshInfo(
        name="cat_5001",
        data="cat_5001.pkl",
        geodists="geodists/geodists_cat_5001.pkl",
        symmetry="symmetry/symmetry_cat_5001.pkl",
        texcoords="texcoords/texcoords_cat_5001.pkl",
    ),
    MeshInfo(
        name="cat_7466",
        data="cat_7466.pkl",
        geodists="geodists/geodists_cat_7466.pkl",
        symmetry="symmetry/symmetry_cat_7466.pkl",
        texcoords="texcoords/texcoords_cat_7466.pkl",
    ),
    MeshInfo(
        name="sheep_5004",
        data="sheep_5004.pkl",
        geodists="geodists/geodists_sheep_5004.pkl",
        symmetry="symmetry/symmetry_sheep_5004.pkl",
        texcoords="texcoords/texcoords_sheep_5004.pkl",
    ),
    MeshInfo(
        name="zebra_5002",
        data="zebra_5002.pkl",
        geodists="geodists/geodists_zebra_5002.pkl",
        symmetry="symmetry/symmetry_zebra_5002.pkl",
        texcoords="texcoords/texcoords_zebra_5002.pkl",
    ),
    MeshInfo(
        name="horse_5004",
        data="horse_5004.pkl",
        geodists="geodists/geodists_horse_5004.pkl",
        symmetry="symmetry/symmetry_horse_5004.pkl",
        texcoords="texcoords/texcoords_zebra_5002.pkl",
    ),
    MeshInfo(
        name="giraffe_5002",
        data="giraffe_5002.pkl",
        geodists="geodists/geodists_giraffe_5002.pkl",
        symmetry="symmetry/symmetry_giraffe_5002.pkl",
        texcoords="texcoords/texcoords_giraffe_5002.pkl",
    ),
    MeshInfo(
        name="elephant_5002",
        data="elephant_5002.pkl",
        geodists="geodists/geodists_elephant_5002.pkl",
        symmetry="symmetry/symmetry_elephant_5002.pkl",
        texcoords="texcoords/texcoords_elephant_5002.pkl",
    ),
    MeshInfo(
        name="dog_5002",
        data="dog_5002.pkl",
        geodists="geodists/geodists_dog_5002.pkl",
        symmetry="symmetry/symmetry_dog_5002.pkl",
        texcoords="texcoords/texcoords_dog_5002.pkl",
    ),
    MeshInfo(
        name="dog_7466",
        data="dog_7466.pkl",
        geodists="geodists/geodists_dog_7466.pkl",
        symmetry="symmetry/symmetry_dog_7466.pkl",
        texcoords="texcoords/texcoords_dog_7466.pkl",
    ),
    MeshInfo(
        name="cow_5002",
        data="cow_5002.pkl",
        geodists="geodists/geodists_cow_5002.pkl",
        symmetry="symmetry/symmetry_cow_5002.pkl",
        texcoords="texcoords/texcoords_cow_5002.pkl",
    ),
    MeshInfo(
        name="bear_4936",
        data="bear_4936.pkl",
        geodists="geodists/geodists_bear_4936.pkl",
        symmetry="symmetry/symmetry_bear_4936.pkl",
        texcoords="texcoords/texcoords_bear_4936.pkl",
    ),
]

register_meshes(MESHES, DENSEPOSE_MESHES_DIR)

@kingsj0405
Copy link
Author

@vkhalidov @MarcSzafraniec
I found you on history of DensePose.
It could be helpful any kind of tips to solve the problem.
Thank you :)

@MarcSzafraniec
Copy link
Contributor

Hi ! For the moment, we do not release the mesh data that we use. We might release cats and dogs meshes soon though. I'll comment there if it's the case.

@kingsj0405
Copy link
Author

@MarcSzafraniec Thank you :)

@yasaminjafarian
Copy link

Is this issue solved? I still cannot get the Rainbox CSE Vertex. Just the blue mask.

@kingsj0405
Copy link
Author

@yasaminjafarian
The reason of the blue mask is zero-valued vertices of meshes.
So I utilize just value of embeddings to visualize now.

import cv2
import numpy as np
from torch.nn import functional as F

from densepose.vis.densepose_outputs_vertex import DensePoseOutputsVertexVisualizer

class DensePoseCSE3Visualizer(DensePoseOutputsVertexVisualizer):
    def __init__(
        self,
        cfg,
        inplace=True,
        cmap=cv2.COLORMAP_JET,
        alpha=0.7,
        device="cuda",
        **kwargs,
    ):
        self.inplace = inplace
        self.alpha = alpha
        self.device = device
    
    def _get_cse3(self, E, S, h, w):

        embedding_resized = F.interpolate(E, size=(h,w), mode="bilinear", align_corners=False)[0]
        coarse_segm_resized = F.interpolate(S, size=(h,w), mode="bilinear", align_corners=False)[0]
        mask = coarse_segm_resized.argmax(0) > 0

        cse3 = embedding_resized[:3,:,:].permute(1, 2, 0)
        return cse3, mask
    
    def _add_cse3_vis(self, image_bgr, mask, cse3, bbox_xywh):
        if self.inplace:
            image_target_bgr = image_bgr
        else:
            image_target_bgr = image_bgr * 0
        x, y, w, h = [int(v) for v in bbox_xywh]
        mask_bg = np.tile((mask == 0)[:,:,np.newaxis], [1,1,3])
        cse3 = cse3 - cse3.min()
        cse3 = cse3 / cse3.max() * 255.0
        cse3[mask_bg] = image_target_bgr[y : y + h, x : x + w, :][mask_bg]
        image_target_bgr[y : y + h, x : x + w, :] = (
            image_target_bgr[y : y + h, x : x + w, :] * (1.0 - self.alpha) + cse3 * self.alpha
        )
        return image_target_bgr.astype(np.uint8)


    def visualize(
        self,
        image_bgr,
        outputs_boxes_xywh_classes,
    ):

        if outputs_boxes_xywh_classes[0] is None:
            return image_bgr
        
        S, E, N, bboxes_xywh, _ = self.extract_and_check_outputs_and_boxes(
            outputs_boxes_xywh_classes
        )

        for n in range(N):
            x, y, w, h = bboxes_xywh[n].int().tolist()
            cse3, mask = self._get_cse3(
                E[[n]],
                S[[n]],
                h,
                w,
            )
            cse3 = cse3.cpu().numpy()
            mask = mask.cpu().numpy().astype(np.uint8)

            image_bgr = self._add_cse3_vis(image_bgr, mask, cse3, [x, y, w, h])
        
        return image_bgr

With the visualizer, you can get following results.
NOTE: This is not optimal and well-designed visualizer

image

image

image

@yasaminjafarian
Copy link

yasaminjafarian commented Nov 13, 2021

Thanks a lot @kingsj0405 for your response. I was just wondering how do you call the DensePoseCSE3Visualizer.visualize() here?

I save the data predicted from apply_net.py in a pickle file and load that:

!python apply_net.py dump \
    configs/cse/Base-DensePose-RCNN-FPN-Human.yaml \
    https://dl.fbaipublicfiles.com/densepose/cse/densepose_rcnn_R_50_FPN_s1x/251155172/model_final_c4ea5f.pkl \
    "0001_img.png" \
    --output 'outputs/results_human_cse.pkl' \
    -v 
with open('outputs/results_human_cse.pkl','rb') as f:
    data=pickle.load(f)

Then should I call the visualizer like this (bg is the background image hxwx3)? This gives me an error:

image_bgr = DensePoseCSE3Visualizer.visualize(bg, data)

Error: TypeError: visualize() missing 1 required positional argument: 'outputs_boxes_xywh_classes'

@kingsj0405
Copy link
Author

kingsj0405 commented Nov 14, 2021

@yanicklandry I added DensePoseCSE3Visualizer to VISUALIZER which is used in show action.

...
@register_action
class ShowAction(InferenceAction):
    """
    Show action that visualizes selected entries on an image
    """

    COMMAND: ClassVar[str] = "show"
    VISUALIZERS: ClassVar[Dict[str, object]] = {
        "dp_contour": DensePoseResultsContourVisualizer,
        "dp_segm": DensePoseResultsFineSegmentationVisualizer,
        "dp_u": DensePoseResultsUVisualizer,
        "dp_v": DensePoseResultsVVisualizer,
        "dp_iuv_texture": DensePoseResultsVisualizerWithTexture,
        "dp_cse_texture": DensePoseOutputsTextureVisualizer,
        "dp_vertex": DensePoseOutputsVertexVisualizer,
        "dp_cse3": DensePoseCSE3Visualizer,
        "bbox": ScoredBoundingBoxVisualizer,
    }
...

You need to run command like this

!python apply_net.py show configs/densepose_rcnn_R_50_FPN_s1x.yaml densepose_rcnn_R_50_FPN_s1x.pkl image.jpg dp_cse3,bbox --output image_densepose_contour.png

@yasaminjafarian
Copy link

@kingsj0405 Thank you very much for the code. I can run it now.
Also, I was wondering which config and weights do you use for the cats?

When I use this one, it shows a rainbow coloring:

!python apply_net_cse.py show \
    configs/cse/densepose_rcnn_R_50_FPN_soft_animals_CA_finetune_4k.yaml \
    https://dl.fbaipublicfiles.com/densepose/cse/densepose_rcnn_R_50_FPN_soft_animals_CA_finetune_4k/253498611/model_final_6d69b7.pkl \
    cat.jpg \
    dp_cse3,bbox \
    --output image_densepose_contour_cat3.png

image_densepose_contour_cat3 0001

But when I use this one or this one, the results are very pinkish:

!python apply_net_cse.py show \
    configs/cse/densepose_rcnn_R_50_FPN_soft_animals_I0_finetune_m2m_16k.yaml \
    https://dl.fbaipublicfiles.com/densepose/cse/densepose_rcnn_R_50_FPN_soft_animals_finetune_maskonly_24k/267687159/model_final_354e61.pkl \
    cat.jpg \
    dp_cse3,bbox \
    --output image_densepose_contour_cat.png

image_densepose_contour_cat 0001

@kingsj0405
Copy link
Author

@yasaminjafarian
Because I tried multiple times with messy environment, I'm not sure...
In my experience, any configuration and pretrained model in CSE works for the category cat.
Maybe the pinkish one have proper embedding on other cse channel.

@yasaminjafarian
Copy link

@kingsj0405 I see. Great, thank you very much.

@JonnyScream
Copy link

JonnyScream commented Feb 8, 2022

Hi ! For the moment, we do not release the mesh data that we use. We might release cats and dogs meshes soon though. I'll comment there if it's the case.
@MarcSzafraniec
Hi Marc,
is it possible to release the SMPL human model's ("smpl_27554.pkl") valid data that was used for the CSE paradigm (see reference below)?

Reference:
image

Thank you.

@runa91
Copy link

runa91 commented Aug 4, 2022

@MarcSzafraniec is there any update w.r.t the relase cat and dog meshes?

@MarcSzafraniec
Copy link
Contributor

Hello ! Sorry but we discussed with the legal team and we do not have the rights to distribute the meshes at the moment.

@runa91
Copy link

runa91 commented Aug 4, 2022

Could you map it to SMAL?

@sunflower110
Copy link

Hello friends, is there a way to publish the meshes, or is there a way to create the meshes personally?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
densepose issues specific to densepose
Projects
None yet
Development

No branches or pull requests

7 participants