Skip to content

Commit

Permalink
Add TestMatrix functionality to qtest (#2037)
Browse files Browse the repository at this point in the history
I believe this should address the ideas @Manvi-Agrawal proposed in
#2026. What do you think Manvi?

Closes #2026

---------

Co-authored-by: Manvi-Agrawal <[email protected]>
Co-authored-by: orpuente-MS <[email protected]>
  • Loading branch information
3 people authored Jan 10, 2025
1 parent 3e83ec3 commit d021574
Show file tree
Hide file tree
Showing 3 changed files with 160 additions and 15 deletions.
45 changes: 43 additions & 2 deletions library/qtest/src/Functions.qs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
// Licensed under the MIT License.

import Util.TestCaseResult, Util.OutputMessage;
import Std.Arrays.Mapped, Std.Arrays.All;
import Std.Arrays.Mapped, Std.Arrays.All, Std.Arrays.Enumerated;

/// # Summary
/// Runs a number of test cases and returns true if all tests passed, false otherwise.
Expand Down Expand Up @@ -47,6 +47,47 @@ function RunAllTestCases<'T : Eq + Show>(test_cases : (String, () -> 'T, 'T)[])
Mapped((name, case, result) -> TestCase(name, case, result), test_cases)
}

/// # Summary
/// Given a function to test, an array of test cases of the form (input, expected_output), and a test mode, runs the test cases and returns the result of the test mode.
///
/// # Inputs
/// - `test_suite_name` : A string representing the name of the test suite.
/// - `func` : The function to test.
/// - `test_cases` : An array of tuples of the form (input, expected_output).
/// - `mode` : A function that takes an array of tuples of the form (test_name, test_case, expected_output) and returns a value of type 'U.
/// Intended to be either `Qtest.Functions.CheckAllTestCases` or `Qtest.Functions.RunAllTestCases`.
///
/// # Example
/// ```qsharp
/// TestMatrix("Add One", x -> x + 1, [(2, 3), (3, 4)], CheckAllTestCases);
/// ```

function TestMatrix<'T, 'O : Show + Eq, 'U>(
test_suite_name : String,
func : 'T -> 'O,
test_cases : ('T, 'O)[],
mode : ((String, () -> 'O, 'O)[]) -> 'U
) : 'U {
let test_cases_qs = Mapped((ix, (input, expected)) -> (test_suite_name + $" {ix + 1}", () -> func(input), expected), Enumerated(test_cases));
mode(test_cases_qs)
}

function RunTestMatrix<'T : Show, 'O : Show + Eq>(
test_suite_name : String,
func : 'T -> 'O,
test_cases : ('T, 'O)[]
) : TestCaseResult[] {
TestMatrix(test_suite_name, func, test_cases, RunAllTestCases)
}

function CheckTestMatrix<'T : Show, 'O : Show + Eq>(
test_suite_name : String,
func : 'T -> 'O,
test_cases : ('T, 'O)[]
) : Bool {
TestMatrix(test_suite_name, func, test_cases, CheckAllTestCases)
}

/// Internal (non-exported) helper function. Runs a test case and produces a `TestCaseResult`
function TestCase<'T : Eq + Show>(name : String, test_case : () -> 'T, expected : 'T) : TestCaseResult {
let result = test_case();
Expand All @@ -57,4 +98,4 @@ function TestCase<'T : Eq + Show>(name : String, test_case : () -> 'T, expected
}
}

export CheckAllTestCases, RunAllTestCases;
export CheckAllTestCases, RunAllTestCases, TestMatrix, RunTestMatrix, CheckTestMatrix;
76 changes: 68 additions & 8 deletions library/qtest/src/Operations.qs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
// Licensed under the MIT License.

import Util.TestCaseResult, Util.OutputMessage;
import Std.Arrays.Mapped, Std.Arrays.All;
import Std.Arrays.Mapped, Std.Arrays.All, Std.Arrays.Enumerated;

