diff --git a/.editorconfig b/.editorconfig new file mode 100644 index 0000000..d4a2c44 --- /dev/null +++ b/.editorconfig @@ -0,0 +1,21 @@ +# http://editorconfig.org + +root = true + +[*] +indent_style = space +indent_size = 4 +trim_trailing_whitespace = true +insert_final_newline = true +charset = utf-8 +end_of_line = lf + +[*.bat] +indent_style = tab +end_of_line = crlf + +[LICENSE] +insert_final_newline = false + +[Makefile] +indent_style = tab diff --git a/.github/ISSUE_TEMPLATE.md b/.github/ISSUE_TEMPLATE.md new file mode 100644 index 0000000..f712034 --- /dev/null +++ b/.github/ISSUE_TEMPLATE.md @@ -0,0 +1,15 @@ +* xgboost2sql version: +* Python version: +* Operating System: + +### Description + +Describe what you were trying to get done. +Tell us what happened, what went wrong, and what you expected to happen. + +### What I Did + +``` +Paste the command(s) you ran and the output. +If there was a crash, please include the traceback here. +``` diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..4c915d1 --- /dev/null +++ b/.gitignore @@ -0,0 +1,106 @@ +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +env/ +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +*.egg-info/ +.installed.cfg +*.egg + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +.hypothesis/ +.pytest_cache/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# pyenv +.python-version + +# celery beat schedule file +celerybeat-schedule + +# SageMath parsed files +*.sage.py + +# dotenv +.env + +# virtualenv +.venv +venv/ +ENV/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ + +# IDE settings +.vscode/ +.idea/ diff --git a/.travis.yml b/.travis.yml new file mode 100644 index 0000000..043f5dc --- /dev/null +++ b/.travis.yml @@ -0,0 +1,28 @@ +# Config file for automatic testing at travis-ci.com + +language: python +python: + - 3.8 + - 3.7 + - 3.6 + +# Command to install dependencies, e.g. pip install -r requirements.txt --use-mirrors +install: pip install -U tox-travis + +# Command to run tests, e.g. python setup.py test +script: tox + +# Assuming you have installed the travis-ci CLI tool, after you +# create the Github repo and add it to Travis, run the +# following command to finish PyPI deployment setup: +# $ travis encrypt --add deploy.password +deploy: + provider: pypi + distributions: sdist bdist_wheel + user: RyanZheng + password: + secure: PLEASE_REPLACE_ME + on: + tags: true + repo: ZhengRyan/xgboost2sql + python: 3.8 diff --git a/AUTHORS.rst b/AUTHORS.rst new file mode 100644 index 0000000..c4a54fa --- /dev/null +++ b/AUTHORS.rst @@ -0,0 +1,13 @@ +======= +Credits +======= + +Development Lead +---------------- + +* RyanZheng + +Contributors +------------ + +None yet. Why not be the first? diff --git a/CONTRIBUTING.rst b/CONTRIBUTING.rst new file mode 100644 index 0000000..c9c0b86 --- /dev/null +++ b/CONTRIBUTING.rst @@ -0,0 +1,128 @@ +.. highlight:: shell + +============ +Contributing +============ + +Contributions are welcome, and they are greatly appreciated! Every little bit +helps, and credit will always be given. + +You can contribute in many ways: + +Types of Contributions +---------------------- + +Report Bugs +~~~~~~~~~~~ + +Report bugs at https://github.com/ZhengRyan/xgboost2sql/issues. + +If you are reporting a bug, please include: + +* Your operating system name and version. +* Any details about your local setup that might be helpful in troubleshooting. +* Detailed steps to reproduce the bug. + +Fix Bugs +~~~~~~~~ + +Look through the GitHub issues for bugs. Anything tagged with "bug" and "help +wanted" is open to whoever wants to implement it. + +Implement Features +~~~~~~~~~~~~~~~~~~ + +Look through the GitHub issues for features. Anything tagged with "enhancement" +and "help wanted" is open to whoever wants to implement it. + +Write Documentation +~~~~~~~~~~~~~~~~~~~ + +xgboost2sql could always use more documentation, whether as part of the +official xgboost2sql docs, in docstrings, or even on the web in blog posts, +articles, and such. + +Submit Feedback +~~~~~~~~~~~~~~~ + +The best way to send feedback is to file an issue at https://github.com/ZhengRyan/xgboost2sql/issues. + +If you are proposing a feature: + +* Explain in detail how it would work. +* Keep the scope as narrow as possible, to make it easier to implement. +* Remember that this is a volunteer-driven project, and that contributions + are welcome :) + +Get Started! +------------ + +Ready to contribute? Here's how to set up `xgboost2sql` for local development. + +1. Fork the `xgboost2sql` repo on GitHub. +2. Clone your fork locally:: + + $ git clone git@github.com:your_name_here/xgboost2sql.git + +3. Install your local copy into a virtualenv. Assuming you have virtualenvwrapper installed, this is how you set up your fork for local development:: + + $ mkvirtualenv xgboost2sql + $ cd xgboost2sql/ + $ python setup.py develop + +4. Create a branch for local development:: + + $ git checkout -b name-of-your-bugfix-or-feature + + Now you can make your changes locally. + +5. When you're done making changes, check that your changes pass flake8 and the + tests, including testing other Python versions with tox:: + + $ flake8 xgboost2sql tests + $ python setup.py test or pytest + $ tox + + To get flake8 and tox, just pip install them into your virtualenv. + +6. Commit your changes and push your branch to GitHub:: + + $ git add . + $ git commit -m "Your detailed description of your changes." + $ git push origin name-of-your-bugfix-or-feature + +7. Submit a pull request through the GitHub website. + +Pull Request Guidelines +----------------------- + +Before you submit a pull request, check that it meets these guidelines: + +1. The pull request should include tests. +2. If the pull request adds functionality, the docs should be updated. Put + your new functionality into a function with a docstring, and add the + feature to the list in README.rst. +3. The pull request should work for Python 3.5, 3.6, 3.7 and 3.8, and for PyPy. Check + https://travis-ci.com/ZhengRyan/xgboost2sql/pull_requests + and make sure that the tests pass for all supported Python versions. + +Tips +---- + +To run a subset of tests:: + +$ pytest tests.test_xgboost2sql + + +Deploying +--------- + +A reminder for the maintainers on how to deploy. +Make sure all your changes are committed (including an entry in HISTORY.rst). +Then run:: + +$ bump2version patch # possible: major / minor / patch +$ git push +$ git push --tags + +Travis will then deploy to PyPI if tests pass. diff --git a/HISTORY.rst b/HISTORY.rst new file mode 100644 index 0000000..4a6d9ae --- /dev/null +++ b/HISTORY.rst @@ -0,0 +1,8 @@ +======= +History +======= + +0.1.0 (2023-06-04) +------------------ + +* First release on PyPI. diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000..dda4589 --- /dev/null +++ b/LICENSE @@ -0,0 +1,22 @@ +MIT License + +Copyright (c) 2023, RyanZheng + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. + diff --git a/MANIFEST.in b/MANIFEST.in new file mode 100644 index 0000000..965b2dd --- /dev/null +++ b/MANIFEST.in @@ -0,0 +1,11 @@ +include AUTHORS.rst +include CONTRIBUTING.rst +include HISTORY.rst +include LICENSE +include README.rst + +recursive-include tests * +recursive-exclude * __pycache__ +recursive-exclude * *.py[co] + +recursive-include docs *.rst conf.py Makefile make.bat *.jpg *.png *.gif diff --git a/Makefile b/Makefile new file mode 100644 index 0000000..7106847 --- /dev/null +++ b/Makefile @@ -0,0 +1,87 @@ +.PHONY: clean clean-build clean-pyc clean-test coverage dist docs help install lint lint/flake8 +.DEFAULT_GOAL := help + +define BROWSER_PYSCRIPT +import os, webbrowser, sys + +from urllib.request import pathname2url + +webbrowser.open("file://" + pathname2url(os.path.abspath(sys.argv[1]))) +endef +export BROWSER_PYSCRIPT + +define PRINT_HELP_PYSCRIPT +import re, sys + +for line in sys.stdin: + match = re.match(r'^([a-zA-Z_-]+):.*?## (.*)$$', line) + if match: + target, help = match.groups() + print("%-20s %s" % (target, help)) +endef +export PRINT_HELP_PYSCRIPT + +BROWSER := python -c "$$BROWSER_PYSCRIPT" + +help: + @python -c "$$PRINT_HELP_PYSCRIPT" < $(MAKEFILE_LIST) + +clean: clean-build clean-pyc clean-test ## remove all build, test, coverage and Python artifacts + +clean-build: ## remove build artifacts + rm -fr build/ + rm -fr dist/ + rm -fr .eggs/ + find . -name '*.egg-info' -exec rm -fr {} + + find . -name '*.egg' -exec rm -f {} + + +clean-pyc: ## remove Python file artifacts + find . -name '*.pyc' -exec rm -f {} + + find . -name '*.pyo' -exec rm -f {} + + find . -name '*~' -exec rm -f {} + + find . -name '__pycache__' -exec rm -fr {} + + +clean-test: ## remove test and coverage artifacts + rm -fr .tox/ + rm -f .coverage + rm -fr htmlcov/ + rm -fr .pytest_cache + +lint/flake8: ## check style with flake8 + flake8 xgboost2sql tests + +lint: lint/flake8 ## check style + +test: ## run tests quickly with the default Python + pytest + +test-all: ## run tests on every Python version with tox + tox + +coverage: ## check code coverage quickly with the default Python + coverage run --source xgboost2sql -m pytest + coverage report -m + coverage html + $(BROWSER) htmlcov/index.html + +docs: ## generate Sphinx HTML documentation, including API docs + rm -f docs/xgboost2sql.rst + rm -f docs/modules.rst + sphinx-apidoc -o docs/ xgboost2sql + $(MAKE) -C docs clean + $(MAKE) -C docs html + $(BROWSER) docs/_build/html/index.html + +servedocs: docs ## compile the docs watching for changes + watchmedo shell-command -p '*.rst' -c '$(MAKE) -C docs html' -R -D . + +release: dist ## package and upload a release + twine upload dist/* + +dist: clean ## builds source and wheel package + python setup.py sdist + python setup.py bdist_wheel + ls -l dist + +install: clean ## install the package to the active Python's site-packages + python setup.py install diff --git a/README.md b/README.md new file mode 100644 index 0000000..9bae7d1 --- /dev/null +++ b/README.md @@ -0,0 +1,492 @@ +# XGBoost模型转sql语句工具包 +现在是大数据量的时代,我们开发的模型要应用在特别大的待预测集上,使用单机的python,需要预测2、3天,甚至更久,中途很有可能中断。因此需要通过分布式的方式来预测。这个工具包就是实现了将训练好的python模型,转换成sql语句。将生成的sql语句可以放到大数据环境中进行分布式执行预测,能比单机的python预测快好几个量级 + + +## 思想碰撞 + +| 微信 | 微信公众号 | +| :---: | :----: | +| RyanZheng.png | 魔都数据干饭人.png | +| RyanZheng | 魔都数据干饭人 | + + +> 仓库地址:https://github.com/ZhengRyan/xgboost2sql +> +> 微信公众号文章: +> +> pipy包:https://pypi.org/project/xgboost2sql/ + + + +## 环境准备 +可以不用单独创建虚拟环境,因为对包的依赖没有版本要求 + +### `xgboost2sql` 安装 +pip install(pip安装) + +```bash +pip install xgboost2sql # to install +pip install -U xgboost2sql # to upgrade +``` + +Source code install(源码安装) + +```bash +python setup.py install +``` + +### 运行样例 +####【注意:::核验对比python模型预测出来的结果和sql语句预测出来的结果是否一致请查看"https://github.com/ZhengRyan/xgboost2sql/examples/tutorial_code.ipynb"教程代码】 + ++ 导入相关依赖 +```python +import xgboost as xgb +from sklearn.datasets import make_classification +from sklearn.model_selection import train_test_split +from xgboost2sql import XGBoost2Sql +``` + ++ 训练1个xgboost二分类模型 +```python +X, y = make_classification(n_samples=10000, + n_features=10, + n_informative=3, + n_redundant=2, + n_repeated=0, + n_classes=2, + weights=[0.7, 0.3], + flip_y=0.1, + random_state=1024) + +X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=1024) + +###训练模型 +model = xgb.XGBClassifier(n_estimators=3) +model.fit(X_train, y_train) +xgb.to_graphviz(model) +``` ++ 使用xgboost2sql工具包将模型转换成的sql语句 +```python +xgb2sql = XGBoost2Sql() +sql_str = xgb2sql.transform(model) +``` + ++ 将sql语句保存 +```python +xgb2sql.save() +``` + ++ 将sql语句打印出来 +```python +print(sql_str) +``` + +```sql +select key,1 / (1 + exp(-(tree_1_score + tree_2_score + tree_3_score)+(-0.0))) as score + from ( + select key, + --tree1 + case when (f9<-1.64164519 or f9 is null) then + case when (f3<-4.19117069 or f3 is null) then + case when (f2<-1.31743848 or f2 is null) then + -0.150000006 + else + -0.544186056 + end + else + case when (f3<1.23432565 or f3 is null) then + case when (f7<-2.55682254 or f7 is null) then + -0.200000018 + else + case when (f5<0.154983491 or f5 is null) then + 0.544721723 + else + case when (f3<0.697217584 or f3 is null) then + -0.150000006 + else + 0.333333373 + end + end + end + else + case when (f5<-1.0218116 or f5 is null) then + case when (f0<-0.60882163 or f0 is null) then + case when (f2<0.26019755 or f2 is null) then + 0.0666666701 + else + -0.300000012 + end + else + -0.520000041 + end + else + 0.333333373 + end + end + end + else + case when (f9<1.60392439 or f9 is null) then + case when (f3<0.572542191 or f3 is null) then + case when (f5<0.653370142 or f5 is null) then + case when (f7<-0.765973091 or f7 is null) then + case when (f3<-0.432390809 or f3 is null) then + 0.204000011 + else + -0.485454559 + end + else + case when (f3<-1.20459461 or f3 is null) then + -0.5104478 + else + 0.441509455 + end + end + else + case when (f7<0.133017987 or f7 is null) then + case when (f8<0.320554674 or f8 is null) then + -0.290322572 + else + 0.368339777 + end + else + case when (f8<-0.211985052 or f8 is null) then + 0.504000008 + else + -0.525648415 + end + end + end + else + case when (f7<2.22314501 or f7 is null) then + case when (f8<-0.00532855326 or f8 is null) then + case when (f8<-0.204920739 or f8 is null) then + -0.533991575 + else + -0.200000018 + end + else + 0.428571463 + end + else + case when (f3<1.33772755 or f3 is null) then + case when (f0<-0.975171864 or f0 is null) then + 0.163636371 + else + 0.51818186 + end + else + -0 + end + end + end + else + case when (f3<1.77943277 or f3 is null) then + case when (f7<-0.469875157 or f7 is null) then + case when (f3<-0.536645889 or f3 is null) then + case when (f9<1.89841866 or f9 is null) then + -0 + else + 0.333333373 + end + else + case when (f4<-2.43660188 or f4 is null) then + 0.150000006 + else + 0.551020443 + end + end + else + case when (f1<-0.0788691565 or f1 is null) then + 0.150000006 + else + -0.375 + end + end + else + case when (f4<-1.73232496 or f4 is null) then + -0.150000006 + else + case when (f6<-1.6080606 or f6 is null) then + -0.150000006 + else + case when (f7<-0.259483218 or f7 is null) then + -0.558620751 + else + -0.300000012 + end + end + end + end + end + end + as tree_1_score, +--tree2 + case when (f9<-1.64164519 or f9 is null) then + case when (f3<-4.19117069 or f3 is null) then + case when (f0<0.942570388 or f0 is null) then + -0.432453066 + else + -0.128291652 + end + else + case when (f3<1.23432565 or f3 is null) then + case when (f7<-2.55682254 or f7 is null) then + -0.167702854 + else + case when (f5<0.154983491 or f5 is null) then + case when (f1<2.19985676 or f1 is null) then + 0.41752997 + else + 0.115944751 + end + else + 0.115584135 + end + end + else + case when (f5<-1.0218116 or f5 is null) then + case when (f0<-0.60882163 or f0 is null) then + -0.119530827 + else + -0.410788596 + end + else + 0.28256765 + end + end + end + else + case when (f9<1.60392439 or f9 is null) then + case when (f3<0.460727394 or f3 is null) then + case when (f5<0.653370142 or f5 is null) then + case when (f7<-0.933565617 or f7 is null) then + case when (f3<-0.572475374 or f3 is null) then + 0.182491601 + else + -0.377898693 + end + else + case when (f3<-1.20459461 or f3 is null) then + -0.392539263 + else + 0.352721155 + end + end + else + case when (f7<0.207098693 or f7 is null) then + case when (f8<0.498489976 or f8 is null) then + -0.193351224 + else + 0.29298231 + end + else + case when (f8<-0.117464997 or f8 is null) then + 0.400667101 + else + -0.402199954 + end + end + end + else + case when (f7<1.98268723 or f7 is null) then + case when (f8<-0.00532855326 or f8 is null) then + case when (f7<1.36281848 or f7 is null) then + -0.408002198 + else + -0.236123681 + end + else + case when (f5<1.14038813 or f5 is null) then + 0.404326111 + else + -0.110877581 + end + end + else + case when (f3<1.56952488 or f3 is null) then + case when (f5<2.14646816 or f5 is null) then + 0.409404457 + else + 0.0696995854 + end + else + -0.32059738 + end + end + end + else + case when (f3<1.77943277 or f3 is null) then + case when (f7<-0.469875157 or f7 is null) then + case when (f3<-0.536645889 or f3 is null) then + case when (f9<1.89841866 or f9 is null) then + -0 + else + 0.28256765 + end + else + 0.419863999 + end + else + case when (f3<0.444227457 or f3 is null) then + -0.34664312 + else + 0.0693304539 + end + end + else + case when (f4<-1.10089087 or f4 is null) then + case when (f3<2.3550868 or f3 is null) then + 0.0147894565 + else + -0.331404865 + end + else + -0.421277165 + end + end + end + end + as tree_2_score, +--tree3 + case when (f9<-1.64164519 or f9 is null) then + case when (f3<-4.19117069 or f3 is null) then + case when (f4<-1.30126143 or f4 is null) then + -0.0772174299 + else + -0.374165356 + end + else + case when (f3<1.23432565 or f3 is null) then + case when (f7<-2.55682254 or f7 is null) then + -0.142005175 + else + case when (f5<0.154983491 or f5 is null) then + case when (f7<3.59379435 or f7 is null) then + 0.352122813 + else + 0.132789165 + end + else + 0.0924336985 + end + end + else + case when (f5<-1.0218116 or f5 is null) then + case when (f0<-0.60882163 or f0 is null) then + -0.0954768136 + else + -0.351594836 + end + else + 0.245992288 + end + end + end + else + case when (f9<1.60392439 or f9 is null) then + case when (f3<0.347133756 or f3 is null) then + case when (f5<0.661561131 or f5 is null) then + case when (f7<-0.933565617 or f7 is null) then + case when (f3<-0.472413659 or f3 is null) then + 0.116336405 + else + -0.313245147 + end + else + case when (f3<-1.5402329 or f3 is null) then + -0.352897167 + else + 0.311400592 + end + end + else + case when (f7<0.275665522 or f7 is null) then + case when (f8<0.403402805 or f8 is null) then + -0.292606086 + else + 0.220064178 + end + else + case when (f8<-0.0442957953 or f8 is null) then + 0.350784421 + else + -0.336107522 + end + end + end + else + case when (f7<1.77503061 or f7 is null) then + case when (f8<0.196157426 or f8 is null) then + case when (f7<1.36281848 or f7 is null) then + -0.3376683 + else + -0.0711223111 + end + else + case when (f7<-0.661211252 or f7 is null) then + 0.434363276 + else + -0.219307661 + end + end + else + case when (f3<1.37940335 or f3 is null) then + case when (f6<1.34894884 or f6 is null) then + 0.367155522 + else + 0.124757253 + end + else + -0.293739736 + end + end + end + else + case when (f3<1.77943277 or f3 is null) then + case when (f7<-0.469875157 or f7 is null) then + case when (f3<-0.536645889 or f3 is null) then + case when (f9<1.89841866 or f9 is null) then + -0 + else + 0.245992288 + end + else + case when (f0<1.60565615 or f0 is null) then + 0.357973605 + else + 0.193993196 + end + end + else + case when (f9<1.89456153 or f9 is null) then + -0.276471078 + else + 0.111896731 + end + end + else + case when (f1<1.35706067 or f1 is null) then + case when (f4<-1.10089087 or f4 is null) then + case when (f3<2.3550868 or f3 is null) then + 0.0119848112 + else + -0.284813672 + end + else + -0.376859784 + end + else + case when (f2<-0.25748384 or f2 is null) then + 0.0723158419 + else + -0.253415495 + end + end + end + end + end + as tree_3_score + from data_table) +``` + + diff --git a/docs/Makefile b/docs/Makefile new file mode 100644 index 0000000..4786e82 --- /dev/null +++ b/docs/Makefile @@ -0,0 +1,20 @@ +# Minimal makefile for Sphinx documentation +# + +# You can set these variables from the command line. +SPHINXOPTS = +SPHINXBUILD = python -msphinx +SPHINXPROJ = xgboost2sql +SOURCEDIR = . +BUILDDIR = _build + +# Put it first so that "make" without argument is like "make help". +help: + @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) + +.PHONY: help Makefile + +# Catch-all target: route all unknown targets to Sphinx using the new +# "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). +%: Makefile + @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) diff --git a/docs/authors.rst b/docs/authors.rst new file mode 100644 index 0000000..e122f91 --- /dev/null +++ b/docs/authors.rst @@ -0,0 +1 @@ +.. include:: ../AUTHORS.rst diff --git a/docs/conf.py b/docs/conf.py new file mode 100755 index 0000000..29c03e4 --- /dev/null +++ b/docs/conf.py @@ -0,0 +1,162 @@ +#!/usr/bin/env python +# +# xgboost2sql documentation build configuration file, created by +# sphinx-quickstart on Fri Jun 9 13:47:02 2017. +# +# This file is execfile()d with the current directory set to its +# containing dir. +# +# Note that not all possible configuration values are present in this +# autogenerated file. +# +# All configuration values have a default; values that are commented out +# serve to show the default. + +# If extensions (or modules to document with autodoc) are in another +# directory, add these directories to sys.path here. If the directory is +# relative to the documentation root, use os.path.abspath to make it +# absolute, like shown here. +# +import os +import sys +sys.path.insert(0, os.path.abspath('..')) + +import xgboost2sql + +# -- General configuration --------------------------------------------- + +# If your documentation needs a minimal Sphinx version, state it here. +# +# needs_sphinx = '1.0' + +# Add any Sphinx extension module names here, as strings. They can be +# extensions coming with Sphinx (named 'sphinx.ext.*') or your custom ones. +extensions = ['sphinx.ext.autodoc', 'sphinx.ext.viewcode'] + +# Add any paths that contain templates here, relative to this directory. +templates_path = ['_templates'] + +# The suffix(es) of source filenames. +# You can specify multiple suffix as a list of string: +# +# source_suffix = ['.rst', '.md'] +source_suffix = '.rst' + +# The master toctree document. +master_doc = 'index' + +# General information about the project. +project = 'xgboost2sql' +copyright = "2023, RyanZheng" +author = "RyanZheng" + +# The version info for the project you're documenting, acts as replacement +# for |version| and |release|, also used in various other places throughout +# the built documents. +# +# The short X.Y version. +version = xgboost2sql.__version__ +# The full version, including alpha/beta/rc tags. +release = xgboost2sql.__version__ + +# The language for content autogenerated by Sphinx. Refer to documentation +# for a list of supported languages. +# +# This is also used if you do content translation via gettext catalogs. +# Usually you set "language" from the command line for these cases. +language = None + +# List of patterns, relative to source directory, that match files and +# directories to ignore when looking for source files. +# This patterns also effect to html_static_path and html_extra_path +exclude_patterns = ['_build', 'Thumbs.db', '.DS_Store'] + +# The name of the Pygments (syntax highlighting) style to use. +pygments_style = 'sphinx' + +# If true, `todo` and `todoList` produce output, else they produce nothing. +todo_include_todos = False + + +# -- Options for HTML output ------------------------------------------- + +# The theme to use for HTML and HTML Help pages. See the documentation for +# a list of builtin themes. +# +html_theme = 'alabaster' + +# Theme options are theme-specific and customize the look and feel of a +# theme further. For a list of options available for each theme, see the +# documentation. +# +# html_theme_options = {} + +# Add any paths that contain custom static files (such as style sheets) here, +# relative to this directory. They are copied after the builtin static files, +# so a file named "default.css" will overwrite the builtin "default.css". +html_static_path = ['_static'] + + +# -- Options for HTMLHelp output --------------------------------------- + +# Output file base name for HTML help builder. +htmlhelp_basename = 'xgboost2sqldoc' + + +# -- Options for LaTeX output ------------------------------------------ + +latex_elements = { + # The paper size ('letterpaper' or 'a4paper'). + # + # 'papersize': 'letterpaper', + + # The font size ('10pt', '11pt' or '12pt'). + # + # 'pointsize': '10pt', + + # Additional stuff for the LaTeX preamble. + # + # 'preamble': '', + + # Latex figure (float) alignment + # + # 'figure_align': 'htbp', +} + +# Grouping the document tree into LaTeX files. List of tuples +# (source start file, target name, title, author, documentclass +# [howto, manual, or own class]). +latex_documents = [ + (master_doc, 'xgboost2sql.tex', + 'xgboost2sql Documentation', + 'RyanZheng', 'manual'), +] + + +# -- Options for manual page output ------------------------------------ + +# One entry per manual page. List of tuples +# (source start file, name, description, authors, manual section). +man_pages = [ + (master_doc, 'xgboost2sql', + 'xgboost2sql Documentation', + [author], 1) +] + + +# -- Options for Texinfo output ---------------------------------------- + +# Grouping the document tree into Texinfo files. List of tuples +# (source start file, target name, title, author, +# dir menu entry, description, category) +texinfo_documents = [ + (master_doc, 'xgboost2sql', + 'xgboost2sql Documentation', + author, + 'xgboost2sql', + 'One line description of project.', + 'Miscellaneous'), +] + + + diff --git a/docs/contributing.rst b/docs/contributing.rst new file mode 100644 index 0000000..e582053 --- /dev/null +++ b/docs/contributing.rst @@ -0,0 +1 @@ +.. include:: ../CONTRIBUTING.rst diff --git a/docs/history.rst b/docs/history.rst new file mode 100644 index 0000000..2506499 --- /dev/null +++ b/docs/history.rst @@ -0,0 +1 @@ +.. include:: ../HISTORY.rst diff --git a/docs/index.rst b/docs/index.rst new file mode 100644 index 0000000..9e75919 --- /dev/null +++ b/docs/index.rst @@ -0,0 +1,20 @@ +Welcome to xgboost2sql's documentation! +====================================== + +.. toctree:: + :maxdepth: 2 + :caption: Contents: + + readme + installation + usage + modules + contributing + authors + history + +Indices and tables +================== +* :ref:`genindex` +* :ref:`modindex` +* :ref:`search` diff --git a/docs/installation.rst b/docs/installation.rst new file mode 100644 index 0000000..5601a84 --- /dev/null +++ b/docs/installation.rst @@ -0,0 +1,51 @@ +.. highlight:: shell + +============ +Installation +============ + + +Stable release +-------------- + +To install xgboost2sql, run this command in your terminal: + +.. code-block:: console + + $ pip install xgboost2sql + +This is the preferred method to install xgboost2sql, as it will always install the most recent stable release. + +If you don't have `pip`_ installed, this `Python installation guide`_ can guide +you through the process. + +.. _pip: https://pip.pypa.io +.. _Python installation guide: http://docs.python-guide.org/en/latest/starting/installation/ + + +From sources +------------ + +The sources for xgboost2sql can be downloaded from the `Github repo`_. + +You can either clone the public repository: + +.. code-block:: console + + $ git clone git://github.com/ZhengRyan/xgboost2sql + +Or download the `tarball`_: + +.. code-block:: console + + $ curl -OJL https://github.com/ZhengRyan/xgboost2sql/tarball/master + +Once you have a copy of the source, you can install it with: + +.. code-block:: console + + $ python setup.py install + + +.. _Github repo: https://github.com/ZhengRyan/xgboost2sql +.. _tarball: https://github.com/ZhengRyan/xgboost2sql/tarball/master diff --git a/docs/make.bat b/docs/make.bat new file mode 100644 index 0000000..b688a60 --- /dev/null +++ b/docs/make.bat @@ -0,0 +1,36 @@ +@ECHO OFF + +pushd %~dp0 + +REM Command file for Sphinx documentation + +if "%SPHINXBUILD%" == "" ( + set SPHINXBUILD=python -msphinx +) +set SOURCEDIR=. +set BUILDDIR=_build +set SPHINXPROJ=xgboost2sql + +if "%1" == "" goto help + +%SPHINXBUILD% >NUL 2>NUL +if errorlevel 9009 ( + echo. + echo.The Sphinx module was not found. Make sure you have Sphinx installed, + echo.then set the SPHINXBUILD environment variable to point to the full + echo.path of the 'sphinx-build' executable. Alternatively you may add the + echo.Sphinx directory to PATH. + echo. + echo.If you don't have Sphinx installed, grab it from + echo.http://sphinx-doc.org/ + exit /b 1 +) + +%SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% +goto end + +:help +%SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% + +:end +popd diff --git a/docs/readme.rst b/docs/readme.rst new file mode 100644 index 0000000..72a3355 --- /dev/null +++ b/docs/readme.rst @@ -0,0 +1 @@ +.. include:: ../README.rst diff --git a/docs/usage.rst b/docs/usage.rst new file mode 100644 index 0000000..0b70e11 --- /dev/null +++ b/docs/usage.rst @@ -0,0 +1,7 @@ +===== +Usage +===== + +To use xgboost2sql in a project:: + + import xgboost2sql diff --git a/examples/tutorial_code.ipynb b/examples/tutorial_code.ipynb new file mode 100644 index 0000000..89cefb3 --- /dev/null +++ b/examples/tutorial_code.ipynb @@ -0,0 +1,1974 @@ +{ + "cells": [ + { + "cell_type": "raw", + "metadata": {}, + "source": [ + "一、训练1个xgboost二分类模型\n", + "二、使用模型对测试数据集进行预测\n", + "三、使用xgboost2sql包将模型转换成的sql语句\n", + "四、使用模型转换成的sql语句对测试数据集进行预测\n", + "五、对比python模型预测出来的结果和sql语句预测出来的结果是否一致(很重要,一定要认真核对)\n", + "六、将sql保存成文件" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "pip install xgboost2sql" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [ + { + "data": { + "image/svg+xml": [ + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "0\n", + "\n", + "f9<-1.64164519\n", + "\n", + "\n", + "\n", + "1\n", + "\n", + "f3<-4.19117069\n", + "\n", + "\n", + "\n", + "0->1\n", + "\n", + "\n", + "yes, missing\n", + "\n", + "\n", + "\n", + "2\n", + "\n", + "f9<1.60392439\n", + "\n", + "\n", + "\n", + "0->2\n", + "\n", + "\n", + "no\n", + "\n", + "\n", + "\n", + "3\n", + "\n", + "f2<-1.31743848\n", + "\n", + "\n", + "\n", + "1->3\n", + "\n", + "\n", + "yes, missing\n", + "\n", + "\n", + "\n", + "4\n", + "\n", + "f3<1.23432565\n", + "\n", + "\n", + "\n", + "1->4\n", + "\n", + "\n", + "no\n", + "\n", + "\n", + "\n", + "5\n", + "\n", + "f3<0.572542191\n", + "\n", + "\n", + "\n", + "2->5\n", + "\n", + "\n", + "yes, missing\n", + "\n", + "\n", + "\n", + "6\n", + "\n", + "f3<1.77943277\n", + "\n", + "\n", + "\n", + "2->6\n", + "\n", + "\n", + "no\n", + "\n", + "\n", + "\n", + "7\n", + "\n", + "leaf=-0.150000006\n", + "\n", + "\n", + "\n", + "3->7\n", + "\n", + "\n", + "yes, missing\n", + "\n", + "\n", + "\n", + "8\n", + "\n", + "leaf=-0.544186056\n", + "\n", + "\n", + "\n", + "3->8\n", + "\n", + "\n", + "no\n", + "\n", + "\n", + "\n", + "9\n", + "\n", + "f7<-2.55682254\n", + "\n", + "\n", + "\n", + "4->9\n", + "\n", + "\n", + "yes, missing\n", + "\n", + "\n", + "\n", + "10\n", + "\n", + "f5<-1.0218116\n", + "\n", + "\n", + "\n", + "4->10\n", + "\n", + "\n", + "no\n", + "\n", + "\n", + "\n", + "15\n", + "\n", + "leaf=-0.200000018\n", + "\n", + "\n", + "\n", + "9->15\n", + "\n", + "\n", + "yes, missing\n", + "\n", + "\n", + "\n", + "16\n", + "\n", + "f5<0.154983491\n", + "\n", + "\n", + "\n", + "9->16\n", + "\n", + "\n", + "no\n", + "\n", + "\n", + "\n", + "17\n", + "\n", + "f0<-0.60882163\n", + "\n", + "\n", + "\n", + "10->17\n", + "\n", + "\n", + "yes, missing\n", + "\n", + "\n", + "\n", + "18\n", + "\n", + "leaf=0.333333373\n", + "\n", + "\n", + "\n", + "10->18\n", + "\n", + "\n", + "no\n", + "\n", + "\n", + "\n", + "27\n", + "\n", + "leaf=0.544721723\n", + "\n", + "\n", + "\n", + "16->27\n", + "\n", + "\n", + "yes, missing\n", + "\n", + "\n", + "\n", + "28\n", + "\n", + "f3<0.697217584\n", + "\n", + "\n", + "\n", + "16->28\n", + "\n", + "\n", + "no\n", + "\n", + "\n", + "\n", + "45\n", + "\n", + "leaf=-0.150000006\n", + "\n", + "\n", + "\n", + "28->45\n", + "\n", + "\n", + "yes, missing\n", + "\n", + "\n", + "\n", + "46\n", + "\n", + "leaf=0.333333373\n", + "\n", + "\n", + "\n", + "28->46\n", + "\n", + "\n", + "no\n", + "\n", + "\n", + "\n", + "29\n", + "\n", + "f2<0.26019755\n", + "\n", + "\n", + "\n", + "17->29\n", + "\n", + "\n", + "yes, missing\n", + "\n", + "\n", + "\n", + "30\n", + "\n", + "leaf=-0.520000041\n", + "\n", + "\n", + "\n", + "17->30\n", + "\n", + "\n", + "no\n", + "\n", + "\n", + "\n", + "47\n", + "\n", + "leaf=0.0666666701\n", + "\n", + "\n", + "\n", + "29->47\n", + "\n", + "\n", + "yes, missing\n", + "\n", + "\n", + "\n", + "48\n", + "\n", + "leaf=-0.300000012\n", + "\n", + "\n", + "\n", + "29->48\n", + "\n", + "\n", + "no\n", + "\n", + "\n", + "\n", + "11\n", + "\n", + "f5<0.653370142\n", + "\n", + "\n", + "\n", + "5->11\n", + "\n", + "\n", + "yes, missing\n", + "\n", + "\n", + "\n", + "12\n", + "\n", + "f7<2.22314501\n", + "\n", + "\n", + "\n", + "5->12\n", + "\n", + "\n", + "no\n", + "\n", + "\n", + "\n", + "13\n", + "\n", + "f7<-0.469875157\n", + "\n", + "\n", + "\n", + "6->13\n", + "\n", + "\n", + "yes, missing\n", + "\n", + "\n", + "\n", + "14\n", + "\n", + "f4<-1.73232496\n", + "\n", + "\n", + "\n", + "6->14\n", + "\n", + "\n", + "no\n", + "\n", + "\n", + "\n", + "19\n", + "\n", + "f7<-0.765973091\n", + "\n", + "\n", + "\n", + "11->19\n", + "\n", + "\n", + "yes, missing\n", + "\n", + "\n", + "\n", + "20\n", + "\n", + "f7<0.133017987\n", + "\n", + "\n", + "\n", + "11->20\n", + "\n", + "\n", + "no\n", + "\n", + "\n", + "\n", + "21\n", + "\n", + "f8<-0.00532855326\n", + "\n", + "\n", + "\n", + "12->21\n", + "\n", + "\n", + "yes, missing\n", + "\n", + "\n", + "\n", + "22\n", + "\n", + "f3<1.33772755\n", + "\n", + "\n", + "\n", + "12->22\n", + "\n", + "\n", + "no\n", + "\n", + "\n", + "\n", + "31\n", + "\n", + "f3<-0.432390809\n", + "\n", + "\n", + "\n", + "19->31\n", + "\n", + "\n", + "yes, missing\n", + "\n", + "\n", + "\n", + "32\n", + "\n", + "f3<-1.20459461\n", + "\n", + "\n", + "\n", + "19->32\n", + "\n", + "\n", + "no\n", + "\n", + "\n", + "\n", + "33\n", + "\n", + "f8<0.320554674\n", + "\n", + "\n", + "\n", + "20->33\n", + "\n", + "\n", + "yes, missing\n", + "\n", + "\n", + "\n", + "34\n", + "\n", + "f8<-0.211985052\n", + "\n", + "\n", + "\n", + "20->34\n", + "\n", + "\n", + "no\n", + "\n", + "\n", + "\n", + "49\n", + "\n", + "leaf=0.204000011\n", + "\n", + "\n", + "\n", + "31->49\n", + "\n", + "\n", + "yes, missing\n", + "\n", + "\n", + "\n", + "50\n", + "\n", + "leaf=-0.485454559\n", + "\n", + "\n", + "\n", + "31->50\n", + "\n", + "\n", + "no\n", + "\n", + "\n", + "\n", + "51\n", + "\n", + "leaf=-0.5104478\n", + "\n", + "\n", + "\n", + "32->51\n", + "\n", + "\n", + "yes, missing\n", + "\n", + "\n", + "\n", + "52\n", + "\n", + "leaf=0.441509455\n", + "\n", + "\n", + "\n", + "32->52\n", + "\n", + "\n", + "no\n", + "\n", + "\n", + "\n", + "53\n", + "\n", + "leaf=-0.290322572\n", + "\n", + "\n", + "\n", + "33->53\n", + "\n", + "\n", + "yes, missing\n", + "\n", + "\n", + "\n", + "54\n", + "\n", + "leaf=0.368339777\n", + "\n", + "\n", + "\n", + "33->54\n", + "\n", + "\n", + "no\n", + "\n", + "\n", + "\n", + "55\n", + "\n", + "leaf=0.504000008\n", + "\n", + "\n", + "\n", + "34->55\n", + "\n", + "\n", + "yes, missing\n", + "\n", + "\n", + "\n", + "56\n", + "\n", + "leaf=-0.525648415\n", + "\n", + "\n", + "\n", + "34->56\n", + "\n", + "\n", + "no\n", + "\n", + "\n", + "\n", + "35\n", + "\n", + "f8<-0.204920739\n", + "\n", + "\n", + "\n", + "21->35\n", + "\n", + "\n", + "yes, missing\n", + "\n", + "\n", + "\n", + "36\n", + "\n", + "leaf=0.428571463\n", + "\n", + "\n", + "\n", + "21->36\n", + "\n", + "\n", + "no\n", + "\n", + "\n", + "\n", + "37\n", + "\n", + "f0<-0.975171864\n", + "\n", + "\n", + "\n", + "22->37\n", + "\n", + "\n", + "yes, missing\n", + "\n", + "\n", + "\n", + "38\n", + "\n", + "leaf=-0\n", + "\n", + "\n", + "\n", + "22->38\n", + "\n", + "\n", + "no\n", + "\n", + "\n", + "\n", + "57\n", + "\n", + "leaf=-0.533991575\n", + "\n", + "\n", + "\n", + "35->57\n", + "\n", + "\n", + "yes, missing\n", + "\n", + "\n", + "\n", + "58\n", + "\n", + "leaf=-0.200000018\n", + "\n", + "\n", + "\n", + "35->58\n", + "\n", + "\n", + "no\n", + "\n", + "\n", + "\n", + "59\n", + "\n", + "leaf=0.163636371\n", + "\n", + "\n", + "\n", + "37->59\n", + "\n", + "\n", + "yes, missing\n", + "\n", + "\n", + "\n", + "60\n", + "\n", + "leaf=0.51818186\n", + "\n", + "\n", + "\n", + "37->60\n", + "\n", + "\n", + "no\n", + "\n", + "\n", + "\n", + "23\n", + "\n", + "f3<-0.536645889\n", + "\n", + "\n", + "\n", + "13->23\n", + "\n", + "\n", + "yes, missing\n", + "\n", + "\n", + "\n", + "24\n", + "\n", + "f1<-0.0788691565\n", + "\n", + "\n", + "\n", + "13->24\n", + "\n", + "\n", + "no\n", + "\n", + "\n", + "\n", + "25\n", + "\n", + "leaf=-0.150000006\n", + "\n", + "\n", + "\n", + "14->25\n", + "\n", + "\n", + "yes, missing\n", + "\n", + "\n", + "\n", + "26\n", + "\n", + "f6<-1.6080606\n", + "\n", + "\n", + "\n", + "14->26\n", + "\n", + "\n", + "no\n", + "\n", + "\n", + "\n", + "39\n", + "\n", + "f9<1.89841866\n", + "\n", + "\n", + "\n", + "23->39\n", + "\n", + "\n", + "yes, missing\n", + "\n", + "\n", + "\n", + "40\n", + "\n", + "f4<-2.43660188\n", + "\n", + "\n", + "\n", + "23->40\n", + "\n", + "\n", + "no\n", + "\n", + "\n", + "\n", + "41\n", + "\n", + "leaf=0.150000006\n", + "\n", + "\n", + "\n", + "24->41\n", + "\n", + "\n", + "yes, missing\n", + "\n", + "\n", + "\n", + "42\n", + "\n", + "leaf=-0.375\n", + "\n", + "\n", + "\n", + "24->42\n", + "\n", + "\n", + "no\n", + "\n", + "\n", + "\n", + "61\n", + "\n", + "leaf=-0\n", + "\n", + "\n", + "\n", + "39->61\n", + "\n", + "\n", + "yes, missing\n", + "\n", + "\n", + "\n", + "62\n", + "\n", + "leaf=0.333333373\n", + "\n", + "\n", + "\n", + "39->62\n", + "\n", + "\n", + "no\n", + "\n", + "\n", + "\n", + "63\n", + "\n", + "leaf=0.150000006\n", + "\n", + "\n", + "\n", + "40->63\n", + "\n", + "\n", + "yes, missing\n", + "\n", + "\n", + "\n", + "64\n", + "\n", + "leaf=0.551020443\n", + "\n", + "\n", + "\n", + "40->64\n", + "\n", + "\n", + "no\n", + "\n", + "\n", + "\n", + "43\n", + "\n", + "leaf=-0.150000006\n", + "\n", + "\n", + "\n", + "26->43\n", + "\n", + "\n", + "yes, missing\n", + "\n", + "\n", + "\n", + "44\n", + "\n", + "f7<-0.259483218\n", + "\n", + "\n", + "\n", + "26->44\n", + "\n", + "\n", + "no\n", + "\n", + "\n", + "\n", + "65\n", + "\n", + "leaf=-0.558620751\n", + "\n", + "\n", + "\n", + "44->65\n", + "\n", + "\n", + "yes, missing\n", + "\n", + "\n", + "\n", + "66\n", + "\n", + "leaf=-0.300000012\n", + "\n", + "\n", + "\n", + "44->66\n", + "\n", + "\n", + "no\n", + "\n", + "\n", + "\n" + ], + "text/plain": [ + "" + ] + }, + "execution_count": 1, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "###训练1个xgboost二分类模型\n", + "import xgboost as xgb\n", + "from sklearn.datasets import make_classification\n", + "from sklearn.model_selection import train_test_split\n", + "\n", + "X, y = make_classification(n_samples=10000,\n", + " n_features=10,\n", + " n_informative=3,\n", + " n_redundant=2,\n", + " n_repeated=0,\n", + " n_classes=2,\n", + " weights=[0.7, 0.3],\n", + " flip_y=0.1,\n", + " random_state=1024)\n", + "\n", + "X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=1024)\n", + "\n", + "###训练模型\n", + "model = xgb.XGBClassifier(n_estimators=3)\n", + "model.fit(X_train, y_train)\n", + "xgb.to_graphviz(model)" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + " key python_pred_res\n", + "0 0 0.788244247\n", + "1 1 0.220293656\n", + "2 2 0.220293656\n", + "3 3 0.790651679\n", + "4 4 0.220293656\n", + "... ... ...\n", + "2495 2495 0.220293656\n", + "2496 2496 0.220293656\n", + "2497 2497 0.217607751\n", + "2498 2498 0.220293656\n", + "2499 2499 0.762217879\n", + "\n", + "[2500 rows x 2 columns]\n" + ] + } + ], + "source": [ + "import pandas as pd\n", + "pd.set_option('display.float_format',lambda x : '%.9f' % x)\n", + "###使用模型对测试数据集进行预测\n", + "test_pred = model.predict_proba(X_test)[:, 1]\n", + "test_pred = pd.DataFrame(test_pred,columns=['python_pred_res'])\n", + "test_pred.reset_index(inplace=True)\n", + "test_pred.rename(columns={'index':'key'}, inplace=True)\n", + "print(test_pred)" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + " select key,1 / (1 + exp(-((tree_1_score + tree_2_score + tree_3_score)+(-0.0)))) as score\n", + " from (\n", + " select key,\n", + " --tree1\n", + "\t\tcase when (f9<-1.64164519 or f9 is null) then\n", + "\t\t\tcase when (f3<-4.19117069 or f3 is null) then\n", + "\t\t\t\tcase when (f2<-1.31743848 or f2 is null) then\n", + "\t\t\t\t\t-0.150000006\n", + "\t\t\t\telse\n", + "\t\t\t\t\t-0.544186056\n", + "\t\t\t\tend\n", + "\t\t\telse\n", + "\t\t\t\tcase when (f3<1.23432565 or f3 is null) then\n", + "\t\t\t\t\tcase when (f7<-2.55682254 or f7 is null) then\n", + "\t\t\t\t\t\t-0.200000018\n", + "\t\t\t\t\telse\n", + "\t\t\t\t\t\tcase when (f5<0.154983491 or f5 is null) then\n", + "\t\t\t\t\t\t\t0.544721723\n", + "\t\t\t\t\t\telse\n", + "\t\t\t\t\t\t\tcase when (f3<0.697217584 or f3 is null) then\n", + "\t\t\t\t\t\t\t\t-0.150000006\n", + "\t\t\t\t\t\t\telse\n", + "\t\t\t\t\t\t\t\t0.333333373\n", + "\t\t\t\t\t\t\tend\n", + "\t\t\t\t\t\tend\n", + "\t\t\t\t\tend\n", + "\t\t\t\telse\n", + "\t\t\t\t\tcase when (f5<-1.0218116 or f5 is null) then\n", + "\t\t\t\t\t\tcase when (f0<-0.60882163 or f0 is null) then\n", + "\t\t\t\t\t\t\tcase when (f2<0.26019755 or f2 is null) then\n", + "\t\t\t\t\t\t\t\t0.0666666701\n", + "\t\t\t\t\t\t\telse\n", + "\t\t\t\t\t\t\t\t-0.300000012\n", + "\t\t\t\t\t\t\tend\n", + "\t\t\t\t\t\telse\n", + "\t\t\t\t\t\t\t-0.520000041\n", + "\t\t\t\t\t\tend\n", + "\t\t\t\t\telse\n", + "\t\t\t\t\t\t0.333333373\n", + "\t\t\t\t\tend\n", + "\t\t\t\tend\n", + "\t\t\tend\n", + "\t\telse\n", + "\t\t\tcase when (f9<1.60392439 or f9 is null) then\n", + "\t\t\t\tcase when (f3<0.572542191 or f3 is null) then\n", + "\t\t\t\t\tcase when (f5<0.653370142 or f5 is null) then\n", + "\t\t\t\t\t\tcase when (f7<-0.765973091 or f7 is null) then\n", + "\t\t\t\t\t\t\tcase when (f3<-0.432390809 or f3 is null) then\n", + "\t\t\t\t\t\t\t\t0.204000011\n", + "\t\t\t\t\t\t\telse\n", + "\t\t\t\t\t\t\t\t-0.485454559\n", + "\t\t\t\t\t\t\tend\n", + "\t\t\t\t\t\telse\n", + "\t\t\t\t\t\t\tcase when (f3<-1.20459461 or f3 is null) then\n", + "\t\t\t\t\t\t\t\t-0.5104478\n", + "\t\t\t\t\t\t\telse\n", + "\t\t\t\t\t\t\t\t0.441509455\n", + "\t\t\t\t\t\t\tend\n", + "\t\t\t\t\t\tend\n", + "\t\t\t\t\telse\n", + "\t\t\t\t\t\tcase when (f7<0.133017987 or f7 is null) then\n", + "\t\t\t\t\t\t\tcase when (f8<0.320554674 or f8 is null) then\n", + "\t\t\t\t\t\t\t\t-0.290322572\n", + "\t\t\t\t\t\t\telse\n", + "\t\t\t\t\t\t\t\t0.368339777\n", + "\t\t\t\t\t\t\tend\n", + "\t\t\t\t\t\telse\n", + "\t\t\t\t\t\t\tcase when (f8<-0.211985052 or f8 is null) then\n", + "\t\t\t\t\t\t\t\t0.504000008\n", + "\t\t\t\t\t\t\telse\n", + "\t\t\t\t\t\t\t\t-0.525648415\n", + "\t\t\t\t\t\t\tend\n", + "\t\t\t\t\t\tend\n", + "\t\t\t\t\tend\n", + "\t\t\t\telse\n", + "\t\t\t\t\tcase when (f7<2.22314501 or f7 is null) then\n", + "\t\t\t\t\t\tcase when (f8<-0.00532855326 or f8 is null) then\n", + "\t\t\t\t\t\t\tcase when (f8<-0.204920739 or f8 is null) then\n", + "\t\t\t\t\t\t\t\t-0.533991575\n", + "\t\t\t\t\t\t\telse\n", + "\t\t\t\t\t\t\t\t-0.200000018\n", + "\t\t\t\t\t\t\tend\n", + "\t\t\t\t\t\telse\n", + "\t\t\t\t\t\t\t0.428571463\n", + "\t\t\t\t\t\tend\n", + "\t\t\t\t\telse\n", + "\t\t\t\t\t\tcase when (f3<1.33772755 or f3 is null) then\n", + "\t\t\t\t\t\t\tcase when (f0<-0.975171864 or f0 is null) then\n", + "\t\t\t\t\t\t\t\t0.163636371\n", + "\t\t\t\t\t\t\telse\n", + "\t\t\t\t\t\t\t\t0.51818186\n", + "\t\t\t\t\t\t\tend\n", + "\t\t\t\t\t\telse\n", + "\t\t\t\t\t\t\t-0\n", + "\t\t\t\t\t\tend\n", + "\t\t\t\t\tend\n", + "\t\t\t\tend\n", + "\t\t\telse\n", + "\t\t\t\tcase when (f3<1.77943277 or f3 is null) then\n", + "\t\t\t\t\tcase when (f7<-0.469875157 or f7 is null) then\n", + "\t\t\t\t\t\tcase when (f3<-0.536645889 or f3 is null) then\n", + "\t\t\t\t\t\t\tcase when (f9<1.89841866 or f9 is null) then\n", + "\t\t\t\t\t\t\t\t-0\n", + "\t\t\t\t\t\t\telse\n", + "\t\t\t\t\t\t\t\t0.333333373\n", + "\t\t\t\t\t\t\tend\n", + "\t\t\t\t\t\telse\n", + "\t\t\t\t\t\t\tcase when (f4<-2.43660188 or f4 is null) then\n", + "\t\t\t\t\t\t\t\t0.150000006\n", + "\t\t\t\t\t\t\telse\n", + "\t\t\t\t\t\t\t\t0.551020443\n", + "\t\t\t\t\t\t\tend\n", + "\t\t\t\t\t\tend\n", + "\t\t\t\t\telse\n", + "\t\t\t\t\t\tcase when (f1<-0.0788691565 or f1 is null) then\n", + "\t\t\t\t\t\t\t0.150000006\n", + "\t\t\t\t\t\telse\n", + "\t\t\t\t\t\t\t-0.375\n", + "\t\t\t\t\t\tend\n", + "\t\t\t\t\tend\n", + "\t\t\t\telse\n", + "\t\t\t\t\tcase when (f4<-1.73232496 or f4 is null) then\n", + "\t\t\t\t\t\t-0.150000006\n", + "\t\t\t\t\telse\n", + "\t\t\t\t\t\tcase when (f6<-1.6080606 or f6 is null) then\n", + "\t\t\t\t\t\t\t-0.150000006\n", + "\t\t\t\t\t\telse\n", + "\t\t\t\t\t\t\tcase when (f7<-0.259483218 or f7 is null) then\n", + "\t\t\t\t\t\t\t\t-0.558620751\n", + "\t\t\t\t\t\t\telse\n", + "\t\t\t\t\t\t\t\t-0.300000012\n", + "\t\t\t\t\t\t\tend\n", + "\t\t\t\t\t\tend\n", + "\t\t\t\t\tend\n", + "\t\t\t\tend\n", + "\t\t\tend\n", + "\t\tend\n", + "\t\tas tree_1_score,\n", + "--tree2\n", + "\t\tcase when (f9<-1.64164519 or f9 is null) then\n", + "\t\t\tcase when (f3<-4.19117069 or f3 is null) then\n", + "\t\t\t\tcase when (f0<0.942570388 or f0 is null) then\n", + "\t\t\t\t\t-0.432453066\n", + "\t\t\t\telse\n", + "\t\t\t\t\t-0.128291652\n", + "\t\t\t\tend\n", + "\t\t\telse\n", + "\t\t\t\tcase when (f3<1.23432565 or f3 is null) then\n", + "\t\t\t\t\tcase when (f7<-2.55682254 or f7 is null) then\n", + "\t\t\t\t\t\t-0.167702854\n", + "\t\t\t\t\telse\n", + "\t\t\t\t\t\tcase when (f5<0.154983491 or f5 is null) then\n", + "\t\t\t\t\t\t\tcase when (f1<2.19985676 or f1 is null) then\n", + "\t\t\t\t\t\t\t\t0.41752997\n", + "\t\t\t\t\t\t\telse\n", + "\t\t\t\t\t\t\t\t0.115944751\n", + "\t\t\t\t\t\t\tend\n", + "\t\t\t\t\t\telse\n", + "\t\t\t\t\t\t\t0.115584135\n", + "\t\t\t\t\t\tend\n", + "\t\t\t\t\tend\n", + "\t\t\t\telse\n", + "\t\t\t\t\tcase when (f5<-1.0218116 or f5 is null) then\n", + "\t\t\t\t\t\tcase when (f0<-0.60882163 or f0 is null) then\n", + "\t\t\t\t\t\t\t-0.119530827\n", + "\t\t\t\t\t\telse\n", + "\t\t\t\t\t\t\t-0.410788596\n", + "\t\t\t\t\t\tend\n", + "\t\t\t\t\telse\n", + "\t\t\t\t\t\t0.28256765\n", + "\t\t\t\t\tend\n", + "\t\t\t\tend\n", + "\t\t\tend\n", + "\t\telse\n", + "\t\t\tcase when (f9<1.60392439 or f9 is null) then\n", + "\t\t\t\tcase when (f3<0.460727394 or f3 is null) then\n", + "\t\t\t\t\tcase when (f5<0.653370142 or f5 is null) then\n", + "\t\t\t\t\t\tcase when (f7<-0.933565617 or f7 is null) then\n", + "\t\t\t\t\t\t\tcase when (f3<-0.572475374 or f3 is null) then\n", + "\t\t\t\t\t\t\t\t0.182491601\n", + "\t\t\t\t\t\t\telse\n", + "\t\t\t\t\t\t\t\t-0.377898693\n", + "\t\t\t\t\t\t\tend\n", + "\t\t\t\t\t\telse\n", + "\t\t\t\t\t\t\tcase when (f3<-1.20459461 or f3 is null) then\n", + "\t\t\t\t\t\t\t\t-0.392539263\n", + "\t\t\t\t\t\t\telse\n", + "\t\t\t\t\t\t\t\t0.352721155\n", + "\t\t\t\t\t\t\tend\n", + "\t\t\t\t\t\tend\n", + "\t\t\t\t\telse\n", + "\t\t\t\t\t\tcase when (f7<0.207098693 or f7 is null) then\n", + "\t\t\t\t\t\t\tcase when (f8<0.498489976 or f8 is null) then\n", + "\t\t\t\t\t\t\t\t-0.193351224\n", + "\t\t\t\t\t\t\telse\n", + "\t\t\t\t\t\t\t\t0.29298231\n", + "\t\t\t\t\t\t\tend\n", + "\t\t\t\t\t\telse\n", + "\t\t\t\t\t\t\tcase when (f8<-0.117464997 or f8 is null) then\n", + "\t\t\t\t\t\t\t\t0.400667101\n", + "\t\t\t\t\t\t\telse\n", + "\t\t\t\t\t\t\t\t-0.402199954\n", + "\t\t\t\t\t\t\tend\n", + "\t\t\t\t\t\tend\n", + "\t\t\t\t\tend\n", + "\t\t\t\telse\n", + "\t\t\t\t\tcase when (f7<1.98268723 or f7 is null) then\n", + "\t\t\t\t\t\tcase when (f8<-0.00532855326 or f8 is null) then\n", + "\t\t\t\t\t\t\tcase when (f7<1.36281848 or f7 is null) then\n", + "\t\t\t\t\t\t\t\t-0.408002198\n", + "\t\t\t\t\t\t\telse\n", + "\t\t\t\t\t\t\t\t-0.236123681\n", + "\t\t\t\t\t\t\tend\n", + "\t\t\t\t\t\telse\n", + "\t\t\t\t\t\t\tcase when (f5<1.14038813 or f5 is null) then\n", + "\t\t\t\t\t\t\t\t0.404326111\n", + "\t\t\t\t\t\t\telse\n", + "\t\t\t\t\t\t\t\t-0.110877581\n", + "\t\t\t\t\t\t\tend\n", + "\t\t\t\t\t\tend\n", + "\t\t\t\t\telse\n", + "\t\t\t\t\t\tcase when (f3<1.56952488 or f3 is null) then\n", + "\t\t\t\t\t\t\tcase when (f5<2.14646816 or f5 is null) then\n", + "\t\t\t\t\t\t\t\t0.409404457\n", + "\t\t\t\t\t\t\telse\n", + "\t\t\t\t\t\t\t\t0.0696995854\n", + "\t\t\t\t\t\t\tend\n", + "\t\t\t\t\t\telse\n", + "\t\t\t\t\t\t\t-0.32059738\n", + "\t\t\t\t\t\tend\n", + "\t\t\t\t\tend\n", + "\t\t\t\tend\n", + "\t\t\telse\n", + "\t\t\t\tcase when (f3<1.77943277 or f3 is null) then\n", + "\t\t\t\t\tcase when (f7<-0.469875157 or f7 is null) then\n", + "\t\t\t\t\t\tcase when (f3<-0.536645889 or f3 is null) then\n", + "\t\t\t\t\t\t\tcase when (f9<1.89841866 or f9 is null) then\n", + "\t\t\t\t\t\t\t\t-0\n", + "\t\t\t\t\t\t\telse\n", + "\t\t\t\t\t\t\t\t0.28256765\n", + "\t\t\t\t\t\t\tend\n", + "\t\t\t\t\t\telse\n", + "\t\t\t\t\t\t\t0.419863999\n", + "\t\t\t\t\t\tend\n", + "\t\t\t\t\telse\n", + "\t\t\t\t\t\tcase when (f3<0.444227457 or f3 is null) then\n", + "\t\t\t\t\t\t\t-0.34664312\n", + "\t\t\t\t\t\telse\n", + "\t\t\t\t\t\t\t0.0693304539\n", + "\t\t\t\t\t\tend\n", + "\t\t\t\t\tend\n", + "\t\t\t\telse\n", + "\t\t\t\t\tcase when (f4<-1.10089087 or f4 is null) then\n", + "\t\t\t\t\t\tcase when (f3<2.3550868 or f3 is null) then\n", + "\t\t\t\t\t\t\t0.0147894565\n", + "\t\t\t\t\t\telse\n", + "\t\t\t\t\t\t\t-0.331404865\n", + "\t\t\t\t\t\tend\n", + "\t\t\t\t\telse\n", + "\t\t\t\t\t\t-0.421277165\n", + "\t\t\t\t\tend\n", + "\t\t\t\tend\n", + "\t\t\tend\n", + "\t\tend\n", + "\t\tas tree_2_score,\n", + "--tree3\n", + "\t\tcase when (f9<-1.64164519 or f9 is null) then\n", + "\t\t\tcase when (f3<-4.19117069 or f3 is null) then\n", + "\t\t\t\tcase when (f4<-1.30126143 or f4 is null) then\n", + "\t\t\t\t\t-0.0772174299\n", + "\t\t\t\telse\n", + "\t\t\t\t\t-0.374165356\n", + "\t\t\t\tend\n", + "\t\t\telse\n", + "\t\t\t\tcase when (f3<1.23432565 or f3 is null) then\n", + "\t\t\t\t\tcase when (f7<-2.55682254 or f7 is null) then\n", + "\t\t\t\t\t\t-0.142005175\n", + "\t\t\t\t\telse\n", + "\t\t\t\t\t\tcase when (f5<0.154983491 or f5 is null) then\n", + "\t\t\t\t\t\t\tcase when (f7<3.59379435 or f7 is null) then\n", + "\t\t\t\t\t\t\t\t0.352122813\n", + "\t\t\t\t\t\t\telse\n", + "\t\t\t\t\t\t\t\t0.132789165\n", + "\t\t\t\t\t\t\tend\n", + "\t\t\t\t\t\telse\n", + "\t\t\t\t\t\t\t0.0924336985\n", + "\t\t\t\t\t\tend\n", + "\t\t\t\t\tend\n", + "\t\t\t\telse\n", + "\t\t\t\t\tcase when (f5<-1.0218116 or f5 is null) then\n", + "\t\t\t\t\t\tcase when (f0<-0.60882163 or f0 is null) then\n", + "\t\t\t\t\t\t\t-0.0954768136\n", + "\t\t\t\t\t\telse\n", + "\t\t\t\t\t\t\t-0.351594836\n", + "\t\t\t\t\t\tend\n", + "\t\t\t\t\telse\n", + "\t\t\t\t\t\t0.245992288\n", + "\t\t\t\t\tend\n", + "\t\t\t\tend\n", + "\t\t\tend\n", + "\t\telse\n", + "\t\t\tcase when (f9<1.60392439 or f9 is null) then\n", + "\t\t\t\tcase when (f3<0.347133756 or f3 is null) then\n", + "\t\t\t\t\tcase when (f5<0.661561131 or f5 is null) then\n", + "\t\t\t\t\t\tcase when (f7<-0.933565617 or f7 is null) then\n", + "\t\t\t\t\t\t\tcase when (f3<-0.472413659 or f3 is null) then\n", + "\t\t\t\t\t\t\t\t0.116336405\n", + "\t\t\t\t\t\t\telse\n", + "\t\t\t\t\t\t\t\t-0.313245147\n", + "\t\t\t\t\t\t\tend\n", + "\t\t\t\t\t\telse\n", + "\t\t\t\t\t\t\tcase when (f3<-1.5402329 or f3 is null) then\n", + "\t\t\t\t\t\t\t\t-0.352897167\n", + "\t\t\t\t\t\t\telse\n", + "\t\t\t\t\t\t\t\t0.311400592\n", + "\t\t\t\t\t\t\tend\n", + "\t\t\t\t\t\tend\n", + "\t\t\t\t\telse\n", + "\t\t\t\t\t\tcase when (f7<0.275665522 or f7 is null) then\n", + "\t\t\t\t\t\t\tcase when (f8<0.403402805 or f8 is null) then\n", + "\t\t\t\t\t\t\t\t-0.292606086\n", + "\t\t\t\t\t\t\telse\n", + "\t\t\t\t\t\t\t\t0.220064178\n", + "\t\t\t\t\t\t\tend\n", + "\t\t\t\t\t\telse\n", + "\t\t\t\t\t\t\tcase when (f8<-0.0442957953 or f8 is null) then\n", + "\t\t\t\t\t\t\t\t0.350784421\n", + "\t\t\t\t\t\t\telse\n", + "\t\t\t\t\t\t\t\t-0.336107522\n", + "\t\t\t\t\t\t\tend\n", + "\t\t\t\t\t\tend\n", + "\t\t\t\t\tend\n", + "\t\t\t\telse\n", + "\t\t\t\t\tcase when (f7<1.77503061 or f7 is null) then\n", + "\t\t\t\t\t\tcase when (f8<0.196157426 or f8 is null) then\n", + "\t\t\t\t\t\t\tcase when (f7<1.36281848 or f7 is null) then\n", + "\t\t\t\t\t\t\t\t-0.3376683\n", + "\t\t\t\t\t\t\telse\n", + "\t\t\t\t\t\t\t\t-0.0711223111\n", + "\t\t\t\t\t\t\tend\n", + "\t\t\t\t\t\telse\n", + "\t\t\t\t\t\t\tcase when (f7<-0.661211252 or f7 is null) then\n", + "\t\t\t\t\t\t\t\t0.434363276\n", + "\t\t\t\t\t\t\telse\n", + "\t\t\t\t\t\t\t\t-0.219307661\n", + "\t\t\t\t\t\t\tend\n", + "\t\t\t\t\t\tend\n", + "\t\t\t\t\telse\n", + "\t\t\t\t\t\tcase when (f3<1.37940335 or f3 is null) then\n", + "\t\t\t\t\t\t\tcase when (f6<1.34894884 or f6 is null) then\n", + "\t\t\t\t\t\t\t\t0.367155522\n", + "\t\t\t\t\t\t\telse\n", + "\t\t\t\t\t\t\t\t0.124757253\n", + "\t\t\t\t\t\t\tend\n", + "\t\t\t\t\t\telse\n", + "\t\t\t\t\t\t\t-0.293739736\n", + "\t\t\t\t\t\tend\n", + "\t\t\t\t\tend\n", + "\t\t\t\tend\n", + "\t\t\telse\n", + "\t\t\t\tcase when (f3<1.77943277 or f3 is null) then\n", + "\t\t\t\t\tcase when (f7<-0.469875157 or f7 is null) then\n", + "\t\t\t\t\t\tcase when (f3<-0.536645889 or f3 is null) then\n", + "\t\t\t\t\t\t\tcase when (f9<1.89841866 or f9 is null) then\n", + "\t\t\t\t\t\t\t\t-0\n", + "\t\t\t\t\t\t\telse\n", + "\t\t\t\t\t\t\t\t0.245992288\n", + "\t\t\t\t\t\t\tend\n", + "\t\t\t\t\t\telse\n", + "\t\t\t\t\t\t\tcase when (f0<1.60565615 or f0 is null) then\n", + "\t\t\t\t\t\t\t\t0.357973605\n", + "\t\t\t\t\t\t\telse\n", + "\t\t\t\t\t\t\t\t0.193993196\n", + "\t\t\t\t\t\t\tend\n", + "\t\t\t\t\t\tend\n", + "\t\t\t\t\telse\n", + "\t\t\t\t\t\tcase when (f9<1.89456153 or f9 is null) then\n", + "\t\t\t\t\t\t\t-0.276471078\n", + "\t\t\t\t\t\telse\n", + "\t\t\t\t\t\t\t0.111896731\n", + "\t\t\t\t\t\tend\n", + "\t\t\t\t\tend\n", + "\t\t\t\telse\n", + "\t\t\t\t\tcase when (f1<1.35706067 or f1 is null) then\n", + "\t\t\t\t\t\tcase when (f4<-1.10089087 or f4 is null) then\n", + "\t\t\t\t\t\t\tcase when (f3<2.3550868 or f3 is null) then\n", + "\t\t\t\t\t\t\t\t0.0119848112\n", + "\t\t\t\t\t\t\telse\n", + "\t\t\t\t\t\t\t\t-0.284813672\n", + "\t\t\t\t\t\t\tend\n", + "\t\t\t\t\t\telse\n", + "\t\t\t\t\t\t\t-0.376859784\n", + "\t\t\t\t\t\tend\n", + "\t\t\t\t\telse\n", + "\t\t\t\t\t\tcase when (f2<-0.25748384 or f2 is null) then\n", + "\t\t\t\t\t\t\t0.0723158419\n", + "\t\t\t\t\t\telse\n", + "\t\t\t\t\t\t\t-0.253415495\n", + "\t\t\t\t\t\tend\n", + "\t\t\t\t\tend\n", + "\t\t\t\tend\n", + "\t\t\tend\n", + "\t\tend\n", + "\t\tas tree_3_score\n", + " from data_table)\n", + " \n" + ] + } + ], + "source": [ + "from xgboost2sql import XGBoost2Sql\n", + "###使用xgboost2sql包将模型转换成的sql语句\n", + "xgb2sql = XGBoost2Sql()\n", + "sql_str = xgb2sql.transform(model)\n", + "print(sql_str)" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
keyscore
00.0000000000.788244248
11.0000000000.220293659
\n", + "
" + ], + "text/plain": [ + " key score\n", + "0 0.000000000 0.788244248\n", + "1 1.000000000 0.220293659" + ] + }, + "execution_count": 4, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "###使用模型转换成的sql语句对测试数据集进行预测\n", + "from pyspark.sql import SparkSession\n", + "spark = SparkSession.builder.getOrCreate()\n", + "\n", + "X_test_df = pd.DataFrame(X_test)\n", + "X_test_df.columns = X_test_df.columns.map(lambda x: \"f\"+str(x))\n", + "X_test_df.reset_index(inplace=True)\n", + "X_test_df.rename(columns={'index':'key'}, inplace=True)\n", + "values = X_test_df.values.tolist()\n", + "columns = X_test_df.columns.tolist()\n", + "spark_df = spark.createDataFrame(values, columns)\n", + "spark_df.createOrReplaceTempView('data_table')\n", + "sql_pred_pysdf = spark.sql(sql_str)\n", + "sql_pred_df = sql_pred_pysdf.toPandas()\n", + "sql_pred_df.head(2)" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "count 2500.000000000\n", + "mean -0.000000003\n", + "std 0.000000011\n", + "min -0.000000068\n", + "25% -0.000000003\n", + "50% -0.000000001\n", + "75% -0.000000000\n", + "max 0.000000056\n", + "Name: diff, dtype: float64" + ] + }, + "execution_count": 5, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "###对比python模型预测出来的结果和sql语句预测出来的结果是否一致\n", + "test_pred_sql_pred_df = test_pred.merge(sql_pred_df, on='key')\n", + "test_pred_sql_pred_df['diff'] = test_pred_sql_pred_df['python_pred_res'] - test_pred_sql_pred_df['score']\n", + "test_pred_sql_pred_df['diff'].describe()" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + " select key,1 / (1 + exp(-((tree_1_score + tree_2_score + tree_3_score)+(-0.0)))) as score\n", + " from (\n", + " select key,\n", + " --tree1\n", + "\t\tcase when (f9<-1.64164519 or f9 is null) then\n", + "\t\t\tcase when (f3<-4.19117069 or f3 is null) then\n", + "\t\t\t\tcase when (f2<-1.31743848 or f2 is null) then\n", + "\t\t\t\t\t-0.150000006\n", + "\t\t\t\telse\n", + "\t\t\t\t\t-0.544186056\n", + "\t\t\t\tend\n", + "\t\t\telse\n", + "\t\t\t\tcase when (f3<1.23432565 or f3 is null) then\n", + "\t\t\t\t\tcase when (f7<-2.55682254 or f7 is null) then\n", + "\t\t\t\t\t\t-0.200000018\n", + "\t\t\t\t\telse\n", + "\t\t\t\t\t\tcase when (f5<0.154983491 or f5 is null) then\n", + "\t\t\t\t\t\t\t0.544721723\n", + "\t\t\t\t\t\telse\n", + "\t\t\t\t\t\t\tcase when (f3<0.697217584 or f3 is null) then\n", + "\t\t\t\t\t\t\t\t-0.150000006\n", + "\t\t\t\t\t\t\telse\n", + "\t\t\t\t\t\t\t\t0.333333373\n", + "\t\t\t\t\t\t\tend\n", + "\t\t\t\t\t\tend\n", + "\t\t\t\t\tend\n", + "\t\t\t\telse\n", + "\t\t\t\t\tcase when (f5<-1.0218116 or f5 is null) then\n", + "\t\t\t\t\t\tcase when (f0<-0.60882163 or f0 is null) then\n", + "\t\t\t\t\t\t\tcase when (f2<0.26019755 or f2 is null) then\n", + "\t\t\t\t\t\t\t\t0.0666666701\n", + "\t\t\t\t\t\t\telse\n", + "\t\t\t\t\t\t\t\t-0.300000012\n", + "\t\t\t\t\t\t\tend\n", + "\t\t\t\t\t\telse\n", + "\t\t\t\t\t\t\t-0.520000041\n", + "\t\t\t\t\t\tend\n", + "\t\t\t\t\telse\n", + "\t\t\t\t\t\t0.333333373\n", + "\t\t\t\t\tend\n", + "\t\t\t\tend\n", + "\t\t\tend\n", + "\t\telse\n", + "\t\t\tcase when (f9<1.60392439 or f9 is null) then\n", + "\t\t\t\tcase when (f3<0.572542191 or f3 is null) then\n", + "\t\t\t\t\tcase when (f5<0.653370142 or f5 is null) then\n", + "\t\t\t\t\t\tcase when (f7<-0.765973091 or f7 is null) then\n", + "\t\t\t\t\t\t\tcase when (f3<-0.432390809 or f3 is null) then\n", + "\t\t\t\t\t\t\t\t0.204000011\n", + "\t\t\t\t\t\t\telse\n", + "\t\t\t\t\t\t\t\t-0.485454559\n", + "\t\t\t\t\t\t\tend\n", + "\t\t\t\t\t\telse\n", + "\t\t\t\t\t\t\tcase when (f3<-1.20459461 or f3 is null) then\n", + "\t\t\t\t\t\t\t\t-0.5104478\n", + "\t\t\t\t\t\t\telse\n", + "\t\t\t\t\t\t\t\t0.441509455\n", + "\t\t\t\t\t\t\tend\n", + "\t\t\t\t\t\tend\n", + "\t\t\t\t\telse\n", + "\t\t\t\t\t\tcase when (f7<0.133017987 or f7 is null) then\n", + "\t\t\t\t\t\t\tcase when (f8<0.320554674 or f8 is null) then\n", + "\t\t\t\t\t\t\t\t-0.290322572\n", + "\t\t\t\t\t\t\telse\n", + "\t\t\t\t\t\t\t\t0.368339777\n", + "\t\t\t\t\t\t\tend\n", + "\t\t\t\t\t\telse\n", + "\t\t\t\t\t\t\tcase when (f8<-0.211985052 or f8 is null) then\n", + "\t\t\t\t\t\t\t\t0.504000008\n", + "\t\t\t\t\t\t\telse\n", + "\t\t\t\t\t\t\t\t-0.525648415\n", + "\t\t\t\t\t\t\tend\n", + "\t\t\t\t\t\tend\n", + "\t\t\t\t\tend\n", + "\t\t\t\telse\n", + "\t\t\t\t\tcase when (f7<2.22314501 or f7 is null) then\n", + "\t\t\t\t\t\tcase when (f8<-0.00532855326 or f8 is null) then\n", + "\t\t\t\t\t\t\tcase when (f8<-0.204920739 or f8 is null) then\n", + "\t\t\t\t\t\t\t\t-0.533991575\n", + "\t\t\t\t\t\t\telse\n", + "\t\t\t\t\t\t\t\t-0.200000018\n", + "\t\t\t\t\t\t\tend\n", + "\t\t\t\t\t\telse\n", + "\t\t\t\t\t\t\t0.428571463\n", + "\t\t\t\t\t\tend\n", + "\t\t\t\t\telse\n", + "\t\t\t\t\t\tcase when (f3<1.33772755 or f3 is null) then\n", + "\t\t\t\t\t\t\tcase when (f0<-0.975171864 or f0 is null) then\n", + "\t\t\t\t\t\t\t\t0.163636371\n", + "\t\t\t\t\t\t\telse\n", + "\t\t\t\t\t\t\t\t0.51818186\n", + "\t\t\t\t\t\t\tend\n", + "\t\t\t\t\t\telse\n", + "\t\t\t\t\t\t\t-0\n", + "\t\t\t\t\t\tend\n", + "\t\t\t\t\tend\n", + "\t\t\t\tend\n", + "\t\t\telse\n", + "\t\t\t\tcase when (f3<1.77943277 or f3 is null) then\n", + "\t\t\t\t\tcase when (f7<-0.469875157 or f7 is null) then\n", + "\t\t\t\t\t\tcase when (f3<-0.536645889 or f3 is null) then\n", + "\t\t\t\t\t\t\tcase when (f9<1.89841866 or f9 is null) then\n", + "\t\t\t\t\t\t\t\t-0\n", + "\t\t\t\t\t\t\telse\n", + "\t\t\t\t\t\t\t\t0.333333373\n", + "\t\t\t\t\t\t\tend\n", + "\t\t\t\t\t\telse\n", + "\t\t\t\t\t\t\tcase when (f4<-2.43660188 or f4 is null) then\n", + "\t\t\t\t\t\t\t\t0.150000006\n", + "\t\t\t\t\t\t\telse\n", + "\t\t\t\t\t\t\t\t0.551020443\n", + "\t\t\t\t\t\t\tend\n", + "\t\t\t\t\t\tend\n", + "\t\t\t\t\telse\n", + "\t\t\t\t\t\tcase when (f1<-0.0788691565 or f1 is null) then\n", + "\t\t\t\t\t\t\t0.150000006\n", + "\t\t\t\t\t\telse\n", + "\t\t\t\t\t\t\t-0.375\n", + "\t\t\t\t\t\tend\n", + "\t\t\t\t\tend\n", + "\t\t\t\telse\n", + "\t\t\t\t\tcase when (f4<-1.73232496 or f4 is null) then\n", + "\t\t\t\t\t\t-0.150000006\n", + "\t\t\t\t\telse\n", + "\t\t\t\t\t\tcase when (f6<-1.6080606 or f6 is null) then\n", + "\t\t\t\t\t\t\t-0.150000006\n", + "\t\t\t\t\t\telse\n", + "\t\t\t\t\t\t\tcase when (f7<-0.259483218 or f7 is null) then\n", + "\t\t\t\t\t\t\t\t-0.558620751\n", + "\t\t\t\t\t\t\telse\n", + "\t\t\t\t\t\t\t\t-0.300000012\n", + "\t\t\t\t\t\t\tend\n", + "\t\t\t\t\t\tend\n", + "\t\t\t\t\tend\n", + "\t\t\t\tend\n", + "\t\t\tend\n", + "\t\tend\n", + "\t\tas tree_1_score,\n", + "--tree2\n", + "\t\tcase when (f9<-1.64164519 or f9 is null) then\n", + "\t\t\tcase when (f3<-4.19117069 or f3 is null) then\n", + "\t\t\t\tcase when (f0<0.942570388 or f0 is null) then\n", + "\t\t\t\t\t-0.432453066\n", + "\t\t\t\telse\n", + "\t\t\t\t\t-0.128291652\n", + "\t\t\t\tend\n", + "\t\t\telse\n", + "\t\t\t\tcase when (f3<1.23432565 or f3 is null) then\n", + "\t\t\t\t\tcase when (f7<-2.55682254 or f7 is null) then\n", + "\t\t\t\t\t\t-0.167702854\n", + "\t\t\t\t\telse\n", + "\t\t\t\t\t\tcase when (f5<0.154983491 or f5 is null) then\n", + "\t\t\t\t\t\t\tcase when (f1<2.19985676 or f1 is null) then\n", + "\t\t\t\t\t\t\t\t0.41752997\n", + "\t\t\t\t\t\t\telse\n", + "\t\t\t\t\t\t\t\t0.115944751\n", + "\t\t\t\t\t\t\tend\n", + "\t\t\t\t\t\telse\n", + "\t\t\t\t\t\t\t0.115584135\n", + "\t\t\t\t\t\tend\n", + "\t\t\t\t\tend\n", + "\t\t\t\telse\n", + "\t\t\t\t\tcase when (f5<-1.0218116 or f5 is null) then\n", + "\t\t\t\t\t\tcase when (f0<-0.60882163 or f0 is null) then\n", + "\t\t\t\t\t\t\t-0.119530827\n", + "\t\t\t\t\t\telse\n", + "\t\t\t\t\t\t\t-0.410788596\n", + "\t\t\t\t\t\tend\n", + "\t\t\t\t\telse\n", + "\t\t\t\t\t\t0.28256765\n", + "\t\t\t\t\tend\n", + "\t\t\t\tend\n", + "\t\t\tend\n", + "\t\telse\n", + "\t\t\tcase when (f9<1.60392439 or f9 is null) then\n", + "\t\t\t\tcase when (f3<0.460727394 or f3 is null) then\n", + "\t\t\t\t\tcase when (f5<0.653370142 or f5 is null) then\n", + "\t\t\t\t\t\tcase when (f7<-0.933565617 or f7 is null) then\n", + "\t\t\t\t\t\t\tcase when (f3<-0.572475374 or f3 is null) then\n", + "\t\t\t\t\t\t\t\t0.182491601\n", + "\t\t\t\t\t\t\telse\n", + "\t\t\t\t\t\t\t\t-0.377898693\n", + "\t\t\t\t\t\t\tend\n", + "\t\t\t\t\t\telse\n", + "\t\t\t\t\t\t\tcase when (f3<-1.20459461 or f3 is null) then\n", + "\t\t\t\t\t\t\t\t-0.392539263\n", + "\t\t\t\t\t\t\telse\n", + "\t\t\t\t\t\t\t\t0.352721155\n", + "\t\t\t\t\t\t\tend\n", + "\t\t\t\t\t\tend\n", + "\t\t\t\t\telse\n", + "\t\t\t\t\t\tcase when (f7<0.207098693 or f7 is null) then\n", + "\t\t\t\t\t\t\tcase when (f8<0.498489976 or f8 is null) then\n", + "\t\t\t\t\t\t\t\t-0.193351224\n", + "\t\t\t\t\t\t\telse\n", + "\t\t\t\t\t\t\t\t0.29298231\n", + "\t\t\t\t\t\t\tend\n", + "\t\t\t\t\t\telse\n", + "\t\t\t\t\t\t\tcase when (f8<-0.117464997 or f8 is null) then\n", + "\t\t\t\t\t\t\t\t0.400667101\n", + "\t\t\t\t\t\t\telse\n", + "\t\t\t\t\t\t\t\t-0.402199954\n", + "\t\t\t\t\t\t\tend\n", + "\t\t\t\t\t\tend\n", + "\t\t\t\t\tend\n", + "\t\t\t\telse\n", + "\t\t\t\t\tcase when (f7<1.98268723 or f7 is null) then\n", + "\t\t\t\t\t\tcase when (f8<-0.00532855326 or f8 is null) then\n", + "\t\t\t\t\t\t\tcase when (f7<1.36281848 or f7 is null) then\n", + "\t\t\t\t\t\t\t\t-0.408002198\n", + "\t\t\t\t\t\t\telse\n", + "\t\t\t\t\t\t\t\t-0.236123681\n", + "\t\t\t\t\t\t\tend\n", + "\t\t\t\t\t\telse\n", + "\t\t\t\t\t\t\tcase when (f5<1.14038813 or f5 is null) then\n", + "\t\t\t\t\t\t\t\t0.404326111\n", + "\t\t\t\t\t\t\telse\n", + "\t\t\t\t\t\t\t\t-0.110877581\n", + "\t\t\t\t\t\t\tend\n", + "\t\t\t\t\t\tend\n", + "\t\t\t\t\telse\n", + "\t\t\t\t\t\tcase when (f3<1.56952488 or f3 is null) then\n", + "\t\t\t\t\t\t\tcase when (f5<2.14646816 or f5 is null) then\n", + "\t\t\t\t\t\t\t\t0.409404457\n", + "\t\t\t\t\t\t\telse\n", + "\t\t\t\t\t\t\t\t0.0696995854\n", + "\t\t\t\t\t\t\tend\n", + "\t\t\t\t\t\telse\n", + "\t\t\t\t\t\t\t-0.32059738\n", + "\t\t\t\t\t\tend\n", + "\t\t\t\t\tend\n", + "\t\t\t\tend\n", + "\t\t\telse\n", + "\t\t\t\tcase when (f3<1.77943277 or f3 is null) then\n", + "\t\t\t\t\tcase when (f7<-0.469875157 or f7 is null) then\n", + "\t\t\t\t\t\tcase when (f3<-0.536645889 or f3 is null) then\n", + "\t\t\t\t\t\t\tcase when (f9<1.89841866 or f9 is null) then\n", + "\t\t\t\t\t\t\t\t-0\n", + "\t\t\t\t\t\t\telse\n", + "\t\t\t\t\t\t\t\t0.28256765\n", + "\t\t\t\t\t\t\tend\n", + "\t\t\t\t\t\telse\n", + "\t\t\t\t\t\t\t0.419863999\n", + "\t\t\t\t\t\tend\n", + "\t\t\t\t\telse\n", + "\t\t\t\t\t\tcase when (f3<0.444227457 or f3 is null) then\n", + "\t\t\t\t\t\t\t-0.34664312\n", + "\t\t\t\t\t\telse\n", + "\t\t\t\t\t\t\t0.0693304539\n", + "\t\t\t\t\t\tend\n", + "\t\t\t\t\tend\n", + "\t\t\t\telse\n", + "\t\t\t\t\tcase when (f4<-1.10089087 or f4 is null) then\n", + "\t\t\t\t\t\tcase when (f3<2.3550868 or f3 is null) then\n", + "\t\t\t\t\t\t\t0.0147894565\n", + "\t\t\t\t\t\telse\n", + "\t\t\t\t\t\t\t-0.331404865\n", + "\t\t\t\t\t\tend\n", + "\t\t\t\t\telse\n", + "\t\t\t\t\t\t-0.421277165\n", + "\t\t\t\t\tend\n", + "\t\t\t\tend\n", + "\t\t\tend\n", + "\t\tend\n", + "\t\tas tree_2_score,\n", + "--tree3\n", + "\t\tcase when (f9<-1.64164519 or f9 is null) then\n", + "\t\t\tcase when (f3<-4.19117069 or f3 is null) then\n", + "\t\t\t\tcase when (f4<-1.30126143 or f4 is null) then\n", + "\t\t\t\t\t-0.0772174299\n", + "\t\t\t\telse\n", + "\t\t\t\t\t-0.374165356\n", + "\t\t\t\tend\n", + "\t\t\telse\n", + "\t\t\t\tcase when (f3<1.23432565 or f3 is null) then\n", + "\t\t\t\t\tcase when (f7<-2.55682254 or f7 is null) then\n", + "\t\t\t\t\t\t-0.142005175\n", + "\t\t\t\t\telse\n", + "\t\t\t\t\t\tcase when (f5<0.154983491 or f5 is null) then\n", + "\t\t\t\t\t\t\tcase when (f7<3.59379435 or f7 is null) then\n", + "\t\t\t\t\t\t\t\t0.352122813\n", + "\t\t\t\t\t\t\telse\n", + "\t\t\t\t\t\t\t\t0.132789165\n", + "\t\t\t\t\t\t\tend\n", + "\t\t\t\t\t\telse\n", + "\t\t\t\t\t\t\t0.0924336985\n", + "\t\t\t\t\t\tend\n", + "\t\t\t\t\tend\n", + "\t\t\t\telse\n", + "\t\t\t\t\tcase when (f5<-1.0218116 or f5 is null) then\n", + "\t\t\t\t\t\tcase when (f0<-0.60882163 or f0 is null) then\n", + "\t\t\t\t\t\t\t-0.0954768136\n", + "\t\t\t\t\t\telse\n", + "\t\t\t\t\t\t\t-0.351594836\n", + "\t\t\t\t\t\tend\n", + "\t\t\t\t\telse\n", + "\t\t\t\t\t\t0.245992288\n", + "\t\t\t\t\tend\n", + "\t\t\t\tend\n", + "\t\t\tend\n", + "\t\telse\n", + "\t\t\tcase when (f9<1.60392439 or f9 is null) then\n", + "\t\t\t\tcase when (f3<0.347133756 or f3 is null) then\n", + "\t\t\t\t\tcase when (f5<0.661561131 or f5 is null) then\n", + "\t\t\t\t\t\tcase when (f7<-0.933565617 or f7 is null) then\n", + "\t\t\t\t\t\t\tcase when (f3<-0.472413659 or f3 is null) then\n", + "\t\t\t\t\t\t\t\t0.116336405\n", + "\t\t\t\t\t\t\telse\n", + "\t\t\t\t\t\t\t\t-0.313245147\n", + "\t\t\t\t\t\t\tend\n", + "\t\t\t\t\t\telse\n", + "\t\t\t\t\t\t\tcase when (f3<-1.5402329 or f3 is null) then\n", + "\t\t\t\t\t\t\t\t-0.352897167\n", + "\t\t\t\t\t\t\telse\n", + "\t\t\t\t\t\t\t\t0.311400592\n", + "\t\t\t\t\t\t\tend\n", + "\t\t\t\t\t\tend\n", + "\t\t\t\t\telse\n", + "\t\t\t\t\t\tcase when (f7<0.275665522 or f7 is null) then\n", + "\t\t\t\t\t\t\tcase when (f8<0.403402805 or f8 is null) then\n", + "\t\t\t\t\t\t\t\t-0.292606086\n", + "\t\t\t\t\t\t\telse\n", + "\t\t\t\t\t\t\t\t0.220064178\n", + "\t\t\t\t\t\t\tend\n", + "\t\t\t\t\t\telse\n", + "\t\t\t\t\t\t\tcase when (f8<-0.0442957953 or f8 is null) then\n", + "\t\t\t\t\t\t\t\t0.350784421\n", + "\t\t\t\t\t\t\telse\n", + "\t\t\t\t\t\t\t\t-0.336107522\n", + "\t\t\t\t\t\t\tend\n", + "\t\t\t\t\t\tend\n", + "\t\t\t\t\tend\n", + "\t\t\t\telse\n", + "\t\t\t\t\tcase when (f7<1.77503061 or f7 is null) then\n", + "\t\t\t\t\t\tcase when (f8<0.196157426 or f8 is null) then\n", + "\t\t\t\t\t\t\tcase when (f7<1.36281848 or f7 is null) then\n", + "\t\t\t\t\t\t\t\t-0.3376683\n", + "\t\t\t\t\t\t\telse\n", + "\t\t\t\t\t\t\t\t-0.0711223111\n", + "\t\t\t\t\t\t\tend\n", + "\t\t\t\t\t\telse\n", + "\t\t\t\t\t\t\tcase when (f7<-0.661211252 or f7 is null) then\n", + "\t\t\t\t\t\t\t\t0.434363276\n", + "\t\t\t\t\t\t\telse\n", + "\t\t\t\t\t\t\t\t-0.219307661\n", + "\t\t\t\t\t\t\tend\n", + "\t\t\t\t\t\tend\n", + "\t\t\t\t\telse\n", + "\t\t\t\t\t\tcase when (f3<1.37940335 or f3 is null) then\n", + "\t\t\t\t\t\t\tcase when (f6<1.34894884 or f6 is null) then\n", + "\t\t\t\t\t\t\t\t0.367155522\n", + "\t\t\t\t\t\t\telse\n", + "\t\t\t\t\t\t\t\t0.124757253\n", + "\t\t\t\t\t\t\tend\n", + "\t\t\t\t\t\telse\n", + "\t\t\t\t\t\t\t-0.293739736\n", + "\t\t\t\t\t\tend\n", + "\t\t\t\t\tend\n", + "\t\t\t\tend\n", + "\t\t\telse\n", + "\t\t\t\tcase when (f3<1.77943277 or f3 is null) then\n", + "\t\t\t\t\tcase when (f7<-0.469875157 or f7 is null) then\n", + "\t\t\t\t\t\tcase when (f3<-0.536645889 or f3 is null) then\n", + "\t\t\t\t\t\t\tcase when (f9<1.89841866 or f9 is null) then\n", + "\t\t\t\t\t\t\t\t-0\n", + "\t\t\t\t\t\t\telse\n", + "\t\t\t\t\t\t\t\t0.245992288\n", + "\t\t\t\t\t\t\tend\n", + "\t\t\t\t\t\telse\n", + "\t\t\t\t\t\t\tcase when (f0<1.60565615 or f0 is null) then\n", + "\t\t\t\t\t\t\t\t0.357973605\n", + "\t\t\t\t\t\t\telse\n", + "\t\t\t\t\t\t\t\t0.193993196\n", + "\t\t\t\t\t\t\tend\n", + "\t\t\t\t\t\tend\n", + "\t\t\t\t\telse\n", + "\t\t\t\t\t\tcase when (f9<1.89456153 or f9 is null) then\n", + "\t\t\t\t\t\t\t-0.276471078\n", + "\t\t\t\t\t\telse\n", + "\t\t\t\t\t\t\t0.111896731\n", + "\t\t\t\t\t\tend\n", + "\t\t\t\t\tend\n", + "\t\t\t\telse\n", + "\t\t\t\t\tcase when (f1<1.35706067 or f1 is null) then\n", + "\t\t\t\t\t\tcase when (f4<-1.10089087 or f4 is null) then\n", + "\t\t\t\t\t\t\tcase when (f3<2.3550868 or f3 is null) then\n", + "\t\t\t\t\t\t\t\t0.0119848112\n", + "\t\t\t\t\t\t\telse\n", + "\t\t\t\t\t\t\t\t-0.284813672\n", + "\t\t\t\t\t\t\tend\n", + "\t\t\t\t\t\telse\n", + "\t\t\t\t\t\t\t-0.376859784\n", + "\t\t\t\t\t\tend\n", + "\t\t\t\t\telse\n", + "\t\t\t\t\t\tcase when (f2<-0.25748384 or f2 is null) then\n", + "\t\t\t\t\t\t\t0.0723158419\n", + "\t\t\t\t\t\telse\n", + "\t\t\t\t\t\t\t-0.253415495\n", + "\t\t\t\t\t\tend\n", + "\t\t\t\t\tend\n", + "\t\t\t\tend\n", + "\t\t\tend\n", + "\t\tend\n", + "\t\tas tree_3_score\n", + " from data_table)\n", + " \n" + ] + } + ], + "source": [ + "###可以看出差异只有小数位级别,说明sql_str这个语句就是1个模型了。可以放到大数据环境中进行分布式执行,能比单机的python预测快好几个量级\n", + "print(sql_str)" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [], + "source": [ + "###将sql保存\n", + "xgb2sql.save()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.7.2" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/images/RyanZheng.png b/images/RyanZheng.png new file mode 100644 index 0000000..fadb7b4 Binary files /dev/null and b/images/RyanZheng.png differ diff --git "a/images/\351\255\224\351\203\275\346\225\260\346\215\256\345\271\262\351\245\255\344\272\272.png" "b/images/\351\255\224\351\203\275\346\225\260\346\215\256\345\271\262\351\245\255\344\272\272.png" new file mode 100644 index 0000000..82cc4a6 Binary files /dev/null and "b/images/\351\255\224\351\203\275\346\225\260\346\215\256\345\271\262\351\245\255\344\272\272.png" differ diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..a340243 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,3 @@ +xgboost +pyspark + diff --git a/setup.cfg b/setup.cfg new file mode 100644 index 0000000..3f759ae --- /dev/null +++ b/setup.cfg @@ -0,0 +1,20 @@ +[bumpversion] +current_version = 0.1.0 +commit = True +tag = True + +[bumpversion:file:setup.py] +search = version='{current_version}' +replace = version='{new_version}' + +[bumpversion:file:xgboost2sql/__init__.py] +search = __version__ = '{current_version}' +replace = __version__ = '{new_version}' + +[bdist_wheel] +universal = 1 + +[flake8] +exclude = docs +[tool:pytest] +collect_ignore = ['setup.py'] diff --git a/setup.py b/setup.py new file mode 100644 index 0000000..4fd708c --- /dev/null +++ b/setup.py @@ -0,0 +1,49 @@ +#!/usr/bin/env python + +"""The setup script.""" + +from setuptools import setup, find_packages + +with open('README.rst') as readme_file: + readme = readme_file.read() + +with open('HISTORY.rst') as history_file: + history = history_file.read() + +requirements = ['Click>=7.0', ] + +test_requirements = ['pytest>=3', ] + +setup( + author="RyanZheng", + author_email='zhengruiping000@163.com', + python_requires='>=3.6', + classifiers=[ + 'Development Status :: 2 - Pre-Alpha', + 'Intended Audience :: Developers', + 'License :: OSI Approved :: MIT License', + 'Natural Language :: English', + 'Programming Language :: Python :: 3', + 'Programming Language :: Python :: 3.6', + 'Programming Language :: Python :: 3.7', + 'Programming Language :: Python :: 3.8', + ], + description="Convert the trained xgboost model to sql", + entry_points={ + 'console_scripts': [ + 'xgboost2sql=xgboost2sql.cli:main', + ], + }, + install_requires=requirements, + license="MIT license", + long_description=readme + '\n\n' + history, + include_package_data=True, + keywords='xgboost2sql', + name='xgboost2sql', + packages=find_packages(include=['xgboost2sql', 'xgboost2sql.*']), + test_suite='tests', + tests_require=test_requirements, + url='https://github.com/ZhengRyan/xgboost2sql', + version='0.1.0', + zip_safe=False, +) diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..ab1bec4 --- /dev/null +++ b/tests/__init__.py @@ -0,0 +1 @@ +"""Unit test package for xgboost2sql.""" diff --git a/tests/test_xgboost2sql.py b/tests/test_xgboost2sql.py new file mode 100644 index 0000000..e1a6dd3 --- /dev/null +++ b/tests/test_xgboost2sql.py @@ -0,0 +1,32 @@ +#!/usr/bin/env python + +"""Tests for `xgboost2sql` package.""" + +import xgboost as xgb +from sklearn.datasets import make_classification +from sklearn.model_selection import train_test_split + +from xgboost2sql import XGBoost2Sql + +X, y = make_classification(n_samples=10000, + n_features=10, + n_informative=3, + n_redundant=2, + n_repeated=0, + n_classes=2, + weights=[0.7, 0.3], + flip_y=0.1, + random_state=1024) + +X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=1024) + +###训练模型 +model = xgb.XGBClassifier(n_estimators=3) +model.fit(X_train, y_train) +xgb.to_graphviz(model) + +###使用xgboost2sql包将模型转换成的sql语句 +xgb2sql = XGBoost2Sql() +sql_str = xgb2sql.transform(model) +print(sql_str) +xgb2sql.save() diff --git a/xgboost2sql/__init__.py b/xgboost2sql/__init__.py new file mode 100644 index 0000000..fb496e3 --- /dev/null +++ b/xgboost2sql/__init__.py @@ -0,0 +1,7 @@ +"""Top-level package for xgboost2sql.""" + +__author__ = """RyanZheng""" +__email__ = 'zhengruiping000@163.com' +__version__ = '0.1.0' + +from .xgboost2sql import XGBoost2Sql diff --git a/xgboost2sql/xgboost2sql.py b/xgboost2sql/xgboost2sql.py new file mode 100644 index 0000000..6388f94 --- /dev/null +++ b/xgboost2sql/xgboost2sql.py @@ -0,0 +1,245 @@ +#!/usr/bin/env python +# ! -*- coding: utf-8 -*- + +''' +@File: xgboost2sql.py +@Author: RyanZheng +@Email: ryan.zhengrp@gmail.com +@Created Time on: 2020-05-17 +''' + +import codecs +import json +import math +import warnings + +import xgboost + + +class XGBoost2Sql: + sql_str = ''' + select {0},1 / (1 + exp(-(({1})+({2})))) as score + from ( + select {0}, + {3} + from {4}) + ''' + code_str = '' + + def transform(self, xgboost_model, keep_columns=['key'], table_name='data_table'): + """ + + Args: + xgboost_model:训练好的xgboost模型 + keep_columns:使用sql语句进行预测时,需要保留的列。默认主键保留 + table_name:待预测数据集的表名 + + Returns:xgboost模型的sql预测语句 + + """ + + strs = self.get_dump_model(xgboost_model) + ### 读模型保存为txt的文件 + # with open('xgboost_model', 'r') as f: + # strs = f.read() + ### 读模型保存为txt的文件 + + logit = self.get_model_config(xgboost_model) + + # 解析的逻辑 + columns_l = [] + tree_list = strs.split('booster') + for i in range(1, len(tree_list)): + tree_str = tree_list[i] + lines = tree_str.split('\n') + v_lines = lines[1:-1] + + self.code_str += '--tree' + str(i) + '\n' + is_right = False + self.pre_tree(v_lines, is_right, 1) + columns_l.append('tree_{}_score'.format(i)) + if i == len(tree_list) - 1: + self.code_str += '\t\tas tree_{}_score'.format(i) + else: + self.code_str += '\t\tas tree_{}_score,\n'.format(i) + columns = ' + '.join(columns_l) + self.sql_str = self.sql_str.format(','.join(keep_columns), columns, logit, self.code_str, table_name) + return self.sql_str + + def get_dump_model(self, xgb_model): + """ + + Args: + xgb_model:xgboost模型 + + Returns: + + """ + if isinstance(xgb_model, xgboost.XGBClassifier): + xgb_model = xgb_model.get_booster() + # joblib.dump(xgb_model, 'xgb.ml') + # xgb_model.dump_model('xgb.txt') + # xgb_model.save_model('xgb.json') + ret = xgb_model.get_dump() + tree_strs = '' + for i, _ in enumerate(ret): + tree_strs += 'booster[{}]:\n'.format(i) + tree_strs += ret[i] + return tree_strs + + def get_model_config(self, xgb_model): + """ + + Args: + xgb_model:xgboost模型 + + Returns: + + """ + if isinstance(xgb_model, xgboost.XGBClassifier): + xgb_model = xgb_model.get_booster() + + try: + ###-math.log((1 / x) - 1) + x = float(json.loads(xgb_model.save_config())['learner']['learner_model_param']['base_score']) + return -math.log((1 - x) / x) + except: + warnings.warn( + 'xgboost model version less than :: 1.0.0, ' + 'If the base_score parameter is not 0.5 when developing the model, ' + 'Insert the base_score value into the formula "-math.log((1-x)/x)" ' + 'and replace the -0.0 value at +(-0.0) in the first sentence of the generated sql statement with the calculated value') + warnings.warn( + 'xgboost 模型的版本低于1.0.0,如果开发模型时, base_score 参数不是0.5,' + '请将base_score的参数取值带入"-math.log((1 - x) / x)"公式,计算出的值,替换掉生成的sql语句第1句中的+(-0.0)处的-0.0取值') + return 0 + + def pre_tree(self, lines, is_right, n): + """ + + Args: + lines:二叉树行 + is_right:是否右边 + n:第几层 + + Returns: + + """ + n += 1 + res = '' + if len(lines) <= 1: + str = lines[0].strip() + if 'leaf=' in str: + tmp = str.split('leaf=') + if len(tmp) > 1: + if is_right: + format = '\t' * (n - 1) + res = format + 'else\n' + format + '\t' + tmp[1].strip() + '\n' + format + 'end' + else: + format = '\t' * n + res = format + tmp[1].strip() + self.code_str += res + '\n' + return + v = lines[0].strip() + start_index = v.find('[') + median_index = v.find('<') + end_index = v.find(']') + v_name = v[start_index + 1:median_index].strip() + v_value = v[median_index:end_index] + ynm = v[end_index + 1:].strip().split(',') + yes_v = int(ynm[0].replace('yes=', '').strip()) + no_v = int(ynm[1].replace('no=', '').strip()) + miss_v = int(ynm[2].replace('missing=', '').strip()) + z_lines = lines[1:] + + if is_right: + format = '\t' * (n - 1) + res = res + format + 'else' + '\n' + if miss_v == yes_v: + format = '\t' * n + res = res + format + 'case when (' + v_name + v_value + ' or ' + v_name + ' is null' + ') then' + else: + format = '\t' * n + res = res + format + 'case when (' + v_name + v_value + ' and ' + v_name + ' is null' + ') then' + self.code_str += res + '\n' + left_right = self.get_tree_str(z_lines, yes_v, no_v) + + left_lines = left_right[0] + right_lines = left_right[1] + self.pre_tree(left_lines, False, n) + self.pre_tree(right_lines, True, n) + if is_right: + format = '\t' * (n - 1) + self.code_str += format + 'end\n' + + def get_tree_str(self, lines, yes_flag, no_flag): + """ + + Args: + lines:二叉树行 + yes_flag:左边 + no_flag:右边 + + Returns: + + """ + res = [] + left_n = 0 + right_n = 0 + for i in range(len(lines)): + tmp = lines[i].strip() + f_index = tmp.find(':') + next_flag = int(tmp[:f_index]) + if next_flag == yes_flag: + left_n = i + if next_flag == no_flag: + right_n = i + if right_n > left_n: + res.append(list(lines[left_n:right_n])) + res.append(list(lines[right_n:])) + else: + res.append(lines[left_n:]) + res.append(lines[right_n:left_n]) + return res + + def save(self, filename='xgb_model.sql'): + """ + + Args: + filename:sql语句保存的位置 + + Returns: + + """ + with codecs.open(filename, 'w', encoding='utf-8') as f: + f.write(self.sql_str) + + +if __name__ == '__main__': + ###训练1个xgboost二分类模型 + import xgboost as xgb + from sklearn.datasets import make_classification + from sklearn.model_selection import train_test_split + + X, y = make_classification(n_samples=10000, + n_features=10, + n_informative=3, + n_redundant=2, + n_repeated=0, + n_classes=2, + weights=[0.7, 0.3], + flip_y=0.1, + random_state=1024) + + X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=1024) + + ###训练模型 + model = xgb.XGBClassifier(n_estimators=3) + model.fit(X_train, y_train) + xgb.to_graphviz(model) + + ###使用xgboost2sql包将模型转换成的sql语句 + xgb2sql = XGBoost2Sql() + sql_str = xgb2sql.transform(model) + print(sql_str) + xgb2sql.save()