Skip to content

Commit 4923740

Browse files
committed
Add env module
1 parent 1effeaa commit 4923740

File tree

4 files changed

+184
-0
lines changed

4 files changed

+184
-0
lines changed

src/env.rs

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
use std::path::PathBuf;
2+
use std::result::Result as StdResult;
3+
4+
use mlua::{Lua, Result, Table};
5+
6+
/// Returns the current working directory
7+
fn current_dir(_lua: &Lua, _: ()) -> Result<StdResult<PathBuf, String>> {
8+
let dir = lua_try!(std::env::current_dir());
9+
Ok(Ok(dir))
10+
}
11+
12+
/// Changes the current working directory to the specified path
13+
fn set_current_dir(_lua: &Lua, path: String) -> Result<StdResult<bool, String>> {
14+
lua_try!(std::env::set_current_dir(path));
15+
Ok(Ok(true))
16+
}
17+
18+
/// Returns the full filesystem path of the current running executable
19+
fn current_exe(_lua: &Lua, _: ()) -> Result<StdResult<PathBuf, String>> {
20+
let exe = lua_try!(std::env::current_exe());
21+
Ok(Ok(exe))
22+
}
23+
24+
/// Returns the path of the current user’s home directory if known
25+
fn home_dir(_lua: &Lua, _: ()) -> Result<Option<PathBuf>> {
26+
Ok(std::env::home_dir())
27+
}
28+
29+
/// Fetches the environment variable key from the current process
30+
fn var(_lua: &Lua, key: String) -> Result<Option<String>> {
31+
Ok(std::env::var(key).ok())
32+
}
33+
34+
/// Returns a table containing all environment variables of the current process
35+
fn vars(lua: &Lua, _: ()) -> Result<Table> {
36+
lua.create_table_from(std::env::vars())
37+
}
38+
39+
/// Sets the environment variable key to the value in the current process
40+
///
41+
/// If value is Nil, the environment variable will be removed
42+
fn set_var(_lua: &Lua, (key, value): (String, Option<String>)) -> Result<()> {
43+
match value {
44+
Some(v) => unsafe { std::env::set_var(key, v) },
45+
None => unsafe { std::env::remove_var(key) },
46+
}
47+
Ok(())
48+
}
49+
50+
/// A loader for the `env` module.
51+
fn loader(lua: &Lua) -> Result<Table> {
52+
let t = lua.create_table()?;
53+
t.set("current_dir", lua.create_function(current_dir)?)?;
54+
t.set("set_current_dir", lua.create_function(set_current_dir)?)?;
55+
t.set("current_exe", lua.create_function(current_exe)?)?;
56+
t.set("home_dir", lua.create_function(home_dir)?)?;
57+
t.set("var", lua.create_function(var)?)?;
58+
t.set("vars", lua.create_function(vars)?)?;
59+
t.set("set_var", lua.create_function(set_var)?)?;
60+
61+
// Constants
62+
t.set("ARCH", std::env::consts::ARCH)?;
63+
t.set("FAMILY", std::env::consts::FAMILY)?;
64+
t.set("OS", std::env::consts::OS)?;
65+
66+
Ok(t)
67+
}
68+
69+
/// Registers the `yaml` module in the given Lua state.
70+
pub fn register(lua: &Lua, name: Option<&str>) -> Result<Table> {
71+
let name = name.unwrap_or("@env");
72+
let value = loader(lua)?;
73+
lua.register_module(name, &value)?;
74+
Ok(value)
75+
}

src/lib.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ pub(crate) mod time;
1111

1212
pub mod assertions;
1313
pub mod bytes;
14+
pub mod env;
1415
pub mod testing;
1516

1617
#[cfg(feature = "json")]

