From fc0f804f42a0deb8848fc8e909517fc441b80e70 Mon Sep 17 00:00:00 2001 From: woodenbirds <1979309725@qq.com> Date: Tue, 19 Nov 2024 13:53:41 -0800 Subject: [PATCH] fix wilor_home --- _wilor_helper.py | 46 ++++++++++++++++++++++++++++++---------------- 1 file changed, 30 insertions(+), 16 deletions(-) diff --git a/_wilor_helper.py b/_wilor_helper.py index 2ae7143..f0dab6c 100644 --- a/_wilor_helper.py +++ b/_wilor_helper.py @@ -6,14 +6,12 @@ import numpy as np from dataclasses import dataclass from torch import Tensor -from typing import Any, Generator, Literal, TypedDict - -from wilor.models import load_wilor -from wilor.utils import recursive_to -from wilor.datasets.vitdet_dataset import ViTDetDataset -from wilor.utils.renderer import Renderer, cam_crop_to_full +from typing import Literal, TypedDict from ultralytics import YOLO from jaxtyping import Float, Int + +import sys + LIGHT_PURPLE=(0.25098039, 0.274117647, 0.65882353) @@ -67,16 +65,32 @@ class _RawHamerOutputs: class WiLoRHelper: def __init__(self, wilor_home: str="./"): - model, model_cfg = load_wilor(wilor_home, checkpoint_path = './pretrained_models/wilor_final.ckpt' , - cfg_path= './pretrained_models/model_config.yaml') - detector = YOLO(os.path.join(wilor_home, 'pretrained_models/detector.pt')) - # Setup the renderer - self.renderer = Renderer(model_cfg, faces=model.mano.faces) - - self.device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') - self._model = model.to(self.device) - self._model_cfg = model_cfg - self._detector = detector.to(self.device) + sys.path.append(wilor_home) + global load_wilor + global recursive_to + global ViTDetDataset + global Renderer + global cam_crop_to_full + + from wilor.models import load_wilor + from wilor.utils import recursive_to + from wilor.datasets.vitdet_dataset import ViTDetDataset + from wilor.utils.renderer import Renderer, cam_crop_to_full + checkpoint_path = os.path.join(wilor_home, 'pretrained_models/wilor_final.ckpt') + cfg_path = os.path.join(wilor_home, 'pretrained_models/model_config.yaml') + original_dir = os.getcwd() + os.chdir(wilor_home) + model, model_cfg = load_wilor(checkpoint_path = checkpoint_path , + cfg_path= cfg_path) + os.chdir(original_dir) + detector = YOLO(os.path.join(wilor_home, 'pretrained_models/detector.pt')) + # Setup the renderer + self.renderer = Renderer(model_cfg, faces=model.mano.faces) + + self.device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') + self._model = model.to(self.device) + self._model_cfg = model_cfg + self._detector = detector.to(self.device) def look_for_hands( self,