From 72c393de18bc5504b720965dd6bc89d84243bb75 Mon Sep 17 00:00:00 2001 From: Jared Goodwin Date: Thu, 20 Jul 2023 15:33:54 -0700 Subject: [PATCH] Hotfix for resolving scripting shells through DI. --- Agent/Program.cs | 2 +- Agent/Services/AgentHubConnection.cs | 2 +- Agent/Services/ExternalScriptingShell.cs | 242 ++++++++++--------- Agent/Services/{PSCore.cs => PsCoreShell.cs} | 31 ++- Agent/Services/ScriptExecutor.cs | 39 +-- 5 files changed, 160 insertions(+), 156 deletions(-) rename Agent/Services/{PSCore.cs => PsCoreShell.cs} (85%) diff --git a/Agent/Program.cs b/Agent/Program.cs index 1fab73ae6..66ffd7dbd 100644 --- a/Agent/Program.cs +++ b/Agent/Program.cs @@ -95,7 +95,7 @@ private static void RegisterServices(IServiceCollection services) services.AddSingleton(); services.AddHostedService(services => services.GetRequiredService()); services.AddScoped(); - services.AddTransient(); + services.AddTransient(); services.AddTransient(); services.AddScoped(); services.AddScoped(); diff --git a/Agent/Services/AgentHubConnection.cs b/Agent/Services/AgentHubConnection.cs index d3766d112..207717e63 100644 --- a/Agent/Services/AgentHubConnection.cs +++ b/Agent/Services/AgentHubConnection.cs @@ -356,7 +356,7 @@ private void RegisterMessageHandlers() { try { - var session = PSCore.GetCurrent(senderConnectionId); + var session = PsCoreShell.GetCurrent(senderConnectionId); var completion = session.GetCompletions(inputText, currentIndex, forward); var completionModel = completion.ToPwshCompletion(); await _hubConnection.InvokeAsync("ReturnPowerShellCompletions", completionModel, intent, senderConnectionId).ConfigureAwait(false); diff --git a/Agent/Services/ExternalScriptingShell.cs b/Agent/Services/ExternalScriptingShell.cs index e75baf5d6..de0c91914 100644 --- a/Agent/Services/ExternalScriptingShell.cs +++ b/Agent/Services/ExternalScriptingShell.cs @@ -2,13 +2,10 @@ using Microsoft.Extensions.Logging; using Remotely.Shared.Enums; using Remotely.Shared.Models; -using Remotely.Shared.Utilities; using System; using System.Collections.Concurrent; -using System.Collections.Generic; using System.Diagnostics; using System.Linq; -using System.Text; using System.Threading; using System.Threading.Tasks; @@ -16,16 +13,29 @@ namespace Remotely.Agent.Services { public interface IExternalScriptingShell { - ScriptResult WriteInput(string input, TimeSpan timeout); + Process ShellProcess { get; } + Task Init(ScriptingShell shell, string shellProcessName, string lineEnding, string connectionId); + Task WriteInput(string input, TimeSpan timeout); } public class ExternalScriptingShell : IExternalScriptingShell { - private static readonly ConcurrentDictionary _sessions = new(); + private static readonly ConcurrentDictionary _sessions = new(); private readonly IConfigService _configService; private readonly ILogger _logger; + private readonly ManualResetEvent _outputDone = new(false); + private readonly SemaphoreSlim _writeLock = new(1, 1); + private string _errorOut = string.Empty; + private string _lastInputID = string.Empty; private string _lineEnding; + private System.Timers.Timer _processIdleTimeout = new(TimeSpan.FromMinutes(10)) + { + AutoReset = false + }; + + private string _senderConnectionId; private ScriptingShell _shell; + private string _standardOut = string.Empty; public ExternalScriptingShell( IConfigService configService, @@ -34,47 +44,31 @@ public ExternalScriptingShell( _configService = configService; _logger = logger; } + public Process ShellProcess { get; set; } - private string ErrorOut { get; set; } - - private string LastInputID { get; set; } - - private ManualResetEvent OutputDone { get; } = new(false); - - private System.Timers.Timer ProcessIdleTimeout { get; set; } - - private string SenderConnectionId { get; set; } - - private Process ShellProcess { get; set; } - - private string StandardOut { get; set; } - - private Stopwatch Stopwatch { get; set; } // TODO: Turn into cache and factory. - public static ExternalScriptingShell GetCurrent(ScriptingShell shell, string senderConnectionId) + public static async Task GetCurrent(ScriptingShell shell, string senderConnectionId) { if (_sessions.TryGetValue($"{shell}-{senderConnectionId}", out var session) && session.ShellProcess?.HasExited != true) { - session.ProcessIdleTimeout.Stop(); - session.ProcessIdleTimeout.Start(); return session; } else { - session = Program.Services.GetRequiredService(); + session = Program.Services.GetRequiredService(); switch (shell) { case ScriptingShell.WinPS: - session.Init(shell, "powershell.exe", "\r\n", senderConnectionId); + await session.Init(shell, "powershell.exe", "\r\n", senderConnectionId); break; case ScriptingShell.Bash: - session.Init(shell, "bash", "\n", senderConnectionId); + await session.Init(shell, "bash", "\n", senderConnectionId); break; case ScriptingShell.CMD: - session.Init(shell, "cmd.exe", "\r\n", senderConnectionId); + await session.Init(shell, "cmd.exe", "\r\n", senderConnectionId); break; default: throw new ArgumentException($"Unknown external scripting shell type: {shell}"); @@ -84,135 +78,143 @@ public static ExternalScriptingShell GetCurrent(ScriptingShell shell, string sen } } - public ScriptResult WriteInput(string input, TimeSpan timeout) + public async Task Init(ScriptingShell shell, string shellProcessName, string lineEnding, string connectionId) + { + _shell = shell; + _lineEnding = lineEnding; + _senderConnectionId = connectionId; + + var psi = new ProcessStartInfo(shellProcessName) + { + WindowStyle = ProcessWindowStyle.Hidden, + Verb = "RunAs", + UseShellExecute = false, + RedirectStandardError = true, + RedirectStandardInput = true, + RedirectStandardOutput = true + }; + + var connectionInfo = _configService.GetConnectionInfo(); + psi.Environment.Add("DeviceId", connectionInfo.DeviceID); + psi.Environment.Add("ServerUrl", connectionInfo.Host); + + ShellProcess = new Process + { + StartInfo = psi + }; + ShellProcess.ErrorDataReceived += ShellProcess_ErrorDataReceived; + ShellProcess.OutputDataReceived += ShellProcess_OutputDataReceived; + + ShellProcess.Start(); + + ShellProcess.BeginErrorReadLine(); + ShellProcess.BeginOutputReadLine(); + + _processIdleTimeout = new System.Timers.Timer(TimeSpan.FromMinutes(10)) + { + AutoReset = false + }; + _processIdleTimeout.Elapsed += ProcessIdleTimeout_Elapsed; + _processIdleTimeout.Start(); + + if (shell == ScriptingShell.WinPS) + { + await WriteInput("$VerbosePreference = \"Continue\";", TimeSpan.FromSeconds(5)); + await WriteInput("$DebugPreference = \"Continue\";", TimeSpan.FromSeconds(5)); + await WriteInput("$InformationPreference = \"Continue\";", TimeSpan.FromSeconds(5)); + await WriteInput("$WarningPreference = \"Continue\";", TimeSpan.FromSeconds(5)); + } + } + + public async Task WriteInput(string input, TimeSpan timeout) { + await _writeLock.WaitAsync(); + var sw = Stopwatch.StartNew(); + try { - StandardOut = ""; - ErrorOut = ""; - Stopwatch = Stopwatch.StartNew(); - lock (ShellProcess) - { - LastInputID = Guid.NewGuid().ToString(); - OutputDone.Reset(); - ShellProcess.StandardInput.Write(input + _lineEnding); - ShellProcess.StandardInput.Write("echo " + LastInputID + _lineEnding); - - var result = Task.WhenAny( - Task.Run(() => - { - return ShellProcess.WaitForExit((int)timeout.TotalMilliseconds); - }), - Task.Run(() => - { - return OutputDone.WaitOne(); - - })).ConfigureAwait(false).GetAwaiter().GetResult(); - - if (!result.Result) + _processIdleTimeout.Stop(); + _processIdleTimeout.Start(); + _outputDone.Reset(); + + _standardOut = ""; + _errorOut = ""; + _lastInputID = Guid.NewGuid().ToString(); + + ShellProcess.StandardInput.Write(input + _lineEnding); + ShellProcess.StandardInput.Write("echo " + _lastInputID + _lineEnding); + + var result = await Task.WhenAny( + Task.Run(() => + { + return ShellProcess.WaitForExit((int)timeout.TotalMilliseconds); + }), + Task.Run(() => { - return GeneratePartialResult(input); - } + return _outputDone.WaitOne(); + + })).ConfigureAwait(false).GetAwaiter().GetResult(); + + if (!result) + { + return GeneratePartialResult(input, sw.Elapsed); } - return GenerateCompletedResult(input); + + return GenerateCompletedResult(input, sw.Elapsed); } catch (Exception ex) { _logger.LogError(ex, "Error while writing input to scripting shell."); - ErrorOut += Environment.NewLine + ex.Message; + _errorOut += Environment.NewLine + ex.Message; // Something's wrong. Let the next command start a new session. RemoveSession(); } + finally + { + _writeLock.Release(); + } - return GeneratePartialResult(input); + return GeneratePartialResult(input, sw.Elapsed); } - private ScriptResult GenerateCompletedResult(string input) + private ScriptResult GenerateCompletedResult(string input, TimeSpan runtime) { return new ScriptResult() { Shell = _shell, - RunTime = Stopwatch.Elapsed, + RunTime = runtime, ScriptInput = input, - SenderConnectionID = SenderConnectionId, + SenderConnectionID = _senderConnectionId, DeviceID = _configService.GetConnectionInfo().DeviceID, - StandardOutput = StandardOut.Split(Environment.NewLine), - ErrorOutput = ErrorOut.Split(Environment.NewLine), - HadErrors = !string.IsNullOrWhiteSpace(ErrorOut) || + StandardOutput = _standardOut.Split(Environment.NewLine), + ErrorOutput = _errorOut.Split(Environment.NewLine), + HadErrors = !string.IsNullOrWhiteSpace(_errorOut) || (ShellProcess.HasExited && ShellProcess.ExitCode != 0) }; } - private ScriptResult GeneratePartialResult(string input) + private ScriptResult GeneratePartialResult(string input, TimeSpan runtime) { var partialResult = new ScriptResult() { Shell = _shell, - RunTime = Stopwatch.Elapsed, + RunTime = runtime, ScriptInput = input, - SenderConnectionID = SenderConnectionId, + SenderConnectionID = _senderConnectionId, DeviceID = _configService.GetConnectionInfo().DeviceID, - StandardOutput = StandardOut.Split(Environment.NewLine), + StandardOutput = _standardOut.Split(Environment.NewLine), ErrorOutput = (new[] { "WARNING: The command execution timed out and was forced to return before finishing. " + "The results may be partial, and the terminal process has been reset. " + "Please note that interactive commands aren't supported."}) - .Concat(ErrorOut.Split(Environment.NewLine)) + .Concat(_errorOut.Split(Environment.NewLine)) .ToArray(), - HadErrors = !string.IsNullOrWhiteSpace(ErrorOut) || + HadErrors = !string.IsNullOrWhiteSpace(_errorOut) || (ShellProcess.HasExited && ShellProcess.ExitCode != 0) }; ProcessIdleTimeout_Elapsed(this, null); return partialResult; } - - private void Init(ScriptingShell shell, string shellProcessName, string lineEnding, string connectionId) - { - _shell = shell; - _lineEnding = lineEnding; - SenderConnectionId = connectionId; - - var psi = new ProcessStartInfo(shellProcessName) - { - WindowStyle = ProcessWindowStyle.Hidden, - Verb = "RunAs", - UseShellExecute = false, - RedirectStandardError = true, - RedirectStandardInput = true, - RedirectStandardOutput = true - }; - - var connectionInfo = _configService.GetConnectionInfo(); - psi.Environment.Add("DeviceId", connectionInfo.DeviceID); - psi.Environment.Add("ServerUrl", connectionInfo.Host); - - ShellProcess = new Process - { - StartInfo = psi - }; - ShellProcess.ErrorDataReceived += ShellProcess_ErrorDataReceived; - ShellProcess.OutputDataReceived += ShellProcess_OutputDataReceived; - - ShellProcess.Start(); - - ShellProcess.BeginErrorReadLine(); - ShellProcess.BeginOutputReadLine(); - - ProcessIdleTimeout = new System.Timers.Timer(TimeSpan.FromMinutes(10).TotalMilliseconds) - { - AutoReset = false - }; - ProcessIdleTimeout.Elapsed += ProcessIdleTimeout_Elapsed; - ProcessIdleTimeout.Start(); - - if (shell == ScriptingShell.WinPS) - { - WriteInput("$VerbosePreference = \"Continue\";", TimeSpan.FromSeconds(5)); - WriteInput("$DebugPreference = \"Continue\";", TimeSpan.FromSeconds(5)); - WriteInput("$InformationPreference = \"Continue\";", TimeSpan.FromSeconds(5)); - WriteInput("$WarningPreference = \"Continue\";", TimeSpan.FromSeconds(5)); - } - } private void ProcessIdleTimeout_Elapsed(object sender, System.Timers.ElapsedEventArgs e) { RemoveSession(); @@ -221,26 +223,26 @@ private void ProcessIdleTimeout_Elapsed(object sender, System.Timers.ElapsedEven private void RemoveSession() { ShellProcess?.Kill(); - _sessions.TryRemove(SenderConnectionId, out _); + _sessions.TryRemove(_senderConnectionId, out _); } private void ShellProcess_ErrorDataReceived(object sender, DataReceivedEventArgs e) { if (e?.Data != null) { - ErrorOut += e.Data + Environment.NewLine; + _errorOut += e.Data + Environment.NewLine; } } private void ShellProcess_OutputDataReceived(object sender, DataReceivedEventArgs e) { - if (e?.Data?.Contains(LastInputID) == true) + if (e?.Data?.Contains(_lastInputID) == true) { - OutputDone.Set(); + _outputDone.Set(); } else { - StandardOut += e.Data + Environment.NewLine; + _standardOut += e.Data + Environment.NewLine; } } diff --git a/Agent/Services/PSCore.cs b/Agent/Services/PsCoreShell.cs similarity index 85% rename from Agent/Services/PSCore.cs rename to Agent/Services/PsCoreShell.cs index c7af5f9d8..c7de39495 100644 --- a/Agent/Services/PSCore.cs +++ b/Agent/Services/PsCoreShell.cs @@ -6,32 +6,30 @@ using System.Diagnostics; using System.Linq; using System.Management.Automation; -using System.Management.Automation.Runspaces; -using System.Timers; -using static Immense.RemoteControl.Desktop.Native.Windows.User32; +using System.Threading.Tasks; namespace Remotely.Agent.Services { - public interface IPSCore + public interface IPsCoreShell { - string SenderConnectionId { get; } + string SenderConnectionId { get; set; } CommandCompletion GetCompletions(string inputText, int currentIndex, bool? forward); - ScriptResult WriteInput(string input); + Task WriteInput(string input); } - public class PSCore : IPSCore + public class PsCoreShell : IPsCoreShell { - private static readonly ConcurrentDictionary _sessions = new ConcurrentDictionary(); + private static readonly ConcurrentDictionary _sessions = new(); private readonly IConfigService _configService; private readonly ConnectionInfo _connectionInfo; - private readonly ILogger _logger; + private readonly ILogger _logger; private readonly PowerShell _powershell; private CommandCompletion _lastCompletion; private string _lastInputText; - public PSCore( + public PsCoreShell( IConfigService configService, - ILogger logger) + ILogger logger) { _configService = configService; _logger = logger; @@ -49,9 +47,9 @@ public PSCore( _powershell.Invoke(); } - public string SenderConnectionId { get; private set; } + public string SenderConnectionId { get; set; } // TODO: Turn into cache and factory. - public static PSCore GetCurrent(string senderConnectionId) + public static IPsCoreShell GetCurrent(string senderConnectionId) { if (_sessions.TryGetValue(senderConnectionId, out var session)) { @@ -59,7 +57,7 @@ public static PSCore GetCurrent(string senderConnectionId) } else { - session = Program.Services.GetRequiredService(); + session = Program.Services.GetRequiredService(); session.SenderConnectionId = senderConnectionId; _sessions.AddOrUpdate(senderConnectionId, session, (id, b) => session); return session; @@ -83,7 +81,7 @@ public CommandCompletion GetCompletions(string inputText, int currentIndex, bool return _lastCompletion; } - public ScriptResult WriteInput(string input) + public async Task WriteInput(string input) { var deviceId = _configService.GetConnectionInfo().DeviceID; var sw = Stopwatch.StartNew(); @@ -100,7 +98,8 @@ public ScriptResult WriteInput(string input) using var ps = PowerShell.Create(); ps.AddScript("$args[0] | Out-String"); ps.AddArgument(results); - var hostOutput = (string)ps.Invoke()[0].BaseObject; + var result = await ps.InvokeAsync(); + var hostOutput = result[0].BaseObject.ToString(); var verboseOut = _powershell.Streams.Verbose.ReadAll().Select(x => x.Message); var debugOut = _powershell.Streams.Debug.ReadAll().Select(x => x.Message); diff --git a/Agent/Services/ScriptExecutor.cs b/Agent/Services/ScriptExecutor.cs index bc8013c74..69bbaf906 100644 --- a/Agent/Services/ScriptExecutor.cs +++ b/Agent/Services/ScriptExecutor.cs @@ -5,9 +5,6 @@ using Remotely.Shared.Models; using Remotely.Shared.Utilities; using System; -using System.IO; -using System.Linq; -using System.Net; using System.Net.Http; using System.Net.Http.Json; using System.Text.Json; @@ -44,7 +41,7 @@ public async Task RunCommandFromApi(ScriptingShell shell, try { - var result = ExecuteScriptContent(shell, requestID, command, TimeSpan.FromMinutes(AppConstants.ScriptRunExpirationMinutes)); + var result = await ExecuteScriptContent(shell, requestID, command, TimeSpan.FromMinutes(AppConstants.ScriptRunExpirationMinutes)); result.InputType = ScriptInputType.Api; result.SenderUserName = senderUsername; @@ -69,7 +66,7 @@ public async Task RunCommandFromTerminal(ScriptingShell shell, { try { - var result = ExecuteScriptContent(shell, senderConnectionID, command, timeout); + var result = await ExecuteScriptContent(shell, senderConnectionID, command, timeout); result.InputType = scriptInputType; result.SenderUserName = senderUsername; @@ -112,7 +109,7 @@ public async Task RunScript(Guid savedScriptId, hc.DefaultRequestHeaders.Add("Authorization", authToken); var savedScript = await hc.GetFromJsonAsync(url); - var result = ExecuteScriptContent(savedScript.Shell, + var result = await ExecuteScriptContent(savedScript.Shell, Guid.NewGuid().ToString(), savedScript.Content, TimeSpan.FromMinutes(AppConstants.ScriptRunExpirationMinutes)); @@ -130,8 +127,8 @@ public async Task RunScript(Guid savedScriptId, } } - // TODO: Async/await. - private ScriptResult ExecuteScriptContent(ScriptingShell shell, + private async Task ExecuteScriptContent( + ScriptingShell shell, string terminalSessionId, string command, TimeSpan timeout) @@ -139,31 +136,37 @@ private ScriptResult ExecuteScriptContent(ScriptingShell shell, switch (shell) { case ScriptingShell.PSCore: - return PSCore.GetCurrent(terminalSessionId).WriteInput(command); + return await PsCoreShell + .GetCurrent(terminalSessionId) + .WriteInput(command); case ScriptingShell.WinPS: if (EnvironmentHelper.IsWindows) { - return ExternalScriptingShell - .GetCurrent(ScriptingShell.WinPS, terminalSessionId) - .WriteInput(command, timeout); + var instance = await ExternalScriptingShell + .GetCurrent(ScriptingShell.WinPS, terminalSessionId); + return await instance.WriteInput(command, timeout); } break; case ScriptingShell.CMD: if (EnvironmentHelper.IsWindows) { - return ExternalScriptingShell - .GetCurrent(ScriptingShell.CMD, terminalSessionId) - .WriteInput(command, timeout); + var instance = await ExternalScriptingShell + .GetCurrent(ScriptingShell.CMD, terminalSessionId); + + return await instance.WriteInput(command, timeout); } break; case ScriptingShell.Bash: - return ExternalScriptingShell - .GetCurrent(ScriptingShell.Bash, terminalSessionId) - .WriteInput(command, timeout); + { + var instance = await ExternalScriptingShell + .GetCurrent(ScriptingShell.Bash, terminalSessionId); + + return await instance.WriteInput(command, timeout); + } default: break; }