Ginkgo Generated from branch based on main. Ginkgo version 1.9.0
A numerical linear algebra library targeting many-core architectures
Loading...
Searching...
No Matches
range.hpp
1// SPDX-FileCopyrightText: 2017 - 2024 The Ginkgo authors
2//
3// SPDX-License-Identifier: BSD-3-Clause
4
5#ifndef GKO_PUBLIC_CORE_BASE_RANGE_HPP_
6#define GKO_PUBLIC_CORE_BASE_RANGE_HPP_
7
8
9#include <type_traits>
10
11#include <ginkgo/core/base/math.hpp>
12#include <ginkgo/core/base/types.hpp>
13#include <ginkgo/core/base/utils.hpp>
14
15
16namespace gko {
17
18
46struct span {
54 GKO_ATTRIBUTES constexpr span(size_type point) noexcept
55 : span{point, point + 1}
56 {}
57
64 GKO_ATTRIBUTES constexpr span(size_type begin, size_type end) noexcept
65 : begin{begin}, end{end}
66 {}
67
73 GKO_ATTRIBUTES constexpr bool is_valid() const { return begin <= end; }
74
80 GKO_ATTRIBUTES constexpr size_type length() const { return end - begin; }
81
86
91};
92
93
94GKO_ATTRIBUTES GKO_INLINE constexpr bool operator<(const span& first,
95 const span& second)
96{
97 return first.end < second.begin;
98}
99
100
101GKO_ATTRIBUTES GKO_INLINE constexpr bool operator<=(const span& first,
102 const span& second)
103{
104 return first.end <= second.begin;
105}
106
107
108GKO_ATTRIBUTES GKO_INLINE constexpr bool operator>(const span& first,
109 const span& second)
110{
111 return second < first;
112}
113
114
115GKO_ATTRIBUTES GKO_INLINE constexpr bool operator>=(const span& first,
116 const span& second)
117{
118 return second <= first;
119}
120
121
122GKO_ATTRIBUTES GKO_INLINE constexpr bool operator==(const span& first,
123 const span& second)
124{
125 return first.begin == second.begin && first.end == second.end;
126}
127
128
129GKO_ATTRIBUTES GKO_INLINE constexpr bool operator!=(const span& first,
130 const span& second)
131{
132 return !(first == second);
133}
134
135
136namespace detail {
137
138
139template <size_type CurrentDimension = 0, typename FirstRange,
140 typename SecondRange>
141GKO_ATTRIBUTES constexpr GKO_INLINE
142 std::enable_if_t<(CurrentDimension >= max(FirstRange::dimensionality,
143 SecondRange::dimensionality)),
144 bool>
145 equal_dimensions(const FirstRange&, const SecondRange&)
146{
147 return true;
148}
149
150template <size_type CurrentDimension = 0, typename FirstRange,
151 typename SecondRange>
152GKO_ATTRIBUTES constexpr GKO_INLINE
153 std::enable_if_t<(CurrentDimension < max(FirstRange::dimensionality,
154 SecondRange::dimensionality)),
155 bool>
156 equal_dimensions(const FirstRange& first, const SecondRange& second)
157{
158 return first.length(CurrentDimension) == second.length(CurrentDimension) &&
159 equal_dimensions<CurrentDimension + 1>(first, second);
160}
161
166template <class...>
167struct head;
168
172template <class First, class... Rest>
173struct head<First, Rest...> {
174 using type = First;
175};
176
180template <class... T>
181using head_t = typename head<T...>::type;
182
183
184} // namespace detail
185
186
296template <typename Accessor>
297class range {
298public:
302 using accessor = Accessor;
303
307 static constexpr size_type dimensionality = accessor::dimensionality;
308
312 ~range() = default;
313
322 template <
323 typename... AccessorParams,
324 typename = std::enable_if_t<
325 sizeof...(AccessorParams) != 1 ||
326 !std::is_same<
327 range, std::decay<detail::head_t<AccessorParams...>>>::value>>
328 GKO_ATTRIBUTES constexpr explicit range(AccessorParams&&... params)
329 : accessor_{std::forward<AccessorParams>(params)...}
330 {}
331
344 template <typename... DimensionTypes>
345 GKO_ATTRIBUTES constexpr auto operator()(DimensionTypes&&... dimensions)
346 const -> decltype(std::declval<accessor>()(
347 std::forward<DimensionTypes>(dimensions)...))
348 {
349 static_assert(sizeof...(DimensionTypes) <= dimensionality,
350 "Too many dimensions in range call");
351 return accessor_(std::forward<DimensionTypes>(dimensions)...);
352 }
353
362 template <typename OtherAccessor>
363 GKO_ATTRIBUTES const range& operator=(
364 const range<OtherAccessor>& other) const
365 {
366 GKO_ASSERT(detail::equal_dimensions(*this, other));
367 accessor_.copy_from(other);
368 return *this;
369 }
370
384 GKO_ATTRIBUTES const range& operator=(const range& other) const
385 {
386 GKO_ASSERT(detail::equal_dimensions(*this, other));
387 accessor_.copy_from(other.get_accessor());
388 return *this;
389 }
390
391 range(const range& other) = default;
392
400 GKO_ATTRIBUTES constexpr size_type length(size_type dimension) const
401 {
402 return accessor_.length(dimension);
403 }
404
412 GKO_ATTRIBUTES constexpr const accessor* operator->() const noexcept
413 {
414 return &accessor_;
415 }
416
422 GKO_ATTRIBUTES constexpr const accessor& get_accessor() const noexcept
423 {
424 return accessor_;
425 }
426
427private:
428 accessor accessor_;
429};
430
431
432// implementation of range operations follows
433// (you probably should not have to look at this unless you're interested in the
434// gory details)
435
436
437namespace detail {
438
439
440enum class operation_kind { range_by_range, scalar_by_range, range_by_scalar };
441
442
443template <typename Accessor, typename Operation>
444struct implement_unary_operation {
445 using accessor = Accessor;
446 static constexpr size_type dimensionality = accessor::dimensionality;
447
448 GKO_ATTRIBUTES constexpr explicit implement_unary_operation(
449 const Accessor& operand)
450 : operand{operand}
451 {}
452
453 template <typename... DimensionTypes>
454 GKO_ATTRIBUTES constexpr auto operator()(
455 const DimensionTypes&... dimensions) const
456 -> decltype(Operation::evaluate(std::declval<accessor>(),
457 dimensions...))
458 {
459 return Operation::evaluate(operand, dimensions...);
460 }
461
462 GKO_ATTRIBUTES constexpr size_type length(size_type dimension) const
463 {
464 return operand.length(dimension);
465 }
466
467 template <typename OtherAccessor>
468 GKO_ATTRIBUTES void copy_from(const OtherAccessor& other) const = delete;
469
470 const accessor operand;
471};
472
473
474template <operation_kind Kind, typename FirstOperand, typename SecondOperand,
475 typename Operation>
476struct implement_binary_operation {};
477
478template <typename FirstAccessor, typename SecondAccessor, typename Operation>
479struct implement_binary_operation<operation_kind::range_by_range, FirstAccessor,
480 SecondAccessor, Operation> {
481 using first_accessor = FirstAccessor;
482 using second_accessor = SecondAccessor;
483 static_assert(first_accessor::dimensionality ==
484 second_accessor::dimensionality,
485 "Both ranges need to have the same number of dimensions");
486 static constexpr size_type dimensionality = first_accessor::dimensionality;
487
488 GKO_ATTRIBUTES explicit implement_binary_operation(
489 const FirstAccessor& first, const SecondAccessor& second)
490 : first{first}, second{second}
491 {
492 GKO_ASSERT(gko::detail::equal_dimensions(first, second));
493 }
494
495 template <typename... DimensionTypes>
496 GKO_ATTRIBUTES constexpr auto operator()(
497 const DimensionTypes&... dimensions) const
498 -> decltype(Operation::evaluate_range_by_range(
499 std::declval<first_accessor>(), std::declval<second_accessor>(),
500 dimensions...))
501 {
502 return Operation::evaluate_range_by_range(first, second, dimensions...);
503 }
504
505 GKO_ATTRIBUTES constexpr size_type length(size_type dimension) const
506 {
507 return first.length(dimension);
508 }
509
510 template <typename OtherAccessor>
511 GKO_ATTRIBUTES void copy_from(const OtherAccessor& other) const = delete;
512
513 const first_accessor first;
514 const second_accessor second;
515};
516
517template <typename FirstOperand, typename SecondAccessor, typename Operation>
518struct implement_binary_operation<operation_kind::scalar_by_range, FirstOperand,
519 SecondAccessor, Operation> {
520 using second_accessor = SecondAccessor;
521 static constexpr size_type dimensionality = second_accessor::dimensionality;
522
523 GKO_ATTRIBUTES constexpr explicit implement_binary_operation(
524 const FirstOperand& first, const SecondAccessor& second)
525 : first{first}, second{second}
526 {}
527
528 template <typename... DimensionTypes>
529 GKO_ATTRIBUTES constexpr auto operator()(
530 const DimensionTypes&... dimensions) const
531 -> decltype(Operation::evaluate_scalar_by_range(
532 std::declval<FirstOperand>(), std::declval<second_accessor>(),
533 dimensions...))
534 {
535 return Operation::evaluate_scalar_by_range(first, second,
536 dimensions...);
537 }
538
539 GKO_ATTRIBUTES constexpr size_type length(size_type dimension) const
540 {
541 return second.length(dimension);
542 }
543
544 template <typename OtherAccessor>
545 GKO_ATTRIBUTES void copy_from(const OtherAccessor& other) const = delete;
546
547 const FirstOperand first;
548 const second_accessor second;
549};
550
551template <typename FirstAccessor, typename SecondOperand, typename Operation>
552struct implement_binary_operation<operation_kind::range_by_scalar,
553 FirstAccessor, SecondOperand, Operation> {
554 using first_accessor = FirstAccessor;
555 static constexpr size_type dimensionality = first_accessor::dimensionality;
556
557 GKO_ATTRIBUTES constexpr explicit implement_binary_operation(
558 const FirstAccessor& first, const SecondOperand& second)
559 : first{first}, second{second}
560 {}
561
562 template <typename... DimensionTypes>
563 GKO_ATTRIBUTES constexpr auto operator()(
564 const DimensionTypes&... dimensions) const
565 -> decltype(Operation::evaluate_range_by_scalar(
566 std::declval<first_accessor>(), std::declval<SecondOperand>(),
567 dimensions...))
568 {
569 return Operation::evaluate_range_by_scalar(first, second,
570 dimensions...);
571 }
572
573 GKO_ATTRIBUTES constexpr size_type length(size_type dimension) const
574 {
575 return first.length(dimension);
576 }
577
578 template <typename OtherAccessor>
579 GKO_ATTRIBUTES void copy_from(const OtherAccessor& other) const = delete;
580
581 const first_accessor first;
582 const SecondOperand second;
583};
584
585
586} // namespace detail
587
588#define GKO_DEPRECATED_UNARY_RANGE_OPERATION(_operation_deprecated_name, \
589 _operation_name) \
590 namespace accessor { \
591 template <typename Operand> \
592 struct GKO_DEPRECATED("Please use " #_operation_name) \
593 _operation_deprecated_name : _operation_name<Operand> {}; \
594 } \
595 static_assert(true, \
596 "This assert is used to counter the false positive extra " \
597 "semi-colon warnings")
598
599
600#define GKO_ENABLE_UNARY_RANGE_OPERATION(_operation_name, _operator_name, \
601 _operator) \
602 namespace accessor { \
603 template <typename Operand> \
604 struct _operation_name \
605 : ::gko::detail::implement_unary_operation<Operand, \
606 ::gko::_operator> { \
607 using ::gko::detail::implement_unary_operation< \
608 Operand, ::gko::_operator>::implement_unary_operation; \
609 }; \
610 } \
611 GKO_BIND_UNARY_RANGE_OPERATION_TO_OPERATOR(_operation_name, _operator_name)
612
613
614#define GKO_BIND_UNARY_RANGE_OPERATION_TO_OPERATOR(_operation_name, \
615 _operator_name) \
616 template <typename Accessor> \
617 GKO_ATTRIBUTES constexpr GKO_INLINE \
618 range<accessor::_operation_name<Accessor>> \
619 _operator_name(const range<Accessor>& operand) \
620 { \
621 return range<accessor::_operation_name<Accessor>>( \
622 operand.get_accessor()); \
623 } \
624 static_assert(true, \
625 "This assert is used to counter the false positive extra " \
626 "semi-colon warnings")
627
628
629#define GKO_DEFINE_SIMPLE_UNARY_OPERATION(_name, ...) \
630 struct _name { \
631 private: \
632 template <typename Operand> \
633 GKO_ATTRIBUTES static constexpr auto simple_evaluate_impl( \
634 const Operand& operand) -> decltype(__VA_ARGS__) \
635 { \
636 return __VA_ARGS__; \
637 } \
638 \
639 public: \
640 template <typename AccessorType, typename... DimensionTypes> \
641 GKO_ATTRIBUTES static constexpr auto evaluate( \
642 const AccessorType& accessor, const DimensionTypes&... dimensions) \
643 -> decltype(simple_evaluate_impl(accessor(dimensions...))) \
644 { \
645 return simple_evaluate_impl(accessor(dimensions...)); \
646 } \
647 }
648
649
650namespace accessor {
651namespace detail {
652
653
654// unary arithmetic
655GKO_DEFINE_SIMPLE_UNARY_OPERATION(unary_plus, +operand);
656GKO_DEFINE_SIMPLE_UNARY_OPERATION(unary_minus, -operand);
657
658// unary logical
659GKO_DEFINE_SIMPLE_UNARY_OPERATION(logical_not, !operand);
660
661// unary bitwise
662GKO_DEFINE_SIMPLE_UNARY_OPERATION(bitwise_not, ~(operand));
663
664// common functions
665GKO_DEFINE_SIMPLE_UNARY_OPERATION(zero_operation, zero(operand));
666GKO_DEFINE_SIMPLE_UNARY_OPERATION(one_operation, one(operand));
667GKO_DEFINE_SIMPLE_UNARY_OPERATION(abs_operation, abs(operand));
668GKO_DEFINE_SIMPLE_UNARY_OPERATION(real_operation, real(operand));
669GKO_DEFINE_SIMPLE_UNARY_OPERATION(imag_operation, imag(operand));
670GKO_DEFINE_SIMPLE_UNARY_OPERATION(conj_operation, conj(operand));
671GKO_DEFINE_SIMPLE_UNARY_OPERATION(squared_norm_operation,
672 squared_norm(operand));
673
674} // namespace detail
675} // namespace accessor
676
677
678// unary arithmetic
679GKO_ENABLE_UNARY_RANGE_OPERATION(unary_plus, operator+,
680 accessor::detail::unary_plus);
681GKO_ENABLE_UNARY_RANGE_OPERATION(unary_minus, operator-,
682 accessor::detail::unary_minus);
683
684// unary logical
685GKO_ENABLE_UNARY_RANGE_OPERATION(logical_not, operator!,
686 accessor::detail::logical_not);
687
688// unary bitwise
689GKO_ENABLE_UNARY_RANGE_OPERATION(bitwise_not, operator~,
690 accessor::detail::bitwise_not);
691
692// common unary functions
693
694GKO_ENABLE_UNARY_RANGE_OPERATION(zero_operation, zero,
695 accessor::detail::zero_operation);
696GKO_ENABLE_UNARY_RANGE_OPERATION(one_operation, one,
697 accessor::detail::one_operation);
698GKO_ENABLE_UNARY_RANGE_OPERATION(abs_operation, abs,
699 accessor::detail::abs_operation);
700GKO_ENABLE_UNARY_RANGE_OPERATION(real_operation, real,
701 accessor::detail::real_operation);
702GKO_ENABLE_UNARY_RANGE_OPERATION(imag_operation, imag,
703 accessor::detail::imag_operation);
704GKO_ENABLE_UNARY_RANGE_OPERATION(conj_operation, conj,
705 accessor::detail::conj_operation);
706GKO_ENABLE_UNARY_RANGE_OPERATION(squared_norm_operation, squared_norm,
707 accessor::detail::squared_norm_operation);
708
709GKO_DEPRECATED_UNARY_RANGE_OPERATION(one_operaton, one_operation);
710GKO_DEPRECATED_UNARY_RANGE_OPERATION(abs_operaton, abs_operation);
711GKO_DEPRECATED_UNARY_RANGE_OPERATION(real_operaton, real_operation);
712GKO_DEPRECATED_UNARY_RANGE_OPERATION(imag_operaton, imag_operation);
713GKO_DEPRECATED_UNARY_RANGE_OPERATION(conj_operaton, conj_operation);
714GKO_DEPRECATED_UNARY_RANGE_OPERATION(squared_norm_operaton,
716
717namespace accessor {
718
719
720template <typename Accessor>
722 using accessor = Accessor;
723 static constexpr size_type dimensionality = accessor::dimensionality;
724
725 GKO_ATTRIBUTES constexpr explicit transpose_operation(
726 const Accessor& operand)
727 : operand{operand}
728 {}
729
730 template <typename FirstDimensionType, typename SecondDimensionType,
731 typename... DimensionTypes>
732 GKO_ATTRIBUTES constexpr auto operator()(
733 const FirstDimensionType& first_dim,
734 const SecondDimensionType& second_dim,
735 const DimensionTypes&... dims) const
736 -> decltype(std::declval<accessor>()(second_dim, first_dim, dims...))
737 {
738 return operand(second_dim, first_dim, dims...);
739 }
740
741 GKO_ATTRIBUTES constexpr size_type length(size_type dimension) const
742 {
743 return dimension < 2 ? operand.length(dimension ^ 1)
744 : operand.length(dimension);
745 }
746
747 template <typename OtherAccessor>
748 GKO_ATTRIBUTES void copy_from(const OtherAccessor& other) const = delete;
749
750 const accessor operand;
751};
752
753
754} // namespace accessor
755
756
757GKO_BIND_UNARY_RANGE_OPERATION_TO_OPERATOR(transpose_operation, transpose);
758
759
760#undef GKO_DEPRECATED_UNARY_RANGE_OPERATION
761#undef GKO_DEFINE_SIMPLE_UNARY_OPERATION
762#undef GKO_ENABLE_UNARY_RANGE_OPERATION
763
764
765#define GKO_ENABLE_BINARY_RANGE_OPERATION(_operation_name, _operator_name, \
766 _operator) \
767 namespace accessor { \
768 template <::gko::detail::operation_kind Kind, typename FirstOperand, \
769 typename SecondOperand> \
770 struct _operation_name \
771 : ::gko::detail::implement_binary_operation< \
772 Kind, FirstOperand, SecondOperand, ::gko::_operator> { \
773 using ::gko::detail::implement_binary_operation< \
774 Kind, FirstOperand, SecondOperand, \
775 ::gko::_operator>::implement_binary_operation; \
776 }; \
777 } \
778 GKO_BIND_RANGE_OPERATION_TO_OPERATOR(_operation_name, _operator_name); \
779 static_assert(true, \
780 "This assert is used to counter the false positive extra " \
781 "semi-colon warnings")
782
783
784#define GKO_BIND_RANGE_OPERATION_TO_OPERATOR(_operation_name, _operator_name) \
785 template <typename Accessor> \
786 GKO_ATTRIBUTES constexpr GKO_INLINE range<accessor::_operation_name< \
787 ::gko::detail::operation_kind::range_by_range, Accessor, Accessor>> \
788 _operator_name(const range<Accessor>& first, \
789 const range<Accessor>& second) \
790 { \
791 return range<accessor::_operation_name< \
792 ::gko::detail::operation_kind::range_by_range, Accessor, \
793 Accessor>>(first.get_accessor(), second.get_accessor()); \
794 } \
795 \
796 template <typename FirstAccessor, typename SecondAccessor> \
797 GKO_ATTRIBUTES constexpr GKO_INLINE range<accessor::_operation_name< \
798 ::gko::detail::operation_kind::range_by_range, FirstAccessor, \
799 SecondAccessor>> \
800 _operator_name(const range<FirstAccessor>& first, \
801 const range<SecondAccessor>& second) \
802 { \
803 return range<accessor::_operation_name< \
804 ::gko::detail::operation_kind::range_by_range, FirstAccessor, \
805 SecondAccessor>>(first.get_accessor(), second.get_accessor()); \
806 } \
807 \
808 template <typename FirstAccessor, typename SecondOperand> \
809 GKO_ATTRIBUTES constexpr GKO_INLINE range<accessor::_operation_name< \
810 ::gko::detail::operation_kind::range_by_scalar, FirstAccessor, \
811 SecondOperand>> \
812 _operator_name(const range<FirstAccessor>& first, \
813 const SecondOperand& second) \
814 { \
815 return range<accessor::_operation_name< \
816 ::gko::detail::operation_kind::range_by_scalar, FirstAccessor, \
817 SecondOperand>>(first.get_accessor(), second); \
818 } \
819 \
820 template <typename FirstOperand, typename SecondAccessor> \
821 GKO_ATTRIBUTES constexpr GKO_INLINE range<accessor::_operation_name< \
822 ::gko::detail::operation_kind::scalar_by_range, FirstOperand, \
823 SecondAccessor>> \
824 _operator_name(const FirstOperand& first, \
825 const range<SecondAccessor>& second) \
826 { \
827 return range<accessor::_operation_name< \
828 ::gko::detail::operation_kind::scalar_by_range, FirstOperand, \
829 SecondAccessor>>(first, second.get_accessor()); \
830 } \
831 static_assert(true, \
832 "This assert is used to counter the false positive extra " \
833 "semi-colon warnings")
834
835
836#define GKO_DEPRECATED_SIMPLE_BINARY_OPERATION(_deprecated_name, _name) \
837 struct GKO_DEPRECATED("Please use " #_name) _deprecated_name : _name {}
838
839#define GKO_DEFINE_SIMPLE_BINARY_OPERATION(_name, ...) \
840 struct _name { \
841 private: \
842 template <typename FirstOperand, typename SecondOperand> \
843 GKO_ATTRIBUTES constexpr static auto simple_evaluate_impl( \
844 const FirstOperand& first, const SecondOperand& second) \
845 -> decltype(__VA_ARGS__) \
846 { \
847 return __VA_ARGS__; \
848 } \
849 \
850 public: \
851 template <typename FirstAccessor, typename SecondAccessor, \
852 typename... DimensionTypes> \
853 GKO_ATTRIBUTES static constexpr auto evaluate_range_by_range( \
854 const FirstAccessor& first, const SecondAccessor& second, \
855 const DimensionTypes&... dims) \
856 -> decltype(simple_evaluate_impl(first(dims...), second(dims...))) \
857 { \
858 return simple_evaluate_impl(first(dims...), second(dims...)); \
859 } \
860 \
861 template <typename FirstOperand, typename SecondAccessor, \
862 typename... DimensionTypes> \
863 GKO_ATTRIBUTES static constexpr auto evaluate_scalar_by_range( \
864 const FirstOperand& first, const SecondAccessor& second, \
865 const DimensionTypes&... dims) \
866 -> decltype(simple_evaluate_impl(first, second(dims...))) \
867 { \
868 return simple_evaluate_impl(first, second(dims...)); \
869 } \
870 \
871 template <typename FirstAccessor, typename SecondOperand, \
872 typename... DimensionTypes> \
873 GKO_ATTRIBUTES static constexpr auto evaluate_range_by_scalar( \
874 const FirstAccessor& first, const SecondOperand& second, \
875 const DimensionTypes&... dims) \
876 -> decltype(simple_evaluate_impl(first(dims...), second)) \
877 { \
878 return simple_evaluate_impl(first(dims...), second); \
879 } \
880 }
881
882
883namespace accessor {
884namespace detail {
885
886
887// binary arithmetic
888GKO_DEFINE_SIMPLE_BINARY_OPERATION(add, first + second);
889GKO_DEFINE_SIMPLE_BINARY_OPERATION(sub, first - second);
890GKO_DEFINE_SIMPLE_BINARY_OPERATION(mul, first* second);
891GKO_DEFINE_SIMPLE_BINARY_OPERATION(div, first / second);
892GKO_DEFINE_SIMPLE_BINARY_OPERATION(mod, first % second);
893
894// relational
895GKO_DEFINE_SIMPLE_BINARY_OPERATION(less, first < second);
896GKO_DEFINE_SIMPLE_BINARY_OPERATION(greater, first > second);
897GKO_DEFINE_SIMPLE_BINARY_OPERATION(less_or_equal, first <= second);
898GKO_DEFINE_SIMPLE_BINARY_OPERATION(greater_or_equal, first >= second);
899GKO_DEFINE_SIMPLE_BINARY_OPERATION(equal, first == second);
900GKO_DEFINE_SIMPLE_BINARY_OPERATION(not_equal, first != second);
901
902// binary logical
903GKO_DEFINE_SIMPLE_BINARY_OPERATION(logical_or, first || second);
904GKO_DEFINE_SIMPLE_BINARY_OPERATION(logical_and, first&& second);
905
906// binary bitwise
907GKO_DEFINE_SIMPLE_BINARY_OPERATION(bitwise_or, first | second);
908GKO_DEFINE_SIMPLE_BINARY_OPERATION(bitwise_and, first& second);
909GKO_DEFINE_SIMPLE_BINARY_OPERATION(bitwise_xor, first ^ second);
910GKO_DEFINE_SIMPLE_BINARY_OPERATION(left_shift, first << second);
911GKO_DEFINE_SIMPLE_BINARY_OPERATION(right_shift, first >> second);
912
913// common binary functions
914GKO_DEFINE_SIMPLE_BINARY_OPERATION(max_operation, max(first, second));
915GKO_DEFINE_SIMPLE_BINARY_OPERATION(min_operation, min(first, second));
916
917GKO_DEPRECATED_SIMPLE_BINARY_OPERATION(max_operaton, max_operation);
918GKO_DEPRECATED_SIMPLE_BINARY_OPERATION(min_operaton, min_operation);
919} // namespace detail
920} // namespace accessor
921
922
923// binary arithmetic
924GKO_ENABLE_BINARY_RANGE_OPERATION(add, operator+, accessor::detail::add);
925GKO_ENABLE_BINARY_RANGE_OPERATION(sub, operator-, accessor::detail::sub);
926GKO_ENABLE_BINARY_RANGE_OPERATION(mul, operator*, accessor::detail::mul);
927GKO_ENABLE_BINARY_RANGE_OPERATION(div, operator/, accessor::detail::div);
928GKO_ENABLE_BINARY_RANGE_OPERATION(mod, operator%, accessor::detail::mod);
929
930// relational
931GKO_ENABLE_BINARY_RANGE_OPERATION(less, operator<, accessor::detail::less);
932GKO_ENABLE_BINARY_RANGE_OPERATION(greater, operator>,
933 accessor::detail::greater);
934GKO_ENABLE_BINARY_RANGE_OPERATION(less_or_equal, operator<=,
935 accessor::detail::less_or_equal);
936GKO_ENABLE_BINARY_RANGE_OPERATION(greater_or_equal, operator>=,
937 accessor::detail::greater_or_equal);
938GKO_ENABLE_BINARY_RANGE_OPERATION(equal, operator==, accessor::detail::equal);
939GKO_ENABLE_BINARY_RANGE_OPERATION(not_equal, operator!=,
940 accessor::detail::not_equal);
941
942// binary logical
943GKO_ENABLE_BINARY_RANGE_OPERATION(logical_or, operator||,
944 accessor::detail::logical_or);
945GKO_ENABLE_BINARY_RANGE_OPERATION(logical_and, operator&&,
946 accessor::detail::logical_and);
947
948// binary bitwise
949GKO_ENABLE_BINARY_RANGE_OPERATION(bitwise_or, operator|,
950 accessor::detail::bitwise_or);
951GKO_ENABLE_BINARY_RANGE_OPERATION(bitwise_and, operator&,
952 accessor::detail::bitwise_and);
953GKO_ENABLE_BINARY_RANGE_OPERATION(bitwise_xor, operator^,
954 accessor::detail::bitwise_xor);
955GKO_ENABLE_BINARY_RANGE_OPERATION(left_shift, operator<<,
956 accessor::detail::left_shift);
957GKO_ENABLE_BINARY_RANGE_OPERATION(right_shift, operator>>,
958 accessor::detail::right_shift);
959
960// common binary functions
961GKO_ENABLE_BINARY_RANGE_OPERATION(max_operation, max,
962 accessor::detail::max_operation);
963GKO_ENABLE_BINARY_RANGE_OPERATION(min_operation, min,
964 accessor::detail::min_operation);
965
966
967// special binary range functions
968namespace accessor {
969
970
971template <gko::detail::operation_kind Kind, typename FirstAccessor,
972 typename SecondAccessor>
974 static_assert(Kind == gko::detail::operation_kind::range_by_range,
975 "Matrix multiplication expects both operands to be ranges");
976 using first_accessor = FirstAccessor;
977 using second_accessor = SecondAccessor;
978 static_assert(first_accessor::dimensionality ==
979 second_accessor::dimensionality,
980 "Both ranges need to have the same number of dimensions");
981 static constexpr size_type dimensionality = first_accessor::dimensionality;
982
983 GKO_ATTRIBUTES explicit mmul_operation(const FirstAccessor& first,
984 const SecondAccessor& second)
985 : first{first}, second{second}
986 {
987 GKO_ASSERT(first.length(1) == second.length(0));
988 GKO_ASSERT(gko::detail::equal_dimensions<2>(first, second));
989 }
990
991 template <typename FirstDimension, typename SecondDimension,
992 typename... DimensionTypes>
993 GKO_ATTRIBUTES auto operator()(const FirstDimension& row,
994 const SecondDimension& col,
995 const DimensionTypes&... rest) const
996 -> decltype(std::declval<FirstAccessor>()(row, 0, rest...) *
997 std::declval<SecondAccessor>()(0, col, rest...) +
998 std::declval<FirstAccessor>()(row, 1, rest...) *
999 std::declval<SecondAccessor>()(1, col, rest...))
1000 {
1001 using result_type =
1002 decltype(first(row, 0, rest...) * second(0, col, rest...) +
1003 first(row, 1, rest...) * second(1, col, rest...));
1004 GKO_ASSERT(first.length(1) == second.length(0));
1005 auto result = zero<result_type>();
1006 const auto size = first.length(1);
1007 for (auto i = zero(size); i < size; ++i) {
1008 result += first(row, i, rest...) * second(i, col, rest...);
1009 }
1010 return result;
1011 }
1012
1013 GKO_ATTRIBUTES constexpr size_type length(size_type dimension) const
1014 {
1015 return dimension == 1 ? second.length(1) : first.length(dimension);
1016 }
1017
1018 template <typename OtherAccessor>
1019 GKO_ATTRIBUTES void copy_from(const OtherAccessor& other) const = delete;
1020
1021 const first_accessor first;
1022 const second_accessor second;
1023};
1024
1025
1026} // namespace accessor
1027
1028
1029GKO_BIND_RANGE_OPERATION_TO_OPERATOR(mmul_operation, mmul);
1030
1031
1032#undef GKO_DEFINE_SIMPLE_BINARY_OPERATION
1033#undef GKO_ENABLE_BINARY_RANGE_OPERATION
1034
1035
1036} // namespace gko
1037
1038
1039#endif // GKO_PUBLIC_CORE_BASE_RANGE_HPP_
A range is a multidimensional view of the memory.
Definition range.hpp:297
Accessor accessor
The type of the underlying accessor.
Definition range.hpp:302
constexpr auto operator()(DimensionTypes &&... dimensions) const -> decltype(std::declval< accessor >()(std::forward< DimensionTypes >(dimensions)...))
Returns a value (or a sub-range) with the specified indexes.
Definition range.hpp:345
constexpr size_type length(size_type dimension) const
Returns the length of the specified dimension of the range.
Definition range.hpp:400
constexpr const accessor * operator->() const noexcept
Returns a pointer to the accessor.
Definition range.hpp:412
static constexpr size_type dimensionality
The number of dimensions of the range.
Definition range.hpp:307
const range & operator=(const range &other) const
Assigns another range to this range.
Definition range.hpp:384
constexpr const accessor & get_accessor() const noexcept
`Returns a reference to the accessor.
Definition range.hpp:422
~range()=default
Use the default destructor.
const range & operator=(const range< OtherAccessor > &other) const
Definition range.hpp:363
constexpr range(AccessorParams &&... params)
Creates a new range.
Definition range.hpp:328
The Ginkgo namespace.
Definition abstract_factory.hpp:20
constexpr T one()
Returns the multiplicative identity for T.
Definition math.hpp:630
constexpr std::enable_if_t<!is_complex_s< T >::value, T > abs(const T &x)
Returns the absolute value of the object.
Definition math.hpp:931
constexpr T zero()
Returns the additive identity for T.
Definition math.hpp:602
constexpr auto imag(const T &x)
Returns the imaginary part of the object.
Definition math.hpp:885
std::size_t size_type
Integral type used for allocation quantities.
Definition types.hpp:89
constexpr T min(const T &x, const T &y)
Returns the smaller of the arguments.
Definition math.hpp:719
batch_dim< 2, DimensionType > transpose(const batch_dim< 2, DimensionType > &input)
Returns a batch_dim object with its dimensions swapped for batched operators.
Definition batch_dim.hpp:119
constexpr auto squared_norm(const T &x) -> decltype(real(conj(x) *x))
Returns the squared norm of the object.
Definition math.hpp:913
constexpr auto conj(const T &x)
Returns the conjugate of an object.
Definition math.hpp:899
constexpr T max(const T &x, const T &y)
Returns the larger of the arguments.
Definition math.hpp:701
constexpr auto real(const T &x)
Returns the real part of the object.
Definition math.hpp:869
Definition range.hpp:699
Definition range.hpp:710
Definition range.hpp:924
Definition range.hpp:952
Definition range.hpp:690
Definition range.hpp:950
Definition range.hpp:954
Definition range.hpp:705
Definition range.hpp:713
Definition range.hpp:927
Definition range.hpp:938
Definition range.hpp:937
Definition range.hpp:933
Definition range.hpp:703
Definition range.hpp:712
Definition range.hpp:956
Definition range.hpp:935
Definition range.hpp:931
Definition range.hpp:946
Definition range.hpp:686
Definition range.hpp:944
Definition range.hpp:962
Definition range.hpp:964
Definition range.hpp:973
Definition range.hpp:928
Definition range.hpp:926
Definition range.hpp:940
Definition range.hpp:697
Definition range.hpp:709
Definition range.hpp:701
Definition range.hpp:711
Definition range.hpp:958
Definition range.hpp:715
Definition range.hpp:925
Definition range.hpp:721
Definition range.hpp:682
Definition range.hpp:680
Definition range.hpp:695
A span is a lightweight structure used to create sub-ranges from other ranges.
Definition range.hpp:46
constexpr span(size_type begin, size_type end) noexcept
Creates a span.
Definition range.hpp:64
constexpr span(size_type point) noexcept
Creates a span representing a point point.
Definition range.hpp:54
constexpr bool is_valid() const
Checks if a span is valid.
Definition range.hpp:73
constexpr size_type length() const
Returns the length of a span.
Definition range.hpp:80
const size_type begin
Beginning of the span.
Definition range.hpp:85
const size_type end
End of the span.
Definition range.hpp:90