Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
243 changes: 130 additions & 113 deletions src/mp/gen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,109 @@ static bool BoxedType(const ::capnp::Type& type)
type.isFloat64() || type.isEnum());
}

struct Field
{
::capnp::StructSchema::Field param;
bool param_is_set = false;
::capnp::StructSchema::Field result;
bool result_is_set = false;
int args = 0;
bool retval = false;
bool optional = false;
bool requested = false;
bool skip = false;
kj::StringPtr exception;
};

struct FieldList
{
std::vector<Field> fields;
std::map<kj::StringPtr, int> field_idx; // name -> args index
bool has_result = false;

void addField(const ::capnp::StructSchema::Field& schema_field, bool param, bool result)
{
auto field_name = schema_field.getProto().getName();
auto inserted = field_idx.emplace(field_name, fields.size());
if (inserted.second) {
fields.emplace_back();
}
auto& field = fields[inserted.first->second];
if (param) {
field.param = schema_field;
field.param_is_set = true;
}
if (result) {
field.result = schema_field;
field.result_is_set = true;
}

if (!param && field_name == kj::StringPtr{"result"}) {
field.retval = true;
has_result = true;
}

if (AnnotationExists(schema_field.getProto(), SKIP_ANNOTATION_ID)) {
field.skip = true;
}
GetAnnotationText(schema_field.getProto(), EXCEPTION_ANNOTATION_ID, &field.exception);

int32_t count = 1;
if (!GetAnnotationInt32(schema_field.getProto(), COUNT_ANNOTATION_ID, &count)) {
if (schema_field.getType().isStruct()) {
GetAnnotationInt32(schema_field.getType().asStruct().getProto(),
COUNT_ANNOTATION_ID, &count);
} else if (schema_field.getType().isInterface()) {
GetAnnotationInt32(schema_field.getType().asInterface().getProto(),
COUNT_ANNOTATION_ID, &count);
}
}


if (inserted.second && !field.retval && !field.exception.size()) {
field.args = count;
}
}

void mergeFields()
{
for (auto& field : field_idx) {
auto has_field = field_idx.find("has" + Cap(field.first));
if (has_field != field_idx.end()) {
fields[has_field->second].skip = true;
fields[field.second].optional = true;
}
auto want_field = field_idx.find("want" + Cap(field.first));
if (want_field != field_idx.end() && fields[want_field->second].param_is_set) {
fields[want_field->second].skip = true;
fields[field.second].requested = true;
}
}
}
};

std::string AccessorType(kj::StringPtr base_name, const Field& field)
{
const auto& f = field.param_is_set ? field.param : field.result;
const auto field_name = f.getProto().getName();
const auto field_type = f.getType();

std::ostringstream out;
out << "Accessor<" << base_name << "_fields::" << Cap(field_name) << ", ";
if (!field.param_is_set) {
out << "FIELD_OUT";
} else if (field.result_is_set) {
out << "FIELD_IN | FIELD_OUT";
} else {
out << "FIELD_IN";
}
if (field.optional) out << " | FIELD_OPTIONAL";
if (field.requested) out << " | FIELD_REQUESTED";
if (BoxedType(field_type)) out << " | FIELD_BOXED";
out << ">";
return out.str();
}

