Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

pytorch: add patch for CVE-2024-27318, CVE-2022-1941 #10469

Open
wants to merge 1 commit into
base: 3.0-dev
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
352 changes: 352 additions & 0 deletions SPECS/pytorch/CVE-2022-1941.patch
Original file line number Diff line number Diff line change
@@ -0,0 +1,352 @@
# Patch generated by Archana Choudhary <[email protected]>
# Source: https://github.com/protocolbuffers/protobuf/commit/55815e423bb82cc828836bbd60c79c1f9a195763

diff --color -ruN a/third_party/protobuf/src/google/protobuf/extension_set_inl.h b/third_party/protobuf/src/google/protobuf/extension_set_inl.h
--- a/third_party/protobuf/src/google/protobuf/extension_set_inl.h 2024-03-27 22:28:55.000000000 +0000
+++ b/third_party/protobuf/src/google/protobuf/extension_set_inl.h 2024-09-18 11:49:16.390834276 +0000
@@ -206,16 +206,21 @@
const char* ptr, const Msg* containing_type,
internal::InternalMetadata* metadata, internal::ParseContext* ctx) {
std::string payload;
- uint32 type_id = 0;
- bool payload_read = false;
+ uint32 type_id;
+ enum class State { kNoTag, kHasType, kHasPayload, kDone };
+ State state = State::kNoTag;
+
while (!ctx->Done(&ptr)) {
uint32 tag = static_cast<uint8>(*ptr++);
if (tag == WireFormatLite::kMessageSetTypeIdTag) {
uint64 tmp;
ptr = ParseBigVarint(ptr, &tmp);
GOOGLE_PROTOBUF_PARSER_ASSERT(ptr);
- type_id = tmp;
- if (payload_read) {
+ if (state == State::kNoTag) {
+ type_id = tmp;
+ state = State::kHasType;
+ } else if (state == State::kHasPayload) {
+ type_id = tmp;
ExtensionInfo extension;
bool was_packed_on_wire;
if (!FindExtension(2, type_id, containing_type, ctx, &extension,
@@ -241,20 +246,24 @@
GOOGLE_PROTOBUF_PARSER_ASSERT(value->_InternalParse(p, &tmp_ctx) &&
tmp_ctx.EndedAtLimit());
}
- type_id = 0;
+ state = State::kDone;
}
} else if (tag == WireFormatLite::kMessageSetMessageTag) {
- if (type_id != 0) {
+ if (state == State::kHasType) {
ptr = ParseFieldMaybeLazily(static_cast<uint64>(type_id) * 8 + 2, ptr,
containing_type, metadata, ctx);
GOOGLE_PROTOBUF_PARSER_ASSERT(ptr != nullptr);
- type_id = 0;
+ state = State::kDone;
} else {
+ std::string tmp;
int32 size = ReadSize(&ptr);
GOOGLE_PROTOBUF_PARSER_ASSERT(ptr);
- ptr = ctx->ReadString(ptr, size, &payload);
+ ptr = ctx->ReadString(ptr, size, &tmp);
GOOGLE_PROTOBUF_PARSER_ASSERT(ptr);
- payload_read = true;
+ if (state == State::kNoTag) {
+ payload = std::move(tmp);
+ state = State::kHasPayload;
+ }
}
} else {
ptr = ReadTag(ptr - 1, &tag);
diff --color -ruN a/third_party/protobuf/src/google/protobuf/wire_format.cc b/third_party/protobuf/src/google/protobuf/wire_format.cc
--- a/third_party/protobuf/src/google/protobuf/wire_format.cc 2024-03-27 22:28:55.000000000 +0000
+++ b/third_party/protobuf/src/google/protobuf/wire_format.cc 2024-09-18 11:49:16.390834276 +0000
@@ -659,9 +659,11 @@
const char* _InternalParse(const char* ptr, internal::ParseContext* ctx) {
// Parse a MessageSetItem
auto metadata = reflection->MutableInternalMetadata(msg);
+ enum class State { kNoTag, kHasType, kHasPayload, kDone };
+ State state = State::kNoTag;
+
std::string payload;
uint32 type_id = 0;
- bool payload_read = false;
while (!ctx->Done(&ptr)) {
// We use 64 bit tags in order to allow typeid's that span the whole
// range of 32 bit numbers.
@@ -670,8 +672,11 @@
uint64 tmp;
ptr = ParseBigVarint(ptr, &tmp);
GOOGLE_PROTOBUF_PARSER_ASSERT(ptr);
- type_id = tmp;
- if (payload_read) {
+ if (state == State::kNoTag) {
+ type_id = tmp;
+ state = State::kHasType;
+ } else if (state == State::kHasPayload) {
+ type_id = tmp;
const FieldDescriptor* field;
if (ctx->data().pool == nullptr) {
field = reflection->FindKnownExtensionByNumber(type_id);
@@ -698,17 +703,17 @@
GOOGLE_PROTOBUF_PARSER_ASSERT(value->_InternalParse(p, &tmp_ctx) &&
tmp_ctx.EndedAtLimit());
}
- type_id = 0;
+ state = State::kDone;
}
continue;
} else if (tag == WireFormatLite::kMessageSetMessageTag) {
- if (type_id == 0) {
+ if (state == State::kNoTag) {
int32 size = ReadSize(&ptr);
GOOGLE_PROTOBUF_PARSER_ASSERT(ptr);
ptr = ctx->ReadString(ptr, size, &payload);
GOOGLE_PROTOBUF_PARSER_ASSERT(ptr);
- payload_read = true;
- } else {
+ state = State::kHasPayload;
+ } else if (state == State::kHasType) {
// We're now parsing the payload
const FieldDescriptor* field = nullptr;
if (descriptor->IsExtensionNumber(type_id)) {
@@ -722,7 +727,12 @@
ptr = WireFormat::_InternalParseAndMergeField(
msg, ptr, ctx, static_cast<uint64>(type_id) * 8 + 2, reflection,
field);
- type_id = 0;
+ state = State::kDone;
+ } else {
+ int32 size = ReadSize(&ptr);
+ GOOGLE_PROTOBUF_PARSER_ASSERT(ptr);
+ ptr = ctx->Skip(ptr, size);
+ GOOGLE_PROTOBUF_PARSER_ASSERT(ptr);
}
} else {
// An unknown field in MessageSetItem.
diff --color -ruN a/third_party/protobuf/src/google/protobuf/wire_format_lite.h b/third_party/protobuf/src/google/protobuf/wire_format_lite.h
--- a/third_party/protobuf/src/google/protobuf/wire_format_lite.h 2024-03-27 22:28:55.000000000 +0000
+++ b/third_party/protobuf/src/google/protobuf/wire_format_lite.h 2024-09-18 11:49:16.390834276 +0000
@@ -1798,6 +1798,9 @@
// we can parse it later.
std::string message_data;

+ enum class State { kNoTag, kHasType, kHasPayload, kDone };
+ State state = State::kNoTag;
+
while (true) {
const uint32 tag = input->ReadTagNoLastTag();
if (tag == 0) return false;
@@ -1806,26 +1809,34 @@
case WireFormatLite::kMessageSetTypeIdTag: {
uint32 type_id;
if (!input->ReadVarint32(&type_id)) return false;
- last_type_id = type_id;
-
- if (!message_data.empty()) {
+ if (state == State::kNoTag) {
+ last_type_id = type_id;
+ state = State::kHasType;
+ } else if (state == State::kHasPayload) {
// We saw some message data before the type_id. Have to parse it
// now.
io::CodedInputStream sub_input(
reinterpret_cast<const uint8*>(message_data.data()),
static_cast<int>(message_data.size()));
sub_input.SetRecursionLimit(input->RecursionBudget());
- if (!ms.ParseField(last_type_id, &sub_input)) {
+ if (!ms.ParseField(type_id, &sub_input)) {
return false;
}
message_data.clear();
+ state = State::kDone;
}

break;
}

case WireFormatLite::kMessageSetMessageTag: {
- if (last_type_id == 0) {
+ if (state == State::kHasType) {
+ // Already saw type_id, so we can parse this directly.
+ if (!ms.ParseField(last_type_id, input)) {
+ return false;
+ }
+ state = State::kDone;
+ } else if (state == State::kNoTag) {
// We haven't seen a type_id yet. Append this data to message_data.
uint32 length;
if (!input->ReadVarint32(&length)) return false;
@@ -1836,11 +1847,9 @@
auto ptr = reinterpret_cast<uint8*>(&message_data[0]);
ptr = io::CodedOutputStream::WriteVarint32ToArray(length, ptr);
if (!input->ReadRaw(ptr, length)) return false;
+ state = State::kHasPayload;
} else {
- // Already saw type_id, so we can parse this directly.
- if (!ms.ParseField(last_type_id, input)) {
- return false;
- }
+ if (!ms.SkipField(tag, input)) return false;
}

break;
diff --color -ruN a/third_party/protobuf/src/google/protobuf/wire_format_unittest.cc b/third_party/protobuf/src/google/protobuf/wire_format_unittest.cc
--- a/third_party/protobuf/src/google/protobuf/wire_format_unittest.cc 2024-03-27 22:28:55.000000000 +0000
+++ b/third_party/protobuf/src/google/protobuf/wire_format_unittest.cc 2024-09-18 11:49:16.394834273 +0000
@@ -47,6 +47,7 @@
#include <google/protobuf/io/zero_copy_stream_impl.h>
#include <google/protobuf/io/zero_copy_stream_impl_lite.h>
#include <google/protobuf/descriptor.h>
+#include <google/protobuf/dynamic_message.h>
#include <google/protobuf/wire_format_lite.h>
#include <google/protobuf/testing/googletest.h>
#include <gmock/gmock.h>
@@ -585,30 +586,56 @@
EXPECT_EQ(message_set.DebugString(), dynamic_message_set.DebugString());
}

-TEST(WireFormatTest, ParseMessageSetWithReverseTagOrder) {
+namespace {
+std::string BuildMessageSetItemStart() {
std::string data;
{
- unittest::TestMessageSetExtension1 message;
- message.set_i(123);
- // Build a MessageSet manually with its message content put before its
- // type_id.
io::StringOutputStream output_stream(&data);
io::CodedOutputStream coded_output(&output_stream);
coded_output.WriteTag(WireFormatLite::kMessageSetItemStartTag);
+ }
+ return data;
+}
+std::string BuildMessageSetItemEnd() {
+ std::string data;
+ {
+ io::StringOutputStream output_stream(&data);
+ io::CodedOutputStream coded_output(&output_stream);
+ coded_output.WriteTag(WireFormatLite::kMessageSetItemEndTag);
+ }
+ return data;
+}
+std::string BuildMessageSetTestExtension1(int value = 123) {
+ std::string data;
+ {
+ unittest::TestMessageSetExtension1 message;
+ message.set_i(value);
+ io::StringOutputStream output_stream(&data);
+ io::CodedOutputStream coded_output(&output_stream);
// Write the message content first.
WireFormatLite::WriteTag(WireFormatLite::kMessageSetMessageNumber,
WireFormatLite::WIRETYPE_LENGTH_DELIMITED,
&coded_output);
coded_output.WriteVarint32(message.ByteSizeLong());
message.SerializeWithCachedSizes(&coded_output);
- // Write the type id.
- uint32 type_id = message.GetDescriptor()->extension(0)->number();
+ }
+ return data;
+}
+std::string BuildMessageSetItemTypeId(int extension_number) {
+ std::string data;
+ {
+ io::StringOutputStream output_stream(&data);
+ io::CodedOutputStream coded_output(&output_stream);
WireFormatLite::WriteUInt32(WireFormatLite::kMessageSetTypeIdNumber,
- type_id, &coded_output);
- coded_output.WriteTag(WireFormatLite::kMessageSetItemEndTag);
+ extension_number, &coded_output);
}
+ return data;
+}
+void ValidateTestMessageSet(const std::string& test_case,
+ const std::string& data) {
+ SCOPED_TRACE(test_case);
{
- proto2_wireformat_unittest::TestMessageSet message_set;
+ ::proto2_wireformat_unittest::TestMessageSet message_set;
ASSERT_TRUE(message_set.ParseFromString(data));

EXPECT_EQ(123,
@@ -616,10 +643,15 @@
.GetExtension(
unittest::TestMessageSetExtension1::message_set_extension)
.i());
+
+ // Make sure it does not contain anything else.
+ message_set.ClearExtension(
+ unittest::TestMessageSetExtension1::message_set_extension);
+ EXPECT_EQ(message_set.SerializeAsString(), "");
}
{
// Test parse the message via Reflection.
- proto2_wireformat_unittest::TestMessageSet message_set;
+ ::proto2_wireformat_unittest::TestMessageSet message_set;
io::CodedInputStream input(reinterpret_cast<const uint8*>(data.data()),
data.size());
EXPECT_TRUE(WireFormat::ParseAndMergePartial(&input, &message_set));
@@ -631,6 +663,61 @@
unittest::TestMessageSetExtension1::message_set_extension)
.i());
}
+ {
+ // Test parse the message via DynamicMessage.
+ DynamicMessageFactory factory;
+ std::unique_ptr<Message> msg(
+ factory
+ .GetPrototype(
+ ::proto2_wireformat_unittest::TestMessageSet::descriptor())
+ ->New());
+ msg->ParseFromString(data);
+ auto* reflection = msg->GetReflection();
+ std::vector<const FieldDescriptor*> fields;
+ reflection->ListFields(*msg, &fields);
+ ASSERT_EQ(fields.size(), 1);
+ const auto& sub = reflection->GetMessage(*msg, fields[0]);
+ reflection = sub.GetReflection();
+ EXPECT_EQ(123, reflection->GetInt32(
+ sub, sub.GetDescriptor()->FindFieldByName("i")));
+ }
+}
+} // namespace
+
+TEST(WireFormatTest, ParseMessageSetWithAnyTagOrder) {
+ std::string start = BuildMessageSetItemStart();
+ std::string end = BuildMessageSetItemEnd();
+ std::string id = BuildMessageSetItemTypeId(
+ unittest::TestMessageSetExtension1::descriptor()->extension(0)->number());
+ std::string message = BuildMessageSetTestExtension1();
+
+ ValidateTestMessageSet("id + message", start + id + message + end);
+ ValidateTestMessageSet("message + id", start + message + id + end);
+}
+
+TEST(WireFormatTest, ParseMessageSetWithDuplicateTags) {
+ std::string start = BuildMessageSetItemStart();
+ std::string end = BuildMessageSetItemEnd();
+ std::string id = BuildMessageSetItemTypeId(
+ unittest::TestMessageSetExtension1::descriptor()->extension(0)->number());
+ std::string other_id = BuildMessageSetItemTypeId(123456);
+ std::string message = BuildMessageSetTestExtension1();
+ std::string other_message = BuildMessageSetTestExtension1(321);
+
+ // Double id
+ ValidateTestMessageSet("id + other_id + message",
+ start + id + other_id + message + end);
+ ValidateTestMessageSet("id + message + other_id",
+ start + id + message + other_id + end);
+ ValidateTestMessageSet("message + id + other_id",
+ start + message + id + other_id + end);
+ // Double message
+ ValidateTestMessageSet("id + message + other_message",
+ start + id + message + other_message + end);
+ ValidateTestMessageSet("message + id + other_message",
+ start + message + id + other_message + end);
+ ValidateTestMessageSet("message + other_message + id",
+ start + message + other_message + id + end);
}

void SerializeReverseOrder(
Loading
Loading