Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions keep/providers/snmp_provider/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from keep.providers.snmp_provider.snmp_provider import SnmpProvider
184 changes: 184 additions & 0 deletions keep/providers/snmp_provider/snmp_provider.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,184 @@
"""
SNMP Provider for Keep.
Supports receiving SNMP Traps (v1, v2c) and converting them into Keep Alerts.
"""

import dataclasses
import datetime
import logging
import threading
import uuid
import pydantic
from pysnmp.entity import engine, config
from pysnmp.entity.rfc3413 import ntfrcv
from pysnmp.carrier.asyncio.dgram import udp

from keep.api.models.alert import AlertDto, AlertSeverity, AlertStatus
from keep.contextmanager.contextmanager import ContextManager
from keep.providers.base.base_provider import BaseProvider
from keep.providers.models.provider_config import ProviderConfig, ProviderScope


@pydantic.dataclasses.dataclass
class SnmpProviderAuthConfig:
"""
SNMP authentication configuration.
"""
bind_address: str = dataclasses.field(
default="0.0.0.0",
metadata={
"required": True,
"description": "The address to bind the SNMP Trap listener to",
"hint": "0.0.0.0 to listen on all interfaces",
},
)
port: int = dataclasses.field(
default=162,
metadata={
"required": True,
"description": "The UDP port to listen for Traps",
"hint": "Default is 162. Note: ports < 1024 require root privileges.",
},
)
community: str = dataclasses.field(
default="public",
metadata={
"required": True,
"description": "SNMP Community string",
"hint": "e.g. public",
},
)


class SnmpProvider(BaseProvider):
"""
SNMP provider class.
"""

PROVIDER_DISPLAY_NAME = "SNMP"
PROVIDER_CATEGORY = ["Monitoring"]
PROVIDER_TAGS = ["alert", "network"]

def __init__(
self, context_manager: ContextManager, provider_id: str, config: ProviderConfig
):
super().__init__(context_manager, provider_id, config)
self.snmp_engine = None
self.consume = False

def validate_config(self):
"""
Validates required configuration for SNMP provider.
"""
self.authentication_config = SnmpProviderAuthConfig(
**self.config.authentication
)

def dispose(self):
"""
Dispose the provider.
"""
self.stop_consume()

def _trap_callback(self, snmp_engine, state_reference, context_engine_id, context_name, var_binds, cb_ctx):
"""
Callback function executed when a Trap is received.
"""
self.logger.info("SNMP Trap received from %s", cb_ctx)

# Extract OIDs and values
trap_data = {}
for name, val in var_binds:
oid = str(name)
value = str(val.prettyPrint())
trap_data[oid] = value
self.logger.debug("OID: %s, Value: %s", oid, value)

# Create a Keep Alert
try:
alert = self._format_trap_to_alert(trap_data, cb_ctx)
self._push_alert(alert)
except Exception:
self.logger.exception("Failed to push SNMP trap alert to Keep")

def _format_trap_to_alert(self, trap_data: dict, source_info: any) -> dict:
"""
Converts raw trap data into a Keep-compatible Alert dictionary.
"""
# Attempt to find a meaningful name from the Trap (e.g., sysName or specific OID)
# 1.3.6.1.6.3.1.1.4.1.0 is snmpTrapOID.0
trap_oid = trap_data.get("1.3.6.1.6.3.1.1.4.1.0", "unknown-trap")

return {
"id": str(uuid.uuid4()),
"name": f"SNMP Trap: {trap_oid}",
"status": AlertStatus.FIRING,
"lastReceived": datetime.datetime.now(datetime.timezone.utc).isoformat(),
"environment": self.config.details.get("environment", "production"),
"service": "network-device",
"source": ["snmp"],
"message": f"Received SNMP Trap {trap_oid} from {source_info}",
"description": "SNMP Trap captured by Keep SNMP Provider",
"severity": AlertSeverity.CRITICAL, # Traps are usually critical by default
"fingerprint": hashlib.sha256(f"{trap_oid}-{source_info}".encode()).hexdigest(),
"payload": trap_data, # Include all OIDs in the payload
}

def start_consume(self):
"""
Starts the SNMP Trap listener.
"""
self.logger.info(
"Starting SNMP Trap listener on %s:%s",
self.authentication_config.bind_address,
self.authentication_config.port,
)
self.consume = True
self.snmp_engine = engine.SnmpEngine()

# Configure Community-based security (SNMP v1/v2c)
config.addV1System(self.snmp_engine, "keep-area", self.authentication_config.community)

