Skip to content

Commit

Permalink
[Quality] fix c++ binaries formatting (#859)
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens authored Jul 8, 2024
1 parent 2330f08 commit 3934fe1
Show file tree
Hide file tree
Showing 2 changed files with 80 additions and 64 deletions.
76 changes: 76 additions & 0 deletions tensordict/csrc/utils.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
// Copyright (c) Meta Platforms, Inc. and affiliates.
//
// This source code is licensed under the MIT license found in the
// LICENSE file in the root directory of this source tree.

#include "utils.h"

namespace py = pybind11;

py::tuple _unravel_key_to_tuple(const py::object &key) {
bool is_tuple = py::isinstance<py::tuple>(key);
bool is_str = py::isinstance<py::str>(key);

if (is_tuple) {
py::list newkey;
for (const auto &subkey : key) {
if (py::isinstance<py::str>(subkey)) {
newkey.append(subkey);
} else {
auto _key = _unravel_key_to_tuple(subkey.cast<py::object>());
if (_key.size() == 0) {
return py::make_tuple();
}
newkey += _key;
}
}
return py::tuple(newkey);
}
if (is_str) {
return py::make_tuple(key);
} else {
return py::make_tuple();
}
}

py::object unravel_key(const py::object &key) {
bool is_tuple = py::isinstance<py::tuple>(key);
bool is_str = py::isinstance<py::str>(key);

if (is_tuple) {
py::list newkey;
int count = 0;
for (const auto &subkey : key) {
if (py::isinstance<py::str>(subkey)) {
newkey.append(subkey);
count++;
} else {
auto _key = _unravel_key_to_tuple(subkey.cast<py::object>());
count += _key.size();
newkey += _key;
}
}
if (count == 1) {
return newkey[0];
}
return py::tuple(newkey);
}
if (is_str) {
return key;
} else {
throw std::runtime_error("key should be a Sequence<NestedKey>");
}
}

py::list unravel_key_list(const py::list &keys) {
py::list newkeys;
for (const auto &key : keys) {
auto _key = unravel_key(key.cast<py::object>());
newkeys.append(_key);
}
return newkeys;
}

py::list unravel_key_list(const py::tuple &keys) {
return unravel_key_list(py::list(keys));
}
68 changes: 4 additions & 64 deletions tensordict/csrc/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,70 +8,10 @@

namespace py = pybind11;

py::tuple _unravel_key_to_tuple(const py::object &key) {
bool is_tuple = py::isinstance<py::tuple>(key);
bool is_str = py::isinstance<py::str>(key);
py::tuple _unravel_key_to_tuple(const py::object &key);

if (is_tuple) {
py::list newkey;
for (const auto &subkey : key) {
if (py::isinstance<py::str>(subkey)) {
newkey.append(subkey);
} else {
auto _key = _unravel_key_to_tuple(subkey.cast<py::object>());
if (_key.size() == 0) {
return py::make_tuple();
}
newkey += _key;
}
}
return py::tuple(newkey);
}
if (is_str) {
return py::make_tuple(key);
} else {
return py::make_tuple();
}
}
py::object unravel_key(const py::object &key);

py::object unravel_key(const py::object &key) {
bool is_tuple = py::isinstance<py::tuple>(key);
bool is_str = py::isinstance<py::str>(key);
py::list unravel_key_list(const py::list &keys);

if (is_tuple) {
py::list newkey;
int count = 0;
for (const auto &subkey : key) {
if (py::isinstance<py::str>(subkey)) {
newkey.append(subkey);
count++;
} else {
auto _key = _unravel_key_to_tuple(subkey.cast<py::object>());
count += _key.size();
newkey += _key;
}
}
if (count == 1) {
return newkey[0];
}
return py::tuple(newkey);
}
if (is_str) {
return key;
} else {
throw std::runtime_error("key should be a Sequence<NestedKey>");
}
}

py::list unravel_key_list(const py::list &keys) {
py::list newkeys;
for (const auto &key : keys) {
auto _key = unravel_key(key.cast<py::object>());
newkeys.append(_key);
}
return newkeys;
}

py::list unravel_key_list(const py::tuple &keys) {
return unravel_key_list(py::list(keys));
}
py::list unravel_key_list(const py::tuple &keys);

0 comments on commit 3934fe1

Please sign in to comment.