diff --git a/keep/providers/snmp_provider/__init__.py b/keep/providers/snmp_provider/__init__.py new file mode 100644 index 0000000000..b8fc57c827 --- /dev/null +++ b/keep/providers/snmp_provider/__init__.py @@ -0,0 +1 @@ +from keep.providers.snmp_provider.snmp_provider import SnmpProvider diff --git a/keep/providers/snmp_provider/snmp_provider.py b/keep/providers/snmp_provider/snmp_provider.py new file mode 100644 index 0000000000..595e2b9117 --- /dev/null +++ b/keep/providers/snmp_provider/snmp_provider.py @@ -0,0 +1,184 @@ +""" +SNMP Provider for Keep. +Supports receiving SNMP Traps (v1, v2c) and converting them into Keep Alerts. +""" + +import dataclasses +import datetime +import logging +import threading +import uuid +import pydantic +from pysnmp.entity import engine, config +from pysnmp.entity.rfc3413 import ntfrcv +from pysnmp.carrier.asyncio.dgram import udp + +from keep.api.models.alert import AlertDto, AlertSeverity, AlertStatus +from keep.contextmanager.contextmanager import ContextManager +from keep.providers.base.base_provider import BaseProvider +from keep.providers.models.provider_config import ProviderConfig, ProviderScope + + +@pydantic.dataclasses.dataclass +class SnmpProviderAuthConfig: + """ + SNMP authentication configuration. + """ + bind_address: str = dataclasses.field( + default="0.0.0.0", + metadata={ + "required": True, + "description": "The address to bind the SNMP Trap listener to", + "hint": "0.0.0.0 to listen on all interfaces", + }, + ) + port: int = dataclasses.field( + default=162, + metadata={ + "required": True, + "description": "The UDP port to listen for Traps", + "hint": "Default is 162. Note: ports < 1024 require root privileges.", + }, + ) + community: str = dataclasses.field( + default="public", + metadata={ + "required": True, + "description": "SNMP Community string", + "hint": "e.g. public", + }, + ) + + +class SnmpProvider(BaseProvider): + """ + SNMP provider class. + """ + + PROVIDER_DISPLAY_NAME = "SNMP" + PROVIDER_CATEGORY = ["Monitoring"] + PROVIDER_TAGS = ["alert", "network"] + + def __init__( + self, context_manager: ContextManager, provider_id: str, config: ProviderConfig + ): + super().__init__(context_manager, provider_id, config) + self.snmp_engine = None + self.consume = False + + def validate_config(self): + """ + Validates required configuration for SNMP provider. + """ + self.authentication_config = SnmpProviderAuthConfig( + **self.config.authentication + ) + + def dispose(self): + """ + Dispose the provider. + """ + self.stop_consume() + + def _trap_callback(self, snmp_engine, state_reference, context_engine_id, context_name, var_binds, cb_ctx): + """ + Callback function executed when a Trap is received. + """ + self.logger.info("SNMP Trap received from %s", cb_ctx) + + # Extract OIDs and values + trap_data = {} + for name, val in var_binds: + oid = str(name) + value = str(val.prettyPrint()) + trap_data[oid] = value + self.logger.debug("OID: %s, Value: %s", oid, value) + + # Create a Keep Alert + try: + alert = self._format_trap_to_alert(trap_data, cb_ctx) + self._push_alert(alert) + except Exception: + self.logger.exception("Failed to push SNMP trap alert to Keep") + + def _format_trap_to_alert(self, trap_data: dict, source_info: any) -> dict: + """ + Converts raw trap data into a Keep-compatible Alert dictionary. + """ + # Attempt to find a meaningful name from the Trap (e.g., sysName or specific OID) + # 1.3.6.1.6.3.1.1.4.1.0 is snmpTrapOID.0 + trap_oid = trap_data.get("1.3.6.1.6.3.1.1.4.1.0", "unknown-trap") + + return { + "id": str(uuid.uuid4()), + "name": f"SNMP Trap: {trap_oid}", + "status": AlertStatus.FIRING, + "lastReceived": datetime.datetime.now(datetime.timezone.utc).isoformat(), + "environment": self.config.details.get("environment", "production"), + "service": "network-device", + "source": ["snmp"], + "message": f"Received SNMP Trap {trap_oid} from {source_info}", + "description": "SNMP Trap captured by Keep SNMP Provider", + "severity": AlertSeverity.CRITICAL, # Traps are usually critical by default + "fingerprint": hashlib.sha256(f"{trap_oid}-{source_info}".encode()).hexdigest(), + "payload": trap_data, # Include all OIDs in the payload + } + + def start_consume(self): + """ + Starts the SNMP Trap listener. + """ + self.logger.info( + "Starting SNMP Trap listener on %s:%s", + self.authentication_config.bind_address, + self.authentication_config.port, + ) + self.consume = True + self.snmp_engine = engine.SnmpEngine() + + # Configure Community-based security (SNMP v1/v2c) + config.addV1System(self.snmp_engine, "keep-area", self.authentication_config.community) + + # Configure Transport Endpoint + try: + config.addTransport( + self.snmp_engine, + udp.domainName, + udp.UdpTransport().openServerMode( + (self.authentication_config.bind_address, self.authentication_config.port) + ), + ) + except Exception as e: + self.logger.error("Failed to bind SNMP port %s: %s", self.authentication_config.port, e) + self.consume = False + return + + # Register Callback + ntfrcv.NotificationReceiver(self.snmp_engine, self._trap_callback) + + self.snmp_engine.transportDispatcher.jobStarted(1) + + try: + while self.consume: + # Run the dispatcher loop + self.snmp_engine.transportDispatcher.runDispatcher() + except Exception: + self.logger.exception("Error in SNMP Trap dispatcher loop") + finally: + self.snmp_engine.transportDispatcher.closeDispatcher() + self.logger.info("SNMP Trap listener stopped") + + def stop_consume(self): + """ + Stops the SNMP Trap listener. + """ + self.consume = False + if self.snmp_engine: + self.snmp_engine.transportDispatcher.jobFinished(1) + + def status(self): + if self.consume: + return {"status": "running", "error": ""} + return {"status": "stopped", "error": ""} + +import hashlib diff --git a/pyproject.toml b/pyproject.toml index e4e70695be..830ce83090 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -97,6 +97,7 @@ google-generativeai = "^0.8.4" retry2 = "^0.9.5" requests-aws4auth = "^1.3.1" awscli = "^1.40.8" +pysnmp-lextudio = "^6.3.0" [tool.poetry.group.dev.dependencies] pre-commit = "^3.0.4" pre-commit-hooks = "^4.4.0" diff --git a/simulate_snmp_test.py b/simulate_snmp_test.py new file mode 100644 index 0000000000..94ccb6786d --- /dev/null +++ b/simulate_snmp_test.py @@ -0,0 +1,66 @@ +import time +import threading +import logging +from keep.providers.snmp_provider.snmp_provider import SnmpProvider +from keep.providers.models.provider_config import ProviderConfig +from keep.contextmanager.contextmanager import ContextManager +from pysnmp.hlapi import * + +# Set up logging to see what's happening +logging.basicConfig(level=logging.DEBUG) + +def start_provider(): + context_manager = ContextManager(tenant_id="test-tenant") + config = ProviderConfig( + authentication={ + "bind_address": "127.0.0.1", + "port": 1162, + "community": "public", + } + ) + provider = SnmpProvider( + context_manager=context_manager, + provider_id="snmp-test", + config=config + ) + + # Mock _push_alert to see if it's called + provider._push_alert = lambda alert: print(f"\n[SUCCESS] Alert Pushed to Keep: {alert['name']}\nPayload: {alert['payload']}\n") + + provider.start_consume() + +def send_trap(): + print("Sending SNMP Trap to localhost:1162...") + errorIndication, errorStatus, errorIndex, varBinds = next( + sendNotification( + SnmpEngine(), + CommunityData('public', mpModel=0), + UdpTransportTarget(('127.0.0.1', 1162)), + ContextData(), + 'trap', + NotificationType( + ObjectIdentity('1.3.6.1.6.3.1.1.5.2') # Cold Start Trap + ).addVarBinds( + ('1.3.6.1.2.1.1.5.0', OctetString('Test-Device')) + ) + ) + ) + if errorIndication: + print(f"Error sending trap: {errorIndication}") + else: + print("Trap sent successfully!") + +if __name__ == "__main__": + # Start provider in a background thread + t = threading.Thread(target=start_provider, daemon=True) + t.start() + + # Wait for provider to start + time.sleep(2) + + # Send a trap + send_trap() + + # Wait to see the result + time.sleep(5) + print("Simulation finished.") diff --git a/tests/test_snmp_provider.py b/tests/test_snmp_provider.py new file mode 100644 index 0000000000..f7f0279bdc --- /dev/null +++ b/tests/test_snmp_provider.py @@ -0,0 +1,32 @@ +import unittest +from unittest.mock import MagicMock +from keep.providers.snmp_provider.snmp_provider import SnmpProvider +from keep.providers.models.provider_config import ProviderConfig +from keep.contextmanager.contextmanager import ContextManager + +class TestSnmpProvider(unittest.TestCase): + def setUp(self): + self.context_manager = ContextManager(tenant_id="test-tenant") + self.config = ProviderConfig( + authentication={ + "bind_address": "127.0.0.1", + "port": 1162, + "community": "public", + } + ) + self.provider = SnmpProvider( + context_manager=self.context_manager, + provider_id="snmp-test", + config=self.config + ) + + def test_initialization(self): + self.assertEqual(self.provider.provider_id, "snmp-test") + self.assertEqual(self.provider.authentication_config.port, 1162) + + def test_status(self): + status = self.provider.status() + self.assertEqual(status["status"], "stopped") + +if __name__ == "__main__": + unittest.main()