Skip to content

Commit

Permalink
feat(NODE-6258): add signal support to cursor APIs
Browse files Browse the repository at this point in the history
  • Loading branch information
nbbeeken committed Jan 15, 2025
1 parent e2aa15c commit ab02d53
Show file tree
Hide file tree
Showing 24 changed files with 1,237 additions and 123 deletions.
20 changes: 12 additions & 8 deletions src/client-side-encryption/auto_encrypter.ts
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import { kDecorateResult } from '../constants';
import { getMongoDBClientEncryption } from '../deps';
import { MongoRuntimeError } from '../error';
import { MongoClient, type MongoClientOptions } from '../mongo_client';
import { type Abortable } from '../mongo_types';
import { MongoDBCollectionNamespace } from '../utils';
import { autoSelectSocketOptions } from './client_encryption';
import * as cryptoCallbacks from './crypto_callbacks';
Expand Down Expand Up @@ -372,8 +373,10 @@ export class AutoEncrypter {
async encrypt(
ns: string,
cmd: Document,
options: CommandOptions = {}
options: CommandOptions & Abortable = {}
): Promise<Document | Uint8Array> {
options.signal?.throwIfAborted();

if (this._bypassEncryption) {
// If `bypassAutoEncryption` has been specified, don't encrypt
return cmd;
Expand All @@ -398,7 +401,7 @@ export class AutoEncrypter {
socketOptions: autoSelectSocketOptions(this._client.s.options)
});

return deserialize(await stateMachine.execute(this, context, options.timeoutContext), {
return deserialize(await stateMachine.execute(this, context, options), {
promoteValues: false,
promoteLongs: false
});
Expand All @@ -407,7 +410,12 @@ export class AutoEncrypter {
/**
* Decrypt a command response
*/
async decrypt(response: Uint8Array, options: CommandOptions = {}): Promise<Uint8Array> {
async decrypt(
response: Uint8Array,
options: CommandOptions & Abortable = {}
): Promise<Uint8Array> {
options.signal?.throwIfAborted();

const context = this._mongocrypt.makeDecryptionContext(response);

context.id = this._contextCounter++;
Expand All @@ -419,11 +427,7 @@ export class AutoEncrypter {
socketOptions: autoSelectSocketOptions(this._client.s.options)
});

return await stateMachine.execute(
this,
context,
options.timeoutContext?.csotEnabled() ? options.timeoutContext : undefined
);
return await stateMachine.execute(this, context, options);
}

/**
Expand Down
10 changes: 6 additions & 4 deletions src/client-side-encryption/client_encryption.ts
Original file line number Diff line number Diff line change
Expand Up @@ -225,7 +225,7 @@ export class ClientEncryption {
TimeoutContext.create(resolveTimeoutOptions(this._client, { timeoutMS: this._timeoutMS }));

const dataKey = deserialize(
await stateMachine.execute(this, context, timeoutContext)
await stateMachine.execute(this, context, { timeoutContext })
) as DataKey;

const { db: dbName, collection: collectionName } = MongoDBCollectionNamespace.fromString(
Expand Down Expand Up @@ -293,7 +293,9 @@ export class ClientEncryption {
resolveTimeoutOptions(this._client, { timeoutMS: this._timeoutMS })
);

const { v: dataKeys } = deserialize(await stateMachine.execute(this, context, timeoutContext));
const { v: dataKeys } = deserialize(
await stateMachine.execute(this, context, { timeoutContext })
);
if (dataKeys.length === 0) {
return {};
}
Expand Down Expand Up @@ -696,7 +698,7 @@ export class ClientEncryption {
? TimeoutContext.create(resolveTimeoutOptions(this._client, { timeoutMS: this._timeoutMS }))
: undefined;

const { v } = deserialize(await stateMachine.execute(this, context, timeoutContext));
const { v } = deserialize(await stateMachine.execute(this, context, { timeoutContext }));

return v;
}
Expand Down Expand Up @@ -780,7 +782,7 @@ export class ClientEncryption {
this._timeoutMS != null
? TimeoutContext.create(resolveTimeoutOptions(this._client, { timeoutMS: this._timeoutMS }))
: undefined;
const { v } = deserialize(await stateMachine.execute(this, context, timeoutContext));
const { v } = deserialize(await stateMachine.execute(this, context, { timeoutContext }));
return v;
}
}
Expand Down
116 changes: 80 additions & 36 deletions src/client-side-encryption/state_machine.ts
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,15 @@ import { CursorTimeoutContext } from '../cursor/abstract_cursor';
import { getSocks, type SocksLib } from '../deps';
import { MongoOperationTimeoutError } from '../error';
import { type MongoClient, type MongoClientOptions } from '../mongo_client';
import { type Abortable } from '../mongo_types';
import { Timeout, type TimeoutContext, TimeoutError } from '../timeout';
import { BufferPool, MongoDBCollectionNamespace, promiseWithResolvers } from '../utils';
import {
addAbortListener,
BufferPool,
kDispose,
MongoDBCollectionNamespace,
promiseWithResolvers
} from '../utils';
import { autoSelectSocketOptions, type DataKey } from './client_encryption';
import { MongoCryptError } from './errors';
import { type MongocryptdManager } from './mongocryptd_manager';
Expand Down Expand Up @@ -189,7 +196,7 @@ export class StateMachine {
async execute(
executor: StateMachineExecutable,
context: MongoCryptContext,
timeoutContext?: TimeoutContext
options: { timeoutContext?: TimeoutContext } & Abortable
): Promise<Uint8Array> {
const keyVaultNamespace = executor._keyVaultNamespace;
const keyVaultClient = executor._keyVaultClient;
Expand All @@ -199,6 +206,7 @@ export class StateMachine {
let result: Uint8Array | null = null;

while (context.state !== MONGOCRYPT_CTX_DONE && context.state !== MONGOCRYPT_CTX_ERROR) {
options.signal?.throwIfAborted();
debug(`[context#${context.id}] ${stateToString.get(context.state) || context.state}`);

switch (context.state) {
Expand All @@ -214,7 +222,7 @@ export class StateMachine {
metaDataClient,
context.ns,
filter,
timeoutContext
options
);
if (collInfo) {
context.addMongoOperationResponse(collInfo);
Expand All @@ -235,9 +243,9 @@ export class StateMachine {
// When we are using the shared library, we don't have a mongocryptd manager.
const markedCommand: Uint8Array = mongocryptdManager
? await mongocryptdManager.withRespawn(
this.markCommand.bind(this, mongocryptdClient, context.ns, command, timeoutContext)
this.markCommand.bind(this, mongocryptdClient, context.ns, command, options)
)
: await this.markCommand(mongocryptdClient, context.ns, command, timeoutContext);
: await this.markCommand(mongocryptdClient, context.ns, command, options);

context.addMongoOperationResponse(markedCommand);
context.finishMongoOperation();
Expand All @@ -246,12 +254,7 @@ export class StateMachine {

case MONGOCRYPT_CTX_NEED_MONGO_KEYS: {
const filter = context.nextMongoOperation();
const keys = await this.fetchKeys(
keyVaultClient,
keyVaultNamespace,
filter,
timeoutContext
);
const keys = await this.fetchKeys(keyVaultClient, keyVaultNamespace, filter, options);

if (keys.length === 0) {
// See docs on EMPTY_V
Expand All @@ -273,7 +276,7 @@ export class StateMachine {
}

case MONGOCRYPT_CTX_NEED_KMS: {
await Promise.all(this.requests(context, timeoutContext));
await Promise.all(this.requests(context, options));
context.finishKMSRequests();
break;
}
Expand Down Expand Up @@ -315,11 +318,13 @@ export class StateMachine {
* @param kmsContext - A C++ KMS context returned from the bindings
* @returns A promise that resolves when the KMS reply has be fully parsed
*/
async kmsRequest(request: MongoCryptKMSRequest, timeoutContext?: TimeoutContext): Promise<void> {
async kmsRequest(
request: MongoCryptKMSRequest,
options?: { timeoutContext?: TimeoutContext } & Abortable
): Promise<void> {
const parsedUrl = request.endpoint.split(':');
const port = parsedUrl[1] != null ? Number.parseInt(parsedUrl[1], 10) : HTTPS_PORT;
const socketOptions = autoSelectSocketOptions(this.options.socketOptions || {});
const options: tls.ConnectionOptions & {
const socketOptions: tls.ConnectionOptions & {
host: string;
port: number;
autoSelectFamily?: boolean;
Expand All @@ -328,7 +333,7 @@ export class StateMachine {
host: parsedUrl[0],
servername: parsedUrl[0],
port,
...socketOptions
...autoSelectSocketOptions(this.options.socketOptions || {})
};
const message = request.message;
const buffer = new BufferPool();
Expand Down Expand Up @@ -363,7 +368,7 @@ export class StateMachine {
throw error;
}
try {
await this.setTlsOptions(providerTlsOptions, options);
await this.setTlsOptions(providerTlsOptions, socketOptions);
} catch (err) {
throw onerror(err);
}
Expand All @@ -380,23 +385,25 @@ export class StateMachine {
.once('close', () => rejectOnNetSocketError(onclose()))
.once('connect', () => resolveOnNetSocketConnect());

let abortListener;

try {
if (this.options.proxyOptions && this.options.proxyOptions.proxyHost) {
const netSocketOptions = {
...socketOptions,
host: this.options.proxyOptions.proxyHost,
port: this.options.proxyOptions.proxyPort || 1080,
...socketOptions
port: this.options.proxyOptions.proxyPort || 1080
};
netSocket.connect(netSocketOptions);
await willConnect;

try {
socks ??= loadSocks();
options.socket = (
socketOptions.socket = (
await socks.SocksClient.createConnection({
existing_socket: netSocket,
command: 'connect',
destination: { host: options.host, port: options.port },
destination: { host: socketOptions.host, port: socketOptions.port },
proxy: {
// host and port are ignored because we pass existing_socket
host: 'iLoveJavaScript',
Expand All @@ -412,7 +419,7 @@ export class StateMachine {
}
}

socket = tls.connect(options, () => {
socket = tls.connect(socketOptions, () => {
socket.write(message);
});

Expand All @@ -422,6 +429,11 @@ export class StateMachine {
resolve
} = promiseWithResolvers<void>();

abortListener = addAbortListener(options?.signal, function () {
destroySockets();
rejectOnTlsSocketError(this.reason);
});

socket
.once('error', err => rejectOnTlsSocketError(onerror(err)))
.once('close', () => rejectOnTlsSocketError(onclose()))
Expand All @@ -436,8 +448,11 @@ export class StateMachine {
resolve();
}
});
await (timeoutContext?.csotEnabled()
? Promise.all([willResolveKmsRequest, Timeout.expires(timeoutContext?.remainingTimeMS)])
await (options?.timeoutContext?.csotEnabled()
? Promise.all([
willResolveKmsRequest,
Timeout.expires(options.timeoutContext?.remainingTimeMS)
])
: willResolveKmsRequest);
} catch (error) {
if (error instanceof TimeoutError)
Expand All @@ -446,16 +461,17 @@ export class StateMachine {
} finally {
// There's no need for any more activity on this socket at this point.
destroySockets();
abortListener?.[kDispose]();
}
}

*requests(context: MongoCryptContext, timeoutContext?: TimeoutContext) {
*requests(context: MongoCryptContext, options?: { timeoutContext?: TimeoutContext } & Abortable) {
for (
let request = context.nextKMSRequest();
request != null;
request = context.nextKMSRequest()
) {
yield this.kmsRequest(request, timeoutContext);
yield this.kmsRequest(request, options);
}
}

Expand Down Expand Up @@ -516,14 +532,16 @@ export class StateMachine {
client: MongoClient,
ns: string,
filter: Document,
timeoutContext?: TimeoutContext
options?: { timeoutContext?: TimeoutContext } & Abortable
): Promise<Uint8Array | null> {
const { db } = MongoDBCollectionNamespace.fromString(ns);

const cursor = client.db(db).listCollections(filter, {
promoteLongs: false,
promoteValues: false,
timeoutContext: timeoutContext && new CursorTimeoutContext(timeoutContext, Symbol())
timeoutContext:
options?.timeoutContext && new CursorTimeoutContext(options?.timeoutContext, Symbol()),
signal: options?.signal
});

// There is always exactly zero or one matching documents, so this should always exhaust the cursor
Expand All @@ -547,17 +565,30 @@ export class StateMachine {
client: MongoClient,
ns: string,
command: Uint8Array,
timeoutContext?: TimeoutContext
options?: { timeoutContext?: TimeoutContext } & Abortable
): Promise<Uint8Array> {
const { db } = MongoDBCollectionNamespace.fromString(ns);
const bsonOptions = { promoteLongs: false, promoteValues: false };
const rawCommand = deserialize(command, bsonOptions);

const commandOptions: {
timeoutMS?: number;
signal?: AbortSignal;
} = {
timeoutMS: undefined,
signal: undefined
};

if (options?.timeoutContext?.csotEnabled()) {
commandOptions.timeoutMS = options.timeoutContext.remainingTimeMS;
}
if (options?.signal) {
commandOptions.signal = options.signal;
}

const response = await client.db(db).command(rawCommand, {
...bsonOptions,
...(timeoutContext?.csotEnabled()
? { timeoutMS: timeoutContext?.remainingTimeMS }
: undefined)
...commandOptions
});

return serialize(response, this.bsonOptions);
Expand All @@ -575,17 +606,30 @@ export class StateMachine {
client: MongoClient,
keyVaultNamespace: string,
filter: Uint8Array,
timeoutContext?: TimeoutContext
options?: { timeoutContext?: TimeoutContext } & Abortable
): Promise<Array<DataKey>> {
const { db: dbName, collection: collectionName } =
MongoDBCollectionNamespace.fromString(keyVaultNamespace);

const commandOptions: {
timeoutContext?: CursorTimeoutContext;
signal?: AbortSignal;
} = {
timeoutContext: undefined,
signal: undefined
};

if (options?.timeoutContext != null) {
commandOptions.timeoutContext = new CursorTimeoutContext(options.timeoutContext, Symbol());
}
if (options?.signal != null) {
commandOptions.signal = options.signal;
}

return client
.db(dbName)
.collection<DataKey>(collectionName, { readConcern: { level: 'majority' } })
.find(deserialize(filter), {
timeoutContext: timeoutContext && new CursorTimeoutContext(timeoutContext, Symbol())
})
.find(deserialize(filter), commandOptions)
.toArray();
}
}
Loading

0 comments on commit ab02d53

Please sign in to comment.