// src_file is path to .capnp file to generate stub code from.
//
// src_prefix can be used to generate outputs in a different directory than the
Expand Down Expand Up @@ -332,6 +435,13 @@ static void Generate(kj::StringPtr src_prefix,

if (node.getProto().isStruct()) {
const auto& struc = node.asStruct();

FieldList fields;
for (const auto schema_field : struc.getFields()) {
fields.addField(schema_field, true, true);
}
fields.mergeFields();

std::ostringstream generic_name;
generic_name << node_name;
dec << "template<";
Expand All @@ -352,22 +462,18 @@ static void Generate(kj::StringPtr src_prefix,
dec << "struct ProxyStruct<" << message_namespace << "::" << generic_name.str() << ">\n";
dec << "{\n";
dec << " using Struct = " << message_namespace << "::" << generic_name.str() << ";\n";
for (const auto field : struc.getFields()) {
auto field_name = field.getProto().getName();
for (const auto& field : fields.fields) {
auto field_name = field.param.getProto().getName();
add_accessor(field_name);
dec << " using " << Cap(field_name) << "Accessor = Accessor<" << base_name
<< "_fields::" << Cap(field_name) << ", FIELD_IN | FIELD_OUT";
if (BoxedType(field.getType())) dec << " | FIELD_BOXED";
dec << ">;\n";
dec << " using " << Cap(field_name) << "Accessor = "
<< AccessorType(base_name, field) << ";\n";
}
dec << " using Accessors = std::tuple<";
size_t i = 0;
for (const auto field : struc.getFields()) {
if (AnnotationExists(field.getProto(), SKIP_ANNOTATION_ID)) {
continue;
}
for (const auto& field : fields.fields) {
if (field.skip) continue;
if (i) dec << ", ";
dec << Cap(field.getProto().getName()) << "Accessor";
dec << Cap(field.param.getProto().getName()) << "Accessor";
++i;
}
dec << ">;\n";
Expand All @@ -381,13 +487,11 @@ static void Generate(kj::StringPtr src_prefix,
inl << "public:\n";
inl << " using Struct = " << message_namespace << "::" << node_name << ";\n";
size_t i = 0;
for (const auto field : struc.getFields()) {
if (AnnotationExists(field.getProto(), SKIP_ANNOTATION_ID)) {
continue;
}
auto field_name = field.getProto().getName();
for (const auto& field : fields.fields) {
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In commit "mpgen: support primitive std::optional struct fields" (6dbfa56):

I think the hasX field here is intended as internal serialization plumbing for the logical optional field, so I would have expected it to stay hidden from the generated ProxyStruct surface rather than also getting its own public accessor alias.

I believe a minimal fix would be to skip alias generation for field.skip entries in the ProxyStruct loop, similar to how they are already skipped when building the Accessors tuple below. Something along these lines:

for (const auto& field : fields.fields) {
      if (field.skip) continue;
      auto field_name = field.param.getProto().getName();
      add_accessor(field_name);
      dec << "    using " << Cap(field_name) << "Accessor =
  "
          << AccessorType(base_name, field) << ";\n";
  

if (field.skip) continue;
auto field_name = field.param.getProto().getName();
auto member_name = field_name;
GetAnnotationText(field.getProto(), NAME_ANNOTATION_ID, &member_name);
GetAnnotationText(field.param.getProto(), NAME_ANNOTATION_ID, &member_name);
inl << " static decltype(auto) get(std::integral_constant<size_t, " << i << ">) { return "
<< "&" << proxied_class_type << "::" << member_name << "; }\n";
++i;
Expand Down Expand Up @@ -430,85 +534,14 @@ static void Generate(kj::StringPtr src_prefix,
const bool is_construct = method_name == kj::StringPtr{"construct"};
const bool is_destroy = method_name == kj::StringPtr{"destroy"};

struct Field
{
::capnp::StructSchema::Field param;
bool param_is_set = false;
::capnp::StructSchema::Field result;
bool result_is_set = false;
int args = 0;
bool retval = false;
bool optional = false;
bool requested = false;
bool skip = false;
kj::StringPtr exception;
};

std::vector<Field> fields;
std::map<kj::StringPtr, int> field_idx; // name -> args index
bool has_result = false;

auto add_field = [&](const ::capnp::StructSchema::Field& schema_field, bool param) {
if (AnnotationExists(schema_field.getProto(), SKIP_ANNOTATION_ID)) {
return;
}

auto field_name = schema_field.getProto().getName();
auto inserted = field_idx.emplace(field_name, fields.size());
if (inserted.second) {
fields.emplace_back();
}
auto& field = fields[inserted.first->second];
if (param) {
field.param = schema_field;
field.param_is_set = true;
} else {
field.result = schema_field;
field.result_is_set = true;
}

if (!param && field_name == kj::StringPtr{"result"}) {
field.retval = true;
has_result = true;
}

GetAnnotationText(schema_field.getProto(), EXCEPTION_ANNOTATION_ID, &field.exception);

int32_t count = 1;
if (!GetAnnotationInt32(schema_field.getProto(), COUNT_ANNOTATION_ID, &count)) {
if (schema_field.getType().isStruct()) {
GetAnnotationInt32(schema_field.getType().asStruct().getProto(),
COUNT_ANNOTATION_ID, &count);
} else if (schema_field.getType().isInterface()) {
GetAnnotationInt32(schema_field.getType().asInterface().getProto(),
COUNT_ANNOTATION_ID, &count);
}
}


if (inserted.second && !field.retval && !field.exception.size()) {
field.args = count;
}
};

FieldList fields;
for (const auto schema_field : method.getParamType().getFields()) {
add_field(schema_field, true);
fields.addField(schema_field, true, false);
}
for (const auto schema_field : method.getResultType().getFields()) {
add_field(schema_field, false);
}
for (auto& field : field_idx) {
auto has_field = field_idx.find("has" + Cap(field.first));
if (has_field != field_idx.end()) {
fields[has_field->second].skip = true;
fields[field.second].optional = true;
}
auto want_field = field_idx.find("want" + Cap(field.first));
if (want_field != field_idx.end() && fields[want_field->second].param_is_set) {
fields[want_field->second].skip = true;
fields[field.second].requested = true;
}
fields.addField(schema_field, false, true);
}
fields.mergeFields();

if (!is_construct && !is_destroy && (&method_interface == &interface)) {
methods << "template<>\n";
Expand All @@ -524,25 +557,11 @@ static void Generate(kj::StringPtr src_prefix,
std::ostringstream server_invoke_start;
std::ostringstream server_invoke_end;
int argc = 0;
for (const auto& field : fields) {
for (const auto& field : fields.fields) {
if (field.skip) continue;

const auto& f = field.param_is_set ? field.param : field.result;
auto field_name = f.getProto().getName();
auto field_type = f.getType();

std::ostringstream field_flags;
if (!field.param_is_set) {
field_flags << "FIELD_OUT";
} else if (field.result_is_set) {
field_flags << "FIELD_IN | FIELD_OUT";
} else {
field_flags << "FIELD_IN";
}
if (field.optional) field_flags << " | FIELD_OPTIONAL";
if (field.requested) field_flags << " | FIELD_REQUESTED";
if (BoxedType(field_type)) field_flags << " | FIELD_BOXED";

add_accessor(field_name);

std::ostringstream fwd_args;
Expand All @@ -569,8 +588,7 @@ static void Generate(kj::StringPtr src_prefix,
client_invoke << "MakeClientParam<";
}

client_invoke << "Accessor<" << base_name << "_fields::" << Cap(field_name) << ", "
<< field_flags.str() << ">>(";
client_invoke << AccessorType(base_name, field) << ">(";

if (field.retval) {
client_invoke << field_name;
Expand All @@ -586,8 +604,7 @@ static void Generate(kj::StringPtr src_prefix,
} else {
server_invoke_start << "MakeServerField<" << field.args;
}
server_invoke_start << ", Accessor<" << base_name << "_fields::" << Cap(field_name) << ", "
<< field_flags.str() << ">>(";
server_invoke_start << ", " << AccessorType(base_name, field) << ">(";
server_invoke_end << ")";
}

Expand All @@ -603,12 +620,12 @@ static void Generate(kj::StringPtr src_prefix,
def_client << "ProxyClient<" << message_namespace << "::" << node_name << ">::M" << method_ordinal
<< "::Result ProxyClient<" << message_namespace << "::" << node_name << ">::" << method_name
<< "(" << super_str << client_args.str() << ") {\n";
if (has_result) {
if (fields.has_result) {
def_client << " typename M" << method_ordinal << "::Result result;\n";
}
def_client << " clientInvoke(" << self_str << ", &" << message_namespace << "::" << node_name
<< "::Client::" << method_name << "Request" << client_invoke.str() << ");\n";
if (has_result) def_client << " return result;\n";
if (fields.has_result) def_client << " return result;\n";
def_client << "}\n";

server << " kj::Promise<void> " << method_name << "(" << Cap(method_name)
Expand Down
1 change: 1 addition & 0 deletions test/mp/test/foo-types.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
#include <mp/type-map.h>
#include <mp/type-message.h>
#include <mp/type-number.h>
#include <mp/type-optional.h>
#include <mp/type-pointer.h>
#include <mp/type-set.h>
#include <mp/type-string.h>
Expand Down
2 changes: 2 additions & 0 deletions test/mp/test/foo.capnp
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,8 @@ struct FooStruct $Proxy.wrap("mp::test::FooStruct") {
name @0 :Text;
setint @1 :List(Int32);
vbool @2 :List(Bool);
optionalInt @3 :Int32 $Proxy.name("optional_int");
hasOptionalInt @4 :Bool;
}

struct FooCustom $Proxy.wrap("mp::test::FooCustom") {
Expand Down
2 changes: 2 additions & 0 deletions test/mp/test/foo.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
#include <functional>
#include <map>
#include <memory>
#include <optional>
#include <string>
#include <set>
#include <vector>
Expand All @@ -21,6 +22,7 @@ struct FooStruct
std::string name;
std::set<int> setint;
std::vector<bool> vbool;
std::optional<int> optional_int;
};

enum class FooEnum : uint8_t { ONE = 1, TWO = 2, };
Expand Down
7 changes: 7 additions & 0 deletions test/mp/test/test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,7 @@ KJ_TEST("Call FooInterface methods")
in.vbool.push_back(false);
in.vbool.push_back(true);
in.vbool.push_back(false);
in.optional_int = 3;
FooStruct out = foo->pass(in);
KJ_EXPECT(in.name == out.name);
KJ_EXPECT(in.setint.size() == out.setint.size());
Expand All @@ -150,6 +151,12 @@ KJ_TEST("Call FooInterface methods")
for (size_t i = 0; i < in.vbool.size(); ++i) {
KJ_EXPECT(in.vbool[i] == out.vbool[i]);
}
KJ_EXPECT(in.optional_int == out.optional_int);

// Additional checks for std::optional member
KJ_EXPECT(foo->pass(in).optional_int == 3);
in.optional_int.reset();
KJ_EXPECT(!foo->pass(in).optional_int);

FooStruct err;
try {
Expand Down
Loading