Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
294 changes: 293 additions & 1 deletion music_assistant/providers/party_mode/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,13 @@
ConfigValueType,
ProviderConfig,
)
from music_assistant_models.enums import ConfigEntryType, MediaType, PlaybackState, ProviderFeature
from music_assistant_models.enums import (
ConfigEntryType,
EventType,
MediaType,
PlaybackState,
ProviderFeature,
)
from music_assistant_models.errors import InvalidDataError
from music_assistant_models.queue_item import QueueItem

Expand Down Expand Up @@ -117,6 +123,39 @@ class PartyModeConfig(DataClassDictMixin):
qr_instruction_text: str


@dataclass
class PartyModeGuestRequest(DataClassDictMixin):
"""Details of a guest action, included in state events for HA tracking.

Provides rich track metadata so HA automations can build dashboards
(e.g., top requested songs, genre breakdowns, request history).
"""

action: str # "add", "boost", or "skip"
track_name: str | None
track_artist: str | None
track_uri: str | None
track_album: str | None
track_duration: int | None # seconds
track_genres: list[str] | None
track_image_url: str | None


@dataclass
class PartyModeState(DataClassDictMixin):
"""State data for party mode, exposed via events and API for HA integration."""

enabled: bool
guest_access_enabled: bool
player_id: str | None
join_url: str | None
join_code: str | None
total_guest_additions: int
total_guest_boosts: int
total_guest_skips: int
last_request: PartyModeGuestRequest | None = None


