protobuf 字段检查工具

背景

最近遇到一个需求,将jsonnet转为对应protobuf对象。官方都有相应的库,如下所示是一个jsonnet转为对应protobuf对象的案例,message是protobuf对象的父类:

bool GetProtoFromJsonnetFile(const std::string& file_name,
                             google::protobuf::Message* message) {
  std::ifstream fin(file_name.c_str());

  if (!fin) {
    std::cerr << "Fail to read jsonnet file " << file_name;
    return false;
  }

  std::string data;
  jsonnet::Jsonnet js;
  js.init();

  if (!js.evaluateFile(file_name, &data)) {
    std::cerr << "INVALID json format, filename=" << file_name;
    return false;
  }

  google::protobuf::util::JsonParseOptions options;
  options.ignore_unknown_fields = true;
  auto status =
      google::protobuf::util::JsonStringToMessage(data, message, options);
  if (!status.ok()) {
    std::cerr << "Fail to parse jsonnet file" << status.error_message();
    return false;
  }

  json json_data;
  json_data = json::parse(data);
  std::vector<std::string> paths;
  if (!ComparePbField(json_data, *message, paths)) {
    std::cerr << file_name << " invalid field found:";
    for (auto& s : paths) {
      std::cerr << "    " << s;
    }
  }

  return true;
}

但是可能的误操作导致json中会有些多余的字段,不会进行提醒,报错结果不够全面,所以编写了这个工具,学习下protobuf的反射机制、nlohmann json库的api。

  • 实现思路:

jsonnet读取后生成protobuf对象后进行处理,因为相应库函数忽略了无效字段,jsonnest中的对象字段一定大于等于proto对象的字段。由此设计ComparePbField函数,流程大致为:递归解析json对象时同时对jsonnet生成的proto对象进行反射查找对应字段,未找到时将维护的递归路径数组封装打印日志。

具体步骤

1. 仅使用 protobuf api

Protobuf可以对proto这类txt文本可以使用google::protobuf::TextFormat接口来解析生成pb对象,注意AllowUnknownField(true)即可生成并警告无效字段。

相关api文档

#include <google/protobuf/text_format.h>  
google::protobuf::TextFormat::Parser par;
par.AllowUnknownField(true);

然后调用JsonStringToMessage 返回值:

修改options.ignore_unknown_fields为false(该选项默认值false),即可在JsonStringToMessage的返回值status 调用error_code() 进行处理。 但是忽略无效字段选项关闭后,无论是 缺少必要字段 还是 多余未知字段 返回code都为3 (INVALID_ARGUMENT) 无法进行区分。如下打印日志所示:

# 缺少必要字段 class_name
code: 3, detail: (components[0].config): missing field name
# 修改原字段 msg_type
code: 3, detail: (components[0].config.task) trigger_wpolicy: Cannot find field.
# 添加无效字段 do_not_skip
code: 3, detail: (components[0].config.task.input_channels[1]) do_not_skip: Cannot find field.

2. 反射相关api的学习

反射即为动态获取类的属性与方法,c++原生并不支持反射,protobuf的反射通过Descriptor生成Reflection来获取并创建修改类内字段来实现反射。 protobuf 类的基本父类是 google::protobuf::Message。Descriptor有常用的以下接口:

const std::string & name() const; // 获取message自身名字
int field_count() const; // 获取该message中有多少字段
const FileDescriptor* file() const; // The .proto file in which this message type was defined. Never nullptr.

// 获取类 FieldDescriptor:
const FieldDescriptor* field(int index) const; // 根据定义顺序索引获取,即从0开始到最大定义的条目
const FieldDescriptor* FindFieldByNumber(int number) const; // 根据定义的message里面的顺序值获取(option string name=3,3即为number)
const FieldDescriptor* FindFieldByName(const string& name) const; // 根据field name获取
Descriptor功能
FileDescriptor获取Proto文件中的Descriptor和ServiceDescriptor
Descriptor获取类message属性和方法,包括FieldDescriptor和EnumDescriptor
FieldDescriptor获取message中各个字段的类型、标签、名称等
EnumDescriptor获取Enum中的各个字段名称、值等
ServiceDescriptor获取service中的MethodDescriptor
MethodDescriptor获取各个RPC中的request、response、名称等
  • 通过proto文件获取描述符
const FileDescriptor* fileDescriptor = DescriptorPool::generated_pool()->FindFileByName(file);
  • 通过类名获取描述符
// DescriptorPool包含了程序编译的时候所链接的全部 protobuf Message types。
// 然后通过其提供的 FindMessageTypeByName 方法即可根据type name 查找到Descriptor。
auto descriptor = google::protobuf::DescriptorPool::generated_pool()->FindMessageTypeByName("Person");
// 利用Descriptor拿到类型注册的instance. 这个是不可修改的.
auto prototype = google::protobuf::MessageFactory::generated_factory()->GetPrototype(descriptor);
// 构造一个可用的消息.
auto instance = prototype->New(); //创建新的 person message对象。

