10#ifndef IFPACK2_DETAILS_CHEBYSHEVKERNEL_DEF_HPP
11#define IFPACK2_DETAILS_CHEBYSHEVKERNEL_DEF_HPP
13#include "Tpetra_CrsMatrix.hpp"
14#include "Tpetra_MultiVector.hpp"
15#include "Tpetra_Operator.hpp"
16#include "Tpetra_Vector.hpp"
17#include "Tpetra_Export_decl.hpp"
18#include "Tpetra_Import_decl.hpp"
19#include "Kokkos_ArithTraits.hpp"
20#include "Teuchos_Assert.hpp"
22#include "KokkosSparse_spmv_impl.hpp"
32template<
class WVector,
42 static_assert (
static_cast<int> (WVector::rank) == 1,
43 "WVector must be a rank 1 View.");
44 static_assert (
static_cast<int> (DVector::rank) == 1,
45 "DVector must be a rank 1 View.");
46 static_assert (
static_cast<int> (BVector::rank) == 1,
47 "BVector must be a rank 1 View.");
48 static_assert (
static_cast<int> (XVector_colMap::rank) == 1,
49 "XVector_colMap must be a rank 1 View.");
50 static_assert (
static_cast<int> (XVector_domMap::rank) == 1,
51 "XVector_domMap must be a rank 1 View.");
53 using execution_space =
typename AMatrix::execution_space;
54 using LO =
typename AMatrix::non_const_ordinal_type;
55 using value_type =
typename AMatrix::non_const_value_type;
56 using team_policy =
typename Kokkos::TeamPolicy<execution_space>;
57 using team_member =
typename team_policy::member_type;
58 using ATV = Kokkos::ArithTraits<value_type>;
69 const LO rows_per_team;
90 const size_t numRows = m_A.numRows ();
91 const size_t numCols = m_A.numCols ();
104 using KAT = Kokkos::ArithTraits<residual_value_type>;
107 (Kokkos::TeamThreadRange (
dev, 0, rows_per_team),
108 [&] (
const LO&
loop) {
110 static_cast<LO
> (
dev.league_rank ()) * rows_per_team +
loop;
111 if (
lclRow >= m_A.numRows ()) {
114 const KokkosSparse::SparseRowViewConst<AMatrix>
A_row = m_A.rowConst(
lclRow);
118 Kokkos::parallel_reduce
126 (Kokkos::PerThread(
dev),
154chebyshev_kernel_vector
165 using execution_space =
typename AMatrix::execution_space;
167 if (A.numRows () == 0) {
172 int vector_length = -1;
175 const int64_t rows_per_team = KokkosSparse::Impl::spmv_launch_parameters<execution_space>(A.numRows(), A.nnz(),
rows_per_thread, team_size, vector_length);
178 using Kokkos::Dynamic;
179 using Kokkos::Static;
180 using Kokkos::Schedule;
181 using Kokkos::TeamPolicy;
192 policyDynamic = policy_type_dynamic (worksets, team_size, vector_length);
193 policyStatic = policy_type_static (worksets, team_size, vector_length);
197 using w_vec_type =
typename WVector::non_const_type;
198 using d_vec_type =
typename DVector::const_type;
199 using b_vec_type =
typename BVector::const_type;
200 using matrix_type = AMatrix;
201 using x_colMap_vec_type =
typename XVector_colMap::const_type;
202 using x_domMap_vec_type =
typename XVector_domMap::non_const_type;
203 using scalar_type =
typename Kokkos::ArithTraits<Scalar>::val_type;
205 if (beta == Kokkos::ArithTraits<Scalar>::zero ()) {
206 constexpr bool use_beta =
false;
209 ChebyshevKernelVectorFunctor<w_vec_type, d_vec_type,
210 b_vec_type, matrix_type,
211 x_colMap_vec_type, x_domMap_vec_type,
215 functor_type func (alpha, w, d, b, A, x_colMap, x_domMap, beta, rows_per_team);
217 Kokkos::parallel_for (kernel_label, policyDynamic, func);
219 Kokkos::parallel_for (kernel_label, policyStatic, func);
222 ChebyshevKernelVectorFunctor<w_vec_type, d_vec_type,
223 b_vec_type, matrix_type,
224 x_colMap_vec_type, x_domMap_vec_type,
228 functor_type func (alpha, w, d, b, A, x_colMap, x_domMap, beta, rows_per_team);
230 Kokkos::parallel_for (kernel_label, policyDynamic, func);
232 Kokkos::parallel_for (kernel_label, policyStatic, func);
236 constexpr bool use_beta =
true;
239 ChebyshevKernelVectorFunctor<w_vec_type, d_vec_type,
240 b_vec_type, matrix_type,
241 x_colMap_vec_type, x_domMap_vec_type,
245 functor_type func (alpha, w, d, b, A, x_colMap, x_domMap, beta, rows_per_team);
247 Kokkos::parallel_for (kernel_label, policyDynamic, func);
249 Kokkos::parallel_for (kernel_label, policyStatic, func);
252 ChebyshevKernelVectorFunctor<w_vec_type, d_vec_type,
253 b_vec_type, matrix_type,
254 x_colMap_vec_type, x_domMap_vec_type,
258 functor_type func (alpha, w, d, b, A, x_colMap, x_domMap, beta, rows_per_team);
260 Kokkos::parallel_for (kernel_label, policyDynamic, func);
262 Kokkos::parallel_for (kernel_label, policyStatic, func);
269template<
class TpetraOperatorType>
270ChebyshevKernel<TpetraOperatorType>::
271ChebyshevKernel (
const Teuchos::RCP<const operator_type>& A,
272 const bool useNativeSpMV):
273 useNativeSpMV_(useNativeSpMV)
278template<
class TpetraOperatorType>
280ChebyshevKernel<TpetraOperatorType>::
281setMatrix (
const Teuchos::RCP<const operator_type>& A)
283 if (A_op_.get () != A.get ()) {
287 V1_ = std::unique_ptr<multivector_type> (
nullptr);
289 using Teuchos::rcp_dynamic_cast;
290 Teuchos::RCP<const crs_matrix_type> A_crs =
291 rcp_dynamic_cast<const crs_matrix_type> (A);
292 if (A_crs.is_null ()) {
293 A_crs_ = Teuchos::null;
294 imp_ = Teuchos::null;
295 exp_ = Teuchos::null;
299 TEUCHOS_ASSERT( A_crs->isFillComplete () );
301 auto G = A_crs->getCrsGraph ();
302 imp_ = G->getImporter ();
303 exp_ = G->getExporter ();
304 if (!imp_.is_null ()) {
305 if (X_colMap_.get () ==
nullptr ||
306 !X_colMap_->getMap()->isSameAs (*(imp_->getTargetMap ()))) {
307 X_colMap_ = std::unique_ptr<vector_type> (
new vector_type (imp_->getTargetMap ()));
315template<
class TpetraOperatorType>
317ChebyshevKernel<TpetraOperatorType>::
318compute (multivector_type& W,
330 W_vec_ = W.getVectorNonConst (0);
331 B_vec_ = B.getVectorNonConst (0);
332 X_vec_ = X.getVectorNonConst (0);
333 TEUCHOS_ASSERT( ! A_crs_.is_null () );
334 fusedCase (*W_vec_, alpha, D_inv, *B_vec_, *A_crs_, *X_vec_, beta);
337 TEUCHOS_ASSERT( ! A_op_.is_null () );
338 unfusedCase (W, alpha, D_inv, B, *A_op_, X, beta);
342template<
class TpetraOperatorType>
343typename ChebyshevKernel<TpetraOperatorType>::vector_type&
344ChebyshevKernel<TpetraOperatorType>::
345importVector (vector_type& X_domMap)
347 if (imp_.is_null ()) {
351 X_colMap_->doImport (X_domMap, *imp_, Tpetra::REPLACE);
356template<
class TpetraOperatorType>
358ChebyshevKernel<TpetraOperatorType>::
359canFuse (
const multivector_type& B)
const
366 return B.getNumVectors () == size_t (1) &&
367 ! A_crs_.is_null () &&
371template<
class TpetraOperatorType>
373ChebyshevKernel<TpetraOperatorType>::
374unfusedCase (multivector_type& W,
378 const operator_type& A,
382 using STS = Teuchos::ScalarTraits<SC>;
383 if (V1_.get () ==
nullptr) {
384 using MV = multivector_type;
385 const size_t numVecs = B.getNumVectors ();
386 V1_ = std::unique_ptr<MV> (
new MV (B.getMap (), numVecs));
388 const SC one = Teuchos::ScalarTraits<SC>::one ();
391 Tpetra::deep_copy (*V1_, B);
392 A.apply (X, *V1_, Teuchos::NO_TRANS, -one, one);
395 W.elementWiseMultiply (alpha, D_inv, *V1_, beta);
398 X.update (STS::one(), W, STS::one());
401template<
class TpetraOperatorType>
403ChebyshevKernel<TpetraOperatorType>::
404fusedCase (vector_type& W,
408 const crs_matrix_type& A,
412 vector_type& X_colMap = importVector (X);
414 using Impl::chebyshev_kernel_vector;
415 using STS = Teuchos::ScalarTraits<SC>;
417 auto A_lcl = A.getLocalMatrixDevice ();
419 auto Dinv_lcl = Kokkos::subview(D_inv.getLocalViewDevice(Tpetra::Access::ReadOnly), Kokkos::ALL(), 0);
420 auto B_lcl = Kokkos::subview(B.getLocalViewDevice(Tpetra::Access::ReadOnly), Kokkos::ALL(), 0);
421 auto X_domMap_lcl = Kokkos::subview(X.getLocalViewDevice(Tpetra::Access::ReadWrite), Kokkos::ALL(), 0);
422 auto X_colMap_lcl = Kokkos::subview(X_colMap.getLocalViewDevice(Tpetra::Access::ReadOnly), Kokkos::ALL(), 0);
424 const bool do_X_update = !imp_.is_null ();
425 if (beta == STS::zero ()) {
426 auto W_lcl = Kokkos::subview(W.getLocalViewDevice(Tpetra::Access::OverwriteAll), Kokkos::ALL(), 0);
427 chebyshev_kernel_vector (alpha, W_lcl, Dinv_lcl,
429 X_colMap_lcl, X_domMap_lcl,
434 auto W_lcl = Kokkos::subview(W.getLocalViewDevice(Tpetra::Access::ReadWrite), Kokkos::ALL(), 0);
435 chebyshev_kernel_vector (alpha, W_lcl, Dinv_lcl,
437 X_colMap_lcl, X_domMap_lcl,
442 X.update(STS::one (), W, STS::one ());
448#define IFPACK2_DETAILS_CHEBYSHEVKERNEL_INSTANT(SC,LO,GO,NT) \
449 template class Ifpack2::Details::ChebyshevKernel<Tpetra::Operator<SC, LO, GO, NT> >;
Ifpack2's implementation of Trilinos::Details::LinearSolver interface.
Definition Ifpack2_Details_LinearSolver_decl.hpp:77
Ifpack2 implementation details.
Preconditioners and smoothers for Tpetra sparse matrices.
Definition Ifpack2_AdditiveSchwarz_decl.hpp:41
Functor for computing W := alpha * D * (B - A*X) + beta * W and X := X+W.
Definition Ifpack2_Details_ChebyshevKernel_def.hpp:41