/// # Summary
/// Runs a number of test cases and returns true if all tests passed, false otherwise.
Expand All @@ -27,7 +27,6 @@ operation CheckAllTestCases<'T : Eq + Show>(test_cases : (String, Int, Qubit[] =
OutputMessage(test_results);

All(test_case -> test_case.did_pass, test_results)

}

/// # Summary
Expand All @@ -44,11 +43,7 @@ operation CheckAllTestCases<'T : Eq + Show>(test_cases : (String, Int, Qubit[] =
/// ```qsharp
/// RunAllTestCases([("0b0001 == 1", 4, (qs) => X(qs[0]), (qs) => MeasureSignedInteger(qs, 4), 1)]);
/// ```
operation RunAllTestCases<'T : Eq + Show>(test_cases : (String, Int, (Qubit[]) => (), (Qubit[]) => 'T, 'T)[]) : TestCaseResult[] {
let num_tests = Length(test_cases);

let num_tests = Length(test_cases);

operation RunAllTestCases<'T : Eq + Show>(test_cases : (String, Int, Qubit[] => Unit, Qubit[] => 'T, 'T)[]) : TestCaseResult[] {
MappedOperation((name, num_qubits, prepare_state, case, result) => {
use qubits = Qubit[num_qubits];
prepare_state(qubits);
Expand All @@ -68,6 +63,71 @@ operation MappedOperation<'T, 'U>(mapper : ('T => 'U), array : 'T[]) : 'U[] {
mapped
}


/// # Summary
/// Given an operation on some qubits `func` which returns some value to test and a number of qubits to use `num_qubits`,
/// runs a number of test cases of the form `(Qubit[] => Unit, 'O)` where the first element is a qubit
/// state preparation operation and the second element is the expected output of the operation.
/// Returns the result of the `mode` function which takes a list of test cases and returns a value of type `'U`.
///
/// # Input
/// - `test_suite_name` : A string representing the name of the test suite.
/// - `func` : An operation which takes an array of qubits and returns a value of type `'O`.
/// - `num_qubits` : The number of qubits to use in the test. These are allocated before the test and reset before each test case.
/// - `test_cases` : A list of test cases, each of the form `(Qubit[] => Unit, 'O)`. The lambda operation should set up the qubits
/// in a specific state for `func` to operate on.
/// - `mode` : A function which takes a list of test cases and returns a value of type `'U`. Intended to be either `Qtest.Operations.CheckAllTestCases` or `Qtest.Operations.RunAllTestCases`.
///
/// # Example
/// ```qsharp
/// let test_cases: (Qubit[] => Unit, Int)[] = [
/// (qs => { X(qs[0]); X(qs[3]); }, 0b1001),
/// (qs => { X(qs[0]); X(qs[1]); }, 0b0011)
/// ];
///
/// let res : Util.TestCaseResult[] = Operations.TestMatrix(
/// // test name
/// "QubitTestMatrix",
/// // operation to test
/// qs => MeasureInteger(qs),
/// // number of qubits
/// 4,
/// // test cases
/// test_cases,
/// // test mode
/// Operations.RunAllTestCases
/// );
/// ```

operation TestMatrix<'O : Show + Eq, 'U>(
test_suite_name : String,
func : Qubit[] => 'O,
num_qubits : Int,
test_cases : (Qubit[] => Unit, 'O)[],
mode : ((String, Int, Qubit[] => Unit, Qubit[] => 'O, 'O)[]) => 'U
) : 'U {
let test_cases_qs = Mapped((ix, (qubit_prep_function, expected)) -> (test_suite_name + $" {ix + 1}", num_qubits, qubit_prep_function, func, expected), Enumerated(test_cases));
mode(test_cases_qs)
}

operation CheckTestMatrix<'O : Show + Eq>(
test_suite_name : String,
func : Qubit[] => 'O,
num_qubits : Int,
test_cases : (Qubit[] => Unit, 'O)[]
) : Bool {
TestMatrix(test_suite_name, func, num_qubits, test_cases, CheckAllTestCases)
}

operation RunTestMatrix<'O : Show + Eq>(
test_suite_name : String,
func : Qubit[] => 'O,
num_qubits : Int,
test_cases : (Qubit[] => Unit, 'O)[]
) : TestCaseResult[] {
TestMatrix(test_suite_name, func, num_qubits, test_cases, RunAllTestCases)
}

/// Internal (non-exported) helper function. Runs a test case and produces a `TestCaseResult`
operation TestCase<'T : Eq + Show>(name : String, qubits : Qubit[], test_case : (Qubit[]) => 'T, expected : 'T) : TestCaseResult {
let result = test_case(qubits);
Expand All @@ -78,4 +138,4 @@ operation TestCase<'T : Eq + Show>(name : String, qubits : Qubit[], test_case :
}
}

export CheckAllTestCases, RunAllTestCases;
export CheckAllTestCases, RunAllTestCases, TestMatrix, CheckTestMatrix, RunTestMatrix;
54 changes: 49 additions & 5 deletions library/qtest/src/Tests.qs
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,51 @@
// Licensed under the MIT License.

import Std.Diagnostics.Fact;
import Std.Arrays.All;

function Main() : Unit {
operation Main() : Unit {
FunctionTestMatrixTests();
OperationTestMatrixTests();
BasicTests();
}

operation OperationTestMatrixTests() : Unit {
let test_cases : (Qubit[] => Unit, Int)[] = [
(qs => { X(qs[0]); X(qs[3]); }, 0b1001),
(qs => { X(qs[0]); X(qs[1]); }, 0b0011)
];

let res1 : Util.TestCaseResult[] = Operations.TestMatrix(
"QubitTestMatrix",
qs => MeasureInteger(qs),
4,
test_cases,
Operations.RunAllTestCases
);

let res2 : Util.TestCaseResult[] = Operations.RunTestMatrix(
"QubitTestMatrix",
qs => MeasureInteger(qs),
4,
test_cases,
);

Fact(All(x -> x.did_pass, res1) and All(x -> x.did_pass, res2), "RunTestMatrix and TestMatrix did not return the same results");
}

function FunctionTestMatrixTests() : Unit {
let all_passed = Functions.TestMatrix("Return 42", TestCaseOne, [((), 42), ((), 42)], Functions.CheckAllTestCases);
Fact(all_passed, "basic test matrix did not pass");

let at_least_one_failed = not Functions.TestMatrix("Return 42", TestCaseOne, [((), 42), ((), 43)], Functions.CheckAllTestCases);
Fact(at_least_one_failed, "basic test matrix did not report failure");

let results = Functions.TestMatrix("AddOne", AddOne, [(5, 6), (6, 7)], Functions.RunAllTestCases);
Fact(Length(results) == 2, "test matrix did not return results for all test cases");
Fact(All(result -> result.did_pass, results), "test matrix did not pass all test cases");
}

function BasicTests() : Unit {
let sample_tests = [
("Should return 42", TestCaseOne, 43),
("Should add one", () -> AddOne(5), 42),
Expand All @@ -27,9 +70,10 @@ function Main() : Unit {
"Test harness did not return results for all test cases."
);

Fact(run_all_result[0].did_pass, "test one passed when it should have failed");
Fact(run_all_result[1].did_pass, "test two failed when it should have passed");
Fact(run_all_result[2].did_pass, "test three passed when it should have failed");
Fact(not run_all_result[0].did_pass, "test one passed when it should have failed");
Fact(not run_all_result[1].did_pass, "test two passed when it should have failed");
Fact(run_all_result[2].did_pass, "test three failed when it should have passed");

}

function TestCaseOne() : Int {
Expand All @@ -38,4 +82,4 @@ function TestCaseOne() : Int {

function AddOne(x : Int) : Int {
x + 1
}
}

0 comments on commit d021574

Please sign in to comment.