使用反射来将加载生成的module_config对象与protobuf各个字段对比检查。ref->ListFields只检查对象中设置的字段,descriptor->field_count()会计算类中全部字段。

/*        //ref取
std::vector<const google::protobuf::FieldDescriptor*> fields;
ref->ListFields(message, &fields);
for (auto field : fields) { ******
        // des取
for (int i = 0; i < descriptor->field_count(); ++i) {
    const google::protobuf::FieldDescriptor* field = descriptor->field(i); *******
*/


// 递归查询protobuf对象字段
void CompareField(const google::protobuf::Message& message) {
  const google::protobuf::Descriptor* des = message.GetDescriptor();
  const google::protobuf::Reflection* ref = message.GetReflection();
  if(des) {
    std::vector<const google::protobuf::FieldDescriptor*> fields;
    ref->ListFields(message, &fields);
    for (auto field : fields) {
      if (field->is_repeated()) {
        for(int i = 0; i < ref->FieldSize(message, field); i++) {
          // Recursively traverse repeated types
          if (field->cpp_type() == google::protobuf::FieldDescriptor::CPPTYPE_MESSAGE) {
            auto& sub_message = ref->GetRepeatedMessage(message, field, i);
            CompareField(sub_message);
          } else {      
            // skip
          }
        }
      } else {
        if (field->cpp_type() == google::protobuf::FieldDescriptor::CPPTYPE_MESSAGE) {
          // Recursively traverse compound message
          auto& sub_message = ref->GetMessage(message, field);
          CompareField(sub_message);
        } else {      
          // skip
        }
      }
    }
  }
}

// 常见的一些解析手段 ref->Get*** 获取字段对应value值
#define CASE_FIELD_TYPE(cpptype, method, valuetype)\
case google::protobuf::FieldDescriptor::CPPTYPE_##cpptype:{\
    valuetype value = reflection->Get##method(message, field);\
    int wsize = field->name().size();\
    serialized_string->append(reinterpret_cast<char*>(&wsize), sizeof(wsize));\
    serialized_string->append(field->name().c_str(), field->name().size());\
    wsize = sizeof(value);\
    serialized_string->append(reinterpret_cast<char*>(&wsize), sizeof(wsize));\
    serialized_string->append(reinterpret_cast<char*>(&value), sizeof(value));\
    break;\
}
switch (field->cpp_type()) {
    CASE_FIELD_TYPE(INT32, Int32, int);
    CASE_FIELD_TYPE(UINT32, UInt32, uint32_t);
    CASE_FIELD_TYPE(FLOAT, Float, float);
    CASE_FIELD_TYPE(DOUBLE, Double, double);
    CASE_FIELD_TYPE(BOOL, Bool, bool);
    CASE_FIELD_TYPE(INT64, Int64, int64_t);
    CASE_FIELD_TYPE(UINT64, UInt64, uint64_t);
}

问题在于前JsonStringToMessage生成的对象已经忽略了无效字段,所以该方法反射出的字段很正常无法检查错误。

3. 自实现的递归工具

所以将上述protobuf反射字段对比json生成的对象字段,从而检查是否包含无效字段。使用递归回溯来记录path,print打印递归中不存在的字段。 json解析示例
对于递归过程中的field与字段映射一开始直接粗暴地用哈希表存字段key对应的field指针,后改为从GetDescriptor获取field。

inline void PrintPath(const vector<string>& path) {
  string path_s;
  for (auto& s : path) {
    path_s = path_s + s + ".";
  }
  path_s.pop_back();
  std::out << "Invalid field found: " << path_s;
}

void DoComparePbField(vector<string>& path, const json& js,
                      const google::protobuf::Message& message) {
  const google::protobuf::Reflection* ref = message.GetReflection();
  const google::protobuf::Descriptor* des = message.GetDescriptor();
  for (json::const_reverse_iterator it = js.crbegin(); it != js.crend(); ++it) {
    path.push_back(it.key());
    auto field = des->FindFieldByName(it.key());
    if (field) {
      if (it->is_array()) {
        for (size_t i = 0; i < it.value().size(); i++) {
          if (it.value().at(i).is_array() || it.value().at(i).is_object()) {
            path.back() = path.back() + "[" + std::to_string(i) + "]";
            auto& sub_message = ref->GetRepeatedMessage(message, field, i);
            DoComparePbField(path, it.value().at(i), sub_message);
            path.back() = it.key();
          }  // else skip
        }
      } else {
        if (it->is_object()) {
          auto& sub_message = ref->GetMessage(message, field);
          DoComparePbField(path, *it, sub_message);
        }  // else skip
      }
    } else {  // find invaild
      PrintPath(path);
    }
    path.pop_back();
  }
}

// Recursively compare json and prptobuf fields, then print the recursive path
// of invalid fields
void ComparePbField(const json& js, const google::protobuf::Message& message) {
  vector<string> path;
  if (!js.empty()) DoComparePbField(path, js, message);
}