Skip to content

Commit

Permalink
fix wilor_home
Browse files Browse the repository at this point in the history
  • Loading branch information
annie-liyunqi committed Nov 19, 2024
1 parent 33c982e commit fc0f804
Showing 1 changed file with 30 additions and 16 deletions.
46 changes: 30 additions & 16 deletions _wilor_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,12 @@
import numpy as np
from dataclasses import dataclass
from torch import Tensor
from typing import Any, Generator, Literal, TypedDict

from wilor.models import load_wilor
from wilor.utils import recursive_to
from wilor.datasets.vitdet_dataset import ViTDetDataset
from wilor.utils.renderer import Renderer, cam_crop_to_full
from typing import Literal, TypedDict
from ultralytics import YOLO
from jaxtyping import Float, Int

import sys

LIGHT_PURPLE=(0.25098039, 0.274117647, 0.65882353)


Expand Down Expand Up @@ -67,16 +65,32 @@ class _RawHamerOutputs:

class WiLoRHelper:
def __init__(self, wilor_home: str="./"):
model, model_cfg = load_wilor(wilor_home, checkpoint_path = './pretrained_models/wilor_final.ckpt' ,
cfg_path= './pretrained_models/model_config.yaml')
detector = YOLO(os.path.join(wilor_home, 'pretrained_models/detector.pt'))
# Setup the renderer
self.renderer = Renderer(model_cfg, faces=model.mano.faces)

self.device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
self._model = model.to(self.device)
self._model_cfg = model_cfg
self._detector = detector.to(self.device)
sys.path.append(wilor_home)
global load_wilor
global recursive_to
global ViTDetDataset
global Renderer
global cam_crop_to_full

from wilor.models import load_wilor
from wilor.utils import recursive_to
from wilor.datasets.vitdet_dataset import ViTDetDataset
from wilor.utils.renderer import Renderer, cam_crop_to_full
checkpoint_path = os.path.join(wilor_home, 'pretrained_models/wilor_final.ckpt')
cfg_path = os.path.join(wilor_home, 'pretrained_models/model_config.yaml')
original_dir = os.getcwd()
os.chdir(wilor_home)
model, model_cfg = load_wilor(checkpoint_path = checkpoint_path ,
cfg_path= cfg_path)
os.chdir(original_dir)
detector = YOLO(os.path.join(wilor_home, 'pretrained_models/detector.pt'))
# Setup the renderer
self.renderer = Renderer(model_cfg, faces=model.mano.faces)

self.device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
self._model = model.to(self.device)
self._model_cfg = model_cfg
self._detector = detector.to(self.device)

def look_for_hands(
self,
Expand Down

0 comments on commit fc0f804

Please sign in to comment.