From a667894b6c28aa73a20068a141d586407b8a1d04 Mon Sep 17 00:00:00 2001 From: Justin Boswell Date: Thu, 30 Nov 2023 23:07:03 -0800 Subject: [PATCH] Added support for template deduction guides --- cxxheaderparser/parser.py | 59 ++++++++++++++++++++++ tests/test_template.py | 103 ++++++++++++++++++++++++++++++++++++++ 2 files changed, 162 insertions(+) diff --git a/cxxheaderparser/parser.py b/cxxheaderparser/parser.py index 65be148..63183de 100644 --- a/cxxheaderparser/parser.py +++ b/cxxheaderparser/parser.py @@ -2563,6 +2563,54 @@ def _parse_operator_conversion( _class_enum_stage2 = {":", "final", "explicit", "{"} + def _maybe_parse_deduction_guide( + self, + parsed_type: Type, + mods: ParsedTypeModifiers, + doxygen: typing.Optional[str], + template: TemplateDeclTypeVar, + location: Location, + ) -> bool: + """ + Parses a deduction guide, e.g. + template + MyClass(T) -> MyClass(U); + """ + + tok = self.lex.token_if("(") + if not tok: + return False + + # scan past any possible parameters... + toks = self._consume_balanced_tokens(tok) + # ... and check to see if the next token is an arrow, and thus a trailing return + if not self.lex.token_peek_if("ARROW"): + self.lex.return_tokens(toks) + return False + + # return the function params so they can be parsed properly, leaving out the leading ( as + # _parse_function expects that to have already been consumed + self.lex.return_tokens(toks[1:]) + + # use auto as the return type, which will force parsing of trailing return type + return_type = Type(PQName([AutoSpecifier()])) + pqname = parsed_type.typename + msvc_convention = self.lex.token_if_val(*self._msvc_conventions) + return self._parse_function( + mods, + return_type, + pqname, + None, + template, + doxygen, + location, + False, + False, + False, + False, + msvc_convention, + ) + def _parse_declarations( self, tok: LexToken, @@ -2593,6 +2641,17 @@ def _parse_declarations( ): return + # check to see if this is a class template deduction guide + if ( + parsed_type is not None + and not is_typedef + and not is_friend + and self._maybe_parse_deduction_guide( + parsed_type, mods, doxygen, template, location + ) + ): + return + var_ok = True if is_typedef: diff --git a/tests/test_template.py b/tests/test_template.py index 344e98f..31511fd 100644 --- a/tests/test_template.py +++ b/tests/test_template.py @@ -2163,3 +2163,106 @@ def test_member_class_template_specialization() -> None: ] ) ) + + +def test_template_deduction_guide() -> None: + content = """ + template > + Error(std::basic_string_view) -> Error; + """ + data = parse_string(content, cleandoc=True) + + assert data == ParsedData( + namespace=NamespaceScope( + functions=[ + Function( + return_type=Type( + typename=PQName( + segments=[ + NameSpecifier( + name="Error", + specialization=TemplateSpecialization( + args=[ + TemplateArgument( + arg=Type( + typename=PQName( + segments=[ + NameSpecifier(name="std"), + NameSpecifier( + name="string" + ), + ] + ) + ) + ) + ] + ), + ) + ] + ) + ), + name=PQName(segments=[NameSpecifier(name="Error")]), + parameters=[ + Parameter( + type=Type( + typename=PQName( + segments=[ + NameSpecifier(name="std"), + NameSpecifier( + name="basic_string_view", + specialization=TemplateSpecialization( + args=[ + TemplateArgument( + arg=Type( + typename=PQName( + segments=[ + NameSpecifier( + name="CharT" + ) + ] + ) + ) + ), + TemplateArgument( + arg=Type( + typename=PQName( + segments=[ + NameSpecifier( + name="Traits" + ) + ] + ) + ) + ), + ] + ), + ), + ] + ) + ) + ) + ], + has_trailing_return=True, + template=TemplateDecl( + params=[ + TemplateTypeParam(typekey="class", name="CharT"), + TemplateTypeParam( + typekey="class", + name="Traits", + default=Value( + tokens=[ + Token(value="std"), + Token(value="::"), + Token(value="char_traits"), + Token(value="<"), + Token(value="CharT"), + Token(value=">"), + ] + ), + ), + ] + ), + ) + ] + ) + )