Skip to content

Commit

Permalink
Ensure all server variables specify the right character set.
Browse files Browse the repository at this point in the history
Signed-off-by: Bradley Grainger <[email protected]>
  • Loading branch information
bgrainger committed Jul 17, 2024
1 parent 3326a0e commit 3f50c14
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 11 deletions.
5 changes: 3 additions & 2 deletions src/MySqlConnector/Core/ServerSession.cs
Original file line number Diff line number Diff line change
Expand Up @@ -538,7 +538,8 @@ public async Task DisposeAsync(IOBehavior ioBehavior, CancellationToken cancella
if (m_useCompression)
m_payloadHandler = new CompressedPayloadHandler(m_payloadHandler.ByteHandler);

if (ok.ClientCharacterSet != (ServerVersion.Version >= ServerVersions.SupportsUtf8Mb4 ? "utf8mb4" : "utf8"))
// send 'SET NAMES' to set the character set and collation unless the server reports that it's already using the desired character set (e.g., MariaDB >= 11.5)
if (ok.NewCharacterSet != (ServerVersion.Version >= ServerVersions.SupportsUtf8Mb4 ? CharacterSet.Utf8Mb4Binary : CharacterSet.Utf8Mb3Binary))
{
// set 'collation_connection' to the server default
await SendAsync(m_setNamesPayload, ioBehavior, cancellationToken).ConfigureAwait(false);
Expand All @@ -550,7 +551,7 @@ public async Task DisposeAsync(IOBehavior ioBehavior, CancellationToken cancella
{
await GetRealServerDetailsAsync(ioBehavior, CancellationToken.None).ConfigureAwait(false);
}
else if (ok.ConnectionId is int newConnectionId && newConnectionId != ConnectionId)
else if (ok.NewConnectionId is int newConnectionId && newConnectionId != ConnectionId)
{
Log.ChangingConnectionId(m_logger, Id, ConnectionId, newConnectionId, ServerVersion.OriginalString, ServerVersion.OriginalString);
ConnectionId = newConnectionId;
Expand Down
39 changes: 30 additions & 9 deletions src/MySqlConnector/Protocol/Payloads/OkPayload.cs
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@ internal sealed class OkPayload
public int WarningCount { get; }
public string? StatusInfo { get; }
public string? NewSchema { get; }
public string? ClientCharacterSet { get; }
public int? ConnectionId { get; }
public CharacterSet? NewCharacterSet { get; }
public int? NewConnectionId { get; }

public const byte Signature = 0x00;

Expand Down Expand Up @@ -61,7 +61,9 @@ public static void Verify(ReadOnlySpan<byte> span, IServerCapabilities serverCap
var serverStatus = (ServerStatus) reader.ReadUInt16();
var warningCount = (int) reader.ReadUInt16();
string? newSchema = null;
string? clientCharacterSet = null;
CharacterSet clientCharacterSet = default;
CharacterSet connectionCharacterSet = default;
CharacterSet resultsCharacterSet = default;
int? connectionId = null;
ReadOnlySpan<byte> statusBytes;

Expand Down Expand Up @@ -93,7 +95,21 @@ public static void Verify(ReadOnlySpan<byte> span, IServerCapabilities serverCap
var systemVariableValue = systemVariableValueLength == -1 ? default : reader.ReadByteString(systemVariableValueLength);
if (systemVariableName.SequenceEqual("character_set_client"u8) && systemVariableValueLength != 0)
{
clientCharacterSet = Encoding.ASCII.GetString(systemVariableValue);
clientCharacterSet = systemVariableValue.SequenceEqual("utf8mb4"u8) ? CharacterSet.Utf8Mb4Binary :
systemVariableValue.SequenceEqual("utf8"u8) ? CharacterSet.Utf8Mb3Binary :
CharacterSet.None;
}
else if (systemVariableName.SequenceEqual("character_set_connection"u8) && systemVariableValueLength != 0)
{
connectionCharacterSet = systemVariableValue.SequenceEqual("utf8mb4"u8) ? CharacterSet.Utf8Mb4Binary :
systemVariableValue.SequenceEqual("utf8"u8) ? CharacterSet.Utf8Mb3Binary :
CharacterSet.None;
}
else if (systemVariableName.SequenceEqual("character_set_results"u8) && systemVariableValueLength != 0)
{
resultsCharacterSet = systemVariableValue.SequenceEqual("utf8mb4"u8) ? CharacterSet.Utf8Mb4Binary :
systemVariableValue.SequenceEqual("utf8"u8) ? CharacterSet.Utf8Mb3Binary :
CharacterSet.None;
}
else if (systemVariableName.SequenceEqual("connection_id"u8))
{
Expand Down Expand Up @@ -129,32 +145,37 @@ public static void Verify(ReadOnlySpan<byte> span, IServerCapabilities serverCap
{
var statusInfo = statusBytes.Length == 0 ? null : Encoding.UTF8.GetString(statusBytes);

if (affectedRowCount == 0 && lastInsertId == 0 && warningCount == 0 && statusInfo is null && newSchema is null && clientCharacterSet is null && connectionId is null)
// detect the connection character set as utf8mb4 (or utf8) if all three system variables are set to the same value
var characterSet = clientCharacterSet == CharacterSet.Utf8Mb4Binary && connectionCharacterSet == CharacterSet.Utf8Mb4Binary && resultsCharacterSet == CharacterSet.Utf8Mb4Binary ? CharacterSet.Utf8Mb4Binary :
clientCharacterSet == CharacterSet.Utf8Mb3Binary && connectionCharacterSet == CharacterSet.Utf8Mb3Binary && resultsCharacterSet == CharacterSet.Utf8Mb3Binary ? CharacterSet.Utf8Mb3Binary :
CharacterSet.None;

if (affectedRowCount == 0 && lastInsertId == 0 && warningCount == 0 && statusInfo is null && newSchema is null && clientCharacterSet is CharacterSet.None && connectionId is null)
{
if (serverStatus == ServerStatus.AutoCommit)
return s_autoCommitOk;
if (serverStatus == (ServerStatus.AutoCommit | ServerStatus.SessionStateChanged))
return s_autoCommitSessionStateChangedOk;
}

return new OkPayload(affectedRowCount, lastInsertId, serverStatus, warningCount, statusInfo, newSchema, clientCharacterSet, connectionId);
return new OkPayload(affectedRowCount, lastInsertId, serverStatus, warningCount, statusInfo, newSchema, characterSet, connectionId);
}
else
{
return null;
}
}

private OkPayload(ulong affectedRowCount, ulong lastInsertId, ServerStatus serverStatus, int warningCount, string? statusInfo, string? newSchema, string? clientCharacterSet, int? connectionId)
private OkPayload(ulong affectedRowCount, ulong lastInsertId, ServerStatus serverStatus, int warningCount, string? statusInfo, string? newSchema, CharacterSet newCharacterSet, int? connectionId)
{
AffectedRowCount = affectedRowCount;
LastInsertId = lastInsertId;
ServerStatus = serverStatus;
WarningCount = warningCount;
StatusInfo = statusInfo;
NewSchema = newSchema;
ClientCharacterSet = clientCharacterSet;
ConnectionId = connectionId;
NewCharacterSet = newCharacterSet;
NewConnectionId = connectionId;
}

private static readonly OkPayload s_autoCommitOk = new(0, 0, ServerStatus.AutoCommit, 0, default, default, default, default);
Expand Down

0 comments on commit 3f50c14

Please sign in to comment.