From 75b839bce8f13349a1d28a94a68335f0fad3ed23 Mon Sep 17 00:00:00 2001 From: Pankaj Saini Date: Sun, 21 Apr 2024 17:57:18 -0700 Subject: [PATCH 1/2] Adding a type extension to check for if any type contains native object dangerous for deserializing. --- .../TypeExtensionsTests.cs | 94 +++++++++++++++++++ src/DurableTask.Core/Common/TypeExtension.cs | 85 +++++++++++++++++ 2 files changed, 179 insertions(+) create mode 100644 Test/DurableTask.Core.Tests/TypeExtensionsTests.cs create mode 100644 src/DurableTask.Core/Common/TypeExtension.cs diff --git a/Test/DurableTask.Core.Tests/TypeExtensionsTests.cs b/Test/DurableTask.Core.Tests/TypeExtensionsTests.cs new file mode 100644 index 000000000..578312ef7 --- /dev/null +++ b/Test/DurableTask.Core.Tests/TypeExtensionsTests.cs @@ -0,0 +1,94 @@ +// ---------------------------------------------------------------------------------- +// Copyright Microsoft Corporation +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ---------------------------------------------------------------------------------- + +namespace DurableTask.Core.Tests +{ + using DurableTask.Core.Common; + using Microsoft.VisualStudio.TestTools.UnitTesting; + using System; + using System.Threading.Tasks; + using System.Threading; + + [TestClass] + public class TypeExtensionsTests + { + [TestMethod] + [DataRow(typeof(Task), true)] + [DataRow(typeof(CancellationToken), true)] + [DataRow(typeof(Semaphore), true)] + [DataRow(typeof(Task<(bool, string)>), true)] + [DataRow(typeof(TestClass1), false)] + [DataRow(typeof(TestClass2), false)] + [DataRow(typeof(string), false)] + public void IsEqualOrContainsCancellationTokenType(Type typeContaining, bool isTrue) + { + if (isTrue) + { + Assert.IsTrue(typeContaining.IsEqualOrContainsNativeType()); + } + else + { + Assert.IsFalse(typeContaining.IsEqualOrContainsNativeType()); + } + } + + [TestMethod] + [DataRow(typeof(Task), typeof(CancellationToken), true)] + [DataRow(typeof(CancellationToken), typeof(CancellationToken), true)] + [DataRow(typeof(Task), typeof(TestClass1), true)] + [DataRow(typeof(Task<(bool, string)>), typeof(string), true)] + [DataRow(typeof(Task<(bool, string)>), typeof(bool), true)] + [DataRow(typeof(TestClass1), typeof(TestClass1), true)] + [DataRow(typeof(TestClass1), typeof(string), true)] + [DataRow(typeof(TestClass1), typeof(double), true)] + [DataRow(typeof(TestClass1), typeof(TestClass2), true)] + [DataRow(typeof(TestClass2), typeof(double), true)] + [DataRow(typeof(TestClass2), typeof(string), false)] + public void Test_ContainsType(Type typeContaining, Type typeContained, bool isTrue) + { + if (isTrue) + { + Assert.IsTrue(typeContaining.IsEqualOrContainsType(typeContained)); + } + else + { + Assert.IsFalse(typeContaining.IsEqualOrContainsType(typeContained)); + } + } + + internal class TestClass1 + { + public string abc; + + public TestClass2 t2; + + public TestClass2 NewT2 { get; set; } + + public TestClass1() + { + abc = string.Empty; + t2 = new TestClass2(); + } + } + + internal class TestClass2 + { + public double def; + + public TestClass2() + { + def = 1.0; + } + } + } +} diff --git a/src/DurableTask.Core/Common/TypeExtension.cs b/src/DurableTask.Core/Common/TypeExtension.cs new file mode 100644 index 000000000..8f672bf76 --- /dev/null +++ b/src/DurableTask.Core/Common/TypeExtension.cs @@ -0,0 +1,85 @@ +// ---------------------------------------------------------------------------------- +// Copyright Microsoft Corporation +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ---------------------------------------------------------------------------------- + +namespace DurableTask.Core.Common +{ + using Microsoft.Win32.SafeHandles; + using System; + using System.Collections.Generic; + using System.Linq; + using System.Reflection; + using System.Threading; + + internal static class TypeExtension + { + /// + /// Checks if a type has any dangerous object which can conflict with native OS event handle. + /// E.g., CancellationToken, etc. + /// + /// + /// + public static bool IsEqualOrContainsNativeType(this Type typeContaining) + { + return typeContaining.IsEqualOrContainsType(typeof(WaitHandle)) + || typeContaining.IsEqualOrContainsType(typeof(SafeWaitHandle)); + } + + public static bool IsEqualOrContainsType(this Type typeContaining, Type typeContained) + { + if (typeContaining == typeContained) + { + return true; + } + + List processedTypes = new List(); + return typeContaining.ContainsType(typeContained, processedTypes); + } + + private static bool ContainsType(this Type typeContaining, Type typeContained, List processedTypes) + { + if (processedTypes.Any(x => x == typeContaining)) + { + // Self-reference, no point processing it again. + return false; + } + else + { + processedTypes.Add(typeContaining); + } + + // Get all properties and fields of typeT + PropertyInfo[] properties = typeContaining.GetProperties(); + FieldInfo[] fields = typeContaining.GetFields(); + + // Check properties + foreach (var property in properties) + { + if (property.PropertyType == typeContained) + return true; + else if (!property.PropertyType.IsPrimitive && property.PropertyType.ContainsType(typeContained, processedTypes)) + return true; + } + + // Check fields + foreach (var field in fields) + { + if (field.FieldType == typeContained) + return true; + else if (!field.FieldType.IsPrimitive && field.FieldType.ContainsType(typeContained, processedTypes)) + return true; + } + + return false; + } + } +} From 97b64dcb5fc9fdc7c6357a878a12496c40da8e22 Mon Sep 17 00:00:00 2001 From: Pankaj Saini Date: Thu, 25 Apr 2024 20:19:35 -0700 Subject: [PATCH 2/2] Adding validation for TaskActivity and TaskOrchestration registration. --- .../ReflectionBasedTaskActivityTests.cs | 146 ++++++++++++++++++ .../TaskRegistrationTests.cs | 144 +++++++++++++++++ src/DurableTask.Core/Common/TypeExtension.cs | 10 ++ .../ReflectionBasedTaskActivity.cs | 49 ++++++ src/DurableTask.Core/TaskHubWorker.cs | 62 +++++++- 5 files changed, 410 insertions(+), 1 deletion(-) create mode 100644 Test/DurableTask.Core.Tests/ReflectionBasedTaskActivityTests.cs create mode 100644 Test/DurableTask.Core.Tests/TaskRegistrationTests.cs diff --git a/Test/DurableTask.Core.Tests/ReflectionBasedTaskActivityTests.cs b/Test/DurableTask.Core.Tests/ReflectionBasedTaskActivityTests.cs new file mode 100644 index 000000000..aa5825ea3 --- /dev/null +++ b/Test/DurableTask.Core.Tests/ReflectionBasedTaskActivityTests.cs @@ -0,0 +1,146 @@ +// ---------------------------------------------------------------------------------- +// Copyright Microsoft Corporation +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ---------------------------------------------------------------------------------- + +namespace DurableTask.Core.Tests +{ + using System; + using System.Collections.Concurrent; + using System.Collections.Generic; + using System.Diagnostics; + using System.IO; + using System.Linq; + using System.Reflection; + using System.Runtime.Serialization; + using System.Text; + using System.Threading; + using System.Threading.Tasks; + using System.Xml; + using DurableTask.Core.Command; + using DurableTask.Core.History; + using DurableTask.Emulator; + using DurableTask.Test.Orchestrations; + using Microsoft.VisualStudio.TestTools.UnitTesting; + + [TestClass] + public class ReflectionBasedTaskActivityTests + { + public TaskHubWorker worker { get; private set; } + + [TestInitialize] + public void Initialize() + { + var service = new LocalOrchestrationService(); + this.worker = new TaskHubWorker(service); + } + + [TestMethod] + public void Test_AddTaskActivitiesFromInterface() + { + IReflectionBasedTaskActivityTest reflectionBasedTaskActivityTest = new ReflectionBasedTaskActivityTest(); + + worker.AddTaskActivitiesFromInterface(reflectionBasedTaskActivityTest); + worker.AddTaskActivitiesFromInterface(reflectionBasedTaskActivityTest, useFullyQualifiedMethodNames: true); + } + + [TestMethod] + public void Test_ReflectionBasedTaskActivity() + { + IReflectionBasedTaskActivityTest reflectionBasedTaskActivityTest = new ReflectionBasedTaskActivityTest(); + + worker.AddTaskActivitiesFromInterface(reflectionBasedTaskActivityTest); + worker.AddTaskActivitiesFromInterface(reflectionBasedTaskActivityTest, useFullyQualifiedMethodNames: true); + } + + [TestMethod] + public void Test_ReflectionBasedTaskActivityWithGeneric() + { + IReflectionBasedTaskActivityWithGenericTest reflectionBasedTaskActivityTest = new ReflectionBasedTaskActivityTest(); + + worker.AddTaskActivitiesFromInterface(reflectionBasedTaskActivityTest); + worker.AddTaskActivitiesFromInterface(reflectionBasedTaskActivityTest, useFullyQualifiedMethodNames: true); + } + + [TestMethod] + public void Test_AddTaskActivitiesFromInterfaceWithCancellationToken() + { + IReflectionBasedTaskActivityWithCancellationTokenTest reflectionBasedTaskActivityTest = new ReflectionBasedTaskActivityTest(); + Action action = () => + { + worker.AddTaskActivitiesFromInterface(reflectionBasedTaskActivityTest); + }; + + Assert.ThrowsException(action); + + action = () => + { + worker.AddTaskActivitiesFromInterface(reflectionBasedTaskActivityTest, useFullyQualifiedMethodNames: true); + }; + + Assert.ThrowsException(action); + } + + [TestMethod] + public void Test_ReflectionBasedTaskActivityWithCancellationToken() + { + IReflectionBasedTaskActivityWithCancellationTokenTest reflectionBasedTaskActivityTest = new ReflectionBasedTaskActivityTest(); + Action action = () => + { + worker.AddTaskActivitiesFromInterface(reflectionBasedTaskActivityTest); + }; + + Assert.ThrowsException(action); + + action = () => + { + worker.AddTaskActivitiesFromInterface(reflectionBasedTaskActivityTest, useFullyQualifiedMethodNames: true); + }; + + Assert.ThrowsException(action); + } + + public interface IReflectionBasedTaskActivityTest + { + void TestMethod1(); + void TestMethod2(int n1, int n2); + } + + public interface IReflectionBasedTaskActivityWithGenericTest + { + void TestMethodG3(T obj); + } + + public interface IReflectionBasedTaskActivityWithCancellationTokenTest : IReflectionBasedTaskActivityTest + { + void TestMethod3(int n, CancellationToken cancellationToken); + } + + public class ReflectionBasedTaskActivityTest : IReflectionBasedTaskActivityWithCancellationTokenTest, IReflectionBasedTaskActivityWithGenericTest + { + public void TestMethod1() + { + } + + public void TestMethod2(int n1, int n2) + { + } + + public void TestMethod3(int n, CancellationToken cancellationToken) + { + } + + public void TestMethodG3(T obj) + { + } + } + } +} diff --git a/Test/DurableTask.Core.Tests/TaskRegistrationTests.cs b/Test/DurableTask.Core.Tests/TaskRegistrationTests.cs new file mode 100644 index 000000000..8f41322f6 --- /dev/null +++ b/Test/DurableTask.Core.Tests/TaskRegistrationTests.cs @@ -0,0 +1,144 @@ +// ---------------------------------------------------------------------------------- +// Copyright Microsoft Corporation +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ---------------------------------------------------------------------------------- + +namespace DurableTask.Core.Tests +{ + using System; + using System.Collections.Concurrent; + using System.Collections.Generic; + using System.Diagnostics; + using System.IO; + using System.Linq; + using System.Runtime.Serialization; + using System.Text; + using System.Threading; + using System.Threading.Tasks; + using System.Xml; + using DurableTask.Core.Command; + using DurableTask.Core.History; + using DurableTask.Emulator; + using DurableTask.Test.Orchestrations; + using Microsoft.VisualStudio.TestTools.UnitTesting; + + [TestClass] + public class TaskRegistrationTests + { + public TaskHubWorker worker { get; private set; } + + [TestInitialize] + public void Initialize() + { + var service = new LocalOrchestrationService(); + this.worker = new TaskHubWorker(service); + } + + [TestMethod] + [DataRow(typeof(TaskOrchInValid))] + [DataRow(typeof(TaskOrchInValidGeneric))] + [DataRow(typeof(TaskOrchInValidGeneric))] + [DataRow(typeof(TaskOrchGeneric))] + [DataRow(typeof(TaskOrchGeneric))] + [DataRow(typeof(TaskOrchGeneric))] + public void Test_RegistrationOfInValidTaskOrchestration(Type type) + { + Action action = () => + { + this.worker.AddTaskOrchestrations(new[] { type }); + }; + + Assert.ThrowsException(action); + } + + [TestMethod] + [DataRow(typeof(TaskOrchValid))] + [DataRow(typeof(TaskOrchValidGeneric))] + [DataRow(typeof(TaskOrchValidGeneric))] + [DataRow(typeof(TaskOrchGeneric))] + [DataRow(typeof(TaskOrchGeneric))] + public void Test_RegistrationOfValidTaskOrchestration(Type type) + { + this.worker.AddTaskOrchestrations(new[] { type }); + } + + [TestMethod] + [DataRow(typeof(SampleTaskActivity))] + [DataRow(typeof(SampleTaskActivity))] + [DataRow(typeof(SampleTaskActivity))] + public void Test_RegistrationOfInValidTaskActivities(Type type) + { + Action action = () => + { + this.worker.AddTaskActivities(new[] { type }); + }; + + Assert.ThrowsException(action); + } + + [TestMethod] + [DataRow(typeof(SampleTaskActivity))] + [DataRow(typeof(SampleTaskActivity))] + [DataRow(typeof(SampleTaskActivity))] + public void Test_RegistrationOfValidTaskActivities(Type type) + { + this.worker.AddTaskActivities(new[] { type }); + } + + public class TaskOrchValid : TaskOrchestration + { + public override Task RunTask(OrchestrationContext context, string input) + { + throw new NotImplementedException(); + } + } + + public class TaskOrchValidGeneric : TaskOrchestration + { + public override Task RunTask(OrchestrationContext context, string input) + { + throw new NotImplementedException(); + } + } + + public class TaskOrchInValid : TaskOrchestration + { + public override Task RunTask(OrchestrationContext context, CancellationToken input) + { + throw new NotImplementedException(); + } + } + + public class TaskOrchInValidGeneric : TaskOrchestration + { + public override Task RunTask(OrchestrationContext context, CancellationToken input) + { + throw new NotImplementedException(); + } + } + + public class TaskOrchGeneric : TaskOrchestration + { + public override Task RunTask(OrchestrationContext context, Tin input) + { + throw new NotImplementedException(); + } + } + + public class SampleTaskActivity : TaskActivity + { + protected override Tout Execute(TaskContext context, Tin input) + { + throw new NotImplementedException(); + } + } + } +} diff --git a/src/DurableTask.Core/Common/TypeExtension.cs b/src/DurableTask.Core/Common/TypeExtension.cs index 8f672bf76..76d19e0b2 100644 --- a/src/DurableTask.Core/Common/TypeExtension.cs +++ b/src/DurableTask.Core/Common/TypeExtension.cs @@ -22,6 +22,16 @@ namespace DurableTask.Core.Common internal static class TypeExtension { + internal static void HandleTypeValidationError(string metadataInfo) + { + var isCancellationTokenWarningOnly = true; + + if (isCancellationTokenWarningOnly) + { + throw new InvalidOperationException($"Dangerous type deserialization found while converting method to task-activity for {metadataInfo}."); + } + } + /// /// Checks if a type has any dangerous object which can conflict with native OS event handle. /// E.g., CancellationToken, etc. diff --git a/src/DurableTask.Core/ReflectionBasedTaskActivity.cs b/src/DurableTask.Core/ReflectionBasedTaskActivity.cs index b935f8c1c..1bf5eac1e 100644 --- a/src/DurableTask.Core/ReflectionBasedTaskActivity.cs +++ b/src/DurableTask.Core/ReflectionBasedTaskActivity.cs @@ -42,6 +42,12 @@ public ReflectionBasedTaskActivity(object activityObject, MethodInfo methodInfo) ActivityObject = activityObject; MethodInfo = methodInfo; genericArguments = methodInfo.GetGenericArguments(); + + // Add validations for parameters, generic types and return types. + // For parameters. + // For return types. + // For generic arguments. + ValidatePotentialSerializationOfNativeType(methodInfo); } /// @@ -93,6 +99,19 @@ public override async Task RunAsync(TaskContext context, string input) } Type[] genericTypeArguments = this.GetGenericTypeArguments(jArray); + + // Validate generic parameters before deserializing. + if (genericTypeArguments != null) + { + foreach (Type genericTypeArgument in genericTypeArguments) + { + if (genericTypeArgument.IsEqualOrContainsNativeType()) + { + TypeExtension.HandleTypeValidationError($"methodName: {MethodInfo.Name}, genericType: {genericTypeArgument.FullName}"); + } + } + } + object[] inputParameters = this.GetInputParameters(jArray, parameterCount, methodParameters, genericTypeArguments); string serializedReturn = string.Empty; @@ -232,5 +251,35 @@ private object[] GetInputParameters(JArray jArray, int parameterCount, Parameter return inputParameters; } + + private void ValidatePotentialSerializationOfNativeType(MethodInfo methodInfo) + { + foreach (var parameterInfo in methodInfo.GetParameters()) + { + if (parameterInfo.ParameterType.IsEqualOrContainsNativeType()) + { + TypeExtension.HandleTypeValidationError($"methodName: {methodInfo.Name}, parameterName: {parameterInfo.Name}, parameterType: {parameterInfo.ParameterType}"); + } + } + + // Only for return types, if it is of type Task, check for its generic type. + // Why? execution of task-activity doesn't serialize/deserialize Task await, but only the generic type. + if (methodInfo.ReturnType == typeof(Task) || methodInfo.ReturnType.BaseType == typeof(Task) || (methodInfo.ReturnType == typeof(Task<>))) + { + var genericTypes = methodInfo.ReturnType.GetGenericArguments(); + + foreach (var genericType in genericTypes) + { + if (genericType.IsEqualOrContainsNativeType()) + { + TypeExtension.HandleTypeValidationError($"methodName: {methodInfo.Name}, returnType: {genericType.FullName}"); + } + } + } + else if (methodInfo.ReturnType.IsEqualOrContainsNativeType()) + { + TypeExtension.HandleTypeValidationError($"methodName: {methodInfo.Name}, returnType: {methodInfo.ReturnType.FullName}"); + } + } } } \ No newline at end of file diff --git a/src/DurableTask.Core/TaskHubWorker.cs b/src/DurableTask.Core/TaskHubWorker.cs index 24271dc78..2bfa7e17c 100644 --- a/src/DurableTask.Core/TaskHubWorker.cs +++ b/src/DurableTask.Core/TaskHubWorker.cs @@ -21,6 +21,7 @@ namespace DurableTask.Core using System.Runtime.CompilerServices; using System.Threading; using System.Threading.Tasks; + using DurableTask.Core.Common; using DurableTask.Core.Entities; using DurableTask.Core.Exceptions; using DurableTask.Core.Logging; @@ -315,6 +316,9 @@ public TaskHubWorker AddTaskOrchestrations(params Type[] taskOrchestrationTypes) { foreach (Type type in taskOrchestrationTypes) { + // Validate for generic type argument of Taskxxxx. + ValidateTaskType(type, typeof(TaskOrchestration<,,,>)); + ObjectCreator creator = new DefaultObjectCreator(type); this.orchestrationManager.Add(creator); } @@ -333,6 +337,11 @@ public TaskHubWorker AddTaskOrchestrations(params ObjectCreator creator in taskOrchestrationCreators) { + var instance = creator.Create(); + + // Validate for generic type argument of Taskxxxx. + ValidateTaskType(instance.GetType(), typeof(TaskOrchestration<,,,>)); + this.orchestrationManager.Add(creator); } @@ -357,7 +366,7 @@ public TaskHubWorker AddTaskEntities(params Type[] taskEntityTypes) type.Name, string.Empty, type); - + this.entityManager.Add(creator); } @@ -394,6 +403,9 @@ public TaskHubWorker AddTaskActivities(params TaskActivity[] taskActivityObjects { foreach (TaskActivity instance in taskActivityObjects) { + // Validate for generic type argument of Taskxxxx. + ValidateTaskType(instance.GetType(), typeof(AsyncTaskActivity<,>)); + ObjectCreator creator = new DefaultObjectCreator(instance); this.activityManager.Add(creator); } @@ -409,6 +421,9 @@ public TaskHubWorker AddTaskActivities(params Type[] taskActivityTypes) { foreach (Type type in taskActivityTypes) { + // Validate for generic type argument of Taskxxxx. + ValidateTaskType(type, typeof(AsyncTaskActivity<,>)); + ObjectCreator creator = new DefaultObjectCreator(type); this.activityManager.Add(creator); } @@ -642,5 +657,50 @@ private void ValidateActivitiesInterfaceType(Type @interface, object activities) throw new ArgumentException($"type {activities.GetType().FullName} does not implement {@interface.FullName}"); } } + + private static void ValidateTaskType(Type sourceType, Type targetType) + { + var type = GetTargetType(sourceType, targetType); + + if (type != null) + { + var genericArguments = type.GetGenericArguments(); + + if (genericArguments != null) + { + foreach (var genericArgument in genericArguments) + { + if (genericArgument.IsEqualOrContainsNativeType()) + { + TypeExtension.HandleTypeValidationError($"GenericArgument: {genericArgument.FullName} for type: {type.FullName} for source type: {sourceType.FullName}."); + } + } + } + } + } + + private static Type GetTargetType(Type derivedType, Type targetType) + { + if (derivedType != null) + { + if(derivedType.Name.Equals(targetType.Name)) + { + return derivedType; + } + + Type currentType = derivedType.BaseType; + while (currentType != null) + { + if (currentType.Name.Equals(targetType.Name)) + { + return currentType; + } + currentType = currentType.BaseType; + } + return null; + } + + return null; + } } } \ No newline at end of file