Ginkgo Generated from branch based on main. Ginkgo version 1.9.0
A numerical linear algebra library targeting many-core architectures
Loading...
Searching...
No Matches
batch_solver_base.hpp
1// SPDX-FileCopyrightText: 2017 - 2024 The Ginkgo authors
2//
3// SPDX-License-Identifier: BSD-3-Clause
4
5#ifndef GKO_PUBLIC_CORE_SOLVER_BATCH_SOLVER_BASE_HPP_
6#define GKO_PUBLIC_CORE_SOLVER_BATCH_SOLVER_BASE_HPP_
7
8
9#include <ginkgo/core/base/abstract_factory.hpp>
10#include <ginkgo/core/base/batch_lin_op.hpp>
11#include <ginkgo/core/base/batch_multi_vector.hpp>
12#include <ginkgo/core/base/utils_helper.hpp>
13#include <ginkgo/core/log/batch_logger.hpp>
14#include <ginkgo/core/matrix/batch_identity.hpp>
15#include <ginkgo/core/stop/batch_stop_enum.hpp>
16
17
18namespace gko {
19namespace batch {
20namespace solver {
21
22
30public:
36 std::shared_ptr<const BatchLinOp> get_system_matrix() const
37 {
38 return this->system_matrix_;
39 }
40
46 std::shared_ptr<const BatchLinOp> get_preconditioner() const
47 {
48 return this->preconditioner_;
49 }
50
56 double get_tolerance() const { return this->residual_tol_; }
57
64 void reset_tolerance(double res_tol)
65 {
66 if (res_tol < 0) {
67 GKO_INVALID_STATE("Tolerance cannot be negative!");
68 }
69 this->residual_tol_ = res_tol;
70 }
71
77 int get_max_iterations() const { return this->max_iterations_; }
78
85 void reset_max_iterations(int max_iterations)
86 {
87 if (max_iterations < 0) {
88 GKO_INVALID_STATE("Max iterations cannot be negative!");
89 }
90 this->max_iterations_ = max_iterations;
91 }
92
98 ::gko::batch::stop::tolerance_type get_tolerance_type() const
99 {
100 return this->tol_type_;
101 }
102
108 void reset_tolerance_type(::gko::batch::stop::tolerance_type tol_type)
109 {
110 if (tol_type == ::gko::batch::stop::tolerance_type::absolute ||
111 tol_type == ::gko::batch::stop::tolerance_type::relative) {
112 this->tol_type_ = tol_type;
113 } else {
114 GKO_INVALID_STATE("Invalid tolerance type specified!");
115 }
116 }
117
118protected:
119 BatchSolver() {}
120
121 BatchSolver(std::shared_ptr<const BatchLinOp> system_matrix,
122 std::shared_ptr<const BatchLinOp> gen_preconditioner,
123 const double res_tol, const int max_iterations,
124 const ::gko::batch::stop::tolerance_type tol_type)
125 : system_matrix_{std::move(system_matrix)},
126 preconditioner_{std::move(gen_preconditioner)},
127 residual_tol_{res_tol},
128 max_iterations_{max_iterations},
129 tol_type_{tol_type},
130 workspace_{}
131 {}
132
133 void set_system_matrix_base(std::shared_ptr<const BatchLinOp> system_matrix)
134 {
135 this->system_matrix_ = std::move(system_matrix);
136 }
137
138 void set_preconditioner_base(std::shared_ptr<const BatchLinOp> precond)
139 {
140 this->preconditioner_ = std::move(precond);
141 }
142
143 std::shared_ptr<const BatchLinOp> system_matrix_{};
144 std::shared_ptr<const BatchLinOp> preconditioner_{};
145 double residual_tol_{};
146 int max_iterations_{};
147 ::gko::batch::stop::tolerance_type tol_type_{};
148 mutable array<unsigned char> workspace_{};
149};
150
151
152template <typename Parameters, typename Factory>
154 : enable_parameters_type<Parameters, Factory> {
162
170
175 ::gko::batch::stop::tolerance_type GKO_FACTORY_PARAMETER_SCALAR(
176 tolerance_type, ::gko::batch::stop::tolerance_type::absolute);
177
182 std::shared_ptr<const BatchLinOpFactory> GKO_DEFERRED_FACTORY_PARAMETER(
184
189 std::shared_ptr<const BatchLinOp> GKO_FACTORY_PARAMETER_SCALAR(
191};
192
193
202template <typename ConcreteSolver, typename ValueType,
203 typename PolymorphicBase = BatchLinOp>
205 : public BatchSolver,
206 public EnableBatchLinOp<ConcreteSolver, PolymorphicBase> {
207public:
208 using real_type = remove_complex<ValueType>;
209
210 const ConcreteSolver* apply(ptr_param<const MultiVector<ValueType>> b,
212 {
213 this->validate_application_parameters(b.get(), x.get());
214 auto exec = this->get_executor();
215 this->apply_impl(make_temporary_clone(exec, b).get(),
216 make_temporary_clone(exec, x).get());
217 return self();
218 }
219
220 const ConcreteSolver* apply(ptr_param<const MultiVector<ValueType>> alpha,
224 {
225 this->validate_application_parameters(alpha.get(), b.get(), beta.get(),
226 x.get());
227 auto exec = this->get_executor();
228 this->apply_impl(make_temporary_clone(exec, alpha).get(),
229 make_temporary_clone(exec, b).get(),
230 make_temporary_clone(exec, beta).get(),
231 make_temporary_clone(exec, x).get());
232 return self();
233 }
234
235 ConcreteSolver* apply(ptr_param<const MultiVector<ValueType>> b,
237 {
238 this->validate_application_parameters(b.get(), x.get());
239 auto exec = this->get_executor();
240 this->apply_impl(make_temporary_clone(exec, b).get(),
241 make_temporary_clone(exec, x).get());
242 return self();
243 }
244
245 ConcreteSolver* apply(ptr_param<const MultiVector<ValueType>> alpha,
249 {
250 this->validate_application_parameters(alpha.get(), b.get(), beta.get(),
251 x.get());
252 auto exec = this->get_executor();
253 this->apply_impl(make_temporary_clone(exec, alpha).get(),
254 make_temporary_clone(exec, b).get(),
255 make_temporary_clone(exec, beta).get(),
256 make_temporary_clone(exec, x).get());
257 return self();
258 }
259
260protected:
261 GKO_ENABLE_SELF(ConcreteSolver);
262
263 explicit EnableBatchSolver(std::shared_ptr<const Executor> exec)
265 {}
266
267 template <typename FactoryParameters>
268 explicit EnableBatchSolver(std::shared_ptr<const Executor> exec,
269 std::shared_ptr<const BatchLinOp> system_matrix,
270 const FactoryParameters& params)
271 : BatchSolver(system_matrix, nullptr, params.tolerance,
272 params.max_iterations, params.tolerance_type),
274 exec, gko::transpose(system_matrix->get_size()))
275 {
276 GKO_ASSERT_BATCH_HAS_SQUARE_DIMENSIONS(system_matrix_);
277
278 using value_type = typename ConcreteSolver::value_type;
279 using Identity = matrix::Identity<value_type>;
280 using real_type = remove_complex<value_type>;
281
282 if (params.generated_preconditioner) {
283 GKO_ASSERT_BATCH_EQUAL_DIMENSIONS(params.generated_preconditioner,
284 this);
285 preconditioner_ = std::move(params.generated_preconditioner);
286 } else if (params.preconditioner) {
287 preconditioner_ = params.preconditioner->generate(system_matrix_);
288 } else {
289 auto id = Identity::create(exec, system_matrix->get_size());
290 preconditioner_ = std::move(id);
291 }
292 // We use a workspace here to store the logger data (iteration count
293 // and solver residual), and require a minimum size of
294 // `sizeof(real_type)+ sizeof(int)`
295 const size_type workspace_size =
296 system_matrix->get_num_batch_items() * 32;
297 workspace_.set_executor(exec);
298 workspace_.resize_and_reset(workspace_size);
299 }
300
301 void set_system_matrix(std::shared_ptr<const BatchLinOp> new_system_matrix)
302 {
303 auto exec = self()->get_executor();
304 if (new_system_matrix) {
305 GKO_ASSERT_BATCH_EQUAL_DIMENSIONS(self(), new_system_matrix);
306 GKO_ASSERT_BATCH_HAS_SQUARE_DIMENSIONS(new_system_matrix);
307 if (new_system_matrix->get_executor() != exec) {
308 new_system_matrix = gko::clone(exec, new_system_matrix);
309 }
310 }
311 this->set_system_matrix_base(new_system_matrix);
312 }
313
314 void set_preconditioner(std::shared_ptr<const BatchLinOp> new_precond)
315 {
316 auto exec = self()->get_executor();
317 if (new_precond) {
318 GKO_ASSERT_BATCH_EQUAL_DIMENSIONS(self(), new_precond);
319 GKO_ASSERT_BATCH_HAS_SQUARE_DIMENSIONS(new_precond);
320 if (new_precond->get_executor() != exec) {
321 new_precond = gko::clone(exec, new_precond);
322 }
323 }
324 this->set_preconditioner_base(new_precond);
325 }
326
327 EnableBatchSolver& operator=(const EnableBatchSolver& other)
328 {
329 if (&other != this) {
330 this->set_size(other.get_size());
331 this->set_system_matrix(other.get_system_matrix());
332 this->set_preconditioner(other.get_preconditioner());
333 this->reset_tolerance(other.get_tolerance());
336 }
337
338 return *this;
339 }
340
341 EnableBatchSolver& operator=(EnableBatchSolver&& other)
342 {
343 if (&other != this) {
344 this->set_size(other.get_size());
345 this->set_system_matrix(other.get_system_matrix());
346 this->set_preconditioner(other.get_preconditioner());
347 this->reset_tolerance(other.get_tolerance());
350 other.set_system_matrix(nullptr);
351 other.set_preconditioner(nullptr);
352 }
353 return *this;
354 }
355
358 other.self()->get_executor(), other.self()->get_size())
359 {
360 *this = other;
361 }
362
365 other.self()->get_executor(), other.self()->get_size())
366 {
367 *this = std::move(other);
368 }
369
370 void apply_impl(const MultiVector<ValueType>* b,
371 MultiVector<ValueType>* x) const
372 {
373 auto exec = this->get_executor();
374 if (b->get_common_size()[1] > 1) {
375 GKO_NOT_IMPLEMENTED;
376 }
377 auto workspace_view = workspace_.as_view();
378 auto log_data_ = std::make_unique<log::detail::log_data<real_type>>(
379 exec, b->get_num_batch_items(), workspace_view);
380
381 this->solver_apply(b, x, log_data_.get());
382
383 this->template log<gko::log::Logger::batch_solver_completed>(
384 log_data_->iter_counts, log_data_->res_norms);
385 }
386
387 void apply_impl(const MultiVector<ValueType>* alpha,
388 const MultiVector<ValueType>* b,
389 const MultiVector<ValueType>* beta,
390 MultiVector<ValueType>* x) const
391 {
392 auto x_clone = x->clone();
393 this->apply(b, x_clone.get());
394 x->scale(beta);
395 x->add_scaled(alpha, x_clone.get());
396 }
397
398 virtual void solver_apply(const MultiVector<ValueType>* b,
400 log::detail::log_data<real_type>* info) const = 0;
401};
402
403
404} // namespace solver
405} // namespace batch
406} // namespace gko
407
408
409#endif // GKO_PUBLIC_CORE_SOLVER_BATCH_SOLVER_BASE_HPP_
Definition batch_lin_op.hpp:59
The EnableBatchLinOp mixin can be used to provide sensible default implementations of the majority of...
Definition batch_lin_op.hpp:252
MultiVector stores multiple vectors in a batched fashion and is useful for batched operations.
Definition logger.hpp:41
void scale(ptr_param< const MultiVector< ValueType > > alpha)
Scales the vector with a scalar (aka: BLAS scal).
dim< 2 > get_common_size() const
Returns the common size of the batch items.
Definition batch_multi_vector.hpp:144
size_type get_num_batch_items() const
Returns the number of batch items.
Definition batch_multi_vector.hpp:134
void add_scaled(ptr_param< const MultiVector< ValueType > > alpha, ptr_param< const MultiVector< ValueType > > b)
Adds b scaled by alpha to the vector (aka: BLAS axpy).
The batch Identity matrix, which represents a batch of Identity matrices.
Definition batch_identity.hpp:32
The BatchSolver is a base class for all batched solvers and provides the common getters and setter fo...
Definition batch_solver_base.hpp:29
std::shared_ptr< const BatchLinOp > get_system_matrix() const
Returns the system operator (matrix) of the linear system.
Definition batch_solver_base.hpp:36
void reset_max_iterations(int max_iterations)
Set the maximum number of iterations for the solver to use, independent of the factory that created i...
Definition batch_solver_base.hpp:85
double get_tolerance() const
Get the residual tolerance used by the solver.
Definition batch_solver_base.hpp:56
int get_max_iterations() const
Get the maximum number of iterations set on the solver.
Definition batch_solver_base.hpp:77
void reset_tolerance(double res_tol)
Update the residual tolerance to be used by the solver.
Definition batch_solver_base.hpp:64
::gko::batch::stop::tolerance_type get_tolerance_type() const
Get the tolerance type.
Definition batch_solver_base.hpp:98
void reset_tolerance_type(::gko::batch::stop::tolerance_type tol_type)
Set the type of tolerance check to use inside the solver.
Definition batch_solver_base.hpp:108
std::shared_ptr< const BatchLinOp > get_preconditioner() const
Returns the generated preconditioner.
Definition batch_solver_base.hpp:46
This mixin provides apply and common iterative solver functionality to all the batched solvers.
Definition batch_solver_base.hpp:206
The enable_parameters_type mixin is used to create a base implementation of the factory parameters st...
Definition abstract_factory.hpp:211
This class is used for function parameters in the place of raw pointers.
Definition utils_helper.hpp:41
#define GKO_FACTORY_PARAMETER_SCALAR(_name, _default)
Creates a scalar factory parameter in the factory parameters structure.
Definition abstract_factory.hpp:445
The Ginkgo namespace.
Definition abstract_factory.hpp:20
typename detail::remove_complex_s< T >::type remove_complex
Obtain the type which removed the complex of complex/scalar type or the template parameter of class b...
Definition math.hpp:260
std::size_t size_type
Integral type used for allocation quantities.
Definition types.hpp:89
detail::cloned_type< Pointer > clone(const Pointer &p)
Creates a unique clone of the object pointed to by p.
Definition utils_helper.hpp:173
batch_dim< 2, DimensionType > transpose(const batch_dim< 2, DimensionType > &input)
Returns a batch_dim object with its dimensions swapped for batched operators.
Definition batch_dim.hpp:119
detail::temporary_clone< detail::pointee< Ptr > > make_temporary_clone(std::shared_ptr< const Executor > exec, Ptr &&ptr)
Creates a temporary_clone.
Definition temporary_clone.hpp:208
STL namespace.
int max_iterations
Default maximum number iterations allowed.
Definition batch_solver_base.hpp:161
double tolerance
Default residual tolerance.
Definition batch_solver_base.hpp:169
std::shared_ptr< const BatchLinOpFactory > preconditioner
The preconditioner to be used by the iterative solver.
Definition batch_solver_base.hpp:183
::gko::batch::stop::tolerance_type tolerance_type
To specify which type of tolerance check is to be considered, absolute or relative (to the rhs l2 nor...
Definition batch_solver_base.hpp:176
std::shared_ptr< const BatchLinOp > generated_preconditioner
Already generated preconditioner.
Definition batch_solver_base.hpp:190