Skip to content

Commit

Permalink
cleanup of deprecated test methods
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 716500241
Change-Id: I75de675b87746e49d250dc312aacf718707245d7
  • Loading branch information
TF-Agents Team authored and copybara-github committed Jan 17, 2025
1 parent 9ecbb2f commit bd31aec
Show file tree
Hide file tree
Showing 5 changed files with 8 additions and 8 deletions.
2 changes: 1 addition & 1 deletion tf_agents/policies/categorical_q_policy_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ def testBuild(self):
self.assertLen(policy.variables(), 2)

def testMultipleActionsRaiseError(self):
with self.assertRaisesRegexp(
with self.assertRaisesRegex(
TypeError, '.*action_spec must be a BoundedTensorSpec.*'
):
# Replace the action_spec for this test.
Expand Down
4 changes: 2 additions & 2 deletions tf_agents/policies/py_tf_policy_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ def testRaiseValueErrorWithoutSession(self):
if tf.executing_eagerly():
self.skipTest('b/123770140: Handling sessions with eager mode is buggy')
policy = py_tf_policy.PyTFPolicy(self._tf_policy)
with self.assertRaisesRegexp(
with self.assertRaisesRegex(
AttributeError,
"No TensorFlow session-like object was set on this 'PyTFPolicy'.*",
):
Expand Down Expand Up @@ -376,7 +376,7 @@ def testSaveWrappedPolicyRestoreOuterCheckAssertConsumed(self, batch_size=5):
)
# 3). Restoring the actor policy while checking that all variables in
# the checkpoint were found in the graph should fail.
with self.assertRaisesRegexp(
with self.assertRaisesRegex(
AssertionError, 'not bound to checkpointed values*'
):
policy_restored.restore(
Expand Down
4 changes: 2 additions & 2 deletions tf_agents/policies/q_policy_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ def testBuild(self):

def testMultipleActionsRaiseError(self):
action_spec = [tensor_spec.BoundedTensorSpec([], tf.int32, 0, 1)] * 2
with self.assertRaisesRegexp(ValueError, 'Only scalar actions'):
with self.assertRaisesRegex(ValueError, 'Only scalar actions'):
q_policy.QPolicy(self._time_step_spec, action_spec, q_network=DummyNet())

def testAction(self):
Expand Down Expand Up @@ -186,7 +186,7 @@ def testActionSpecsIncompatible(self):
network_action_spec = tensor_spec.BoundedTensorSpec([2], tf.int32, 0, 1)
q_net = DummyNetWithActionSpec(network_action_spec)

with self.assertRaisesRegexp(
with self.assertRaisesRegex(
ValueError, 'action_spec must be compatible with q_network.action_spec'
):
q_policy.QPolicy(self._time_step_spec, self._action_spec, q_net)
Expand Down
2 changes: 1 addition & 1 deletion tf_agents/policies/scripted_py_policy_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,7 +218,7 @@ def testEpisodeLength(self):
self.assertEqual(action_script[1][1], action_step.action)
action_step = policy.action(self._time_step, action_step.state)
self.assertEqual(action_script[1][1], action_step.action)
with self.assertRaisesRegexp(ValueError, '.*Episode is longer than.*'):
with self.assertRaisesRegex(ValueError, '.*Episode is longer than.*'):
policy.action(self._time_step, action_step.state)

def testPolicyStateSpecIsEmpty(self):
Expand Down
4 changes: 2 additions & 2 deletions tf_agents/policies/tf_policy_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,7 +254,7 @@ def testValidateArgsDisabled(self):
self.assertAllEqual([[1], [1], [1], [1]], action)

def testMismatchedDtypes(self):
with self.assertRaisesRegexp(TypeError, ".*dtype that doesn't match.*"):
with self.assertRaisesRegex(TypeError, ".*dtype that doesn't match.*"):
policy = TFPolicyMismatchedDtypes()
observation = tf.constant([[1, 2], [3, 4]], dtype=tf.float32)
time_step = ts.restart(observation)
Expand All @@ -271,7 +271,7 @@ def testMatchedDtypes(self):
policy.action(time_step)

def testMismatchedDtypesListAction(self):
with self.assertRaisesRegexp(TypeError, ".*dtype that doesn't match.*"):
with self.assertRaisesRegex(TypeError, ".*dtype that doesn't match.*"):
policy = TFPolicyMismatchedDtypesListAction()
observation = tf.constant([[1, 2], [3, 4]], dtype=tf.float32)
time_step = ts.restart(observation)
Expand Down

0 comments on commit bd31aec

Please sign in to comment.