Skip to content

Commit 7bc5f40

Browse files
sfc-gh-anavalosSnowflake Authors
and
Snowflake Authors
authored
Project import generated by Copybara. (#131)
GitOrigin-RevId: 376d560591c49a1cbb8de1922d03cb51867613b5 Co-authored-by: Snowflake Authors <[email protected]>
1 parent 38d2497 commit 7bc5f40

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

42 files changed

+856
-575
lines changed

.bazelrc

+2-2
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
# Common Default
22

33
# Wrapper to make sure tests are run.
4-
# Allow at most 3 hours for eternal tests.
5-
test --run_under='//bazel:test_wrapper' --test_timeout=-1,-1,-1,10800
4+
# Allow at most 4 hours for eternal tests.
5+
test --run_under='//bazel:test_wrapper' --test_timeout=-1,-1,-1,14400
66

77
# Since integration tests are located in different packages than code under test,
88
# the default instrumentation filter would exclude the code under test. This

CHANGELOG.md

+15-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,20 @@
11
# Release History
22

3-
## 1.7.1
3+
## 1.7.2
4+
5+
### Bug Fixes
6+
7+
- Model Explainability: Fix issue that explain is enabled for scikit-learn pipeline
8+
whose task is UNKNOWN and fails later when invoked.
9+
10+
### Behavior Changes
11+
12+
### New Features
13+
14+
- Registry: Support asynchronous model inference service creation with the `block` option
15+
in `ModelVersion.create_service()` set to True by default.
16+
17+
## 1.7.1 (2024-11-05)
418

519
### Bug Fixes
620

README.md

+29-6
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ and deployment process, and includes two key components.
1212

1313
### Snowpark ML Development
1414

15-
[Snowpark ML Development](https://docs.snowflake.com/en/developer-guide/snowpark-ml/index#snowpark-ml-development)
15+
[Snowpark ML Development](https://docs.snowflake.com/en/developer-guide/snowpark-ml/index#ml-modeling)
1616
provides a collection of python APIs enabling efficient ML model development directly in Snowflake:
1717

1818
1. Modeling API (`snowflake.ml.modeling`) for data preprocessing, feature engineering and model training in Snowflake.
@@ -26,14 +26,21 @@ their native data loader formats.
2626
1. FileSet API: FileSet provides a Python fsspec-compliant API for materializing data into a Snowflake internal stage
2727
from a query or Snowpark Dataframe along with a number of convenience APIs.
2828

29-
### Snowpark Model Management [Public Preview]
29+
### Snowflake MLOps
3030

31-
[Snowpark Model Management](https://docs.snowflake.com/en/developer-guide/snowpark-ml/index#snowpark-ml-ops) complements
32-
the Snowpark ML Development API, and provides model management capabilities along with integrated deployment into Snowflake.
31+
Snowflake MLOps contains suit of tools and objects to make ML development cycle. It complements
32+
the Snowpark ML Development API, and provides end to end development to deployment within Snowflake.
3333
Currently, the API consists of:
3434

35-
1. Registry: A python API for managing models within Snowflake which also supports deployment of ML models into Snowflake
36-
as native MODEL object running with Snowflake Warehouse.
35+
1. [Registry](https://docs.snowflake.com/en/developer-guide/snowpark-ml/index#snowflake-model-registry): A python API
36+
allows secure deployment and management of models in Snowflake, supporting models trained both inside and outside of
37+
Snowflake.
38+
2. [Feature Store](https://docs.snowflake.com/en/developer-guide/snowpark-ml/index#snowflake-feature-store): A fully
39+
integrated solution for defining, managing, storing and discovering ML features derived from your data. The
40+
Snowflake Feature Store supports automated, incremental refresh from batch and streaming data sources, so that
41+
feature pipelines need be defined only once to be continuously updated with new data.
42+
3. [Datasets](https://docs.snowflake.com/developer-guide/snowflake-ml/overview#snowflake-datasets): Dataset provide an
43+
immutable, versioned snapshot of your data suitable for ingestion by your machine learning models.
3744

3845
## Getting started
3946

@@ -80,3 +87,19 @@ conda install \
8087

8188
Note that until a `snowflake-ml-python` package version is available in the official Snowflake conda channel, there may
8289
be compatibility issues. Server-side functionality that `snowflake-ml-python` depends on may not yet be released.
90+
91+
### Verifying the package
92+
93+
1. Install cosign.
94+
This example is using golang installation: [installing-cosign-with-go](https://edu.chainguard.dev/open-source/sigstore/cosign/how-to-install-cosign/#installing-cosign-with-go).
95+
1. Download the file from the repository like [pypi](https://pypi.org/project/snowflake-ml-python/#files).
96+
1. Download the signature files from the [release tag](https://github.com/snowflakedb/snowflake-ml-python/releases/tag/1.7.0).
97+
1. Verify signature on projects signed using Jenkins job:
98+
99+
```sh
100+
cosign verify-blob snowflake_ml_python-1.7.0.tar.gz --key snowflake-ml-python-1.7.0.pub --signature resources.linux.snowflake_ml_python-1.7.0.tar.gz.sig
101+
102+
cosign verify-blob snowflake_ml_python-1.7.0.tar.gz --key snowflake-ml-python-1.7.0.pub --signature resources.linux.snowflake_ml_python-1.7.0
103+
```
104+
105+
NOTE: Version 1.7.0 is used as example here. Please choose the the latest version.

bazel/environments/conda-env-snowflake.yml

+1
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ dependencies:
3636
- protobuf==3.20.3
3737
- psutil==5.9.0
3838
- pyarrow==10.0.1
39+
- pyjwt==2.8.0
3940
- pytest-rerunfailures==12.0
4041
- pytest-xdist==3.5.0
4142
- pytest==7.4.0

bazel/environments/conda-env.yml

+1
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ dependencies:
3636
- protobuf==3.20.3
3737
- psutil==5.9.0
3838
- pyarrow==10.0.1
39+
- pyjwt==2.8.0
3940
- pytest-rerunfailures==12.0
4041
- pytest-xdist==3.5.0
4142
- pytest==7.4.0

bazel/environments/conda-gpu-env.yml

+1
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ dependencies:
3737
- protobuf==3.20.3
3838
- psutil==5.9.0
3939
- pyarrow==10.0.1
40+
- pyjwt==2.8.0
4041
- pytest-rerunfailures==12.0
4142
- pytest-xdist==3.5.0
4243
- pytest==7.4.0

ci/conda_recipe/meta.yaml

+2-1
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ build:
1717
noarch: python
1818
package:
1919
name: snowflake-ml-python
20-
version: 1.7.1
20+
version: 1.7.2
2121
requirements:
2222
build:
2323
- python
@@ -35,6 +35,7 @@ requirements:
3535
- packaging>=20.9,<25
3636
- pandas>=1.0.0,<3
3737
- pyarrow
38+
- pyjwt>=2.0.0, <3
3839
- pytimeparse>=1.1.8,<2
3940
- pyyaml>=6.0,<7
4041
- requests

ci/targets/quarantine/prod3.txt

-1
Original file line numberDiff line numberDiff line change
@@ -5,4 +5,3 @@
55
//tests/integ/snowflake/ml/modeling/preprocessing:k_bins_discretizer_test
66
//tests/integ/snowflake/ml/modeling/linear_model:logistic_regression_test
77
//tests/integ/snowflake/ml/registry/model:registry_mlflow_model_test
8-
//tests/integ/snowflake/ml/registry/services/...

docs/README.md

+2-1
Original file line numberDiff line numberDiff line change
@@ -32,4 +32,5 @@ The following files are in the `docs/source` directory:
3232
- `index.rst`: ReStructuredText (RST) file that will be built as the index page.
3333
It mainly as a landing point and indicates the subp-ackages to include in the API reference.
3434
Currently these include the Modeling and FileSet/FileSystem APIs.
35-
- `fileset.rst`, `modeling.rst`, `registry.rst`: RST files that direct Sphinx to include the specific classes in each submodule.
35+
- RST files that direct Sphinx to include the specific classes in each submodule.
36+
- `fileset.rst`, `modeling.rst`, `monitoring.rst`, `registry.rst`

docs/source/index.rst

+1
Original file line numberDiff line numberDiff line change
@@ -32,4 +32,5 @@ Table of Contents
3232
fileset
3333
model
3434
modeling
35+
monitoring
3536
registry

docs/source/monitoring.rst

+31
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
===========================
2+
snowflake.ml.monitoring
3+
===========================
4+
5+
.. automodule:: snowflake.ml.monitoring
6+
:noindex:
7+
8+
snowflake.ml.monitoring.model_monitor
9+
-------------------------------------
10+
11+
.. currentmodule:: snowflake.ml.monitoring.model_monitor
12+
13+
.. rubric:: Classes
14+
15+
.. autosummary::
16+
:toctree: api/monitoring
17+
18+
ModelMonitor
19+
20+
snowflake.ml.monitoring.entities
21+
-------------------------------------
22+
23+
.. currentmodule:: snowflake.ml.monitoring.entities
24+
25+
.. rubric:: Classes
26+
27+
.. autosummary::
28+
:toctree: api/monitoring
29+
30+
model_monitor_config.ModelMonitorConfig
31+
model_monitor_config.ModelMonitorSourceConfig

requirements.txt

+1
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ peft==0.5.0
3232
protobuf==3.20.3
3333
psutil==5.9.0
3434
pyarrow==10.0.1
35+
pyjwt==2.8.0
3536
pytest-rerunfailures==12.0
3637
pytest-xdist==3.5.0
3738
pytest==7.4.0

requirements.yml

+3
Original file line numberDiff line numberDiff line change
@@ -174,6 +174,9 @@
174174
- name: pyarrow
175175
dev_version: 10.0.1
176176
version_requirements: ''
177+
- name: pyjwt
178+
dev_version: 2.8.0
179+
version_requirements: '>=2.0.0, <3'
177180
- name: pytest
178181
dev_version: 7.4.0
179182
tags:

snowflake/ml/_internal/utils/BUILD.bazel

+5
Original file line numberDiff line numberDiff line change
@@ -249,3 +249,8 @@ py_test(
249249
"//snowflake/ml/test_utils:mock_session",
250250
],
251251
)
252+
253+
py_library(
254+
name = "jwt_generator",
255+
srcs = ["jwt_generator.py"],
256+
)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,141 @@
1+
import base64
2+
import hashlib
3+
import logging
4+
from datetime import datetime, timedelta, timezone
5+
from typing import Optional
6+
7+
import jwt
8+
from cryptography.hazmat.primitives import serialization
9+
from cryptography.hazmat.primitives.asymmetric import types
10+
11+
logger = logging.getLogger(__name__)
12+
13+
ISSUER = "iss"
14+
EXPIRE_TIME = "exp"
15+
ISSUE_TIME = "iat"
16+
SUBJECT = "sub"
17+
18+
19+
class JWTGenerator:
20+
"""
21+
Creates and signs a JWT with the specified private key file, username, and account identifier. The JWTGenerator
22+
keeps the generated token and only regenerates the token if a specified period of time has passed.
23+
"""
24+
25+
_DEFAULT_LIFETIME = timedelta(minutes=59) # The tokens will have a 59-minute lifetime
26+
_DEFAULT_RENEWAL_DELTA = timedelta(minutes=54) # Tokens will be renewed after 54 minutes
27+
ALGORITHM = "RS256" # Tokens will be generated using RSA with SHA256
28+
29+
def __init__(
30+
self,
31+
account: str,
32+
user: str,
33+
private_key: types.PRIVATE_KEY_TYPES,
34+
lifetime: Optional[timedelta] = None,
35+
renewal_delay: Optional[timedelta] = None,
36+
) -> None:
37+
"""
38+
Create a new JWTGenerator object.
39+
40+
Args:
41+
account: The account identifier.
42+
user: The username.
43+
private_key: The private key used to sign the JWT.
44+
lifetime: The lifetime of the token.
45+
renewal_delay: The time before the token expires to renew it.
46+
"""
47+
48+
# Construct the fully qualified name of the user in uppercase.
49+
self.account = JWTGenerator._prepare_account_name_for_jwt(account)
50+
self.user = user.upper()
51+
self.qualified_username = self.account + "." + self.user
52+
self.private_key = private_key
53+
self.public_key_fp = JWTGenerator._calculate_public_key_fingerprint(self.private_key)
54+
55+
self.issuer = self.qualified_username + "." + self.public_key_fp
56+
self.lifetime = lifetime or JWTGenerator._DEFAULT_LIFETIME
57+
self.renewal_delay = renewal_delay or JWTGenerator._DEFAULT_RENEWAL_DELTA
58+
self.renew_time = datetime.now(timezone.utc)
59+
self.token: Optional[str] = None
60+
61+
logger.info(
62+
"""Creating JWTGenerator with arguments
63+
account : %s, user : %s, lifetime : %s, renewal_delay : %s""",
64+
self.account,
65+
self.user,
66+
self.lifetime,
67+
self.renewal_delay,
68+
)
69+
70+
@staticmethod
71+
def _prepare_account_name_for_jwt(raw_account: str) -> str:
72+
account = raw_account
73+
if ".global" not in account:
74+
# Handle the general case.
75+
idx = account.find(".")
76+
if idx > 0:
77+
account = account[0:idx]
78+
else:
79+
# Handle the replication case.
80+
idx = account.find("-")
81+
if idx > 0:
82+
account = account[0:idx]
83+
# Use uppercase for the account identifier.
84+
return account.upper()
85+
86+
def get_token(self) -> str:
87+
now = datetime.now(timezone.utc) # Fetch the current time
88+
if self.token is not None and self.renew_time > now:
89+
return self.token
90+
91+
# If the token has expired or doesn't exist, regenerate the token.
92+
logger.info(
93+
"Generating a new token because the present time (%s) is later than the renewal time (%s)",
94+
now,
95+
self.renew_time,
96+
)
97+
# Calculate the next time we need to renew the token.
98+
self.renew_time = now + self.renewal_delay
99+
100+
# Create our payload
101+
payload = {
102+
# Set the issuer to the fully qualified username concatenated with the public key fingerprint.
103+
ISSUER: self.issuer,
104+
# Set the subject to the fully qualified username.
105+
SUBJECT: self.qualified_username,
106+
# Set the issue time to now.
107+
ISSUE_TIME: now,
108+
# Set the expiration time, based on the lifetime specified for this object.
109+
EXPIRE_TIME: now + self.lifetime,
110+
}
111+
112+
# Regenerate the actual token
113+
token = jwt.encode(payload, key=self.private_key, algorithm=JWTGenerator.ALGORITHM)
114+
# If you are using a version of PyJWT prior to 2.0, jwt.encode returns a byte string instead of a string.
115+
# If the token is a byte string, convert it to a string.
116+
if isinstance(token, bytes):
117+
token = token.decode("utf-8")
118+
self.token = token
119+
logger.info(
120+
"Generated a JWT with the following payload: %s",
121+
jwt.decode(self.token, key=self.private_key.public_key(), algorithms=[JWTGenerator.ALGORITHM]),
122+
)
123+
124+
return token
125+
126+
@staticmethod
127+
def _calculate_public_key_fingerprint(private_key: types.PRIVATE_KEY_TYPES) -> str:
128+
# Get the raw bytes of public key.
129+
public_key_raw = private_key.public_key().public_bytes(
130+
serialization.Encoding.DER, serialization.PublicFormat.SubjectPublicKeyInfo
131+
)
132+
133+
# Get the sha256 hash of the raw bytes.
134+
sha256hash = hashlib.sha256()
135+
sha256hash.update(public_key_raw)
136+
137+
# Base64-encode the value and prepend the prefix 'SHA256:'.
138+
public_key_fp = "SHA256:" + base64.b64encode(sha256hash.digest()).decode("utf-8")
139+
logger.info("Public key fingerprint is %s", public_key_fp)
140+
141+
return public_key_fp

0 commit comments

Comments
 (0)