tests/lib.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ fn run_file(modname: &str) -> Result<()> {
99

1010
// Preload all modules
1111
mlua_stdlib::assertions::register(&lua, None)?;
12+
mlua_stdlib::env::register(&lua, None)?;
1213
let testing = mlua_stdlib::testing::register(&lua, None)?;
1314

1415
#[cfg(feature = "json")]
@@ -50,6 +51,7 @@ macro_rules! include_tests {
5051

5152
include_tests! {
5253
assertions,
54+
env,
5355
#[cfg(feature = "json")] json,
5456
#[cfg(feature = "regex")] regex,
5557
#[cfg(feature = "yaml")] yaml,

tests/lua/env_tests.lua

Lines changed: 106 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,106 @@
1+
local env = require("@env")
2+
3+
testing:test("current_dir", function(t)
4+
local dir, err = env.current_dir()
5+
t.assert_eq(err, nil)
6+
t.assert(type(dir) == "string", "current_dir should return a string")
7+
t.assert(#dir > 0, "current_dir should not be empty")
8+
end)
9+
10+
testing:test("set_current_dir", function(t)
11+
if env.FAMILY ~= "unix" then
12+
t.skip("Skipping set_current_dir test on Windows")
13+
end
14+
15+
local original_dir = env.current_dir()
16+
local parent_dir = original_dir .. "/.."
17+
local ok, err1 = env.set_current_dir(parent_dir)
18+
t.assert_eq(err1, nil)
19+
t.assert_eq(ok, true)
20+
21+
-- Verify the directory changed
22+
local new_dir, err2 = env.current_dir()
23+
t.assert_eq(err2, nil)
24+
t.assert(new_dir ~= original_dir, "directory should have changed")
25+
26+
-- Change back to original directory
27+
local _, err3 = env.set_current_dir(original_dir)
28+
t.assert_eq(err3, nil)
29+
30+
-- Test invalid directory
31+
local _, err5 = env.set_current_dir("/nonexistent/directory/path")
32+
t.assert(err5 ~= nil, "should fail for nonexistent directory")
33+
t.assert(type(err5) == "string", "error should be a string")
34+
end)
35+
36+
testing:test("current_exe", function(t)
37+
local exe, err = env.current_exe()
38+
t.assert_eq(err, nil)
39+
t.assert(type(exe) == "string", "current_exe should return a string")
40+
t.assert(#exe > 0, "current_exe should not be empty")
41+
-- The executable path should be a valid path (contains forward slash on Unix systems)
42+
if env.FAMILY == "unix" then
43+
t.assert(exe:match("/"), "executable should be a full path")
44+
end
45+
end)
46+
47+
testing:test("home_dir", function(t)
48+
local home = env.home_dir()
49+
-- home_dir can return nil if home directory is not known
50+
if home ~= nil then
51+
t.assert(type(home) == "string", "home_dir should return a string when available")
52+
t.assert(#home > 0, "home_dir should not be empty when available")
53+
end
54+
end)
55+
56+
testing:test("var", function(t)
57+
-- Test getting a variable that likely doesn't exist
58+
local value = env.var("MLUA_STDLIB_NONEXISTENT_VAR")
59+
t.assert_eq(value, nil, "nonexistent variable should return nil")
60+
61+
-- Test getting PATH (should exist on most systems)
62+
local path = env.var("PATH")
63+
t.assert(type(path) == "string", "PATH should be a string")
64+
t.assert(#path > 0, "PATH should not be empty")
65+
end)
66+
67+
testing:test("set_var", function(t)
68+
local test_key = "MLUA_STDLIB_TEST_VAR"
69+
local test_value = "test_value_123"
70+
71+
-- Ensure the variable doesn't exist initially
72+
local initial = env.var(test_key)
73+
t.assert_eq(initial, nil, "test variable should not exist initially")
74+
75+
-- Set the variable
76+
env.set_var(test_key, test_value)
77+
t.assert_eq(env.var(test_key), test_value, "variable should be set correctly")
78+
79+
-- Update the variable
80+
local new_value = "updated_value_456"
81+
env.set_var(test_key, new_value)
82+
t.assert_eq(env.var(test_key), new_value, "variable should be updated correctly")
83+
84+
-- Remove the variable
85+
env.set_var(test_key, nil)
86+
t.assert_eq(env.var(test_key), nil, "variable should be removed when set to nil")
87+
end)
88+
89+
testing:test("vars", function(t)
90+
local all_vars = env.vars()
91+
t.assert(type(all_vars) == "table", "vars should return a table")
92+
93+
-- Check that common environment variables exist
94+
local path = all_vars["PATH"]
95+
t.assert(type(path) == "string", "PATH in vars should be a string")
96+
97+
-- Set a test variable and verify it appears in vars
98+
local test_key = "MLUA_STDLIB_TEST_VARS"
99+
local test_value = "test_vars_value"
100+
env.set_var(test_key, test_value)
101+
t.assert_eq(env.vars()[test_key], test_value, "test variable should appear in vars")
102+
103+
-- Clean up
104+
env.set_var(test_key, nil)
105+
t.assert_eq(env.vars()[test_key], nil, "test variable should be removed from vars")
106+
end)

0 commit comments

Comments
 (0)