# Configure Transport Endpoint
try:
config.addTransport(
self.snmp_engine,
udp.domainName,
udp.UdpTransport().openServerMode(
(self.authentication_config.bind_address, self.authentication_config.port)
),
)
except Exception as e:
self.logger.error("Failed to bind SNMP port %s: %s", self.authentication_config.port, e)
self.consume = False
return

# Register Callback
ntfrcv.NotificationReceiver(self.snmp_engine, self._trap_callback)

self.snmp_engine.transportDispatcher.jobStarted(1)

try:
while self.consume:
# Run the dispatcher loop
self.snmp_engine.transportDispatcher.runDispatcher()
except Exception:
self.logger.exception("Error in SNMP Trap dispatcher loop")
finally:
self.snmp_engine.transportDispatcher.closeDispatcher()
self.logger.info("SNMP Trap listener stopped")

def stop_consume(self):
"""
Stops the SNMP Trap listener.
"""
self.consume = False
if self.snmp_engine:
self.snmp_engine.transportDispatcher.jobFinished(1)

def status(self):
if self.consume:
return {"status": "running", "error": ""}
return {"status": "stopped", "error": ""}

import hashlib
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,7 @@ google-generativeai = "^0.8.4"
retry2 = "^0.9.5"
requests-aws4auth = "^1.3.1"
awscli = "^1.40.8"
pysnmp-lextudio = "^6.3.0"
[tool.poetry.group.dev.dependencies]
pre-commit = "^3.0.4"
pre-commit-hooks = "^4.4.0"
Expand Down
66 changes: 66 additions & 0 deletions simulate_snmp_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
import time
import threading
import logging
from keep.providers.snmp_provider.snmp_provider import SnmpProvider
from keep.providers.models.provider_config import ProviderConfig
from keep.contextmanager.contextmanager import ContextManager
from pysnmp.hlapi import *

# Set up logging to see what's happening
logging.basicConfig(level=logging.DEBUG)

def start_provider():
context_manager = ContextManager(tenant_id="test-tenant")
config = ProviderConfig(
authentication={
"bind_address": "127.0.0.1",
"port": 1162,
"community": "public",
}
)
provider = SnmpProvider(
context_manager=context_manager,
provider_id="snmp-test",
config=config
)

# Mock _push_alert to see if it's called
provider._push_alert = lambda alert: print(f"\n[SUCCESS] Alert Pushed to Keep: {alert['name']}\nPayload: {alert['payload']}\n")

provider.start_consume()

def send_trap():
print("Sending SNMP Trap to localhost:1162...")
errorIndication, errorStatus, errorIndex, varBinds = next(
sendNotification(
SnmpEngine(),
CommunityData('public', mpModel=0),
UdpTransportTarget(('127.0.0.1', 1162)),
ContextData(),
'trap',
NotificationType(
ObjectIdentity('1.3.6.1.6.3.1.1.5.2') # Cold Start Trap
).addVarBinds(
('1.3.6.1.2.1.1.5.0', OctetString('Test-Device'))
)
)
)
if errorIndication:
print(f"Error sending trap: {errorIndication}")
else:
print("Trap sent successfully!")

if __name__ == "__main__":
# Start provider in a background thread
t = threading.Thread(target=start_provider, daemon=True)
t.start()

# Wait for provider to start
time.sleep(2)

# Send a trap
send_trap()

# Wait to see the result
time.sleep(5)
print("Simulation finished.")
32 changes: 32 additions & 0 deletions tests/test_snmp_provider.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
import unittest
from unittest.mock import MagicMock
from keep.providers.snmp_provider.snmp_provider import SnmpProvider
from keep.providers.models.provider_config import ProviderConfig
from keep.contextmanager.contextmanager import ContextManager

class TestSnmpProvider(unittest.TestCase):
def setUp(self):
self.context_manager = ContextManager(tenant_id="test-tenant")
self.config = ProviderConfig(
authentication={
"bind_address": "127.0.0.1",
"port": 1162,
"community": "public",
}
)
self.provider = SnmpProvider(
context_manager=self.context_manager,
provider_id="snmp-test",
config=self.config
)

def test_initialization(self):
self.assertEqual(self.provider.provider_id, "snmp-test")
self.assertEqual(self.provider.authentication_config.port, 1162)

def test_status(self):
status = self.provider.status()
self.assertEqual(status["status"], "stopped")

if __name__ == "__main__":
unittest.main()
Loading