Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix IAsyncEnumerable controller methods to allow setting headers #57924

Open
wants to merge 3 commits into
base: release/9.0-rc2
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 16 additions & 5 deletions src/Http/Http.Extensions/src/HttpResponseJsonExtensions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
using System.Text.Json.Serialization;
using System.Text.Json.Serialization.Metadata;
using Microsoft.AspNetCore.Http.Json;
using Microsoft.AspNetCore.Internal;
using Microsoft.Extensions.DependencyInjection;
using Microsoft.Extensions.Options;

Expand Down Expand Up @@ -91,7 +92,9 @@ public static Task WriteAsJsonAsync<TValue>(
response.ContentType = contentType ?? ContentTypeConstants.JsonContentTypeWithCharset;

var startTask = Task.CompletedTask;
if (!response.HasStarted)
// Don't call StartAsync for IAsyncEnumerable. Headers might be set at the beginning of the generator which isn't invoked until
// JsonSerializer starts iterating over the IAsyncEnumerable.
if (!response.HasStarted && !AsyncEnumerableHelper.IsIAsyncEnumerable(typeof(TValue)))
{
// Flush headers before starting Json serialization. This avoids an extra layer of buffering before the first flush.
startTask = response.StartAsync(cancellationToken);
Expand Down Expand Up @@ -132,7 +135,9 @@ public static Task WriteAsJsonAsync<TValue>(
response.ContentType = contentType ?? ContentTypeConstants.JsonContentTypeWithCharset;

var startTask = Task.CompletedTask;
if (!response.HasStarted)
// Don't call StartAsync for IAsyncEnumerable. Headers might be set at the beginning of the generator which isn't invoked until
// JsonSerializer starts iterating over the IAsyncEnumerable.
if (!response.HasStarted && !AsyncEnumerableHelper.IsIAsyncEnumerable(typeof(TValue)))
{
// Flush headers before starting Json serialization. This avoids an extra layer of buffering before the first flush.
startTask = response.StartAsync(cancellationToken);
Expand Down Expand Up @@ -185,7 +190,9 @@ public static Task WriteAsJsonAsync(
response.ContentType = contentType ?? ContentTypeConstants.JsonContentTypeWithCharset;

var startTask = Task.CompletedTask;
if (!response.HasStarted)
// Don't call StartAsync for IAsyncEnumerable. Headers might be set at the beginning of the generator which isn't invoked until
// JsonSerializer starts iterating over the IAsyncEnumerable.
if (!response.HasStarted && value is not null && !AsyncEnumerableHelper.IsIAsyncEnumerable(value.GetType()))
Copy link
Member

@halter73 halter73 Sep 18, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we add a method to AsyncEnumerableHelper like the following?

Task StartResponseIfPossibleAsync(this HttpResponse response, Type? typeToSerialize)

It seems like it would reduce a lot of repetitive HasStarted and null checks.

{
// Flush headers before starting Json serialization. This avoids an extra layer of buffering before the first flush.
startTask = response.StartAsync(cancellationToken);
Expand Down Expand Up @@ -305,7 +312,9 @@ public static Task WriteAsJsonAsync(
response.ContentType = contentType ?? ContentTypeConstants.JsonContentTypeWithCharset;

var startTask = Task.CompletedTask;
if (!response.HasStarted)
// Don't call StartAsync for IAsyncEnumerable. Headers might be set at the beginning of the generator which isn't invoked until
// JsonSerializer starts iterating over the IAsyncEnumerable.
if (!response.HasStarted && !AsyncEnumerableHelper.IsIAsyncEnumerable(type))
{
// Flush headers before starting Json serialization. This avoids an extra layer of buffering before the first flush.
startTask = response.StartAsync(cancellationToken);
Expand Down Expand Up @@ -368,7 +377,9 @@ public static Task WriteAsJsonAsync(
response.ContentType = contentType ?? ContentTypeConstants.JsonContentTypeWithCharset;

var startTask = Task.CompletedTask;
if (!response.HasStarted)
// Don't call StartAsync for IAsyncEnumerable. Headers might be set at the beginning of the generator which isn't invoked until
// JsonSerializer starts iterating over the IAsyncEnumerable.
if (!response.HasStarted && !AsyncEnumerableHelper.IsIAsyncEnumerable(type))
{
// Flush headers before starting Json serialization. This avoids an extra layer of buffering before the first flush.
startTask = response.StartAsync(cancellationToken);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
<Compile Remove="$(RepoRoot)src\Components\Endpoints\src\FormMapping\HttpContextFormDataProvider.cs" LinkBase="SharedFormMapping" />
<Compile Remove="$(RepoRoot)src\Components\Endpoints\src\FormMapping\BrowserFileFromFormFile.cs" LinkBase="SharedFormMapping" />
<Compile Include="$(SharedSourceRoot)ContentTypeConstants.cs" LinkBase="Shared" />
<Compile Include="$(SharedSourceRoot)Reflection\AsyncEnumerableHelper.cs" LinkBase="Shared" />
</ItemGroup>

<ItemGroup>
Expand Down
117 changes: 116 additions & 1 deletion src/Http/Http.Extensions/test/HttpResponseJsonExtensionsTests.cs
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.

using System.IO.Pipelines;
using System.Runtime.CompilerServices;
using System.Text;
using System.Text.Json;
using System.Text.Json.Serialization;
using System.Text.Json.Serialization.Metadata;
using Microsoft.AspNetCore.InternalTesting;
using Microsoft.AspNetCore.Http.Features;

#nullable enable

Expand Down Expand Up @@ -481,6 +482,83 @@ public async Task WriteAsJsonAsync_NullValue_WithJsonTypeInfo_JsonResponse()
Assert.Equal("null", data);
}

[Fact]
public async Task WriteAsJsonAsyncGeneric_AsyncEnumerableStartAsyncNotCalled()
{
// Arrange
var body = new MemoryStream();
var context = new DefaultHttpContext();
context.Response.Body = body;
var responseBodyFeature = new TestHttpResponseBodyFeature(context.Features.GetRequiredFeature<IHttpResponseBodyFeature>());
context.Features.Set<IHttpResponseBodyFeature>(responseBodyFeature);

// Act
await context.Response.WriteAsJsonAsync(AsyncEnumerable());

// Assert
Assert.Equal(ContentTypeConstants.JsonContentTypeWithCharset, context.Response.ContentType);
Assert.Equal(StatusCodes.Status200OK, context.Response.StatusCode);

Assert.Equal("[1,2]", Encoding.UTF8.GetString(body.ToArray()));

async IAsyncEnumerable<int> AsyncEnumerable()
{
Assert.False(responseBodyFeature.StartCalled);
await Task.Yield();
Comment on lines +506 to +507
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wouldn't it be better to test this after yielding?

Suggested change
Assert.False(responseBodyFeature.StartCalled);
await Task.Yield();
await Task.Yield();
Assert.False(responseBodyFeature.StartCalled);

I wonder if we should add an end-to-end test where we try adding a header after yielding.

yield return 1;
yield return 2;
}
}

[Fact]
public async Task WriteAsJsonAsync_AsyncEnumerableStartAsyncNotCalled()
{
// Arrange
var body = new MemoryStream();
var context = new DefaultHttpContext();
context.Response.Body = body;
var responseBodyFeature = new TestHttpResponseBodyFeature(context.Features.GetRequiredFeature<IHttpResponseBodyFeature>());
context.Features.Set<IHttpResponseBodyFeature>(responseBodyFeature);

// Act
await context.Response.WriteAsJsonAsync(AsyncEnumerable(), typeof(IAsyncEnumerable<int>));

// Assert
Assert.Equal(ContentTypeConstants.JsonContentTypeWithCharset, context.Response.ContentType);
Assert.Equal(StatusCodes.Status200OK, context.Response.StatusCode);

Assert.Equal("[1,2]", Encoding.UTF8.GetString(body.ToArray()));

async IAsyncEnumerable<int> AsyncEnumerable()
{
Assert.False(responseBodyFeature.StartCalled);
await Task.Yield();
yield return 1;
yield return 2;
}
}

[Fact]
public async Task WriteAsJsonAsync_StartAsyncCalled()
{
// Arrange
var body = new MemoryStream();
var context = new DefaultHttpContext();
context.Response.Body = body;
var responseBodyFeature = new TestHttpResponseBodyFeature(context.Features.GetRequiredFeature<IHttpResponseBodyFeature>());
context.Features.Set<IHttpResponseBodyFeature>(responseBodyFeature);

// Act
await context.Response.WriteAsJsonAsync(new int[] {1, 2}, typeof(int[]));

// Assert
Assert.Equal(ContentTypeConstants.JsonContentTypeWithCharset, context.Response.ContentType);
Assert.Equal(StatusCodes.Status200OK, context.Response.StatusCode);

Assert.Equal("[1,2]", Encoding.UTF8.GetString(body.ToArray()));
Assert.True(responseBodyFeature.StartCalled);
}

public class TestObject
{
public string? StringProperty { get; set; }
Expand Down Expand Up @@ -530,4 +608,41 @@ public override ValueTask WriteAsync(ReadOnlyMemory<byte> buffer, CancellationTo
return new ValueTask(tcs.Task);
}
}

public class TestHttpResponseBodyFeature : IHttpResponseBodyFeature
{
private readonly IHttpResponseBodyFeature _inner;

public bool StartCalled;

public TestHttpResponseBodyFeature(IHttpResponseBodyFeature inner)
{
_inner = inner;
}

public Stream Stream => _inner.Stream;

public PipeWriter Writer => _inner.Writer;

public Task CompleteAsync()
{
return _inner.CompleteAsync();
}

public void DisableBuffering()
{
_inner.DisableBuffering();
}

public Task SendFileAsync(string path, long offset, long? count, CancellationToken cancellationToken = default)
{
return _inner.SendFileAsync(path, offset, count, cancellationToken);
}

public Task StartAsync(CancellationToken cancellationToken = default)
{
StartCalled = true;
return _inner.StartAsync(cancellationToken);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
using System.Text.Json;
using System.Text.Json.Serialization.Metadata;
using Microsoft.AspNetCore.Http;
using Microsoft.AspNetCore.Internal;

namespace Microsoft.AspNetCore.Mvc.Formatters;

Expand Down Expand Up @@ -88,10 +89,17 @@ public sealed override async Task WriteResponseBodyAsync(OutputFormatterWriteCon
try
{
var responseWriter = httpContext.Response.BodyWriter;

if (!httpContext.Response.HasStarted)
{
// Flush headers before starting Json serialization. This avoids an extra layer of buffering before the first flush.
await httpContext.Response.StartAsync();
var typeToCheck = context.ObjectType ?? context.Object?.GetType();
// Don't call StartAsync for IAsyncEnumerable methods. Headers might be set in the controller method which isn't invoked until
// JsonSerializer starts iterating over the IAsyncEnumerable.
if (typeToCheck is not null && !AsyncEnumerableHelper.IsIAsyncEnumerable(typeToCheck))
{
// Flush headers before starting Json serialization. This avoids an extra layer of buffering before the first flush.
await httpContext.Response.StartAsync();
}
}

if (jsonTypeInfo is not null)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,8 +68,13 @@ public async Task ExecuteAsync(ActionContext context, JsonResult result)
var responseWriter = response.BodyWriter;
if (!response.HasStarted)
{
// Flush headers before starting Json serialization. This avoids an extra layer of buffering before the first flush.
await response.StartAsync();
// Don't call StartAsync for IAsyncEnumerable methods. Headers might be set in the controller method which isn't invoked until
// JsonSerializer starts iterating over the IAsyncEnumerable.
if (!AsyncEnumerableHelper.IsIAsyncEnumerable(objectType))
{
// Flush headers before starting Json serialization. This avoids an extra layer of buffering before the first flush.
await response.StartAsync();
}
}

await JsonSerializer.SerializeAsync(responseWriter, value, objectType, jsonSerializerOptions, context.HttpContext.RequestAborted);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ Microsoft.AspNetCore.Mvc.RouteAttribute</Description>
<Compile Include="$(SharedSourceRoot)HttpParseResult.cs" LinkBase="Shared" />
<Compile Include="$(SharedSourceRoot)HttpRuleParser.cs" LinkBase="Shared" />
<Compile Include="$(SharedSourceRoot)Json\JsonSerializerExtensions.cs" LinkBase="Shared" />
<Compile Include="$(SharedSourceRoot)Reflection\AsyncEnumerableHelper.cs" LinkBase="Shared" />
</ItemGroup>

<ItemGroup>
Expand Down
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.

using System.IO.Pipelines;
using System.Runtime.CompilerServices;
using System.Text;
using System.Text.Json;
using System.Text.Json.Serialization;
using System.Text.Json.Serialization.Metadata;
using Microsoft.AspNetCore.Http.Features;
using Microsoft.AspNetCore.InternalTesting;
using Microsoft.DotNet.RemoteExecutor;
using Microsoft.Extensions.Primitives;
Expand Down Expand Up @@ -113,6 +115,121 @@ public async Task WriteResponseBodyAsync_ForLargeAsyncEnumerable()
Assert.Equal(expected.ToArray(), body.ToArray());
}

// Regression test: https://github.com/dotnet/aspnetcore/issues/57895
[Fact]
public async Task WriteResponseBodyAsync_AsyncEnumerableStartAsyncNotCalled()
{
// Arrange
TestHttpResponseBodyFeature responseBodyFeature = null;
var expected = new MemoryStream();
await JsonSerializer.SerializeAsync(expected, AsyncEnumerable(), new JsonSerializerOptions(JsonSerializerDefaults.Web));
var formatter = GetOutputFormatter();
var mediaType = MediaTypeHeaderValue.Parse("application/json; charset=utf-8");
var encoding = CreateOrGetSupportedEncoding(formatter, "utf-8", isDefaultEncoding: true);

var body = new MemoryStream();

var actionContext = GetActionContext(mediaType, body);
responseBodyFeature = new TestHttpResponseBodyFeature(actionContext.HttpContext.Features.Get<IHttpResponseBodyFeature>());
actionContext.HttpContext.Features.Set<IHttpResponseBodyFeature>(responseBodyFeature);

var asyncEnumerable = AsyncEnumerable();
var outputFormatterContext = new OutputFormatterWriteContext(
actionContext.HttpContext,
new TestHttpResponseStreamWriterFactory().CreateWriter,
asyncEnumerable.GetType(),
asyncEnumerable)
{
ContentType = new StringSegment(mediaType.ToString()),
};

// Act
await formatter.WriteResponseBodyAsync(outputFormatterContext, Encoding.GetEncoding("utf-8"));

// Assert
Assert.Equal(expected.ToArray(), body.ToArray());

async IAsyncEnumerable<int> AsyncEnumerable()
{
// StartAsync shouldn't be called by SystemTestJsonOutputFormatter when using IAsyncEnumerable
// This allows Controller methods to set Headers, etc.
Assert.False(responseBodyFeature?.StartCalled ?? false);
await Task.Yield();
yield return 1;
}
}

[Fact]
public async Task WriteResponseBodyAsync_StartAsyncCalled()
{
// Arrange
TestHttpResponseBodyFeature responseBodyFeature = null;
var expected = new MemoryStream();
await JsonSerializer.SerializeAsync(expected, 1, new JsonSerializerOptions(JsonSerializerDefaults.Web));
var formatter = GetOutputFormatter();
var mediaType = MediaTypeHeaderValue.Parse("application/json; charset=utf-8");
var encoding = CreateOrGetSupportedEncoding(formatter, "utf-8", isDefaultEncoding: true);

var body = new MemoryStream();

var actionContext = GetActionContext(mediaType, body);
responseBodyFeature = new TestHttpResponseBodyFeature(actionContext.HttpContext.Features.Get<IHttpResponseBodyFeature>());
actionContext.HttpContext.Features.Set<IHttpResponseBodyFeature>(responseBodyFeature);

var outputFormatterContext = new OutputFormatterWriteContext(
actionContext.HttpContext,
new TestHttpResponseStreamWriterFactory().CreateWriter,
typeof(int),
1)
{
ContentType = new StringSegment(mediaType.ToString()),
};

// Act
await formatter.WriteResponseBodyAsync(outputFormatterContext, Encoding.GetEncoding("utf-8"));

// Assert
Assert.Equal(expected.ToArray(), body.ToArray());
Assert.True(responseBodyFeature.StartCalled);
}

public class TestHttpResponseBodyFeature : IHttpResponseBodyFeature
{
private readonly IHttpResponseBodyFeature _inner;

public bool StartCalled;

public TestHttpResponseBodyFeature(IHttpResponseBodyFeature inner)
{
_inner = inner;
}

public Stream Stream => _inner.Stream;

public PipeWriter Writer => _inner.Writer;

public Task CompleteAsync()
{
return _inner.CompleteAsync();
}

public void DisableBuffering()
{
_inner.DisableBuffering();
}

public Task SendFileAsync(string path, long offset, long? count, CancellationToken cancellationToken = default)
{
return _inner.SendFileAsync(path, offset, count, cancellationToken);
}

public Task StartAsync(CancellationToken cancellationToken = default)
{
StartCalled = true;
return _inner.StartAsync(cancellationToken);
}
}

[Fact]
public async Task WriteResponseBodyAsync_AsyncEnumerableConnectionCloses()
{
Expand Down
Loading
Loading