Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Running on Pandas data set #2

Open
andrewczgithub opened this issue Jul 29, 2019 · 8 comments
Open

Running on Pandas data set #2

andrewczgithub opened this issue Jul 29, 2019 · 8 comments

Comments

@andrewczgithub
Copy link

andrewczgithub commented Jul 29, 2019

Hi @YuliaRubanova
Really interested in the code and looking at building jupyter notebooks with alternate data sets.
I was just wondering how best to use the code on say a pandas data set?
Can this be used on a multivariate series?
Many thanks and best,
Andrew

@anandijain
Copy link

anandijain commented Jul 31, 2019

I would like to echo what Andrew said.
I can't seem to find a clean way to use my own dataset with the provided code.

From my digging, it looks like I need to create a file for my dataset, similar to person_activity.py.
I think the main source of my confusion is the sheer number of arguments and references to other files. It has made it difficult to trace down what code I need to write and how I need to adapt it to work with latent_ode.

If I can figure out how to make a simple model that runs on pandas, I might try to make a pull request.
I think that the way I would try to format it is using less dictionaries and more canonical ML/NN types:

simple classify framework

model = LatentODE()
for i, (x, y) in enumerate(data_loader): 
  optim.zero_grad()
  logits = model(x)
  loss = CrossEntropyLoss(logits, y)
  loss.backward()
  optim.step()

Thank you for providing latent_ode!
It is a fantastic concept and paper. I'm just having a hard time understanding how all of the files come together.

@YuliaRubanova
Copy link
Owner

Latent ODE code is designed to handle the datasets where time series are of different length, and each time series measured at different times. It can be used with multivariate time series. Different dimensions of data might be measured at different times as well.

All the datasets are loaded in lib/parse_datasets.py using torch.utils.data.DataLoader.

Easy case: all time series share observation times

This approach with be the easiest to use in Jupiter notebooks. A good example is Periodic_1d dataset (use input parameter --dataset periodic).

Data format

  • a dataset in form of a tensor [B x T x D], where B is number of time series in the dataset and N is number of time points in each time series, D — data dimensionality.
  • a vector [T] with observations times. Observation times can be irregular.

For example, in periodic dataset, observation times are generated here; dataset is generated here.

Steps

  1. Pass the dataset through DataLoader like here.
    In the code, train_y is the dataset; time_steps_extrap are the shared time points.

Optional
Here we assumed that time series share the time points. You can use the input parameter --sample-tp to randomize the time points that are used in each batch. For example, if dataset has 100 time points and we use --sample-tp 30, each batch will have 30 randomly sampled time points out of the original set of 100.

Time series have different observation times and/or have different length

This case is useful for real data. Good examples are PhysioNet class (—dataset physionet) and PersonActivity class (--dataset activity).

To use Latent ODE on your dataset, I recommend formatting your data as a list of records, as described below. You can use the collate function for DataLoader from here

Data format
Since each time series has different length, we represent the dataset as a list of records: [record1, record2, record3, …]. Each record represents one time series (e.g. one patient in Physionet).

Each record has the following format:
(record_id, observation_times, values, mask, labels)

  • record_id: an id of this string
  • observation_times: a 1-dimensional numpy array containing T time values of
  • values: a (T, D) numpy array containing observed D-dimensional values at T time points
  • mask: a (T, D) tensor containing 1 where values were observed and 0 otherwise. Useful if different dimensions are observed at different times. If all dimensions are observed at the same time, fill the mask with ones.
  • labels: a list of labels for the current patient, if labels are available. Otherwise None.

Pipeline of the Physionet dataset
To use it on your dataset, you need to only replace step 1 to produce the list of records.

  1. Physionet class loads the dataset from files (each patient is stored in its own file) and outputs the list of records (format described above) like so

  2. Physionet class is called in lib/parse_dataset.py to get a list of records. List of records is then split into train/test here

  3. Dataloader takes the list of records and collates them into batches like so Function that collates records into batches is here

  4. During the training, the model calls data loader to get a new batch.

@anandijain
Copy link

Thank you for this, this clarifies a ton!

You mention that labels is a list.
I have a list of astronomical objects that each time curve could be. (Not classifying per time point).
Since the classification target (astronomical object class) does not have any semantic meaning, it wouldn't make much since to do regression on it.

Do I make a list of the labels and then just have the labels in the record tuple be an index?

Does the shape of labels need to be the length of the time points?

Thanks again!

@YuliaRubanova
Copy link
Owner

Concerning the labels, we support two tasks right now:

  1. Binary classification per time series (see Physionet).

  2. Multi-class classification per time point (see PersonActivity)

Your case seems to be multi-class classification per time series, which is the blend of the two set-ups that we have now.

You can do the following:

  1. Convert the labels to one-hot encoding. For each time series, labels will be a binary vector [C], where C — number of classes. This solves the problem of semantic meaning that you mentioned.

  2. Use the collate function for Physionet from here, since it creates labels per time series.

    Set N_labels = C, where C is the number of classes.

    Note that we used normalization for Physionet dataset here. You can try with and without normalization and see if it helps the training on your case.

  3. In lib/parse_datasets.py set classif_per_tp to False for your dataset like so. This signifies the model that classification should be run per time series rather than per time point.

  4. Lastly, we need to modify the loss.
    Function for computing classification loss is called here if you are using Latent ODE model, or here for ODE-RNN. Make sure to use multi-class CE loss instead of binary CE loss. The multi-class CE loss compute_multiclass_CE_loss expects number of time points to be the third dimension. You can hack it by adding the third dimention of size 1 to label_predictions just before you call compute_multiclass_CE_loss.

