-
Notifications
You must be signed in to change notification settings - Fork 522
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
pytorch: add patch for CVE-2024-27318, CVE-2022-1941
- Loading branch information
Showing
3 changed files
with
735 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,352 @@ | ||
# Patch generated by Archana Choudhary <[email protected]> | ||
# Source: https://github.com/protocolbuffers/protobuf/commit/55815e423bb82cc828836bbd60c79c1f9a195763 | ||
|
||
diff --color -ruN a/third_party/protobuf/src/google/protobuf/extension_set_inl.h b/third_party/protobuf/src/google/protobuf/extension_set_inl.h | ||
--- a/third_party/protobuf/src/google/protobuf/extension_set_inl.h 2024-03-27 22:28:55.000000000 +0000 | ||
+++ b/third_party/protobuf/src/google/protobuf/extension_set_inl.h 2024-09-18 11:49:16.390834276 +0000 | ||
@@ -206,16 +206,21 @@ | ||
const char* ptr, const Msg* containing_type, | ||
internal::InternalMetadata* metadata, internal::ParseContext* ctx) { | ||
std::string payload; | ||
- uint32 type_id = 0; | ||
- bool payload_read = false; | ||
+ uint32 type_id; | ||
+ enum class State { kNoTag, kHasType, kHasPayload, kDone }; | ||
+ State state = State::kNoTag; | ||
+ | ||
while (!ctx->Done(&ptr)) { | ||
uint32 tag = static_cast<uint8>(*ptr++); | ||
if (tag == WireFormatLite::kMessageSetTypeIdTag) { | ||
uint64 tmp; | ||
ptr = ParseBigVarint(ptr, &tmp); | ||
GOOGLE_PROTOBUF_PARSER_ASSERT(ptr); | ||
- type_id = tmp; | ||
- if (payload_read) { | ||
+ if (state == State::kNoTag) { | ||
+ type_id = tmp; | ||
+ state = State::kHasType; | ||
+ } else if (state == State::kHasPayload) { | ||
+ type_id = tmp; | ||
ExtensionInfo extension; | ||
bool was_packed_on_wire; | ||
if (!FindExtension(2, type_id, containing_type, ctx, &extension, | ||
@@ -241,20 +246,24 @@ | ||
GOOGLE_PROTOBUF_PARSER_ASSERT(value->_InternalParse(p, &tmp_ctx) && | ||
tmp_ctx.EndedAtLimit()); | ||
} | ||
- type_id = 0; | ||
+ state = State::kDone; | ||
} | ||
} else if (tag == WireFormatLite::kMessageSetMessageTag) { | ||
- if (type_id != 0) { | ||
+ if (state == State::kHasType) { | ||
ptr = ParseFieldMaybeLazily(static_cast<uint64>(type_id) * 8 + 2, ptr, | ||
containing_type, metadata, ctx); | ||
GOOGLE_PROTOBUF_PARSER_ASSERT(ptr != nullptr); | ||
- type_id = 0; | ||
+ state = State::kDone; | ||
} else { | ||
+ std::string tmp; | ||
int32 size = ReadSize(&ptr); | ||
GOOGLE_PROTOBUF_PARSER_ASSERT(ptr); | ||
- ptr = ctx->ReadString(ptr, size, &payload); | ||
+ ptr = ctx->ReadString(ptr, size, &tmp); | ||
GOOGLE_PROTOBUF_PARSER_ASSERT(ptr); | ||
- payload_read = true; | ||
+ if (state == State::kNoTag) { | ||
+ payload = std::move(tmp); | ||
+ state = State::kHasPayload; | ||
+ } | ||
} | ||
} else { | ||
ptr = ReadTag(ptr - 1, &tag); | ||
diff --color -ruN a/third_party/protobuf/src/google/protobuf/wire_format.cc b/third_party/protobuf/src/google/protobuf/wire_format.cc | ||
--- a/third_party/protobuf/src/google/protobuf/wire_format.cc 2024-03-27 22:28:55.000000000 +0000 | ||
+++ b/third_party/protobuf/src/google/protobuf/wire_format.cc 2024-09-18 11:49:16.390834276 +0000 | ||
@@ -659,9 +659,11 @@ | ||
const char* _InternalParse(const char* ptr, internal::ParseContext* ctx) { | ||
// Parse a MessageSetItem | ||
auto metadata = reflection->MutableInternalMetadata(msg); | ||
+ enum class State { kNoTag, kHasType, kHasPayload, kDone }; | ||
+ State state = State::kNoTag; | ||
+ | ||
std::string payload; | ||
uint32 type_id = 0; | ||
- bool payload_read = false; | ||
while (!ctx->Done(&ptr)) { | ||
// We use 64 bit tags in order to allow typeid's that span the whole | ||
// range of 32 bit numbers. | ||
@@ -670,8 +672,11 @@ | ||
uint64 tmp; | ||
ptr = ParseBigVarint(ptr, &tmp); | ||
GOOGLE_PROTOBUF_PARSER_ASSERT(ptr); | ||
- type_id = tmp; | ||
- if (payload_read) { | ||
+ if (state == State::kNoTag) { | ||
+ type_id = tmp; | ||
+ state = State::kHasType; | ||
+ } else if (state == State::kHasPayload) { | ||
+ type_id = tmp; | ||
const FieldDescriptor* field; | ||
if (ctx->data().pool == nullptr) { | ||
field = reflection->FindKnownExtensionByNumber(type_id); | ||
@@ -698,17 +703,17 @@ | ||
GOOGLE_PROTOBUF_PARSER_ASSERT(value->_InternalParse(p, &tmp_ctx) && | ||
tmp_ctx.EndedAtLimit()); | ||
} | ||
- type_id = 0; | ||
+ state = State::kDone; | ||
} | ||
continue; | ||
} else if (tag == WireFormatLite::kMessageSetMessageTag) { | ||
- if (type_id == 0) { | ||
+ if (state == State::kNoTag) { | ||
int32 size = ReadSize(&ptr); | ||
GOOGLE_PROTOBUF_PARSER_ASSERT(ptr); | ||
ptr = ctx->ReadString(ptr, size, &payload); | ||
GOOGLE_PROTOBUF_PARSER_ASSERT(ptr); | ||
- payload_read = true; | ||
- } else { | ||
+ state = State::kHasPayload; | ||
+ } else if (state == State::kHasType) { | ||
// We're now parsing the payload | ||
const FieldDescriptor* field = nullptr; | ||
if (descriptor->IsExtensionNumber(type_id)) { | ||
@@ -722,7 +727,12 @@ | ||
ptr = WireFormat::_InternalParseAndMergeField( | ||
msg, ptr, ctx, static_cast<uint64>(type_id) * 8 + 2, reflection, | ||
field); | ||
- type_id = 0; | ||
+ state = State::kDone; | ||
+ } else { | ||
+ int32 size = ReadSize(&ptr); | ||
+ GOOGLE_PROTOBUF_PARSER_ASSERT(ptr); | ||
+ ptr = ctx->Skip(ptr, size); | ||
+ GOOGLE_PROTOBUF_PARSER_ASSERT(ptr); | ||
} | ||
} else { | ||
// An unknown field in MessageSetItem. | ||
diff --color -ruN a/third_party/protobuf/src/google/protobuf/wire_format_lite.h b/third_party/protobuf/src/google/protobuf/wire_format_lite.h | ||
--- a/third_party/protobuf/src/google/protobuf/wire_format_lite.h 2024-03-27 22:28:55.000000000 +0000 | ||
+++ b/third_party/protobuf/src/google/protobuf/wire_format_lite.h 2024-09-18 11:49:16.390834276 +0000 | ||
@@ -1798,6 +1798,9 @@ | ||
// we can parse it later. | ||
std::string message_data; | ||
|
||
+ enum class State { kNoTag, kHasType, kHasPayload, kDone }; | ||
+ State state = State::kNoTag; | ||
+ | ||
while (true) { | ||
const uint32 tag = input->ReadTagNoLastTag(); | ||
if (tag == 0) return false; | ||
@@ -1806,26 +1809,34 @@ | ||
case WireFormatLite::kMessageSetTypeIdTag: { | ||
uint32 type_id; | ||
if (!input->ReadVarint32(&type_id)) return false; | ||
- last_type_id = type_id; | ||
- | ||
- if (!message_data.empty()) { | ||
+ if (state == State::kNoTag) { | ||
+ last_type_id = type_id; | ||
+ state = State::kHasType; | ||
+ } else if (state == State::kHasPayload) { | ||
// We saw some message data before the type_id. Have to parse it | ||
// now. | ||
io::CodedInputStream sub_input( | ||
reinterpret_cast<const uint8*>(message_data.data()), | ||
static_cast<int>(message_data.size())); | ||
sub_input.SetRecursionLimit(input->RecursionBudget()); | ||
- if (!ms.ParseField(last_type_id, &sub_input)) { | ||
+ if (!ms.ParseField(type_id, &sub_input)) { | ||
return false; | ||
} | ||
message_data.clear(); | ||
+ state = State::kDone; | ||
} | ||
|
||
break; | ||
} | ||
|
||
case WireFormatLite::kMessageSetMessageTag: { | ||
- if (last_type_id == 0) { | ||
+ if (state == State::kHasType) { | ||
+ // Already saw type_id, so we can parse this directly. | ||
+ if (!ms.ParseField(last_type_id, input)) { | ||
+ return false; | ||
+ } | ||
+ state = State::kDone; | ||
+ } else if (state == State::kNoTag) { | ||
// We haven't seen a type_id yet. Append this data to message_data. | ||
uint32 length; | ||
if (!input->ReadVarint32(&length)) return false; | ||
@@ -1836,11 +1847,9 @@ | ||
auto ptr = reinterpret_cast<uint8*>(&message_data[0]); | ||
ptr = io::CodedOutputStream::WriteVarint32ToArray(length, ptr); | ||
if (!input->ReadRaw(ptr, length)) return false; | ||
+ state = State::kHasPayload; | ||
} else { | ||
- // Already saw type_id, so we can parse this directly. | ||
- if (!ms.ParseField(last_type_id, input)) { | ||
- return false; | ||
- } | ||
+ if (!ms.SkipField(tag, input)) return false; | ||
} | ||
|
||
break; | ||
diff --color -ruN a/third_party/protobuf/src/google/protobuf/wire_format_unittest.cc b/third_party/protobuf/src/google/protobuf/wire_format_unittest.cc | ||
--- a/third_party/protobuf/src/google/protobuf/wire_format_unittest.cc 2024-03-27 22:28:55.000000000 +0000 | ||
+++ b/third_party/protobuf/src/google/protobuf/wire_format_unittest.cc 2024-09-18 11:49:16.394834273 +0000 | ||
@@ -47,6 +47,7 @@ | ||
#include <google/protobuf/io/zero_copy_stream_impl.h> | ||
#include <google/protobuf/io/zero_copy_stream_impl_lite.h> | ||
#include <google/protobuf/descriptor.h> | ||
+#include <google/protobuf/dynamic_message.h> | ||
#include <google/protobuf/wire_format_lite.h> | ||
#include <google/protobuf/testing/googletest.h> | ||
#include <gmock/gmock.h> | ||
@@ -585,30 +586,56 @@ | ||
EXPECT_EQ(message_set.DebugString(), dynamic_message_set.DebugString()); | ||
} | ||
|
||
-TEST(WireFormatTest, ParseMessageSetWithReverseTagOrder) { | ||
+namespace { | ||
+std::string BuildMessageSetItemStart() { | ||
std::string data; | ||
{ | ||
- unittest::TestMessageSetExtension1 message; | ||
- message.set_i(123); | ||
- // Build a MessageSet manually with its message content put before its | ||
- // type_id. | ||
io::StringOutputStream output_stream(&data); | ||
io::CodedOutputStream coded_output(&output_stream); | ||
coded_output.WriteTag(WireFormatLite::kMessageSetItemStartTag); | ||
+ } | ||
+ return data; | ||
+} | ||
+std::string BuildMessageSetItemEnd() { | ||
+ std::string data; | ||
+ { | ||
+ io::StringOutputStream output_stream(&data); | ||
+ io::CodedOutputStream coded_output(&output_stream); | ||
+ coded_output.WriteTag(WireFormatLite::kMessageSetItemEndTag); | ||
+ } | ||
+ return data; | ||
+} | ||
+std::string BuildMessageSetTestExtension1(int value = 123) { | ||
+ std::string data; | ||
+ { | ||
+ unittest::TestMessageSetExtension1 message; | ||
+ message.set_i(value); | ||
+ io::StringOutputStream output_stream(&data); | ||
+ io::CodedOutputStream coded_output(&output_stream); | ||
// Write the message content first. | ||
WireFormatLite::WriteTag(WireFormatLite::kMessageSetMessageNumber, | ||
WireFormatLite::WIRETYPE_LENGTH_DELIMITED, | ||
&coded_output); | ||
coded_output.WriteVarint32(message.ByteSizeLong()); | ||
message.SerializeWithCachedSizes(&coded_output); | ||
- // Write the type id. | ||
- uint32 type_id = message.GetDescriptor()->extension(0)->number(); | ||
+ } | ||
+ return data; | ||
+} | ||
+std::string BuildMessageSetItemTypeId(int extension_number) { | ||
+ std::string data; | ||
+ { | ||
+ io::StringOutputStream output_stream(&data); | ||
+ io::CodedOutputStream coded_output(&output_stream); | ||
WireFormatLite::WriteUInt32(WireFormatLite::kMessageSetTypeIdNumber, | ||
- type_id, &coded_output); | ||
- coded_output.WriteTag(WireFormatLite::kMessageSetItemEndTag); | ||
+ extension_number, &coded_output); | ||
} | ||
+ return data; | ||
+} | ||
+void ValidateTestMessageSet(const std::string& test_case, | ||
+ const std::string& data) { | ||
+ SCOPED_TRACE(test_case); | ||
{ | ||
- proto2_wireformat_unittest::TestMessageSet message_set; | ||
+ ::proto2_wireformat_unittest::TestMessageSet message_set; | ||
ASSERT_TRUE(message_set.ParseFromString(data)); | ||
|
||
EXPECT_EQ(123, | ||
@@ -616,10 +643,15 @@ | ||
.GetExtension( | ||
unittest::TestMessageSetExtension1::message_set_extension) | ||
.i()); | ||
+ | ||
+ // Make sure it does not contain anything else. | ||
+ message_set.ClearExtension( | ||
+ unittest::TestMessageSetExtension1::message_set_extension); | ||
+ EXPECT_EQ(message_set.SerializeAsString(), ""); | ||
} | ||
{ | ||
// Test parse the message via Reflection. | ||
- proto2_wireformat_unittest::TestMessageSet message_set; | ||
+ ::proto2_wireformat_unittest::TestMessageSet message_set; | ||
io::CodedInputStream input(reinterpret_cast<const uint8*>(data.data()), | ||
data.size()); | ||
EXPECT_TRUE(WireFormat::ParseAndMergePartial(&input, &message_set)); | ||
@@ -631,6 +663,61 @@ | ||
unittest::TestMessageSetExtension1::message_set_extension) | ||
.i()); | ||
} | ||
+ { | ||
+ // Test parse the message via DynamicMessage. | ||
+ DynamicMessageFactory factory; | ||
+ std::unique_ptr<Message> msg( | ||
+ factory | ||
+ .GetPrototype( | ||
+ ::proto2_wireformat_unittest::TestMessageSet::descriptor()) | ||
+ ->New()); | ||
+ msg->ParseFromString(data); | ||
+ auto* reflection = msg->GetReflection(); | ||
+ std::vector<const FieldDescriptor*> fields; | ||
+ reflection->ListFields(*msg, &fields); | ||
+ ASSERT_EQ(fields.size(), 1); | ||
+ const auto& sub = reflection->GetMessage(*msg, fields[0]); | ||
+ reflection = sub.GetReflection(); | ||
+ EXPECT_EQ(123, reflection->GetInt32( | ||
+ sub, sub.GetDescriptor()->FindFieldByName("i"))); | ||
+ } | ||
+} | ||
+} // namespace | ||
+ | ||
+TEST(WireFormatTest, ParseMessageSetWithAnyTagOrder) { | ||
+ std::string start = BuildMessageSetItemStart(); | ||
+ std::string end = BuildMessageSetItemEnd(); | ||
+ std::string id = BuildMessageSetItemTypeId( | ||
+ unittest::TestMessageSetExtension1::descriptor()->extension(0)->number()); | ||
+ std::string message = BuildMessageSetTestExtension1(); | ||
+ | ||
+ ValidateTestMessageSet("id + message", start + id + message + end); | ||
+ ValidateTestMessageSet("message + id", start + message + id + end); | ||
+} | ||
+ | ||
+TEST(WireFormatTest, ParseMessageSetWithDuplicateTags) { | ||
+ std::string start = BuildMessageSetItemStart(); | ||
+ std::string end = BuildMessageSetItemEnd(); | ||
+ std::string id = BuildMessageSetItemTypeId( | ||
+ unittest::TestMessageSetExtension1::descriptor()->extension(0)->number()); | ||
+ std::string other_id = BuildMessageSetItemTypeId(123456); | ||
+ std::string message = BuildMessageSetTestExtension1(); | ||
+ std::string other_message = BuildMessageSetTestExtension1(321); | ||
+ | ||
+ // Double id | ||
+ ValidateTestMessageSet("id + other_id + message", | ||
+ start + id + other_id + message + end); | ||
+ ValidateTestMessageSet("id + message + other_id", | ||
+ start + id + message + other_id + end); | ||
+ ValidateTestMessageSet("message + id + other_id", | ||
+ start + message + id + other_id + end); | ||
+ // Double message | ||
+ ValidateTestMessageSet("id + message + other_message", | ||
+ start + id + message + other_message + end); | ||
+ ValidateTestMessageSet("message + id + other_message", | ||
+ start + message + id + other_message + end); | ||
+ ValidateTestMessageSet("message + other_message + id", | ||
+ start + message + other_message + id + end); | ||
} | ||
|
||
void SerializeReverseOrder( |
Oops, something went wrong.