Skip to content

Commit

Permalink
Merge pull request #84 from spacemeshos/78-refactor-initializer
Browse files Browse the repository at this point in the history
Refactor initializer to be simpler to use and race free
  • Loading branch information
fasmat authored Nov 9, 2022
2 parents 1fcfbdf + a7e19fd commit 87da468
Show file tree
Hide file tree
Showing 4 changed files with 574 additions and 462 deletions.
231 changes: 112 additions & 119 deletions initialization/initialization.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@ import (
"path/filepath"
"sync"
"sync/atomic"
"time"

"golang.org/x/sync/errgroup"

Expand All @@ -29,11 +28,19 @@ type (
ComputeProvider = gpu.ComputeProvider
)

type Status int

const (
StatusNotStarted Status = iota
StatusStarted
StatusInitializing
StatusCompleted
StatusError
)

var (
ErrNotInitializing = errors.New("not initializing")
ErrAlreadyInitializing = errors.New("already initializing")
ErrCannotResetWhileInitializing = errors.New("cannot reset while initializing")
ErrStopped = errors.New("gpu-post: stopped")
ErrStateMetadataFileMissing = errors.New("metadata file is missing")
)

Expand All @@ -45,67 +52,105 @@ func CPUProviderID() int {
return gpu.CPUProviderID()
}

type initializeOption struct {
commitment []byte
cfg *Config
initOpts *config.InitOpts
logger Logger
}

func (opts *initializeOption) verify() error {
if opts.cfg == nil {
return errors.New("no config provided")
}

if opts.initOpts == nil {
return errors.New("no init options provided")
}

return config.Validate(*opts.cfg, *opts.initOpts)
}

type initializeOptionFunc func(*initializeOption) error

func WithCommitment(commitment []byte) initializeOptionFunc {
return func(opts *initializeOption) error {
if len(commitment) != 32 {
return fmt.Errorf("invalid `id` length; expected: 32, given: %v", len(commitment))
}
opts.commitment = commitment
return nil
}
}

func WithInitOpts(initOpts config.InitOpts) initializeOptionFunc {
return func(opts *initializeOption) error {
opts.initOpts = &initOpts
return nil
}
}

func WithConfig(cfg Config) initializeOptionFunc {
return func(opts *initializeOption) error {
opts.cfg = &cfg
return nil
}
}

func WithLogger(logger Logger) initializeOptionFunc {
return func(opts *initializeOption) error {
if logger == nil {
return errors.New("logger is nil")
}
opts.logger = logger
return nil
}
}

type Initializer struct {
numLabelsWritten atomic.Uint64
numLabelsWrittenChan chan uint64
numLabelsWritten atomic.Uint64

cfg Config
opts InitOpts
commitment []byte

diskState *DiskState
initializing bool
mtx sync.RWMutex

stopChan chan struct{}
doneChan chan struct{}
diskState *DiskState
mtx sync.RWMutex

logger Logger
}

func NewInitializer(cfg Config, opts config.InitOpts, commitment []byte) (*Initializer, error) {
if len(commitment) != 32 {
return nil, fmt.Errorf("invalid `id` length; expected: 32, given: %v", len(commitment))
func NewInitializer(opts ...initializeOptionFunc) (*Initializer, error) {
options := &initializeOption{
logger: shared.DisabledLogger{},
}

for _, opt := range opts {
if err := opt(options); err != nil {
return nil, err
}
}

if err := config.Validate(cfg, opts); err != nil {
if err := options.verify(); err != nil {
return nil, err
}

return &Initializer{
cfg: cfg,
opts: opts,
commitment: commitment,
diskState: NewDiskState(opts.DataDir, uint(cfg.BitsPerLabel)),
logger: shared.DisabledLogger{},
cfg: *options.cfg,
opts: *options.initOpts,
commitment: options.commitment,
diskState: NewDiskState(options.initOpts.DataDir, uint(options.cfg.BitsPerLabel)),
logger: options.logger,
}, nil
}

// Initialize is the process in which the prover commits to store some data, by having its storage filled with
// pseudo-random data with respect to a specific id. This data is the result of a computationally-expensive operation.
func (init *Initializer) Initialize() error {
init.mtx.Lock()

if init.initializing {
init.mtx.Unlock()
func (init *Initializer) Initialize(ctx context.Context) error {
if !init.mtx.TryLock() {
return ErrAlreadyInitializing
}

init.stopChan = make(chan struct{})
init.doneChan = make(chan struct{})
init.numLabelsWrittenChan = make(chan uint64)

init.initializing = true
init.mtx.Unlock()

defer func() {
init.mtx.Lock()
defer init.mtx.Unlock()
init.initializing = false

close(init.doneChan)
close(init.numLabelsWrittenChan)
}()
defer init.mtx.Unlock()

if numLabelsWritten, err := init.diskState.NumLabelsWritten(); err != nil {
return err
Expand All @@ -130,57 +175,23 @@ func (init *Initializer) Initialize() error {
init.opts.NumFiles, init.opts.NumUnits, init.cfg.LabelsPerUnit, init.cfg.BitsPerLabel, init.opts.DataDir)

for i := 0; i < int(init.opts.NumFiles); i++ {
if err := init.initFile(uint(init.opts.ComputeProviderID), i, numLabels, fileNumLabels); err != nil {
if err := init.initFile(ctx, uint(init.opts.ComputeProviderID), i, numLabels, fileNumLabels); err != nil {
return err
}
}

return nil
}

func (init *Initializer) isInitializing() bool {
init.mtx.RLock()
defer init.mtx.RUnlock()
return init.initializing
}

func (init *Initializer) Stop() error {
if !init.isInitializing() {
return ErrNotInitializing
}

close(init.stopChan)
if res := gpu.Stop(); res != gpu.StopResultOk {
return fmt.Errorf("gpu stop error: %s", res)
}

select {
case <-init.doneChan:
case <-time.After(5 * time.Second):
return errors.New("stop timeout")
}

return nil
}

func (init *Initializer) SessionNumLabelsWrittenChan() <-chan uint64 {
init.mtx.RLock()
defer init.mtx.RUnlock()
return init.numLabelsWrittenChan
}

func (init *Initializer) SessionNumLabelsWritten() uint64 {
return init.numLabelsWritten.Load()
}

func (init *Initializer) Reset() error {
if init.isInitializing() {
if !init.mtx.TryLock() {
return ErrCannotResetWhileInitializing
}

if err := init.VerifyStarted(); err != nil {
return err
}
defer init.mtx.Unlock()

files, err := os.ReadDir(init.opts.DataDir)
if err != nil {
Expand All @@ -203,38 +214,30 @@ func (init *Initializer) Reset() error {
return nil
}

func (init *Initializer) Started() (bool, error) {
numLabelsWritten, err := init.diskState.NumLabelsWritten()
if err != nil {
return false, err
func (init *Initializer) Status() Status {
if !init.mtx.TryLock() {
return StatusInitializing
}
defer init.mtx.Unlock()

return numLabelsWritten > 0, nil
}

func (init *Initializer) Completed() (bool, error) {
numLabelsWritten, err := init.diskState.NumLabelsWritten()
if err != nil {
return false, err
return StatusError
}

target := uint64(init.opts.NumUnits) * uint64(init.cfg.LabelsPerUnit)
return numLabelsWritten == target, nil
}

func (init *Initializer) VerifyStarted() error {
ok, err := init.Started()
if err != nil {
return err
if numLabelsWritten == target {
return StatusCompleted
}
if !ok {
return shared.ErrInitNotStarted

if numLabelsWritten > 0 {
return StatusStarted
}

return nil
return StatusNotStarted
}

func (init *Initializer) initFile(computeProviderID uint, fileIndex int, numLabels uint64, fileNumLabels uint64) error {
func (init *Initializer) initFile(ctx context.Context, computeProviderID uint, fileIndex int, numLabels uint64, fileNumLabels uint64) error {
fileOffset := uint64(fileIndex) * fileNumLabels
fileTargetPosition := fileOffset + fileNumLabels
batchSize := uint64(config.DefaultComputeBatchSize)
Expand All @@ -254,7 +257,7 @@ func (init *Initializer) initFile(computeProviderID uint, fileIndex int, numLabe
if numLabelsWritten > 0 {
if numLabelsWritten == fileNumLabels {
init.logger.Info("initialization: file #%v already initialized; number of labels: %v, start position: %v", fileIndex, numLabelsWritten, fileOffset)
init.updateSessionNumLabelsWritten(fileTargetPosition)
init.numLabelsWritten.Store(fileTargetPosition)
return nil
}

Expand All @@ -263,7 +266,7 @@ func (init *Initializer) initFile(computeProviderID uint, fileIndex int, numLabe
if err := writer.Truncate(fileNumLabels); err != nil {
return err
}
init.updateSessionNumLabelsWritten(fileTargetPosition)
init.numLabelsWritten.Store(fileTargetPosition)
return nil
}

Expand All @@ -275,17 +278,22 @@ func (init *Initializer) initFile(computeProviderID uint, fileIndex int, numLabe
currentPosition := numLabelsWritten
outputChan := make(chan []byte, 1024)

errGroup, ctx := errgroup.WithContext(context.Background())
errGroup, ctx := errgroup.WithContext(ctx)

// Start compute worker.
errGroup.Go(func() error {
defer close(outputChan)

for currentPosition < fileNumLabels {
select {
case <-init.stopChan:
case <-ctx.Done():
init.logger.Info("initialization: stopped")
return ErrStopped

if res := gpu.Stop(); res != gpu.StopResultOk {
return fmt.Errorf("gpu stop error: %s", res)
}

return ctx.Err()
default:
}

Expand All @@ -307,7 +315,7 @@ func (init *Initializer) initFile(computeProviderID uint, fileIndex int, numLabe
outputChan <- output
currentPosition += uint64(batchSize)

init.updateSessionNumLabelsWritten(fileOffset + currentPosition)
init.numLabelsWritten.Store(fileOffset + currentPosition)
}
return nil
})
Expand Down Expand Up @@ -344,17 +352,6 @@ func (init *Initializer) initFile(computeProviderID uint, fileIndex int, numLabe
return nil
}

func (init *Initializer) updateSessionNumLabelsWritten(numLabelsWritten uint64) {
init.numLabelsWritten.Store(numLabelsWritten)

select {
case init.numLabelsWrittenChan <- numLabelsWritten:
default:
// if no one listens for the update, we just drop it
// otherwise Initializer would eventually stop working until someone reads from the channel
}
}

func (init *Initializer) verifyMetadata(m *Metadata) error {
if !bytes.Equal(init.commitment, m.Commitment) {
return ConfigMismatchError{
Expand Down Expand Up @@ -419,7 +416,3 @@ func (init *Initializer) saveMetadata() error {
func (init *Initializer) loadMetadata() (*Metadata, error) {
return LoadMetadata(init.opts.DataDir)
}

func (init *Initializer) SetLogger(logger Logger) {
init.logger = logger
}
Loading

0 comments on commit 87da468

Please sign in to comment.