Skip to content

Commit

Permalink
test: add test file
Browse files Browse the repository at this point in the history
  • Loading branch information
archeno committed Aug 29, 2024
1 parent 42ff243 commit 60024c7
Show file tree
Hide file tree
Showing 6 changed files with 50 additions and 2 deletions.
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
dist/
dist/
**/__pycache__/
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ classifiers = [
# path = "src/fy_polyfit/polyfit.py"
source = "vcs"


[project.scripts]
polyfit = "fy_polyfit.polyfit:main"

Expand Down
Binary file removed src/fy_polyfit/__pycache__/polyfit.cpython-311.pyc
Binary file not shown.
12 changes: 11 additions & 1 deletion src/fy_polyfit/polyfit.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,12 @@ def cal_rmse(y_true, y_pred):
return np.sqrt(np.mean((y_pred - y_true) **2))


def PolyfitLinear(x, y):
return np.polyfit(x, y, 1)

def PolyfitQuadratic(x, y):
return np.polyfit(x, y, 2)

def main():
parser = argparse.ArgumentParser(description='从指定文件中读取数据,默认为data.txt, 无头部,空格分隔,两列')
parser.add_argument('filename', nargs='?', default='data.txt',
Expand Down Expand Up @@ -108,8 +114,12 @@ def main():
ax[1].text(0.5, 0.8, formula2,color='blue',fontsize=16, transform=ax[1].transAxes, verticalalignment='top', horizontalalignment='center')

plt.tight_layout()

# plt.draw()
# 保存图片到images文件夹下
plt.savefig('../../images/fitting.png')
plt.show()

except FileNotFoundError as e:
print(e)

Expand Down
Empty file added tests/__init__.py
Empty file.
36 changes: 36 additions & 0 deletions tests/test_polyfit.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
import sys, os
import numpy as np
import pytest
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "../src")))
from fy_polyfit.polyfit import PolyfitLinear, PolyfitQuadratic

def test_polyfit():
x = np.array([1, 2, 3, 4, 5])
y = 2 * x + 1 + np.random.uniform(-0.5, 0.5, size=x.shape)

# 调用 PolyfitLinear
coefficients = PolyfitLinear(x, y)

# 期望的系数
expected_coefficients = np.polyfit(x, y, 1)

# 断言系数接近期望值
np.testing.assert_almost_equal(coefficients, expected_coefficients, decimal=2)

def test_polyfit_quadratic():
# 示例数据
x = np.array([1, 2, 3, 4, 5])
y = 3 * x**2 + 2 * x + 1 + np.random.uniform(-1, 1, size=x.shape)

# 调用 PolyfitQuadratic
coefficients = PolyfitQuadratic(x, y)

# 期望的系数
expected_coefficients = np.polyfit(x, y, 2)

# 断言系数接近期望值
np.testing.assert_almost_equal(coefficients, expected_coefficients, decimal=2)


if __name__ == "__main__":
pytest.main()

0 comments on commit 60024c7

Please sign in to comment.