Skip to content

Commit

Permalink
chore: release v3.1.0 (merge)
Browse files Browse the repository at this point in the history
  • Loading branch information
eduardocarvp committed Jan 12, 2021
2 parents cab643b + 679e1d4 commit eeca3e8
Show file tree
Hide file tree
Showing 18 changed files with 155 additions and 59 deletions.
14 changes: 14 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,4 +1,18 @@

# [3.1.0](https://github.com/dreamquark-ai/tabnet/compare/v3.0.0...v3.1.0) (2021-01-12)


### Bug Fixes

* n_a not being used ([7ae20c9](https://github.com/dreamquark-ai/tabnet/commit/7ae20c98a601da95040b9ecf79eac19f1d3e4a7b))


### Features

* save and load preds_mapper ([cab643b](https://github.com/dreamquark-ai/tabnet/commit/cab643b156fdecfded51d70d29072fc43f397bbb))



# [3.0.0](https://github.com/dreamquark-ai/tabnet/compare/v2.0.1...v3.0.0) (2020-12-15)


Expand Down
21 changes: 17 additions & 4 deletions docs/_modules/pytorch_tabnet/abstract_model.html
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,7 @@ <h1>Source code for pytorch_tabnet.abstract_model</h1><div class="highlight"><pr
<span class="n">validate_eval_set</span><span class="p">,</span>
<span class="n">create_dataloaders</span><span class="p">,</span>
<span class="n">define_device</span><span class="p">,</span>
<span class="n">ComplexEncoder</span><span class="p">,</span>
<span class="p">)</span>
<span class="kn">from</span> <span class="nn">pytorch_tabnet.callbacks</span> <span class="kn">import</span> <span class="p">(</span>
<span class="n">CallbackContainer</span><span class="p">,</span>
Expand Down Expand Up @@ -491,6 +492,10 @@ <h1>Source code for pytorch_tabnet.abstract_model</h1><div class="highlight"><pr

<span class="bp">self</span><span class="o">.</span><span class="n">network</span><span class="o">.</span><span class="n">load_state_dict</span><span class="p">(</span><span class="n">update_state_dict</span><span class="p">)</span></div>

<div class="viewcode-block" id="TabModel.load_class_attrs"><a class="viewcode-back" href="../../generated_docs/pytorch_tabnet.html#pytorch_tabnet.abstract_model.TabModel.load_class_attrs">[docs]</a> <span class="k">def</span> <span class="nf">load_class_attrs</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">class_attrs</span><span class="p">):</span>
<span class="k">for</span> <span class="n">attr_name</span><span class="p">,</span> <span class="n">attr_value</span> <span class="ow">in</span> <span class="n">class_attrs</span><span class="o">.</span><span class="n">items</span><span class="p">():</span>
<span class="nb">setattr</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">attr_name</span><span class="p">,</span> <span class="n">attr_value</span><span class="p">)</span></div>

<div class="viewcode-block" id="TabModel.save_model"><a class="viewcode-back" href="../../generated_docs/pytorch_tabnet.html#pytorch_tabnet.abstract_model.TabModel.save_model">[docs]</a> <span class="k">def</span> <span class="nf">save_model</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">path</span><span class="p">):</span>
<span class="sd">&quot;&quot;&quot;Saving TabNet model in two distinct files.</span>

Expand All @@ -506,19 +511,26 @@ <h1>Source code for pytorch_tabnet.abstract_model</h1><div class="highlight"><pr

<span class="sd"> &quot;&quot;&quot;</span>
<span class="n">saved_params</span> <span class="o">=</span> <span class="p">{}</span>
<span class="n">init_params</span> <span class="o">=</span> <span class="p">{}</span>
<span class="k">for</span> <span class="n">key</span><span class="p">,</span> <span class="n">val</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">get_params</span><span class="p">()</span><span class="o">.</span><span class="n">items</span><span class="p">():</span>
<span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">val</span><span class="p">,</span> <span class="nb">type</span><span class="p">):</span>
<span class="c1"># Don&#39;t save torch specific params</span>
<span class="k">continue</span>
<span class="k">else</span><span class="p">:</span>
<span class="n">saved_params</span><span class="p">[</span><span class="n">key</span><span class="p">]</span> <span class="o">=</span> <span class="n">val</span>
<span class="n">init_params</span><span class="p">[</span><span class="n">key</span><span class="p">]</span> <span class="o">=</span> <span class="n">val</span>
<span class="n">saved_params</span><span class="p">[</span><span class="s2">&quot;init_params&quot;</span><span class="p">]</span> <span class="o">=</span> <span class="n">init_params</span>

<span class="n">class_attrs</span> <span class="o">=</span> <span class="p">{</span>
<span class="s2">&quot;preds_mapper&quot;</span><span class="p">:</span> <span class="bp">self</span><span class="o">.</span><span class="n">preds_mapper</span>
<span class="p">}</span>
<span class="n">saved_params</span><span class="p">[</span><span class="s2">&quot;class_attrs&quot;</span><span class="p">]</span> <span class="o">=</span> <span class="n">class_attrs</span>

<span class="c1"># Create folder</span>
<span class="n">Path</span><span class="p">(</span><span class="n">path</span><span class="p">)</span><span class="o">.</span><span class="n">mkdir</span><span class="p">(</span><span class="n">parents</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span> <span class="n">exist_ok</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>

<span class="c1"># Save models params</span>
<span class="k">with</span> <span class="nb">open</span><span class="p">(</span><span class="n">Path</span><span class="p">(</span><span class="n">path</span><span class="p">)</span><span class="o">.</span><span class="n">joinpath</span><span class="p">(</span><span class="s2">&quot;model_params.json&quot;</span><span class="p">),</span> <span class="s2">&quot;w&quot;</span><span class="p">,</span> <span class="n">encoding</span><span class="o">=</span><span class="s2">&quot;utf8&quot;</span><span class="p">)</span> <span class="k">as</span> <span class="n">f</span><span class="p">:</span>
<span class="n">json</span><span class="o">.</span><span class="n">dump</span><span class="p">(</span><span class="n">saved_params</span><span class="p">,</span> <span class="n">f</span><span class="p">)</span>
<span class="n">json</span><span class="o">.</span><span class="n">dump</span><span class="p">(</span><span class="n">saved_params</span><span class="p">,</span> <span class="n">f</span><span class="p">,</span> <span class="bp">cls</span><span class="o">=</span><span class="n">ComplexEncoder</span><span class="p">)</span>

<span class="c1"># Save state_dict</span>
<span class="n">torch</span><span class="o">.</span><span class="n">save</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">network</span><span class="o">.</span><span class="n">state_dict</span><span class="p">(),</span> <span class="n">Path</span><span class="p">(</span><span class="n">path</span><span class="p">)</span><span class="o">.</span><span class="n">joinpath</span><span class="p">(</span><span class="s2">&quot;network.pt&quot;</span><span class="p">))</span>
Expand All @@ -539,7 +551,7 @@ <h1>Source code for pytorch_tabnet.abstract_model</h1><div class="highlight"><pr
<span class="k">with</span> <span class="n">zipfile</span><span class="o">.</span><span class="n">ZipFile</span><span class="p">(</span><span class="n">filepath</span><span class="p">)</span> <span class="k">as</span> <span class="n">z</span><span class="p">:</span>
<span class="k">with</span> <span class="n">z</span><span class="o">.</span><span class="n">open</span><span class="p">(</span><span class="s2">&quot;model_params.json&quot;</span><span class="p">)</span> <span class="k">as</span> <span class="n">f</span><span class="p">:</span>
<span class="n">loaded_params</span> <span class="o">=</span> <span class="n">json</span><span class="o">.</span><span class="n">load</span><span class="p">(</span><span class="n">f</span><span class="p">)</span>
<span class="n">loaded_params</span><span class="p">[</span><span class="s2">&quot;device_name&quot;</span><span class="p">]</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">device_name</span>
<span class="n">loaded_params</span><span class="p">[</span><span class="s2">&quot;init_params&quot;</span><span class="p">][</span><span class="s2">&quot;device_name&quot;</span><span class="p">]</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">device_name</span>
<span class="k">with</span> <span class="n">z</span><span class="o">.</span><span class="n">open</span><span class="p">(</span><span class="s2">&quot;network.pt&quot;</span><span class="p">)</span> <span class="k">as</span> <span class="n">f</span><span class="p">:</span>
<span class="k">try</span><span class="p">:</span>
<span class="n">saved_state_dict</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">load</span><span class="p">(</span><span class="n">f</span><span class="p">,</span> <span class="n">map_location</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">device</span><span class="p">)</span>
Expand All @@ -554,11 +566,12 @@ <h1>Source code for pytorch_tabnet.abstract_model</h1><div class="highlight"><pr
<span class="k">except</span> <span class="ne">KeyError</span><span class="p">:</span>
<span class="k">raise</span> <span class="ne">KeyError</span><span class="p">(</span><span class="s2">&quot;Your zip file is missing at least one component&quot;</span><span class="p">)</span>

<span class="bp">self</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="o">**</span><span class="n">loaded_params</span><span class="p">)</span>
<span class="bp">self</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="o">**</span><span class="n">loaded_params</span><span class="p">[</span><span class="s2">&quot;init_params&quot;</span><span class="p">])</span>

<span class="bp">self</span><span class="o">.</span><span class="n">_set_network</span><span class="p">()</span>
<span class="bp">self</span><span class="o">.</span><span class="n">network</span><span class="o">.</span><span class="n">load_state_dict</span><span class="p">(</span><span class="n">saved_state_dict</span><span class="p">)</span>
<span class="bp">self</span><span class="o">.</span><span class="n">network</span><span class="o">.</span><span class="n">eval</span><span class="p">()</span>
<span class="bp">self</span><span class="o">.</span><span class="n">load_class_attrs</span><span class="p">(</span><span class="n">loaded_params</span><span class="p">[</span><span class="s2">&quot;class_attrs&quot;</span><span class="p">])</span>

<span class="k">return</span></div>

Expand Down
6 changes: 3 additions & 3 deletions docs/_modules/pytorch_tabnet/callbacks.html
Original file line number Diff line number Diff line change
Expand Up @@ -265,8 +265,8 @@ <h1>Source code for pytorch_tabnet.callbacks</h1><div class="highlight"><pre>
<span class="sd"> minimum change in monitored value to qualify as improvement.</span>
<span class="sd"> This number should be positive.</span>
<span class="sd"> patience : integer</span>
<span class="sd"> number of epochs to wait for improvment before terminating.</span>
<span class="sd"> the counter be reset after each improvment</span>
<span class="sd"> number of epochs to wait for improvement before terminating.</span>
<span class="sd"> the counter be reset after each improvement</span>

<span class="sd"> &quot;&quot;&quot;</span>

Expand Down Expand Up @@ -312,7 +312,7 @@ <h1>Source code for pytorch_tabnet.callbacks</h1><div class="highlight"><pre>
<span class="bp">self</span><span class="o">.</span><span class="n">trainer</span><span class="o">.</span><span class="n">network</span><span class="o">.</span><span class="n">load_state_dict</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">best_weights</span><span class="p">)</span>

<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">stopped_epoch</span> <span class="o">&gt;</span> <span class="mi">0</span><span class="p">:</span>
<span class="n">msg</span> <span class="o">=</span> <span class="sa">f</span><span class="s2">&quot;</span><span class="se">\n</span><span class="s2">Early stopping occured at epoch </span><span class="si">{</span><span class="bp">self</span><span class="o">.</span><span class="n">stopped_epoch</span><span class="si">}</span><span class="s2">&quot;</span>
<span class="n">msg</span> <span class="o">=</span> <span class="sa">f</span><span class="s2">&quot;</span><span class="se">\n</span><span class="s2">Early stopping occurred at epoch </span><span class="si">{</span><span class="bp">self</span><span class="o">.</span><span class="n">stopped_epoch</span><span class="si">}</span><span class="s2">&quot;</span>
<span class="n">msg</span> <span class="o">+=</span> <span class="p">(</span>
<span class="sa">f</span><span class="s2">&quot; with best_epoch = </span><span class="si">{</span><span class="bp">self</span><span class="o">.</span><span class="n">best_epoch</span><span class="si">}</span><span class="s2"> and &quot;</span>
<span class="o">+</span> <span class="sa">f</span><span class="s2">&quot;best_</span><span class="si">{</span><span class="bp">self</span><span class="o">.</span><span class="n">early_stopping_metric</span><span class="si">}</span><span class="s2"> = </span><span class="si">{</span><span class="nb">round</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">best_loss</span><span class="p">,</span> <span class="mi">5</span><span class="p">)</span><span class="si">}</span><span class="s2">&quot;</span>
Expand Down
8 changes: 4 additions & 4 deletions docs/_modules/pytorch_tabnet/metrics.html
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,7 @@ <h1>Source code for pytorch_tabnet.metrics</h1><div class="highlight"><pre>
<span class="sd"> y_pred : torch.Tensor or np.array</span>
<span class="sd"> Reconstructed prediction (with embeddings)</span>
<span class="sd"> embedded_x : torch.Tensor</span>
<span class="sd"> Orginal input embedded by network</span>
<span class="sd"> Original input embedded by network</span>
<span class="sd"> obf_vars : torch.Tensor</span>
<span class="sd"> Binary mask for obfuscated variables.</span>
<span class="sd"> 1 means the variable was obfuscated so reconstruction is based on this.</span>
Expand Down Expand Up @@ -217,7 +217,7 @@ <h1>Source code for pytorch_tabnet.metrics</h1><div class="highlight"><pre>
<span class="sd"> y_pred : torch.Tensor or np.array</span>
<span class="sd"> Reconstructed prediction (with embeddings)</span>
<span class="sd"> embedded_x : torch.Tensor</span>
<span class="sd"> Orginal input embedded by network</span>
<span class="sd"> Original input embedded by network</span>
<span class="sd"> obf_vars : torch.Tensor</span>
<span class="sd"> Binary mask for obfuscated variables.</span>
<span class="sd"> 1 means the variables was obfuscated so reconstruction is based on this.</span>
Expand Down Expand Up @@ -509,7 +509,7 @@ <h1>Source code for pytorch_tabnet.metrics</h1><div class="highlight"><pre>
<div class="viewcode-block" id="RMSLE"><a class="viewcode-back" href="../../generated_docs/pytorch_tabnet.html#pytorch_tabnet.metrics.RMSLE">[docs]</a><span class="k">class</span> <span class="nc">RMSLE</span><span class="p">(</span><span class="n">Metric</span><span class="p">):</span>
<span class="sd">&quot;&quot;&quot;</span>
<span class="sd"> Mean squared logarithmic error regression loss.</span>
<span class="sd"> Scikit-imeplementation:</span>
<span class="sd"> Scikit-implementation:</span>
<span class="sd"> https://scikit-learn.org/stable/modules/generated/sklearn.metrics.mean_squared_log_error.html</span>
<span class="sd"> Note: In order to avoid error, negative predictions are clipped to 0.</span>
<span class="sd"> This means that you should clip negative predictions manually after calling predict.</span>
Expand Down Expand Up @@ -557,7 +557,7 @@ <h1>Source code for pytorch_tabnet.metrics</h1><div class="highlight"><pre>
<span class="sd"> y_pred : torch.Tensor or np.array</span>
<span class="sd"> Reconstructed prediction (with embeddings)</span>
<span class="sd"> embedded_x : torch.Tensor</span>
<span class="sd"> Orginal input embedded by network</span>
<span class="sd"> Original input embedded by network</span>
<span class="sd"> obf_vars : torch.Tensor</span>
<span class="sd"> Binary mask for obfuscated variables.</span>
<span class="sd"> 1 means the variables was obfuscated so reconstruction is based on this.</span>
Expand Down
Loading

0 comments on commit eeca3e8

Please sign in to comment.