diff --git a/keep-ui/public/snmp.svg b/keep-ui/public/snmp.svg new file mode 100644 index 0000000000..87e0763de4 --- /dev/null +++ b/keep-ui/public/snmp.svg @@ -0,0 +1 @@ + diff --git a/keep/providers/snmp_provider/__init__.py b/keep/providers/snmp_provider/__init__.py new file mode 100644 index 0000000000..061ab4b1f0 --- /dev/null +++ b/keep/providers/snmp_provider/__init__.py @@ -0,0 +1 @@ +from keep.providers.snmp_provider.snmp_provider import SnmpProvider as SnmpProvider diff --git a/keep/providers/snmp_provider/snmp_provider.py b/keep/providers/snmp_provider/snmp_provider.py new file mode 100644 index 0000000000..29ee0cb3cf --- /dev/null +++ b/keep/providers/snmp_provider/snmp_provider.py @@ -0,0 +1,224 @@ +import asyncio +from typing import Any, Optional +import pydantic +from dataclasses import field +from pysnmp.hlapi import * # type: ignore +from pysnmp.entity import engine, config # type: ignore +from pysnmp.entity.rfc3413 import ntfrcv # type: ignore +from pysnmp.carrier.asyncio.dgram import udp # type: ignore + +from keep.contextmanager.contextmanager import ContextManager +from keep.providers.base.base_provider import BaseProvider +from keep.providers.models.provider_config import ProviderConfig +from keep.api.models.alert import AlertSeverity, AlertStatus + + +@pydantic.dataclasses.dataclass +class SnmpProviderAuthConfig: + """ + SNMP authentication configuration. + """ + + tags: list[str] = field(default_factory=list) + port: int = field( + default=162, + metadata={ + "required": False, + "description": "Port to listen for SNMP traps", + "hint": "Default is 162", + }, + ) + community: str = field( + default="public", + metadata={ + "required": False, + "description": "SNMP v2c Community String", + "hint": "Default is public", + }, + ) + # SNMP v3 auth + v3_user: Optional[str] = field( + default=None, + metadata={ + "required": False, + "description": "SNMP v3 Security Name", + "hint": "Username for SNMP v3", + }, + ) + v3_auth_key: Optional[str] = field( + default=None, + metadata={ + "required": False, + "description": "SNMP v3 Auth Key", + "sensitive": True, + }, + ) + v3_priv_key: Optional[str] = field( + default=None, + metadata={ + "required": False, + "description": "SNMP v3 Priv Key", + "sensitive": True, + }, + ) + v3_auth_proto: str = field( + default="sha", + metadata={ + "required": False, + "description": "SNMP v3 Auth Protocol", + "hint": "sha, md5, etc.", + }, + ) + v3_priv_proto: str = field( + default="aes", + metadata={ + "required": False, + "description": "SNMP v3 Priv Protocol", + "hint": "aes, des, etc.", + }, + ) + + +class SnmpProvider(BaseProvider): + """ + SNMP provider class for receiving traps. + """ + + PROVIDER_DISPLAY_NAME = "SNMP" + PROVIDER_CATEGORY = ["Monitoring"] + PROVIDER_TAGS = ["alert", "topology"] + + def __init__( + self, context_manager: ContextManager, provider_id: str, config: ProviderConfig + ): + super().__init__(context_manager, provider_id, config) + self.snmp_engine = engine.SnmpEngine() + self.consume = False + + def validate_config(self) -> None: + self.authentication_config = SnmpProviderAuthConfig( + **self.config.authentication # type: ignore + ) + + def dispose(self) -> None: + self.stop_consume() + + def _trap_callback( + self, + snmpEngine: Any, + stateReference: Any, + contextEngineId: Any, + contextName: Any, + varBinds: Any, + cbCtx: Any, + ) -> None: + """ + Callback executed when a trap is received. + """ + self.logger.info("SNMP Trap received") + + # Extract basic info + alert_data: dict[str, Any] = { + "source": ["snmp"], + "severity": AlertSeverity.INFO, + "status": AlertStatus.FIRING, + "description": "", + "varbinds": {}, + } + + # Trap OID is usually one of the varbinds + trap_oid = "unknown" + for name, val in varBinds: + name_str = str(name) + val_str = str(val) + alert_data["varbinds"][name_str] = val_str + + # snmpTrapOID.0 = 1.3.6.1.6.3.1.1.4.1.0 + if "1.3.6.1.6.3.1.1.4.1.0" in name_str: + trap_oid = val_str + + alert_data["description"] += f"{name_str}: {val_str}\n" + + alert_data["name"] = f"SNMP Trap: {trap_oid}" + + try: + self._push_alert(alert_data) + except Exception: + self.logger.exception("Failed to push SNMP alert to Keep") + + def start_consume(self) -> None: + """ + Start listening for SNMP traps. + """ + self.consume = True + self.logger.info( + f"Starting SNMP Trap listener on port {self.authentication_config.port}" + ) + + # 1. Setup SNMP v2c Community + config.addV1System( + self.snmp_engine, "keep-area", self.authentication_config.community + ) + + # 2. Setup SNMP v3 User (if configured) + if self.authentication_config.v3_user: + auth_proto = ( + config.usmHMACSHAAuthProtocol + if self.authentication_config.v3_auth_proto.lower() == "sha" + else config.usmHMACMD5AuthProtocol + ) + priv_proto = ( + config.usmAesCfb128Protocol + if self.authentication_config.v3_priv_proto.lower() == "aes" + else config.usmDesPrivProtocol + ) + + config.addV3User( + self.snmp_engine, + self.authentication_config.v3_user, + auth_proto, + self.authentication_config.v3_auth_key, + priv_proto, + self.authentication_config.v3_priv_key, + ) + + # 3. Transport Setup + config.addTransport( + self.snmp_engine, + udp.domainName, + udp.UdpAsyncioTransport().openServerMode( + ("0.0.0.0", self.authentication_config.port) + ), + ) + + # 4. Register Notification Receiver + ntfrcv.NotificationReceiver(self.snmp_engine, self._trap_callback) + + # 5. Run the background loop + self.logger.info("SNMP listener active") + try: + # Thread-safe event loop management for Python 3.12+ + try: + loop = asyncio.get_event_loop() + except RuntimeError: + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + + async def listen() -> None: + while self.consume: + await asyncio.sleep(0.1) # Faster polling for stop signal + + if loop.is_running(): + # If we're in a thread with a running loop (like some test runners) + asyncio.run_coroutine_threadsafe(listen(), loop) + else: + loop.run_until_complete(listen()) + except Exception: + self.logger.exception("Error in SNMP listener loop") + finally: + self.logger.info("SNMP listener stopped") + + def stop_consume(self) -> None: + self.consume = False + # Unsubscribe/close transport if needed + # self.snmp_engine.transportDispatcher.closeDispatcher() - usually handles it diff --git a/poetry.lock b/poetry.lock index 00198e7ff3..c89d8c8e5d 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry 2.3.2 and should not be changed by hand. +# This file is automatically @generated by Poetry 2.3.4 and should not be changed by hand. [[package]] name = "aiofiles" @@ -4552,6 +4552,24 @@ files = [ {file = "pyproject_hooks-1.2.0.tar.gz", hash = "sha256:1e859bd5c40fae9448642dd871adf459e5e2084186e8d2c2a79a824c970da1f8"}, ] +[[package]] +name = "pysnmp" +version = "7.1.22" +description = "A Python library for SNMP" +optional = false +python-versions = ">=3.10" +groups = ["main"] +files = [ + {file = "pysnmp-7.1.22-py3-none-any.whl", hash = "sha256:57e704a6ba2bbf571d16cd5dc08b89bb3fa0ebeb5f4f26b87fececad3b3de7a6"}, + {file = "pysnmp-7.1.22.tar.gz", hash = "sha256:37ac595c7f0c1c00514505939b4dcf5b4fd5a9ffe51b0349f60bb640c11b0f77"}, +] + +[package.dependencies] +pyasn1 = ">=0.4.8,<0.5.0 || >0.5.0" + +[package.extras] +dev = ["Sphinx (>=7.0.0,<8.0.0)", "bump2version (>=1.0.1)", "codecov (>=2.1.12)", "cryptography (>=44.0.1)", "doc8 (>=1.0.0)", "flake8 (>=7.0.0)", "furo (>=2023.1.1)", "jinja2 (>=3.1.6)", "pep8-naming (>=0.14.1)", "pre-commit (==2.21.0)", "pysmi (>=1.6.1)", "pytest (>=7.2.0)", "pytest-asyncio (>=0.21.1)", "pytest-codecov (>=0.4.0)", "pytest-cov (>=4.1.0)", "ruff (>=0.11.7)", "sphinx-copybutton (>=0.5.2)", "sphinx-notfound-page (>=1.0.0)", "sphinx-polyversion (>=1.0.0)", "sphinx-sitemap-lextudio (>=2.5.2)"] + [[package]] name = "pytest" version = "8.3.4" @@ -6156,4 +6174,4 @@ type = ["pytest-mypy"] [metadata] lock-version = "2.1" python-versions = ">=3.11,<3.14" -content-hash = "824ecd671f3022b7d2e54af04d744f5aaa3ede100d810e5b9fdcb4980c4f7592" +content-hash = "d5e962f69a2978dd54a762cca4059b2d27b1bb4b53d6c1621ff53a2a9c0a3d24" diff --git a/pyproject.toml b/pyproject.toml index e4e70695be..1aa8f60ad5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -87,6 +87,7 @@ psycopg = "^3.2.3" prometheus-client = "^0.21.1" psycopg2-binary = "^2.9.10" urllib3 = "<2.7.0" +pysnmp = "^7.1.0" prometheus-fastapi-instrumentator = "^7.0.0" slowapi = "^0.1.9" diff --git a/tests/test_snmp_provider.py b/tests/test_snmp_provider.py new file mode 100644 index 0000000000..f5c890e66a --- /dev/null +++ b/tests/test_snmp_provider.py @@ -0,0 +1,94 @@ +import pytest +from unittest.mock import MagicMock, patch +from keep.contextmanager.contextmanager import ContextManager +from keep.providers.snmp_provider.snmp_provider import SnmpProvider +from keep.providers.models.provider_config import ProviderConfig +from keep.api.models.alert import AlertSeverity, AlertStatus + +class TestSnmpProvider: + @pytest.fixture + def context_manager(self): + return ContextManager(tenant_id="test-tenant") + + @pytest.fixture + def snmp_config(self): + return ProviderConfig( + authentication={ + "port": 1162, + "community": "public", + "v3_user": "test-user", + "v3_auth_key": "auth-key", + "v3_priv_key": "priv-key" + } + ) + + @pytest.fixture + def snmp_provider(self, context_manager, snmp_config): + return SnmpProvider(context_manager, "test-snmp", snmp_config) + + def test_validate_config(self, snmp_provider): + snmp_provider.validate_config() + assert snmp_provider.authentication_config.port == 1162 + assert snmp_provider.authentication_config.community == "public" + assert snmp_provider.authentication_config.v3_user == "test-user" + + @patch("keep.providers.snmp_provider.snmp_provider.SnmpProvider._push_alert") + def test_trap_callback(self, mock_push_alert, snmp_provider): + # Prepare mock varbinds + # 1.3.6.1.6.3.1.1.4.1.0 is the OID for snmpTrapOID.0 + var_binds = [ + ("1.3.6.1.2.1.1.3.0", "12345"), # sysUpTime + ("1.3.6.1.6.3.1.1.4.1.0", "1.3.6.1.4.1.2021.251.1"), # trap OID + ("1.3.6.1.4.1.2021.251.1.1", "Sample alert message") # custom varbind + ] + + # Simulate callback + snmp_provider._trap_callback( + snmpEngine=None, + stateReference=None, + contextEngineId=None, + contextName=None, + varBinds=var_binds, + cbCtx=None + ) + + # Verify push_alert was called + mock_push_alert.assert_called_once() + args, _ = mock_push_alert.call_args + alert_data = args[0] + + assert alert_data["source"] == ["snmp"] + assert "SNMP Trap: 1.3.6.1.4.1.2021.251.1" in alert_data["name"] + assert "1.3.6.1.4.1.2021.251.1.1: Sample alert message" in alert_data["description"] + assert alert_data["severity"] == AlertSeverity.INFO + assert alert_data["status"] == AlertStatus.FIRING + + @patch("pysnmp.entity.config.addV1System") + @patch("pysnmp.entity.config.addV3User") + @patch("pysnmp.entity.config.addTransport") + @patch("pysnmp.entity.rfc3413.ntfrcv.NotificationReceiver") + @patch("pysnmp.carrier.asyncio.dgram.udp.UdpAsyncioTransport") + def test_start_consume_logic(self, mock_udp, mock_ntfrcv, mock_transport, mock_v3, mock_v1, snmp_provider): + snmp_provider.validate_config() + + # Mock asyncio loop and related functions + mock_loop = MagicMock() + mock_loop.is_running.return_value = False + + with patch("asyncio.get_event_loop", return_value=mock_loop): + # We want to test the setup logic without actually running the infinite loop + # setting consume = False here will make start_consume's loop exit immediately + # if we don't overwrite it inside the method, but start_consume does overwrite it. + # So we mock run_until_complete to ensure we don't hang. + mock_loop.run_until_complete = MagicMock() + + snmp_provider.start_consume() + + # Verify setup calls + mock_v1.assert_called_with(snmp_provider.snmp_engine, "keep-area", "public") + mock_v3.assert_called() + mock_transport.assert_called() + mock_ntfrcv.assert_called_with(snmp_provider.snmp_engine, snmp_provider._trap_callback) + + # Verify the loop was attempted + mock_loop.run_until_complete.assert_called() diff --git a/uv.lock b/uv.lock new file mode 100644 index 0000000000..9431a635b1 --- /dev/null +++ b/uv.lock @@ -0,0 +1,3 @@ +version = 1 +revision = 3 +requires-python = ">=3.11"