Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions include/Accessor.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,9 @@ namespace cytnx {
Accessor(const Accessor &rhs);
// copy assignment:
Accessor &operator=(const Accessor &rhs);

// check equality
bool operator==(const Accessor &rhs) const;
///@endcond

int type() const { return this->_type; }
Expand Down
10 changes: 7 additions & 3 deletions include/Tensor.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -1006,10 +1006,13 @@ namespace cytnx {
/**
@brief get elements using Accessor (C++ API) / slices (python API)
@param[in] accessors the Accessor (C++ API) / slices (python API) to get the elements.
@param[out] removed an ascending list of indices that were removed from the original shape of
the Tensor.
@return [Tensor]
@see \link cytnx::Accessor Accessor\endlink for cordinate with Accessor in C++ API.
@note
1. the return will be a new Tensor instance, which not share memory with the current Tensor.
The return will be a new Tensor instance, which does not share memory with the current
Tensor.

## Equivalently:
One can also using more intruisive way to get the slice using [] operator.
Expand All @@ -1024,9 +1027,10 @@ namespace cytnx {
#### output>
\verbinclude example/Tensor/get.py.out
*/
Tensor get(const std::vector<cytnx::Accessor> &accessors) const {
Tensor get(const std::vector<cytnx::Accessor> &accessors,
std::vector<cytnx_int64> &removed = *new std::vector<cytnx_int64>()) const {
Tensor out;
out._impl = this->_impl->get(accessors);
out._impl = this->_impl->get(accessors, removed);
return out;
}

Expand Down
119 changes: 102 additions & 17 deletions include/UniTensor.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -672,12 +672,12 @@ namespace cytnx {
if (this->is_diag()) {
cytnx_error_msg(
in.shape() != this->_block.shape(),
"[ERROR][DenseUniTensor] put_block, the input tensor shape does not match.%s", "\n");
"[ERROR][DenseUniTensor][put_block] the input tensor shape does not match.%s", "\n");
this->_block = in.clone();
} else {
cytnx_error_msg(
in.shape() != this->shape(),
"[ERROR][DenseUniTensor] put_block, the input tensor shape does not match.%s", "\n");
"[ERROR][DenseUniTensor][put_block] the input tensor shape does not match.%s", "\n");
this->_block = in.clone();
}
}
Expand All @@ -699,12 +699,12 @@ namespace cytnx {
if (this->is_diag()) {
cytnx_error_msg(
in.shape() != this->_block.shape(),
"[ERROR][DenseUniTensor] put_block, the input tensor shape does not match.%s", "\n");
"[ERROR][DenseUniTensor][put_block] the input tensor shape does not match.%s", "\n");
this->_block = in;
} else {
cytnx_error_msg(
in.shape() != this->shape(),
"[ERROR][DenseUniTensor] put_block, the input tensor shape does not match.%s", "\n");
"[ERROR][DenseUniTensor][put_block] the input tensor shape does not match.%s", "\n");
this->_block = in;
}
}
Expand All @@ -721,10 +721,57 @@ namespace cytnx {
}
// this will only work on non-symm tensor (DenseUniTensor)
boost::intrusive_ptr<UniTensor_base> get(const std::vector<Accessor> &accessors) {
boost::intrusive_ptr<UniTensor_base> out(new DenseUniTensor());
out->Init_by_Tensor(this->_block.get(accessors), false, 0); // wrapping around.
return out;
if (accessors.empty()) return this;
DenseUniTensor *out = this->clone_meta();
std::vector<cytnx_int64> removed; // bonds to be removed
if (this->_is_diag) {
if (accessors.size() == 2) {
if (accessors[0] == accessors[1]) {
std::vector<Accessor> acc_in(1, accessors[0]);
return this->get(acc_in);
} else { // convert to dense UniTensor
out->_block = this->_block;
out->to_dense_();
return out->get(accessors);
}
} else { // for one accessor element, use this accessor for both bonds
cytnx_error_msg(accessors.size() > 1,
"[ERROR][DenseUniTensor][get] for diagonal UniTensors, only one or two "
"accessor elements are allowed.%s",
"\n");
out->_block = this->_block.get(accessors, removed);
if (removed.empty()) { // change dimension of bonds
for (cytnx_int64 idx = out->_bonds.size() - 1; idx >= 0; idx--) {
out->_bonds[idx]._impl->_dim = out->_block.shape()[0];
}
} else { // erase all bonds
out->_is_braket_form = false;
out->_is_diag = false;
out->_rowrank = 0;
out->_labels = std::vector<std::string>();
out->_bonds = std::vector<Bond>();
}
}
} else { // non-diagonal
out->_block = this->_block.get(accessors, removed);
for (cytnx_int64 idx = removed.size() - 1; idx >= 0; idx--) {
out->_labels.erase(out->_labels.begin() + removed[idx]);
out->_bonds.erase(out->_bonds.begin() + removed[idx]);
if (removed[idx] < this->_rowrank) out->_rowrank--;
}
// adapt dimensions on bonds
auto dims = out->_block.shape();
for (cytnx_int64 idx = 0; idx < out->_bonds.size(); idx++) {
out->_bonds[idx]._impl->_dim = dims[idx];
}
// update_braket
if (out->is_tag() && !out->_is_braket_form) {
out->_is_braket_form = out->_update_braket();
}
}
return boost::intrusive_ptr<UniTensor_base>(out);
}

