Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 45 additions & 1 deletion src/cpp/cli/lemonade_client.cpp
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
#include "lemon_cli/lemonade_client.h"
#include <httplib.h>
#include <iostream>
#include <algorithm>
#include <iomanip>
#include <regex>
#include <sstream>
#include <nlohmann/json.hpp>

Expand All @@ -13,6 +15,40 @@ static const int DEFAULT_CONNECTION_TIMEOUT_MS = 30000;
static const int DEFAULT_READ_TIMEOUT_MS = 30000;
static const int LONG_TIMEOUT_MS = 86400000;

static std::regex build_name_filter_regex(const std::string& name_filter) {
std::string regex_pattern;
regex_pattern.reserve(name_filter.size() * 2);

for (char ch : name_filter) {
switch (ch) {
case '*':
regex_pattern += ".*";
break;
case '\\':
case '^':
case '$':
case '.':
case '|':
case '?':
case '+':
case '(':
case ')':
case '[':
case ']':
case '{':
case '}':
regex_pattern += '\\';
regex_pattern += ch;
break;
default:
regex_pattern += ch;
break;
}
}

return std::regex(regex_pattern, std::regex_constants::ECMAScript | std::regex_constants::icase);
}

HttpError::HttpError(int status, std::string body, const std::string& message)
: std::runtime_error(message), status_code_(status), response_body_(std::move(body)) {}

Expand Down Expand Up @@ -297,10 +333,18 @@ std::vector<ModelInfo> LemonadeClient::get_models(bool show_all) const {
return models;
}

int LemonadeClient::list_models(bool show_all) const {
int LemonadeClient::list_models(bool show_all, const std::string& name_filter) const {
try {
std::vector<ModelInfo> models = get_models(show_all);

if (!name_filter.empty()) {
const std::regex filter_regex = build_name_filter_regex(name_filter);
models.erase(
std::remove_if(models.begin(), models.end(),
[&](const ModelInfo& m) { return !std::regex_search(m.id, filter_regex); }),
models.end());
}

if (models.empty()) {
std::cout << "No models available" << std::endl;
return 0;
Expand Down
8 changes: 6 additions & 2 deletions src/cpp/cli/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,7 @@ struct CliConfig {
int port = 13305;
std::string api_key;
std::string model;
std::string list_filter;
std::map<std::string, std::string> checkpoints;
std::string recipe;
std::vector<std::string> labels;
Expand Down Expand Up @@ -1012,7 +1013,10 @@ int main(int argc, char* argv[]) {
CLI::App* cleanup_cmd = app.add_subcommand("cleanup-cache", "Clean up orphaned files in HuggingFace cache")->group("Model management");

// List options
list_cmd->add_flag("--downloaded", config.downloaded, "Save model options for future loads");
list_cmd->add_flag("--downloaded", config.downloaded, "Show only downloaded models");
list_cmd->add_option("name_filter", config.list_filter,
"Optional case-insensitive model-name filter; supports * wildcards")
->type_name("NAME_FILTER");

// Backend management options
backends_install_cmd->add_option("spec", config.backend_spec, "Backend spec (recipe:backend)")->required()->type_name("SPEC");
Expand Down Expand Up @@ -1151,7 +1155,7 @@ int main(int argc, char* argv[]) {
}
return client.status(config.port);
} else if (list_cmd->count() > 0) {
return client.list_models(!config.downloaded);
return client.list_models(!config.downloaded, config.list_filter);
} else if (pull_cmd->count() > 0) {
if (config.model.empty()) {
std::cerr << "Error: 'lemonade pull' requires a model name or Hugging Face checkpoint." << std::endl;
Expand Down
2 changes: 1 addition & 1 deletion src/cpp/include/lemon_cli/lemonade_client.h
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ class LemonadeClient {
~LemonadeClient();

// Model management commands
int list_models(bool show_all) const;
int list_models(bool show_all, const std::string& name_filter = "") const;
int pull_model(const nlohmann::json& model_data);
int delete_model(const std::string& model_name) const;
int load_model(const std::string& model_name, const nlohmann::json& recipe_options, bool save_options = false) const;
Expand Down
Loading