|
6 | 6 |
|
7 | 7 | import torch |
8 | 8 | from einops import rearrange |
9 | | -from helios.data.constants import Modality |
10 | | -from helios.nn.flexihelios import Encoder, TokensAndMasks |
11 | | -from helios.train.masking import MaskedHeliosSample, MaskValue |
12 | 9 | from olmo_core.config import Config |
13 | 10 | from olmo_core.distributed.checkpoint import load_model_and_optim_state |
| 11 | +from olmoearth_pretrain.data.constants import Modality |
| 12 | +from olmoearth_pretrain.nn.flexihelios import Encoder, TokensAndMasks |
| 13 | +from olmoearth_pretrain.train.masking import MaskedOlmoEarthSample, MaskValue |
14 | 14 | from upath import UPath |
15 | 15 |
|
16 | 16 | from rslp.log_utils import get_logger |
@@ -53,7 +53,7 @@ def __init__( |
53 | 53 | selector: an optional sequence of attribute names or list indices to select |
54 | 54 | the sub-module that should be applied on the input images. |
55 | 55 | forward_kwargs: additional arguments to pass to forward pass besides the |
56 | | - MaskedHeliosSample. |
| 56 | + MaskedOlmoEarthSample. |
57 | 57 | random_initialization: whether to skip loading the checkpoint so the |
58 | 58 | weights are randomly initialized. In this case, the checkpoint is only |
59 | 59 | used to define the model architecture. |
@@ -148,7 +148,7 @@ def forward(self, inputs: list[dict[str, Any]]) -> list[torch.Tensor]: |
148 | 148 | timestamps[:, :, 2] = 2024 # year |
149 | 149 | kwargs["timestamps"] = timestamps |
150 | 150 |
|
151 | | - sample = MaskedHeliosSample(**kwargs) |
| 151 | + sample = MaskedOlmoEarthSample(**kwargs) |
152 | 152 |
|
153 | 153 | # Decide context based on self.autocast_dtype. |
154 | 154 | if self.autocast_dtype is None: |
|
0 commit comments