// this will only work on non-symm tensor (DenseUniTensor)
void set(const std::vector<Accessor> &accessors, const Tensor &rhs) {
this->_block.set(accessors, rhs);
Expand Down Expand Up @@ -1680,7 +1727,7 @@ namespace cytnx {
true,
"[ERROR] cannot perform elementwise arithmetic '+' btwn Scalar and BlockUniTensor.\n %s "
"\n",
"This operation will destroy block structure. [Suggest] using get/set_block(s) to do "
"This operation will destroy block structure. [Suggest] using get/put_block(s) to do "
"operation on the block(s).");
}

Expand All @@ -1693,15 +1740,15 @@ namespace cytnx {
true,
"[ERROR] cannot perform elementwise arithmetic '+' btwn Scalar and BlockUniTensor.\n %s "
"\n",
"This operation will destroy block structure. [Suggest] using get/set_block(s) to do "
"This operation will destroy block structure. [Suggest] using get/put_block(s) to do "
"operation on the block(s).");
}
void lSub_(const Scalar &lhs) {
cytnx_error_msg(
true,
"[ERROR] cannot perform elementwise arithmetic '+' btwn Scalar and BlockUniTensor.\n %s "
"\n",
"This operation will destroy block structure. [Suggest] using get/set_block(s) to do "
"This operation will destroy block structure. [Suggest] using get/put_block(s) to do "
"operation on the block(s).");
}

Expand All @@ -1710,7 +1757,7 @@ namespace cytnx {
true,
"[ERROR] cannot perform elementwise arithmetic '+' btwn Scalar and BlockUniTensor.\n %s "
"\n",
"This operation will destroy block structure. [Suggest] using get/set_block(s) to do "
"This operation will destroy block structure. [Suggest] using get/put_block(s) to do "
"operation on the block(s).");
}
void Div_(const Scalar &rhs);
Expand All @@ -1719,7 +1766,7 @@ namespace cytnx {
true,
"[ERROR] cannot perform elementwise arithmetic '+' btwn Scalar and BlockUniTensor.\n %s "
"\n",
"This operation will destroy block structure. [Suggest] using get/set_block(s) to do "
"This operation will destroy block structure. [Suggest] using get/put_block(s) to do "
"operation on the block(s).");
}
void from_(const boost::intrusive_ptr<UniTensor_base> &rhs, const bool &force,
Expand Down Expand Up @@ -2475,7 +2522,7 @@ namespace cytnx {
"[ERROR] cannot perform elementwise arithmetic '+' btwn Scalar and "
"BlockFermionicUniTensor.\n %s "
"\n",
"This operation will destroy block structure. [Suggest] using get/set_block(s) to do "
"This operation will destroy block structure. [Suggest] using get/put_block(s) to do "
"operation on the block(s).");
}

Expand All @@ -2489,7 +2536,7 @@ namespace cytnx {
"[ERROR] cannot perform elementwise arithmetic '+' btwn Scalar and "
"BlockFermionicUniTensor.\n %s "
"\n",
"This operation will destroy block structure. [Suggest] using get/set_block(s) to do "
"This operation will destroy block structure. [Suggest] using get/put_block(s) to do "
"operation on the block(s).");
}
void lSub_(const Scalar &lhs) {
Expand All @@ -2498,7 +2545,7 @@ namespace cytnx {
"[ERROR] cannot perform elementwise arithmetic '+' btwn Scalar and "
"BlockFermionicUniTensor.\n %s "
"\n",
"This operation will destroy block structure. [Suggest] using get/set_block(s) to do "
"This operation will destroy block structure. [Suggest] using get/put_block(s) to do "
"operation on the block(s).");
}