async def setup(
mass: MusicAssistant, manifest: ProviderManifest, config: ProviderConfig
) -> ProviderInstanceType:
Expand Down Expand Up @@ -380,6 +419,9 @@ def __init__(
super().__init__(mass, manifest, config, supported_features)
self._unregister_handles: list[Callable[[], None]] = []
self._queue_lock = asyncio.Lock()
self._total_guest_additions: int = 0
self._total_guest_boosts: int = 0
self._total_guest_skips: int = 0

async def loaded_in_mass(self) -> None:
"""Call after the provider has been loaded."""
Expand All @@ -402,6 +444,36 @@ async def loaded_in_mass(self) -> None:
self._unregister_handles.append(
self.mass.register_api_command("party_mode/skip", self.skip_current)
)
# HA integration API commands
self._unregister_handles.append(
self.mass.register_api_command(
"party_mode/state", self.get_party_mode_state, required_role="user"
)
)
self._unregister_handles.append(
self.mass.register_api_command(
"party_mode/enable_guest_access",
self.enable_guest_access,
required_role="user",
)
)
self._unregister_handles.append(
self.mass.register_api_command(
"party_mode/disable_guest_access",
self.disable_guest_access,
required_role="user",
)
)
self._unregister_handles.append(
self.mass.register_api_command(
"party_mode/regenerate_code",
self.regenerate_join_code,
required_role="user",
)
)

# Signal initial state
self._signal_state_event()

async def unload(self, is_removed: bool = False) -> None:
"""Call when the provider is being unloaded.
Expand All @@ -410,6 +482,23 @@ async def unload(self, is_removed: bool = False) -> None:
"""
self.logger.debug("Party mode unload called, is_removed=%s", is_removed)

# Signal disabled state before unloading
if is_removed:
self.mass.signal_event(
EventType.PARTY_MODE_UPDATED,
object_id=self.instance_id,
data=PartyModeState(
enabled=False,
guest_access_enabled=False,
player_id=None,
join_url=None,
join_code=None,
total_guest_additions=self._total_guest_additions,
total_guest_boosts=self._total_guest_boosts,
total_guest_skips=self._total_guest_skips,
).to_dict(),
)

# Unregister all API commands
for unregister in self._unregister_handles:
unregister()
Expand Down Expand Up @@ -622,13 +711,27 @@ async def add_to_queue(
await self.mass.player_queues.play_index(queue_id, insert_index)
started_playback = True

self._total_guest_additions += 1
if boost:
self._total_guest_boosts += 1

self.logger.info(
"Guest added to queue: %s (boost=%s, started_playback=%s)",
uri,
boost,
started_playback,
)

# Build request info for HA tracking
media_item = await self.mass.music.get_item_by_uri(uri)
self._signal_state_event(
last_request=self._build_guest_request(
action="boost" if boost else "add",
uri=uri,
media_item=media_item,
)
)

return {
"success": True,
"queue_id": queue_id,
Expand Down Expand Up @@ -753,18 +856,207 @@ async def skip_current(self) -> dict[str, Any]:
if queue.state != PlaybackState.PLAYING:
raise InvalidDataError("Nothing is currently playing")

# Capture current track info before skipping
current_media_item = (
queue.current_item.media_item
if queue.current_item and queue.current_item.media_item
else None
)
current_uri = current_media_item.uri if current_media_item else None

# Skip to next track
await self.mass.player_queues.next(queue_id)

self._total_guest_skips += 1
self.logger.info("Guest skipped current track on queue: %s", queue_id)
self._signal_state_event(
last_request=self._build_guest_request(
action="skip",
uri=current_uri,
media_item=current_media_item,
)
)

return {
"success": True,
"queue_id": queue_id,
}

# ==================== HA Integration API Commands ====================

async def get_party_mode_state(self) -> dict[str, Any]:
"""Get the current party mode state for HA integration.

:returns: Dict with party mode state including enabled status, counters, and join info.
"""
state = await self._build_state()
return state.to_dict()

async def enable_guest_access(self) -> dict[str, Any]:
"""Enable guest access for party mode.

Updates the config and triggers a provider reload.

:returns: Result dict with success status.
"""
if self.config.get_value(CONF_ENABLE_GUEST_ACCESS):
return {"success": True, "message": "Guest access is already enabled"}

await self.mass.config.save_provider_config(
provider_domain=self.domain,
values={CONF_ENABLE_GUEST_ACCESS: True},
instance_id=self.instance_id,
)
return {"success": True}

async def disable_guest_access(self) -> dict[str, Any]:
"""Disable guest access for party mode.

Updates the config, revokes tokens, and triggers a provider reload.

:returns: Result dict with success status.
"""
if not self.config.get_value(CONF_ENABLE_GUEST_ACCESS):
return {"success": True, "message": "Guest access is already disabled"}

await self.mass.config.save_provider_config(
provider_domain=self.domain,
values={CONF_ENABLE_GUEST_ACCESS: False},
instance_id=self.instance_id,
)
return {"success": True}

async def regenerate_join_code(self) -> dict[str, Any]:
"""Revoke the current join code and generate a new one.

:returns: Result dict with the new join code and URL.
"""
if not self.config.get_value(CONF_ENABLE_GUEST_ACCESS):
raise InvalidDataError("Guest access is disabled")

auth = self.mass.webserver.auth
guest_user = await self._get_or_create_party_guest_user()

# Revoke existing codes
await auth.revoke_join_codes(guest_user)

# Generate a new code
code, _expires_at = await auth.generate_join_code(
user=guest_user,
expires_in_hours=8,
max_uses=0,
device_name="Party Mode Guest",
)

url = await self.get_party_mode_url()
self.logger.info("Regenerated join code for party mode")
self._signal_state_event()

return {
"success": True,
"join_code": code,
"join_url": url,
}

# ==================== Helper Methods ====================

@staticmethod
def _build_guest_request(
action: str,
uri: str | None,
media_item: Any = None,
) -> PartyModeGuestRequest:
"""Build a PartyModeGuestRequest from a media item.

:param action: The action type ("add", "boost", or "skip").
:param uri: The URI of the media item.
:param media_item: The resolved media item (Track, Radio, etc.), or None.
:returns: PartyModeGuestRequest with track metadata.
"""
track_name: str | None = None
track_artist: str | None = None
track_album: str | None = None
track_duration: int | None = None
track_genres: list[str] | None = None
track_image_url: str | None = None

if media_item:
track_name = media_item.name
if hasattr(media_item, "artist_str"):
track_artist = media_item.artist_str
if hasattr(media_item, "album") and media_item.album:
track_album = media_item.album.name
if hasattr(media_item, "duration") and media_item.duration:
track_duration = media_item.duration
if media_item.metadata and media_item.metadata.genres:
track_genres = sorted(media_item.metadata.genres)
if media_item.image:
track_image_url = media_item.image.path

return PartyModeGuestRequest(
action=action,
track_name=track_name,
track_artist=track_artist,
track_uri=uri,
track_album=track_album,
track_duration=track_duration,
track_genres=track_genres,
track_image_url=track_image_url,
)

async def _build_state(self) -> PartyModeState:
"""Build the current party mode state.

:returns: PartyModeState with current status.
"""
guest_access_enabled = bool(self.config.get_value(CONF_ENABLE_GUEST_ACCESS))
player_id = str(self.config.get_value(CONF_PARTY_MODE_PLAYER) or "")

join_code: str | None = None
join_url: str | None = None
if guest_access_enabled:
auth = self.mass.webserver.auth
guest_user = await auth.get_user_by_username(PARTY_MODE_GUEST_USER)
if guest_user:
join_code = await auth.get_active_join_code(guest_user)
join_url = await self.get_party_mode_url()

return PartyModeState(
enabled=True,
guest_access_enabled=guest_access_enabled,
player_id=player_id or None,
join_url=join_url,
join_code=join_code,
total_guest_additions=self._total_guest_additions,
total_guest_boosts=self._total_guest_boosts,
total_guest_skips=self._total_guest_skips,
)

def _signal_state_event(self, last_request: PartyModeGuestRequest | None = None) -> None:
"""Signal a party mode state update event.

Fires an async task to build state and emit the event, since
building the state requires async calls (join code lookup).

:param last_request: Optional details of the guest action that triggered this event.
"""
self.mass.create_task(self._async_signal_state_event(last_request))

async def _async_signal_state_event(
self, last_request: PartyModeGuestRequest | None = None
) -> None:
"""Build state and fire the party mode updated event.

:param last_request: Optional details of the guest action that triggered this event.
"""
state = await self._build_state()
state.last_request = last_request
self.mass.signal_event(
EventType.PARTY_MODE_UPDATED,
object_id=self.instance_id,
data=state.to_dict(),
)

async def _revoke_guest_tokens(self) -> None:
"""Revoke all guest access tokens and codes for party mode.

Expand Down
Loading