われがわログ

最適化アルゴリズムとかプログラミングについて書きたい

Eigenをcerealでシリアライズする

自分で若干コードを書かないとEigenオブジェクトをシリアライズできなかったのでメモ。

結論から言うと、C++17環境であれば以下のようにload(), save()を定義してやればよい。

#pragma once
#include <cereal/cereal.hpp>
#include <cereal/types/vector.hpp>
#include <Eigen/Dense>
#include <vector>

namespace cereal
{
    //Eigenのデータの中身を配列として出力するためのラッパ
    template <class T>
    struct DataWrapper
    {
        DataWrapper(T& m) : mat(m) {}
        T& mat;
        template <class Archive>
        void save(Archive& ar) const
        {
            ar(cereal::make_size_tag(mat.size()));
            for (int iter = 0, end = static_cast<int>(mat.size()); iter != end; ++iter)
                ar(*(mat.data() + iter));
        }

        template <class Archive>
        void load(Archive& ar)
        {
            cereal::size_type n_rows;
            ar(cereal::make_size_tag(n_rows));
            for (int iter = 0, end = static_cast<int>(mat.size()); iter != end; ++iter)
                ar(*(mat.data() + iter));
        }
    };

    template <class Archive, class Derived, cereal::traits::DisableIf<cereal::traits::is_text_archive<Archive>::value> = cereal::traits::sfinae>
    inline void save(Archive& ar, Eigen::PlainObjectBase<Derived> const& m) {
        using ArrT = Eigen::PlainObjectBase<Derived>;
        if constexpr (ArrT::RowsAtCompileTime == Eigen::Dynamic) ar(m.rows());
        if constexpr (ArrT::ColsAtCompileTime == Eigen::Dynamic) ar(m.cols());
        ar(binary_data(m.data(), m.size() * sizeof(typename Derived::Scalar)));
    }

    template <class Archive, class Derived, cereal::traits::DisableIf<cereal::traits::is_text_archive<Archive>::value> = cereal::traits::sfinae>
    inline void load(Archive& ar, Eigen::PlainObjectBase<Derived>& m) {
        using ArrT = Eigen::PlainObjectBase<Derived>;
        Eigen::Index rows = ArrT::RowsAtCompileTime, cols = ArrT::ColsAtCompileTime;
        if constexpr (ArrT::RowsAtCompileTime == Eigen::Dynamic) ar(rows);
        if constexpr (ArrT::ColsAtCompileTime == Eigen::Dynamic) ar(cols);
        m.resize(rows, cols);
        ar(binary_data(m.data(), static_cast<std::size_t>(rows * cols * sizeof(typename Derived::Scalar))));
    }

    template <class Archive, class Derived, cereal::traits::EnableIf<cereal::traits::is_text_archive<Archive>::value> = cereal::traits::sfinae>
    inline void save(Archive& ar, Eigen::PlainObjectBase<Derived> const& m) {
        using ArrT = Eigen::PlainObjectBase<Derived>;
        ar(cereal::make_nvp("rows", m.rows()));
        ar(cereal::make_nvp("cols", m.cols()));
        ar(cereal::make_nvp("data", DataWrapper(m)));
    }

    template <class Archive, class Derived, cereal::traits::EnableIf<cereal::traits::is_text_archive<Archive>::value> = cereal::traits::sfinae>
    inline void load(Archive& ar, Eigen::PlainObjectBase<Derived>& m) {
        using ArrT = Eigen::PlainObjectBase<Derived>;
        Eigen::Index rows = ArrT::RowsAtCompileTime, cols = ArrT::ColsAtCompileTime;
        ar(rows);
        ar(cols);
        m.resize(rows, cols);
        ar(DataWrapper(m));
    }
}

コードは以下のページをかなり参考にした。 stackoverflow.com

stackoverflow.com

uscilab.github.io

以下はテストコード。

#include "cereal_eigen.h"
#include <random>
#include <fstream>
#include <filesystem>
#include <cereal/archives/binary.hpp>
#include <cereal/archives/json.hpp>

namespace fs = std::filesystem;

class TestClass {
public:
    TestClass() {
        std::random_device rd;
        std::mt19937 gen(rd());
        std::uniform_real_distribution<double> distd(-1.0, 1.0);
        std::uniform_int_distribution<int> disti(0, 100);

        mat = Eigen::MatrixXd::NullaryExpr(5, 3, [&]() {return distd(gen); });;
        arr << disti(gen), disti(gen), disti(gen);
        vec3d << distd(gen), distd(gen), distd(gen);;
    }
    void Print() {
        std::cout << "mat: " << std::endl << mat << std::endl;
        std::cout << "arr: " << arr << std::endl;
        std::cout << "evec: " << vec3d << std::endl;
    }
    template<class Archive>
    void serialize(Archive& archive)
    {
        archive(CEREAL_NVP(vec3d), CEREAL_NVP(arr), CEREAL_NVP(mat));
    }

    Eigen::MatrixXd mat;
    Eigen::Vector3d vec3d;
    Eigen::Array<int, 1, 3> arr;
};

int main() {
    fs::create_directory("data");
    fs::path path_binary = "data/class.cereal";
    fs::path path_json = "data/class.json";
    auto t = TestClass();
    t.Print();
    {
        std::ofstream out(path_binary, std::ios::binary);
        cereal::BinaryOutputArchive archive_o(out);
        archive_o(CEREAL_NVP(t));

        std::ofstream out2(path_json);
        cereal::JSONOutputArchive archive_o2(out2);
        archive_o2(CEREAL_NVP(t));
    }

    std::cout << std::endl;
    std::cout << "Load from a binary file" << std::endl;

    TestClass t_binary;
    {
        std::ifstream in(path_binary, std::ios::binary);
        cereal::BinaryInputArchive archive_i(in);
        archive_i(t_binary);
    }
    t_binary.Print();

    std::cout << std::endl;
    std::cout << "Load from a json file" << std::endl;
    TestClass t_json;
    {
        std::ifstream in2(path_json);
        cereal::JSONInputArchive archive_i2(in2);
        archive_i2(t_json);
    }
    t_json.Print();
}

出力されるJSONファイルはこんな感じ

{
    "t": {
        "vec3d": {
            "rows": 3,
            "cols": 1,
            "data": [
                0.3161424837703686,
                0.09854066925246485,
                0.6277061473999255
            ]
        },
        "arr": {
            "rows": 1,
            "cols": 3,
            "data": [
                94,
                16,
                40
            ]
        },
        "mat": {
            "rows": 5,
            "cols": 3,
            "data": [
                -0.26028257842889648,
                0.25791645918797637,
                0.6778558335502285,
                -0.05895612669355521,
                0.08350015176958703,
                0.5584214393023781,
                0.4016409143850439,
                0.6676428322540648,
                0.10603756836576128,
                -0.48229906250794388,
                0.7801566848260653,
                -0.8826490291364187,
                -0.9816820696869016,
                -0.9338585559435026,
                0.5447137103816357
            ]
        }
    }
}