Skip to content

Commit

Permalink
polishing
Browse files Browse the repository at this point in the history
  • Loading branch information
SermetPekin committed Dec 9, 2024
1 parent 52dcc59 commit 64d5244
Show file tree
Hide file tree
Showing 19 changed files with 332 additions and 8,374 deletions.
21 changes: 11 additions & 10 deletions include/adam.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,11 +27,9 @@
// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
// THE SOFTWARE.
// */
class AdamOptimizer
{
/*
arXiv preprint arXiv:1412.6980 , Diederik P Kingma, Jimmy Ba

/*
Adam maintains two moving averages for each parameter:
First Moment Estimate (Mean):
Expand All @@ -54,6 +52,9 @@ class AdamOptimizer
*/
class AdamOptimizer
{

public:
double lr; // Learning rate
double beta1; // Exponential decay rate for the first moment
Expand Down Expand Up @@ -81,30 +82,30 @@ class AdamOptimizer
}
}

// Step function to update parameters

void step()
{
t++; // Increment time step
t++; // time step
for (auto &param : params)
{
double g = param->grad; // Gradient of the parameter

// Update first moment estimate (mean)
// first moment estimate (mean)
m[param.get()] = beta1 * m[param.get()] + (1.0 - beta1) * g;

// Update second moment estimate (uncentered variance)
// second moment estimate (uncentered variance)
v[param.get()] = beta2 * v[param.get()] + (1.0 - beta2) * g * g;

// Compute bias-corrected estimates
double m_hat = m[param.get()] / (1.0 - std::pow(beta1, t));
double v_hat = v[param.get()] / (1.0 - std::pow(beta2, t));

// Update parameter

param->data -= lr * m_hat / (std::sqrt(v_hat) + epsilon);
}
}

// Zero gradients for the next step

void zero_grad()
{
for (auto &param : params)
Expand Down
147 changes: 0 additions & 147 deletions include/column.hpp

This file was deleted.

128 changes: 0 additions & 128 deletions include/csv_table.hpp

This file was deleted.

12 changes: 6 additions & 6 deletions include/data_utils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -51,20 +51,20 @@ using vv_double = std::vector<std::vector<double>>;

inline
void display_data(const ColRows& inputs, const ColRows& targets, const std::vector<std::string>& column_names ) {
// Print column headers
// column headers
for (const auto& col_name : column_names) {
std::cout << std::setw(15) << std::left << col_name;
}
std::cout << std::setw(15) << std::left << "target";
std::cout << "\n";

// Print separator line
// separator line
for (size_t i = 0; i < column_names.size() + 1; ++i) {
std::cout << std::setw(15) << std::setfill('-') << "" << std::setfill(' ');
}
std::cout << "\n";

// Print rows of data
// rows of data
for (size_t i = 0; i < inputs.size(); ++i) {
for (const auto& value : inputs[i]) {
std::cout << std::setw(15) << std::left << value->data;
Expand All @@ -78,7 +78,7 @@ void display_data(const ColRows& inputs, const ColRows& targets, const std::vect
std::cout << "\n";
}

// Print final separator line
std::cout << "========================\n";
}

Expand All @@ -105,7 +105,7 @@ inline void write_to_csv(const std::vector<std::vector<T>> &data, const std::str
{
std::ofstream file(filename);

// Check if the file is open

if (!file.is_open())
{
std::cerr << "Error: Could not open file " << filename << " for writing.\n";
Expand Down Expand Up @@ -195,7 +195,7 @@ inline vv_double convert_to_double_with_encoding(const vv_string &data, bool has
}
}

// Encode the target column using the encoding map

std::string target_value = data[i].back();
if (target_encoding_map.find(target_value) == target_encoding_map.end())
{
Expand Down
4 changes: 2 additions & 2 deletions include/micrograd.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,8 @@ THE SOFTWARE.
// #include "dataset.hpp"
#include "loss.hpp"
#include "csv.hpp"
#include "column.hpp"
#include "csv_table.hpp"
// #include "column.hpp"
// #include "csv_table.hpp"
#include "mlp.hpp"
#include "sgd.hpp"
#include "console_utils.hpp"
Expand Down
Loading

0 comments on commit 64d5244

Please sign in to comment.