Expand All @@ -2508,7 +2555,7 @@ namespace cytnx {
"[ERROR] cannot perform elementwise arithmetic '+' btwn Scalar and "
"BlockFermionicUniTensor.\n %s "
"\n",
"This operation will destroy block structure. [Suggest] using get/set_block(s) to do "
"This operation will destroy block structure. [Suggest] using get/put_block(s) to do "
"operation on the block(s).");
}
void Div_(const Scalar &rhs);
Expand All @@ -2518,7 +2565,7 @@ namespace cytnx {
"[ERROR] cannot perform elementwise arithmetic '+' btwn Scalar and "
"BlockFermionicUniTensor.\n %s "
"\n",
"This operation will destroy block structure. [Suggest] using get/set_block(s) to do "
"This operation will destroy block structure. [Suggest] using get/put_block(s) to do "
"operation on the block(s).");
}
void from_(const boost::intrusive_ptr<UniTensor_base> &rhs, const bool &force);
Expand Down Expand Up @@ -4247,11 +4294,49 @@ namespace cytnx {
this->_impl->put_block_(in, new_qidx, force);
in.permute_(new_order);
}

/**
@brief get elements using Accessor (C++ API) / slices (python API)
@param[in] accessors the Accessor (C++ API) / slices (python API) to get the elements.
@return [UniTensor]
@see Tensor::get, UniTensor::operator[]
@note
1. the return will be a new UniTensor instance, which does not share memory with the current
UniTensor.

2. Equivalently, one can also use the [] operator to access elements.
*/
UniTensor get(const std::vector<Accessor> &accessors) const {
UniTensor out;
out._impl = this->_impl->get(accessors);
return out;
}

/**
@brief get elements using Accessor (C++ API) / slices (python API)
@see get()
*/
UniTensor operator[](const std::vector<cytnx::Accessor> &accessors) const {
UniTensor out;
out._impl = this->_impl->get(accessors);
return out;
}
UniTensor operator[](const std::initializer_list<cytnx::Accessor> &accessors) const {
std::vector<cytnx::Accessor> acc_in = accessors;
return this->get(acc_in);
}
UniTensor operator[](const std::vector<cytnx_int64> &accessors) const {
std::vector<cytnx::Accessor> acc_in;
for (cytnx_int64 i = 0; i < accessors.size(); i++) {
acc_in.push_back(cytnx::Accessor(accessors[i]));
}
return this->get(acc_in);
}
UniTensor operator[](const std::initializer_list<cytnx_int64> &accessors) const {
std::vector<cytnx_int64> acc_in = accessors;
return (*this)[acc_in];
}

void set(const std::vector<Accessor> &accessors, const Tensor &rhs) {
this->_impl->set(accessors, rhs);
}
Expand Down
4 changes: 3 additions & 1 deletion include/backend/Tensor_impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,9 @@ namespace cytnx {
return this->_storage.at(RealRank);
}

