diff options
author | ali-beep <54114435+ali-beep@users.noreply.github.com> | 2019-08-15 13:41:12 -0400 |
---|---|---|
committer | Wenzel Jakob <wenzel.jakob@epfl.ch> | 2019-08-15 19:41:11 +0200 |
commit | 5ef13eb680069680c41e89265d4f1105bd501846 (patch) | |
tree | 7e78747fd9032b25110e1a3f02cd653f1fdc1963 | |
parent | b2fdfd122827c1170c75702c05a4040997cf3bf5 (diff) | |
download | platform_external_python_pybind11-5ef13eb680069680c41e89265d4f1105bd501846.tar.gz platform_external_python_pybind11-5ef13eb680069680c41e89265d4f1105bd501846.tar.bz2 platform_external_python_pybind11-5ef13eb680069680c41e89265d4f1105bd501846.zip |
Add negative indexing support to stl_bind. (#1882)
-rw-r--r-- | include/pybind11/stl_bind.h | 65 | ||||
-rw-r--r-- | tests/test_stl_binders.py | 10 |
2 files changed, 52 insertions, 23 deletions
diff --git a/include/pybind11/stl_bind.h b/include/pybind11/stl_bind.h index 1f87252..d3adaed 100644 --- a/include/pybind11/stl_bind.h +++ b/include/pybind11/stl_bind.h @@ -115,6 +115,14 @@ void vector_modifiers(enable_if_t<is_copy_constructible<typename Vector::value_t using SizeType = typename Vector::size_type; using DiffType = typename Vector::difference_type; + auto wrap_i = [](DiffType i, SizeType n) { + if (i < 0) + i += n; + if (i < 0 || (SizeType)i >= n) + throw index_error(); + return i; + }; + cl.def("append", [](Vector &v, const T &value) { v.push_back(value); }, arg("x"), @@ -159,10 +167,13 @@ void vector_modifiers(enable_if_t<is_copy_constructible<typename Vector::value_t ); cl.def("insert", - [](Vector &v, SizeType i, const T &x) { - if (i > v.size()) + [](Vector &v, DiffType i, const T &x) { + // Can't use wrap_i; i == v.size() is OK + if (i < 0) + i += v.size(); + if (i < 0 || (SizeType)i > v.size()) throw index_error(); - v.insert(v.begin() + (DiffType) i, x); + v.insert(v.begin() + i, x); }, arg("i") , arg("x"), "Insert an item at a given position." @@ -180,11 +191,10 @@ void vector_modifiers(enable_if_t<is_copy_constructible<typename Vector::value_t ); cl.def("pop", - [](Vector &v, SizeType i) { - if (i >= v.size()) - throw index_error(); - T t = v[i]; - v.erase(v.begin() + (DiffType) i); + [wrap_i](Vector &v, DiffType i) { + i = wrap_i(i, v.size()); + T t = v[(SizeType) i]; + v.erase(v.begin() + i); return t; }, arg("i"), @@ -192,10 +202,9 @@ void vector_modifiers(enable_if_t<is_copy_constructible<typename Vector::value_t ); cl.def("__setitem__", - [](Vector &v, SizeType i, const T &t) { - if (i >= v.size()) - throw index_error(); - v[i] = t; + [wrap_i](Vector &v, DiffType i, const T &t) { + i = wrap_i(i, v.size()); + v[(SizeType)i] = t; } ); @@ -238,10 +247,9 @@ void vector_modifiers(enable_if_t<is_copy_constructible<typename Vector::value_t ); cl.def("__delitem__", - [](Vector &v, SizeType i) { - if (i >= v.size()) - throw index_error(); - v.erase(v.begin() + DiffType(i)); + [wrap_i](Vector &v, DiffType i) { + i = wrap_i(i, v.size()); + v.erase(v.begin() + i); }, "Delete the list elements at index ``i``" ); @@ -277,13 +285,21 @@ template <typename Vector, typename Class_> void vector_accessor(enable_if_t<!vector_needs_copy<Vector>::value, Class_> &cl) { using T = typename Vector::value_type; using SizeType = typename Vector::size_type; + using DiffType = typename Vector::difference_type; using ItType = typename Vector::iterator; + auto wrap_i = [](DiffType i, SizeType n) { + if (i < 0) + i += n; + if (i < 0 || (SizeType)i >= n) + throw index_error(); + return i; + }; + cl.def("__getitem__", - [](Vector &v, SizeType i) -> T & { - if (i >= v.size()) - throw index_error(); - return v[i]; + [wrap_i](Vector &v, DiffType i) -> T & { + i = wrap_i(i, v.size()); + return v[(SizeType)i]; }, return_value_policy::reference_internal // ref + keepalive ); @@ -303,12 +319,15 @@ template <typename Vector, typename Class_> void vector_accessor(enable_if_t<vector_needs_copy<Vector>::value, Class_> &cl) { using T = typename Vector::value_type; using SizeType = typename Vector::size_type; + using DiffType = typename Vector::difference_type; using ItType = typename Vector::iterator; cl.def("__getitem__", - [](const Vector &v, SizeType i) -> T { - if (i >= v.size()) + [](const Vector &v, DiffType i) -> T { + if (i < 0 && (i += v.size()) < 0) + throw index_error(); + if ((SizeType)i >= v.size()) throw index_error(); - return v[i]; + return v[(SizeType)i]; } ); diff --git a/tests/test_stl_binders.py b/tests/test_stl_binders.py index 52c8ac0..6d5a159 100644 --- a/tests/test_stl_binders.py +++ b/tests/test_stl_binders.py @@ -53,6 +53,16 @@ def test_vector_int(): v_int2.extend(x for x in range(5)) assert v_int2 == m.VectorInt([0, 99, 2, 3, 4, 5, 6, 7, 0, 1, 2, 3, 4]) + # test negative indexing + assert v_int2[-1] == 4 + + # insert with negative index + v_int2.insert(-1, 88) + assert v_int2 == m.VectorInt([0, 99, 2, 3, 4, 5, 6, 7, 0, 1, 2, 3, 88, 4]) + + # delete negative index + del v_int2[-1] + assert v_int2 == m.VectorInt([0, 99, 2, 3, 4, 5, 6, 7, 0, 1, 2, 3, 88]) # related to the PyPy's buffer protocol. @pytest.unsupported_on_pypy |