Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[question]A problem of how to use MlpLstmPolicy in GAIL training? #1148

Closed
LongchaoDa opened this issue Jan 12, 2022 · 3 comments
Closed

[question]A problem of how to use MlpLstmPolicy in GAIL training? #1148

LongchaoDa opened this issue Jan 12, 2022 · 3 comments
Labels
question Further information is requested

Comments

@LongchaoDa
Copy link

LongchaoDa commented Jan 12, 2022

I was training a GAIL model with MlpLstmPolicy in Stable_Baselines2, however, I could not successfully run the training process even though: I made the

assert issubclass(self.policy, LstmPolicy)

in the TRPO part, Is there any changes I should make? Or If there is no other possible solution, how can i customize a LSTM policy to be compatible with GAIL for myself?

Looking forward to your reply!

The error happened is here:

Traceback (most recent call last): File "/home/.../train-recurrentGail.py", line 18, in <module>
    model.learn(total_timesteps=100000)
  File "/home/.../model/gail/model.py", line 57, in learn
    return super().learn(total_timesteps, callback, log_interval, tb_log_name, reset_num_timesteps)
  File "/home/.../model/gail/myTrpo.py", line 364, in learn
    seg = seg_gen.__next__()
  File "/home/.../common/runners.py", line 118, in traj_segment_generator
    action, vpred, states, _ = policy.step(observation.reshape(-1, *observation.shape), states, done)
  File "/home/.../stable_baselines/common/policies.py", line 508, in step
    {self.obs_ph: obs, self.states_ph: state, self.dones_ph: mask})
  File "/home/.../python/client/session.py", line 956, in run
    run_metadata_ptr)
  File "/home/.../python/client/session.py", line 1156, in _run
    (np_val.shape, subfeed_t.name, str(subfeed_t.get_shape())))
ValueError: Cannot feed value of shape () for Tensor 'input_1/dones_ph:0', which has shape '(1,)'

@LongchaoDa LongchaoDa changed the title A problem of how to use MlpLstmPolicy in GAIL training?[question] [question]A problem of how to use MlpLstmPolicy in GAIL training? Jan 12, 2022
@Miffyli Miffyli added the question Further information is requested label Jan 12, 2022
@Miffyli
Copy link
Collaborator

Miffyli commented Jan 12, 2022

LSTM is not supported in the pretraining. There was a PR adding this but it has then died out: #315 . For better imitation learning algos see imitation library, for example. Note that SB2 is not mantained anymore.

You may close this issue if you have no further questions.

@LongchaoDa
Copy link
Author

Thank you, so you mean the structure of Stable Baseline 2 is difficult to support custom policy LSTM into GAIL traning?
Maybe i will turn to the link you mentioned to explore more info.

@Miffyli
Copy link
Collaborator

Miffyli commented Jan 13, 2022

Thank you, so you mean the structure of Stable Baseline 2 is difficult to support custom policy LSTM into GAIL traning?
Yes it requires a bit of work (you may check the PR #315).

Closing issue as resolved :).

@Miffyli Miffyli closed this as completed Jan 13, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
question Further information is requested
Projects
None yet
Development

No branches or pull requests

2 participants