diff --git a/cosmos_transfer1/auxiliary/depth_anything/model/depth_anything.py b/cosmos_transfer1/auxiliary/depth_anything/model/depth_anything.py index 072b0fbf..d3933b7b 100644 --- a/cosmos_transfer1/auxiliary/depth_anything/model/depth_anything.py +++ b/cosmos_transfer1/auxiliary/depth_anything/model/depth_anything.py @@ -34,13 +34,14 @@ def __init__(self): self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # Load image processor and model with half precision log.info(f"Loading Depth Anything model - {DEPTH_ANYTHING_MODEL_CHECKPOINT}...") + checkpoint = os.getenv("CHECKPOINT_DIR", "checkpoints") self.image_processor = AutoImageProcessor.from_pretrained( - DEPTH_ANYTHING_MODEL_CHECKPOINT, + os.path.join(checkpoint, DEPTH_ANYTHING_MODEL_CHECKPOINT), torch_dtype=torch.float16, trust_remote_code=True, ) self.model = AutoModelForDepthEstimation.from_pretrained( - DEPTH_ANYTHING_MODEL_CHECKPOINT, + os.path.join(checkpoint, DEPTH_ANYTHING_MODEL_CHECKPOINT), torch_dtype=torch.float16, trust_remote_code=True, ).to(self.device)