[Fix] Don't dispose streams in DefaultRestClient (#2652)

* Duplicate file streams before sending

* Other code needs to dispose their objects

* Another resource to dispose

* Stop disposing and copying streams in SendAsync

* Fix inverted boolean check

Co-authored-by: Dmitry <dimson-n@users.noreply.github.com>

* Await results for using statement to work

---------

Co-authored-by: Dmitry <dimson-n@users.noreply.github.com>
This commit is contained in:
Ben Reilly
2023-04-14 19:07:12 -04:00
committed by GitHub
parent 69cce5baf2
commit 84431decfd
4 changed files with 66 additions and 59 deletions

View File

@@ -1068,11 +1068,11 @@ namespace Discord.Rest
/// <returns>
/// A task that represents the asynchronous creation operation. The task result contains the created sticker.
/// </returns>
public Task<CustomSticker> CreateStickerAsync(string name, string path, IEnumerable<string> tags, string description = null,
public async Task<CustomSticker> CreateStickerAsync(string name, string path, IEnumerable<string> tags, string description = null,
RequestOptions options = null)
{
var fs = File.OpenRead(path);
return CreateStickerAsync(name, fs, Path.GetFileName(fs.Name), tags, description,options);
using var fs = File.OpenRead(path);
return await CreateStickerAsync(name, fs, Path.GetFileName(fs.Name), tags, description,options);
}
/// <summary>
/// Creates a new sticker in this guild

View File

@@ -228,6 +228,7 @@ namespace Discord.Rest
fileName ??= Path.GetFileName(filePath);
Preconditions.NotNullOrEmpty(fileName, nameof(fileName), "File Name must not be empty or null");
using var fileStream = !string.IsNullOrEmpty(filePath) ? new MemoryStream(File.ReadAllBytes(filePath), false) : null;
var args = new API.Rest.CreateWebhookMessageParams
{
Content = text,
@@ -235,7 +236,7 @@ namespace Discord.Rest
IsTTS = isTTS,
Embeds = embeds.Select(x => x.ToModel()).ToArray(),
Components = component?.Components.Select(x => new API.ActionRowComponent(x)).ToArray() ?? Optional<API.ActionRowComponent[]>.Unspecified,
File = !string.IsNullOrEmpty(filePath) ? new MultipartFile(new MemoryStream(File.ReadAllBytes(filePath), false), fileName) : Optional<MultipartFile>.Unspecified
File = fileStream != null ? new MultipartFile(fileStream, fileName) : Optional<MultipartFile>.Unspecified
};
if (ephemeral)

View File

@@ -1,7 +1,7 @@
using Discord.Net.Converters;
using Newtonsoft.Json;
using System;
using System.Collections.Generic;
using System.Collections.Immutable;
using System.Globalization;
using System.IO;
using System.Linq;
@@ -101,18 +101,39 @@ namespace Discord.Net.Rest
IEnumerable<KeyValuePair<string, IEnumerable<string>>> requestHeaders = null)
{
string uri = Path.Combine(_baseUrl, endpoint);
using (var restRequest = new HttpRequestMessage(GetMethod(method), uri))
{
// HttpRequestMessage implements IDisposable but we do not need to dispose it as it merely disposes of its Content property,
// which we can do as needed. And regarding that, we do not want to take responsibility for disposing of content provided by
// the caller of this function, since it's possible that the caller wants to reuse it or is forced to reuse it because of a
// 429 response. Therefore, by convention, we only dispose the content objects created in this function (if any).
//
// See this comment explaining why this is safe: https://github.com/aspnet/Security/issues/886#issuecomment-229181249
// See also the source for HttpRequestMessage: https://github.com/microsoft/referencesource/blob/master/System/net/System/Net/Http/HttpRequestMessage.cs
#pragma warning disable IDISP004
var restRequest = new HttpRequestMessage(GetMethod(method), uri);
#pragma warning restore IDISP004
if (reason != null)
restRequest.Headers.Add("X-Audit-Log-Reason", Uri.EscapeDataString(reason));
if (requestHeaders != null)
foreach (var header in requestHeaders)
restRequest.Headers.Add(header.Key, header.Value);
var content = new MultipartFormDataContent("Upload----" + DateTime.Now.ToString(CultureInfo.InvariantCulture));
MemoryStream memoryStream = null;
if (multipartParams != null)
static StreamContent GetStreamContent(Stream stream)
{
foreach (var p in multipartParams)
if (stream.CanSeek)
{
// Reset back to the beginning; it may have been used elsewhere or in a previous request.
stream.Position = 0;
}
#pragma warning disable IDISP004
return new StreamContent(stream);
#pragma warning restore IDISP004
}
foreach (var p in multipartParams ?? ImmutableDictionary<string, object>.Empty)
{
switch (p.Value)
{
@@ -122,22 +143,10 @@ namespace Discord.Net.Rest
case byte[] byteArrayValue:
{ content.Add(new ByteArrayContent(byteArrayValue), p.Key); continue; }
case Stream streamValue:
{ content.Add(new StreamContent(streamValue), p.Key); continue; }
{ content.Add(GetStreamContent(streamValue), p.Key); continue; }
case MultipartFile fileValue:
{
var stream = fileValue.Stream;
if (!stream.CanSeek)
{
memoryStream = new MemoryStream();
await stream.CopyToAsync(memoryStream).ConfigureAwait(false);
memoryStream.Position = 0;
#pragma warning disable IDISP001
stream = memoryStream;
#pragma warning restore IDISP001
}
var streamContent = new StreamContent(stream);
var extension = fileValue.Filename.Split('.').Last();
var streamContent = GetStreamContent(fileValue.Stream);
if (fileValue.ContentType != null)
streamContent.Headers.ContentType = new MediaTypeHeaderValue(fileValue.ContentType);
@@ -151,12 +160,9 @@ namespace Discord.Net.Rest
throw new InvalidOperationException($"Unsupported param type \"{p.Value.GetType().Name}\".");
}
}
}
restRequest.Content = content;
var result = await SendInternalAsync(restRequest, cancelToken, headerOnly).ConfigureAwait(false);
memoryStream?.Dispose();
return result;
}
return await SendInternalAsync(restRequest, cancelToken, headerOnly).ConfigureAwait(false);
}
private async Task<RestResponse> SendInternalAsync(HttpRequestMessage request, CancellationToken cancelToken, bool headerOnly)

View File

@@ -1558,11 +1558,11 @@ namespace Discord.WebSocket
/// <returns>
/// A task that represents the asynchronous creation operation. The task result contains the created sticker.
/// </returns>
public Task<SocketCustomSticker> CreateStickerAsync(string name, string path, IEnumerable<string> tags, string description = null,
public async Task<SocketCustomSticker> CreateStickerAsync(string name, string path, IEnumerable<string> tags, string description = null,
RequestOptions options = null)
{
var fs = File.OpenRead(path);
return CreateStickerAsync(name, fs, Path.GetFileName(fs.Name), tags, description, options);
using var fs = File.OpenRead(path);
return await CreateStickerAsync(name, fs, Path.GetFileName(fs.Name), tags, description, options);
}
/// <summary>
/// Creates a new sticker in this guild