Skip to content

Commit

Permalink
Merged PR 25836: Check via hashing if re-syncing in local mode is req…
Browse files Browse the repository at this point in the history
…uired

* This adds GPU-side hashing to tensors (a hash based on mumurhash3)
* The hash is used to check if parameters across nodes have diverged, if yes, resync all parameters and optimizer shards. Before it would resync every N (100 or 200) updates. Now this can be skipped if nothing diverged.
  • Loading branch information
emjotde committed Sep 27, 2022
1 parent 1f2929d commit 2cd3055
Show file tree
Hide file tree
Showing 13 changed files with 173 additions and 18 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ and this project adheres to [Semantic Versioning](http://semver.org/spec/v2.0.0.
## [Unreleased]

### Added
- Fused inplace-dropout in FFN layer in Transformer
- `--force-decode` option for marian-decoder
- `--output-sampling` now works with ensembles (requires proper normalization via e.g `--weights 0.5 0.5`)

Expand All @@ -24,6 +25,7 @@ and this project adheres to [Semantic Versioning](http://semver.org/spec/v2.0.0.
- Fixed fp16 training/inference with factors-combine concat method

### Changed
- Parameter synchronization in local sharding model now executes hash checksum before syncing
- Make guided-alignment faster via sparse memory layout, add alignment points for EOS, remove losses other than ce
- Negative `--workspace -N` value allocates workspace as total available GPU memory minus N megabytes.
- Set default parameters for cost-scaling to 8.f 10000 1.f 8.f, i.e. when scaling scale by 8 and do not try to automatically scale up or down. This seems most stable.
Expand Down
2 changes: 1 addition & 1 deletion VERSION
Original file line number Diff line number Diff line change
@@ -1 +1 @@
v1.11.9
v1.11.11
1 change: 1 addition & 0 deletions src/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,7 @@ set_target_properties(marian PROPERTIES ARCHIVE_OUTPUT_DIRECTORY "${CMAKE_BINARY
if(CUDA_FOUND)
cuda_add_library(marian_cuda
tensors/gpu/device.cu
tensors/gpu/hash.cu
tensors/gpu/algorithm.cu
tensors/gpu/prod.cpp
tensors/gpu/prod.cu
Expand Down
3 changes: 1 addition & 2 deletions src/common/hash.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,7 @@ inline void hash_combine(HashType& seed, T const& v) {

// Hash a whole chunk of memory, mostly used for diagnostics
template <class T, class HashType = std::size_t>
inline HashType hashMem(const T* beg, size_t len) {
HashType seed = 0;
inline HashType hashMem(const T* beg, size_t len, HashType seed = 0) {
for(auto it = beg; it < beg + len; ++it)
hash_combine(seed, *it);
return seed;
Expand Down
45 changes: 45 additions & 0 deletions src/functional/operators.h
Original file line number Diff line number Diff line change
Expand Up @@ -750,5 +750,50 @@ UNARY(sReLUBack, ReLUback, Ops<ElementType>::reluBack(x));
BINARY(sPReLU, PReLU, Ops<ElementType>::prelu(x, y));
BINARY(sPReLUBack, PReLUback, Ops<ElementType>::preluBack(x, y));

#ifdef __CUDACC__
// only visible by nvcc

DEVICE_INLINE uint32_t gf2u(float f32) {
// binary cast, bits stay the same
return __float_as_uint(f32);
}

DEVICE_INLINE float gu2f(uint32_t u32) {
// binary cast, bits stay the same
return __uint_as_float(u32);
}

// this is an adaptation of murmurhash3 as binary operator, all the
// magic numbers are present in the cpu implementation
DEVICE_INLINE uint32_t murmur3_u32(uint32_t seed, uint32_t key) {
uint32_t h = seed;
uint32_t k = key;

k *= 0xcc9e2d51;
k = (k << 15) | (k >> 17);
k *= 0x1b873593;

h ^= k;

h = (h << 13) | (h >> 19);
h = h * 5 + 0xe6546b64;

return h;
}

DEVICE_INLINE float murmur3_f32(float seed, float key) {
// We cast from float to uint32_t and the hash back to float.
// Not great, but allows us to hack the float-specific reduction function to accumulate a hash value.
// This is not exactly murmurhash3 since we do a tree-reduction of hashes while murmur hash combines
// values linearly in memory order. But when tested this seems to work just as well for hashing purposes.
return gu2f(murmur3_u32(gf2u(seed), gf2u(key)));
}

// Define a binary operator that allows for hashing inside the Marian low-level operator framework.
// For now, gpu-side only.
BINARY(Murmur, murmur, murmur3_f32(x, y));

#endif

} // end namespace functional
} // end namespace marian
6 changes: 4 additions & 2 deletions src/tensors/gpu/add_all.inc
Original file line number Diff line number Diff line change
Expand Up @@ -36,10 +36,11 @@ template void AggregateAll<float, float, Assignee<1>, BinaryFunctor<elem::Plus,
template void AggregateAll<float, float, BinaryFunctor<elem::Mult, Assignee<1>, Assignee<1>>, BinaryFunctor<elem::Plus, Assignee<1>, Assignee<2>>>(std::shared_ptr<Allocator>, BinaryFunctor<elem::Mult, Assignee<1>, Assignee<1>>, float, BinaryFunctor<elem::Plus, Assignee<1>, Assignee<2>>, float, marian::Tensor, marian::Tensor);
template void marian::AggregateAll<float, float, marian::functional::BinaryFunctor<marian::functional::elem::Mult, marian::functional::UnaryFunctor<marian::functional::elem::Sgn, marian::functional::Assignee<1> >, marian::functional::Assignee<2> >, marian::functional::BinaryFunctor<marian::functional::elem::Plus, marian::functional::Assignee<1>, marian::functional::Assignee<2> > >(std::shared_ptr<marian::Allocator>, marian::functional::BinaryFunctor<marian::functional::elem::Mult, marian::functional::UnaryFunctor<marian::functional::elem::Sgn, marian::functional::Assignee<1> >, marian::functional::Assignee<2> >, float, marian::functional::BinaryFunctor<marian::functional::elem::Plus, marian::functional::Assignee<1>, marian::functional::Assignee<2> >, float, IntrusivePtr<marian::TensorBase>, IntrusivePtr<marian::TensorBase>, IntrusivePtr<marian::TensorBase>);
template void marian::AggregateAll<float, float, marian::functional::UnaryFunctor<marian::functional::elem::Abs, marian::functional::BinaryFunctor<marian::functional::elem::Minus, marian::functional::Assignee<1>, marian::functional::Assignee<2> > >, marian::functional::BinaryFunctor<marian::functional::elem::Plus, marian::functional::Assignee<1>, marian::functional::Assignee<2> > >(std::shared_ptr<marian::Allocator>, marian::functional::UnaryFunctor<marian::functional::elem::Abs, marian::functional::BinaryFunctor<marian::functional::elem::Minus, marian::functional::Assignee<1>, marian::functional::Assignee<2> > >, float, marian::functional::BinaryFunctor<marian::functional::elem::Plus, marian::functional::Assignee<1>, marian::functional::Assignee<2> >, float, IntrusivePtr<marian::TensorBase>, IntrusivePtr<marian::TensorBase>, IntrusivePtr<marian::TensorBase>);
template void marian::AggregateAll<float,float,marian::functional::BinaryFunctor<marian::functional::elem::Mult,marian::functional::Assignee<1>,marian::functional::UnaryFunctor<marian::functional::elem::Cos,marian::functional::Assignee<2> > >,marian::functional::BinaryFunctor<marian::functional::elem::Plus,marian::functional::Assignee<1>,marian::functional::Assignee<2> > >(std::shared_ptr<marian::Allocator>,marian::functional::BinaryFunctor<marian::functional::elem::Mult,marian::functional::Assignee<1>,marian::functional::UnaryFunctor<marian::functional::elem::Cos,marian::functional::Assignee<2> > >,float,marian::functional::BinaryFunctor<marian::functional::elem::Plus,marian::functional::Assignee<1>,marian::functional::Assignee<2> >,float,IntrusivePtr<marian::TensorBase>,IntrusivePtr<marian::TensorBase>,IntrusivePtr<marian::TensorBase>);
template void marian::AggregateAll<float,float,marian::functional::BinaryFunctor<marian::functional::elem::Mult,marian::functional::Assignee<1>,marian::functional::UnaryFunctor<marian::functional::elem::Cos,marian::functional::Assignee<2> > >,marian::functional::BinaryFunctor<marian::functional::elem::Plus,marian::functional::Assignee<1>,marian::functional::Assignee<2> > >(std::shared_ptr<marian::Allocator>,marian::functional::BinaryFunctor<marian::functional::elem::Mult,marian::functional::Assignee<1>,marian::functional::UnaryFunctor<marian::functional::elem::Cos,marian::functional::Assignee<2> > >,float,marian::functional::BinaryFunctor<marian::functional::elem::Plus,marian::functional::Assignee<1>,marian::functional::Assignee<2> >,float,IntrusivePtr<marian::TensorBase>,IntrusivePtr<marian::TensorBase>,IntrusivePtr<marian::TensorBase>);
template void marian::AggregateAll<float, float, marian::functional::UnaryFunctor<marian::functional::elem::Abs, marian::functional::Assignee<1> >, marian::functional::BinaryFunctor<marian::functional::elem::Max, marian::functional::Assignee<1>, marian::functional::Assignee<2> > >(std::shared_ptr<marian::Allocator>, marian::functional::UnaryFunctor<marian::functional::elem::Abs, marian::functional::Assignee<1> >, float, marian::functional::BinaryFunctor<marian::functional::elem::Max, marian::functional::Assignee<1>, marian::functional::Assignee<2> >, float, IntrusivePtr<marian::TensorBase>, IntrusivePtr<marian::TensorBase>);
template void marian::AggregateAll<float, float, marian::functional::BinaryFunctor<marian::functional::elem::Mult, marian::functional::Assignee<1>, marian::functional::UnaryFunctor<marian::functional::elem::Neg, marian::functional::UnaryFunctor<marian::functional::elem::Sin, marian::functional::Assignee<2> > > >, marian::functional::BinaryFunctor<marian::functional::elem::Plus, marian::functional::Assignee<1>, marian::functional::Assignee<2> > >(std::shared_ptr<marian::Allocator>, marian::functional::BinaryFunctor<marian::functional::elem::Mult, marian::functional::Assignee<1>, marian::functional::UnaryFunctor<marian::functional::elem::Neg, marian::functional::UnaryFunctor<marian::functional::elem::Sin, marian::functional::Assignee<2> > > >, float, marian::functional::BinaryFunctor<marian::functional::elem::Plus, marian::functional::Assignee<1>, marian::functional::Assignee<2> >, float, marian::Tensor, marian::Tensor, marian::Tensor);
template void marian::AggregateAll<float, float, marian::functional::BinaryFunctor<marian::functional::elem::Div, marian::functional::Assignee<1>, marian::functional::UnaryFunctor<marian::functional::elem::Sqr, marian::functional::UnaryFunctor<marian::functional::elem::Cos, marian::functional::Assignee<2> > > >, marian::functional::BinaryFunctor<marian::functional::elem::Plus, marian::functional::Assignee<1>, marian::functional::Assignee<2> > >(std::shared_ptr<marian::Allocator>, marian::functional::BinaryFunctor<marian::functional::elem::Div, marian::functional::Assignee<1>, marian::functional::UnaryFunctor<marian::functional::elem::Sqr, marian::functional::UnaryFunctor<marian::functional::elem::Cos, marian::functional::Assignee<2> > > >, float, marian::functional::BinaryFunctor<marian::functional::elem::Plus, marian::functional::Assignee<1>, marian::functional::Assignee<2> >, float, marian::Tensor, marian::Tensor, marian::Tensor);
template void marian::AggregateAll<float, float, marian::functional::Assignee<1>, marian::functional::BinaryFunctor<marian::functional::elem::Murmur, marian::functional::Assignee<1>, marian::functional::Assignee<2> > >(std::shared_ptr<marian::Allocator>, marian::functional::Assignee<1>, float, marian::functional::BinaryFunctor<marian::functional::elem::Murmur, marian::functional::Assignee<1>, marian::functional::Assignee<2> >, float, IntrusivePtr<marian::TensorBase>, IntrusivePtr<marian::TensorBase>);

#if COMPILE_FP16
template void AggregateAll<__half, float, BinaryFunctor<elem::Mult, BinaryFunctor<elem::Mult, Capture, Assignee<1>>, Assignee<2>>, BinaryFunctor<elem::Plus, Assignee<1>, Assignee<2>>>(std::shared_ptr<Allocator>, BinaryFunctor<elem::Mult, BinaryFunctor<elem::Mult, Capture, Assignee<1>>, Assignee<2>>, float, BinaryFunctor<elem::Plus, Assignee<1>, Assignee<2>>, float, marian::Tensor, marian::Tensor, marian::Tensor);
Expand Down Expand Up @@ -77,8 +78,9 @@ template void AggregateAll<__half, float, Assignee<1>, BinaryFunctor<elem::Plus,
template void AggregateAll<__half, float, BinaryFunctor<elem::Mult, Assignee<1>, Assignee<1>>, BinaryFunctor<elem::Plus, Assignee<1>, Assignee<2>>>(std::shared_ptr<Allocator>, BinaryFunctor<elem::Mult, Assignee<1>, Assignee<1>>, float, BinaryFunctor<elem::Plus, Assignee<1>, Assignee<2>>, float, marian::Tensor, marian::Tensor);
template void marian::AggregateAll<__half, float, marian::functional::BinaryFunctor<marian::functional::elem::Mult, marian::functional::UnaryFunctor<marian::functional::elem::Sgn, marian::functional::Assignee<1> >, marian::functional::Assignee<2> >, marian::functional::BinaryFunctor<marian::functional::elem::Plus, marian::functional::Assignee<1>, marian::functional::Assignee<2> > >(std::shared_ptr<marian::Allocator>, marian::functional::BinaryFunctor<marian::functional::elem::Mult, marian::functional::UnaryFunctor<marian::functional::elem::Sgn, marian::functional::Assignee<1> >, marian::functional::Assignee<2> >, float, marian::functional::BinaryFunctor<marian::functional::elem::Plus, marian::functional::Assignee<1>, marian::functional::Assignee<2> >, float, IntrusivePtr<marian::TensorBase>, IntrusivePtr<marian::TensorBase>, IntrusivePtr<marian::TensorBase>);
template void marian::AggregateAll<__half, float, marian::functional::UnaryFunctor<marian::functional::elem::Abs, marian::functional::BinaryFunctor<marian::functional::elem::Minus, marian::functional::Assignee<1>, marian::functional::Assignee<2> > >, marian::functional::BinaryFunctor<marian::functional::elem::Plus, marian::functional::Assignee<1>, marian::functional::Assignee<2> > >(std::shared_ptr<marian::Allocator>, marian::functional::UnaryFunctor<marian::functional::elem::Abs, marian::functional::BinaryFunctor<marian::functional::elem::Minus, marian::functional::Assignee<1>, marian::functional::Assignee<2> > >, float, marian::functional::BinaryFunctor<marian::functional::elem::Plus, marian::functional::Assignee<1>, marian::functional::Assignee<2> >, float, IntrusivePtr<marian::TensorBase>, IntrusivePtr<marian::TensorBase>, IntrusivePtr<marian::TensorBase>);
template void marian::AggregateAll<__half,float,marian::functional::BinaryFunctor<marian::functional::elem::Mult,marian::functional::Assignee<1>,marian::functional::UnaryFunctor<marian::functional::elem::Cos,marian::functional::Assignee<2> > >,marian::functional::BinaryFunctor<marian::functional::elem::Plus,marian::functional::Assignee<1>,marian::functional::Assignee<2> > >(std::shared_ptr<marian::Allocator>,marian::functional::BinaryFunctor<marian::functional::elem::Mult,marian::functional::Assignee<1>,marian::functional::UnaryFunctor<marian::functional::elem::Cos,marian::functional::Assignee<2> > >,float,marian::functional::BinaryFunctor<marian::functional::elem::Plus,marian::functional::Assignee<1>,marian::functional::Assignee<2> >,float,IntrusivePtr<marian::TensorBase>,IntrusivePtr<marian::TensorBase>,IntrusivePtr<marian::TensorBase>);
template void marian::AggregateAll<__half,float,marian::functional::BinaryFunctor<marian::functional::elem::Mult,marian::functional::Assignee<1>,marian::functional::UnaryFunctor<marian::functional::elem::Cos,marian::functional::Assignee<2> > >,marian::functional::BinaryFunctor<marian::functional::elem::Plus,marian::functional::Assignee<1>,marian::functional::Assignee<2> > >(std::shared_ptr<marian::Allocator>,marian::functional::BinaryFunctor<marian::functional::elem::Mult,marian::functional::Assignee<1>,marian::functional::UnaryFunctor<marian::functional::elem::Cos,marian::functional::Assignee<2> > >,float,marian::functional::BinaryFunctor<marian::functional::elem::Plus,marian::functional::Assignee<1>,marian::functional::Assignee<2> >,float,IntrusivePtr<marian::TensorBase>,IntrusivePtr<marian::TensorBase>,IntrusivePtr<marian::TensorBase>);
template void marian::AggregateAll<__half, float, marian::functional::UnaryFunctor<marian::functional::elem::Abs, marian::functional::Assignee<1> >, marian::functional::BinaryFunctor<marian::functional::elem::Max, marian::functional::Assignee<1>, marian::functional::Assignee<2> > >(std::shared_ptr<marian::Allocator>, marian::functional::UnaryFunctor<marian::functional::elem::Abs, marian::functional::Assignee<1> >, float, marian::functional::BinaryFunctor<marian::functional::elem::Max, marian::functional::Assignee<1>, marian::functional::Assignee<2> >, float, IntrusivePtr<marian::TensorBase>, IntrusivePtr<marian::TensorBase>);
template void marian::AggregateAll<__half, float, marian::functional::BinaryFunctor<marian::functional::elem::Mult, marian::functional::Assignee<1>, marian::functional::UnaryFunctor<marian::functional::elem::Neg, marian::functional::UnaryFunctor<marian::functional::elem::Sin, marian::functional::Assignee<2> > > >, marian::functional::BinaryFunctor<marian::functional::elem::Plus, marian::functional::Assignee<1>, marian::functional::Assignee<2> > >(std::shared_ptr<marian::Allocator>, marian::functional::BinaryFunctor<marian::functional::elem::Mult, marian::functional::Assignee<1>, marian::functional::UnaryFunctor<marian::functional::elem::Neg, marian::functional::UnaryFunctor<marian::functional::elem::Sin, marian::functional::Assignee<2> > > >, float, marian::functional::BinaryFunctor<marian::functional::elem::Plus, marian::functional::Assignee<1>, marian::functional::Assignee<2> >, float, marian::Tensor, marian::Tensor, marian::Tensor);
template void marian::AggregateAll<__half, float, marian::functional::BinaryFunctor<marian::functional::elem::Div, marian::functional::Assignee<1>, marian::functional::UnaryFunctor<marian::functional::elem::Sqr, marian::functional::UnaryFunctor<marian::functional::elem::Cos, marian::functional::Assignee<2> > > >, marian::functional::BinaryFunctor<marian::functional::elem::Plus, marian::functional::Assignee<1>, marian::functional::Assignee<2> > >(std::shared_ptr<marian::Allocator>, marian::functional::BinaryFunctor<marian::functional::elem::Div, marian::functional::Assignee<1>, marian::functional::UnaryFunctor<marian::functional::elem::Sqr, marian::functional::UnaryFunctor<marian::functional::elem::Cos, marian::functional::Assignee<2> > > >, float, marian::functional::BinaryFunctor<marian::functional::elem::Plus, marian::functional::Assignee<1>, marian::functional::Assignee<2> >, float, marian::Tensor, marian::Tensor, marian::Tensor);
template void marian::AggregateAll<__half, float, marian::functional::Assignee<1>, marian::functional::BinaryFunctor<marian::functional::elem::Murmur, marian::functional::Assignee<1>, marian::functional::Assignee<2> > >(std::shared_ptr<marian::Allocator>, marian::functional::Assignee<1>, float, marian::functional::BinaryFunctor<marian::functional::elem::Murmur, marian::functional::Assignee<1>, marian::functional::Assignee<2> >, float, IntrusivePtr<marian::TensorBase>, IntrusivePtr<marian::TensorBase>);
#endif
57 changes: 57 additions & 0 deletions src/tensors/gpu/hash.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
#include "tensors/gpu/add_all.h"
#include "functional/operators.h"
// clang-format on

#include <cstdint>

#if COMPILE_FP16
#include <cuda_fp16.h>
#endif

namespace marian {
namespace gpu {

// cpu-side conversion of float to uint32_t via bit-wise cast
uint32_t f2u(float f32) {
uint32_t u32;
std::memcpy(&u32, &f32, 4);
return u32;
}

// cpu-side conversion of uint32_t to float via bit-wise cast
float u2f(uint32_t u32) {
float f32;
std::memcpy(&f32, &u32, 4);
return f32;
}

// Computes a murmur3-ish hash value for a Marian tensor.
uint32_t hashTensor(Tensor tensor, uint32_t seed, Ptr<Allocator> allocator) {
// we first accumulate into single value via a binary mumurhash3-like operator,
// see functional/operators.h for details.
using namespace functional;
uint32_t h = 0;
if(tensor->type() == Type::float32)
h = f2u(AggregateAllAndReturn<float, float>(allocator, _1, u2f(seed), murmur(_1, _2), 1, tensor));
#if COMPILE_FP16
else if(tensor->type() == Type::float16)
// internally, a half value gets cast to a float value before hashing or combining. These is the same
// mechanics as for summing where we cast to a larger type for better precision.
h = f2u(AggregateAllAndReturn<half, float>(allocator, _1, u2f(seed), murmur(_1, _2), 1, tensor));
#endif
else
ABORT("Hashing of tensors not supported for type {}", tensor->type());

// finalization according to murmurhash3 implementation
uint32_t len = (uint32_t)tensor->size();
h ^= len;
h ^= h >> 16;
h *= 0x85ebca6b;
h ^= h >> 13;
h *= 0xc2b2ae35;
h ^= h >> 16;
return h;
}

} // namespace gpu
} // namespace marian
14 changes: 7 additions & 7 deletions src/tensors/tensor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -138,13 +138,13 @@ void TensorBase::set(const io::Item& item) {
memory_->data<char>());
}

size_t TensorBase::hash() {
io::Item temp;
size_t seed = 0;
get(temp, "temp");
for(auto c : temp.bytes)
util::hash_combine(seed, c);
return seed;
size_t TensorBase::hash(size_t seed, Ptr<Allocator> allocator) {
#ifdef CUDA_FOUND
if(backend_->getDeviceId().type == DeviceType::gpu)
return marian::gpu::hashTensor(this, (uint32_t)seed, allocator);
else // we assmume CPU
#endif
return marian::util::hashMem(memory_->data<char>(), memory_->size(), seed);
}

} // namespace marian
Expand Down
Loading

0 comments on commit 2cd3055

Please sign in to comment.