aboutsummaryrefslogtreecommitdiffstats
path: root/include/pybind11/eigen.h
diff options
context:
space:
mode:
Diffstat (limited to 'include/pybind11/eigen.h')
-rw-r--r--include/pybind11/eigen.h261
1 files changed, 261 insertions, 0 deletions
diff --git a/include/pybind11/eigen.h b/include/pybind11/eigen.h
new file mode 100644
index 0000000..f2f0985
--- /dev/null
+++ b/include/pybind11/eigen.h
@@ -0,0 +1,261 @@
+/*
+ pybind11/eigen.h: Transparent conversion for dense and sparse Eigen matrices
+
+ Copyright (c) 2016 Wenzel Jakob <wenzel.jakob@epfl.ch>
+
+ All rights reserved. Use of this source code is governed by a
+ BSD-style license that can be found in the LICENSE file.
+*/
+
+#pragma once
+
+#include "numpy.h"
+#include <Eigen/Core>
+#include <Eigen/SparseCore>
+
+#if defined(_MSC_VER)
+#pragma warning(push)
+#pragma warning(disable: 4127) // warning C4127: Conditional expression is constant
+#endif
+
+NAMESPACE_BEGIN(pybind11)
+NAMESPACE_BEGIN(detail)
+
+template <typename T> class is_eigen_dense {
+private:
+ template<typename Derived> static std::true_type test(const Eigen::DenseBase<Derived> &);
+ static std::false_type test(...);
+public:
+ static constexpr bool value = decltype(test(std::declval<T>()))::value;
+};
+
+template <typename T> class is_eigen_sparse {
+private:
+ template<typename Derived> static std::true_type test(const Eigen::SparseMatrixBase<Derived> &);
+ static std::false_type test(...);
+public:
+ static constexpr bool value = decltype(test(std::declval<T>()))::value;
+};
+
+template<typename Type>
+struct type_caster<Type, typename std::enable_if<is_eigen_dense<Type>::value>::type> {
+ typedef typename Type::Scalar Scalar;
+ static constexpr bool rowMajor = Type::Flags & Eigen::RowMajorBit;
+
+ bool load(handle src, bool) {
+ array_t<Scalar> buffer(src, true);
+ if (!buffer.check())
+ return false;
+
+ buffer_info info = buffer.request();
+ if (info.ndim == 1) {
+ typedef Eigen::Stride<Eigen::Dynamic, 0> Strides;
+ if (!Type::IsVectorAtCompileTime &&
+ !(Type::RowsAtCompileTime == Eigen::Dynamic &&
+ Type::ColsAtCompileTime == Eigen::Dynamic))
+ return false;
+
+ if (Type::SizeAtCompileTime != Eigen::Dynamic &&
+ info.shape[0] != (size_t) Type::SizeAtCompileTime)
+ return false;
+
+ auto strides = Strides(info.strides[0] / sizeof(Scalar), 0);
+
+ value = Eigen::Map<Type, 0, Strides>(
+ (Scalar *) info.ptr, info.shape[0], 1, strides);
+ } else if (info.ndim == 2) {
+ typedef Eigen::Stride<Eigen::Dynamic, Eigen::Dynamic> Strides;
+
+ if ((Type::RowsAtCompileTime != Eigen::Dynamic && info.shape[0] != (size_t) Type::RowsAtCompileTime) ||
+ (Type::ColsAtCompileTime != Eigen::Dynamic && info.shape[1] != (size_t) Type::ColsAtCompileTime))
+ return false;
+
+ auto strides = Strides(
+ info.strides[rowMajor ? 0 : 1] / sizeof(Scalar),
+ info.strides[rowMajor ? 1 : 0] / sizeof(Scalar));
+
+ value = Eigen::Map<Type, 0, Strides>(
+ (Scalar *) info.ptr, info.shape[0], info.shape[1], strides);
+ } else {
+ return false;
+ }
+ return true;
+ }
+
+ static handle cast(const Type *src, return_value_policy policy, handle parent) {
+ return cast(*src, policy, parent);
+ }
+
+ static handle cast(const Type &src, return_value_policy /* policy */, handle /* parent */) {
+ array result(buffer_info(
+ /* Pointer to buffer */
+ const_cast<Scalar *>(src.data()),
+ /* Size of one scalar */
+ sizeof(Scalar),
+ /* Python struct-style format descriptor */
+ format_descriptor<Scalar>::value,
+ /* Number of dimensions */
+ 2,
+ /* Buffer dimensions */
+ { (size_t) src.rows(),
+ (size_t) src.cols() },
+ /* Strides (in bytes) for each index */
+ { sizeof(Scalar) * (rowMajor ? src.cols() : 1),
+ sizeof(Scalar) * (rowMajor ? 1 : src.rows()) }
+ ));
+ return result.release();
+ }
+
+ template <typename _T> using cast_op_type = pybind11::detail::cast_op_type<_T>;
+
+ static PYBIND11_DESCR name() {
+ return _("numpy.ndarray[dtype=") + npy_format_descriptor<Scalar>::name() +
+ _(", shape=(") + rows() + _(", ") + cols() + _(")]");
+ }
+
+ operator Type*() { return &value; }
+ operator Type&() { return value; }
+
+private:
+ template <typename T = Type, typename std::enable_if<T::RowsAtCompileTime == Eigen::Dynamic, int>::type = 0>
+ static PYBIND11_DESCR rows() { return _("m"); }
+ template <typename T = Type, typename std::enable_if<T::RowsAtCompileTime != Eigen::Dynamic, int>::type = 0>
+ static PYBIND11_DESCR rows() { return _<T::RowsAtCompileTime>(); }
+ template <typename T = Type, typename std::enable_if<T::ColsAtCompileTime == Eigen::Dynamic, int>::type = 0>
+ static PYBIND11_DESCR cols() { return _("n"); }
+ template <typename T = Type, typename std::enable_if<T::ColsAtCompileTime != Eigen::Dynamic, int>::type = 0>
+ static PYBIND11_DESCR cols() { return _<T::ColsAtCompileTime>(); }
+
+private:
+ Type value;
+};
+
+template<typename Type>
+struct type_caster<Type, typename std::enable_if<is_eigen_sparse<Type>::value>::type> {
+ typedef typename Type::Scalar Scalar;
+ typedef typename std::remove_reference<decltype(*std::declval<Type>().outerIndexPtr())>::type StorageIndex;
+ typedef typename Type::Index Index;
+ static constexpr bool rowMajor = Type::Flags & Eigen::RowMajorBit;
+
+ bool load(handle src, bool) {
+ object obj(src, true);
+ object sparse_module = module::import("scipy.sparse");
+ object matrix_type = sparse_module.attr(
+ rowMajor ? "csr_matrix" : "csc_matrix");
+
+ if (obj.get_type() != matrix_type.ptr()) {
+ try {
+ obj = matrix_type.call(obj);
+ } catch (const error_already_set &) {
+ PyErr_Clear();
+ return false;
+ }
+ }
+
+ auto valuesArray = array_t<Scalar>((object) obj.attr("data"));
+ auto innerIndicesArray = array_t<StorageIndex>((object) obj.attr("indices"));
+ auto outerIndicesArray = array_t<StorageIndex>((object) obj.attr("indptr"));
+ auto shape = pybind11::tuple((pybind11::object) obj.attr("shape"));
+ auto nnz = obj.attr("nnz").cast<Index>();
+
+ if (!valuesArray.check() || !innerIndicesArray.check() ||
+ !outerIndicesArray.check())
+ return false;
+
+ buffer_info outerIndices = outerIndicesArray.request();
+ buffer_info innerIndices = innerIndicesArray.request();
+ buffer_info values = valuesArray.request();
+
+ value = Eigen::MappedSparseMatrix<Scalar, Type::Flags, StorageIndex>(
+ shape[0].cast<Index>(),
+ shape[1].cast<Index>(),
+ nnz,
+ static_cast<StorageIndex *>(outerIndices.ptr),
+ static_cast<StorageIndex *>(innerIndices.ptr),
+ static_cast<Scalar *>(values.ptr)
+ );
+
+ return true;
+ }
+
+ static handle cast(const Type *src, return_value_policy policy, handle parent) {
+ return cast(*src, policy, parent);
+ }
+
+ static handle cast(const Type &src, return_value_policy /* policy */, handle /* parent */) {
+ const_cast<Type&>(src).makeCompressed();
+
+ object matrix_type = module::import("scipy.sparse").attr(
+ rowMajor ? "csr_matrix" : "csc_matrix");
+
+ array data(buffer_info(
+ // Pointer to buffer
+ const_cast<Scalar *>(src.valuePtr()),
+ // Size of one scalar
+ sizeof(Scalar),
+ // Python struct-style format descriptor
+ format_descriptor<Scalar>::value,
+ // Number of dimensions
+ 1,
+ // Buffer dimensions
+ { (size_t) src.nonZeros() },
+ // Strides
+ { sizeof(Scalar) }
+ ));
+
+ array outerIndices(buffer_info(
+ // Pointer to buffer
+ const_cast<StorageIndex *>(src.outerIndexPtr()),
+ // Size of one scalar
+ sizeof(StorageIndex),
+ // Python struct-style format descriptor
+ format_descriptor<StorageIndex>::value,
+ // Number of dimensions
+ 1,
+ // Buffer dimensions
+ { (size_t) (rowMajor ? src.rows() : src.cols()) + 1 },
+ // Strides
+ { sizeof(StorageIndex) }
+ ));
+
+ array innerIndices(buffer_info(
+ // Pointer to buffer
+ const_cast<StorageIndex *>(src.innerIndexPtr()),
+ // Size of one scalar
+ sizeof(StorageIndex),
+ // Python struct-style format descriptor
+ format_descriptor<StorageIndex>::value,
+ // Number of dimensions
+ 1,
+ // Buffer dimensions
+ { (size_t) src.nonZeros() },
+ // Strides
+ { sizeof(StorageIndex) }
+ ));
+
+ return matrix_type.call(
+ std::make_tuple(data, innerIndices, outerIndices),
+ std::make_pair(src.rows(), src.cols())
+ ).release();
+ }
+
+ template <typename _T> using cast_op_type = pybind11::detail::cast_op_type<_T>;
+
+ template <typename T = Type, typename std::enable_if<(T::Flags & Eigen::RowMajorBit) != 0, int>::type = 0>
+ static PYBIND11_DESCR name() { return _("scipy.sparse.csr_matrix[dtype=") + npy_format_descriptor<Scalar>::name() + _("]"); }
+ template <typename T = Type, typename std::enable_if<(T::Flags & Eigen::RowMajorBit) == 0, int>::type = 0>
+ static PYBIND11_DESCR name() { return _("scipy.sparse.csc_matrix[dtype=") + npy_format_descriptor<Scalar>::name() + _("]"); }
+
+ operator Type*() { return &value; }
+ operator Type&() { return value; }
+
+private:
+ Type value;
+};
+
+NAMESPACE_END(detail)
+NAMESPACE_END(pybind11)
+
+#if defined(_MSC_VER)
+#pragma warning(pop)
+#endif