boost::intrusive_ptr<Tensor_impl> get(const std::vector<cytnx::Accessor> &accessors);
boost::intrusive_ptr<Tensor_impl> get(
const std::vector<cytnx::Accessor> &accessors,
std::vector<cytnx_int64> &removed = *new std::vector<cytnx_int64>());
boost::intrusive_ptr<Tensor_impl> get_deprecated(const std::vector<cytnx::Accessor> &accessors);
void set(const std::vector<cytnx::Accessor> &accessors,
const boost::intrusive_ptr<Tensor_impl> &rhs);
Expand Down
66 changes: 60 additions & 6 deletions pybind/unitensor_py.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -316,8 +316,26 @@ void unitensor_binding(py::module &m) {
std::vector<cytnx::Accessor> accessors;
if (self.is_diag()){
if (py::isinstance<py::tuple>(locators)) {
cytnx_error_msg(true,
"[ERROR] cannot get element using [tuple] on is_diag=True UniTensor since the block is rank-1, consider [int] or [int:int] instead.%s", "\n");
py::tuple Args = locators.cast<py::tuple>();
cytnx_error_msg(Args.size() > 2,
"[ERROR][slicing] A diagonal UniTensor can only be accessed with one- or two dimensional slicing.%s", "\n");
cytnx_uint64 cnt = 0;
// mixing of slice and ints
for (cytnx_uint32 axis = 0; axis < Args.size(); axis++) {
cnt++;
// check type:
if (py::isinstance<py::slice>(Args[axis])) {
py::slice sls = Args[axis].cast<py::slice>();
if (!sls.compute((ssize_t)self.shape()[axis], &start, &stop, &step, &slicelength))
throw py::error_already_set();
accessors.push_back(cytnx::Accessor::range(cytnx_int64(start), cytnx_int64(stop),
cytnx_int64(step)));
} else {
accessors.push_back(cytnx::Accessor(Args[axis].cast<cytnx_int64>()));
}
}
// cytnx_error_msg(true,
// "[ERROR] cannot get element using [tuple] on is_diag=True UniTensor since the block is rank-1, consider [int] or [int:int] instead.%s", "\n");
} else if (py::isinstance<py::slice>(locators)) {
py::slice sls = locators.cast<py::slice>();
if (!sls.compute((ssize_t)self.shape()[0], &start, &stop, &step, &slicelength))
Expand Down Expand Up @@ -385,8 +403,26 @@ void unitensor_binding(py::module &m) {
std::vector<cytnx::Accessor> accessors;
if (self.is_diag()){
if (py::isinstance<py::tuple>(locators)) {
cytnx_error_msg(true,
"[ERROR] cannot get element using [tuple] on is_diag=True UniTensor since the block is rank-1, consider [int] or [int:int] instead.%s", "\n");
py::tuple Args = locators.cast<py::tuple>();
cytnx_error_msg(Args.size() > 2,
"[ERROR][slicing] A diagonal UniTensor can only be accessed with one- or two dimensional slicing.%s", "\n");
cytnx_uint64 cnt = 0;
// mixing of slice and ints
for (cytnx_uint32 axis = 0; axis < Args.size(); axis++) {
cnt++;
// check type:
if (py::isinstance<py::slice>(Args[axis])) {
py::slice sls = Args[axis].cast<py::slice>();
if (!sls.compute((ssize_t)self.shape()[axis], &start, &stop, &step, &slicelength))
throw py::error_already_set();
accessors.push_back(cytnx::Accessor::range(cytnx_int64(start), cytnx_int64(stop),
cytnx_int64(step)));
} else {
accessors.push_back(cytnx::Accessor(Args[axis].cast<cytnx_int64>()));
}
}
// cytnx_error_msg(true,
// "[ERROR] cannot get element using [tuple] on is_diag=True UniTensor since the block is rank-1, consider [int] or [int:int] instead.%s", "\n");
} else if (py::isinstance<py::slice>(locators)) {
py::slice sls = locators.cast<py::slice>();
if (!sls.compute((ssize_t)self.shape()[0], &start, &stop, &step, &slicelength))
Expand Down Expand Up @@ -453,8 +489,26 @@ void unitensor_binding(py::module &m) {
std::vector<cytnx::Accessor> accessors;
if (self.is_diag()){
if (py::isinstance<py::tuple>(locators)) {
cytnx_error_msg(true,
"[ERROR] cannot get element using [tuple] on is_diag=True UniTensor since the block is rank-1, consider [int] or [int:int] instead.%s", "\n");
py::tuple Args = locators.cast<py::tuple>();
cytnx_error_msg(Args.size() > 2,
"[ERROR][slicing] A diagonal UniTensor can only be accessed with one- or two dimensional slicing.%s", "\n");
cytnx_uint64 cnt = 0;
// mixing of slice and ints
for (cytnx_uint32 axis = 0; axis < Args.size(); axis++) {
cnt++;
// check type:
if (py::isinstance<py::slice>(Args[axis])) {
py::slice sls = Args[axis].cast<py::slice>();
if (!sls.compute((ssize_t)self.shape()[axis], &start, &stop, &step, &slicelength))
throw py::error_already_set();
accessors.push_back(cytnx::Accessor::range(cytnx_int64(start), cytnx_int64(stop),
cytnx_int64(step)));
} else {
accessors.push_back(cytnx::Accessor(Args[axis].cast<cytnx_int64>()));
}
}
// cytnx_error_msg(true,
// "[ERROR] cannot get element using [tuple] on is_diag=True UniTensor since the block is rank-1, consider [int] or [int:int] instead.%s", "\n");
} else if (py::isinstance<py::slice>(locators)) {
py::slice sls = locators.cast<py::slice>();
if (!sls.compute((ssize_t)self.shape()[0], &start, &stop, &step, &slicelength))
Expand Down
8 changes: 8 additions & 0 deletions src/Accessor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,14 @@ namespace cytnx {
return *this;
}

// check equality
bool Accessor::operator==(const Accessor &rhs) const {
bool out = (this->_type == rhs._type) && (this->_min == rhs._min) && (this->_max == rhs._max) &&
(this->loc == rhs.loc) && (this->_step == rhs._step) &&
(this->idx_list == rhs.idx_list);
return out;
}

// get the real len from dim
// if _type is all, pos will be null, and len == dim
// if _type is range, pos will be the locator, and len == len(pos)
Expand Down
Loading
Loading