diff --git a/CHANGELOG.md b/CHANGELOG.md index 23008b734..44145b897 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -9,6 +9,7 @@ and this project adheres to [Semantic Versioning](http://semver.org/spec/v2.0.0. ## [Unreleased] ### Added +- Fused inplace-dropout in FFN layer in Transformer - `--force-decode` option for marian-decoder - `--output-sampling` now works with ensembles (requires proper normalization via e.g `--weights 0.5 0.5`) @@ -24,6 +25,7 @@ and this project adheres to [Semantic Versioning](http://semver.org/spec/v2.0.0. - Fixed fp16 training/inference with factors-combine concat method ### Changed +- Parameter synchronization in local sharding model now executes hash checksum before syncing - Make guided-alignment faster via sparse memory layout, add alignment points for EOS, remove losses other than ce - Negative `--workspace -N` value allocates workspace as total available GPU memory minus N megabytes. - Set default parameters for cost-scaling to 8.f 10000 1.f 8.f, i.e. when scaling scale by 8 and do not try to automatically scale up or down. This seems most stable. diff --git a/VERSION b/VERSION index 316ba050f..d15b7998b 100644 --- a/VERSION +++ b/VERSION @@ -1 +1 @@ -v1.11.9 +v1.11.11 diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index e4599c407..f095f2eb8 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -177,6 +177,7 @@ set_target_properties(marian PROPERTIES ARCHIVE_OUTPUT_DIRECTORY "${CMAKE_BINARY if(CUDA_FOUND) cuda_add_library(marian_cuda tensors/gpu/device.cu + tensors/gpu/hash.cu tensors/gpu/algorithm.cu tensors/gpu/prod.cpp tensors/gpu/prod.cu diff --git a/src/common/hash.h b/src/common/hash.h index 7aca30de2..c2df2a63e 100644 --- a/src/common/hash.h +++ b/src/common/hash.h @@ -18,8 +18,7 @@ inline void hash_combine(HashType& seed, T const& v) { // Hash a whole chunk of memory, mostly used for diagnostics template -inline HashType hashMem(const T* beg, size_t len) { - HashType seed = 0; +inline HashType hashMem(const T* beg, size_t len, HashType seed = 0) { for(auto it = beg; it < beg + len; ++it) hash_combine(seed, *it); return seed; diff --git a/src/functional/operators.h b/src/functional/operators.h index a14f153f1..80b40ff40 100644 --- a/src/functional/operators.h +++ b/src/functional/operators.h @@ -750,5 +750,50 @@ UNARY(sReLUBack, ReLUback, Ops::reluBack(x)); BINARY(sPReLU, PReLU, Ops::prelu(x, y)); BINARY(sPReLUBack, PReLUback, Ops::preluBack(x, y)); +#ifdef __CUDACC__ +// only visible by nvcc + +DEVICE_INLINE uint32_t gf2u(float f32) { + // binary cast, bits stay the same + return __float_as_uint(f32); +} + +DEVICE_INLINE float gu2f(uint32_t u32) { + // binary cast, bits stay the same + return __uint_as_float(u32); +} + +// this is an adaptation of murmurhash3 as binary operator, all the +// magic numbers are present in the cpu implementation +DEVICE_INLINE uint32_t murmur3_u32(uint32_t seed, uint32_t key) { + uint32_t h = seed; + uint32_t k = key; + + k *= 0xcc9e2d51; + k = (k << 15) | (k >> 17); + k *= 0x1b873593; + + h ^= k; + + h = (h << 13) | (h >> 19); + h = h * 5 + 0xe6546b64; + + return h; +} + +DEVICE_INLINE float murmur3_f32(float seed, float key) { + // We cast from float to uint32_t and the hash back to float. + // Not great, but allows us to hack the float-specific reduction function to accumulate a hash value. + // This is not exactly murmurhash3 since we do a tree-reduction of hashes while murmur hash combines + // values linearly in memory order. But when tested this seems to work just as well for hashing purposes. + return gu2f(murmur3_u32(gf2u(seed), gf2u(key))); +} + +// Define a binary operator that allows for hashing inside the Marian low-level operator framework. +// For now, gpu-side only. +BINARY(Murmur, murmur, murmur3_f32(x, y)); + +#endif + } // end namespace functional } // end namespace marian diff --git a/src/tensors/gpu/add_all.inc b/src/tensors/gpu/add_all.inc index b6cb34173..ba466d895 100644 --- a/src/tensors/gpu/add_all.inc +++ b/src/tensors/gpu/add_all.inc @@ -36,10 +36,11 @@ template void AggregateAll, BinaryFunctor, Assignee<1>>, BinaryFunctor, Assignee<2>>>(std::shared_ptr, BinaryFunctor, Assignee<1>>, float, BinaryFunctor, Assignee<2>>, float, marian::Tensor, marian::Tensor); template void marian::AggregateAll >, marian::functional::Assignee<2> >, marian::functional::BinaryFunctor, marian::functional::Assignee<2> > >(std::shared_ptr, marian::functional::BinaryFunctor >, marian::functional::Assignee<2> >, float, marian::functional::BinaryFunctor, marian::functional::Assignee<2> >, float, IntrusivePtr, IntrusivePtr, IntrusivePtr); template void marian::AggregateAll, marian::functional::Assignee<2> > >, marian::functional::BinaryFunctor, marian::functional::Assignee<2> > >(std::shared_ptr, marian::functional::UnaryFunctor, marian::functional::Assignee<2> > >, float, marian::functional::BinaryFunctor, marian::functional::Assignee<2> >, float, IntrusivePtr, IntrusivePtr, IntrusivePtr); -template void marian::AggregateAll,marian::functional::UnaryFunctor > >,marian::functional::BinaryFunctor,marian::functional::Assignee<2> > >(std::shared_ptr,marian::functional::BinaryFunctor,marian::functional::UnaryFunctor > >,float,marian::functional::BinaryFunctor,marian::functional::Assignee<2> >,float,IntrusivePtr,IntrusivePtr,IntrusivePtr); +template void marian::AggregateAll,marian::functional::UnaryFunctor > >,marian::functional::BinaryFunctor,marian::functional::Assignee<2> > >(std::shared_ptr,marian::functional::BinaryFunctor,marian::functional::UnaryFunctor > >,float,marian::functional::BinaryFunctor,marian::functional::Assignee<2> >,float,IntrusivePtr,IntrusivePtr,IntrusivePtr); template void marian::AggregateAll >, marian::functional::BinaryFunctor, marian::functional::Assignee<2> > >(std::shared_ptr, marian::functional::UnaryFunctor >, float, marian::functional::BinaryFunctor, marian::functional::Assignee<2> >, float, IntrusivePtr, IntrusivePtr); template void marian::AggregateAll, marian::functional::UnaryFunctor > > >, marian::functional::BinaryFunctor, marian::functional::Assignee<2> > >(std::shared_ptr, marian::functional::BinaryFunctor, marian::functional::UnaryFunctor > > >, float, marian::functional::BinaryFunctor, marian::functional::Assignee<2> >, float, marian::Tensor, marian::Tensor, marian::Tensor); template void marian::AggregateAll, marian::functional::UnaryFunctor > > >, marian::functional::BinaryFunctor, marian::functional::Assignee<2> > >(std::shared_ptr, marian::functional::BinaryFunctor, marian::functional::UnaryFunctor > > >, float, marian::functional::BinaryFunctor, marian::functional::Assignee<2> >, float, marian::Tensor, marian::Tensor, marian::Tensor); +template void marian::AggregateAll, marian::functional::BinaryFunctor, marian::functional::Assignee<2> > >(std::shared_ptr, marian::functional::Assignee<1>, float, marian::functional::BinaryFunctor, marian::functional::Assignee<2> >, float, IntrusivePtr, IntrusivePtr); #if COMPILE_FP16 template void AggregateAll<__half, float, BinaryFunctor>, Assignee<2>>, BinaryFunctor, Assignee<2>>>(std::shared_ptr, BinaryFunctor>, Assignee<2>>, float, BinaryFunctor, Assignee<2>>, float, marian::Tensor, marian::Tensor, marian::Tensor); @@ -77,8 +78,9 @@ template void AggregateAll<__half, float, Assignee<1>, BinaryFunctor, Assignee<1>>, BinaryFunctor, Assignee<2>>>(std::shared_ptr, BinaryFunctor, Assignee<1>>, float, BinaryFunctor, Assignee<2>>, float, marian::Tensor, marian::Tensor); template void marian::AggregateAll<__half, float, marian::functional::BinaryFunctor >, marian::functional::Assignee<2> >, marian::functional::BinaryFunctor, marian::functional::Assignee<2> > >(std::shared_ptr, marian::functional::BinaryFunctor >, marian::functional::Assignee<2> >, float, marian::functional::BinaryFunctor, marian::functional::Assignee<2> >, float, IntrusivePtr, IntrusivePtr, IntrusivePtr); template void marian::AggregateAll<__half, float, marian::functional::UnaryFunctor, marian::functional::Assignee<2> > >, marian::functional::BinaryFunctor, marian::functional::Assignee<2> > >(std::shared_ptr, marian::functional::UnaryFunctor, marian::functional::Assignee<2> > >, float, marian::functional::BinaryFunctor, marian::functional::Assignee<2> >, float, IntrusivePtr, IntrusivePtr, IntrusivePtr); -template void marian::AggregateAll<__half,float,marian::functional::BinaryFunctor,marian::functional::UnaryFunctor > >,marian::functional::BinaryFunctor,marian::functional::Assignee<2> > >(std::shared_ptr,marian::functional::BinaryFunctor,marian::functional::UnaryFunctor > >,float,marian::functional::BinaryFunctor,marian::functional::Assignee<2> >,float,IntrusivePtr,IntrusivePtr,IntrusivePtr); +template void marian::AggregateAll<__half,float,marian::functional::BinaryFunctor,marian::functional::UnaryFunctor > >,marian::functional::BinaryFunctor,marian::functional::Assignee<2> > >(std::shared_ptr,marian::functional::BinaryFunctor,marian::functional::UnaryFunctor > >,float,marian::functional::BinaryFunctor,marian::functional::Assignee<2> >,float,IntrusivePtr,IntrusivePtr,IntrusivePtr); template void marian::AggregateAll<__half, float, marian::functional::UnaryFunctor >, marian::functional::BinaryFunctor, marian::functional::Assignee<2> > >(std::shared_ptr, marian::functional::UnaryFunctor >, float, marian::functional::BinaryFunctor, marian::functional::Assignee<2> >, float, IntrusivePtr, IntrusivePtr); template void marian::AggregateAll<__half, float, marian::functional::BinaryFunctor, marian::functional::UnaryFunctor > > >, marian::functional::BinaryFunctor, marian::functional::Assignee<2> > >(std::shared_ptr, marian::functional::BinaryFunctor, marian::functional::UnaryFunctor > > >, float, marian::functional::BinaryFunctor, marian::functional::Assignee<2> >, float, marian::Tensor, marian::Tensor, marian::Tensor); template void marian::AggregateAll<__half, float, marian::functional::BinaryFunctor, marian::functional::UnaryFunctor > > >, marian::functional::BinaryFunctor, marian::functional::Assignee<2> > >(std::shared_ptr, marian::functional::BinaryFunctor, marian::functional::UnaryFunctor > > >, float, marian::functional::BinaryFunctor, marian::functional::Assignee<2> >, float, marian::Tensor, marian::Tensor, marian::Tensor); +template void marian::AggregateAll<__half, float, marian::functional::Assignee<1>, marian::functional::BinaryFunctor, marian::functional::Assignee<2> > >(std::shared_ptr, marian::functional::Assignee<1>, float, marian::functional::BinaryFunctor, marian::functional::Assignee<2> >, float, IntrusivePtr, IntrusivePtr); #endif diff --git a/src/tensors/gpu/hash.cu b/src/tensors/gpu/hash.cu new file mode 100644 index 000000000..e132cdc4f --- /dev/null +++ b/src/tensors/gpu/hash.cu @@ -0,0 +1,57 @@ +#include "tensors/gpu/add_all.h" +#include "functional/operators.h" +// clang-format on + +#include + +#if COMPILE_FP16 +#include +#endif + +namespace marian { +namespace gpu { + +// cpu-side conversion of float to uint32_t via bit-wise cast +uint32_t f2u(float f32) { + uint32_t u32; + std::memcpy(&u32, &f32, 4); + return u32; +} + +// cpu-side conversion of uint32_t to float via bit-wise cast +float u2f(uint32_t u32) { + float f32; + std::memcpy(&f32, &u32, 4); + return f32; +} + +// Computes a murmur3-ish hash value for a Marian tensor. +uint32_t hashTensor(Tensor tensor, uint32_t seed, Ptr allocator) { + // we first accumulate into single value via a binary mumurhash3-like operator, + // see functional/operators.h for details. + using namespace functional; + uint32_t h = 0; + if(tensor->type() == Type::float32) + h = f2u(AggregateAllAndReturn(allocator, _1, u2f(seed), murmur(_1, _2), 1, tensor)); +#if COMPILE_FP16 + else if(tensor->type() == Type::float16) + // internally, a half value gets cast to a float value before hashing or combining. These is the same + // mechanics as for summing where we cast to a larger type for better precision. + h = f2u(AggregateAllAndReturn(allocator, _1, u2f(seed), murmur(_1, _2), 1, tensor)); +#endif + else + ABORT("Hashing of tensors not supported for type {}", tensor->type()); + + // finalization according to murmurhash3 implementation + uint32_t len = (uint32_t)tensor->size(); + h ^= len; + h ^= h >> 16; + h *= 0x85ebca6b; + h ^= h >> 13; + h *= 0xc2b2ae35; + h ^= h >> 16; + return h; +} + +} // namespace gpu +} // namespace marian \ No newline at end of file diff --git a/src/tensors/tensor.cpp b/src/tensors/tensor.cpp index 02de17bc5..e9a07ab46 100644 --- a/src/tensors/tensor.cpp +++ b/src/tensors/tensor.cpp @@ -138,13 +138,13 @@ void TensorBase::set(const io::Item& item) { memory_->data()); } -size_t TensorBase::hash() { - io::Item temp; - size_t seed = 0; - get(temp, "temp"); - for(auto c : temp.bytes) - util::hash_combine(seed, c); - return seed; +size_t TensorBase::hash(size_t seed, Ptr allocator) { +#ifdef CUDA_FOUND + if(backend_->getDeviceId().type == DeviceType::gpu) + return marian::gpu::hashTensor(this, (uint32_t)seed, allocator); + else // we assmume CPU +#endif + return marian::util::hashMem(memory_->data(), memory_->size(), seed); } } // namespace marian diff --git a/src/tensors/tensor.h b/src/tensors/tensor.h index a70714043..48e3aaec9 100644 --- a/src/tensors/tensor.h +++ b/src/tensors/tensor.h @@ -3,8 +3,10 @@ #include "common/definitions.h" #include "common/shape.h" #include "common/types.h" +#include "tensors/allocator.h" #include "tensors/backend.h" #include "tensors/memory_piece.h" + #ifdef CUDA_FOUND #include "tensors/gpu/algorithm.h" #endif @@ -327,7 +329,14 @@ class TensorBase { DISPATCH_BY_TYPE2(type_, debug, precision, dispCols); } - size_t hash(); + // Computes a hash value for the given tensor, for a cpu-side tensor this is + // going to be the hash function from stdlib (64-bit), for gpu-side tensors + // it is going to be the result of a mumurhash3-like hash (32-bit). + // The argument seed can be used to define a new random hash function. + // The allocator argument can be used to allocate memory via the standard + // marian allocator instead of cudaMalloc (the default). + // The hashes are not the same for cpu and gpu! + size_t hash(size_t seed = 0, Ptr allocator = nullptr); }; diff --git a/src/tensors/tensor_operators.h b/src/tensors/tensor_operators.h index 1fc4542d8..178bb6920 100644 --- a/src/tensors/tensor_operators.h +++ b/src/tensors/tensor_operators.h @@ -436,5 +436,12 @@ static inline float L2Norm(marian::Tensor in, Ptr allocator) { // clang-format off DISPATCH5(PoolingWithMaskingForward, marian::Tensor, marian::Tensor, marian::Tensor, int, bool) DISPATCH6(PoolingWithMaskingBackward, marian::Tensor, marian::Tensor, marian::Tensor, marian::Tensor, int, bool) + +#ifdef CUDA_FOUND +namespace gpu { + uint32_t hashTensor(Tensor tensor, uint32_t seed, Ptr allocator); +} +#endif + // clang-format on } // namespace marian diff --git a/src/training/graph_group.cpp b/src/training/graph_group.cpp index d9a77a708..0ba1b279d 100644 --- a/src/training/graph_group.cpp +++ b/src/training/graph_group.cpp @@ -91,6 +91,42 @@ void GraphGroup::initGraphsAndOpts() { } } +void GraphGroup::syncParametersAndShards() { + // In local model we have seen that parameters can diverge occasionally due to non-determinism in NCCL. + // Here, we try to catch this and if caught, re-sync everything (also optimizer state) across nodes. + if(shardingMode_ == ShardingMode::local) { + std::vector hashes(mpi_->numMPIProcesses(), 0); + // compute hash value of parameters of 0-th graph (we only need to check one graph per node) + for(int i = 0; i < hashes.size(); i++) { + if(i == mpi_->myMPIRank()) { + hashes[i] = graphs_[0]->params()->vals()->hash(); // this is quite fast with on-GPU implementation + LOG(debug, "Parameter hash for graph 0 on node {}: {}", mpi_->myMPIRank(), hashes[i]); + } + } + + // Collect hashes from all nodes, note changing rootRank. + // After this hashes contains all hashes from all nodes. + for(int i = 0; i < hashes.size(); i++) + mpi_->bCast(&hashes[i], 1, mpi_->getDataType(&hashes[i]), /*rootRank=*/i); + + // If any of the hashes diverges, re-sync. + if(std::any_of(hashes.begin(), hashes.end(), [&hashes](size_t v){ return v != hashes[0]; })) { + if(isMainProcess()) { + LOG(warn, "Parameters diverged:"); + for(int i = 0; i < hashes.size(); i++) + LOG(warn, "\tGot hash {} for node {}", hashes[i], i); + LOG(warn, "Syncing all parameters and optimizer shards across {} MPI processes", mpi_->numMPIProcesses()); + } + + comm_->broadcastParams(); + comm_->broadcastShards(optimizerShards_); + + if(isMainProcess()) + LOG(warn, "Re-synced all shards"); + } + } +} + // increase cost-scaling factor if no NaN has been detected for a // given number of iterations. Usually we increase by 2 which adds // one more bit for precision. diff --git a/src/training/graph_group.h b/src/training/graph_group.h index 9f1362e75..0895caa77 100644 --- a/src/training/graph_group.h +++ b/src/training/graph_group.h @@ -85,6 +85,7 @@ class GraphGroup { GraphGroup(Ptr options); void initGraphsAndOpts(); + void syncParametersAndShards(); virtual ~GraphGroup() {} diff --git a/src/training/graph_group_sync.cpp b/src/training/graph_group_sync.cpp index c90a384e4..a3eee8a7b 100644 --- a/src/training/graph_group_sync.cpp +++ b/src/training/graph_group_sync.cpp @@ -346,11 +346,7 @@ void SyncGraphGroup::update(std::vector> subBatches, size_t num scheduler_->update(localLoss, numReadBatches, updateBatchSize, updateTargetWords, gradNorm); if(scheduler_->syncing()) { - if(shardingMode_ == ShardingMode::local) { - LOG(debug, "Syncing all parameters and optimizer shards across {} MPI processes", mpi_->numMPIProcesses()); - comm_->broadcastParams(); - comm_->broadcastShards(optimizerShards_); - } + syncParametersAndShards(); } // save intermediate model (and optimizer state) to file