diff --git a/RecoTracker/LSTCore/interface/DenseLayer.h b/RecoTracker/LSTCore/interface/DenseLayer.h new file mode 100644 index 0000000000000..e9431750e74e3 --- /dev/null +++ b/RecoTracker/LSTCore/interface/DenseLayer.h @@ -0,0 +1,35 @@ +#ifndef RecoTracker_LSTCore_interface_DenseLayer_h +#define RecoTracker_LSTCore_interface_DenseLayer_h + +#include +#include +#include + +/** + * Represents a dense (fully connected) layer with fixed input and output sizes. + * + * IN: Number of input neurons + * OUT: Number of output neurons + */ +template +struct DenseLayer { + /** + * Biases: one float per output neuron. + */ + std::array biases{}; + + /** + * Weights: stored as IN rows of OUT columns. + */ + std::array, IN> weights{}; + + /** + * Returns the weight from input neuron index `in` to output neuron index `out`. + */ + float getWeight(std::size_t in, std::size_t out) const { return weights[in][out]; } + + static constexpr std::size_t inputSize = IN; + static constexpr std::size_t outputSize = OUT; +}; + +#endif \ No newline at end of file diff --git a/RecoTracker/LSTCore/interface/Dnn.h b/RecoTracker/LSTCore/interface/Dnn.h new file mode 100644 index 0000000000000..190ef1f243325 --- /dev/null +++ b/RecoTracker/LSTCore/interface/Dnn.h @@ -0,0 +1,140 @@ +#ifndef RecoTracker_LSTCore_interface_Dnn_h +#define RecoTracker_LSTCore_interface_Dnn_h + +#include +#include +#include +#include +#include +#include + +/** + * A general Dnn class that holds a sequence (tuple) of DenseLayer types, + * each with compile-time fixed dimensions. + * + * Layers: A parameter pack of layer types (e.g. DenseLayer<23,32>, DenseLayer<32,1>, etc.) + */ +template +class Dnn { +public: + Dnn() = default; + explicit Dnn(const std::string& filename) { load(filename); } + + /** + * Loads biases and weights for each layer in the tuple from a binary file. + */ + void load(const std::string& filename) { + std::ifstream file(filename, std::ios::binary); + if (!file) { + throw std::runtime_error("Failed to open file: " + filename); + } + + loadLayers<0>(file); + + if (!file.good()) { + throw std::runtime_error("Error reading from file: " + filename); + } + file.close(); + } + + /** + * Prints the biases and weights of each layer to stdout. + */ + void print() const { printLayers<0>(); } + + /** + * A const reference to the underlying tuple of layers. + */ + const std::tuple& getLayers() const { return layers_; } + + /** + * A reference to the underlying tuple of layers. + */ + std::tuple& getLayers() { return layers_; } + +private: + // Store all layers in a compile-time tuple + std::tuple layers_; + + /** + * Internal compile-time recursion for loading each layer from file + */ + template + typename std::enable_if::type loadLayers(std::ifstream&) { + // Base case: no more layers to load + } + + template + typename std::enable_if < I::type loadLayers(std::ifstream& file) { + auto& layer = std::get(layers_); + + // Read and verify header information + uint32_t layer_id, num_inputs, num_outputs; + file.read(reinterpret_cast(&layer_id), sizeof(layer_id)); + file.read(reinterpret_cast(&num_inputs), sizeof(num_inputs)); + file.read(reinterpret_cast(&num_outputs), sizeof(num_outputs)); + + // Verify the dimensions match our template parameters + if (num_inputs != layer.inputSize || num_outputs != layer.outputSize) { + throw std::runtime_error("Layer " + std::to_string(I) + + " dimension mismatch: " + "expected " + + std::to_string(layer.inputSize) + "x" + std::to_string(layer.outputSize) + ", got " + + std::to_string(num_inputs) + "x" + std::to_string(num_outputs)); + } + + // Verify layer index matches + if (layer_id != I + 1) { // Assumes 1-based layer IDs + throw std::runtime_error("Layer index mismatch: expected " + std::to_string(I + 1) + ", got " + + std::to_string(layer_id)); + } + + // Read biases + file.read(reinterpret_cast(layer.biases.data()), layer.biases.size() * sizeof(float)); + + // Read weights row by row + for (auto& row : layer.weights) { + file.read(reinterpret_cast(row.data()), row.size() * sizeof(float)); + } + + if (!file.good()) { + throw std::runtime_error("Failed to read parameters for layer " + std::to_string(I)); + } + + // Recurse to next layer + loadLayers(file); + } + + /** + * Internal compile-time recursion for printing each layer + */ + template + typename std::enable_if::type printLayers() const { + // Base case: no more layers to print + } + + template + typename std::enable_if < I::type printLayers() const { + const auto& layer = std::get(layers_); + std::cout << "\n=== Layer " << I + 1 << " ===\nInputs=" << layer.inputSize << ", Outputs=" << layer.outputSize + << "\n\nBiases:\n"; + + for (float b : layer.biases) { + std::cout << b << " "; + } + std::cout << "\n\nWeights:\n"; + + for (std::size_t in = 0; in < layer.inputSize; ++in) { + std::cout << " [ "; + for (std::size_t out = 0; out < layer.outputSize; ++out) { + std::cout << layer.getWeight(in, out) << " "; + } + std::cout << "]\n"; + } + + // Recurse to next layer + printLayers(); + } +}; + +#endif \ No newline at end of file diff --git a/RecoTracker/LSTCore/interface/DnnWeightsDevSoA.h b/RecoTracker/LSTCore/interface/DnnWeightsDevSoA.h new file mode 100644 index 0000000000000..43a4c1dd180ef --- /dev/null +++ b/RecoTracker/LSTCore/interface/DnnWeightsDevSoA.h @@ -0,0 +1,19 @@ +#ifndef RecoTracker_LSTCore_interface_DnnWeightsDevSoA_h +#define RecoTracker_LSTCore_interface_DnnWeightsDevSoA_h + +#include "RecoTracker/LSTCore/interface/DenseLayer.h" + +namespace lst { + + /** + * Data structure holding multiple dense layers for the DNN weights. + */ + struct DnnWeightsDevData { + DenseLayer<23, 32> layer1; + DenseLayer<32, 32> layer2; + DenseLayer<32, 1> layer3; + }; + +} // namespace lst + +#endif // RecoTracker_LSTCore_interface_DnnWeightsDevSoA_h \ No newline at end of file diff --git a/RecoTracker/LSTCore/interface/LSTESData.h b/RecoTracker/LSTCore/interface/LSTESData.h index bfa10186f8f2e..d0836ef5dc42b 100644 --- a/RecoTracker/LSTCore/interface/LSTESData.h +++ b/RecoTracker/LSTCore/interface/LSTESData.h @@ -6,6 +6,8 @@ #include "RecoTracker/LSTCore/interface/ModulesHostCollection.h" #include "RecoTracker/LSTCore/interface/PixelMap.h" +#include "RecoTracker/LSTCore/interface/DnnWeightsDevSoA.h" +#include "DataFormats/Portable/interface/PortableObject.h" #include "HeterogeneousCore/AlpakaInterface/interface/CopyToDevice.h" #include @@ -23,21 +25,25 @@ namespace lst { std::shared_ptr> endcapGeometry; // Host-side object that is shared between the LSTESData objects for different devices std::shared_ptr pixelMapping; - + // ==== New DNN weights pointer ==== + std::shared_ptr> dnnWeights; LSTESData(uint16_t const& nModulesIn, uint16_t const& nLowerModulesIn, unsigned int const& nPixelsIn, unsigned int const& nEndCapMapIn, std::shared_ptr> modulesIn, std::shared_ptr> endcapGeometryIn, - std::shared_ptr const& pixelMappingIn) + std::shared_ptr const& pixelMappingIn, + // New constructor argument for DNN + std::shared_ptr> dnnWeightsIn) : nModules(nModulesIn), nLowerModules(nLowerModulesIn), nPixels(nPixelsIn), nEndCapMap(nEndCapMapIn), modules(std::move(modulesIn)), endcapGeometry(std::move(endcapGeometryIn)), - pixelMapping(pixelMappingIn) {} + pixelMapping(pixelMappingIn), + dnnWeights(std::move(dnnWeightsIn)) {} }; std::unique_ptr> loadAndFillESHost(std::string& ptCutLabel); @@ -54,16 +60,22 @@ namespace cms::alpakatools { using TDev = alpaka::Dev; std::shared_ptr> deviceModules; std::shared_ptr> deviceEndcapGeometry; + // === New pointer for the copied DNN weights === + std::shared_ptr> deviceDnnWeights; if constexpr (std::is_same_v) { deviceModules = srcData.modules; deviceEndcapGeometry = srcData.endcapGeometry; + deviceDnnWeights = srcData.dnnWeights; } else { deviceModules = std::make_shared>( CopyToDevice>::copyAsync( queue, *srcData.modules)); deviceEndcapGeometry = std::make_shared>( CopyToDevice>::copyAsync(queue, *srcData.endcapGeometry)); + // Copy the DNN weights to device + deviceDnnWeights = std::make_shared>( + CopyToDevice>::copyAsync(queue, *srcData.dnnWeights)); } return lst::LSTESData>(srcData.nModules, @@ -72,7 +84,8 @@ namespace cms::alpakatools { srcData.nEndCapMap, std::move(deviceModules), std::move(deviceEndcapGeometry), - srcData.pixelMapping); + srcData.pixelMapping, + std::move(deviceDnnWeights)); } }; } // namespace cms::alpakatools diff --git a/RecoTracker/LSTCore/src/LSTESData.cc b/RecoTracker/LSTCore/src/LSTESData.cc index dad8522bbe2cd..e132a7bca46b3 100644 --- a/RecoTracker/LSTCore/src/LSTESData.cc +++ b/RecoTracker/LSTCore/src/LSTESData.cc @@ -3,7 +3,9 @@ #include "RecoTracker/LSTCore/interface/ModuleConnectionMap.h" #include "RecoTracker/LSTCore/interface/TiltedGeometry.h" #include "RecoTracker/LSTCore/interface/PixelMap.h" - +#include "RecoTracker/LSTCore/interface/Dnn.h" +#include "RecoTracker/LSTCore/interface/DenseLayer.h" +#include "RecoTracker/LSTCore/interface/DnnWeightsDevSoA.h" #include "ModuleMethods.h" #include @@ -111,11 +113,30 @@ std::unique_ptr> lst::loadAndFillESHost(s tiltedGeometry, moduleConnectionMap); auto pixelMappingPtr = std::make_shared(std::move(pixelMapping)); + + // === Load from the DNN instance === + auto model = + Dnn, DenseLayer<32, 32>, DenseLayer<32, 1>>("../standalone/analysis/DNN/network_weights.bin"); + + // Copy the loaded model into a host DnnWeightsDevData struct + lst::DnnWeightsDevData hostDnn; + { + auto const& layers = model.getLayers(); + hostDnn.layer1 = std::get<0>(layers); + hostDnn.layer2 = std::get<1>(layers); + hostDnn.layer3 = std::get<2>(layers); + } + + // Wrap it in a PortableHostObject so it can be copied to device + auto hostDnnWeights = std::make_shared>(cms::alpakatools::host()); + hostDnnWeights->value() = hostDnn; + return std::make_unique>(nModules, nLowerModules, nPixels, endcapGeometry.nEndCapMap, std::move(modulesBuffers), std::move(endcapGeometryDev), - pixelMappingPtr); + pixelMappingPtr, + hostDnnWeights); } diff --git a/RecoTracker/LSTCore/src/alpaka/LSTEvent.dev.cc b/RecoTracker/LSTCore/src/alpaka/LSTEvent.dev.cc index d77e242a1a88c..d1317e5008e6d 100644 --- a/RecoTracker/LSTCore/src/alpaka/LSTEvent.dev.cc +++ b/RecoTracker/LSTCore/src/alpaka/LSTEvent.dev.cc @@ -914,6 +914,7 @@ void LSTEvent::createQuintuplets() { quintupletsDC_->view(), rangesDC_->const_view(), nEligibleT5Modules, + dnnWeights_.data(), ptCut_); Vec3D const threadsPerBlockDupQuint{1, 16, 16}; diff --git a/RecoTracker/LSTCore/src/alpaka/LSTEvent.h b/RecoTracker/LSTCore/src/alpaka/LSTEvent.h index 02f1decef916b..e3c97c3ee4d89 100644 --- a/RecoTracker/LSTCore/src/alpaka/LSTEvent.h +++ b/RecoTracker/LSTCore/src/alpaka/LSTEvent.h @@ -25,6 +25,8 @@ #include "RecoTracker/LSTCore/interface/alpaka/ModulesDeviceCollection.h" #include "RecoTracker/LSTCore/interface/alpaka/ObjectRangesDeviceCollection.h" #include "RecoTracker/LSTCore/interface/alpaka/EndcapGeometryDevDeviceCollection.h" +#include "RecoTracker/LSTCore/interface/DnnWeightsDevSoA.h" +#include "DataFormats/Portable/interface/PortableObject.h" #include "Hit.h" #include "Kernels.h" @@ -78,6 +80,7 @@ namespace ALPAKA_ACCELERATOR_NAMESPACE::lst { ModulesDeviceCollection const& modules_; PixelMap const& pixelMapping_; EndcapGeometryDevDeviceCollection const& endcapGeometry_; + PortableObject const& dnnWeights_; bool addObjects_; public: @@ -92,6 +95,7 @@ namespace ALPAKA_ACCELERATOR_NAMESPACE::lst { modules_(*deviceESData->modules), pixelMapping_(*deviceESData->pixelMapping), endcapGeometry_(*deviceESData->endcapGeometry), + dnnWeights_(*deviceESData->dnnWeights), addObjects_(verbose) { if (pt_cut < 0.6f) { throw std::invalid_argument("Minimum pT cut must be at least 0.6 GeV. Provided value: " + diff --git a/RecoTracker/LSTCore/src/alpaka/NeuralNetwork.h b/RecoTracker/LSTCore/src/alpaka/NeuralNetwork.h index cc1bffa3d928b..9610f26654ed1 100644 --- a/RecoTracker/LSTCore/src/alpaka/NeuralNetwork.h +++ b/RecoTracker/LSTCore/src/alpaka/NeuralNetwork.h @@ -6,7 +6,7 @@ #include "RecoTracker/LSTCore/interface/alpaka/Common.h" #include "RecoTracker/LSTCore/interface/MiniDoubletsSoA.h" -#include "NeuralNetworkWeights.h" +#include "RecoTracker/LSTCore/interface/DnnWeightsDevSoA.h" namespace ALPAKA_ACCELERATOR_NAMESPACE::lst::t5dnn { @@ -24,10 +24,11 @@ namespace ALPAKA_ACCELERATOR_NAMESPACE::lst::t5dnn { } template - ALPAKA_FN_ACC ALPAKA_FN_INLINE void linear_layer(const float (&input)[IN_FEATURES], - float (&output)[OUT_FEATURES], - const float (&weights)[IN_FEATURES][OUT_FEATURES], - const float (&biases)[OUT_FEATURES]) { + ALPAKA_FN_ACC ALPAKA_FN_INLINE void linear_layer( + const float (&input)[IN_FEATURES], + float (&output)[OUT_FEATURES], + const std::array, IN_FEATURES>& weights, + const std::array& biases) { CMS_UNROLL_LOOP for (unsigned int i = 0; i < OUT_FEATURES; ++i) { output[i] = biases[i]; @@ -52,6 +53,7 @@ namespace ALPAKA_ACCELERATOR_NAMESPACE::lst::t5dnn { template ALPAKA_FN_ACC ALPAKA_FN_INLINE bool runInference(TAcc const& acc, + lst::DnnWeightsDevData const* dnnPtr, MiniDoubletsConst mds, const unsigned int mdIndex1, const unsigned int mdIndex2, @@ -126,15 +128,15 @@ namespace ALPAKA_ACCELERATOR_NAMESPACE::lst::t5dnn { float x_3[1]; // Layer 3 linear output // Layer 1: Linear + Relu - linear_layer(x, x_1, wgtT_layer1, bias_layer1); + linear_layer(x, x_1, dnnPtr->layer1.weights, dnnPtr->layer1.biases); relu_activation(x_1); // Layer 2: Linear + Relu - linear_layer(x_1, x_2, wgtT_layer2, bias_layer2); + linear_layer(x_1, x_2, dnnPtr->layer2.weights, dnnPtr->layer2.biases); relu_activation(x_2); // Layer 3: Linear + Sigmoid - linear_layer(x_2, x_3, wgtT_output_layer, bias_output_layer); + linear_layer(x_2, x_3, dnnPtr->layer3.weights, dnnPtr->layer3.biases); float x_5 = sigmoid_activation(acc, x_3[0]); // Get the bin index based on abs(eta) of first hit and t5_pt diff --git a/RecoTracker/LSTCore/src/alpaka/NeuralNetworkWeights.h b/RecoTracker/LSTCore/src/alpaka/NeuralNetworkWeights.h deleted file mode 100644 index 42f7b19f33898..0000000000000 --- a/RecoTracker/LSTCore/src/alpaka/NeuralNetworkWeights.h +++ /dev/null @@ -1,257 +0,0 @@ -#ifndef RecoTracker_LSTCore_src_alpaka_NeuralNetworkWeights_h -#define RecoTracker_LSTCore_src_alpaka_NeuralNetworkWeights_h - -#include - -namespace ALPAKA_ACCELERATOR_NAMESPACE::lst::t5dnn { - ALPAKA_STATIC_ACC_MEM_GLOBAL const float bias_layer1[32] = { - -1.3837075f, -0.0653152f, -0.7900129f, 0.0714758f, -1.1574365f, -1.4634879f, -0.9317133f, -0.1455518f, - -0.0459635f, -0.2055620f, 0.0586231f, -0.8943899f, -0.1009487f, 0.0166031f, -0.5451909f, -0.1384538f, - 1.2664700f, -1.8996916f, -0.0025585f, -0.1647783f, -1.9019107f, 0.0707104f, -0.2373025f, 0.0357050f, - -0.0048417f, 2.3127339f, -0.0508943f, -0.1116435f, -0.1610904f, -1.6463890f, -1.0739423f, -0.0962902f}; - - ALPAKA_STATIC_ACC_MEM_GLOBAL const float wgtT_layer1[23][32] = { - {-0.1881404f, -0.0534256f, 1.6563641f, 0.0401664f, 2.8318353f, 1.5049738f, 1.4111555f, -0.2339872f, - 0.0431970f, 0.1220361f, -0.0450153f, -1.6025578f, 0.0394025f, -0.3051167f, 1.9442217f, 0.1599094f, - 0.1376955f, 2.4181051f, -0.0226484f, -0.1801709f, -0.4861264f, -0.0268545f, 0.5463807f, 0.2420150f, - -0.1238829f, 0.2916382f, 0.1507791f, 0.7952659f, 0.2736979f, 3.2790639f, 1.2062043f, -0.0884467f}, - {-0.0469924f, 0.2013927f, 0.0307775f, -0.1241788f, -0.0100412f, 0.0422375f, 0.0211071f, -0.0359304f, - 0.0451861f, 0.0291862f, -0.2094866f, -0.0013007f, 0.1191471f, 0.0750159f, 0.0184378f, 0.0419437f, - -0.0207304f, -0.0444109f, 0.0013400f, -0.0699210f, -0.0668742f, -0.0880825f, -0.0107244f, 0.0363424f, - 0.1391699f, -0.0112885f, -0.0060098f, -0.0073863f, -0.0566143f, -0.0224207f, 0.0103718f, -0.0015193f}, - {0.4520382f, 0.1227609f, -1.3887709f, -0.0542129f, -3.2003114f, -0.8354173f, -1.3173198f, 0.3292131f, - -0.1657729f, -0.1982902f, 0.1599589f, -0.0417666f, -0.1461042f, -1.3237997f, -5.3609071f, -0.0981676f, - 0.2922535f, -1.8692241f, -0.0345302f, 0.1810613f, 0.4473544f, -0.0159401f, -0.7293931f, -1.4816793f, - -0.1431545f, -0.0955672f, -0.2370718f, -0.7204540f, 0.8451244f, -3.4310548f, -1.3518151f, 0.1551731f}, - {0.2670300f, 0.1343590f, 3.0347505f, -0.1783503f, 2.1586559f, 2.4137778f, 2.0080864f, -0.2545274f, - -0.1985905f, 0.1653812f, -0.1714860f, 4.1022782f, -0.1045471f, 4.4776497f, 3.3737848f, -0.0849546f, - -6.1899095f, 3.6970129f, 0.0007382f, 0.1675882f, 0.6014717f, -0.0287709f, 0.0495882f, 2.2192705f, - -0.1043157f, -4.7508621f, -0.0022774f, 0.3766513f, -0.7505829f, 1.9759512f, 1.6747239f, -0.1004091f}, - {0.6639504f, -0.0384022f, -10.0415087f, -0.0032648f, 0.3049855f, -2.0427964f, -1.1522077f, 0.0935732f, - 0.1232134f, 0.0868663f, -0.0230848f, -1.8257296f, -0.0799238f, 6.8892417f, -1.3941933f, 0.0445172f, - 0.9485117f, -2.5238073f, -0.0148513f, 0.2256772f, 0.5914315f, -0.1278037f, 0.1609928f, 11.3438406f, - -0.0831544f, 0.1928522f, 0.0361467f, 0.0137040f, 4.9549832f, 2.3954937f, 0.3917757f, 0.1206975f}, - {29.6590214f, -0.0836848f, -1.3028307f, -0.1391431f, -0.3703596f, 5.3762760f, 1.8429571f, 21.0697041f, - -0.1232606f, 0.0066067f, -0.0308768f, -0.9960231f, 0.1865301f, -1.2142091f, 0.9273136f, 0.0974103f, - 1.4067870f, 0.7268439f, 0.0035755f, 0.0619486f, -32.8901024f, -0.1950644f, -0.3978897f, -3.1790049f, - -0.1371673f, 0.1569460f, 0.0268667f, -0.4512640f, 0.3055371f, -0.2241473f, -0.6455348f, 0.1178979f}, - {-2.9178317f, -0.2023720f, -0.2946439f, -0.1851392f, -0.3493766f, -1.5397958f, -1.5902523f, 1.0981250f, - -0.1796725f, -0.0540953f, 0.0926500f, 2.0021629f, -0.1277778f, 3.3643394f, -7.5327554f, -0.0084912f, - 2.7298651f, 0.2535582f, 0.0474618f, -0.1377846f, -2.2746830f, -0.2016302f, -0.7150622f, 4.4011140f, - -0.1688751f, -1.2160714f, -0.0055839f, -1.1319760f, -2.2543004f, 0.6365916f, -1.4942099f, -0.0992425f}, - {-5.9751196f, -0.1597221f, -3.8946304f, 0.0537821f, 0.4741110f, 3.6895070f, 2.5116272f, 1.7058172f, - -0.0860321f, -0.1519644f, 0.1465356f, 1.4165760f, -0.0984433f, 1.6990343f, 4.0953226f, 0.1742475f, - -3.2570388f, 3.1653547f, 0.0135764f, 0.0092055f, -5.0966530f, -0.0542810f, 0.4907863f, 0.5900084f, - -0.1736992f, -4.9153452f, 0.2017547f, 0.2854181f, 3.1490057f, 0.2885774f, 0.9775900f, -0.2207156f}, - {0.3805595f, 0.0308984f, -9.5846119f, -0.0547350f, 1.9641919f, 2.0823991f, 9.9298115f, 0.0344243f, - -0.1557834f, -0.1847700f, -0.1195207f, 4.4698248f, 0.1492174f, 0.4272707f, 4.7265644f, 0.0200772f, - -14.3444443f, 4.9532328f, 0.0319610f, -0.0645846f, -0.6238102f, 0.1038110f, 0.2483765f, -5.1799927f, - 0.0782294f, 16.8777409f, 0.0196593f, 0.8423936f, -8.5921221f, -0.0184179f, -5.7857180f, -0.0551181f}, - {17.1570740f, 0.0265437f, -1.4766232f, -0.0528512f, 1.0128449f, 3.1529653f, -0.6560294f, 8.7189465f, - -0.1728377f, 0.1245629f, 0.1072764f, 0.2649773f, 0.0254132f, -0.8094708f, 1.8371828f, 0.1586192f, - 1.9410020f, 0.9662392f, -0.0839922f, -0.2894930f, -16.5091496f, -0.1079556f, -0.1204132f, -0.9694697f, - 0.0537786f, 0.2476868f, 0.0076408f, 0.1025890f, 0.1267423f, 0.4956081f, 0.1457323f, 0.1342634f}, - {-0.5389574f, 0.1333421f, -4.6338782f, -0.0645123f, -0.6526322f, -3.2958410f, -1.2309581f, -1.0803053f, - -0.1170542f, -0.0169311f, 0.1147491f, 2.9890807f, -0.1234096f, 0.6792320f, -3.9311285f, -0.0678321f, - -2.7922039f, 4.9413238f, 0.1060735f, -0.1114068f, -2.2443752f, -0.1649915f, -0.3656403f, 2.5320942f, - -0.0249616f, -4.5098810f, -0.1773834f, -1.9516623f, -1.6839710f, -0.1365123f, 1.0296160f, -0.0419825f}, - {-2.4413636f, 0.1075683f, -1.4518708f, 0.0537449f, 0.1154493f, -0.5463845f, 1.3964951f, 2.6729572f, - -0.0206257f, 0.1435281f, -0.1819518f, 0.4540120f, -0.1910136f, 1.7696143f, 2.3670278f, 0.1324464f, - -0.5837788f, -2.2784615f, 0.0345478f, -0.0980538f, -0.4999657f, 0.1178097f, 0.5756868f, -0.1058674f, - 0.1920418f, -3.5473657f, 0.2146371f, 0.2557987f, 1.3935618f, 0.3242345f, 0.2029733f, -0.1844350f}, - {-0.9069599f, -0.2032758f, -0.5786582f, 0.1395915f, 3.9338124f, -1.6806563f, 0.4269728f, -0.3697720f, - -0.0306356f, -0.0341866f, -0.0635755f, 1.8898975f, 0.1968578f, -17.2182655f, 1.4839698f, -0.0541308f, - 15.9838457f, 18.5951862f, 0.0078872f, -0.1186571f, -2.4982276f, 0.0033835f, 0.3749593f, -15.0238085f, - 0.0595601f, -16.8588371f, 0.1146287f, 0.1274172f, 19.3332062f, -7.0513921f, -5.4852023f, 0.1681230f}, - {-5.1457887f, 0.0335570f, 1.8620163f, 0.0560381f, -0.6397949f, -4.0867515f, 1.3578068f, -23.9992580f, - -0.1034287f, 0.1437906f, 0.1076568f, -0.6930848f, -0.1176134f, 2.2855785f, -0.8021089f, 0.0424611f, - -0.6139123f, -3.1381547f, 0.0188163f, -0.1728741f, 0.6676420f, -0.1124282f, 0.1077818f, 2.3839712f, - 0.1340676f, 1.3538554f, 0.0421035f, 0.4513423f, -0.1543196f, 0.5120541f, -0.8940096f, -0.1175765f}, - {2.1656792f, 0.1638565f, 4.5302448f, 0.0741160f, 3.3850696f, -4.8867540f, 2.8059542f, -0.0023008f, - -0.1248942f, -0.0075225f, -0.0082212f, -1.0955724f, -0.1462416f, -1.7098176f, -4.1775723f, 0.1950609f, - 3.6847639f, 1.6520064f, 0.0310502f, -0.0430167f, 3.4527576f, 0.1453262f, -1.0126116f, 1.8785841f, - -0.0615105f, 1.0451943f, -0.2653875f, -1.2223006f, -1.0100641f, 1.2076828f, 0.4882897f, -0.0618375f}, - {2.4578559f, -0.1464199f, -1.3086185f, 0.1208716f, -0.2079897f, -2.7138259f, -1.4107026f, -0.4483974f, - -0.1599056f, 0.0242936f, 0.1326804f, 0.8664415f, 0.0588684f, 0.7366717f, 2.3159802f, -0.1917707f, - -2.0800066f, -7.5100355f, 0.0585225f, 0.1582773f, 1.8128076f, -0.0756957f, 0.8521049f, 0.5539182f, - -0.1738797f, -0.2020151f, 0.2219591f, 0.1088298f, -1.9535940f, 2.4130275f, -0.0741222f, 0.1156681f}, - {-0.4152933f, -0.0679605f, -0.5760314f, -0.0201883f, -14.1784763f, 0.7755737f, -19.5469246f, 0.0381304f, - 0.0160074f, 0.1124380f, -0.0478151f, -2.3719466f, 0.0819727f, -12.5069208f, 2.0468810f, 0.0964909f, - 7.8784809f, -6.3555703f, -0.0429914f, -0.0162720f, -0.9493829f, 0.0296786f, -0.0244959f, -12.6325788f, - -0.1871653f, -9.8338795f, 0.0391840f, -0.1199073f, -11.7859421f, 8.7398720f, 19.4971046f, -0.1954873f}, - {-4.8962007f, -0.1695992f, 0.7760146f, -0.0199836f, -0.0576061f, -6.0196476f, -2.3023551f, -20.0125084f, - -0.1957836f, -0.0993785f, 0.1109372f, -0.0710161f, -0.0553650f, 0.2546394f, -1.7578228f, 0.1498791f, - -2.6269529f, 1.3973731f, 0.0464059f, -0.2307575f, 1.6730053f, -0.0038867f, 0.1040150f, 2.6721606f, - 0.2027777f, -1.2358316f, -0.0587254f, 0.0610504f, -0.1700777f, -0.4323797f, 1.0359807f, -0.0127435f}, - {1.1245984f, -0.1806923f, -1.5868790f, 0.1536594f, 1.6837788f, -1.6474472f, -3.9225550f, 0.4506312f, - 0.1854908f, -0.1023232f, -0.0306957f, -0.8615071f, 0.0945480f, 2.0585704f, 0.6044773f, 0.1269336f, - 2.4720187f, -4.5123949f, -0.0657749f, 0.1738364f, 2.4188614f, 0.0038840f, -0.2019601f, -0.3842189f, - -0.0493631f, 3.6777370f, -0.1003436f, 0.6174496f, 1.0476112f, 2.7601521f, 0.9059890f, -0.1691816f}, - {1.9658293f, 0.2083382f, 1.7833723f, 0.0662620f, -0.3932888f, -1.0642430f, 0.1807114f, -1.1486723f, - -0.0177136f, -0.1706942f, 0.1730027f, 0.6712329f, 0.0485299f, 0.6379296f, -0.2880911f, -0.1993632f, - -0.9471832f, 1.9425983f, 0.0328524f, 0.0777725f, 0.6454380f, 0.0143852f, 0.0192997f, 1.6793132f, - -0.1872064f, -1.5757623f, 0.0242778f, -0.5992475f, 2.2148299f, -3.5215647f, -2.9748621f, 0.0112703f}, - {0.3737165f, 0.0361593f, -0.1075856f, -0.0312021f, -0.0786010f, 1.3149793f, 0.0237401f, -0.0819654f, - -0.1388431f, -0.0306386f, -0.0704427f, -2.3997226f, -0.1392045f, 0.7729424f, 0.1253861f, -0.0819755f, - -0.7590774f, -0.3295609f, -0.0172208f, -0.0551179f, 0.4599459f, -0.1143881f, 2.7430685f, 0.3621114f, - -0.1475701f, 0.2296079f, -2.2224922f, -0.9080986f, 0.2101683f, 0.1190262f, -2.2205217f, -0.0811555f}, - {0.3946800f, -0.1204188f, 0.0543225f, -0.0392627f, 1.9454094f, 0.1865290f, 1.5276426f, -0.0342965f, - 0.0117116f, -0.1873923f, -0.1045035f, 1.8535231f, -0.0207077f, 0.0981549f, -0.0327459f, -0.1486938f, - 0.6359531f, -0.1314566f, -2.1469448f, -0.1665767f, 0.5134121f, -0.0341647f, -2.1786075f, -0.5976576f, - 0.0111857f, 0.3272055f, 2.1917374f, -1.6247722f, 1.6025572f, -1.9965295f, 0.3347488f, 0.1113990f}, - {0.0340557f, -0.1659652f, -0.0042457f, 0.0010229f, -2.1550148f, -0.4728722f, -1.3667214f, 0.2625635f, - -0.0302200f, -0.0322885f, 0.0227866f, 0.6977839f, 0.0050141f, -1.6183628f, 0.0869662f, -0.0775411f, - 0.4754244f, 0.4596581f, 2.1509945f, -0.0313832f, 0.0336208f, -0.1547154f, -0.6017126f, 0.0369996f, - -0.1102583f, -0.5788267f, 0.0017006f, 2.6352038f, -1.7847317f, 1.7510574f, 2.1478791f, -0.2251654f}, - }; - - ALPAKA_STATIC_ACC_MEM_GLOBAL const float bias_layer2[32] = { - -0.2689391f, 1.5461178f, -0.2424639f, 0.4424149f, -0.0411816f, -4.1070848f, 1.4709516f, -0.2439820f, - -0.1750926f, 2.8802166f, -0.1573734f, -1.3724055f, 0.3671952f, 1.8267332f, 1.5655776f, -0.7323843f, - 1.6318209f, 2.2198663f, -1.5951139f, -0.0870247f, 0.2806863f, -0.2407108f, 0.1310665f, -0.5246177f, - 0.1914421f, -0.3386542f, -0.6310596f, 3.2995102f, 0.7519229f, -0.1565450f, -0.1496341f, 1.0073272f}; - - ALPAKA_STATIC_ACC_MEM_GLOBAL const float wgtT_layer2[32][32] = { - {-0.1731049f, 1.7775618f, -0.2532010f, -0.2902778f, -0.1392802f, 4.2428946f, -0.1866968f, -0.1800365f, - -0.0634398f, 0.0763313f, 0.0472901f, -0.8030146f, 0.3161853f, -1.0713238f, -4.6514492f, -0.3908085f, - 1.1607268f, 0.8834935f, -0.1194544f, -0.0785166f, 0.4967587f, -0.0558136f, -0.9601135f, -0.1001592f, - 3.4427991f, -0.2144053f, -0.3632556f, 0.0117088f, 0.1742481f, -0.2540179f, -0.1705156f, -0.2627344f}, - {-0.1478276f, -0.1659575f, 0.1602777f, -0.0758106f, 0.1067696f, -0.0247068f, -0.1123443f, -0.1724832f, - -0.0013103f, -0.0685904f, 0.1537329f, 0.1042632f, -0.0360880f, -0.0679077f, 0.0672719f, 0.1597116f, - -0.0150259f, 0.0367102f, -0.0545881f, -0.0693004f, -0.1008447f, -0.0672846f, -0.1395939f, -0.0324785f, - -0.1051702f, -0.0530534f, -0.1019061f, -0.0921245f, 0.1195077f, 0.0453448f, 0.0257045f, -0.0622537f}, - {-0.0363173f, -0.1990481f, -0.0452148f, 0.4074381f, -0.0731660f, -0.0823270f, 0.3154473f, -0.1909118f, - -0.0165690f, 0.1325824f, -0.0760181f, 0.7768906f, -0.2702211f, -0.6023573f, 1.5904741f, 0.2384946f, - 0.7610655f, -2.8705251f, 0.5754877f, -0.1587478f, -0.5708794f, -0.3421216f, 0.5023443f, 1.2806857f, - 0.2158970f, -0.1364033f, -0.3398291f, 0.9066412f, -1.2935438f, 0.0273695f, -0.1850613f, -0.9301611f}, - {-0.1281746f, 0.1695392f, 0.0805936f, -0.0598281f, 0.1266985f, -0.1697189f, -0.1091505f, -0.1569477f, - 0.0363969f, -0.0628394f, 0.0107523f, 0.0659535f, -0.0568244f, -0.1299786f, 0.0005438f, -0.0806242f, - -0.0806848f, -0.0919798f, -0.0748445f, 0.0792912f, 0.0022868f, 0.0211520f, -0.0183716f, 0.1279848f, - -0.1518286f, -0.0113527f, 0.0824359f, -0.0178597f, 0.0272009f, 0.0288935f, 0.0123459f, 0.1685353f}, - {0.1099675f, -0.3914332f, -0.0647218f, -0.8259028f, -0.0283726f, -0.0860217f, -2.0489185f, 0.1042144f, - 0.1024824f, 0.0735443f, -0.1235109f, -3.3674469f, -0.1799957f, -7.1867313f, 1.6053666f, -0.5203959f, - 0.8686391f, -0.0675404f, -2.8893898f, -0.0796400f, 1.2672142f, -0.0371844f, -1.8065344f, -2.2551982f, - 0.0355568f, 0.0672171f, 0.7150316f, 1.3620002f, -0.4106106f, 0.0126076f, 0.0408083f, 1.5958146f}, - {0.0525989f, 1.8947815f, -0.2513640f, -0.3715420f, -0.1752283f, 1.3911799f, -0.7633898f, -0.1716654f, - -0.0145629f, -1.7601604f, -0.1943324f, -0.5716376f, -0.8281464f, -0.0308049f, -1.4709659f, -0.4294116f, - -0.1030817f, -0.1823493f, 0.7561242f, -0.1608112f, 0.3980689f, -0.2464017f, -1.3065518f, 0.0875702f, - -0.1504322f, -0.0352198f, -0.4051513f, 0.7010455f, -0.2363433f, -0.1118084f, -0.1329087f, -0.3257700f}, - {-0.1209070f, 0.1677164f, -0.1353413f, -0.0410048f, -0.1432644f, 0.2649301f, 0.2247741f, -0.0425357f, - -0.2644008f, 1.4204332f, -0.2540753f, 0.2481354f, 1.9494507f, -0.2003033f, -0.5938342f, -0.3314930f, - 1.5038266f, -2.4000788f, -1.6202501f, -0.0256936f, -0.2890913f, -0.2113032f, 0.9030544f, 1.1483711f, - 0.0545346f, -0.1961582f, -0.2267976f, 0.2372836f, 2.5995049f, -0.1469661f, -0.1017130f, 1.6176132f}, - {0.0542207f, 2.7658713f, -0.1700335f, -0.3357265f, -0.1097085f, 1.6508883f, 0.0132292f, 0.1211861f, - -0.0852982f, 0.9232512f, 0.0202751f, 0.3138782f, 0.2674713f, 0.1247260f, 0.3859081f, 0.3961721f, - 1.0556988f, 0.8574673f, -0.1462571f, -0.1600272f, 0.4117427f, -0.1561815f, 0.0553897f, -0.2753994f, - 5.8420453f, 0.0883128f, 0.3594444f, -0.7174141f, 0.5683901f, 0.0096710f, -0.0957449f, -0.0195320f}, - {0.1561092f, -0.0417566f, -0.1044470f, 0.1186895f, -0.1195878f, 0.0446987f, -0.1386125f, -0.0103878f, - 0.1173026f, 0.1349312f, -0.0676422f, -0.1452308f, 0.0093872f, 0.0069650f, 0.1739093f, -0.1592752f, - -0.1329019f, -0.0459163f, -0.1511888f, -0.0040456f, 0.0065862f, 0.0106182f, 0.0318060f, 0.1003269f, - 0.0249398f, 0.1661194f, -0.0286407f, -0.1062361f, 0.0026465f, -0.0091479f, -0.1493473f, 0.0519762f}, - {-0.0702637f, 0.1154817f, -0.0680643f, 0.1447217f, 0.1394082f, -0.0691432f, 0.0939426f, 0.0483852f, - 0.1437123f, -0.1085759f, 0.0333924f, -0.0683726f, 0.0707103f, -0.0723069f, 0.0124601f, -0.0309495f, - -0.0308395f, -0.0695953f, -0.1078720f, 0.0858701f, -0.0773453f, 0.0477413f, 0.0615588f, 0.1656474f, - 0.1718751f, -0.1125762f, 0.1753366f, -0.0557704f, 0.0921221f, 0.0372290f, -0.1084552f, -0.0438967f}, - {-0.0557757f, 0.0694144f, 0.1150911f, -0.0202319f, 0.0661389f, -0.0928373f, 0.0441888f, -0.0028318f, - -0.0039446f, 0.0294675f, 0.1353384f, 0.0427515f, 0.0695194f, 0.1329748f, 0.1339706f, 0.0713900f, - -0.1384726f, 0.0925476f, 0.1581103f, 0.0100842f, -0.1248652f, -0.0173615f, 0.1637451f, -0.0025173f, - -0.0331219f, -0.0335269f, 0.0949441f, 0.0538645f, 0.0834281f, 0.0137191f, -0.1360130f, 0.0074489f}, - {-0.0949665f, -0.2181539f, 0.0871969f, 3.0772011f, -0.1152011f, -0.0022047f, 1.2700632f, -0.1173392f, - -0.1678371f, -1.3448639f, -0.2893313f, 1.5105180f, -0.6029126f, -1.1568675f, 1.4823192f, 0.1635401f, - -2.2136483f, -1.4164798f, -0.4795305f, -0.0807557f, -1.6675406f, -0.0992591f, 2.1212378f, -0.9400231f, - -0.5339298f, -0.0342672f, -2.3564072f, 1.3407421f, -3.8635128f, -0.1171367f, -0.0364181f, -3.2491686f}, - {-0.1047117f, -0.0540412f, -0.1137928f, 0.1582367f, -0.0982449f, 0.0511854f, -0.0805884f, -0.1141258f, - 0.0931992f, -0.0227052f, 0.0780590f, -0.1288135f, -0.1186576f, -0.0754066f, -0.1234059f, -0.0091936f, - 0.0205475f, 0.1640417f, -0.1527465f, 0.0068472f, -0.1239804f, -0.0448335f, -0.0061169f, -0.0078998f, - 0.0253047f, 0.0712901f, 0.0024753f, -0.0259875f, -0.1238613f, 0.1096537f, -0.0953007f, 0.1385384f}, - {0.0521762f, 1.4885306f, -0.1298001f, 2.3033395f, -0.1589162f, -0.8458843f, 0.0631668f, -0.1424429f, - -0.0384785f, 0.5599840f, 0.0008631f, -1.5839294f, 1.9202064f, 0.6930331f, 0.4948464f, -0.6195241f, - -3.0526664f, 3.1423819f, -1.3433597f, -0.1167206f, -1.3491610f, -0.0901343f, -1.2291449f, 3.5039587f, - 0.4674770f, -0.3027362f, 0.8279622f, 0.3417586f, 0.1367343f, -0.1085793f, -0.1048759f, 1.2729272f}, - {-0.0029521f, 0.2439991f, -0.0858953f, -2.7804739f, -0.0220416f, 0.0256599f, -0.3304259f, -0.0586597f, - -0.0459698f, 0.1670698f, -0.1359344f, -0.3957845f, -1.6954739f, 0.3318155f, 0.9375985f, 0.5211958f, - 0.6071047f, -3.4249072f, 1.3199407f, 0.0136374f, 1.2692807f, 0.0233104f, -0.0731508f, 2.2171400f, - -0.6052189f, -0.0698463f, 1.6376522f, -1.1908000f, -0.1706121f, -0.0380146f, 0.0144418f, 1.5177792f}, - {-0.0314772f, 0.0523589f, -0.0517322f, -0.0100344f, 0.0714635f, -0.1646974f, 0.0800682f, 0.1132821f, - -0.0028872f, -0.1239987f, -0.1322138f, -0.1059789f, 0.1752418f, 0.0475279f, -0.0046871f, 0.1574167f, - -0.0231106f, -0.0261228f, 0.0236005f, 0.1663371f, 0.1059707f, 0.1229704f, 0.1427562f, -0.1648343f, - 0.0992667f, -0.0631751f, -0.1411413f, -0.0999486f, -0.0972435f, -0.1422556f, 0.0973614f, -0.0156000f}, - {-0.1309903f, -0.5060971f, -0.1911870f, 2.2349114f, 0.1010354f, 0.5538697f, 1.8757060f, -0.1538645f, - -0.2073075f, -1.8350753f, 0.0532570f, 1.8151909f, -0.6800886f, 0.2615838f, -0.6204563f, -0.1238837f, - -0.4772464f, -2.4070835f, -0.2783994f, -0.0211087f, -4.4925098f, -0.0790045f, 1.3566529f, -0.3650998f, - -0.4658130f, -0.0479139f, -1.9361999f, 2.1485121f, -3.1108823f, -0.0020647f, -0.0489678f, -0.4781263f}, - {-0.0099352f, -1.9572417f, 0.0918592f, 0.7327217f, -0.0609625f, -0.1969659f, 0.1922992f, -0.1091586f, - -0.2125459f, -1.9542989f, -0.1648019f, -0.9355955f, 0.9144324f, -5.0530005f, -0.2265045f, -0.5638458f, - 4.4370432f, -2.0318019f, -1.5679311f, 0.0221776f, -0.4063498f, -0.1160609f, 0.9651156f, -0.2401051f, - 0.1903293f, -0.2355373f, 0.2334733f, 0.1025979f, 0.7150746f, 0.0315593f, -0.0001765f, 0.0137871f}, - {0.0320691f, -1.8876421f, -0.1241799f, -3.1652985f, -0.1528286f, 2.1882250f, -2.5907574f, 0.0210803f, - -0.1545521f, 0.7706368f, -0.1652040f, -4.1518817f, 4.2974262f, 0.3074523f, 3.3711803f, -37.9055862f, - 1.0623894f, 0.4360786f, -2.6417589f, 0.1113010f, 3.8902094f, -0.1616735f, 0.5595753f, 1.5364015f, - -2.4740698f, -0.0240434f, -28.0232792f, 0.6092473f, 1.6978041f, -0.0458809f, 0.0664777f, 0.2603019f}, - {0.1044999f, 0.0054908f, 0.1407564f, -0.1701076f, -0.1274551f, 0.0443607f, 0.1182709f, -0.1103420f, - -0.1343671f, -0.0042888f, -0.1611361f, 0.0154269f, 0.2285106f, 0.0870507f, 0.0914433f, 0.0657276f, - -0.1664300f, -0.0342912f, 0.1037545f, -0.1175308f, 0.1135652f, 0.1325845f, -0.1459545f, -0.2156865f, - -0.1673723f, -0.1156510f, 0.0179541f, 0.0541515f, 0.0957617f, -0.1297485f, 0.1045326f, 0.2950188f}, - {-0.1401742f, -2.8181052f, -0.0588381f, -0.1517100f, -0.0608850f, -3.5837226f, -0.1528927f, -0.0211265f, - 0.0881796f, -0.4448619f, -0.1457623f, -0.8828475f, 0.1261238f, -1.0495204f, -3.7918513f, -0.4645159f, - -0.0800092f, 0.0624971f, 0.1528609f, -0.1069645f, 0.4319421f, 0.0651448f, -0.6571375f, -0.0323338f, - -4.6534319f, -0.0538999f, -0.2221518f, 0.0972160f, 0.1496329f, 0.0570569f, -0.1125795f, -0.0153687f}, - {-0.1065502f, 0.0606179f, -0.1400291f, -0.0220975f, -0.0613350f, -0.0038843f, -0.0132201f, 0.1678067f, - 0.1008587f, -0.1255144f, -0.0675021f, -0.0475353f, 0.0278098f, 0.0527470f, -0.0089845f, -0.0622052f, - 0.1088723f, 0.0053812f, 0.0627310f, -0.0226460f, -0.1096366f, -0.0505830f, -0.0301058f, -0.0775778f, - -0.0008928f, -0.1157909f, 0.0544982f, 0.0430219f, -0.0134386f, -0.1095094f, 0.1215172f, 0.0081556f}, - {-0.1747307f, -0.7465636f, -0.0497346f, -2.0686443f, 0.0190713f, -2.9156351f, -5.4731860f, -0.0728399f, - -0.0845178f, -14.8429976f, -0.1068359f, 1.8549156f, -3.1135283f, -0.0907917f, -0.0262453f, -8.8010912f, - -4.3007965f, -1.6772208f, -0.2576891f, -0.0163111f, -7.8583646f, 0.0697906f, -0.0943863f, -0.7450574f, - 1.1493169f, 0.0921000f, -0.2395420f, 0.5794312f, -4.2405462f, -0.0910322f, -0.1381017f, -1.0270567f}, - {-0.0446755f, -0.8131990f, -0.1741483f, -1.7555307f, 0.0153283f, 0.0734032f, -0.5930048f, -0.0398877f, - -0.0215982f, 0.0497884f, -0.0504920f, 0.0942539f, -1.1370168f, -0.8821361f, -0.0879569f, 0.3811991f, - 1.2224945f, 0.3782545f, 1.4800016f, 0.0494110f, 1.7101970f, -0.2885793f, -0.1778114f, -1.3913733f, - -0.0944610f, -0.3578439f, 0.3491475f, -3.0349872f, 0.8044587f, 0.0928676f, -0.0395946f, 0.2008810f}, - {0.0721043f, -0.1181163f, 0.0108281f, -0.1215726f, 0.1285277f, 0.0851443f, 0.0791321f, 0.1765833f, - -0.0324889f, -0.0150838f, -0.0051942f, 0.1685798f, 0.1521861f, 0.0283858f, 0.0326072f, 0.0346215f, - -0.1081120f, -0.0745824f, -0.1762613f, 0.0901582f, 0.1335704f, 0.1599123f, -0.0097813f, 0.0364541f, - -0.0391450f, -0.0079635f, 0.1014886f, 0.0130333f, 0.0438304f, -0.0074333f, 0.0845035f, -0.0471010f}, - {0.0360538f, -0.9701002f, -0.2217611f, -1.1626705f, 0.0548465f, 0.6605385f, -0.6693703f, -0.1432099f, - -0.0754442f, -0.2380328f, -0.0754142f, -2.3242903f, 3.5773275f, 0.0707042f, 0.2052065f, -1.3753067f, - -0.8530636f, 3.1850073f, -0.2901604f, -0.1291050f, -4.4672642f, -0.2425279f, 0.1252670f, 0.4261391f, - -0.8620862f, 0.1153403f, -0.1999598f, -4.7756801f, 2.8851914f, -0.1340472f, 0.0482952f, 1.7996837f}, - {-0.1654812f, 0.9604513f, 0.1770310f, -16.5736618f, -0.0350192f, -0.5557595f, -35.3047371f, -0.1299658f, - 0.0065243f, -3.0823336f, 0.0351931f, 4.9456911f, -1.4382623f, -1.6900688f, -1.9084880f, -3.1811504f, - -8.0212736f, -7.3994560f, 4.9219728f, 0.0433824f, 0.6197430f, 0.0308996f, 5.2004323f, 0.5327767f, - 1.0885966f, 0.1487215f, -21.4211712f, -1.8733859f, 1.9195696f, -0.0539309f, -0.0795544f, -3.1121061f}, - {-0.0058153f, 1.7521383f, -0.2205407f, 2.6318321f, -0.0038140f, -1.4131194f, 3.0181022f, 0.0373498f, - -0.1246315f, -1.8323456f, -0.1470954f, 2.9131169f, 1.1522563f, 0.6036215f, -3.3962972f, 7.0906253f, - -1.5353408f, -0.2648884f, 0.5501783f, -0.2262681f, -2.4874980f, -0.0533402f, 3.0222948f, 0.3296265f, - 1.4057258f, 0.0185255f, 6.1208682f, 0.7210779f, -0.3055671f, -0.2595702f, -0.1286864f, 0.6510819f}, - {-0.2145578f, 0.4758183f, -0.1186396f, -0.6096930f, -0.1574199f, -0.1929667f, -0.6877209f, -0.2098342f, - 0.0726678f, 0.1379885f, 0.0710437f, -1.1860796f, 0.6582619f, 0.2388466f, 0.0458675f, -0.0634391f, - -0.1678368f, -8.2454395f, -0.6461441f, -0.2063597f, 0.0304686f, 0.0319904f, -1.0730971f, 1.1281222f, - 0.1292592f, -0.3054110f, 0.7732272f, -1.0069786f, -0.0847367f, -0.2342585f, -0.1553642f, 1.5100089f}, - {-0.1022291f, 2.7367072f, -0.1738961f, -1.0328600f, -0.0864617f, -0.3224345f, -2.6092832f, -0.2382921f, - 0.0578183f, 0.4115438f, 0.0121692f, -1.0689495f, 0.5158959f, 2.9600139f, 0.8839240f, -0.7147520f, - -2.7168157f, 1.2148006f, 1.5884653f, -0.1227511f, 1.3176637f, -0.1335970f, -1.4691980f, 1.1131358f, - -0.1302031f, 0.0779746f, 0.2622980f, 0.0837635f, 2.7756395f, -0.0315265f, 0.0868374f, -4.2980185f}, - {0.0228074f, 2.1787968f, -0.1889012f, -0.8560471f, -0.1063542f, -0.2869910f, 0.2767612f, -0.1183861f, - -0.0992468f, 2.1517978f, -0.0428540f, 1.0697522f, 1.9683092f, 2.1042306f, -0.0426359f, -0.3499008f, - -0.9989156f, 0.0880459f, 2.9753070f, -0.1941337f, -3.1616704f, -0.0093505f, 1.4922180f, 2.8480091f, - 0.2656264f, -0.1299839f, -1.0458518f, -1.6748481f, -3.1420829f, -0.1360553f, -0.1117443f, -1.3989290f}, - {-0.0246332f, 0.1165779f, 0.0255498f, -0.0601489f, 0.1545041f, -0.0977981f, 0.1242626f, -0.1533627f, - -0.1294386f, -0.0231293f, -0.1460808f, 0.1763088f, 0.0953614f, -0.0716483f, -0.1003436f, 0.0804519f, - 0.1373295f, -0.0686773f, 0.1198382f, 0.1519430f, 0.1640775f, -0.1675753f, 0.0790529f, -0.1521838f, - 0.0378523f, 0.1039687f, -0.0701027f, 0.0509319f, 0.1355647f, 0.0978021f, 0.0391430f, 0.0241266f}, - }; - - ALPAKA_STATIC_ACC_MEM_GLOBAL const float bias_output_layer[1] = {-0.7420582f}; - - ALPAKA_STATIC_ACC_MEM_GLOBAL const float wgtT_output_layer[32][1] = { - {0.0381968f}, {1.0667214f}, {0.0505496f}, {-1.5677565f}, {0.0066824f}, {-0.9951485f}, {0.9438043f}, - {0.0068631f}, {-0.0216870f}, {0.6560486f}, {-0.0235629f}, {0.9653404f}, {0.6641668f}, {-0.5351945f}, - {-0.5303048f}, {1.9339687f}, {0.4359012f}, {-0.7492802f}, {-0.5728400f}, {0.0473893f}, {-0.5091293f}, - {-0.1926489f}, {-0.6562935f}, {-0.5583456f}, {-0.7618014f}, {-0.0316967f}, {1.1637378f}, {-0.5158406f}, - {-0.5268564f}, {0.0735416f}, {0.0270067f}, {-0.5614370f}, - }; - -} // namespace ALPAKA_ACCELERATOR_NAMESPACE::lst::t5dnn - -#endif diff --git a/RecoTracker/LSTCore/src/alpaka/Quintuplet.h b/RecoTracker/LSTCore/src/alpaka/Quintuplet.h index 679b1334038b2..83af8a3e08493 100644 --- a/RecoTracker/LSTCore/src/alpaka/Quintuplet.h +++ b/RecoTracker/LSTCore/src/alpaka/Quintuplet.h @@ -13,6 +13,7 @@ #include "RecoTracker/LSTCore/interface/ModulesSoA.h" #include "RecoTracker/LSTCore/interface/EndcapGeometry.h" #include "RecoTracker/LSTCore/interface/ObjectRangesSoA.h" +#include "RecoTracker/LSTCore/interface/DnnWeightsDevSoA.h" #include "NeuralNetwork.h" #include "Hit.h" @@ -1465,6 +1466,7 @@ namespace ALPAKA_ACCELERATOR_NAMESPACE::lst { float& dBeta1, float& dBeta2, bool& tightCutFlag, + lst::DnnWeightsDevData const* dnnPtr, const float ptCut) { unsigned int firstSegmentIndex = triplets.segmentIndices()[innerTripletIndex][0]; unsigned int secondSegmentIndex = triplets.segmentIndices()[innerTripletIndex][1]; @@ -1504,6 +1506,7 @@ namespace ALPAKA_ACCELERATOR_NAMESPACE::lst { innerRadius = triplets.radius()[innerTripletIndex]; bool inference = lst::t5dnn::runInference(acc, + dnnPtr, mds, firstMDIndex, secondMDIndex, @@ -1651,6 +1654,7 @@ namespace ALPAKA_ACCELERATOR_NAMESPACE::lst { QuintupletsOccupancy quintupletsOccupancy, ObjectRangesConst ranges, uint16_t nEligibleT5Modules, + lst::DnnWeightsDevData const* dnnPtr, const float ptCut) const { auto const globalThreadIdx = alpaka::getIdx(acc); auto const gridThreadExtent = alpaka::getWorkDiv(acc); @@ -1709,6 +1713,7 @@ namespace ALPAKA_ACCELERATOR_NAMESPACE::lst { dBeta1, dBeta2, tightCutFlag, + dnnPtr, ptCut); if (success) { diff --git a/RecoTracker/LSTCore/standalone/analysis/DNN/network_weights.bin b/RecoTracker/LSTCore/standalone/analysis/DNN/network_weights.bin new file mode 100644 index 0000000000000..398bc20e6c95a Binary files /dev/null and b/RecoTracker/LSTCore/standalone/analysis/DNN/network_weights.bin differ diff --git a/RecoTracker/LSTCore/standalone/analysis/DNN/train_T5_DNN.ipynb b/RecoTracker/LSTCore/standalone/analysis/DNN/train_T5_DNN.ipynb index e7ec1b45283e5..6fc92257fb63f 100644 --- a/RecoTracker/LSTCore/standalone/analysis/DNN/train_T5_DNN.ipynb +++ b/RecoTracker/LSTCore/standalone/analysis/DNN/train_T5_DNN.ipynb @@ -873,6 +873,44 @@ "print_model_weights_biases(model)\n" ] }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import struct\n", + "\n", + "def save_weights_binary(model, filename):\n", + " \"\"\"\n", + " Save weights in binary format with correct transposition\n", + " \"\"\"\n", + " with open(filename, 'wb') as f:\n", + " layer_id = 1\n", + " \n", + " for name, module in model.named_modules():\n", + " if isinstance(module, nn.Linear):\n", + " weights = module.weight.data.cpu().numpy()\n", + " biases = module.bias.data.cpu().numpy()\n", + " \n", + " # Write header with correct dimensions\n", + " f.write(struct.pack('III', \n", + " layer_id,\n", + " weights.shape[1], # num_inputs\n", + " weights.shape[0] # num_outputs \n", + " ))\n", + " \n", + " # Write biases\n", + " biases.astype(np.float32).tofile(f)\n", + " \n", + " # Write weights\n", + " weights.astype(np.float32).tofile(f)\n", + " \n", + " layer_id += 1\n", + "\n", + "save_weights_binary(model, \"network_weights.bin\")" + ] + }, { "cell_type": "code", "execution_count": 11,