diff --git a/include/Accessor.hpp b/include/Accessor.hpp index f9b4d8ed3..52b1dd47a 100644 --- a/include/Accessor.hpp +++ b/include/Accessor.hpp @@ -84,6 +84,9 @@ namespace cytnx { Accessor(const Accessor &rhs); // copy assignment: Accessor &operator=(const Accessor &rhs); + + // check equality + bool operator==(const Accessor &rhs) const; ///@endcond int type() const { return this->_type; } diff --git a/include/Tensor.hpp b/include/Tensor.hpp index dd68a021f..0f9cfba1b 100644 --- a/include/Tensor.hpp +++ b/include/Tensor.hpp @@ -1006,10 +1006,13 @@ namespace cytnx { /** @brief get elements using Accessor (C++ API) / slices (python API) @param[in] accessors the Accessor (C++ API) / slices (python API) to get the elements. + @param[out] removed an ascending list of indices that were removed from the original shape of + the Tensor. @return [Tensor] @see \link cytnx::Accessor Accessor\endlink for cordinate with Accessor in C++ API. @note - 1. the return will be a new Tensor instance, which not share memory with the current Tensor. + The return will be a new Tensor instance, which does not share memory with the current + Tensor. ## Equivalently: One can also using more intruisive way to get the slice using [] operator. @@ -1024,9 +1027,10 @@ namespace cytnx { #### output> \verbinclude example/Tensor/get.py.out */ - Tensor get(const std::vector &accessors) const { + Tensor get(const std::vector &accessors, + std::vector &removed = *new std::vector()) const { Tensor out; - out._impl = this->_impl->get(accessors); + out._impl = this->_impl->get(accessors, removed); return out; } diff --git a/include/UniTensor.hpp b/include/UniTensor.hpp index 0ece48653..e6463fa2b 100644 --- a/include/UniTensor.hpp +++ b/include/UniTensor.hpp @@ -672,12 +672,12 @@ namespace cytnx { if (this->is_diag()) { cytnx_error_msg( in.shape() != this->_block.shape(), - "[ERROR][DenseUniTensor] put_block, the input tensor shape does not match.%s", "\n"); + "[ERROR][DenseUniTensor][put_block] the input tensor shape does not match.%s", "\n"); this->_block = in.clone(); } else { cytnx_error_msg( in.shape() != this->shape(), - "[ERROR][DenseUniTensor] put_block, the input tensor shape does not match.%s", "\n"); + "[ERROR][DenseUniTensor][put_block] the input tensor shape does not match.%s", "\n"); this->_block = in.clone(); } } @@ -699,12 +699,12 @@ namespace cytnx { if (this->is_diag()) { cytnx_error_msg( in.shape() != this->_block.shape(), - "[ERROR][DenseUniTensor] put_block, the input tensor shape does not match.%s", "\n"); + "[ERROR][DenseUniTensor][put_block] the input tensor shape does not match.%s", "\n"); this->_block = in; } else { cytnx_error_msg( in.shape() != this->shape(), - "[ERROR][DenseUniTensor] put_block, the input tensor shape does not match.%s", "\n"); + "[ERROR][DenseUniTensor][put_block] the input tensor shape does not match.%s", "\n"); this->_block = in; } } @@ -721,10 +721,57 @@ namespace cytnx { } // this will only work on non-symm tensor (DenseUniTensor) boost::intrusive_ptr get(const std::vector &accessors) { - boost::intrusive_ptr out(new DenseUniTensor()); - out->Init_by_Tensor(this->_block.get(accessors), false, 0); // wrapping around. - return out; + if (accessors.empty()) return this; + DenseUniTensor *out = this->clone_meta(); + std::vector removed; // bonds to be removed + if (this->_is_diag) { + if (accessors.size() == 2) { + if (accessors[0] == accessors[1]) { + std::vector acc_in(1, accessors[0]); + return this->get(acc_in); + } else { // convert to dense UniTensor + out->_block = this->_block; + out->to_dense_(); + return out->get(accessors); + } + } else { // for one accessor element, use this accessor for both bonds + cytnx_error_msg(accessors.size() > 1, + "[ERROR][DenseUniTensor][get] for diagonal UniTensors, only one or two " + "accessor elements are allowed.%s", + "\n"); + out->_block = this->_block.get(accessors, removed); + if (removed.empty()) { // change dimension of bonds + for (cytnx_int64 idx = out->_bonds.size() - 1; idx >= 0; idx--) { + out->_bonds[idx]._impl->_dim = out->_block.shape()[0]; + } + } else { // erase all bonds + out->_is_braket_form = false; + out->_is_diag = false; + out->_rowrank = 0; + out->_labels = std::vector(); + out->_bonds = std::vector(); + } + } + } else { // non-diagonal + out->_block = this->_block.get(accessors, removed); + for (cytnx_int64 idx = removed.size() - 1; idx >= 0; idx--) { + out->_labels.erase(out->_labels.begin() + removed[idx]); + out->_bonds.erase(out->_bonds.begin() + removed[idx]); + if (removed[idx] < this->_rowrank) out->_rowrank--; + } + // adapt dimensions on bonds + auto dims = out->_block.shape(); + for (cytnx_int64 idx = 0; idx < out->_bonds.size(); idx++) { + out->_bonds[idx]._impl->_dim = dims[idx]; + } + // update_braket + if (out->is_tag() && !out->_is_braket_form) { + out->_is_braket_form = out->_update_braket(); + } + } + return boost::intrusive_ptr(out); } + // this will only work on non-symm tensor (DenseUniTensor) void set(const std::vector &accessors, const Tensor &rhs) { this->_block.set(accessors, rhs); @@ -1680,7 +1727,7 @@ namespace cytnx { true, "[ERROR] cannot perform elementwise arithmetic '+' btwn Scalar and BlockUniTensor.\n %s " "\n", - "This operation will destroy block structure. [Suggest] using get/set_block(s) to do " + "This operation will destroy block structure. [Suggest] using get/put_block(s) to do " "operation on the block(s)."); } @@ -1693,7 +1740,7 @@ namespace cytnx { true, "[ERROR] cannot perform elementwise arithmetic '+' btwn Scalar and BlockUniTensor.\n %s " "\n", - "This operation will destroy block structure. [Suggest] using get/set_block(s) to do " + "This operation will destroy block structure. [Suggest] using get/put_block(s) to do " "operation on the block(s)."); } void lSub_(const Scalar &lhs) { @@ -1701,7 +1748,7 @@ namespace cytnx { true, "[ERROR] cannot perform elementwise arithmetic '+' btwn Scalar and BlockUniTensor.\n %s " "\n", - "This operation will destroy block structure. [Suggest] using get/set_block(s) to do " + "This operation will destroy block structure. [Suggest] using get/put_block(s) to do " "operation on the block(s)."); } @@ -1710,7 +1757,7 @@ namespace cytnx { true, "[ERROR] cannot perform elementwise arithmetic '+' btwn Scalar and BlockUniTensor.\n %s " "\n", - "This operation will destroy block structure. [Suggest] using get/set_block(s) to do " + "This operation will destroy block structure. [Suggest] using get/put_block(s) to do " "operation on the block(s)."); } void Div_(const Scalar &rhs); @@ -1719,7 +1766,7 @@ namespace cytnx { true, "[ERROR] cannot perform elementwise arithmetic '+' btwn Scalar and BlockUniTensor.\n %s " "\n", - "This operation will destroy block structure. [Suggest] using get/set_block(s) to do " + "This operation will destroy block structure. [Suggest] using get/put_block(s) to do " "operation on the block(s)."); } void from_(const boost::intrusive_ptr &rhs, const bool &force, @@ -2475,7 +2522,7 @@ namespace cytnx { "[ERROR] cannot perform elementwise arithmetic '+' btwn Scalar and " "BlockFermionicUniTensor.\n %s " "\n", - "This operation will destroy block structure. [Suggest] using get/set_block(s) to do " + "This operation will destroy block structure. [Suggest] using get/put_block(s) to do " "operation on the block(s)."); } @@ -2489,7 +2536,7 @@ namespace cytnx { "[ERROR] cannot perform elementwise arithmetic '+' btwn Scalar and " "BlockFermionicUniTensor.\n %s " "\n", - "This operation will destroy block structure. [Suggest] using get/set_block(s) to do " + "This operation will destroy block structure. [Suggest] using get/put_block(s) to do " "operation on the block(s)."); } void lSub_(const Scalar &lhs) { @@ -2498,7 +2545,7 @@ namespace cytnx { "[ERROR] cannot perform elementwise arithmetic '+' btwn Scalar and " "BlockFermionicUniTensor.\n %s " "\n", - "This operation will destroy block structure. [Suggest] using get/set_block(s) to do " + "This operation will destroy block structure. [Suggest] using get/put_block(s) to do " "operation on the block(s)."); } @@ -2508,7 +2555,7 @@ namespace cytnx { "[ERROR] cannot perform elementwise arithmetic '+' btwn Scalar and " "BlockFermionicUniTensor.\n %s " "\n", - "This operation will destroy block structure. [Suggest] using get/set_block(s) to do " + "This operation will destroy block structure. [Suggest] using get/put_block(s) to do " "operation on the block(s)."); } void Div_(const Scalar &rhs); @@ -2518,7 +2565,7 @@ namespace cytnx { "[ERROR] cannot perform elementwise arithmetic '+' btwn Scalar and " "BlockFermionicUniTensor.\n %s " "\n", - "This operation will destroy block structure. [Suggest] using get/set_block(s) to do " + "This operation will destroy block structure. [Suggest] using get/put_block(s) to do " "operation on the block(s)."); } void from_(const boost::intrusive_ptr &rhs, const bool &force); @@ -4247,11 +4294,49 @@ namespace cytnx { this->_impl->put_block_(in, new_qidx, force); in.permute_(new_order); } + + /** + @brief get elements using Accessor (C++ API) / slices (python API) + @param[in] accessors the Accessor (C++ API) / slices (python API) to get the elements. + @return [UniTensor] + @see Tensor::get, UniTensor::operator[] + @note + 1. the return will be a new UniTensor instance, which does not share memory with the current + UniTensor. + + 2. Equivalently, one can also use the [] operator to access elements. + */ UniTensor get(const std::vector &accessors) const { UniTensor out; out._impl = this->_impl->get(accessors); return out; } + + /** + @brief get elements using Accessor (C++ API) / slices (python API) + @see get() + */ + UniTensor operator[](const std::vector &accessors) const { + UniTensor out; + out._impl = this->_impl->get(accessors); + return out; + } + UniTensor operator[](const std::initializer_list &accessors) const { + std::vector acc_in = accessors; + return this->get(acc_in); + } + UniTensor operator[](const std::vector &accessors) const { + std::vector acc_in; + for (cytnx_int64 i = 0; i < accessors.size(); i++) { + acc_in.push_back(cytnx::Accessor(accessors[i])); + } + return this->get(acc_in); + } + UniTensor operator[](const std::initializer_list &accessors) const { + std::vector acc_in = accessors; + return (*this)[acc_in]; + } + void set(const std::vector &accessors, const Tensor &rhs) { this->_impl->set(accessors, rhs); } diff --git a/include/backend/Tensor_impl.hpp b/include/backend/Tensor_impl.hpp index 08f0b3d43..575221ae2 100644 --- a/include/backend/Tensor_impl.hpp +++ b/include/backend/Tensor_impl.hpp @@ -187,7 +187,9 @@ namespace cytnx { return this->_storage.at(RealRank); } - boost::intrusive_ptr get(const std::vector &accessors); + boost::intrusive_ptr get( + const std::vector &accessors, + std::vector &removed = *new std::vector()); boost::intrusive_ptr get_deprecated(const std::vector &accessors); void set(const std::vector &accessors, const boost::intrusive_ptr &rhs); diff --git a/pybind/unitensor_py.cpp b/pybind/unitensor_py.cpp index 0c285b882..d1624191f 100644 --- a/pybind/unitensor_py.cpp +++ b/pybind/unitensor_py.cpp @@ -316,8 +316,26 @@ void unitensor_binding(py::module &m) { std::vector accessors; if (self.is_diag()){ if (py::isinstance(locators)) { - cytnx_error_msg(true, - "[ERROR] cannot get element using [tuple] on is_diag=True UniTensor since the block is rank-1, consider [int] or [int:int] instead.%s", "\n"); + py::tuple Args = locators.cast(); + cytnx_error_msg(Args.size() > 2, + "[ERROR][slicing] A diagonal UniTensor can only be accessed with one- or two dimensional slicing.%s", "\n"); + cytnx_uint64 cnt = 0; + // mixing of slice and ints + for (cytnx_uint32 axis = 0; axis < Args.size(); axis++) { + cnt++; + // check type: + if (py::isinstance(Args[axis])) { + py::slice sls = Args[axis].cast(); + if (!sls.compute((ssize_t)self.shape()[axis], &start, &stop, &step, &slicelength)) + throw py::error_already_set(); + accessors.push_back(cytnx::Accessor::range(cytnx_int64(start), cytnx_int64(stop), + cytnx_int64(step))); + } else { + accessors.push_back(cytnx::Accessor(Args[axis].cast())); + } + } + // cytnx_error_msg(true, + // "[ERROR] cannot get element using [tuple] on is_diag=True UniTensor since the block is rank-1, consider [int] or [int:int] instead.%s", "\n"); } else if (py::isinstance(locators)) { py::slice sls = locators.cast(); if (!sls.compute((ssize_t)self.shape()[0], &start, &stop, &step, &slicelength)) @@ -385,8 +403,26 @@ void unitensor_binding(py::module &m) { std::vector accessors; if (self.is_diag()){ if (py::isinstance(locators)) { - cytnx_error_msg(true, - "[ERROR] cannot get element using [tuple] on is_diag=True UniTensor since the block is rank-1, consider [int] or [int:int] instead.%s", "\n"); + py::tuple Args = locators.cast(); + cytnx_error_msg(Args.size() > 2, + "[ERROR][slicing] A diagonal UniTensor can only be accessed with one- or two dimensional slicing.%s", "\n"); + cytnx_uint64 cnt = 0; + // mixing of slice and ints + for (cytnx_uint32 axis = 0; axis < Args.size(); axis++) { + cnt++; + // check type: + if (py::isinstance(Args[axis])) { + py::slice sls = Args[axis].cast(); + if (!sls.compute((ssize_t)self.shape()[axis], &start, &stop, &step, &slicelength)) + throw py::error_already_set(); + accessors.push_back(cytnx::Accessor::range(cytnx_int64(start), cytnx_int64(stop), + cytnx_int64(step))); + } else { + accessors.push_back(cytnx::Accessor(Args[axis].cast())); + } + } + // cytnx_error_msg(true, + // "[ERROR] cannot get element using [tuple] on is_diag=True UniTensor since the block is rank-1, consider [int] or [int:int] instead.%s", "\n"); } else if (py::isinstance(locators)) { py::slice sls = locators.cast(); if (!sls.compute((ssize_t)self.shape()[0], &start, &stop, &step, &slicelength)) @@ -453,8 +489,26 @@ void unitensor_binding(py::module &m) { std::vector accessors; if (self.is_diag()){ if (py::isinstance(locators)) { - cytnx_error_msg(true, - "[ERROR] cannot get element using [tuple] on is_diag=True UniTensor since the block is rank-1, consider [int] or [int:int] instead.%s", "\n"); + py::tuple Args = locators.cast(); + cytnx_error_msg(Args.size() > 2, + "[ERROR][slicing] A diagonal UniTensor can only be accessed with one- or two dimensional slicing.%s", "\n"); + cytnx_uint64 cnt = 0; + // mixing of slice and ints + for (cytnx_uint32 axis = 0; axis < Args.size(); axis++) { + cnt++; + // check type: + if (py::isinstance(Args[axis])) { + py::slice sls = Args[axis].cast(); + if (!sls.compute((ssize_t)self.shape()[axis], &start, &stop, &step, &slicelength)) + throw py::error_already_set(); + accessors.push_back(cytnx::Accessor::range(cytnx_int64(start), cytnx_int64(stop), + cytnx_int64(step))); + } else { + accessors.push_back(cytnx::Accessor(Args[axis].cast())); + } + } + // cytnx_error_msg(true, + // "[ERROR] cannot get element using [tuple] on is_diag=True UniTensor since the block is rank-1, consider [int] or [int:int] instead.%s", "\n"); } else if (py::isinstance(locators)) { py::slice sls = locators.cast(); if (!sls.compute((ssize_t)self.shape()[0], &start, &stop, &step, &slicelength)) diff --git a/src/Accessor.cpp b/src/Accessor.cpp index 08df57830..85b2908c5 100644 --- a/src/Accessor.cpp +++ b/src/Accessor.cpp @@ -114,6 +114,14 @@ namespace cytnx { return *this; } + // check equality + bool Accessor::operator==(const Accessor &rhs) const { + bool out = (this->_type == rhs._type) && (this->_min == rhs._min) && (this->_max == rhs._max) && + (this->loc == rhs.loc) && (this->_step == rhs._step) && + (this->idx_list == rhs.idx_list); + return out; + } + // get the real len from dim // if _type is all, pos will be null, and len == dim // if _type is range, pos will be the locator, and len == len(pos) diff --git a/src/backend/Tensor_impl.cpp b/src/backend/Tensor_impl.cpp index 74e92b887..651945591 100644 --- a/src/backend/Tensor_impl.cpp +++ b/src/backend/Tensor_impl.cpp @@ -158,8 +158,8 @@ namespace cytnx { // shadow new: // - boost::intrusive_ptr Tensor_impl::get( - const std::vector &accessors) { + boost::intrusive_ptr Tensor_impl::get(const std::vector &accessors, + std::vector &removed) { cytnx_error_msg(accessors.size() > this->_shape.size(), "%s", "The input indexes rank is out of range! (>Tensor's rank)."); @@ -234,10 +234,10 @@ namespace cytnx { // permute back: std::vector new_mapper(this->_mapper.begin(), this->_mapper.end()); std::vector new_shape; - std::vector remove_id; + // std::vector removed; for (unsigned int i = 0; i < out->_shape.size(); i++) { if (out->shape()[i] == 1 && (acc[i].type() == Accessor::Singl)) - remove_id.push_back(this->_mapper[this->_invmapper[i]]); + removed.push_back(this->_mapper[this->_invmapper[i]]); else new_shape.push_back(out->shape()[i]); } @@ -247,8 +247,8 @@ namespace cytnx { // cout << "inv_mapper" << endl; // cout << this->_invmapper << endl; - // cout << "remove_id" << endl; - // cout << remove_id << endl; + // cout << "removed" << endl; + // cout << removed << endl; // cout << "out shape raw" << endl; // cout << out->shape() << endl; @@ -262,10 +262,10 @@ namespace cytnx { std::vector perm; for (unsigned int i = 0; i < new_mapper.size(); i++) { perm.push_back(new_mapper[i]); - for (unsigned int j = 0; j < remove_id.size(); j++) { - if (new_mapper[i] > remove_id[j]) + for (unsigned int j = 0; j < removed.size(); j++) { + if (new_mapper[i] > removed[j]) perm.back() -= 1; - else if (new_mapper[i] == remove_id[j]) { + else if (new_mapper[i] == removed[j]) { perm.pop_back(); break; } @@ -371,10 +371,10 @@ namespace cytnx { // permute input to currect pos std::vector new_mapper(this->_mapper.begin(), this->_mapper.end()); std::vector new_shape; - std::vector remove_id; + std::vector removed; for (unsigned int i = 0; i < get_shape.size(); i++) { if (acc[i].type() == Accessor::Singl) - remove_id.push_back(this->_mapper[this->_invmapper[i]]); + removed.push_back(this->_mapper[this->_invmapper[i]]); else new_shape.push_back(get_shape[i]); } @@ -386,10 +386,10 @@ namespace cytnx { for (unsigned int i = 0; i < new_mapper.size(); i++) { perm.push_back(new_mapper[i]); - for (unsigned int j = 0; j < remove_id.size(); j++) { - if (new_mapper[i] > remove_id[j]) + for (unsigned int j = 0; j < removed.size(); j++) { + if (new_mapper[i] > removed[j]) perm.back() -= 1; - else if (new_mapper[i] == remove_id[j]) { + else if (new_mapper[i] == removed[j]) { perm.pop_back(); break; }