From cf9ffbc9bf3580086cebc262a50f4e6d01efe6c3 Mon Sep 17 00:00:00 2001 From: Gil Date: Thu, 15 Nov 2018 09:55:16 +0000 Subject: [PATCH] AnyParameterProperties refactor (#4412) * combined parameter flags in a single mask * added mask_attribute to keep track of parameter availabilities * updated constructors and getters * added some documentation * changed properties default * typesafe bitmasking * refactored code to use enum class and bitmask operators --- src/shogun/base/AnyParameter.h | 86 ++++++++++++++++++---- src/shogun/base/SGObject.h | 14 ++-- src/shogun/lib/bitmask_operators.h | 110 +++++++++++++++++++++++++++++ 3 files changed, 190 insertions(+), 20 deletions(-) create mode 100644 src/shogun/lib/bitmask_operators.h diff --git a/src/shogun/base/AnyParameter.h b/src/shogun/base/AnyParameter.h index 23a1ba44457..4cc8f80db66 100644 --- a/src/shogun/base/AnyParameter.h +++ b/src/shogun/base/AnyParameter.h @@ -1,7 +1,14 @@ +/* + * This software is distributed under BSD 3-clause license (see LICENSE file). + * + * Authors: Heiko Strathmann, Gil Hoben + */ + #ifndef __ANYPARAMETER_H__ #define __ANYPARAMETER_H__ #include +#include #include @@ -22,48 +29,101 @@ namespace shogun GRADIENT_AVAILABLE = 1 }; + /** parameter properties */ + enum class ParameterProperties + { + HYPER = 1u << 0, + GRADIENT = 1u << 1, + MODEL = 1u << 2 + }; + + enableEnumClassBitmask(ParameterProperties); + + /** @brief Class AnyParameterProperties keeps track of of parameter meta + * information, such as properties and descriptions The parameter properties + * can be either true or false. These properties describe if a parameter is + * for example a hyperparameter or if it has a gradient. + */ class AnyParameterProperties { public: + /** Default constructor where all parameter properties are false + */ AnyParameterProperties() - : m_description(), m_model_selection(MS_NOT_AVAILABLE), - m_gradient(GRADIENT_NOT_AVAILABLE) + : m_description("No description given"), + m_attribute_mask(ParameterProperties()) { } + /** Constructor + * @param description parameter description + * @param hyperparameter set to true for parameters that determine + * how training is performed, e.g. regularisation parameters + * @param gradient set to true for parameters required for gradient + * updates + * @param model set to true for parameters used in inference, e.g. + * weights and bias + * */ AnyParameterProperties( std::string description, - EModelSelectionAvailability model_selection = MS_NOT_AVAILABLE, - EGradientAvailability gradient = GRADIENT_NOT_AVAILABLE) - : m_description(description), m_model_selection(model_selection), + EModelSelectionAvailability hyperparameter = MS_NOT_AVAILABLE, + EGradientAvailability gradient = GRADIENT_NOT_AVAILABLE, + bool model = false) + : m_description(description), m_model_selection(hyperparameter), m_gradient(gradient) { + m_attribute_mask = ParameterProperties(); + if (hyperparameter) + m_attribute_mask |= ParameterProperties::HYPER; + if (gradient) + m_attribute_mask |= ParameterProperties::GRADIENT; + if (model) + m_attribute_mask |= ParameterProperties::MODEL; } + /** Mask constructor + * @param description parameter description + * @param attribute_mask mask encoding parameter properties + * */ + AnyParameterProperties( + std::string description, ParameterProperties attribute_mask) + : m_description(description) + { + m_attribute_mask = attribute_mask; + } + /** Copy contructor */ AnyParameterProperties(const AnyParameterProperties& other) : m_description(other.m_description), m_model_selection(other.m_model_selection), - m_gradient(other.m_gradient) + m_gradient(other.m_gradient), + m_attribute_mask(other.m_attribute_mask) { } - - std::string get_description() const + const std::string& get_description() const { return m_description; } - EModelSelectionAvailability get_model_selection() const { - return m_model_selection; + return static_cast( + static_cast( + m_attribute_mask & ParameterProperties::HYPER) > 0); } - EGradientAvailability get_gradient() const { - return m_gradient; + return static_cast( + static_cast( + m_attribute_mask & ParameterProperties::GRADIENT) > 0); + } + bool get_model() const + { + return static_cast( + m_attribute_mask & ParameterProperties::MODEL); } private: std::string m_description; EModelSelectionAvailability m_model_selection; EGradientAvailability m_gradient; + ParameterProperties m_attribute_mask; }; class AnyParameter @@ -116,6 +176,6 @@ namespace shogun Any m_value; AnyParameterProperties m_properties; }; -} +} // namespace shogun #endif diff --git a/src/shogun/base/SGObject.h b/src/shogun/base/SGObject.h index 8464880e781..7aefd5b82f8 100644 --- a/src/shogun/base/SGObject.h +++ b/src/shogun/base/SGObject.h @@ -675,8 +675,7 @@ class CSGObject template void watch_param( const std::string& name, T* value, - AnyParameterProperties properties = AnyParameterProperties( - "Unknown parameter", MS_NOT_AVAILABLE, GRADIENT_NOT_AVAILABLE)) + AnyParameterProperties properties = AnyParameterProperties()) { BaseTag tag(name); create_parameter(tag, AnyParameter(make_any_ref(value), properties)); @@ -693,8 +692,7 @@ class CSGObject template void watch_param( const std::string& name, T** value, S* len, - AnyParameterProperties properties = AnyParameterProperties( - "Unknown parameter", MS_NOT_AVAILABLE, GRADIENT_NOT_AVAILABLE)) + AnyParameterProperties properties = AnyParameterProperties()) { BaseTag tag(name); create_parameter( @@ -714,8 +712,7 @@ class CSGObject template void watch_param( const std::string& name, T** value, S* rows, S* cols, - AnyParameterProperties properties = AnyParameterProperties( - "Unknown parameter", MS_NOT_AVAILABLE, GRADIENT_NOT_AVAILABLE)) + AnyParameterProperties properties = AnyParameterProperties()) { BaseTag tag(name); create_parameter( @@ -733,7 +730,10 @@ class CSGObject { BaseTag tag(name); AnyParameterProperties properties( - "Dynamic parameter", MS_NOT_AVAILABLE, GRADIENT_NOT_AVAILABLE); + "Dynamic parameter", + ParameterProperties::HYPER | + ParameterProperties::GRADIENT | + ParameterProperties::MODEL); std::function bind_method = std::bind(method, dynamic_cast(this)); create_parameter(tag, AnyParameter(make_any(bind_method), properties)); diff --git a/src/shogun/lib/bitmask_operators.h b/src/shogun/lib/bitmask_operators.h new file mode 100644 index 00000000000..5417dea17cb --- /dev/null +++ b/src/shogun/lib/bitmask_operators.h @@ -0,0 +1,110 @@ +#ifndef JSS_BITMASK_HPP +#define JSS_BITMASK_HPP + +// (C) Copyright 2015 Just Software Solutions Ltd +// +// Distributed under the Boost Software License, Version 1.0. +// +// Boost Software License - Version 1.0 - August 17th, 2003 +// +// Permission is hereby granted, free of charge, to any person or +// organization obtaining a copy of the software and accompanying +// documentation covered by this license (the "Software") to use, +// reproduce, display, distribute, execute, and transmit the +// Software, and to prepare derivative works of the Software, and +// to permit third-parties to whom the Software is furnished to +// do so, all subject to the following: +// +// The copyright notices in the Software and this entire +// statement, including the above license grant, this restriction +// and the following disclaimer, must be included in all copies +// of the Software, in whole or in part, and all derivative works +// of the Software, unless such copies or derivative works are +// solely in the form of machine-executable object code generated +// by a source language processor. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY +// KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE +// WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR +// PURPOSE, TITLE AND NON-INFRINGEMENT. IN NO EVENT SHALL THE +// COPYRIGHT HOLDERS OR ANYONE DISTRIBUTING THE SOFTWARE BE +// LIABLE FOR ANY DAMAGES OR OTHER LIABILITY, WHETHER IN +// CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN +// CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +#include + +namespace shogun { + + template + struct enable_bitmask_operators { + static constexpr bool enable = false; + }; + + #define enableEnumClassBitmask(T) template<> \ + struct enable_bitmask_operators \ + { \ + static constexpr bool enable = true; \ + } + + template + typename std::enable_if::enable, E>::type + operator|(E lhs, E rhs) { + typedef typename std::underlying_type::type underlying; + return static_cast( + static_cast(lhs) | static_cast(rhs)); + } + + template + typename std::enable_if::enable, E>::type + operator&(E lhs, E rhs) { + typedef typename std::underlying_type::type underlying; + return static_cast( + static_cast(lhs) & static_cast(rhs)); + } + + template + typename std::enable_if::enable, E>::type + operator^(E lhs, E rhs) { + typedef typename std::underlying_type::type underlying; + return static_cast( + static_cast(lhs) ^ static_cast(rhs)); + } + + template + typename std::enable_if::enable, E>::type + operator~(E lhs) { + typedef typename std::underlying_type::type underlying; + return static_cast( + ~static_cast(lhs)); + } + + template + typename std::enable_if::enable, E &>::type + operator|=(E &lhs, E rhs) { + typedef typename std::underlying_type::type underlying; + lhs = static_cast( + static_cast(lhs) | static_cast(rhs)); + return lhs; + } + + template + typename std::enable_if::enable, E &>::type + operator&=(E &lhs, E rhs) { + typedef typename std::underlying_type::type underlying; + lhs = static_cast( + static_cast(lhs) & static_cast(rhs)); + return lhs; + } + + template + typename std::enable_if::enable, E &>::type + operator^=(E &lhs, E rhs) { + typedef typename std::underlying_type::type underlying; + lhs = static_cast( + static_cast(lhs) ^ static_cast(rhs)); + return lhs; + } +} +#endif \ No newline at end of file