From 9d4aa5784257bd54576883f2dcae203f9a4788a3 Mon Sep 17 00:00:00 2001 From: Kevin Pedro Date: Mon, 20 Nov 2023 18:14:19 -0600 Subject: [PATCH 01/12] tool to inspect and edit config.pbtxt files --- .../SonicTriton/scripts/cmsTriton | 60 +-- .../SonicTriton/scripts/cmsTritonConfigTool | 367 ++++++++++++++++++ 2 files changed, 374 insertions(+), 53 deletions(-) create mode 100755 HeterogeneousCore/SonicTriton/scripts/cmsTritonConfigTool diff --git a/HeterogeneousCore/SonicTriton/scripts/cmsTriton b/HeterogeneousCore/SonicTriton/scripts/cmsTriton index addbfb2c247c7..9c84be2b62616 100755 --- a/HeterogeneousCore/SonicTriton/scripts/cmsTriton +++ b/HeterogeneousCore/SonicTriton/scripts/cmsTriton @@ -34,7 +34,7 @@ get_sandbox(){ usage() { ECHO="echo -e" - $ECHO "cmsTriton [options] [start|stop]" + $ECHO "cmsTriton [options] [start|stop|check]" $ECHO $ECHO "Options:" $ECHO "-c \t don't cleanup temporary dir (for debugging)" @@ -338,57 +338,6 @@ wait_server(){ echo "server is ready!" } -edit_model(){ - MODELNAME=$1 - NUMINSTANCES=$2 - - cp -r $MODELNAME $TMPDIR/$LOCALMODELREPO/ - COPY_EXIT=$? - if [ "$COPY_EXIT" -ne 0 ]; then - echo "Could not copy $MODELNAME into $TMPDIR/$LOCALMODELREPO/" - exit "$COPY_EXIT" - fi - IFS='/' read -ra ADDR <<< "$MODELNAME" - CONFIG=$TMPDIR/$LOCALMODELREPO/${ADDR[-1]}/config.pbtxt - - PLATFORM=$(grep -m 1 "^platform:" "$CONFIG") - - if [[ $PLATFORM == *"ensemble"* ]]; then - #recurse over submodels of ensemble model - MODELLOC=$(echo ""${ADDR[@]:0:${#ADDR[@]}-1} | sed "s/ /\//g") - SUBNAME=$(grep "model_name:" "$CONFIG" | sed 's/model_name://; s/"//g') - for SUBMODEL in ${SUBNAME}; do - SUBMODEL=${MODELLOC}/${SUBMODEL} - edit_model $SUBMODEL "$INSTANCES" - done - else - #This is not an ensemble model, so we should edit the config file - cat <> $CONFIG -instance_group [ - { - count: $NUMINSTANCES - kind: KIND_CPU - } -] - -EOF - if [[ $PLATFORM == *"onnx"* ]]; then - cat <> $CONFIG -parameters { key: "intra_op_thread_count" value: { string_value: "1" } } -parameters { key: "inter_op_thread_count" value: { string_value: "1" } } -EOF - elif [[ $PLATFORM == *"tensorflow"* ]]; then - cat <> $CONFIG -parameters { key: "TF_NUM_INTRA_THREADS" value: { string_value: "1" } } -parameters { key: "TF_NUM_INTER_THREADS" value: { string_value: "1" } } -parameters { key: "TF_USE_PER_SESSION_THREADS" value: { string_value: "1" } } -EOF - else - echo "Warning: thread (instance) control not implemented for $PLATFORM" - fi - fi -} - list_models(){ # make list of model repositories LOCALMODELREPO="local_model_repo" @@ -411,7 +360,12 @@ list_models(){ MODEL="$(dirname "$MODEL")" fi if [ "$INSTANCES" -gt 0 ]; then - edit_model $MODEL "$INSTANCES" + $DRYRUN cmsTritonConfigTool threadcontrol -c ${MODEL}/config.pbtxt --copy $TMPDIR/$LOCALMODELREPO --nThreads $INSTANCES + TOOL_EXIT=$? + if [ "$TOOL_EXIT" -ne 0 ]; then + echo "Could not apply threadcontrol to $MODEL" + exit "$TOOL_EXIT" + fi else REPOS+=("$(dirname "$MODEL")") fi diff --git a/HeterogeneousCore/SonicTriton/scripts/cmsTritonConfigTool b/HeterogeneousCore/SonicTriton/scripts/cmsTritonConfigTool new file mode 100755 index 0000000000000..42740035fd800 --- /dev/null +++ b/HeterogeneousCore/SonicTriton/scripts/cmsTritonConfigTool @@ -0,0 +1,367 @@ +#!/usr/bin/env python3 + +import os, sys, json, pathlib, shutil +from collections import OrderedDict +from argparse import ArgumentParser, ArgumentDefaultsHelpFormatter, RawTextHelpFormatter, RawDescriptionHelpFormatter, Action +from google.protobuf import text_format, json_format, message, descriptor +from google.protobuf.internal import type_checkers +from tritonclient import grpc + +# convenience definition +# (from ConfigArgParse) +class ArgumentDefaultsRawHelpFormatter( + ArgumentDefaultsHelpFormatter, + RawTextHelpFormatter, + RawDescriptionHelpFormatter): + """HelpFormatter that adds default values AND doesn't do line-wrapping""" +pass + +class DictAction(Action): + val_type = None + def __call__(self, parser, namespace, values, option_string=None): + if self.val_type is None: + self.val_type = self.type + result = {} + if len(values)%2!=0: + parser.error("{} args must come in pairs".format(self.dest)) + for i in range(0, len(values), 2): + result[values[i]] = self.val_type(values[i+1]) + setattr(namespace, self.dest, result) + +message_classes = {cls.__name__ : cls for cls in message.Message.__subclasses__()} + +_FieldDescriptor = descriptor.FieldDescriptor +cpp_to_python = { + _FieldDescriptor.CPPTYPE_INT32: int, + _FieldDescriptor.CPPTYPE_INT64: int, + _FieldDescriptor.CPPTYPE_UINT32: int, + _FieldDescriptor.CPPTYPE_UINT64: int, + _FieldDescriptor.CPPTYPE_DOUBLE: float, + _FieldDescriptor.CPPTYPE_FLOAT: float, + _FieldDescriptor.CPPTYPE_BOOL: bool, + _FieldDescriptor.CPPTYPE_STRING: str, +} +checker_to_type = {val.__class__:cpp_to_python[key] for key,val in type_checkers._VALUE_CHECKERS.items()} +# for some reason, this one is not in the map +checker_to_type[type_checkers.UnicodeValueChecker] = str + +kind_to_int = {v.name:v.number for v in grpc.model_config_pb2._MODELINSTANCEGROUP_KIND.values} +thread_control_parameters = { + "onnx": ["intra_op_thread_count", "inter_op_thread_count"], + "tensorflow": ["TF_NUM_INTRA_THREADS", "TF_NUM_INTER_THREADS", "TF_USE_PER_SESSION_THREADS"], +} + +def get_type(obj): + obj_type = obj.__class__.__name__ + entry_type = None + entry_class = None + if obj_type=="RepeatedCompositeFieldContainer" or obj_type=="MessageMap": + entry_type = obj._message_descriptor.name + entry_class = message_classes[entry_type] + elif obj_type=="RepeatedScalarFieldContainer": + entry_class = checker_to_type[obj._type_checker.__class__] + entry_type = entry_class.__name__ + elif obj_type=="ScalarMap": + entry_class = obj.GetEntryClass()().value.__class__ + entry_type = entry_class.__name__ + return { + "class": obj.__class__, + "type": obj_type+("<"+entry_type+">" if entry_type is not None else ""), + "entry_class": entry_class, + "entry_type": entry_type, + } + +def get_fields(obj, name, level=0, verbose=False): + prefix = ' '*level + obj_info = {"name": name, "fields": []} + obj_info.update(get_type(obj)) + if verbose: print(prefix+obj_info["type"],name) + field_obj = None + if hasattr(obj, "DESCRIPTOR"): + field_obj = obj + elif obj_info["entry_class"] is not None and hasattr(obj_info["entry_class"], "DESCRIPTOR"): + field_obj = obj_info["entry_class"]() + field_list = [] + if field_obj is not None: + field_list = [f.name for f in field_obj.DESCRIPTOR.fields] + for field in field_list: + obj_info["fields"].append(get_fields(getattr(field_obj,field),field,level+1,verbose)) + return obj_info + +def msg_json(val, defaults=False): + return json_format.MessageToJson(val, preserving_proto_field_name=True, including_default_value_fields=defaults, indent=0).replace(",\n",", ").replace("\n","") + +def print_fields(obj, info, level=0, json=False, defaults=False): + def print_subfields(obj,level): + fields = obj.DESCRIPTOR.fields if defaults else [f[0] for f in obj.ListFields()] + for field in fields: + print_fields(getattr(obj,field.name), next(f for f in info["fields"] if f["name"]==field.name), level=level, json=json, defaults=defaults) + + prefix = ' ' + print(prefix*level+info["type"],info["name"]) + if hasattr(obj, "DESCRIPTOR"): + if json and level>0: + print(prefix*(level+1)+msg_json(obj, defaults)) + else: + print_subfields(obj,level+1) + elif info["type"].startswith("RepeatedCompositeFieldContainer"): + if json: + print(prefix*(level+1)+str([msg_json(val, defaults) for val in obj])) + else: + for ientry,entry in enumerate(obj): + print(prefix*(level+1)+"{}: ".format(ientry)) + print_subfields(entry,level+2) + elif info["type"].startswith("MessageMap"): + if json: + print(prefix*(level+1)+str({key:msg_json(val, defaults) for key,val in obj.items()})) + else: + for key,val in obj.items(): + print(prefix*(level+1)+"{}: ".format(key)) + print_subfields(val,level+2) + else: + print(prefix*(level+1)+str(obj)) + +def edit_builtin(model,dest,val): + setattr(model,dest,val) + +def edit_scalar_list(model,dest,val): + item = getattr(model,dest) + item.clear() + item.extend(val) + +def edit_scalar_map(model,dest,val): + item = getattr(model,dest) + item.clear() + item.update(val) + +def edit_msg(model,dest,val): + item = getattr(model,dest) + json_format.ParseDict(val,item) + +def edit_msg_list(model,dest,val): + item = getattr(model,dest) + item.clear() + for v in vals: + m = item.add() + json_format.ParseDict(v,m) + +def edit_msg_map(model,dest,val): + item = getattr(model,dest) + item.clear() + for k,v in vals.items(): + m = item.get_or_create(k) + json_format.ParseDict(v,m) + +def add_edit_args(parser, model_info): + group = parser.add_argument_group("fields", description="ModelConfig fields to edit") + dests = {} + for field in model_info["fields"]: + argname = "--{}".format(field["name"].replace("_","-")) + val_type = None + editor = None + if field["class"].__module__=="builtins": + kwargs = dict(type=field["class"]) + editor = edit_builtin + elif field["type"].startswith("RepeatedScalarFieldContainer"): + kwargs = dict(type=field["entry_class"], nargs='*') + editor = edit_scalar_list + elif field["type"].startswith("ScalarMap"): + kwargs = dict(type=str, nargs='*', metavar="key value", action=DictAction) + val_type = field["entry_class"] + editor = edit_scalar_map + elif field["type"].startswith("RepeatedCompositeFieldContainer"): + kwargs = dict(type=json.loads, nargs='*', + help="provide {} values in json format".format(field["entry_type"]) + ) + editor = edit_msg_list + elif field["type"].startswith("MessageMap"): + kwargs = dict(type=str, nargs='*', metavar="key value", action=DictAction, + help="provide {} values in json format".format(field["entry_type"]) + ) + editor = edit_msg_map + val_type = json.loads + else: + kwargs = dict(type=json.loads, + help="provide {} values in json format".format(field["type"]) + ) + edit = edit_msg + action = group.add_argument(argname, **kwargs) + if val_type is not None: action.val_type = val_type + dests[action.dest] = editor + return parser, dests + +def get_checksum(filename, chunksize=4096): + import hashlib + with open(filename, 'rb') as f: + file_hash = hashlib.md5() + while chunk := f.read(chunksize): + file_hash.update(chunk) + return file_hash.hexdigest() + +def update_config(args): + # update config path to be output path (in case view is called) + if args.copy: + args.config = "config.pbtxt" + if isinstance(args.copy,str): + args.config = os.path.join(args.copy, args.config) + + with open(args.config,'w') as outfile: + text_format.PrintMessage(args.model, outfile, use_short_repeated_primitives=True) + +def cfg_common(args): + args.model = grpc.model_config_pb2.ModelConfig() + if hasattr(args,'config'): + with open(args.config,'r') as infile: + text_format.Parse(infile.read(), args.model) + +def cfg_schema(args): + get_fields(args.model, "ModelConfig", verbose=True) + +def cfg_view(args): + print("Contents of {}".format(args.config)) + print_fields(args.model, args.model_info, json=args.json, defaults=args.defaults) + +def cfg_edit(args): + for dest,editor,val in [(dest,editor,getattr(args,dest)) for dest,editor in args.edit_dests.items() if getattr(args,dest) is not None]: + editor(args.model,dest,val) + + update_config(args) + + if args.view: + cfg_view(args) + +def cfg_checksum(args): + agents = args.model.model_repository_agents.agents + checksum_agent = next((agent for agent in agents if agent.name=="checksum"), None) + if checksum_agent is None: + checksum_agent = agents.add(name="checksum") + + incorrect = [] + missing = [] + + from glob import glob + config_dir = os.path.dirname(args.config) + for filename in glob(os.path.join(config_dir,"*/*")): + checksum = get_checksum(filename) + # key = algorithm:[filename relative to config.pbtxt dir] + filename = os.path.relpath(filename, config_dir) + filekey = "MD5:{}".format(filename) + if filekey in checksum_agent.parameters and checksum!=checksum_agent.parameters[filekey]: + incorrect.append(filename) + elif filekey not in checksum_agent.parameters: + missing.append(filename) + else: + continue + if args.update: + checksum_agent.parameters[filekey] = checksum + + if len(incorrect)>0 or len(missing)>0: + if not args.quiet: + if len(incorrect)>0: + print("\n".join(["Incorrect checksums:"]+incorrect)) + if len(missing)>0: + print("\n".join(["Missing checksums:"]+missing)) + if args.update: + update_config(args) + else: + sys.exit(1) + + if args.view: + cfg_view(args) + +def cfg_threadcontrol(args): + # copy the entire model, not just config.pbtxt + config_dir = os.path.dirname(args.config) + copy_dir = args.copy + new_config_dir = os.path.join(copy_dir, pathlib.Path(config_dir).name) + shutil.copytree(config_dir, new_config_dir) + + platform = args.model.platform + if platform=="ensemble": + repo_dir = pathlib.Path(config_dir).parent + for step in args.model.ensemble_scheduling.step: + # update args and run recursively + args.config = os.path.join(repo_dir,step.model_name,"config.pbtxt") + args.copy = copy_dir + cfg_common(args) + cfg_threadcontrol(args) + return + + # is it correct to do this even if found_params is false at the end? + args.model.instance_group.add(count=args.nThreads, kind=kind_to_int['KIND_CPU']) + + found_params = False + for key,val in thread_control_parameters.items(): + if key in platform: # partial matching + for param in val: + item = args.model.parameters.get_or_create(key) + item.string_value = "1" + found_params = True + break + if not found_params: + print("Warning: thread (instance) control not implemented for {}".format(platform)) + + args.copy = new_config_dir + update_config(args) + + if args.view: + cfg_view(args) + +if __name__=="__main__": + # initial common operations + model_info = get_fields(grpc.model_config_pb2.ModelConfig(), "ModelConfig") + edit_dests = None + + _parser_common = ArgumentParser(add_help=False) + _parser_common.add_argument("-c", "--config", type=str, default="", required=True, help="path to input config.pbtxt file") + + parser = ArgumentParser(formatter_class=ArgumentDefaultsRawHelpFormatter) + subparsers = parser.add_subparsers(dest="command") + + parser_schema = subparsers.add_parser("schema", help="view ModelConfig schema", + description="""Display all fields in the ModelConfig object, with type information. + (For collection types, the subfields of the entry type are shown.)""", + ) + parser_schema.set_defaults(func=cfg_schema) + + _parser_view_args = ArgumentParser(add_help=False) + _parser_view_args.add_argument("--json", default=False, action="store_true", help="display in json format") + _parser_view_args.add_argument("--defaults", default=False, action="store_true", help="show fields with default values") + + parser_view = subparsers.add_parser("view", parents=[_parser_common, _parser_view_args], help="view config.pbtxt contents") + parser_view.set_defaults(func=cfg_view) + + _parser_copy_view = ArgumentParser(add_help=False) + _parser_copy_view.add_argument("--view", default=False, action="store_true", help="view file after editing") + + _parser_copy = ArgumentParser(add_help=False, parents=[_parser_copy_view]) + _parser_copy.add_argument("--copy", metavar="dir", default=False, const=True, nargs='?', type=str, + help="make a copy of config.pbtxt instead of editing in place ([dir] = output path for copy; if omitted, current directory is used)" + ) + + parser_edit = subparsers.add_parser("edit", parents=[_parser_common, _parser_copy, _parser_view_args], help="edit config.pbtxt contents") + parser_edit, edit_dests = add_edit_args(parser_edit, model_info) + parser_edit.set_defaults(func=cfg_edit) + + parser_checksum = subparsers.add_parser("checksum", parents=[_parser_common, _parser_copy, _parser_view_args], help="handle model file checksums") + parser_checksum.add_argument("--update", default=False, action="store_true", help="update checksums in config.pbtxt") + parser_checksum.add_argument("--quiet", default=False, action="store_true", help="suppress printouts") + parser_checksum.set_defaults(func=cfg_checksum) + + _parser_copy_req = ArgumentParser(add_help=False, parents=[_parser_copy_view]) + _parser_copy_req.add_argument("--copy", metavar="dir", type=str, required=True, + help="local model repository directory to copy model(s)" + ) + + parser_threadcontrol = subparsers.add_parser("threadcontrol", parents=[_parser_common, _parser_copy_req, _parser_view_args], help="enable thread controls") + parser_threadcontrol.add_argument("--nThreads", type=int, required=True, help="number of threads") + parser_threadcontrol.set_defaults(func=cfg_threadcontrol) + + args = parser.parse_args() + args.model_info = model_info + if edit_dests is not None: + args.edit_dests = edit_dests + + cfg_common(args) + + args.func(args) From d7303c5d1134e7b59b738a3a3e92bdac353fea05 Mon Sep 17 00:00:00 2001 From: Kevin Pedro Date: Mon, 4 Dec 2023 18:09:35 -0600 Subject: [PATCH 02/12] perform version check in client --- .../SonicTriton/interface/triton_utils.h | 1 + .../SonicTriton/src/TritonClient.cc | 59 ++++++++++++++++--- .../SonicTriton/src/triton_utils.cc | 1 + 3 files changed, 54 insertions(+), 7 deletions(-) diff --git a/HeterogeneousCore/SonicTriton/interface/triton_utils.h b/HeterogeneousCore/SonicTriton/interface/triton_utils.h index 159da808edcab..d6c7612a5159c 100644 --- a/HeterogeneousCore/SonicTriton/interface/triton_utils.h +++ b/HeterogeneousCore/SonicTriton/interface/triton_utils.h @@ -83,6 +83,7 @@ extern template std::string triton_utils::printColl(const edm::Span& coll, const std::string& delim); extern template std::string triton_utils::printColl(const std::vector& coll, const std::string& delim); +extern template std::string triton_utils::printColl(const std::vector& coll, const std::string& delim); extern template std::string triton_utils::printColl(const std::unordered_set& coll, const std::string& delim); diff --git a/HeterogeneousCore/SonicTriton/src/TritonClient.cc b/HeterogeneousCore/SonicTriton/src/TritonClient.cc index c57a8355d07a1..14d586b0b547f 100644 --- a/HeterogeneousCore/SonicTriton/src/TritonClient.cc +++ b/HeterogeneousCore/SonicTriton/src/TritonClient.cc @@ -9,11 +9,18 @@ #include "grpc_client.h" #include "grpc_service.pb.h" +#include "model_config.pb.h" -#include +#include "google/protobuf/text_format.h" +#include "google/protobuf/io/zero_copy_stream_impl.h" + +#include #include #include +#include +#include #include +#include #include #include @@ -75,22 +82,60 @@ TritonClient::TritonClient(const edm::ParameterSet& params, const std::string& d //convert seconds to microseconds options_[0].client_timeout_ = params.getUntrackedParameter("timeout") * 1e6; - //config needed for batch size - inference::ModelConfigResponse modelConfigResponse; - TRITON_THROW_IF_ERROR(client_->ModelConfig(&modelConfigResponse, options_[0].model_name_, options_[0].model_version_), - "TritonClient(): unable to get model config"); - inference::ModelConfig modelConfig(modelConfigResponse.config()); + //get fixed parameters from local config + inference::ModelConfig localModelConfig; + { + const std::string& localModelConfigPath(params.getParameter("modelConfigPath").fullPath()); + int fileDescriptor = open(localModelConfigPath.c_str(), O_RDONLY); + if (fileDescriptor < 0) + throw TritonException("LocalFailure") << "TritonClient(): unable to open local model config: " << localModelConfigPath; + google::protobuf::io::FileInputStream localModelConfigInput(fileDescriptor); + localModelConfigInput.SetCloseOnDelete(true); + if (!google::protobuf::TextFormat::Parse(&localModelConfigInput, &localModelConfig)) + throw TritonException("LocalFailure") << "TritonClient(): unable to parse local model config: " << localModelConfigPath; + } //check batch size limitations (after i/o setup) //triton uses max batch size = 0 to denote a model that does not support native batching (using the outer dimension) //but for models that do support batching (native or otherwise), a given event may set batch size 0 to indicate no valid input is present //so set the local max to 1 and keep track of "no outer dim" case - maxOuterDim_ = modelConfig.max_batch_size(); + maxOuterDim_ = localModelConfig.max_batch_size(); noOuterDim_ = maxOuterDim_ == 0; maxOuterDim_ = std::max(1u, maxOuterDim_); //propagate batch size setBatchSize(1); + //compare model checksums to remote config to enforce versioning + inference::ModelConfigResponse modelConfigResponse; + TRITON_THROW_IF_ERROR(client_->ModelConfig(&modelConfigResponse, options_[0].model_name_, options_[0].model_version_), + "TritonClient(): unable to get model config"); + inference::ModelConfig remoteModelConfig(modelConfigResponse.config()); + + std::map> checksums; + size_t fileCounter = 0; + for (const auto& modelConfig: {localModelConfig, remoteModelConfig}) { + const auto& agents = modelConfig.model_repository_agents().agents(); + for (const auto& agent : agents) { + if (agent.name() == "checksum") { + const auto& params = agent.parameters(); + for (const auto& [key, val]: params) { + // only check the requested version + if (key.compare(0, options_[0].model_version_.size()+1, options_[0].model_version_+"/")==0) + checksums[key][fileCounter] = val; + } + break; + } + } + ++fileCounter; + } + std::vector incorrect; + for (const auto& [key, val]: checksums) { + if (checksums[key][0] != checksums[key][1]) + incorrect.push_back(key); + } + if (!incorrect.empty()) + throw TritonException("ModelVersioning") << "The following files have incorrect checksums on the remote server: " << triton_utils::printColl(incorrect, ", "); + //get model info inference::ModelMetadataResponse modelMetadata; TRITON_THROW_IF_ERROR(client_->ModelMetadata(&modelMetadata, options_[0].model_name_, options_[0].model_version_), diff --git a/HeterogeneousCore/SonicTriton/src/triton_utils.cc b/HeterogeneousCore/SonicTriton/src/triton_utils.cc index 3dc872d6e1b42..a71190d951e46 100644 --- a/HeterogeneousCore/SonicTriton/src/triton_utils.cc +++ b/HeterogeneousCore/SonicTriton/src/triton_utils.cc @@ -21,4 +21,5 @@ template std::string triton_utils::printColl(const edm::Span& coll, const std::string& delim); template std::string triton_utils::printColl(const std::vector& coll, const std::string& delim); +template std::string triton_utils::printColl(const std::vector& coll, const std::string& delim); template std::string triton_utils::printColl(const std::unordered_set& coll, const std::string& delim); From 713f6869b468f824bb1502db126b660c8cb22b32 Mon Sep 17 00:00:00 2001 From: Kevin Pedro Date: Mon, 11 Dec 2023 13:52:29 -0600 Subject: [PATCH 03/12] nicer printing for repeated messages --- .../SonicTriton/scripts/cmsTritonConfigTool | 29 +++++++++++++++++++ 1 file changed, 29 insertions(+) diff --git a/HeterogeneousCore/SonicTriton/scripts/cmsTritonConfigTool b/HeterogeneousCore/SonicTriton/scripts/cmsTritonConfigTool index 42740035fd800..2e73407ebba78 100755 --- a/HeterogeneousCore/SonicTriton/scripts/cmsTritonConfigTool +++ b/HeterogeneousCore/SonicTriton/scripts/cmsTritonConfigTool @@ -28,6 +28,35 @@ class DictAction(Action): result[values[i]] = self.val_type(values[i+1]) setattr(namespace, self.dest, result) +# patched version for more succinct printing of repeated messages +def PrintMessage(self, message): + if self.message_formatter and self._TryCustomFormatMessage(message): + return + if (message.DESCRIPTOR.full_name == text_format._ANY_FULL_TYPE_NAME and + self._TryPrintAsAnyMessage(message)): + return + fields = message.ListFields() + if self.use_index_order: + fields.sort(key=lambda x: x[0].number if x[0].is_extension else x[0].index) + for field, value in fields: + if text_format._IsMapEntry(field): + for key in sorted(value): + entry_submsg = value.GetEntryClass()(key=key, value=value[key]) + self.PrintField(field, entry_submsg) + elif field.label == descriptor.FieldDescriptor.LABEL_REPEATED: + if self.use_short_repeated_primitives: # no other conditions here + self._PrintShortRepeatedPrimitivesValue(field, value) + else: + for element in value: + self.PrintField(field, element) + else: + self.PrintField(field, value) + + if self.print_unknown_fields: + self._PrintUnknownFields(unknown_fields.UnknownFieldSet(message)) + +text_format._Printer.PrintMessage = PrintMessage + message_classes = {cls.__name__ : cls for cls in message.Message.__subclasses__()} _FieldDescriptor = descriptor.FieldDescriptor From 9c25faddd2f552561eee7eb3457ff0214701b54b Mon Sep 17 00:00:00 2001 From: Kevin Pedro Date: Mon, 11 Dec 2023 14:39:57 -0600 Subject: [PATCH 04/12] document new tool --- HeterogeneousCore/SonicTriton/README.md | 24 +++++++++++++++++++++++- 1 file changed, 23 insertions(+), 1 deletion(-) diff --git a/HeterogeneousCore/SonicTriton/README.md b/HeterogeneousCore/SonicTriton/README.md index 314b5d4d15986..b30f5c77ce8c7 100644 --- a/HeterogeneousCore/SonicTriton/README.md +++ b/HeterogeneousCore/SonicTriton/README.md @@ -124,14 +124,18 @@ In a SONIC Triton producer, the basic flow should follow this pattern: ## Services +### `cmsTriton` + A script [`cmsTriton`](./scripts/cmsTriton) is provided to launch and manage local servers. -The script has two operations (`start` and `stop`) and the following options: +The script has three operations (`start`, `stop`, `check`) and the following options: * `-c`: don't cleanup temporary dir (for debugging) +* `-C [dir]`: directory containing Nvidia compatibility drivers (checks CMSSW_BASE by default if available) * `-D`: dry run: print container commands rather than executing them * `-d`: use Docker instead of Apptainer * `-f`: force reuse of (possibly) existing container instance * `-g`: use GPU instead of CPU * `-i` [name]`: server image name (default: fastml/triton-torchgeo:22.07-py3-geometric) +* `-I [num]`: number of model instances (default: 0 -> means no local editing of config files) * `-M [dir]`: model repository (can be given more than once) * `-m [dir]`: specific model directory (can be given more than one) * `-n [name]`: name of container instance, also used for hidden temporary dir (default: triton_server_instance) @@ -148,6 +152,7 @@ Additional details and caveats: * The `start` and `stop` operations for a given container instance should always be executed in the same directory if a relative path is used for the hidden temporary directory (including the default from the container instance name), in order to ensure that everything is properly cleaned up. +* The `check` operation just checks if the server can run on the current system, based on driver compatibility. * A model repository is a folder that contains multiple model directories, while a model directory contains the files for a specific file. (In the example below, `$CMSSW_BASE/src/HeterogeneousCore/SonicTriton/data/models` is a model repository, while `$CMSSW_BASE/src/HeterogeneousCore/SonicTriton/data/models/resnet50_netdef` is a model directory.) @@ -155,6 +160,23 @@ If a model repository is provided, all of the models it contains will be provide * Older versions of Apptainer (Singularity) have a short timeout that may cause launching the server to fail the first time the command is executed. The `-r` (retry) flag exists to work around this issue. +### `cmsTritonConfigTool` + +The `config.pbtxt` files used for model configuration are written in the protobuf text format. +To ease modification of these files, a dedicated Python tool [`cmsTritonConfigTool`](./scripts/cmsTritonConfigTool) is provided. +The tool has several modes of operation (each with its own options, which can be viewed using `--help`): +* `schema`: displays all field names and types for the Triton ModelConfig message class. +* `view`: displays the field values from a provided `config.pbtxt` file. +* `edit`: allows changing any field value in a `config.pbtxt` file. Non-primitive types are specified using JSON format. +* `checksum`: checks and updates checksums for model files (to enforce versioning). +* `threadcontrol`: adds job- and ML framework-specific thread control settings. + +The `edit` mode is intended for generic modifications, and only supports overwriting existing values +(not modifying, removing, deleting, etc.). +Additional dedicated modes, like `checksum` and `threadcontrol`, can easily be added for more complicated tasks. + +### `TritonService` + A central `TritonService` is provided to keep track of all available servers and which models they can serve. The servers will automatically be assigned to clients at startup. If some models are not served by any server, the `TritonService` can launch a fallback server using the `cmsTriton` script described above. From b19ae521175f561bc2dd66267ecbd7c1adfcea6d Mon Sep 17 00:00:00 2001 From: Kevin Pedro Date: Mon, 11 Dec 2023 16:28:42 -0600 Subject: [PATCH 05/12] enforce policy against changing existing model files --- .../SonicTriton/scripts/cmsTritonConfigTool | 42 ++++++++++++++----- 1 file changed, 32 insertions(+), 10 deletions(-) diff --git a/HeterogeneousCore/SonicTriton/scripts/cmsTritonConfigTool b/HeterogeneousCore/SonicTriton/scripts/cmsTritonConfigTool index 2e73407ebba78..9ff6a9faebdf4 100755 --- a/HeterogeneousCore/SonicTriton/scripts/cmsTritonConfigTool +++ b/HeterogeneousCore/SonicTriton/scripts/cmsTritonConfigTool @@ -277,23 +277,44 @@ def cfg_checksum(args): filekey = "MD5:{}".format(filename) if filekey in checksum_agent.parameters and checksum!=checksum_agent.parameters[filekey]: incorrect.append(filename) + if args.update and args.force: + checksum_agent.parameters[filekey] = checksum elif filekey not in checksum_agent.parameters: missing.append(filename) + if args.update: + checksum_agent.parameters[filekey] = checksum else: continue - if args.update: - checksum_agent.parameters[filekey] = checksum - if len(incorrect)>0 or len(missing)>0: - if not args.quiet: - if len(incorrect)>0: + needs_update = len(missing)>0 + needs_force_update = len(incorrect)>0 + + if not args.quiet: + if needs_update: + print("\n".join(["Missing checksums:"]+missing)) + if needs_force_update: print("\n".join(["Incorrect checksums:"]+incorrect)) - if len(missing)>0: - print("\n".join(["Missing checksums:"]+missing)) - if args.update: + + if needs_force_update: + if not (args.update and args.force): + extra_args = [arg for arg in ["--update","--force"] if arg not in sys.argv] + raise RuntimeError("\n".join([ + "Incorrect checksum(s) found, indicating existing model file(s) has been changed.", + "This violates policy. To override, run the following command, and provide a justification in your PR:", + "{} {}".format(" ".join(sys.argv), " ".join(extra_args)) + ])) + else: update_config(args) + elif needs_update: + if not args.update: + extra_args = [arg for arg in ["--update"] if arg not in sys.argv] + raise RuntimeError("\n".join([ + "Missing checksum(s) found, indicating new model file(s).", + "To update, run the following command:", + "{} {}".format(" ".join(sys.argv), " ".join(extra_args)) + ])) else: - sys.exit(1) + update_config(args) if args.view: cfg_view(args) @@ -373,7 +394,8 @@ if __name__=="__main__": parser_edit.set_defaults(func=cfg_edit) parser_checksum = subparsers.add_parser("checksum", parents=[_parser_common, _parser_copy, _parser_view_args], help="handle model file checksums") - parser_checksum.add_argument("--update", default=False, action="store_true", help="update checksums in config.pbtxt") + parser_checksum.add_argument("--update", default=False, action="store_true", help="update missing checksums in config.pbtxt") + parser_checksum.add_argument("--force", default=False, action="store_true", help="force update all checksums in config.pbtxt") parser_checksum.add_argument("--quiet", default=False, action="store_true", help="suppress printouts") parser_checksum.set_defaults(func=cfg_checksum) From a4f04da6259abd72fba4b7e1ab3043a2405dac3b Mon Sep 17 00:00:00 2001 From: Kevin Pedro Date: Wed, 10 Jan 2024 15:14:36 -0600 Subject: [PATCH 06/12] patch protobuf instead of monkey-patching --- .../SonicTriton/scripts/cmsTritonConfigTool | 29 ------------------- 1 file changed, 29 deletions(-) diff --git a/HeterogeneousCore/SonicTriton/scripts/cmsTritonConfigTool b/HeterogeneousCore/SonicTriton/scripts/cmsTritonConfigTool index 9ff6a9faebdf4..5ca689667aa7b 100755 --- a/HeterogeneousCore/SonicTriton/scripts/cmsTritonConfigTool +++ b/HeterogeneousCore/SonicTriton/scripts/cmsTritonConfigTool @@ -28,35 +28,6 @@ class DictAction(Action): result[values[i]] = self.val_type(values[i+1]) setattr(namespace, self.dest, result) -# patched version for more succinct printing of repeated messages -def PrintMessage(self, message): - if self.message_formatter and self._TryCustomFormatMessage(message): - return - if (message.DESCRIPTOR.full_name == text_format._ANY_FULL_TYPE_NAME and - self._TryPrintAsAnyMessage(message)): - return - fields = message.ListFields() - if self.use_index_order: - fields.sort(key=lambda x: x[0].number if x[0].is_extension else x[0].index) - for field, value in fields: - if text_format._IsMapEntry(field): - for key in sorted(value): - entry_submsg = value.GetEntryClass()(key=key, value=value[key]) - self.PrintField(field, entry_submsg) - elif field.label == descriptor.FieldDescriptor.LABEL_REPEATED: - if self.use_short_repeated_primitives: # no other conditions here - self._PrintShortRepeatedPrimitivesValue(field, value) - else: - for element in value: - self.PrintField(field, element) - else: - self.PrintField(field, value) - - if self.print_unknown_fields: - self._PrintUnknownFields(unknown_fields.UnknownFieldSet(message)) - -text_format._Printer.PrintMessage = PrintMessage - message_classes = {cls.__name__ : cls for cls in message.Message.__subclasses__()} _FieldDescriptor = descriptor.FieldDescriptor From d404d1d7349fcc5548dd254c02ebe8990c69384d Mon Sep 17 00:00:00 2001 From: Kevin Pedro Date: Wed, 10 Jan 2024 15:27:50 -0600 Subject: [PATCH 07/12] code format --- .../SonicTriton/src/TritonClient.cc | 19 +++++++++++-------- 1 file changed, 11 insertions(+), 8 deletions(-) diff --git a/HeterogeneousCore/SonicTriton/src/TritonClient.cc b/HeterogeneousCore/SonicTriton/src/TritonClient.cc index 14d586b0b547f..7d87b8afa386f 100644 --- a/HeterogeneousCore/SonicTriton/src/TritonClient.cc +++ b/HeterogeneousCore/SonicTriton/src/TritonClient.cc @@ -88,11 +88,13 @@ TritonClient::TritonClient(const edm::ParameterSet& params, const std::string& d const std::string& localModelConfigPath(params.getParameter("modelConfigPath").fullPath()); int fileDescriptor = open(localModelConfigPath.c_str(), O_RDONLY); if (fileDescriptor < 0) - throw TritonException("LocalFailure") << "TritonClient(): unable to open local model config: " << localModelConfigPath; + throw TritonException("LocalFailure") + << "TritonClient(): unable to open local model config: " << localModelConfigPath; google::protobuf::io::FileInputStream localModelConfigInput(fileDescriptor); localModelConfigInput.SetCloseOnDelete(true); if (!google::protobuf::TextFormat::Parse(&localModelConfigInput, &localModelConfig)) - throw TritonException("LocalFailure") << "TritonClient(): unable to parse local model config: " << localModelConfigPath; + throw TritonException("LocalFailure") + << "TritonClient(): unable to parse local model config: " << localModelConfigPath; } //check batch size limitations (after i/o setup) @@ -111,16 +113,16 @@ TritonClient::TritonClient(const edm::ParameterSet& params, const std::string& d "TritonClient(): unable to get model config"); inference::ModelConfig remoteModelConfig(modelConfigResponse.config()); - std::map> checksums; + std::map> checksums; size_t fileCounter = 0; - for (const auto& modelConfig: {localModelConfig, remoteModelConfig}) { + for (const auto& modelConfig : {localModelConfig, remoteModelConfig}) { const auto& agents = modelConfig.model_repository_agents().agents(); for (const auto& agent : agents) { if (agent.name() == "checksum") { const auto& params = agent.parameters(); - for (const auto& [key, val]: params) { + for (const auto& [key, val] : params) { // only check the requested version - if (key.compare(0, options_[0].model_version_.size()+1, options_[0].model_version_+"/")==0) + if (key.compare(0, options_[0].model_version_.size() + 1, options_[0].model_version_ + "/") == 0) checksums[key][fileCounter] = val; } break; @@ -129,12 +131,13 @@ TritonClient::TritonClient(const edm::ParameterSet& params, const std::string& d ++fileCounter; } std::vector incorrect; - for (const auto& [key, val]: checksums) { + for (const auto& [key, val] : checksums) { if (checksums[key][0] != checksums[key][1]) incorrect.push_back(key); } if (!incorrect.empty()) - throw TritonException("ModelVersioning") << "The following files have incorrect checksums on the remote server: " << triton_utils::printColl(incorrect, ", "); + throw TritonException("ModelVersioning") << "The following files have incorrect checksums on the remote server: " + << triton_utils::printColl(incorrect, ", "); //get model info inference::ModelMetadataResponse modelMetadata; From df717aa90846e4f3aefecaed8a26bcd3ca6b2a33 Mon Sep 17 00:00:00 2001 From: Kevin Pedro Date: Tue, 16 Jan 2024 15:47:56 -0600 Subject: [PATCH 08/12] skip symlinked dirs --- HeterogeneousCore/SonicTriton/scripts/cmsTritonConfigTool | 1 + 1 file changed, 1 insertion(+) diff --git a/HeterogeneousCore/SonicTriton/scripts/cmsTritonConfigTool b/HeterogeneousCore/SonicTriton/scripts/cmsTritonConfigTool index 5ca689667aa7b..aee2181e15cf8 100755 --- a/HeterogeneousCore/SonicTriton/scripts/cmsTritonConfigTool +++ b/HeterogeneousCore/SonicTriton/scripts/cmsTritonConfigTool @@ -242,6 +242,7 @@ def cfg_checksum(args): from glob import glob config_dir = os.path.dirname(args.config) for filename in glob(os.path.join(config_dir,"*/*")): + if os.path.islink(os.path.dirname(filename)): continue checksum = get_checksum(filename) # key = algorithm:[filename relative to config.pbtxt dir] filename = os.path.relpath(filename, config_dir) From 4893c62e078de22608c1ec572a08579fa491bfd1 Mon Sep 17 00:00:00 2001 From: Kevin Pedro Date: Tue, 16 Jan 2024 19:10:59 -0600 Subject: [PATCH 09/12] add overall version check function and unit test --- HeterogeneousCore/SonicTriton/README.md | 1 + .../SonicTriton/scripts/cmsTritonConfigTool | 102 ++++++++++++++---- .../SonicTriton/test/BuildFile.xml | 1 + 3 files changed, 86 insertions(+), 18 deletions(-) diff --git a/HeterogeneousCore/SonicTriton/README.md b/HeterogeneousCore/SonicTriton/README.md index b30f5c77ce8c7..7eed0f67989fa 100644 --- a/HeterogeneousCore/SonicTriton/README.md +++ b/HeterogeneousCore/SonicTriton/README.md @@ -169,6 +169,7 @@ The tool has several modes of operation (each with its own options, which can be * `view`: displays the field values from a provided `config.pbtxt` file. * `edit`: allows changing any field value in a `config.pbtxt` file. Non-primitive types are specified using JSON format. * `checksum`: checks and updates checksums for model files (to enforce versioning). +* `versioncheck`: checks and updates checksums for all `config.pbtxt` files in `$CMSSW_SEARCH_PATH`. * `threadcontrol`: adds job- and ML framework-specific thread control settings. The `edit` mode is intended for generic modifications, and only supports overwriting existing values diff --git a/HeterogeneousCore/SonicTriton/scripts/cmsTritonConfigTool b/HeterogeneousCore/SonicTriton/scripts/cmsTritonConfigTool index aee2181e15cf8..9a1ef54c57da6 100755 --- a/HeterogeneousCore/SonicTriton/scripts/cmsTritonConfigTool +++ b/HeterogeneousCore/SonicTriton/scripts/cmsTritonConfigTool @@ -2,7 +2,8 @@ import os, sys, json, pathlib, shutil from collections import OrderedDict -from argparse import ArgumentParser, ArgumentDefaultsHelpFormatter, RawTextHelpFormatter, RawDescriptionHelpFormatter, Action +from argparse import ArgumentParser, ArgumentDefaultsHelpFormatter, RawTextHelpFormatter, RawDescriptionHelpFormatter, Action, Namespace +from enum import Enum from google.protobuf import text_format, json_format, message, descriptor from google.protobuf.internal import type_checkers from tritonclient import grpc @@ -28,6 +29,11 @@ class DictAction(Action): result[values[i]] = self.val_type(values[i+1]) setattr(namespace, self.dest, result) +class TritonChecksumStatus(Enum): + CORRECT = 0 + MISSING = 1 + INCORRECT = 2 + message_classes = {cls.__name__ : cls for cls in message.Message.__subclasses__()} _FieldDescriptor = descriptor.FieldDescriptor @@ -88,6 +94,9 @@ def get_fields(obj, name, level=0, verbose=False): obj_info["fields"].append(get_fields(getattr(field_obj,field),field,level+1,verbose)) return obj_info +def get_model_info(): + return get_fields(grpc.model_config_pb2.ModelConfig(), "ModelConfig") + def msg_json(val, defaults=False): return json_format.MessageToJson(val, preserving_proto_field_name=True, including_default_value_fields=defaults, indent=0).replace(",\n",", ").replace("\n","") @@ -198,6 +207,12 @@ def get_checksum(filename, chunksize=4096): file_hash.update(chunk) return file_hash.hexdigest() +def get_checksum_update_cmd(force=False): + extra_args = ["--update"] + if force: extra_args.append("--force") + extra_args = [arg for arg in extra_args if arg not in sys.argv] + return "{} {}".format(" ".join(sys.argv), " ".join(extra_args)) + def update_config(args): # update config path to be output path (in case view is called) if args.copy: @@ -209,6 +224,8 @@ def update_config(args): text_format.PrintMessage(args.model, outfile, use_short_repeated_primitives=True) def cfg_common(args): + if not hasattr(args,'model_info'): + args.model_info = get_model_info() args.model = grpc.model_config_pb2.ModelConfig() if hasattr(args,'config'): with open(args.config,'r') as infile: @@ -231,6 +248,10 @@ def cfg_edit(args): cfg_view(args) def cfg_checksum(args): + # internal parameter + if not hasattr(args, "should_return"): + args.should_return = False + agents = args.model.model_repository_agents.agents checksum_agent = next((agent for agent in agents if agent.name=="checksum"), None) if checksum_agent is None: @@ -265,32 +286,72 @@ def cfg_checksum(args): if needs_update: print("\n".join(["Missing checksums:"]+missing)) if needs_force_update: - print("\n".join(["Incorrect checksums:"]+incorrect)) + print("\n".join(["Incorrect checksums:"]+incorrect)) if needs_force_update: if not (args.update and args.force): - extra_args = [arg for arg in ["--update","--force"] if arg not in sys.argv] - raise RuntimeError("\n".join([ - "Incorrect checksum(s) found, indicating existing model file(s) has been changed.", - "This violates policy. To override, run the following command, and provide a justification in your PR:", - "{} {}".format(" ".join(sys.argv), " ".join(extra_args)) - ])) + if args.should_return: + return TritonChecksumStatus.INCORRECT + else: + raise RuntimeError("\n".join([ + "Incorrect checksum(s) found, indicating existing model file(s) has been changed, which violates policy.", + "To override, run the following command (and provide a justification in your PR):", + get_checksum_update_cmd(force=True) + ])) else: update_config(args) elif needs_update: if not args.update: - extra_args = [arg for arg in ["--update"] if arg not in sys.argv] - raise RuntimeError("\n".join([ - "Missing checksum(s) found, indicating new model file(s).", - "To update, run the following command:", - "{} {}".format(" ".join(sys.argv), " ".join(extra_args)) - ])) + if args.should_return: + return TritonChecksumStatus.MISSING + else: + raise RuntimeError("\n".join([ + "Missing checksum(s) found, indicating new model file(s).", + "To update, run the following command:", + get_checksum_update_cmd(force=False) + ])) else: update_config(args) if args.view: cfg_view(args) + if args.should_return: + return TritonChecksumStatus.CORRECT + +def cfg_versioncheck(args): + incorrect = [] + missing = [] + + for path in os.environ['CMSSW_SEARCH_PATH'].split(':'): + for dirpath, dirnames, filenames in os.walk(path): + for filename in filenames: + if filename=="config.pbtxt": + filepath = os.path.join(dirpath,filename) + checksum_args = Namespace( + config=filepath, should_return=True, + copy=False, json=False, defaults=False, view=False, + update=args.update, force=args.force, quiet=True + ) + cfg_common(checksum_args) + status = cfg_checksum(checksum_args) + if status==TritonChecksumStatus.INCORRECT: + incorrect.append(filepath) + elif status==TritonChecksumStatus.MISSING: + missing.append(filepath) + + msg = [] + instr = [] + if len(missing)>0: + msg.extend(["","The following files have missing checksum(s), indicating new model file(s):"]+missing) + instr.extend(["","To update missing checksums, run the following command:",get_checksum_update_cmd(force=False)]) + if len(incorrect)>0: + msg.extend(["","The following files have incorrect checksum(s), indicating existing model file(s) have been changed, which violates policy:"]+incorrect) + instr.extend(["","To override incorrect checksums, run the following command (and provide a justification in your PR):",get_checksum_update_cmd(force=True)]) + + if len(msg)>0: + raise RuntimeError("\n".join(msg+instr)) + def cfg_threadcontrol(args): # copy the entire model, not just config.pbtxt config_dir = os.path.dirname(args.config) @@ -331,7 +392,7 @@ def cfg_threadcontrol(args): if __name__=="__main__": # initial common operations - model_info = get_fields(grpc.model_config_pb2.ModelConfig(), "ModelConfig") + model_info = get_model_info() edit_dests = None _parser_common = ArgumentParser(add_help=False) @@ -365,12 +426,17 @@ if __name__=="__main__": parser_edit, edit_dests = add_edit_args(parser_edit, model_info) parser_edit.set_defaults(func=cfg_edit) - parser_checksum = subparsers.add_parser("checksum", parents=[_parser_common, _parser_copy, _parser_view_args], help="handle model file checksums") - parser_checksum.add_argument("--update", default=False, action="store_true", help="update missing checksums in config.pbtxt") - parser_checksum.add_argument("--force", default=False, action="store_true", help="force update all checksums in config.pbtxt") + _parser_checksum_update = ArgumentParser(add_help=False) + _parser_checksum_update.add_argument("--update", default=False, action="store_true", help="update missing checksums") + _parser_checksum_update.add_argument("--force", default=False, action="store_true", help="force update all checksums") + + parser_checksum = subparsers.add_parser("checksum", parents=[_parser_common, _parser_copy, _parser_view_args, _parser_checksum_update], help="handle model file checksums") parser_checksum.add_argument("--quiet", default=False, action="store_true", help="suppress printouts") parser_checksum.set_defaults(func=cfg_checksum) + parser_versioncheck = subparsers.add_parser("versioncheck", parents=[_parser_checksum_update], help="check all model checksums") + parser_versioncheck.set_defaults(func=cfg_versioncheck) + _parser_copy_req = ArgumentParser(add_help=False, parents=[_parser_copy_view]) _parser_copy_req.add_argument("--copy", metavar="dir", type=str, required=True, help="local model repository directory to copy model(s)" diff --git a/HeterogeneousCore/SonicTriton/test/BuildFile.xml b/HeterogeneousCore/SonicTriton/test/BuildFile.xml index 272fba3da2cc8..daac8cb67d83b 100644 --- a/HeterogeneousCore/SonicTriton/test/BuildFile.xml +++ b/HeterogeneousCore/SonicTriton/test/BuildFile.xml @@ -1,5 +1,6 @@ + From 26d99a300f940749edc6c9197a80d2ad9d00491c Mon Sep 17 00:00:00 2001 From: Kevin Pedro Date: Wed, 17 Jan 2024 09:28:16 -0600 Subject: [PATCH 10/12] add test dependency on cmsswdata --- HeterogeneousCore/SonicTriton/test/BuildFile.xml | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/HeterogeneousCore/SonicTriton/test/BuildFile.xml b/HeterogeneousCore/SonicTriton/test/BuildFile.xml index daac8cb67d83b..843b0ffecca71 100644 --- a/HeterogeneousCore/SonicTriton/test/BuildFile.xml +++ b/HeterogeneousCore/SonicTriton/test/BuildFile.xml @@ -1,6 +1,8 @@ - + + + From a2a88f1d53e1a44e5fd741ce08b6f3a581ac0870 Mon Sep 17 00:00:00 2001 From: Kevin Pedro Date: Wed, 17 Jan 2024 09:50:25 -0600 Subject: [PATCH 11/12] remove unnecessary paths --- HeterogeneousCore/SonicTriton/test/BuildFile.xml | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/HeterogeneousCore/SonicTriton/test/BuildFile.xml b/HeterogeneousCore/SonicTriton/test/BuildFile.xml index 843b0ffecca71..e4ff7a0bb56f3 100644 --- a/HeterogeneousCore/SonicTriton/test/BuildFile.xml +++ b/HeterogeneousCore/SonicTriton/test/BuildFile.xml @@ -1,6 +1,6 @@ - - - + + + From 5edfefea2e32aacc81f686e6c00338d283aeb764 Mon Sep 17 00:00:00 2001 From: Kevin Pedro Date: Fri, 26 Jan 2024 14:46:03 -0600 Subject: [PATCH 12/12] improve agent finding pattern --- .../SonicTriton/src/TritonClient.cc | 16 +++++++--------- 1 file changed, 7 insertions(+), 9 deletions(-) diff --git a/HeterogeneousCore/SonicTriton/src/TritonClient.cc b/HeterogeneousCore/SonicTriton/src/TritonClient.cc index 7d87b8afa386f..201ad40d35a0e 100644 --- a/HeterogeneousCore/SonicTriton/src/TritonClient.cc +++ b/HeterogeneousCore/SonicTriton/src/TritonClient.cc @@ -117,15 +117,13 @@ TritonClient::TritonClient(const edm::ParameterSet& params, const std::string& d size_t fileCounter = 0; for (const auto& modelConfig : {localModelConfig, remoteModelConfig}) { const auto& agents = modelConfig.model_repository_agents().agents(); - for (const auto& agent : agents) { - if (agent.name() == "checksum") { - const auto& params = agent.parameters(); - for (const auto& [key, val] : params) { - // only check the requested version - if (key.compare(0, options_[0].model_version_.size() + 1, options_[0].model_version_ + "/") == 0) - checksums[key][fileCounter] = val; - } - break; + auto agent = std::find_if(agents.begin(), agents.end(), [](auto const& a) { return a.name() == "checksum"; }); + if (agent != agents.end()) { + const auto& params = agent->parameters(); + for (const auto& [key, val] : params) { + // only check the requested version + if (key.compare(0, options_[0].model_version_.size() + 1, options_[0].model_version_ + "/") == 0) + checksums[key][fileCounter] = val; } } ++fileCounter;