diff --git a/cmd/main.go b/cmd/main.go index fb0edfe..fbce685 100644 --- a/cmd/main.go +++ b/cmd/main.go @@ -24,6 +24,7 @@ type cfg struct { CertKey string `mapstructure:"cert-key" config:"zerodefault"` PluginDir string `mapstructure:"plugin-dir" config:"zerodefault"` ListOptions string `mapstructure:"list-options" valid:"in(all|selected)"` + SecureLoader bool `mapstructure:"secure-loader" config:"zerodefault"` } func (o cfg) Validate() error { @@ -76,8 +77,22 @@ func main() { } // Load sub-attesters from the path specified in config.yaml - pluginManager, err := plugin.CreateGoPluginManager( - cfg.PluginDir, log.Named("plugin")) + pluginLoader, err := plugin.CreateGoPluginLoader(cfg.PluginDir, log.Named("plugin")) + if err != nil { + log.Fatalf("could not create the plugin loader: %v", err) + } + if cfg.SecureLoader { + subs, err := config.GetSubs(v, "plugins") + if err != nil { + log.Fatalf("failed to enable secure loader: %v", err) + } + if err := pluginLoader.SetChecksum(subs["plugins"]); err != nil { + log.Fatalf("secure loader failed to set plugin checksum: %v", err) + } + } + + pluginManager, err := plugin.CreateGoPluginManagerWithLoader( + pluginLoader, log.Named("plugin")) if err != nil { log.Fatalf("could not create the plugin manager: %v", err) diff --git a/plugin/goplugin_context.go b/plugin/goplugin_context.go index 0903361..a1be09f 100644 --- a/plugin/goplugin_context.go +++ b/plugin/goplugin_context.go @@ -3,8 +3,10 @@ package plugin import ( + "crypto/sha256" "fmt" "os/exec" + "path/filepath" "strings" "github.com/hashicorp/go-plugin" @@ -42,15 +44,31 @@ func createPluginContext( path string, logger *zap.SugaredLogger, ) (*PluginContext, error) { - client := plugin.NewClient( - &plugin.ClientConfig{ - HandshakeConfig: handshakeConfig, - Plugins: loader.pluginMap, - Cmd: exec.Command(path), - Logger: log.NewInternalLogger(logger), - AllowedProtocols: []plugin.Protocol{plugin.ProtocolGRPC}, - }, - ) + cfg := &plugin.ClientConfig{ + HandshakeConfig: handshakeConfig, + Plugins: loader.pluginMap, + Cmd: exec.Command(path), + Logger: log.NewInternalLogger(logger), + AllowedProtocols: []plugin.Protocol{plugin.ProtocolGRPC}, + } + + if len(loader.pluginChecksum) > 0 { + basename := filepath.Base(path) + pluginName := strings.TrimSuffix(basename, filepath.Ext(basename)) + + checksum, ok := loader.pluginChecksum[pluginName] + if !ok { + return nil, fmt.Errorf("the checksum for plugin %s is missing", pluginName) + } + + secureConfig := &plugin.SecureConfig{ + Checksum:[]byte(checksum), + Hash: sha256.New(), + } + cfg.SecureConfig = secureConfig + } + + client := plugin.NewClient(cfg) rpcClient, err := client.Client() if err != nil { diff --git a/plugin/goplugin_loader.go b/plugin/goplugin_loader.go index f5d3782..19cec1e 100644 --- a/plugin/goplugin_loader.go +++ b/plugin/goplugin_loader.go @@ -3,10 +3,12 @@ package plugin import ( + "encoding/hex" "errors" "fmt" "github.com/hashicorp/go-plugin" + "github.com/spf13/viper" "go.uber.org/zap" ) @@ -23,6 +25,7 @@ type GoPluginLoader struct { logger *zap.SugaredLogger loadedByName map[string]*PluginContext + pluginChecksum map[string][]byte // This gets specified as Plugins when creating a new go-plugin client. pluginMap map[string]plugin.Plugin @@ -44,6 +47,7 @@ func CreateGoPluginLoader( func (o *GoPluginLoader) Init(dir string) error { o.pluginMap = make(map[string]plugin.Plugin) o.loadedByName = make(map[string]*PluginContext) + o.pluginChecksum = make(map[string][]byte) o.Location = dir return nil @@ -55,6 +59,36 @@ func (o *GoPluginLoader) Close() { } } +func (o *GoPluginLoader) SetChecksum(v *viper.Viper) error { + for _, name := range v.AllKeys() { + sha256sum := v.Get(name) + switch t := sha256sum.(type) { + case string: + o.logger.Debugw("registered plugin checksum", + "name", name, + "checksum", t, + ) + + checksum, err := hex.DecodeString(t) + if err != nil { + return fmt.Errorf( + "failed to load sha256 checksum for %s: %v", + name, err, + ) + } + + o.pluginChecksum[name] = checksum + default: + return fmt.Errorf( + "invalid checksum for plugin %q: expected string, got %T", + name, t, + ) + } + } + + return nil +} + func RegisterGoPluginUsing(loader *GoPluginLoader, name string) error { if _, ok := loader.pluginMap[name]; ok { return fmt.Errorf("plugin for %q is already registred", name)