diff --git a/data/cmsTritonChecksumTool b/data/cmsTritonChecksumTool new file mode 100755 index 00000000000..eb13e33619c --- /dev/null +++ b/data/cmsTritonChecksumTool @@ -0,0 +1,105 @@ +#!/usr/bin/env python3 + +import os, sys +from collections import OrderedDict +from argparse import ArgumentParser, ArgumentDefaultsHelpFormatter +from google.protobuf import text_format, message, descriptor +from tritonclient import grpc + +def get_checksum(filename, chunksize=4096): + import hashlib + with open(filename, 'rb') as f: + file_hash = hashlib.md5() + while chunk := f.read(chunksize): + file_hash.update(chunk) + return file_hash.hexdigest() + +def update_config(args): + # update config path to be output path (in case view is called) + if args.copy: + args.config = "config.pbtxt" + if isinstance(args.copy,str): + args.config = os.path.join(args.copy, args.config) + + with open(args.config,'w') as outfile: + text_format.PrintMessage(args.model, outfile, use_short_repeated_primitives=True) + +def cfg_common(args): + args.model = grpc.model_config_pb2.ModelConfig() + if hasattr(args,'config'): + with open(args.config,'r') as infile: + text_format.Parse(infile.read(), args.model) + +def cfg_checksum(args): + agents = args.model.model_repository_agents.agents + checksum_agent = next((agent for agent in agents if agent.name=="checksum"), None) + if checksum_agent is None: + checksum_agent = agents.add(name="checksum") + + incorrect = [] + missing = [] + + from glob import glob + config_dir = os.path.dirname(args.config) + for filename in glob(os.path.join(config_dir,"*/*")): + if os.path.islink(os.path.dirname(filename)): continue + checksum = get_checksum(filename) + # key = algorithm:[filename relative to config.pbtxt dir] + filename = os.path.relpath(filename, config_dir) + filekey = "MD5:{}".format(filename) + if filekey in checksum_agent.parameters and checksum!=checksum_agent.parameters[filekey]: + incorrect.append(filename) + if args.update and args.force: + checksum_agent.parameters[filekey] = checksum + elif filekey not in checksum_agent.parameters: + missing.append(filename) + if args.update: + checksum_agent.parameters[filekey] = checksum + else: + continue + + needs_update = len(missing)>0 + needs_force_update = len(incorrect)>0 + + if not args.quiet: + if needs_update: + print("\n".join(["Missing checksums:"]+missing)) + if needs_force_update: + print("\n".join(["Incorrect checksums:"]+incorrect)) + + if needs_force_update: + if not (args.update and args.force): + extra_args = [arg for arg in ["--update","--force"] if arg not in sys.argv] + raise RuntimeError("\n".join([ + "Incorrect checksum(s) found, indicating existing model file(s) has been changed.", + "This violates policy. To override, run the following command, and provide a justification in your PR:", + "{} {}".format(" ".join(sys.argv), " ".join(extra_args)) + ])) + else: + update_config(args) + elif needs_update: + if not args.update: + extra_args = [arg for arg in ["--update"] if arg not in sys.argv] + raise RuntimeError("\n".join([ + "Missing checksum(s) found, indicating new model file(s).", + "To update, run the following command:", + "{} {}".format(" ".join(sys.argv), " ".join(extra_args)) + ])) + else: + update_config(args) + +if __name__=="__main__": + parser = ArgumentParser(formatter_class=ArgumentDefaultsHelpFormatter, description="handle model file checksums") + parser.add_argument("-c", "--config", type=str, default="", required=True, help="path to input config.pbtxt file") + parser.add_argument("--copy", metavar="dir", default=False, const=True, nargs='?', type=str, + help="make a copy of config.pbtxt instead of editing in place ([dir] = output path for copy; if omitted, current directory is used)" + ) + parser.add_argument("--update", default=False, action="store_true", help="update missing checksums in config.pbtxt") + parser.add_argument("--force", default=False, action="store_true", help="force update all checksums in config.pbtxt") + parser.add_argument("--quiet", default=False, action="store_true", help="suppress printouts") + + args = parser.parse_args() + + cfg_common(args) + + cfg_checksum(args) diff --git a/pip/requirements.txt b/pip/requirements.txt index 3358a89bb74..bbc0a14991b 100644 --- a/pip/requirements.txt +++ b/pip/requirements.txt @@ -359,6 +359,8 @@ toolz==0.7.1 tornado==6.3.3 tqdm==4.65.0 traitlets==5.9.0 +# always sync version number with triton-inference-client.spec (C++ version) +tritonclient==2.25.0 trove-classifiers==2023.3.9 typed-ast==1.5.4 typing-extensions==4.5.0 diff --git a/protobuf.spec b/protobuf.spec index 29467042de0..b0e0962a68c 100644 --- a/protobuf.spec +++ b/protobuf.spec @@ -14,6 +14,8 @@ Source: https://github.com/protocolbuffers/protobuf/archive/v%{realversion}.zip Requires: zlib BuildRequires: cmake ninja +# improves text_format printing +Patch0: protobuf_text_format %prep %setup -n %{n}-%{realversion} diff --git a/protobuf_text_format.patch b/protobuf_text_format.patch new file mode 100644 index 00000000000..bebf7006d79 --- /dev/null +++ b/protobuf_text_format.patch @@ -0,0 +1,15 @@ +diff --git a/python/google/protobuf/text_format.py b/python/google/protobuf/text_format.py +index a6d8bcf64..24da4cac5 100644 +--- a/python/google/protobuf/text_format.py ++++ b/python/google/protobuf/text_format.py +@@ -470,9 +470,7 @@ class _Printer(object): + entry_submsg = value.GetEntryClass()(key=key, value=value[key]) + self.PrintField(field, entry_submsg) + elif field.label == descriptor.FieldDescriptor.LABEL_REPEATED: +- if (self.use_short_repeated_primitives +- and field.cpp_type != descriptor.FieldDescriptor.CPPTYPE_MESSAGE +- and field.cpp_type != descriptor.FieldDescriptor.CPPTYPE_STRING): ++ if self.use_short_repeated_primitives: + self._PrintShortRepeatedPrimitivesValue(field, value) + else: + for element in value: diff --git a/python_tools.spec b/python_tools.spec index 7d318b85046..cc08f9b338f 100644 --- a/python_tools.spec +++ b/python_tools.spec @@ -156,6 +156,7 @@ Requires: py3-subprocess32 Requires: py3-kiwisolver Requires: py3-pillow Requires: py3-pydot +Requires: py3-tritonclient Requires: py3-astroid Requires: py3-hepdata-lib diff --git a/scram-project-build.file b/scram-project-build.file index f3c919df7d5..38171c0b9dd 100644 --- a/scram-project-build.file +++ b/scram-project-build.file @@ -312,6 +312,9 @@ find external/%cmsplatf -type l | xargs ls -l %{?PatchReleaseSymlinkRelocate:%PatchReleaseSymlinkRelocate} echo "%{cmsroot}" > %{i}/config/scram_basedir +# install checksum tool for local use (via scram) +cp data/cmsTritonChecksumTool %i/bin/ + %post export SCRAM_ARCH=%cmsplatf cd $RPM_INSTALL_PREFIX/%pkgrel diff --git a/triton-inference-client.spec b/triton-inference-client.spec index 2692763c69d..97b24e9d21c 100644 --- a/triton-inference-client.spec +++ b/triton-inference-client.spec @@ -1,3 +1,4 @@ +# always sync version number with tritonclient in pip/requirements.txt (python version) ### RPM external triton-inference-client 2.25.0 ## INCLUDE cpp-standard %define branch r22.08