|
Sierra Toolkit
Version of the Day
|
00001 /*------------------------------------------------------------------------*/ 00002 /* Copyright 2010 Sandia Corporation. */ 00003 /* Under terms of Contract DE-AC04-94AL85000, there is a non-exclusive */ 00004 /* license for use of this work by or on behalf of the U.S. Government. */ 00005 /* Export of this program may require a license from the */ 00006 /* United States Government. */ 00007 /*------------------------------------------------------------------------*/ 00008 00009 #include <stdlib.h> 00010 #include <stdexcept> 00011 #include <sstream> 00012 #include <vector> 00013 00014 #include <stk_util/parallel/ParallelComm.hpp> 00015 #include <stk_util/parallel/ParallelReduce.hpp> 00016 00017 namespace stk_classic { 00018 00019 //----------------------------------------------------------------------- 00020 00021 #if defined( STK_HAS_MPI ) 00022 00023 enum { STK_MPI_TAG_SIZING = 0 , STK_MPI_TAG_DATA = 1 }; 00024 00025 // Communicate in sparse or dense mode, as directed during allocation 00026 00027 namespace { 00028 00029 bool all_to_all_dense( ParallelMachine p_comm , 00030 const CommBuffer * const send , 00031 const CommBuffer * const recv , 00032 std::ostream & msg ) 00033 { 00034 typedef unsigned char * ucharp ; 00035 00036 static const char method[] = "stk_classic::CommAll::communicate" ; 00037 00038 int result ; 00039 00040 { 00041 const unsigned p_size = parallel_machine_size( p_comm ); 00042 00043 std::vector<int> tmp( p_size * 4 ); 00044 00045 int * const send_counts = (tmp.empty() ? NULL : & tmp[0]) ; 00046 int * const send_displs = send_counts + p_size ; 00047 int * const recv_counts = send_displs + p_size ; 00048 int * const recv_displs = recv_counts + p_size ; 00049 00050 unsigned char * const ps = static_cast<ucharp>(send[0].buffer()); 00051 unsigned char * const pr = static_cast<ucharp>(recv[0].buffer()); 00052 00053 for ( unsigned i = 0 ; i < p_size ; ++i ) { 00054 const CommBuffer & send_buf = send[i] ; 00055 const CommBuffer & recv_buf = recv[i] ; 00056 00057 send_counts[i] = send_buf.capacity(); 00058 recv_counts[i] = recv_buf.capacity(); 00059 00060 send_displs[i] = static_cast<ucharp>(send_buf.buffer()) - ps ; 00061 recv_displs[i] = static_cast<ucharp>(recv_buf.buffer()) - pr ; 00062 } 00063 00064 result = MPI_Alltoallv( ps , send_counts , send_displs , MPI_BYTE , 00065 pr , recv_counts , recv_displs , MPI_BYTE , 00066 p_comm ); 00067 00068 if ( MPI_SUCCESS != result ) { 00069 msg << method << " GLOBAL ERROR: " << result << " == MPI_Alltoallv" ; 00070 } 00071 } 00072 00073 return MPI_SUCCESS == result ; 00074 } 00075 00076 bool all_to_all_sparse( ParallelMachine p_comm , 00077 const CommBuffer * const send , 00078 const CommBuffer * const recv , 00079 std::ostream & msg ) 00080 { 00081 static const char method[] = "stk_classic::CommAll::communicate" ; 00082 static const int mpi_tag = STK_MPI_TAG_DATA ; 00083 00084 int result = MPI_SUCCESS ; 00085 00086 { 00087 const unsigned p_size = parallel_machine_size( p_comm ); 00088 const unsigned p_rank = parallel_machine_rank( p_comm ); 00089 00090 //------------------------------ 00091 // Receive count 00092 00093 unsigned num_recv = 0 ; 00094 00095 for ( unsigned i = 0 ; i < p_size ; ++i ) { 00096 if ( recv[i].capacity() ) { ++num_recv ; } 00097 } 00098 00099 //------------------------------ 00100 // Post receives for specific processors with specific sizes 00101 00102 MPI_Request request_null = MPI_REQUEST_NULL ; 00103 std::vector<MPI_Request> request( num_recv , request_null ); 00104 std::vector<MPI_Status> status( num_recv ); 00105 00106 unsigned count = 0 ; 00107 00108 for ( unsigned i = 0 ; result == MPI_SUCCESS && i < p_size ; ++i ) { 00109 const unsigned recv_size = recv[i].capacity(); 00110 void * const recv_buf = recv[i].buffer(); 00111 if ( recv_size ) { 00112 result = MPI_Irecv( recv_buf , recv_size , MPI_BYTE , 00113 i , mpi_tag , p_comm , & request[count] ); 00114 ++count ; 00115 } 00116 } 00117 00118 if ( MPI_SUCCESS != result ) { 00119 msg << method << " LOCAL[" << p_rank << "] ERROR: " 00120 << result << " == MPI_Irecv , " ; 00121 } 00122 00123 //------------------------------ 00124 // Sync to allow ready sends and for a potential error 00125 00126 int local_error = MPI_SUCCESS == result ? 0 : 1 ; 00127 int global_error = 0 ; 00128 00129 result = MPI_Allreduce( & local_error , & global_error , 00130 1 , MPI_INT , MPI_SUM , p_comm ); 00131 00132 if ( MPI_SUCCESS != result ) { 00133 msg << method << " GLOBAL ERROR: " << result << " == MPI_Allreduce" ; 00134 } 00135 else if ( global_error ) { 00136 result = MPI_ERR_UNKNOWN ; 00137 } 00138 else { 00139 // Everything is local from here on out, no more syncs 00140 00141 //------------------------------ 00142 // Ready-send the buffers, rotate the send processor 00143 // in a simple attempt to smooth out the communication traffic. 00144 00145 for ( unsigned i = 0 ; MPI_SUCCESS == result && i < p_size ; ++i ) { 00146 const int dst = ( i + p_rank ) % p_size ; 00147 const unsigned send_size = send[dst].capacity(); 00148 void * const send_buf = send[dst].buffer(); 00149 if ( send_size ) { 00150 result = MPI_Rsend( send_buf , send_size , MPI_BYTE , 00151 dst , mpi_tag , p_comm ); 00152 } 00153 } 00154 00155 if ( MPI_SUCCESS != result ) { 00156 msg << method << " LOCAL ERROR: " << result << " == MPI_Rsend , " ; 00157 } 00158 else { 00159 MPI_Request * const p_request = (request.empty() ? NULL : & request[0]) ; 00160 MPI_Status * const p_status = (status.empty() ? NULL : & status[0]) ; 00161 00162 result = MPI_Waitall( num_recv , p_request , p_status ); 00163 } 00164 00165 if ( MPI_SUCCESS != result ) { 00166 msg << method << " LOCAL[" << p_rank << "] ERROR: " 00167 << result << " == MPI_Waitall , " ; 00168 } 00169 else { 00170 00171 for ( unsigned i = 0 ; i < num_recv ; ++i ) { 00172 MPI_Status * const recv_status = & status[i] ; 00173 const int recv_proc = recv_status->MPI_SOURCE ; 00174 const int recv_tag = recv_status->MPI_TAG ; 00175 const int recv_plan = recv[recv_proc].capacity(); 00176 int recv_count = 0 ; 00177 00178 MPI_Get_count( recv_status , MPI_BYTE , & recv_count ); 00179 00180 if ( recv_tag != mpi_tag || recv_count != recv_plan ) { 00181 msg << method << " LOCAL[" << p_rank << "] ERROR: Recv[" 00182 << recv_proc << "] Size( " 00183 << recv_count << " != " << recv_plan << " ) , " ; 00184 result = MPI_ERR_UNKNOWN ; 00185 } 00186 } 00187 } 00188 } 00189 } 00190 00191 return MPI_SUCCESS == result ; 00192 } 00193 00194 } 00195 00196 #else 00197 00198 // Not parallel 00199 00200 namespace { 00201 00202 bool all_to_all_dense( ParallelMachine , 00203 const CommBuffer * const send , 00204 const CommBuffer * const recv , 00205 std::ostream & ) 00206 { return send == recv ; } 00207 00208 bool all_to_all_sparse( ParallelMachine , 00209 const CommBuffer * const send , 00210 const CommBuffer * const recv , 00211 std::ostream & ) 00212 { return send == recv ; } 00213 00214 } 00215 00216 #endif 00217 00218 //---------------------------------------------------------------------- 00219 00220 namespace { 00221 00222 inline 00223 size_t align_quad( size_t n ) 00224 { 00225 enum { Size = 4 * sizeof(int) }; 00226 return n + CommBufferAlign<Size>::align(n); 00227 } 00228 00229 } 00230 00231 //---------------------------------------------------------------------- 00232 00233 void CommBuffer::pack_overflow() const 00234 { 00235 std::ostringstream os ; 00236 os << "stk_classic::CommBuffer::pack<T>(...){ overflow by " ; 00237 os << remaining() ; 00238 os << " bytes. }" ; 00239 throw std::overflow_error( os.str() ); 00240 } 00241 00242 void CommBuffer::unpack_overflow() const 00243 { 00244 std::ostringstream os ; 00245 os << "stk_classic::CommBuffer::unpack<T>(...){ overflow by " ; 00246 os << remaining(); 00247 os << " bytes. }" ; 00248 throw std::overflow_error( os.str() ); 00249 } 00250 00251 void CommAll::rank_error( const char * method , unsigned p ) const 00252 { 00253 std::ostringstream os ; 00254 os << "stk_classic::CommAll::" << method 00255 << "(" << p << ") ERROR: Not in [0:" << m_size << ")" ; 00256 throw std::range_error( os.str() ); 00257 } 00258 00259 //---------------------------------------------------------------------- 00260 00261 CommBuffer::CommBuffer() 00262 : m_beg(NULL), m_ptr(NULL), m_end(NULL) 00263 { } 00264 00265 CommBuffer::~CommBuffer() 00266 { } 00267 00268 void CommBuffer::deallocate( const unsigned number , CommBuffer * buffers ) 00269 { 00270 if ( NULL != buffers ) { 00271 for ( unsigned i = 0 ; i < number ; ++i ) { 00272 ( buffers + i )->~CommBuffer(); 00273 } 00274 free( buffers ); 00275 } 00276 } 00277 00278 CommBuffer * CommBuffer::allocate( 00279 const unsigned number , const unsigned * const size ) 00280 { 00281 const size_t n_base = align_quad( number * sizeof(CommBuffer) ); 00282 size_t n_size = n_base ; 00283 00284 if ( NULL != size ) { 00285 for ( unsigned i = 0 ; i < number ; ++i ) { 00286 n_size += align_quad( size[i] ); 00287 } 00288 } 00289 00290 // Allocate space for buffers 00291 00292 void * const p_malloc = malloc( n_size ); 00293 00294 CommBuffer * const b_base = 00295 p_malloc != NULL ? reinterpret_cast<CommBuffer*>(p_malloc) 00296 : reinterpret_cast<CommBuffer*>( NULL ); 00297 00298 if ( p_malloc != NULL ) { 00299 00300 for ( unsigned i = 0 ; i < number ; ++i ) { 00301 new( b_base + i ) CommBuffer(); 00302 } 00303 00304 if ( NULL != size ) { 00305 00306 ucharp ptr = reinterpret_cast<ucharp>( p_malloc ); 00307 00308 ptr += n_base ; 00309 00310 for ( unsigned i = 0 ; i < number ; ++i ) { 00311 CommBuffer & b = b_base[i] ; 00312 b.m_beg = ptr ; 00313 b.m_ptr = ptr ; 00314 b.m_end = ptr + size[i] ; 00315 ptr += align_quad( size[i] ); 00316 } 00317 } 00318 } 00319 00320 return b_base ; 00321 } 00322 00323 //---------------------------------------------------------------------- 00324 //---------------------------------------------------------------------- 00325 00326 CommAll::~CommAll() 00327 { 00328 try { 00329 CommBuffer::deallocate( m_size , m_send ); 00330 if ( 1 < m_size ) { CommBuffer::deallocate( m_size , m_recv ); } 00331 } catch(...){} 00332 m_comm = parallel_machine_null(); 00333 m_size = 0 ; 00334 m_rank = 0 ; 00335 m_send = NULL ; 00336 m_recv = NULL ; 00337 } 00338 00339 CommAll::CommAll() 00340 : m_comm( parallel_machine_null() ), 00341 m_size( 0 ), m_rank( 0 ), 00342 m_bound( 0 ), 00343 m_max( 0 ), 00344 m_send(NULL), 00345 m_recv(NULL) 00346 {} 00347 00348 CommAll::CommAll( ParallelMachine comm ) 00349 : m_comm( comm ), 00350 m_size( parallel_machine_size( comm ) ), 00351 m_rank( parallel_machine_rank( comm ) ), 00352 m_bound( 0 ), 00353 m_max( 0 ), 00354 m_send(NULL), 00355 m_recv(NULL) 00356 { 00357 m_send = CommBuffer::allocate( m_size , NULL ); 00358 00359 if ( NULL == m_send ) { 00360 std::string msg("stk_classic::CommAll::CommAll FAILED malloc"); 00361 throw std::runtime_error(msg); 00362 } 00363 } 00364 00365 bool CommAll::allocate_buffers( const unsigned num_msg_bounds , 00366 const bool symmetric , 00367 const bool local_flag ) 00368 { 00369 const unsigned zero = 0 ; 00370 std::vector<unsigned> tmp( m_size , zero ); 00371 00372 for ( unsigned i = 0 ; i < m_size ; ++i ) { 00373 tmp[i] = m_send[i].size(); 00374 } 00375 00376 const unsigned * const send_size = (tmp.empty() ? NULL : & tmp[0]) ; 00377 const unsigned * const recv_size = symmetric ? (tmp.empty() ? NULL : & tmp[0]) : NULL ; 00378 00379 return allocate_buffers( m_comm, num_msg_bounds, 00380 send_size, recv_size, local_flag ); 00381 } 00382 00383 //---------------------------------------------------------------------- 00384 00385 void CommAll::reset_buffers() 00386 { 00387 if ( m_send ) { 00388 CommBuffer * m = m_send ; 00389 CommBuffer * const me = m + m_size ; 00390 for ( ; m != me ; ++m ) { m->reset(); } 00391 } 00392 if ( m_recv && 1 < m_size ) { 00393 CommBuffer * m = m_recv ; 00394 CommBuffer * const me = m + m_size ; 00395 for ( ; m != me ; ++m ) { m->reset(); } 00396 } 00397 } 00398 00399 //---------------------------------------------------------------------- 00400 00401 void CommAll::swap_send_recv() 00402 { 00403 if ( m_recv == NULL ) { 00404 // ERROR 00405 std::string 00406 msg("stk_classic::CommAll::swap_send_recv(){ NULL recv buffers }" ); 00407 throw std::logic_error( msg ); 00408 } 00409 00410 CommBuffer * tmp_msg = m_send ; 00411 m_send = m_recv ; 00412 m_recv = tmp_msg ; 00413 } 00414 00415 //---------------------------------------------------------------------- 00416 00417 bool CommAll::allocate_buffers( ParallelMachine comm , 00418 const unsigned num_msg_bounds , 00419 const unsigned * const send_size , 00420 const unsigned * const recv_size , 00421 const bool local_flag ) 00422 { 00423 static const char method[] = "stk_classic::CommAll::allocate_buffers" ; 00424 const unsigned uzero = 0 ; 00425 00426 CommBuffer::deallocate( m_size , m_send ); 00427 CommBuffer::deallocate( m_size , m_recv ); 00428 00429 m_comm = comm ; 00430 m_size = parallel_machine_size( comm ); 00431 m_rank = parallel_machine_rank( comm ); 00432 m_bound = num_msg_bounds ; 00433 00434 std::ostringstream msg ; 00435 00436 //-------------------------------- 00437 // Buffer allocation 00438 00439 { 00440 const bool send_none = NULL == send_size ; 00441 00442 std::vector<unsigned> tmp_send ; 00443 00444 if ( send_none ) { tmp_send.resize( m_size , uzero ); } 00445 00446 const unsigned * const send = send_none ? (tmp_send.empty() ? NULL : & tmp_send[0]) : send_size ; 00447 00448 m_send = CommBuffer::allocate( m_size , send ); 00449 00450 if ( 1 < m_size ) { 00451 00452 std::vector<unsigned> tmp_recv ; 00453 00454 const bool recv_tbd = NULL == recv_size ; 00455 00456 if ( recv_tbd ) { // Had better be globally consistent. 00457 00458 tmp_recv.resize( m_size , uzero ); 00459 00460 unsigned * const r = (tmp_recv.empty() ? NULL : & tmp_recv[0]) ; 00461 00462 comm_sizes( m_comm , m_bound , m_max , send , r ); 00463 } 00464 00465 const unsigned * const recv = recv_tbd ? (tmp_recv.empty() ? NULL : & tmp_recv[0]) : recv_size ; 00466 00467 m_recv = CommBuffer::allocate( m_size , recv ); 00468 } 00469 else { 00470 m_recv = m_send ; 00471 } 00472 } 00473 00474 bool error_alloc = m_send == NULL || m_recv == NULL ; 00475 00476 //-------------------------------- 00477 // Propogation of error flag, input flag, and quick/cheap/approximate 00478 // verification of send and receive messages. 00479 // Is the number and total size of messages consistent? 00480 // Sum message counts and sizes for grouped processors. 00481 // Sent are positive and received are negative. 00482 // Should finish with all total counts of zero. 00483 00484 enum { NPSum = 7 }; 00485 enum { Length = 2 + 2 * NPSum }; 00486 00487 int local_result[ Length ]; 00488 int global_result[ Length ]; 00489 00490 Copy<Length>( local_result , 0 ); 00491 00492 local_result[ Length - 2 ] = error_alloc ; 00493 local_result[ Length - 1 ] = local_flag ; 00494 00495 if ( ! error_alloc ) { 00496 00497 const unsigned r = 2 * ( m_rank % NPSum ); 00498 00499 for ( unsigned i = 0 ; i < m_size ; ++i ) { 00500 const unsigned n_send = m_send[i].capacity(); 00501 const unsigned n_recv = m_recv[i].capacity(); 00502 00503 const unsigned s = 2 * ( i % NPSum ); 00504 00505 local_result[s] += n_send ? 1 : 0 ; 00506 local_result[s+1] += n_send ; 00507 00508 local_result[r] -= n_recv ? 1 : 0 ; 00509 local_result[r+1] -= n_recv ; 00510 } 00511 } 00512 00513 if (m_size > 1) { 00514 all_reduce_sum( m_comm , local_result , global_result , Length ); 00515 } 00516 else { 00517 Copy<Length>(global_result, local_result); 00518 } 00519 00520 bool global_flag ; 00521 00522 error_alloc = global_result[ Length - 2 ] ; 00523 global_flag = global_result[ Length - 1 ] ; 00524 00525 bool ok = true ; 00526 00527 for ( unsigned i = 0 ; ok && i < 2 * NPSum ; ++i ) { 00528 ok = 0 == global_result[i] ; 00529 } 00530 00531 if ( error_alloc || ! ok ) { 00532 msg << method << " ERROR:" ; 00533 if ( error_alloc ) { msg << " Failed memory allocation ," ; } 00534 if ( ! ok ) { msg << " Parallel inconsistent send/receive ," ; } 00535 throw std::runtime_error( msg.str() ); 00536 } 00537 00538 return global_flag ; 00539 } 00540 00541 //---------------------------------------------------------------------- 00542 00543 void CommAll::communicate() 00544 { 00545 static const char method[] = "stk_classic::CommAll::communicate" ; 00546 00547 std::ostringstream msg ; 00548 00549 // Verify the send buffers have been filled, reset the buffer pointers 00550 00551 for ( unsigned i = 0 ; i < m_size ; ++i ) { 00552 00553 if ( m_send[i].remaining() ) { 00554 msg << method << " LOCAL[" << m_rank << "] ERROR: Send[" << i 00555 << "] Buffer not filled." ; 00556 throw std::underflow_error( msg.str() ); 00557 } 00558 /* 00559 m_send[i].reset(); 00560 */ 00561 m_recv[i].reset(); 00562 } 00563 00564 if ( 1 < m_size ) { 00565 bool ok ; 00566 00567 if ( m_bound < m_max ) { 00568 ok = all_to_all_dense( m_comm , m_send , m_recv , msg ); 00569 } 00570 else { 00571 ok = all_to_all_sparse( m_comm , m_send , m_recv , msg ); 00572 } 00573 00574 if ( ! ok ) { throw std::runtime_error( msg.str() ); } 00575 } 00576 } 00577 00578 //---------------------------------------------------------------------- 00579 //---------------------------------------------------------------------- 00580 00581 CommBroadcast::CommBroadcast( ParallelMachine comm , unsigned root_rank ) 00582 : m_comm( comm ), 00583 m_size( parallel_machine_size( comm ) ), 00584 m_rank( parallel_machine_rank( comm ) ), 00585 m_root_rank( root_rank ), 00586 m_buffer() 00587 {} 00588 00589 bool CommBroadcast::allocate_buffer( const bool local_flag ) 00590 { 00591 static const char method[] = "stk_classic::CommBroadcast::allocate_buffer" ; 00592 00593 unsigned root_rank_min = m_root_rank ; 00594 unsigned root_rank_max = m_root_rank ; 00595 unsigned root_send_size = m_root_rank == m_rank ? m_buffer.size() : 0 ; 00596 unsigned flag = local_flag ; 00597 00598 all_reduce( m_comm , ReduceMin<1>( & root_rank_min ) & 00599 ReduceMax<1>( & root_rank_max ) & 00600 ReduceMax<1>( & root_send_size ) & 00601 ReduceBitOr<1>( & flag ) ); 00602 00603 if ( root_rank_min != root_rank_max ) { 00604 std::string msg ; 00605 msg.append( method ); 00606 msg.append( " FAILED: inconsistent root processor" ); 00607 throw std::runtime_error( msg ); 00608 } 00609 00610 m_buffer.m_beg = static_cast<CommBuffer::ucharp>( malloc( root_send_size ) ); 00611 m_buffer.m_ptr = m_buffer.m_beg ; 00612 m_buffer.m_end = m_buffer.m_beg + root_send_size ; 00613 00614 return flag ; 00615 } 00616 00617 CommBroadcast::~CommBroadcast() 00618 { 00619 try { 00620 if ( m_buffer.m_beg ) { free( static_cast<void*>( m_buffer.m_beg ) ); } 00621 } catch(...) {} 00622 m_buffer.m_beg = NULL ; 00623 m_buffer.m_ptr = NULL ; 00624 m_buffer.m_end = NULL ; 00625 } 00626 00627 CommBuffer & CommBroadcast::recv_buffer() 00628 { 00629 return m_buffer ; 00630 } 00631 00632 CommBuffer & CommBroadcast::send_buffer() 00633 { 00634 static const char method[] = "stk_classic::CommBroadcast::send_buffer" ; 00635 00636 if ( m_root_rank != m_rank ) { 00637 std::string msg ; 00638 msg.append( method ); 00639 msg.append( " FAILED: is not root processor" ); 00640 throw std::runtime_error( msg ); 00641 } 00642 00643 return m_buffer ; 00644 } 00645 00646 void CommBroadcast::communicate() 00647 { 00648 #if defined( STK_HAS_MPI ) 00649 { 00650 const int count = m_buffer.capacity(); 00651 void * const buf = m_buffer.buffer(); 00652 00653 const int result = MPI_Bcast( buf, count, MPI_BYTE, m_root_rank, m_comm); 00654 00655 if ( MPI_SUCCESS != result ) { 00656 std::ostringstream msg ; 00657 msg << "stk_classic::CommBroadcast::communicate ERROR : " 00658 << result << " == MPI_Bcast" ; 00659 throw std::runtime_error( msg.str() ); 00660 } 00661 } 00662 #endif 00663 00664 m_buffer.reset(); 00665 } 00666 00667 //---------------------------------------------------------------------- 00668 //---------------------------------------------------------------------- 00669 00670 CommGather::~CommGather() 00671 { 00672 try { 00673 free( static_cast<void*>( m_send.m_beg ) ); 00674 00675 if ( NULL != m_recv_count ) { free( static_cast<void*>( m_recv_count ) ); } 00676 00677 if ( NULL != m_recv ) { CommBuffer::deallocate( m_size , m_recv ); } 00678 } catch(...){} 00679 } 00680 00681 void CommGather::reset() 00682 { 00683 m_send.reset(); 00684 00685 if ( NULL != m_recv ) { 00686 for ( unsigned i = 0 ; i < m_size ; ++i ) { m_recv[i].reset(); } 00687 } 00688 } 00689 00690 CommBuffer & CommGather::recv_buffer( unsigned p ) 00691 { 00692 static CommBuffer empty ; 00693 00694 return m_size <= p ? empty : ( 00695 m_size <= 1 ? m_send : m_recv[p] ); 00696 } 00697 00698 //---------------------------------------------------------------------- 00699 00700 CommGather::CommGather( ParallelMachine comm , 00701 unsigned root_rank , unsigned send_size ) 00702 : m_comm( comm ), 00703 m_size( parallel_machine_size( comm ) ), 00704 m_rank( parallel_machine_rank( comm ) ), 00705 m_root_rank( root_rank ), 00706 m_send(), 00707 m_recv(NULL), 00708 m_recv_count(NULL), 00709 m_recv_displ(NULL) 00710 { 00711 m_send.m_beg = static_cast<CommBuffer::ucharp>( malloc( send_size ) ); 00712 m_send.m_ptr = m_send.m_beg ; 00713 m_send.m_end = m_send.m_beg + send_size ; 00714 00715 #if defined( STK_HAS_MPI ) 00716 00717 if ( 1 < m_size ) { 00718 00719 const bool is_root = m_rank == m_root_rank ; 00720 00721 if ( is_root ) { 00722 m_recv_count = static_cast<int*>( malloc(2*m_size*sizeof(int)) ); 00723 m_recv_displ = m_recv_count + m_size ; 00724 } 00725 00726 MPI_Gather( & send_size , 1 , MPI_INT , 00727 m_recv_count , 1 , MPI_INT , 00728 m_root_rank , m_comm ); 00729 00730 if ( is_root ) { 00731 m_recv = CommBuffer::allocate( m_size , 00732 reinterpret_cast<unsigned*>( m_recv_count ) ); 00733 00734 for ( unsigned i = 0 ; i < m_size ; ++i ) { 00735 m_recv_displ[i] = m_recv[i].m_beg - m_recv[0].m_beg ; 00736 } 00737 } 00738 } 00739 00740 #endif 00741 00742 } 00743 00744 00745 void CommGather::communicate() 00746 { 00747 #if defined( STK_HAS_MPI ) 00748 00749 if ( 1 < m_size ) { 00750 00751 const int send_count = m_send.capacity(); 00752 00753 void * const send_buf = m_send.buffer(); 00754 void * const recv_buf = m_rank == m_root_rank ? m_recv->buffer() : NULL ; 00755 00756 MPI_Gatherv( send_buf , send_count , MPI_BYTE , 00757 recv_buf , m_recv_count , m_recv_displ , MPI_BYTE , 00758 m_root_rank , m_comm ); 00759 } 00760 00761 #endif 00762 00763 reset(); 00764 } 00765 00766 //---------------------------------------------------------------------- 00767 //---------------------------------------------------------------------- 00768 00769 #if defined( STK_HAS_MPI ) 00770 00771 bool comm_dense_sizes( ParallelMachine comm , 00772 const unsigned * const send_size , 00773 unsigned * const recv_size , 00774 bool local_flag ) 00775 { 00776 static const char method[] = "stk_classic::comm_dense_sizes" ; 00777 00778 const unsigned zero = 0 ; 00779 const unsigned p_size = parallel_machine_size( comm ); 00780 00781 std::vector<unsigned> send_buf( p_size * 2 , zero ); 00782 std::vector<unsigned> recv_buf( p_size * 2 , zero ); 00783 00784 for ( unsigned i = 0 ; i < p_size ; ++i ) { 00785 const unsigned i2 = i * 2 ; 00786 send_buf[i2] = send_size[i] ; 00787 send_buf[i2+1] = local_flag ; 00788 } 00789 00790 { 00791 unsigned * const ps = (send_buf.empty() ? NULL : & send_buf[0]) ; 00792 unsigned * const pr = (recv_buf.empty() ? NULL : & recv_buf[0]) ; 00793 const int result = 00794 MPI_Alltoall( ps , 2 , MPI_UNSIGNED , pr , 2 , MPI_UNSIGNED , comm ); 00795 00796 if ( MPI_SUCCESS != result ) { 00797 std::string msg ; 00798 msg.append( method ); 00799 msg.append( " FAILED: MPI_SUCCESS != MPI_Alltoall" ); 00800 throw std::runtime_error( msg ); 00801 } 00802 } 00803 00804 bool global_flag = false ; 00805 00806 for ( unsigned i = 0 ; i < p_size ; ++i ) { 00807 const unsigned i2 = i * 2 ; 00808 recv_size[i] = recv_buf[i2] ; 00809 if ( recv_buf[i2+1] ) { global_flag = true ; } 00810 } 00811 00812 return global_flag ; 00813 } 00814 00815 //---------------------------------------------------------------------- 00816 00817 namespace { 00818 00819 extern "C" { 00820 00821 void sum_np_max_2_op( 00822 void * inv , void * outv , int * len , ParallelDatatype * ) 00823 { 00824 const int np = *len - 2 ; 00825 unsigned * ind = (unsigned *) inv ; 00826 unsigned * outd = (unsigned *) outv ; 00827 00828 // Sum all but the last two 00829 // the last two are maximum 00830 00831 for ( int i = 0 ; i < np ; ++i ) { 00832 *outd += *ind ; 00833 ++outd ; 00834 ++ind ; 00835 } 00836 if ( outd[0] < ind[0] ) { outd[0] = ind[0] ; } 00837 if ( outd[1] < ind[1] ) { outd[1] = ind[1] ; } 00838 } 00839 00840 } 00841 00842 } 00843 00844 bool comm_sizes( ParallelMachine comm , 00845 const unsigned num_msg_bound , 00846 unsigned & num_msg_maximum , 00847 const unsigned * const send_size , 00848 unsigned * const recv_size , 00849 bool local_flag ) 00850 { 00851 static const char method[] = "stk_classic::comm_unknown_sizes" ; 00852 const unsigned uzero = 0 ; 00853 00854 static MPI_Op mpi_op = MPI_OP_NULL ; 00855 00856 if ( mpi_op == MPI_OP_NULL ) { 00857 // Is fully commutative 00858 MPI_Op_create( sum_np_max_2_op , 1 , & mpi_op ); 00859 } 00860 00861 const unsigned p_size = parallel_machine_size( comm ); 00862 const unsigned p_rank = parallel_machine_rank( comm ); 00863 00864 int result ; 00865 00866 std::ostringstream msg ; 00867 00868 num_msg_maximum = 0 ; 00869 00870 unsigned num_recv = 0 ; 00871 unsigned max_msg = 0 ; 00872 bool global_flag = false ; 00873 00874 { 00875 std::vector<unsigned> send_buf( p_size + 2 , uzero ); 00876 std::vector<unsigned> recv_buf( p_size + 2 , uzero ); 00877 00878 unsigned * const p_send = (send_buf.empty() ? NULL : & send_buf[0]) ; 00879 unsigned * const p_recv = (recv_buf.empty() ? NULL : & recv_buf[0]) ; 00880 00881 for ( unsigned i = 0 ; i < p_size ; ++i ) { 00882 recv_size[i] = 0 ; // Zero output 00883 if ( send_size[i] ) { 00884 send_buf[i] = 1 ; 00885 ++max_msg ; 00886 } 00887 } 00888 send_buf[p_size] = max_msg ; 00889 send_buf[p_size+1] = local_flag ; 00890 00891 result = MPI_Allreduce(p_send,p_recv,p_size+2,MPI_UNSIGNED,mpi_op,comm); 00892 00893 if ( result != MPI_SUCCESS ) { 00894 // PARALLEL ERROR 00895 msg << method << " ERROR: " << result << " == MPI_AllReduce" ; 00896 throw std::runtime_error( msg.str() ); 00897 } 00898 00899 num_recv = recv_buf[ p_rank ] ; 00900 max_msg = recv_buf[ p_size ] ; 00901 global_flag = recv_buf[ p_size + 1 ] ; 00902 00903 // max_msg is now the maximum send count, 00904 // Loop over receive counts to determine 00905 // if a receive count is larger. 00906 00907 for ( unsigned i = 0 ; i < p_size ; ++i ) { 00908 if ( max_msg < recv_buf[i] ) { max_msg = recv_buf[i] ; } 00909 } 00910 } 00911 00912 num_msg_maximum = max_msg ; 00913 00914 if ( num_msg_bound < max_msg ) { 00915 // Dense, pay for an all-to-all 00916 00917 result = 00918 MPI_Alltoall( (void*) send_size , 1 , MPI_UNSIGNED , 00919 recv_size , 1 , MPI_UNSIGNED , comm ); 00920 00921 if ( MPI_SUCCESS != result ) { 00922 // LOCAL ERROR ? 00923 msg << method << " ERROR: " << result << " == MPI_Alltoall" ; 00924 throw std::runtime_error( msg.str() ); 00925 } 00926 } 00927 else if ( max_msg ) { 00928 // Sparse, just do point-to-point 00929 00930 const int mpi_tag = STK_MPI_TAG_SIZING ; 00931 00932 MPI_Request request_null = MPI_REQUEST_NULL ; 00933 std::vector<MPI_Request> request( num_recv , request_null ); 00934 std::vector<MPI_Status> status( num_recv ); 00935 std::vector<unsigned> buf( num_recv ); 00936 00937 // Post receives for point-to-point message sizes 00938 00939 for ( unsigned i = 0 ; i < num_recv ; ++i ) { 00940 unsigned * const p_buf = & buf[i] ; 00941 MPI_Request * const p_request = & request[i] ; 00942 result = MPI_Irecv( p_buf , 1 , MPI_UNSIGNED , 00943 MPI_ANY_SOURCE , mpi_tag , comm , p_request ); 00944 if ( MPI_SUCCESS != result ) { 00945 // LOCAL ERROR 00946 msg << method << " ERROR: " << result << " == MPI_Irecv" ; 00947 throw std::runtime_error( msg.str() ); 00948 } 00949 } 00950 00951 // Send the point-to-point message sizes, 00952 // rotate the sends in an attempt to balance the message traffic. 00953 00954 for ( unsigned i = 0 ; i < p_size ; ++i ) { 00955 int dst = ( i + p_rank ) % p_size ; 00956 unsigned value = send_size[dst] ; 00957 if ( value ) { 00958 result = MPI_Send( & value , 1 , MPI_UNSIGNED , dst , mpi_tag , comm ); 00959 if ( MPI_SUCCESS != result ) { 00960 // LOCAL ERROR 00961 msg << method << " ERROR: " << result << " == MPI_Send" ; 00962 throw std::runtime_error( msg.str() ); 00963 } 00964 } 00965 } 00966 00967 // Wait for all receives 00968 00969 { 00970 MPI_Request * const p_request = (request.empty() ? NULL : & request[0]) ; 00971 MPI_Status * const p_status = (status.empty() ? NULL : & status[0]) ; 00972 result = MPI_Waitall( num_recv , p_request , p_status ); 00973 } 00974 if ( MPI_SUCCESS != result ) { 00975 // LOCAL ERROR ? 00976 msg << method << " ERROR: " << result << " == MPI_Waitall" ; 00977 throw std::runtime_error( msg.str() ); 00978 } 00979 00980 // Set the receive message sizes 00981 00982 for ( unsigned i = 0 ; i < num_recv ; ++i ) { 00983 MPI_Status * const recv_status = & status[i] ; 00984 const int recv_proc = recv_status->MPI_SOURCE ; 00985 const int recv_tag = recv_status->MPI_TAG ; 00986 int recv_count = 0 ; 00987 00988 MPI_Get_count( recv_status , MPI_UNSIGNED , & recv_count ); 00989 00990 if ( recv_tag != mpi_tag || recv_count != 1 ) { 00991 msg << method << " ERROR: Received buffer mismatch " ; 00992 msg << "P" << p_rank << " <- P" << recv_proc ; 00993 msg << " " << 1 << " != " << recv_count ; 00994 throw std::runtime_error( msg.str() ); 00995 } 00996 00997 const unsigned r_size = buf[i] ; 00998 recv_size[ recv_proc ] = r_size ; 00999 } 01000 } 01001 01002 return global_flag ; 01003 } 01004 01005 //---------------------------------------------------------------------- 01006 //---------------------------------------------------------------------- 01007 01008 #else 01009 01010 01011 bool comm_sizes( ParallelMachine , 01012 const unsigned , 01013 unsigned & num_msg_maximum , 01014 const unsigned * const send_size , 01015 unsigned * const recv_size , 01016 bool local_flag ) 01017 { 01018 num_msg_maximum = send_size[0] ? 1 : 0 ; 01019 01020 recv_size[0] = send_size[0] ; 01021 01022 return local_flag ; 01023 } 01024 01025 bool comm_dense_sizes( ParallelMachine , 01026 const unsigned * const send_size , 01027 unsigned * const recv_size , 01028 bool local_flag ) 01029 { 01030 recv_size[0] = send_size[0] ; 01031 01032 return local_flag ; 01033 } 01034 01035 //---------------------------------------------------------------------- 01036 01037 #endif 01038 01039 } 01040