@anandijain
Copy link

Thanks so much
you rock!

@anandijain
Copy link

  1. C == 14 like this: [0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]
    I have made my dataset produce records with binary classification labels using the record tuple form specified (id, t, vals, mask, labels)

  2. I changed N_labels in physionet to 14.
    I'm using the physionet collate fn and I think I'm using it correctly.

  3. I've changed classif_per_tp to False in parse_datasets.

  4. I'm getting stuck on this part, I unsqueezed the third dimension, but I'm getting the following error:

  File "latent_pandas.py", line 274, in <module>
    train_res = model.compute_all_losses(batch_dict, n_traj_samples = 1, kl_coef = kl_coef)
  File "/mnt/c/Users/***/home/fermi/cosmoNODE/cosmoNODE/latent_ode/lib/base_models.py", line 322, in compute_all_losses
    mask = batch_dict["mask_predicted_data"])
  File "/mnt/c/Users/***/home/fermi/cosmoNODE/cosmoNODE/latent_ode/lib/likelihood_eval.py", line 118, in compute_multiclass_CE_loss
    pred_mask = pred_mask.reshape(n_traj_samples * n_traj * n_tp,  n_dims)
RuntimeError: shape '[50, 14]' is invalid for input of size 4391800

For context, I've duplicated run_models into latent_pandas.

label_predictions is of size [1, 50, 1, 14]. (I set n_traj_samples to 1, because I didn't know exactly what it did)

This might be a cryptic error message, but I'm not exactly sure what I need to change.

If it's helpful here is my dataloader:

class FluxNet(object):
	def __init__(self):
		self.df = pd.read_csv(cosmoNODE.__path__[0] + '/../demos/data/training_set.csv')
		self.meta = pd.read_csv(cosmoNODE.__path__[0] + '/../demos/data/training_set_metadata.csv')
		self.merged = pd.merge(self.df, self.meta, on='object_id')

		self.mins = self.merged.min()
		self.maxes = self.merged.max()

		self.params = self.merged.columns.drop(['object_id', 'mjd', 'target'])
		self.classes = sorted(self.merged['target'].unique())
		self.num_classes = len(self.classes)

		self.labels = []

		self.groups = self.merged.groupby('object_id')
		self.curves = []
		self.get_curves()
		self.length = len(self.curves)
		print('fluxnet loaded')

	def get_curves(self):
		for i, group in enumerate(self.groups):
			object_id = group[0]
			data = group[1]
			times = torch.tensor(data['mjd'].values, dtype=torch.float)
			values = torch.tensor(data.drop(['mjd', 'target'], axis=1).fillna(0).values, dtype=torch.float)
			mask = torch.ones(values.shape, dtype=torch.float)
			target = data['target'].iloc[0]
			# label = labels.index(data['target'].iloc[0])
			label = one_hot(self.classes, target)
			self.labels.append(label)
			record = (object_id, times, values, mask, label)
			self.curves.append(record)

	def __getitem__(self, index):
		return self.curves[index]

	def __len__(self):
		# number of light curves in the dataset
		return self.length
	
	def get_label(self, index):
		return self.labels[index

Thanks for your help

@manikanthr5
Copy link

manikanthr5 commented Oct 7, 2019

Hi @YuliaRubanova

Thank you for your amazing paper and this code repository. I went through your code and your explanation for this issue, but I was unable to come to a conclusion on how to use your code for a generic Multivariate Time Series Classification with each feature observed at all times.

Could you please how to guide me on how to run your code given the following type of data (Multivariate Time Series Classification).

Data Description:
Input: is a Tensor [N * T * C] where:
N: Number of examples in Train / Test set. ex: 1000
B: Time Dimension: ex: 5/50/100 time steps.
C: Input Features for each observation. ex: 5 features.

Output: 0 or 1.

Thank you for your time.

@wfcgdut
Copy link

wfcgdut commented Sep 20, 2024

关于标签,我们现在支持两项任务:

  1. 每个时间序列的二元分类(参见 Physionet)。
  2. 每个时间点的多类分类(请参阅 PersonActivity)

您的情况似乎是每个时间序列的多类分类,这是我们现在拥有的两种设置的混合。

您可以执行以下操作:

  1. 将标签转换为 one-hot 编码。对于每个时间序列,标签将是一个二进制向量 [C],其中 C — 类的数量。这解决了你提到的语义问题。
  2. 从这里使用 Physionet 的 collate 函数,因为它为每个时间序列创建标签。
    Set ,其中 C 是类的数量。N_labels = C
    请注意,我们在这里对 Physionet 数据集使用了归一化。您可以尝试使用和不使用标准化,看看它是否有助于您的案例培训。
  3. In 设置为数据集的 False,如下所示。这表示模型应该按时间序列而不是按时间点运行分类。lib/parse_datasets.py``classif_per_tp
  4. 最后,我们需要修改损失。如果您使用的是 Latent ODE 模型,则在此处调用计算分类损失的函数,或者在此处调用 ODE-RNN。确保使用多类 CE 损失,而不是二进制 CE 损失。多类 CE 损失预期时间点数是第三个维度。您可以通过将大小为 1 的第三个维度添加到 Just Before you call 来破解它。compute_multiclass_CE_loss``label_predictions``compute_multiclass_CE_loss

Hello, may I ask if you have any dependencies to run this code and if there are any files similar to "requirements. txt"

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

5 participants