Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add TestMatrix functionality to qtest #2037

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 43 additions & 2 deletions library/qtest/src/Functions.qs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
// Licensed under the MIT License.

import Util.TestCaseResult, Util.OutputMessage;
import Std.Arrays.Mapped, Std.Arrays.All;
import Std.Arrays.Mapped, Std.Arrays.All, Std.Arrays.Enumerated;

/// # Summary
/// Runs a number of test cases and returns true if all tests passed, false otherwise.
Expand Down Expand Up @@ -47,6 +47,47 @@ function RunAllTestCases<'T : Eq + Show>(test_cases : (String, () -> 'T, 'T)[])
Mapped((name, case, result) -> TestCase(name, case, result), test_cases)
}

/// # Summary
/// Given a function to test and an array of test cases of the form (input, expected_output), and a test mode, runs the test cases and returns the result of the test mode.
///
/// # Inputs
/// - `test_suite_name` : A string representing the name of the test suite.
/// - `func` : The function to test.
/// - `test_cases` : An array of tuples of the form (input, expected_output).
/// - `mode` : A function that takes an array of tuples of the form (test_name, test_case, expected_output) and returns a value of type 'U.
/// Intended to be either `Qtest.Functions.CheckAllTestCases` or `Qtest.Functions.RunAllTestCases`.
///
/// # Example
/// ```qsharp
/// TestMatrix("Add One", x -> x + 1, [(2, 3), (3, 4)], CheckAllTestCases);
/// ```

function TestMatrix<'T, 'O : Show + Eq, 'U>(
test_suite_name : String,
func : 'T -> 'O,
test_cases : ('T, 'O)[],
mode : ((String, () -> 'O, 'O)[]) -> 'U
) : 'U {
let test_cases_qs = Mapped((ix, (input, expected)) -> (test_suite_name + $" {ix + 1}", () -> func(input), expected), Enumerated(test_cases));
mode(test_cases_qs)
}

function RunTestMatrix<'T : Show, 'O : Show + Eq>(
test_suite_name : String,
func : 'T -> 'O,
test_cases : ('T, 'O)[]
) : TestCaseResult[] {
TestMatrix(test_suite_name, func, test_cases, RunAllTestCases)
}

function CheckTestMatrix<'T : Show, 'O : Show + Eq>(
test_suite_name : String,
func : 'T -> 'O,
test_cases : ('T, 'O)[]
) : Bool {
TestMatrix(test_suite_name, func, test_cases, CheckAllTestCases)
}

/// Internal (non-exported) helper function. Runs a test case and produces a `TestCaseResult`
function TestCase<'T : Eq + Show>(name : String, test_case : () -> 'T, expected : 'T) : TestCaseResult {
let result = test_case();
Expand All @@ -57,4 +98,4 @@ function TestCase<'T : Eq + Show>(name : String, test_case : () -> 'T, expected
}
}

export CheckAllTestCases, RunAllTestCases;
export CheckAllTestCases, RunAllTestCases, TestMatrix, RunTestMatrix, CheckTestMatrix;
76 changes: 68 additions & 8 deletions library/qtest/src/Operations.qs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
// Licensed under the MIT License.

import Util.TestCaseResult, Util.OutputMessage;
import Std.Arrays.Mapped, Std.Arrays.All;
import Std.Arrays.Mapped, Std.Arrays.All, Std.Arrays.Enumerated;

/// # Summary
/// Runs a number of test cases and returns true if all tests passed, false otherwise.
Expand All @@ -26,7 +26,6 @@ operation CheckAllTestCases<'T : Eq + Show>(test_cases : (String, Int, (Qubit[])
OutputMessage(test_results);

All(test_case -> test_case.did_pass, test_results)

}

/// # Summary
Expand All @@ -43,11 +42,7 @@ operation CheckAllTestCases<'T : Eq + Show>(test_cases : (String, Int, (Qubit[])
/// ```qsharp
/// RunAllTestCases([("Should return 42", () -> 42, 42)]);
/// ```
operation RunAllTestCases<'T : Eq + Show>(test_cases : (String, Int, (Qubit[]) => (), (Qubit[]) => 'T, 'T)[]) : TestCaseResult[] {
let num_tests = Length(test_cases);

let num_tests = Length(test_cases);

