From dd31365a4ad691cd4851caca623f8e07fdbe1ec6 Mon Sep 17 00:00:00 2001 From: xtqqczze <45661989+xtqqczze@users.noreply.github.com> Date: Sat, 23 Aug 2025 22:02:25 +0100 Subject: [PATCH] Refactor `GetFileShares` * Use unmanaged pointers instead of `nint` * Avoid pointer arithmetic, use Span instead * Fix PInvoke definitions * Avoid `PtrToStructure` and use blittable struct for performance. --- .../CommandCompletion/CompletionCompleters.cs | 77 ++++++++----------- .../Interop/Windows/NetApiBufferFree.cs | 2 +- .../engine/Interop/Windows/NetShareEnum.cs | 48 +++++++++--- 3 files changed, 72 insertions(+), 55 deletions(-) diff --git a/src/System.Management.Automation/engine/CommandCompletion/CompletionCompleters.cs b/src/System.Management.Automation/engine/CommandCompletion/CompletionCompleters.cs index 2b90bc02837..7208bb48369 100644 --- a/src/System.Management.Automation/engine/CommandCompletion/CompletionCompleters.cs +++ b/src/System.Management.Automation/engine/CommandCompletion/CompletionCompleters.cs @@ -17,6 +17,9 @@ using System.Management.Automation.Runspaces; using System.Reflection; using System.Runtime.InteropServices; +#if !UNIX +using System.Runtime.InteropServices.Marshalling; +#endif using System.Text; using System.Text.RegularExpressions; using System.Threading; @@ -5181,14 +5184,6 @@ private static string NewPathCompletionText(string parent, string leaf, StringCo @"(^Microsoft\.PowerShell\.Core\\FileSystem::|^FileSystem::|^)(?:\\\\|//)(?![.|?])([^\\/]+)(?:\\|/)([^\\/]*)$", RegexOptions.IgnoreCase | RegexOptions.Compiled); - [StructLayout(LayoutKind.Sequential, CharSet = CharSet.Unicode)] - private struct SHARE_INFO_1 - { - public string netname; - public int type; - public string remark; - } - private static readonly System.IO.EnumerationOptions _enumerationOptions = new System.IO.EnumerationOptions { MatchCasing = MatchCasing.CaseInsensitive, @@ -5200,50 +5195,46 @@ internal static List GetFileShares(string machine, bool ignoreHidden) #if UNIX return new List(); #else - nint shBuf = nint.Zero; - uint numEntries = 0; - uint totalEntries; - uint resumeHandle = 0; - try + unsafe { - int result = Interop.Windows.NetShareEnum( - machine, - level: 1, - out shBuf, - Interop.Windows.MAX_PREFERRED_LENGTH, - out numEntries, - out totalEntries, - ref resumeHandle); - - var shares = new List(); - if (result == Interop.Windows.ERROR_SUCCESS || result == Interop.Windows.ERROR_MORE_DATA) + Interop.Windows.SHARE_INFO_1* pShareInfo = null; + try { - for (int i = 0; i < numEntries; ++i) - { - nint curInfoPtr = shBuf + (Marshal.SizeOf() * i); - SHARE_INFO_1 shareInfo = Marshal.PtrToStructure(curInfoPtr); + uint result = Interop.Windows.NetShareEnum( + machine, + out pShareInfo, + out int count); - if ((shareInfo.type & Interop.Windows.STYPE_MASK) != Interop.Windows.STYPE_DISKTREE) - { - continue; - } + List shares = new(); - if (ignoreHidden && shareInfo.netname.EndsWith('$')) + if (result is Interop.Windows.ERROR_SUCCESS or Interop.Windows.ERROR_MORE_DATA) + { + foreach (Interop.Windows.SHARE_INFO_1 shareInfo in new ReadOnlySpan(pShareInfo, count)) { - continue; - } + if ((shareInfo.type & Interop.Windows.STYPE_MASK) != Interop.Windows.STYPE_DISKTREE) + { + continue; + } + + string share = Utf16StringMarshaller.ConvertToManaged(shareInfo.netname); + + if (ignoreHidden && share.EndsWith('$')) + { + continue; + } - shares.Add(shareInfo.netname); + shares.Add(share); + } } - } - return shares; - } - finally - { - if (shBuf != nint.Zero) + return shares; + } + finally { - Interop.Windows.NetApiBufferFree(shBuf); + if (pShareInfo is not null) + { + Interop.Windows.NetApiBufferFree(pShareInfo); + } } } #endif diff --git a/src/System.Management.Automation/engine/Interop/Windows/NetApiBufferFree.cs b/src/System.Management.Automation/engine/Interop/Windows/NetApiBufferFree.cs index 4f7d0541872..fca901726b3 100644 --- a/src/System.Management.Automation/engine/Interop/Windows/NetApiBufferFree.cs +++ b/src/System.Management.Automation/engine/Interop/Windows/NetApiBufferFree.cs @@ -11,6 +11,6 @@ internal static unsafe partial class Windows { [LibraryImport("Netapi32.dll")] - internal static partial uint NetApiBufferFree(nint Buffer); + internal static unsafe partial uint NetApiBufferFree(void* Buffer); } } diff --git a/src/System.Management.Automation/engine/Interop/Windows/NetShareEnum.cs b/src/System.Management.Automation/engine/Interop/Windows/NetShareEnum.cs index 7efad887f1a..bf44459bd6d 100644 --- a/src/System.Management.Automation/engine/Interop/Windows/NetShareEnum.cs +++ b/src/System.Management.Automation/engine/Interop/Windows/NetShareEnum.cs @@ -3,24 +3,50 @@ #nullable enable +using System; using System.Runtime.InteropServices; internal static partial class Interop { internal static unsafe partial class Windows { - internal const int MAX_PREFERRED_LENGTH = -1; - internal const int STYPE_DISKTREE = 0; - internal const int STYPE_MASK = 0x000000FF; + internal const uint MAX_PREFERRED_LENGTH = uint.MaxValue; + internal const uint STYPE_DISKTREE = 0u; + internal const uint STYPE_MASK = 255u; [LibraryImport("Netapi32.dll", StringMarshalling = StringMarshalling.Utf16)] - internal static partial int NetShareEnum( - string serverName, - int level, - out nint bufptr, - int prefMaxLen, - out uint entriesRead, - out uint totalEntries, - ref uint resumeHandle); + private static unsafe partial uint NetShareEnum( + string? servername, + uint level, + out byte* bufptr, + uint prefmaxlen, + out uint entriesread, + out uint totalentries, + uint* resume_handle); + + [StructLayout(LayoutKind.Sequential)] + internal unsafe struct SHARE_INFO_1 + { + public ushort* netname; + public uint type; + public ushort* remark; + } + + internal static uint NetShareEnum(string? servername, out T* pShareInfo, out int count) where T : unmanaged + { + uint level = (uint)GetLevelFromStructure(); + uint result = NetShareEnum(servername, level, out byte* pBuffer, MAX_PREFERRED_LENGTH, out uint entriesRead, out _, null); + pShareInfo = (T*)pBuffer; + count = (int)entriesRead; + return result; + } + + private static int GetLevelFromStructure() + { + if (typeof(T) == typeof(SHARE_INFO_1)) + return 1; + + throw new NotSupportedException(); + } } }