Eigenをcerealでシリアライズする
自分で若干コードを書かないとEigenオブジェクトをシリアライズできなかったのでメモ。
結論から言うと、C++17環境であれば以下のようにload()
, save()
を定義してやればよい。
#pragma once #include <cereal/cereal.hpp> #include <cereal/types/vector.hpp> #include <Eigen/Dense> #include <vector> namespace cereal { //Eigenのデータの中身を配列として出力するためのラッパ template <class T> struct DataWrapper { DataWrapper(T& m) : mat(m) {} T& mat; template <class Archive> void save(Archive& ar) const { ar(cereal::make_size_tag(mat.size())); for (int iter = 0, end = static_cast<int>(mat.size()); iter != end; ++iter) ar(*(mat.data() + iter)); } template <class Archive> void load(Archive& ar) { cereal::size_type n_rows; ar(cereal::make_size_tag(n_rows)); for (int iter = 0, end = static_cast<int>(mat.size()); iter != end; ++iter) ar(*(mat.data() + iter)); } }; template <class Archive, class Derived, cereal::traits::DisableIf<cereal::traits::is_text_archive<Archive>::value> = cereal::traits::sfinae> inline void save(Archive& ar, Eigen::PlainObjectBase<Derived> const& m) { using ArrT = Eigen::PlainObjectBase<Derived>; if constexpr (ArrT::RowsAtCompileTime == Eigen::Dynamic) ar(m.rows()); if constexpr (ArrT::ColsAtCompileTime == Eigen::Dynamic) ar(m.cols()); ar(binary_data(m.data(), m.size() * sizeof(typename Derived::Scalar))); } template <class Archive, class Derived, cereal::traits::DisableIf<cereal::traits::is_text_archive<Archive>::value> = cereal::traits::sfinae> inline void load(Archive& ar, Eigen::PlainObjectBase<Derived>& m) { using ArrT = Eigen::PlainObjectBase<Derived>; Eigen::Index rows = ArrT::RowsAtCompileTime, cols = ArrT::ColsAtCompileTime; if constexpr (ArrT::RowsAtCompileTime == Eigen::Dynamic) ar(rows); if constexpr (ArrT::ColsAtCompileTime == Eigen::Dynamic) ar(cols); m.resize(rows, cols); ar(binary_data(m.data(), static_cast<std::size_t>(rows * cols * sizeof(typename Derived::Scalar)))); } template <class Archive, class Derived, cereal::traits::EnableIf<cereal::traits::is_text_archive<Archive>::value> = cereal::traits::sfinae> inline void save(Archive& ar, Eigen::PlainObjectBase<Derived> const& m) { using ArrT = Eigen::PlainObjectBase<Derived>; ar(cereal::make_nvp("rows", m.rows())); ar(cereal::make_nvp("cols", m.cols())); ar(cereal::make_nvp("data", DataWrapper(m))); } template <class Archive, class Derived, cereal::traits::EnableIf<cereal::traits::is_text_archive<Archive>::value> = cereal::traits::sfinae> inline void load(Archive& ar, Eigen::PlainObjectBase<Derived>& m) { using ArrT = Eigen::PlainObjectBase<Derived>; Eigen::Index rows = ArrT::RowsAtCompileTime, cols = ArrT::ColsAtCompileTime; ar(rows); ar(cols); m.resize(rows, cols); ar(DataWrapper(m)); } }
コードは以下のページをかなり参考にした。 stackoverflow.com
以下はテストコード。
#include "cereal_eigen.h" #include <random> #include <fstream> #include <filesystem> #include <cereal/archives/binary.hpp> #include <cereal/archives/json.hpp> namespace fs = std::filesystem; class TestClass { public: TestClass() { std::random_device rd; std::mt19937 gen(rd()); std::uniform_real_distribution<double> distd(-1.0, 1.0); std::uniform_int_distribution<int> disti(0, 100); mat = Eigen::MatrixXd::NullaryExpr(5, 3, [&]() {return distd(gen); });; arr << disti(gen), disti(gen), disti(gen); vec3d << distd(gen), distd(gen), distd(gen);; } void Print() { std::cout << "mat: " << std::endl << mat << std::endl; std::cout << "arr: " << arr << std::endl; std::cout << "evec: " << vec3d << std::endl; } template<class Archive> void serialize(Archive& archive) { archive(CEREAL_NVP(vec3d), CEREAL_NVP(arr), CEREAL_NVP(mat)); } Eigen::MatrixXd mat; Eigen::Vector3d vec3d; Eigen::Array<int, 1, 3> arr; }; int main() { fs::create_directory("data"); fs::path path_binary = "data/class.cereal"; fs::path path_json = "data/class.json"; auto t = TestClass(); t.Print(); { std::ofstream out(path_binary, std::ios::binary); cereal::BinaryOutputArchive archive_o(out); archive_o(CEREAL_NVP(t)); std::ofstream out2(path_json); cereal::JSONOutputArchive archive_o2(out2); archive_o2(CEREAL_NVP(t)); } std::cout << std::endl; std::cout << "Load from a binary file" << std::endl; TestClass t_binary; { std::ifstream in(path_binary, std::ios::binary); cereal::BinaryInputArchive archive_i(in); archive_i(t_binary); } t_binary.Print(); std::cout << std::endl; std::cout << "Load from a json file" << std::endl; TestClass t_json; { std::ifstream in2(path_json); cereal::JSONInputArchive archive_i2(in2); archive_i2(t_json); } t_json.Print(); }
出力されるJSONファイルはこんな感じ
{ "t": { "vec3d": { "rows": 3, "cols": 1, "data": [ 0.3161424837703686, 0.09854066925246485, 0.6277061473999255 ] }, "arr": { "rows": 1, "cols": 3, "data": [ 94, 16, 40 ] }, "mat": { "rows": 5, "cols": 3, "data": [ -0.26028257842889648, 0.25791645918797637, 0.6778558335502285, -0.05895612669355521, 0.08350015176958703, 0.5584214393023781, 0.4016409143850439, 0.6676428322540648, 0.10603756836576128, -0.48229906250794388, 0.7801566848260653, -0.8826490291364187, -0.9816820696869016, -0.9338585559435026, 0.5447137103816357 ] } } }