operation RunAllTestCases<'T : Eq + Show>(test_cases : (String, Int, Qubit[] => Unit, Qubit[] => 'T, 'T)[]) : TestCaseResult[] {
MappedOperation((name, num_qubits, prepare_state, case, result) => {
use qubits = Qubit[num_qubits];
prepare_state(qubits);
Expand All @@ -67,6 +62,71 @@ operation MappedOperation<'T, 'U>(mapper : ('T => 'U), array : 'T[]) : 'U[] {
mapped
}


/// # Summary
/// Given an operation on some qubits `func` which returns some value to test and a number of qubits to use `num_qubits`,
/// runs a number of test cases of the form `(Qubit[] => Unit, 'O)` where the first element is a qubit
/// state preparation operation and the second element is the expected output of the operation.
/// Returns the result of the `mode` function which takes a list of test cases and returns a value of type `'U`.
///
/// # Input
/// - `test_suite_name` : A string representing the name of the test suite.
/// - `func` : An operation which takes an array of qubits and returns a value of type `'O`.
/// - `num_qubits` : The number of qubits to use in the test. These are allocated before the test and reset before each test case.
/// - `test_cases` : A list of test cases, each of the form `(Qubit[] => Unit, 'O)`. The lambda operation should set up the qubits
/// in a specific state for `func` to operate on.
/// - `mode` : A function which takes a list of test cases and returns a value of type `'U`. Intended to be either `Qtest.Operations.CheckAllTestCases` or `Qtest.Operations.RunAllTestCases`.
///
/// # Example
/// ```qsharp
/// let test_cases: (Qubit[] => Unit, Int)[] = [
/// (qs => { X(qs[0]); X(qs[3]); }, 0b1001),
/// (qs => { X(qs[0]); X(qs[1]); }, 0b0011)
/// ];
///
/// let res : Util.TestCaseResult[] = Operations.TestMatrix(
/// // test name
/// "QubitTestMatrix",
/// // operation to test
/// qs => MeasureInteger(qs),
/// // number of qubits
/// 4,
/// // test cases
/// test_cases,
/// // test mode
/// Operations.RunAllTestCases
/// );
/// ```

operation TestMatrix<'O : Show + Eq, 'U>(
test_suite_name : String,
func : Qubit[] => 'O,
num_qubits : Int,
test_cases : (Qubit[] => Unit, 'O)[],
mode : ((String, Int, Qubit[] => Unit, Qubit[] => 'O, 'O)[]) => 'U
) : 'U {
let test_cases_qs = Mapped((ix, (qubit_prep_function, expected)) -> (test_suite_name + $" {ix + 1}", num_qubits, qubit_prep_function, func, expected), Enumerated(test_cases));
mode(test_cases_qs)
}

operation CheckTestMatrix<'O : Show + Eq>(
test_suite_name : String,
func : Qubit[] => 'O,
num_qubits : Int,
test_cases : (Qubit[] => Unit, 'O)[]
) : Bool {
TestMatrix(test_suite_name, func, num_qubits, test_cases, CheckAllTestCases)
}

operation RunTestMatrix<'O : Show + Eq>(
test_suite_name : String,
func : Qubit[] => 'O,
num_qubits : Int,
test_cases : (Qubit[] => Unit, 'O)[]
) : TestCaseResult[] {
TestMatrix(test_suite_name, func, num_qubits, test_cases, RunAllTestCases)
}

/// Internal (non-exported) helper function. Runs a test case and produces a `TestCaseResult`
operation TestCase<'T : Eq + Show>(name : String, qubits : Qubit[], test_case : (Qubit[]) => 'T, expected : 'T) : TestCaseResult {
let result = test_case(qubits);
Expand All @@ -77,4 +137,4 @@ operation TestCase<'T : Eq + Show>(name : String, qubits : Qubit[], test_case :
}
}

export CheckAllTestCases, RunAllTestCases;
export CheckAllTestCases, RunAllTestCases, TestMatrix, CheckTestMatrix, RunTestMatrix;
54 changes: 49 additions & 5 deletions library/qtest/src/Tests.qs
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,51 @@
// Licensed under the MIT License.

import Std.Diagnostics.Fact;
import Std.Arrays.All;

function Main() : Unit {
operation Main() : Unit {
FunctionTestMatrixTests();
OperationTestMatrixTests();
BasicTests();
}

operation OperationTestMatrixTests() : Unit {
let test_cases : (Qubit[] => Unit, Int)[] = [
(qs => { X(qs[0]); X(qs[3]); }, 0b1001),
(qs => { X(qs[0]); X(qs[1]); }, 0b0011)
];

let res1 : Util.TestCaseResult[] = Operations.TestMatrix(
"QubitTestMatrix",
qs => MeasureInteger(qs),
4,
test_cases,
Operations.RunAllTestCases
);

let res2 : Util.TestCaseResult[] = Operations.RunTestMatrix(
"QubitTestMatrix",
qs => MeasureInteger(qs),
4,
test_cases,
);

Fact(All(x -> x.did_pass, res1) and All(x -> x.did_pass, res2), "RunTestMatrix and TestMatrix did not return the same results");
}

function FunctionTestMatrixTests() : Unit {
let all_passed = Functions.TestMatrix("Return 42", TestCaseOne, [((), 42), ((), 42)], Functions.CheckAllTestCases);
Fact(all_passed, "basic test matrix did not pass");

let at_least_one_failed = not Functions.TestMatrix("Return 42", TestCaseOne, [((), 42), ((), 43)], Functions.CheckAllTestCases);
Fact(at_least_one_failed, "basic test matrix did not report failure");

let results = Functions.TestMatrix("AddOne", AddOne, [(5, 6), (6, 7)], Functions.RunAllTestCases);
Fact(Length(results) == 2, "test matrix did not return results for all test cases");
Fact(All(result -> result.did_pass, results), "test matrix did not pass all test cases");
}

function BasicTests() : Unit {
let sample_tests = [
("Should return 42", TestCaseOne, 43),
("Should add one", () -> AddOne(5), 42),
Expand All @@ -27,9 +70,10 @@ function Main() : Unit {
"Test harness did not return results for all test cases."
);

Fact(run_all_result[0].did_pass, "test one passed when it should have failed");
Fact(run_all_result[1].did_pass, "test two failed when it should have passed");
Fact(run_all_result[2].did_pass, "test three passed when it should have failed");
Fact(not run_all_result[0].did_pass, "test one passed when it should have failed");
Fact(not run_all_result[1].did_pass, "test two passed when it should have failed");
Fact(run_all_result[2].did_pass, "test three failed when it should have passed");

}

function TestCaseOne() : Int {
Expand All @@ -38,4 +82,4 @@ function TestCaseOne() : Int {

function AddOne(x : Int) : Int {
x + 1
}
}
Loading