From 6dc20f4b94c2ea78e8865a16c849b1ebf7c83109 Mon Sep 17 00:00:00 2001 From: Tim Tucker Date: Fri, 20 Sep 2024 10:04:43 -0400 Subject: [PATCH 1/3] Lazy load AudioDevice properties, add additional helper functions, move redundant code into helper functions, make variable names more consistent, add typing --- pycaw/utils.py | 364 ++++++++++++++++++++++++++++++++------------- tests/test_core.py | 5 +- 2 files changed, 265 insertions(+), 104 deletions(-) diff --git a/pycaw/utils.py b/pycaw/utils.py index e3fd056..27227af 100644 --- a/pycaw/utils.py +++ b/pycaw/utils.py @@ -1,8 +1,11 @@ import warnings +from typing import List, TypeVar, Generic +from __future__ import annotations import comtypes import psutil from _ctypes import COMError +from comtypes import IUnknown from pycaw.api.audioclient import IChannelAudioVolume, ISimpleAudioVolume from pycaw.api.audiopolicy import IAudioSessionControl2, IAudioSessionManager2 @@ -18,22 +21,98 @@ IID_Empty, ) +# Define a type variable that extends IUnknown +COMInterface = TypeVar('COMInterface', bound='IUnknown') class AudioDevice: """ https://stackoverflow.com/a/20982715/185510 """ - def __init__(self, id, state, properties, dev): - self.id = id - self.state = state - self.properties = properties + def __init__(self, dev): self._dev = dev - self._volume = None + self._id = None + self._state = None + self._properties = None + self._interfaces: dict[str, COMInterface] = {} def __str__(self): return "AudioDevice: %s" % (self.FriendlyName) + def ActivateInterface(self, interface: COMInterface) -> COMInterface: + interface_id = interface._iid_ + if interface_id is None: + return None + + interface_name = str(interface_id) + + if interface_name not in self._interfaces: + activated = self._dev.Activate(interface_id, comtypes.CLSCTX_ALL, None) + interface = activated.QueryInterface(interface) + self._interfaces[interface_name] = interface + + return self._interfaces[interface_name] + + @property + def id(self): + if self._id is None: + self._id = self._dev.GetId() + return self._id + + @property + def state(self) -> AudioDeviceState: + if self._state is None: + state = self._dev.GetState() + self._state = AudioDeviceState(state) + return self._state + + @staticmethod + def getProperty(self, key): + store = self._dev.OpenPropertyStore(STGM.STGM_READ.value) + if store is None: + return None + + propCount = store.GetCount() + for j in range(propCount): + pk = store.GetAt(j) + name = str(pk) + + if name != key: + continue + + try: + value = store.GetValue(pk) + v = value.GetValue() + value.clear() + return v + except COMError as exc: + warnings.warn( + "COMError attempting to get property %r " + "from device %r: %r" % (j, id, exc) + ) + return {} + + @property + def properties(self) -> dict: + if self._properties is None: + store = self._dev.OpenPropertyStore(STGM.STGM_READ.value) + if store is None: + return {} + + properties = {} + propCount = store.GetCount() + for j in range(propCount): + pk = store.GetAt(j) + + value = AudioDevice.getProperty(self, pk) + if value is None: + continue + + name = str(pk) + properties[name] = value + self._properties = properties + return self._properties + @property def FriendlyName(self): DEVPKEY_Device_FriendlyName = ( @@ -43,13 +122,53 @@ def FriendlyName(self): return value @property - def EndpointVolume(self): - if self._volume is None: - iface = self._dev.Activate( - IAudioEndpointVolume._iid_, comtypes.CLSCTX_ALL, None - ) - self._volume = iface.QueryInterface(IAudioEndpointVolume) - return self._volume + def EndpointVolume(self) -> IAudioEndpointVolume: + interface = self.ActivateInterface(IAudioEndpointVolume) + return interface + + @property + def IsMuted(self) -> bool: + endpointVolume = self.EndpointVolume + return endpointVolume.GetMute() == 1 + + def Mute(self): + endpointVolume = self.EndpointVolume + endpointVolume.SetMute(1, None) + + def UnMute(self): + endpointVolume = self.EndpointVolume + endpointVolume.SetMute(0, None) + + @property + def AudioSessionManager(self) -> IAudioSessionManager2: + return self.ActivateInterface(IAudioSessionManager2) + + @property + def Sessions(self) -> List[AudioSession]: + # Only Active devices should have audio sessions + if self.state != AudioDeviceState.Active: + return [] + + mgr = self.AudioSessionManager + + # Assume no sessions if there's no session enumerator + if not hasattr(mgr, "GetSessionEnumerator"): + return [] + + sessions = [] + + sessionEnumerator = mgr.GetSessionEnumerator() + count = sessionEnumerator.GetCount() + for i in range(count): + ctl = sessionEnumerator.GetSession(i) + if ctl is None: + continue + ctl2 = ctl.QueryInterface(IAudioSessionControl2) + if ctl2 is not None: + audio_session = AudioSession(ctl2) + sessions.append(audio_session) + + return sessions class AudioSession: @@ -57,11 +176,10 @@ class AudioSession: https://stackoverflow.com/a/20982715/185510 """ - def __init__(self, audio_session_control2): + def __init__(self, audio_session_control2: IAudioSessionControl2): self._ctl = audio_session_control2 self._process = None - self._volume = None - self._channelVolume = None + self._interfaces: dict[str, COMInterface] = {} self._callback = None def __str__(self): @@ -73,7 +191,7 @@ def __str__(self): return "Pid: %s" % (self.ProcessId) @property - def Process(self): + def Process(self) -> psutil.Process | None: if self._process is None and self.ProcessId != 0: try: self._process = psutil.Process(self.ProcessId) @@ -120,13 +238,13 @@ def DisplayName(self): return s @DisplayName.setter - def DisplayName(self, value): + def DisplayName(self, value: str) -> str: s = self._ctl.GetDisplayName() if s != value: self._ctl.SetDisplayName(value, IID_Empty) @property - def IconPath(self): + def IconPath(self) -> str: """ Please, note that this returns an empty string if the client hadn't called the setter method before. @@ -135,21 +253,25 @@ def IconPath(self): return s @IconPath.setter - def IconPath(self, value): + def IconPath(self, value: str): s = self._ctl.GetIconPath() if s != value: self._ctl.SetIconPath(value, IID_Empty) + def QueryInterface(self, interface: COMInterface) -> COMInterface: + interface_name = str(interface._iid_) + + if interface_name not in self._interfaces: + self._interfaces[interface_name] = self._ctl.QueryInterface(interface) + + return self._interfaces[interface_name] + @property def SimpleAudioVolume(self): - if self._volume is None: - self._volume = self._ctl.QueryInterface(ISimpleAudioVolume) - return self._volume + return self.QueryInterface(ISimpleAudioVolume) def channelAudioVolume(self): - if self._channelVolume is None: - self._channelVolume = self._ctl.QueryInterface(IChannelAudioVolume) - return self._channelVolume + return self.QueryInterface(IChannelAudioVolume) def register_notification(self, callback): if self._callback is None: @@ -166,126 +288,162 @@ class AudioUtilities: https://stackoverflow.com/a/20982715/185510 """ + @staticmethod + def GetDefaultEndpoint(dataFlow: EDataFlow = EDataFlow.eAll.value): + deviceEnumerator = AudioUtilities.GetDeviceEnumerator() + if deviceEnumerator is None: + return None + + return deviceEnumerator.GetDefaultAudioEndpoint( + dataFlow, ERole.eMultimedia.value + ) + @staticmethod def GetSpeakers(): """ get the speakers (1st render + multimedia) device """ - deviceEnumerator = comtypes.CoCreateInstance( - CLSID_MMDeviceEnumerator, IMMDeviceEnumerator, comtypes.CLSCTX_INPROC_SERVER - ) - speakers = deviceEnumerator.GetDefaultAudioEndpoint( - EDataFlow.eRender.value, ERole.eMultimedia.value - ) - return speakers + return AudioUtilities.GetDefaultEndpoint(EDataFlow.eRender.value) @staticmethod def GetMicrophone(): """ get the microphone (1st capture + multimedia) device """ - deviceEnumerator = comtypes.CoCreateInstance( - CLSID_MMDeviceEnumerator, IMMDeviceEnumerator, comtypes.CLSCTX_INPROC_SERVER - ) - microphone = deviceEnumerator.GetDefaultAudioEndpoint( - EDataFlow.eCapture.value, ERole.eMultimedia.value - ) - return microphone + return AudioUtilities.GetDefaultEndpoint(EDataFlow.eCapture.value) @staticmethod - def GetAudioSessionManager(): + def GetAudioSessionManager() -> IAudioSessionManager2 | None: speakers = AudioUtilities.GetSpeakers() if speakers is None: return None - # win7+ only - o = speakers.Activate(IAudioSessionManager2._iid_, comtypes.CLSCTX_ALL, None) - mgr = o.QueryInterface(IAudioSessionManager2) - return mgr + + device = AudioUtilities.CreateDevice(speakers) + return device.AudioSessionManager @staticmethod - def GetAllSessions(): - audio_sessions = [] - mgr = AudioUtilities.GetAudioSessionManager() - if mgr is None: - return audio_sessions - sessionEnumerator = mgr.GetSessionEnumerator() - count = sessionEnumerator.GetCount() - for i in range(count): - ctl = sessionEnumerator.GetSession(i) - if ctl is None: + def GetAllSessions() -> List[AudioSession]: + # TODO: the current behavior here isn't really getting "all sessions" + # Leaving behavior as-is for backward compatibility + speakers = AudioUtilities.GetSpeakers() + if speakers is None: + return [] + + device = AudioUtilities.CreateDevice(speakers) + return device.Sessions + + @staticmethod + def GetSessions( + dataFlow: EDataFlow = EDataFlow.eAll.value, + sessionState: AudioDeviceState = None, + ) -> List[AudioSession]: + # Only active devices can have sessions associated with them + devices = AudioUtilities.GetDevices(dataFlow, DEVICE_STATE.ACTIVE.value) + if devices is None: + return [] + + sessions = [] + for device in devices: + deviceSessions = device.Sessions + if deviceSessions is None: continue - ctl2 = ctl.QueryInterface(IAudioSessionControl2) - if ctl2 is not None: - audio_session = AudioSession(ctl2) - audio_sessions.append(audio_session) - return audio_sessions + + for deviceSession in deviceSessions: + if sessionState is None or sessionState == deviceSession.State: + sessions.append(deviceSession) + + return sessions + + @staticmethod + def GetPlaybackSessions(sessionState: AudioDeviceState = None) -> List[AudioSession]: + return AudioUtilities.GetSessions(EDataFlow.eRender.value, sessionState) + + @staticmethod + def GetRecordingSessions(sessionState: AudioDeviceState = None) -> List[AudioSession]: + return AudioUtilities.GetSessions(EDataFlow.eCapture.value, sessionState) @staticmethod - def GetProcessSession(id): - for session in AudioUtilities.GetAllSessions(): + def GetProcessSession(id) -> AudioSession | None: + # Need to look at all sessions to ensure + # we find the one we're looking for + for session in AudioUtilities.GetSessions(): + if session is None: + continue if session.ProcessId == id: return session # session.Dispose() return None @staticmethod - def CreateDevice(dev): + def CreateDevice(dev) -> AudioDevice | None: if dev is None: return None - id = dev.GetId() - state = dev.GetState() - properties = {} - store = dev.OpenPropertyStore(STGM.STGM_READ.value) - if store is not None: - propCount = store.GetCount() - for j in range(propCount): - try: - pk = store.GetAt(j) - value = store.GetValue(pk) - v = value.GetValue() - except COMError as exc: - warnings.warn( - "COMError attempting to get property %r " - "from device %r: %r" % (j, dev, exc) - ) - continue - value.clear() - name = str(pk) - properties[name] = v - audioState = AudioDeviceState(state) - return AudioDevice(id, audioState, properties, dev) - @staticmethod - def GetAllDevices(): - devices = [] - deviceEnumerator = comtypes.CoCreateInstance( - CLSID_MMDeviceEnumerator, IMMDeviceEnumerator, comtypes.CLSCTX_INPROC_SERVER - ) - if deviceEnumerator is None: - return devices + device = AudioDevice(dev) - collection = deviceEnumerator.EnumAudioEndpoints( - EDataFlow.eAll.value, DEVICE_STATE.MASK_ALL.value - ) + return device + + @staticmethod + def CreateDevices(collection) -> List[AudioDevice]: if collection is None: - return devices + return [] + + devices = [] count = collection.GetCount() for i in range(count): dev = collection.Item(i) - if dev is not None: - devices.append(AudioUtilities.CreateDevice(dev)) + device = AudioUtilities.CreateDevice(dev) + + if device is not None: + devices.append(device) + return devices @staticmethod - def GetDeviceEnumerator(): + def GetDevices(flow=EDataFlow.eAll.value, deviceState=DEVICE_STATE.ACTIVE.value) -> List[AudioDevice]: + """ + Get devices based on filteres for flow direction and device state. + Default to returning active devices. + """ + deviceEnumerator = AudioUtilities.GetDeviceEnumerator() + if deviceEnumerator is None: + return [] + + collection = deviceEnumerator.EnumAudioEndpoints(flow, deviceState) + return AudioUtilities.CreateDevices(collection) + + @staticmethod + def GetAllDevices() -> List[IMMDeviceEnumerator]: + deviceEnumerator = AudioUtilities.GetDeviceEnumerator() + if deviceEnumerator is None: + return [] + + collection = deviceEnumerator.EnumAudioEndpoints( + EDataFlow.eAll.value, DEVICE_STATE.MASK_ALL.value + ) + return AudioUtilities.CreateDevices(collection) + + @staticmethod + def GetDeviceEnumerator() -> IMMDeviceEnumerator: """ Get an instance of IMMDeviceEnumerator. """ - device_enumerator = comtypes.CoCreateInstance( + deviceEnumerator = comtypes.CoCreateInstance( CLSID_MMDeviceEnumerator, IMMDeviceEnumerator, comtypes.CLSCTX_INPROC_SERVER ) - return device_enumerator + return deviceEnumerator + + @staticmethod + def GetDevice(devId): + """ + Get AudioDevice. + One input argument: + - devId: id of the device + """ + deviceEnumerator = AudioUtilities.GetDeviceEnumerator() + dev = deviceEnumerator.GetDevice(devId) + return AudioUtilities.CreateDevice(dev) @staticmethod def GetEndpointDataFlow(devId, outputType=0): @@ -296,8 +454,8 @@ def GetEndpointDataFlow(devId, outputType=0): - outputType: 0 (default) for text, 1 for code. """ DataFlow = ["eRender", "eCapture", "eAll", "EDataFlow_enum_count"] - devEnum = AudioUtilities.GetDeviceEnumerator() - dev = devEnum.GetDevice(devId) + deviceEnumerator = AudioUtilities.GetDeviceEnumerator() + dev = deviceEnumerator.GetDevice(devId) value = dev.QueryInterface(IMMEndpoint).GetDataFlow() if outputType: return value diff --git a/tests/test_core.py b/tests/test_core.py index 437aa4a..477bb31 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -53,7 +53,10 @@ def test_device_failed_properties(self): store.GetValue = mock.Mock(side_effect=_ctypes.COMError(None, None, None)) dev.OpenPropertyStore = mock.Mock(return_value=store) with warnings.catch_warnings(record=True) as w: - AudioUtilities.CreateDevice(dev) + device = AudioUtilities.CreateDevice(dev) + # Properties are lazy-loaded and won't throw exceptions + # until accessed + dev = device.properties assert len(w) == 1 assert "COMError attempting to get property 0 from device" in str(w[0].message) From 25fcce2dbe615576885103858e5908176a432b1c Mon Sep 17 00:00:00 2001 From: Tim Tucker Date: Sat, 21 Sep 2024 09:16:21 -0400 Subject: [PATCH 2/3] lint --- pycaw/utils.py | 30 +++++++++++++++++++----------- 1 file changed, 19 insertions(+), 11 deletions(-) diff --git a/pycaw/utils.py b/pycaw/utils.py index 27227af..eb42b03 100644 --- a/pycaw/utils.py +++ b/pycaw/utils.py @@ -1,7 +1,8 @@ -import warnings -from typing import List, TypeVar, Generic from __future__ import annotations +import warnings +from typing import Generic, List, TypeVar + import comtypes import psutil from _ctypes import COMError @@ -22,7 +23,8 @@ ) # Define a type variable that extends IUnknown -COMInterface = TypeVar('COMInterface', bound='IUnknown') +COMInterface = TypeVar("COMInterface", bound="IUnknown") + class AudioDevice: """ @@ -65,7 +67,7 @@ def state(self) -> AudioDeviceState: state = self._dev.GetState() self._state = AudioDeviceState(state) return self._state - + @staticmethod def getProperty(self, key): store = self._dev.OpenPropertyStore(STGM.STGM_READ.value) @@ -76,10 +78,10 @@ def getProperty(self, key): for j in range(propCount): pk = store.GetAt(j) name = str(pk) - + if name != key: continue - + try: value = store.GetValue(pk) v = value.GetValue() @@ -103,11 +105,11 @@ def properties(self) -> dict: propCount = store.GetCount() for j in range(propCount): pk = store.GetAt(j) - + value = AudioDevice.getProperty(self, pk) if value is None: continue - + name = str(pk) properties[name] = value self._properties = properties @@ -355,11 +357,15 @@ def GetSessions( return sessions @staticmethod - def GetPlaybackSessions(sessionState: AudioDeviceState = None) -> List[AudioSession]: + def GetPlaybackSessions( + sessionState: AudioDeviceState = None, + ) -> List[AudioSession]: return AudioUtilities.GetSessions(EDataFlow.eRender.value, sessionState) @staticmethod - def GetRecordingSessions(sessionState: AudioDeviceState = None) -> List[AudioSession]: + def GetRecordingSessions( + sessionState: AudioDeviceState = None, + ) -> List[AudioSession]: return AudioUtilities.GetSessions(EDataFlow.eCapture.value, sessionState) @staticmethod @@ -401,7 +407,9 @@ def CreateDevices(collection) -> List[AudioDevice]: return devices @staticmethod - def GetDevices(flow=EDataFlow.eAll.value, deviceState=DEVICE_STATE.ACTIVE.value) -> List[AudioDevice]: + def GetDevices( + flow=EDataFlow.eAll.value, deviceState=DEVICE_STATE.ACTIVE.value + ) -> List[AudioDevice]: """ Get devices based on filteres for flow direction and device state. Default to returning active devices. From ff9019207d0114d32f5293f2fe12985a81b6697c Mon Sep 17 00:00:00 2001 From: Tim Tucker Date: Mon, 23 Sep 2024 08:02:16 -0400 Subject: [PATCH 3/3] Remove unused import --- pycaw/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pycaw/utils.py b/pycaw/utils.py index eb42b03..c326eb6 100644 --- a/pycaw/utils.py +++ b/pycaw/utils.py @@ -1,7 +1,7 @@ from __future__ import annotations import warnings -from typing import Generic, List, TypeVar +from typing import List, TypeVar import comtypes import psutil