diff --git a/nettacker/core/lib/socket.py b/nettacker/core/lib/socket.py index ba63e022f..cd1267051 100644 --- a/nettacker/core/lib/socket.py +++ b/nettacker/core/lib/socket.py @@ -26,7 +26,13 @@ def create_tcp_socket(host, port, timeout): return None try: - socket_connection = ssl.wrap_socket(socket_connection) + # Create an SSL context without certificate or hostname verification + context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) + context.check_hostname = False + context.verify_mode = ssl.CERT_NONE + context.check_hostname = False + context.verify_mode = ssl.CERT_NONE + socket_connection = context.wrap_socket(socket_connection, server_hostname=host) ssl_flag = True except Exception: socket_connection = socket.socket(socket.AF_INET, socket.SOCK_STREAM) diff --git a/nettacker/core/lib/ssl.py b/nettacker/core/lib/ssl.py index 507085a1f..31b53ad2e 100644 --- a/nettacker/core/lib/ssl.py +++ b/nettacker/core/lib/ssl.py @@ -117,7 +117,10 @@ def create_tcp_socket(host, port, timeout): return None try: - socket_connection = ssl.wrap_socket(socket_connection) + context = ssl.create_default_context() + context.check_hostname = False + context.verify_mode = ssl.CERT_NONE + socket_connection = context.wrap_socket(socket_connection, server_hostname=host) ssl_flag = True except Exception: socket_connection = socket.socket(socket.AF_INET, socket.SOCK_STREAM) diff --git a/nettacker/lib/graph/d3_tree_v1/engine.py b/nettacker/lib/graph/d3_tree_v1/engine.py index c5b695978..1d682cf67 100644 --- a/nettacker/lib/graph/d3_tree_v1/engine.py +++ b/nettacker/lib/graph/d3_tree_v1/engine.py @@ -28,15 +28,19 @@ def start(events): normalisedjson = {"name": "Started attack", "children": {}} # get data for normalised_json for event in events: - if event["target"] not in normalisedjson["children"]: - normalisedjson["children"].update({event["target"]: {}}) - normalisedjson["children"][event["target"]].update({event["module_name"]: []}) - - if event["module_name"] not in normalisedjson["children"][event["target"]]: - normalisedjson["children"][event["target"]].update({event["module_name"]: []}) - normalisedjson["children"][event["target"]][event["module_name"]].append( - f"target: {event['target']}, module_name: {event['module_name']}, port: " - f"{event['port']}, event: {event['event']}" + target = event.get("target", "unknown_target") + module_name = event.get("module_name", "unknown_module") + port = event.get("port", "unknown_port") + event_name = event.get("event", "unknown_event") + + if target not in normalisedjson["children"]: + normalisedjson["children"].update({target: {}}) + normalisedjson["children"][target].update({module_name: []}) + + if module_name not in normalisedjson["children"][target]: + normalisedjson["children"][target].update({module_name: []}) + normalisedjson["children"][target][module_name].append( + f"target: {target}, module_name: {module_name}, port: {port}, event: {event_name}" ) # define a d3_structure_json d3_structure = {"name": "Starting attack", "children": []} diff --git a/report.html b/report.html new file mode 100644 index 000000000..c4bb4eaaf --- /dev/null +++ b/report.html @@ -0,0 +1,7 @@ + + /*css*/ +
datetargetmodule_nameportlogsjson_eventnowx + +
1
+ + \ No newline at end of file diff --git a/tests/core/lib/test_base.py b/tests/core/lib/test_base.py new file mode 100644 index 000000000..2fc058efd --- /dev/null +++ b/tests/core/lib/test_base.py @@ -0,0 +1,77 @@ +from unittest.mock import MagicMock, patch + +from nettacker.core.lib.base import BaseEngine + + +def test_filter_large_content_truncates(): + engine = BaseEngine() + content = "abcdefghij klm" + result = engine.filter_large_content(content, filter_rate=10) + assert result != content + assert result.startswith("abcdefghij") + assert "klm" not in result + + +@patch("nettacker.core.lib.base.submit_logs_to_db") +@patch("nettacker.core.lib.base.merge_logs_to_list", return_value=["logA"]) +@patch("nettacker.core.lib.base.remove_sensitive_header_keys") +def test_process_conditions_success(mock_remove, mock_merge, mock_submit): + engine = BaseEngine() + event = { + "headers": {"Authorization": "secret"}, + "response": { + "conditions_results": {"log": "entry"}, + "conditions": {"dummy": {"reverse": False, "regex": ""}}, + "condition_type": "and", + }, + "ports": 80, + } + options = {"retries": 1} + mock_remove.return_value = event + + result = engine.process_conditions( + event, + "module", + "target", + "scan", + options, + {"resp": True}, + 1, + 1, + 1, + 1, + 1, + ) + assert result is True + mock_submit.assert_called_once() + mock_merge.assert_called_once() + mock_remove.assert_called_once() + + +@patch("nettacker.core.lib.base.submit_temp_logs_to_db") +def test_process_conditions_save_temp(mock_submit_temp): + engine = BaseEngine() + event = { + "response": { + "conditions_results": [], + "conditions": {}, + "condition_type": "and", + "save_to_temp_events_only": "temp_evt", + } + } + options = {"retries": 1} + result = engine.process_conditions( + event, + "module", + "target", + "scan", + options, + {}, + 1, + 1, + 1, + 1, + 1, + ) + assert result is True + mock_submit_temp.assert_called_once() diff --git a/tests/core/lib/test_base_extended.py b/tests/core/lib/test_base_extended.py new file mode 100644 index 000000000..3d7d40fa0 --- /dev/null +++ b/tests/core/lib/test_base_extended.py @@ -0,0 +1,282 @@ +""" +Comprehensive tests for TemplateLoader including parse, format, and load methods. +""" + +import pytest +from pathlib import Path +from unittest.mock import MagicMock, patch, mock_open +import yaml + +from nettacker.core.template import TemplateLoader + + +class TestTemplateLoaderInit: + """Test TemplateLoader initialization.""" + + def test_initialization_with_name_only(self): + """Test initialization with just a name.""" + loader = TemplateLoader("http_scan_scan") + assert loader.name == "http_scan_scan" + assert loader.inputs == {} + + def test_initialization_with_name_and_inputs(self): + """Test initialization with name and inputs.""" + inputs = {"port": "80", "timeout": "10"} + loader = TemplateLoader("http_scan_scan", inputs=inputs) + assert loader.name == "http_scan_scan" + assert loader.inputs == inputs + + def test_initialization_with_none_inputs(self): + """Test that None inputs are converted to empty dict.""" + loader = TemplateLoader("http_scan_scan", inputs=None) + assert loader.inputs == {} + + +class TestTemplateLoaderParse: + """Test static parse method.""" + + def test_parse_dict_with_matching_input(self): + """Test parse replaces dict values with matching inputs.""" + content = {"host": "{host}", "port": "80"} + inputs = {"host": "example.com"} + result = TemplateLoader.parse(content, inputs) + assert result["host"] == "example.com" + assert result["port"] == "80" + + def test_parse_dict_with_nonmatching_input(self): + """Test parse ignores non-matching dict keys.""" + content = {"host": "default.com", "port": "80"} + inputs = {"timeout": "10"} + result = TemplateLoader.parse(content, inputs) + assert result["host"] == "default.com" + assert result["port"] == "80" + + def test_parse_nested_dict(self): + """Test parse handles nested dictionaries.""" + content = {"level1": {"host": "{host}", "port": 80}} + inputs = {"host": "example.com"} + result = TemplateLoader.parse(content, inputs) + assert result["level1"]["host"] == "example.com" + + def test_parse_list_elements(self): + """Test parse handles list elements.""" + content = [{"host": "{host}"}, {"port": 80}] + inputs = {"host": "example.com"} + result = TemplateLoader.parse(content, inputs) + assert result[0]["host"] == "example.com" + assert result[1]["port"] == 80 + + def test_parse_nested_list_in_dict(self): + """Test parse handles lists within dicts.""" + content = {"servers": [{"host": "primary.com"}, {"host": "backup.com"}]} + inputs = {} + result = TemplateLoader.parse(content, inputs) + assert result["servers"][0]["host"] == "primary.com" + assert result["servers"][1]["host"] == "backup.com" + + def test_parse_with_truthy_input_value(self): + """Test parse uses input value when present and truthy.""" + content = {"enabled": False} + inputs = {"enabled": True} + result = TemplateLoader.parse(content, inputs) + assert result["enabled"] is True + + def test_parse_with_falsy_input_value(self): + """Test parse skips empty/falsy input values.""" + content = {"enabled": True} + inputs = {"enabled": ""} + result = TemplateLoader.parse(content, inputs) + # Empty string is falsy, so original value is kept + assert result["enabled"] is True + + def test_parse_preserves_types_in_nested_dict(self): + """Test parse handles nested structures with various types.""" + content = { + "timeout": 30, + "nested": { + "delay": 0.5, + "data": b"binary" + } + } + inputs = {} + result = TemplateLoader.parse(content, inputs) + assert result["timeout"] == 30 + assert result["nested"]["delay"] == 0.5 + assert result["nested"]["data"] == b"binary" + + def test_parse_empty_dict(self): + """Test parse with empty dict.""" + content = {} + inputs = {"host": "example.com"} + result = TemplateLoader.parse(content, inputs) + assert result == {} + + def test_parse_empty_list(self): + """Test parse with empty list.""" + content = [] + inputs = {"host": "example.com"} + result = TemplateLoader.parse(content, inputs) + assert result == [] + + +class TestTemplateLoaderOpen: + """Test open method for reading YAML files.""" + + def test_open_valid_template(self, tmp_path): + """Test opening a valid template file.""" + mock_yaml_content = "target: '{target}'\nport: 80\n" + expected_path = tmp_path / "scan" / "http_scan.yaml" + + with patch("nettacker.core.template.Config.path.modules_dir", tmp_path): + with patch("builtins.open", mock_open(read_data=mock_yaml_content)) as mocked_open: + loader = TemplateLoader("http_scan_scan") + result = loader.open() + assert isinstance(result, str) + mocked_open.assert_called_once_with(expected_path) + + def test_open_extracts_module_name_correctly(self, tmp_path): + """Test that open correctly parses module name.""" + mock_yaml_content = "test: data\n" + expected_path = tmp_path / "scan" / "port_scan.yaml" + + with patch("nettacker.core.template.Config.path.modules_dir", tmp_path): + with patch("builtins.open", mock_open(read_data=mock_yaml_content)) as mocked_open: + loader = TemplateLoader("port_scan_scan") + loader.open() + assert loader.name == "port_scan_scan" + mocked_open.assert_called_once_with(expected_path) + + +class TestTemplateLoaderFormat: + """Test format method.""" + + def test_format_with_inputs(self): + """Test format substitutes inputs into YAML string.""" + mock_yaml = "target: '{target}'\nport: {port}\n" + + with patch.object(TemplateLoader, "open", return_value=mock_yaml): + loader = TemplateLoader("http_scan_scan", inputs={"target": "example.com", "port": "80"}) + result = loader.format() + assert "example.com" in result + assert "80" in result + + def test_format_without_inputs(self): + """Test format on YAML without placeholders.""" + mock_yaml = "target: localhost\nport: 80\n" + + with patch.object(TemplateLoader, "open", return_value=mock_yaml): + loader = TemplateLoader("http_scan_scan") + result = loader.format() + assert result == mock_yaml + + def test_format_with_all_inputs_provided(self): + """Test format with all required inputs provided.""" + mock_yaml = "target: '{target}'\nport: '{port}'\n" + + with patch.object(TemplateLoader, "open", return_value=mock_yaml): + loader = TemplateLoader("http_scan_scan", inputs={"target": "example.com", "port": "80"}) + result = loader.format() + assert "example.com" in result + assert "80" in result + + +class TestTemplateLoaderLoad: + """Test load method which combines format and parse.""" + + def test_load_yaml_with_inputs(self): + """Test load properly parses YAML and applies inputs.""" + mock_yaml = "requests:\n - host: '{host}'\n port: 80\n" + formatted_yaml = "requests:\n - host: 'example.com'\n port: 80\n" + + with patch.object(TemplateLoader, "open", return_value=mock_yaml): + loader = TemplateLoader("http_scan_scan", inputs={"host": "example.com"}) + result = loader.load() + + # Result should be a parsed YAML dict + assert isinstance(result, dict) + assert "requests" in result + + def test_load_returns_parsed_dict(self): + """Test load returns parsed YAML as dict.""" + mock_yaml = "key: value\nnumber: 42\n" + + with patch.object(TemplateLoader, "open", return_value=mock_yaml): + loader = TemplateLoader("scan_test") + result = loader.load() + + assert isinstance(result, dict) + assert result.get("key") == "value" + assert result.get("number") == 42 + + def test_load_with_nested_yaml(self): + """Test load with complex nested YAML structure.""" + mock_yaml = """ +steps: + - name: scan + params: + host: '{target}' + port: 80 +""" + + with patch.object(TemplateLoader, "open", return_value=mock_yaml): + loader = TemplateLoader("port_scan_scan", inputs={"target": "example.com"}) + result = loader.load() + + assert isinstance(result, dict) + assert "steps" in result + assert isinstance(result["steps"], list) + + def test_load_with_list_yaml(self): + """Test load when YAML root is a list.""" + mock_yaml = "- host: localhost\n port: 80\n- host: example.com\n port: 443\n" + + with patch.object(TemplateLoader, "open", return_value=mock_yaml): + loader = TemplateLoader("scan_test") + result = loader.load() + + assert isinstance(result, list) + assert len(result) == 2 + + +class TestTemplateLoaderIntegration: + """Integration tests combining multiple methods.""" + + def test_full_workflow(self): + """Test complete template loading workflow.""" + mock_yaml = """ +module: scan +target: '{host}' +ports: + - 80 + - 443 +""" + + with patch.object(TemplateLoader, "open", return_value=mock_yaml): + loader = TemplateLoader("http_scan_scan", inputs={"host": "target.com"}) + + # Test each method in sequence + formatted = loader.format() + assert "target.com" in formatted + + loaded = loader.load() + assert isinstance(loaded, dict) + assert loaded["module"] == "scan" + assert loaded["target"] == "target.com" + assert 80 in loaded["ports"] + + def test_loader_with_multiple_templates(self): + """Test creating multiple loaders with different templates.""" + mock_yaml1 = "type: scan\ntarget: '{host}'\n" + mock_yaml2 = "type: brute\nuser: '{user}'\n" + + with patch.object(TemplateLoader, "open") as mock_open_method: + mock_open_method.side_effect = [mock_yaml1, mock_yaml2] + + loader1 = TemplateLoader("port_scan_scan", {"host": "example.com"}) + loader2 = TemplateLoader("ssh_brute_brute", {"user": "admin"}) + + result1 = loader1.load() + result2 = loader2.load() + + assert result1["target"] == "example.com" + assert result2["user"] == "admin" diff --git a/tests/core/lib/test_ftp_ftps.py b/tests/core/lib/test_ftp_ftps.py new file mode 100644 index 000000000..92cbe3bc2 --- /dev/null +++ b/tests/core/lib/test_ftp_ftps.py @@ -0,0 +1,26 @@ +from unittest.mock import MagicMock, patch + +from nettacker.core.lib.ftp import FtpEngine, FtpLibrary +from nettacker.core.lib.ftps import FtpsEngine, FtpsLibrary + + +def test_ftp_engine_has_library(): + engine = FtpEngine() + assert engine.library == FtpLibrary + + +def test_ftp_library_is_defined(): + lib = FtpLibrary() + assert hasattr(lib, "brute_force") + assert callable(lib.brute_force) + + +def test_ftps_engine_inherits_ftp(): + engine = FtpsEngine() + assert engine.library == FtpsLibrary + + +def test_ftps_library_is_defined(): + lib = FtpsLibrary() + assert hasattr(lib, "brute_force") + assert callable(lib.brute_force) diff --git a/tests/core/lib/test_http_and_logger_extended.py b/tests/core/lib/test_http_and_logger_extended.py new file mode 100644 index 000000000..d65597591 --- /dev/null +++ b/tests/core/lib/test_http_and_logger_extended.py @@ -0,0 +1,255 @@ +""" +Additional tests for HTTP engine and logger to achieve higher coverage. +""" + +import sys +from unittest.mock import MagicMock, patch +import pytest + +from nettacker.core.lib.http import HttpEngine +from nettacker.logger import Logger, TerminalCodes, get_logger + + +class TestHttpEngine: + """Test HTTP protocol engine in detail.""" + + def test_http_engine_exists(self): + """Test HttpEngine class exists.""" + engine = HttpEngine() + assert engine is not None + assert callable(engine.run) + + def test_http_engine_has_library(self): + """Test HttpEngine has library attribute.""" + engine = HttpEngine() + assert hasattr(HttpEngine, "library") + assert hasattr(engine, "run") + assert callable(engine.run) + + +class TestLoggerResetColor: + """Test logger reset color method.""" + + def test_reset_color_calls_log(self): + """Test reset_color method.""" + logger = Logger() + with patch.object(logger, "log") as mock_log: + logger.reset_color() + mock_log.assert_called_once_with(TerminalCodes.RESET.value) + + +class TestLoggerVerboseEvent: + """Additional tests for verbose event logging.""" + + def test_verbose_event_info_with_verbose_mode_only(self): + """Test verbose_event_info with --verbose but not --verbose-event.""" + with patch.object(sys, "argv", ["prog", "--verbose"]): + logger = Logger() + with patch.object(logger, "log") as mock_log: + logger.verbose_event_info("event") + # Should log because --verbose enables it + mock_log.assert_called() + + def test_verbose_event_info_with_event_verbose_only(self): + """Test verbose_event_info with --verbose-event but not --verbose.""" + with patch.object(sys, "argv", ["prog", "--verbose-event"]): + logger = Logger() + with patch.object(logger, "log") as mock_log: + logger.verbose_event_info("event") + # Should log because --verbose-event enables it + mock_log.assert_called() + + def test_verbose_event_info_both_flags_with_api(self): + """Test verbose_event_info is silent when running from API even with flags.""" + with patch.object(sys, "argv", ["prog", "--start-api", "--verbose-event"]): + logger = Logger() + with patch.object(logger, "log") as mock_log: + logger.verbose_event_info("event") + # Should be silent despite flags due to run_from_api + mock_log.assert_not_called() + + +class TestLoggerInfoBranches: + """Test all branches of info logging.""" + + def test_info_message_format(self): + """Test info message contains expected components.""" + with patch.object(sys, "argv", ["prog"]): + logger = Logger() + with patch.object(logger, "log") as mock_log: + logger.info("test info") + output = mock_log.call_args[0][0] + assert TerminalCodes.YELLOW.value in output + assert TerminalCodes.GREEN.value in output + assert TerminalCodes.RESET.value in output + assert "[+]" in output + + +class TestLoggerSuccessEventFormat: + """Test success event logging format.""" + + def test_success_event_message_format(self): + """Test success_event_info message format.""" + with patch.object(sys, "argv", ["prog"]): + logger = Logger() + with patch.object(logger, "log") as mock_log: + logger.success_event_info("success") + output = mock_log.call_args[0][0] + assert TerminalCodes.RED.value in output + assert TerminalCodes.CYAN.value in output + assert "[+++]" in output + + +class TestLoggerWarnFormat: + """Test warn message format.""" + + def test_warn_message_format(self): + """Test warn message contains expected components.""" + with patch.object(sys, "argv", ["prog"]): + logger = Logger() + with patch.object(logger, "log") as mock_log: + logger.warn("warning") + output = mock_log.call_args[0][0] + assert TerminalCodes.BLUE.value in output + assert "warning" in output + + +class TestLoggerVerboseInfoFormat: + """Test verbose info format.""" + + def test_verbose_info_message_format(self): + """Test verbose_info message format.""" + with patch.object(sys, "argv", ["prog", "--verbose"]): + logger = Logger() + with patch.object(logger, "log") as mock_log: + logger.verbose_info("verbose") + output = mock_log.call_args[0][0] + assert TerminalCodes.YELLOW.value in output + assert TerminalCodes.PURPLE.value in output + + +class TestLoggerErrorFormat: + """Test error message format.""" + + def test_error_message_format(self): + """Test error message format.""" + with patch.object(sys, "argv", ["prog"]): + logger = Logger() + with patch.object(logger, "log") as mock_log: + logger.error("error") + output = mock_log.call_args[0][0] + assert TerminalCodes.RED.value in output + assert TerminalCodes.YELLOW.value in output + assert "[X]" in output + + +class TestLoggerMultipleCalls: + """Test logger with multiple sequential calls.""" + + def test_multiple_log_calls(self): + """Test multiple log calls in sequence.""" + with patch.object(sys, "argv", ["prog"]): + logger = Logger() + with patch("builtins.print") as mock_print: + logger.log("first") + logger.log("second") + logger.log("third") + assert mock_print.call_count == 3 + + def test_multiple_info_calls_with_different_messages(self): + """Test multiple info calls.""" + with patch.object(sys, "argv", ["prog"]): + logger = Logger() + with patch.object(logger, "log") as mock_log: + logger.info("message 1") + logger.info("message 2") + logger.info("message 3") + assert mock_log.call_count == 3 + + +class TestLoggerWriteFormat: + """Test write method output.""" + + def test_write_direct_output(self): + """Test write method passes content directly.""" + with patch.object(sys, "argv", ["prog"]): + logger = Logger() + with patch("builtins.print") as mock_print: + logger.write("direct message") + assert mock_print.called + + +class TestGetLoggerInstance: + """Test get_logger function.""" + + def test_get_logger_returns_instance(self): + """Test get_logger returns Logger instance.""" + logger = get_logger() + assert isinstance(logger, Logger) + + def test_get_logger_multiple_calls_different_instances(self): + """Test get_logger creates new instances each time.""" + logger1 = get_logger() + logger2 = get_logger() + logger3 = get_logger() + assert isinstance(logger1, Logger) + assert isinstance(logger2, Logger) + assert isinstance(logger3, Logger) + + +class TestLoggerEdgeCases: + """Test edge cases and special conditions.""" + + def test_empty_message(self): + """Test logging empty message.""" + with patch.object(sys, "argv", ["prog"]): + logger = Logger() + with patch("builtins.print") as mock_print: + logger.log("") + mock_print.assert_called_once() + + def test_very_long_message(self): + """Test logging very long message.""" + with patch.object(sys, "argv", ["prog"]): + logger = Logger() + long_message = "x" * 10000 + with patch.object(logger, "log") as mock_log: + logger.info(long_message) + # Should still work with long messages + mock_log.assert_called() + + def test_message_with_special_characters(self): + """Test logging message with special characters.""" + with patch.object(sys, "argv", ["prog"]): + logger = Logger() + with patch("builtins.print") as mock_print: + logger.log("message with \n newline and \t tab") + mock_print.assert_called() + + def test_message_with_unicode(self): + """Test logging unicode messages.""" + with patch.object(sys, "argv", ["prog"]): + logger = Logger() + with patch("builtins.print") as mock_print: + logger.log("Unicode: 你好世界 مرحبا بالعالم") + mock_print.assert_called() + + +class TestLoggerFlushBehavior: + """Test logger flush behavior.""" + + def test_log_uses_flush_true(self): + """Test that log uses flush=True.""" + with patch("builtins.print") as mock_print: + Logger.log("test") + # Verify flush=True was passed + call_kwargs = mock_print.call_args[1] + assert call_kwargs.get("flush") is True + + def test_log_uses_end_empty_string(self): + """Test that log uses end=''.""" + with patch("builtins.print") as mock_print: + Logger.log("test") + # Verify end='' was passed + call_kwargs = mock_print.call_args[1] + assert call_kwargs.get("end") == "" diff --git a/tests/core/lib/test_pop3.py b/tests/core/lib/test_pop3.py new file mode 100644 index 000000000..418e6e80e --- /dev/null +++ b/tests/core/lib/test_pop3.py @@ -0,0 +1,26 @@ +from unittest.mock import MagicMock, patch + +from nettacker.core.lib.pop3 import Pop3Engine, Pop3Library +from nettacker.core.lib.pop3s import Pop3sEngine, Pop3sLibrary + + +def test_pop3_engine_has_library(): + engine = Pop3Engine() + assert engine.library == Pop3Library + + +def test_pop3_library_is_defined(): + lib = Pop3Library() + assert hasattr(lib, "brute_force") + assert callable(lib.brute_force) + + +def test_pop3s_engine_has_library(): + engine = Pop3sEngine() + assert engine.library == Pop3sLibrary + + +def test_pop3s_library_is_defined(): + lib = Pop3sLibrary() + assert hasattr(lib, "brute_force") + assert callable(lib.brute_force) diff --git a/tests/core/lib/test_protocols_comprehensive.py b/tests/core/lib/test_protocols_comprehensive.py new file mode 100644 index 000000000..b9e1a5e2c --- /dev/null +++ b/tests/core/lib/test_protocols_comprehensive.py @@ -0,0 +1,280 @@ +""" +Targeted tests for low-coverage modules to push coverage beyond 64%. +Focus on protocol libraries (smtp, telnet, etc.) and utility modules. +""" + +import sys +from unittest.mock import MagicMock, patch, mock_open +import pytest +import smtplib +import telnetlib + +from nettacker.core.lib.smtp import SmtpLibrary, SmtpEngine +from nettacker.core.lib.smtps import SmtpsLibrary, SmtpsEngine +from nettacker.core.lib.telnet import TelnetLibrary, TelnetEngine +from nettacker.core.lib.ftp import FtpLibrary, FtpEngine +from nettacker.core.lib.pop3 import Pop3Library, Pop3Engine +from nettacker.core.lib.ssh import SshLibrary, SshEngine +from nettacker.lib.compare_report.engine import build_report +from nettacker.core.messages import application_language, get_languages, load_message + + +class TestSmtpLibrary: + """Test SMTP protocol library.""" + + def test_smtp_library_class_defined(self): + """Test SmtpLibrary class exists and has client.""" + assert SmtpLibrary.client == smtplib.SMTP + + def test_smtp_library_instantiation(self): + """Test SmtpLibrary can be instantiated.""" + library = SmtpLibrary() + assert library is not None + assert library.client == smtplib.SMTP + + def test_smtp_engine_has_library(self): + """Test SmtpEngine has SmtpLibrary.""" + assert SmtpEngine.library == SmtpLibrary + + +class TestSmtpsLibrary: + """Test SMTPS protocol library.""" + + def test_smtps_library_defined(self): + """Test SmtpsLibrary exists.""" + assert SmtpsLibrary is not None + assert hasattr(SmtpsLibrary, 'client') + + def test_smtps_engine_defined(self): + """Test SmtpsEngine exists and has library.""" + assert SmtpsEngine.library == SmtpsLibrary + + +class TestTelnetLibrary: + """Test Telnet protocol library.""" + + def test_telnet_library_class_defined(self): + """Test TelnetLibrary class exists.""" + assert TelnetLibrary is not None + assert hasattr(TelnetLibrary, 'client') + + def test_telnet_library_instantiation(self): + """Test TelnetLibrary can be instantiated.""" + library = TelnetLibrary() + assert library is not None + + def test_telnet_engine_defined(self): + """Test TelnetEngine exists and has library.""" + assert TelnetEngine.library == TelnetLibrary + + +class TestFtpLibrary: + """Test FTP protocol library.""" + + def test_ftp_library_defined(self): + """Test FtpLibrary class exists.""" + assert FtpLibrary is not None + assert hasattr(FtpLibrary, 'client') + + def test_ftp_engine_defined(self): + """Test FtpEngine exists.""" + assert FtpEngine.library == FtpLibrary + + +class TestPop3Library: + """Test POP3 protocol library.""" + + def test_pop3_library_defined(self): + """Test Pop3Library class exists.""" + assert Pop3Library is not None + assert hasattr(Pop3Library, 'client') + + def test_pop3_engine_defined(self): + """Test Pop3Engine exists.""" + assert Pop3Engine.library == Pop3Library + + +class TestSshLibrary: + """Test SSH protocol library.""" + + def test_ssh_library_defined(self): + """Test SshLibrary class exists.""" + assert SshLibrary is not None + assert hasattr(SshLibrary, 'client') + + def test_ssh_engine_defined(self): + """Test SshEngine exists.""" + assert SshEngine.library == SshLibrary + + +class TestCompareReportEngine: + """Test compare report module.""" + + def test_build_report_with_simple_data(self): + """Test build_report function with simple data.""" + compare_result = {"scan1": "data1", "scan2": "data2"} + + with patch("builtins.open", mock_open(read_data="__data_will_locate_here__")): + result = build_report(compare_result) + assert '"scan1": "data1"' in result + assert '"scan2": "data2"' in result + + def test_build_report_with_complex_data(self): + """Test build_report with nested data structure.""" + compare_result = { + "comparison": { + "added": [1, 2, 3], + "removed": [4, 5] + } + } + + with patch("builtins.open", mock_open(read_data="prefix __data_will_locate_here__ suffix")): + result = build_report(compare_result) + assert "prefix" in result + assert "suffix" in result + assert "added" in result + + def test_build_report_html_replacement(self): + """Test that placeholder is properly replaced.""" + html_template = "Compare: __data_will_locate_here__" + data = {"status": "ok"} + + with patch("builtins.open", mock_open(read_data=html_template)): + result = build_report(data) + assert "__data_will_locate_here__" not in result + assert "status" in result + + +class TestApplicationLanguage: + """Test language selection logic.""" + + def test_language_from_L_flag(self): + """Test language selection from -L flag.""" + with patch.object(sys, "argv", ["prog", "-L", "fr"]): + with patch("nettacker.core.messages.get_languages", return_value=["en", "fr"]): + lang = application_language() + assert lang == "fr" + + def test_language_from_long_flag(self): + """Test language selection from --language flag.""" + with patch.object(sys, "argv", ["prog", "--language", "de"]): + with patch("nettacker.core.messages.get_languages", return_value=["en", "de"]): + lang = application_language() + assert lang == "de" + + def test_language_from_config(self): + """Test language selection from config.""" + with patch.object(sys, "argv", ["prog"]): + with patch("nettacker.core.messages.Config.settings.language", "fa"): + with patch("nettacker.core.messages.get_languages", return_value=["en", "fa"]): + lang = application_language() + assert lang == "fa" + + def test_language_default_to_en(self): + """Test default language is English.""" + with patch.object(sys, "argv", ["prog", "-L", "invalid"]): + with patch("nettacker.core.messages.get_languages", return_value=["en"]): + lang = application_language() + assert lang == "en" + + def test_language_invalid_reverts_to_en(self): + """Test invalid language reverts to English.""" + with patch.object(sys, "argv", ["prog", "-L", "xx"]): + with patch("nettacker.core.messages.get_languages", return_value=["en", "fr"]): + lang = application_language() + assert lang == "en" + + +class TestGetLanguages: + """Test language detection.""" + + @patch("nettacker.core.messages.Config.path.locale_dir") + def test_get_languages_returns_list(self, mock_locale_dir): + """Test get_languages returns list of available languages.""" + mock_paths = [ + MagicMock(__str__=lambda x: "/path/en.yaml"), + MagicMock(__str__=lambda x: "/path/fr.yaml"), + MagicMock(__str__=lambda x: "/path/de.yaml"), + ] + mock_locale_dir.glob.return_value = mock_paths + + languages = get_languages() + assert len(languages) >= 3 + assert "en" in languages + + +class TestLoadMessageClass: + """Test load_message class initialization.""" + + @patch("nettacker.core.messages.application_language", return_value="en") + @patch("nettacker.core.messages.load_yaml") + def test_load_message_init_english(self, mock_load_yaml, mock_app_lang): + """Test load_message initialization with English.""" + mock_load_yaml.return_value = {"test": "message"} + + loader = load_message() + assert loader.language == "en" + assert loader.messages == {"test": "message"} + + @patch("nettacker.core.messages.application_language", return_value="fa") + @patch("nettacker.core.messages.load_yaml") + def test_load_message_init_farsi_with_fallback(self, mock_load_yaml, mock_app_lang): + """Test load_message initialization with Farsi and English fallback.""" + # First call for Farsi, second call for English fallback + mock_load_yaml.side_effect = [ + {"translated": "message"}, + {"key": "en_value"} + ] + + loader = load_message() + assert loader.language == "fa" + assert mock_load_yaml.call_count == 2 + assert loader.messages.get("key") == "en_value" + + +class TestMessagesGetter: + """Test message retrieval methods.""" + + @patch("nettacker.core.messages.application_language", return_value="en") + @patch("nettacker.core.messages.load_yaml") + def test_load_message_messages_dict(self, mock_load_yaml, mock_app_lang): + """Test load_message creates messages dict.""" + mock_load_yaml.return_value = {"greeting": "Hello", "farewell": "Goodbye"} + + loader = load_message() + assert loader.messages == {"greeting": "Hello", "farewell": "Goodbye"} + assert loader.language == "en" + + @patch("nettacker.core.messages.application_language", return_value="en") + @patch("nettacker.core.messages.load_yaml") + def test_load_message_attribute_access(self, mock_load_yaml, mock_app_lang): + """Test load_message stores messages.""" + mock_load_yaml.return_value = {"error": "An error occurred"} + + loader = load_message() + assert hasattr(loader, "messages") + assert hasattr(loader, "language") + + +@pytest.mark.parametrize( + "engine_class, expected_library", + [ + (SmtpEngine, SmtpLibrary), + (TelnetEngine, TelnetLibrary), + (FtpEngine, FtpLibrary), + (Pop3Engine, Pop3Library), + (SshEngine, SshLibrary), + (SmtpsEngine, SmtpsLibrary), + ], +) +def test_protocol_engine_library_mapping(engine_class, expected_library): + assert engine_class.library == expected_library + + +@pytest.mark.parametrize( + "library_class", + [SmtpLibrary, SmtpsLibrary, TelnetLibrary, FtpLibrary, Pop3Library, SshLibrary], +) +def test_protocol_library_client_attribute(library_class): + assert hasattr(library_class, "client") + assert library_class.client is not None diff --git a/tests/core/lib/test_smtp.py b/tests/core/lib/test_smtp.py new file mode 100644 index 000000000..82d007848 --- /dev/null +++ b/tests/core/lib/test_smtp.py @@ -0,0 +1,26 @@ +from unittest.mock import MagicMock, patch + +from nettacker.core.lib.smtp import SmtpEngine, SmtpLibrary +from nettacker.core.lib.smtps import SmtpsEngine, SmtpsLibrary + + +def test_smtp_engine_has_library(): + engine = SmtpEngine() + assert engine.library == SmtpLibrary + + +def test_smtp_library_is_defined(): + lib = SmtpLibrary() + assert hasattr(lib, "brute_force") + assert callable(lib.brute_force) + + +def test_smtps_engine_has_library(): + engine = SmtpsEngine() + assert engine.library == SmtpsLibrary + + +def test_smtps_library_is_defined(): + lib = SmtpsLibrary() + assert hasattr(lib, "brute_force") + assert callable(lib.brute_force) diff --git a/tests/core/lib/test_socket.py b/tests/core/lib/test_socket.py index 6bedbedc4..bdacdca75 100644 --- a/tests/core/lib/test_socket.py +++ b/tests/core/lib/test_socket.py @@ -1,4 +1,4 @@ -from unittest.mock import patch +from unittest.mock import MagicMock, patch import pytest @@ -140,9 +140,9 @@ def responses(): class TestSocketMethod: + @patch("ssl.create_default_context") @patch("socket.socket") - @patch("ssl.wrap_socket") - def test_create_tcp_socket(self, mock_wrap, mock_socket): + def test_create_tcp_socket(self, mock_socket, mock_context): HOST = "example.com" PORT = 80 TIMEOUT = 60 @@ -151,7 +151,9 @@ def test_create_tcp_socket(self, mock_wrap, mock_socket): socket_instance = mock_socket.return_value socket_instance.settimeout.assert_called_with(TIMEOUT) socket_instance.connect.assert_called_with((HOST, PORT)) - mock_wrap.assert_called_with(socket_instance) + mock_context.return_value.wrap_socket.assert_called_with( + socket_instance, server_hostname=HOST + ) def test_response_conditions_matched_socket_icmp(self, socket_engine, substeps, responses): result = socket_engine.response_conditions_matched( @@ -193,3 +195,34 @@ def test_response_conditions_matched_with_none_response( substeps.tcp_connect_send_and_receive, responses.none ) assert result == [] + + @patch("ssl.create_default_context") + @patch("socket.socket") + def test_create_tcp_socket_wrap_failure(self, mock_socket, mock_context): + first_socket = MagicMock() + fallback_socket = MagicMock() + mock_socket.side_effect = [first_socket, fallback_socket] + mock_context.return_value.wrap_socket.side_effect = Exception("wrap fail") + + conn, ssl_flag = create_tcp_socket("example.com", 80, 10) + + assert conn is fallback_socket + assert ssl_flag is False + fallback_socket.connect.assert_called_with(("example.com", 80)) + + @patch("nettacker.core.lib.socket.create_tcp_socket") + @patch("socket.getservbyport", return_value="http") + def test_tcp_connect_send_and_receive_handles_errors( + self, mock_service, mock_create_tcp + ): + mock_socket = MagicMock() + mock_socket.getpeername.return_value = ("1.1.1.1", 80) + mock_socket.send.side_effect = Exception("send error") + mock_socket.close.return_value = None + mock_create_tcp.return_value = (mock_socket, False) + + library = SocketEngine().library() + result = library.tcp_connect_send_and_receive("host", 80, 1) + + assert result["response"] == "" + assert result["service"] == "http" diff --git a/tests/core/lib/test_ssh.py b/tests/core/lib/test_ssh.py new file mode 100644 index 000000000..25c6d4cbd --- /dev/null +++ b/tests/core/lib/test_ssh.py @@ -0,0 +1,14 @@ +from unittest.mock import MagicMock, patch + +from nettacker.core.lib.ssh import SshEngine, SshLibrary + + +def test_ssh_engine_has_library(): + engine = SshEngine() + assert engine.library == SshLibrary + + +def test_ssh_library_is_defined(): + lib = SshLibrary() + assert hasattr(lib, "brute_force") + assert callable(lib.brute_force) diff --git a/tests/core/lib/test_ssl.py b/tests/core/lib/test_ssl.py index e194a5db2..2dafaa48e 100644 --- a/tests/core/lib/test_ssl.py +++ b/tests/core/lib/test_ssl.py @@ -178,9 +178,9 @@ def connection_params(): class TestSslMethod: + @patch("ssl.create_default_context") @patch("socket.socket") - @patch("ssl.wrap_socket") - def test_create_tcp_socket(self, mock_wrap, mock_socket, connection_params): + def test_create_tcp_socket(self, mock_socket, mock_context, connection_params): create_tcp_socket( connection_params["HOST"], connection_params["PORT"], connection_params["TIMEOUT"] ) @@ -190,7 +190,9 @@ def test_create_tcp_socket(self, mock_wrap, mock_socket, connection_params): socket_instance.connect.assert_called_with( (connection_params["HOST"], connection_params["PORT"]) ) - mock_wrap.assert_called_with(socket_instance) + mock_context.return_value.wrap_socket.assert_called_with( + socket_instance, server_hostname=connection_params["HOST"] + ) @patch("nettacker.core.lib.ssl.is_weak_cipher_suite") @patch("nettacker.core.lib.ssl.is_weak_ssl_version") diff --git a/tests/core/lib/test_telnet.py b/tests/core/lib/test_telnet.py new file mode 100644 index 000000000..27974d23c --- /dev/null +++ b/tests/core/lib/test_telnet.py @@ -0,0 +1,14 @@ +from unittest.mock import MagicMock, patch + +from nettacker.core.lib.telnet import TelnetEngine, TelnetLibrary + + +def test_telnet_engine_has_library(): + engine = TelnetEngine() + assert engine.library == TelnetLibrary + + +def test_telnet_library_is_defined(): + lib = TelnetLibrary() + assert hasattr(lib, "brute_force") + assert callable(lib.brute_force) diff --git a/tests/core/test_app.py b/tests/core/test_app.py new file mode 100644 index 000000000..a38e393c9 --- /dev/null +++ b/tests/core/test_app.py @@ -0,0 +1,95 @@ +import json +from types import SimpleNamespace +from unittest.mock import MagicMock, patch + +from nettacker.core import app as app_module +from nettacker.core.app import Nettacker + + +def test_expand_targets_handles_subdomains_and_port_scan(monkeypatch): + scan_id = "scan-123" + app = Nettacker.__new__(Nettacker) + app.arguments = SimpleNamespace( + targets=["http://example.com/path", "192.168.1.1-192.168.1.2", "example.org"], + scan_ip_range=False, + selected_modules=["mod1", "subdomain_scan", "port_scan"], + scan_subdomains=True, + ping_before_scan=False, + skip_service_discovery=False, + set_hardware_usage=1, + scan_compare_id=None, + socks_proxy=None, + parallel_module_scan=2, + ) + app.start_scan = MagicMock() + + def fake_find_events(target, module_name, _scan_id): + if module_name == "subdomain_scan": + return [ + json.dumps( + {"response": {"conditions_results": {"content": [f"sub.{target}"]}}} + ) + ] + if module_name in {"icmp_scan", "port_scan"}: + return ["ok"] + return [] + + monkeypatch.setattr(app_module, "find_events", fake_find_events) + + expanded = set(app.expand_targets(scan_id)) + + assert {"example.com", "example.org", "192.168.1.1", "192.168.1.2"}.issubset(expanded) + assert "sub.example.com" in expanded + assert app.arguments.url_base_path == "path/" + assert app.start_scan.call_count == 2 + + +def test_filter_target_by_event(monkeypatch): + app = Nettacker.__new__(Nettacker) + + def fake_find_events(target, module_name, scan_id): + return ["found"] if target == "keep" else [] + + monkeypatch.setattr(app_module, "find_events", fake_find_events) + + result = app.filter_target_by_event(["keep", "drop"], "scan-1", "port_scan") + assert result == ["keep"] + + +def test_run_returns_true_when_no_targets(): + app = Nettacker.__new__(Nettacker) + app.arguments = SimpleNamespace( + report_path_filename="report.txt", + compare_report_path_filename="compare.txt", + graph_name=None, + scan_compare_id=None, + ) + app.expand_targets = MagicMock(return_value=[]) + app.start_scan = MagicMock() + + result = app.run() + + assert result is True + app.start_scan.assert_not_called() + + +def test_run_triggers_scan_and_report(monkeypatch): + app = Nettacker.__new__(Nettacker) + app.arguments = SimpleNamespace( + report_path_filename="report.html", + compare_report_path_filename="report.html", + graph_name=None, + scan_compare_id=None, + ) + app.expand_targets = MagicMock(return_value=["example.com"]) + app.start_scan = MagicMock(return_value=0) + + with patch("nettacker.core.app.create_report") as mock_create_report, patch( + "nettacker.core.app.create_compare_report" + ) as mock_compare_report: + exit_code = app.run() + + assert exit_code == 0 + app.start_scan.assert_called_once() + mock_create_report.assert_called_once() + mock_compare_report.assert_not_called() diff --git a/tests/core/test_app_extended.py b/tests/core/test_app_extended.py new file mode 100644 index 000000000..7ebf0ce44 --- /dev/null +++ b/tests/core/test_app_extended.py @@ -0,0 +1,93 @@ +import json +from types import SimpleNamespace +from unittest.mock import MagicMock, patch + +from nettacker.core import app as app_module +from nettacker.core.app import Nettacker + + +def make_nettacker_with_options(**kwargs): + app = Nettacker.__new__(Nettacker) + defaults = { + "targets": ["example.com"], + "scan_ip_range": False, + "selected_modules": ["mod"], + "scan_subdomains": False, + "ping_before_scan": False, + "skip_service_discovery": False, + "set_hardware_usage": 1, + "scan_compare_id": None, + "socks_proxy": None, + "parallel_module_scan": 1, + "report_path_filename": "report.txt", + "compare_report_path_filename": "compare.txt", + "graph_name": None, + } + defaults.update(kwargs) + app.arguments = SimpleNamespace(**defaults) + return app + + +@patch("nettacker.core.app.find_events", return_value=["ok"]) +def test_expand_targets_single_targets(mock_find_events): + app = make_nettacker_with_options( + targets=["1.1.1.1"], + skip_service_discovery=True, + scan_subdomains=False, + ping_before_scan=False, + ) + + result = app.expand_targets("scan-1") + + assert "1.1.1.1" in result + assert "1.1.1.1" in app.arguments.targets + + +@patch("nettacker.core.app.find_events") +@patch("nettacker.core.app.generate_ip_range", return_value=["192.1.1.1", "192.1.1.2"]) +def test_expand_targets_cidr(mock_gen_range, mock_find): + app = make_nettacker_with_options(targets=["192.0.0.0/30"], scan_ip_range=False) + app.start_scan = MagicMock() + + result = app.expand_targets("scan-1") + + assert "192.1.1.1" in result or "192.0.0.0" in result + + +@patch("nettacker.core.app.find_events") +def test_expand_targets_url_extracts_host_and_path(mock_find): + mock_find.return_value = [] + app = make_nettacker_with_options(targets=["https://example.com:8080/api/v1"]) + + app.expand_targets("scan-1") + + assert app.arguments.url_base_path == "api/v1/" + + +@patch("nettacker.core.app.multiprocess.Process") +@patch("nettacker.core.app.wait_for_threads_to_finish", return_value=True) +def test_start_scan_triggers_processes(mock_wait, mock_process): + mock_proc = MagicMock() + mock_process.return_value = mock_proc + + app = make_nettacker_with_options() + + result = app.start_scan("scan-1") + + assert result is True + mock_process.assert_called() + + +@patch("nettacker.core.app.create_report") +@patch("nettacker.core.app.remove_old_logs") +def test_run_flow_with_targets(mock_remove_logs, mock_report): + app = make_nettacker_with_options(targets=["1.1.1.1"]) + app.expand_targets = MagicMock(return_value=["1.1.1.1"]) + app.start_scan = MagicMock(return_value=0) + + result = app.run() + + assert result == 0 + app.expand_targets.assert_called_once() + app.start_scan.assert_called_once() + mock_report.assert_called_once() diff --git a/tests/core/test_app_more.py b/tests/core/test_app_more.py new file mode 100644 index 000000000..a68082082 --- /dev/null +++ b/tests/core/test_app_more.py @@ -0,0 +1,12 @@ +from types import SimpleNamespace +from unittest.mock import MagicMock, patch +import os + +from nettacker.core.app import Nettacker + + +def test_print_logo(): + with patch("nettacker.core.app.log") as mock_log: + Nettacker.print_logo() + mock_log.write_to_api_console.assert_called_once() + mock_log.reset_color.assert_called_once() diff --git a/tests/core/test_arg_parser.py b/tests/core/test_arg_parser.py new file mode 100644 index 000000000..350dc7c02 --- /dev/null +++ b/tests/core/test_arg_parser.py @@ -0,0 +1,108 @@ +import json +from types import SimpleNamespace +from unittest.mock import patch + +import pytest + +from nettacker.core import arg_parser as arg_module +from nettacker.core.arg_parser import ArgParser + + +def make_options(tmp_path, **overrides): + base = { + "language": "en", + "verbose_mode": False, + "verbose_event": False, + "show_version": False, + "show_help_menu": False, + "show_all_modules": False, + "show_all_profiles": False, + "start_api_server": False, + "api_hostname": "0.0.0.0", + "api_port": 5000, + "api_debug_mode": False, + "api_access_key": None, + "api_client_whitelisted_ips": None, + "api_access_log": None, + "api_cert": None, + "api_cert_key": None, + "targets": "example.com", + "targets_list": None, + "selected_modules": "mod1", + "profiles": None, + "set_hardware_usage": "low", + "thread_per_host": 1, + "parallel_module_scan": 1, + "excluded_modules": "", + "ports": "80", + "schema": "http", + "excluded_ports": "", + "user_agent": "custom-agent", + "http_header": None, + "usernames": "admin", + "usernames_list": None, + "passwords": "pass", + "passwords_list": None, + "read_from_file": None, + "report_path_filename": str(tmp_path / "report.txt"), + "graph_name": None, + "modules_extra_args": None, + "timeout": 1, + "time_sleep_between_requests": 1, + "retries": 1, + "socks_proxy": None, + "scan_compare_id": None, + "compare_report_path_filename": str(tmp_path / "compare.txt"), + "scan_ip_range": False, + "scan_subdomains": False, + "skip_service_discovery": False, + "ping_before_scan": False, + } + base.update(overrides) + return SimpleNamespace(**base) + + +@pytest.fixture(autouse=True) +def stub_loaders(monkeypatch): + monkeypatch.setattr(ArgParser, "load_graphs", staticmethod(lambda: ["g1", "g2"])) + monkeypatch.setattr(ArgParser, "load_languages", staticmethod(lambda: ["en", "fr"])) + monkeypatch.setattr(ArgParser, "load_modules", staticmethod(lambda limit=-1, full_details=False: {"mod1": {"profiles": []}, "all": {}})) + monkeypatch.setattr(ArgParser, "load_profiles", staticmethod(lambda limit=-1: {"all": []})) + + +@patch.object(arg_module, "die_failure", side_effect=RuntimeError("fail")) +def test_invalid_language_triggers_die_failure(mock_die, tmp_path): + options = make_options(tmp_path, language="de") + with pytest.raises(RuntimeError): + ArgParser(api_arguments=options) + mock_die.assert_called_once() + + +@patch.object(arg_module, "die_failure", side_effect=RuntimeError("invalid-graph")) +def test_invalid_graph_name(mock_die, tmp_path): + options = make_options(tmp_path, graph_name="missing") + with pytest.raises(RuntimeError): + ArgParser(api_arguments=options) + mock_die.assert_called_once() + + +def test_excluded_ports_range_parsed(tmp_path): + options = make_options(tmp_path, excluded_ports="1-2,5") + parser = ArgParser(api_arguments=options) + assert sorted(parser.arguments.excluded_ports) == [1, 2, 5] + + +def test_modules_extra_args_are_coerced(tmp_path): + raw_args = "flag=true&count=2&pi=3.1&obj={\"a\":1}" + options = make_options(tmp_path, modules_extra_args=raw_args) + parser = ArgParser(api_arguments=options) + coerced = parser.arguments.modules_extra_args + assert coerced == {"flag": True, "count": 2, "pi": 3.1, "obj": {"a": 1}} + + +@patch.object(arg_module, "die_failure", side_effect=RuntimeError("bad-schema")) +def test_invalid_schema(mock_die, tmp_path): + options = make_options(tmp_path, schema="ftp") + with pytest.raises(RuntimeError): + ArgParser(api_arguments=options) + mock_die.assert_called_once() diff --git a/tests/core/test_arg_parser_extended.py b/tests/core/test_arg_parser_extended.py new file mode 100644 index 000000000..829b89141 --- /dev/null +++ b/tests/core/test_arg_parser_extended.py @@ -0,0 +1,409 @@ +""" +Extended tests for ArgParser covering more branches and validation logic. +""" + +import json +import pytest +from types import SimpleNamespace +from unittest.mock import patch, MagicMock + +from nettacker.core import arg_parser as arg_module +from nettacker.core.arg_parser import ArgParser + + +def make_options(tmp_path, **overrides): + """Helper to create options with defaults.""" + base = { + "language": "en", + "verbose_mode": False, + "verbose_event": False, + "show_version": False, + "show_help_menu": False, + "show_all_modules": False, + "show_all_profiles": False, + "start_api_server": False, + "api_hostname": "0.0.0.0", + "api_port": 5000, + "api_debug_mode": False, + "api_access_key": None, + "api_client_whitelisted_ips": None, + "api_access_log": None, + "api_cert": None, + "api_cert_key": None, + "targets": "example.com", + "targets_list": None, + "selected_modules": "mod1", + "profiles": None, + "set_hardware_usage": "low", + "thread_per_host": 1, + "parallel_module_scan": 1, + "excluded_modules": "", + "ports": "80", + "schema": "http", + "excluded_ports": "", + "user_agent": "custom-agent", + "http_header": None, + "usernames": "admin", + "usernames_list": None, + "passwords": "pass", + "passwords_list": None, + "read_from_file": None, + "report_path_filename": str(tmp_path / "report.txt"), + "graph_name": None, + "modules_extra_args": None, + "timeout": 1, + "time_sleep_between_requests": 1, + "retries": 1, + "socks_proxy": None, + "scan_compare_id": None, + "compare_report_path_filename": str(tmp_path / "compare.txt"), + "scan_ip_range": False, + "scan_subdomains": False, + "skip_service_discovery": False, + "ping_before_scan": False, + } + base.update(overrides) + return SimpleNamespace(**base) + + +@pytest.fixture(autouse=True) +def stub_loaders(monkeypatch): + """Stub test data loaders.""" + monkeypatch.setattr(ArgParser, "load_graphs", staticmethod(lambda: ["d3_tree_v1_graph", "d3_tree_v2_graph"])) + monkeypatch.setattr(ArgParser, "load_languages", staticmethod(lambda: ["en", "fa"])) + monkeypatch.setattr( + ArgParser, + "load_modules", + staticmethod(lambda limit=-1, full_details=False: { + "mod1": {"profiles": ["profile1"]}, + "mod2": {"profiles": []}, + "all": {} + }) + ) + monkeypatch.setattr( + ArgParser, + "load_profiles", + staticmethod(lambda limit=-1: { + "profile1": ["mod1"], + "all": [] + }) + ) + + +class TestPortParsing: + """Test port argument parsing.""" + + def test_ports_single_port(self, tmp_path): + options = make_options(tmp_path, ports="80") + parser = ArgParser(api_arguments=options) + assert 80 in parser.arguments.ports + + def test_ports_multiple_ports_comma_separated(self, tmp_path): + options = make_options(tmp_path, ports="80,443,8080") + parser = ArgParser(api_arguments=options) + ports = parser.arguments.ports + assert 80 in ports + assert 443 in ports + assert 8080 in ports + + def test_ports_range(self, tmp_path): + options = make_options(tmp_path, ports="80-82") + parser = ArgParser(api_arguments=options) + ports = parser.arguments.ports + assert 80 in ports + assert 81 in ports + assert 82 in ports + + def test_ports_mixed_single_and_range(self, tmp_path): + options = make_options(tmp_path, ports="80,443-445") + parser = ArgParser(api_arguments=options) + ports = parser.arguments.ports + assert 80 in ports + assert 443 in ports + assert 444 in ports + assert 445 in ports + + +class TestExcludedPortsParsing: + """Test excluded ports parsing.""" + + def test_excluded_ports_single(self, tmp_path): + options = make_options(tmp_path, excluded_ports="22") + parser = ArgParser(api_arguments=options) + assert 22 in parser.arguments.excluded_ports + + def test_excluded_ports_range(self, tmp_path): + options = make_options(tmp_path, excluded_ports="1-10") + parser = ArgParser(api_arguments=options) + excluded = parser.arguments.excluded_ports + assert 1 in excluded + assert 5 in excluded + assert 10 in excluded + + def test_excluded_ports_mixed(self, tmp_path): + options = make_options(tmp_path, excluded_ports="22,23-25,443") + parser = ArgParser(api_arguments=options) + excluded = parser.arguments.excluded_ports + assert 22 in excluded + assert 23 in excluded + assert 24 in excluded + assert 25 in excluded + assert 443 in excluded + + +class TestSchemaParsing: + """Test schema argument parsing.""" + + def test_schema_single_http(self, tmp_path): + options = make_options(tmp_path, schema="http") + parser = ArgParser(api_arguments=options) + assert "http" in parser.arguments.schema + + def test_schema_multiple_comma_separated(self, tmp_path): + options = make_options(tmp_path, schema="http,https") + parser = ArgParser(api_arguments=options) + schema = parser.arguments.schema + assert "http" in schema + assert "https" in schema + + +class TestModulesExtraArgsCoercion: + """Test modules_extra_args type coercion.""" + + def test_extra_args_boolean_true(self, tmp_path): + options = make_options(tmp_path, modules_extra_args="enabled=true") + parser = ArgParser(api_arguments=options) + coerced = parser.arguments.modules_extra_args + assert coerced["enabled"] is True + + def test_extra_args_boolean_false(self, tmp_path): + options = make_options(tmp_path, modules_extra_args="enabled=false") + parser = ArgParser(api_arguments=options) + coerced = parser.arguments.modules_extra_args + assert coerced["enabled"] is False + + def test_extra_args_integer(self, tmp_path): + options = make_options(tmp_path, modules_extra_args="count=42") + parser = ArgParser(api_arguments=options) + coerced = parser.arguments.modules_extra_args + assert coerced["count"] == 42 + assert isinstance(coerced["count"], int) + + def test_extra_args_float(self, tmp_path): + options = make_options(tmp_path, modules_extra_args="ratio=3.14") + parser = ArgParser(api_arguments=options) + coerced = parser.arguments.modules_extra_args + assert coerced["ratio"] == 3.14 + assert isinstance(coerced["ratio"], float) + + def test_extra_args_string(self, tmp_path): + options = make_options(tmp_path, modules_extra_args="name=test") + parser = ArgParser(api_arguments=options) + coerced = parser.arguments.modules_extra_args + assert coerced["name"] == "test" + + def test_extra_args_json_object(self, tmp_path): + options = make_options(tmp_path, modules_extra_args='config={"key":"value"}') + parser = ArgParser(api_arguments=options) + coerced = parser.arguments.modules_extra_args + assert coerced["config"] == {"key": "value"} + + def test_extra_args_multiple_key_value_pairs(self, tmp_path): + options = make_options( + tmp_path, + modules_extra_args="flag=true&count=5&name=test&ratio=2.5" + ) + parser = ArgParser(api_arguments=options) + coerced = parser.arguments.modules_extra_args + assert coerced["flag"] is True + assert coerced["count"] == 5 + assert coerced["name"] == "test" + assert coerced["ratio"] == 2.5 + + +@patch.object(arg_module, "die_failure", side_effect=RuntimeError("fail")) +class TestValidationFailures: + """Test validation error conditions.""" + + def test_invalid_language_fails(self, mock_die, tmp_path): + options = make_options(tmp_path, language="invalid_lang") + with pytest.raises(RuntimeError, match="fail"): + ArgParser(api_arguments=options) + mock_die.assert_called() + + def test_invalid_graph_fails(self, mock_die, tmp_path): + options = make_options(tmp_path, graph_name="nonexistent_graph") + with pytest.raises(RuntimeError, match="fail"): + ArgParser(api_arguments=options) + mock_die.assert_called() + + def test_invalid_schema_fails(self, mock_die, tmp_path): + options = make_options(tmp_path, schema="invalid_schema") + with pytest.raises(RuntimeError, match="fail"): + ArgParser(api_arguments=options) + mock_die.assert_called() + + def test_invalid_hardware_usage_fails(self, mock_die, tmp_path): + options = make_options(tmp_path, set_hardware_usage="ultra") + with pytest.raises(RuntimeError, match="fail"): + ArgParser(api_arguments=options) + mock_die.assert_called() + + +class TestThreadingOptions: + """Test threading and module scan options.""" + + def test_thread_per_host_minimum_enforced(self, tmp_path): + options = make_options(tmp_path, thread_per_host=0) + parser = ArgParser(api_arguments=options) + # Should be bumped to 1 + assert parser.arguments.thread_per_host >= 1 + + def test_thread_per_host_valid_value(self, tmp_path): + options = make_options(tmp_path, thread_per_host=4) + parser = ArgParser(api_arguments=options) + assert parser.arguments.thread_per_host == 4 + + def test_parallel_module_scan_minimum_enforced(self, tmp_path): + options = make_options(tmp_path, parallel_module_scan=-1) + parser = ArgParser(api_arguments=options) + # Should be bumped to 1 + assert parser.arguments.parallel_module_scan >= 1 + + def test_parallel_module_scan_valid_value(self, tmp_path): + options = make_options(tmp_path, parallel_module_scan=2) + parser = ArgParser(api_arguments=options) + assert parser.arguments.parallel_module_scan == 2 + + +class TestModuleSelection: + """Test module and profile selection logic.""" + + def test_module_selection_single(self, tmp_path): + options = make_options(tmp_path, selected_modules="mod1") + parser = ArgParser(api_arguments=options) + assert "mod1" in parser.arguments.selected_modules + + def test_module_selection_multiple(self, tmp_path): + options = make_options(tmp_path, selected_modules="mod1,mod2") + parser = ArgParser(api_arguments=options) + modules = parser.arguments.selected_modules + assert "mod1" in modules + assert "mod2" in modules + + @patch.object(arg_module, "die_failure", side_effect=RuntimeError("module_not_found")) + def test_invalid_module_name_fails(self, mock_die, tmp_path): + options = make_options(tmp_path, selected_modules="invalid_module") + with pytest.raises(RuntimeError, match="module_not_found"): + ArgParser(api_arguments=options) + + def test_profile_selection(self, tmp_path): + options = make_options(tmp_path, profiles="profile1", selected_modules="") + parser = ArgParser(api_arguments=options) + # Modules from profile should be selected + assert len(parser.arguments.selected_modules) > 0 + + +class TestHardwareUsageOptions: + """Test hardware usage configuration.""" + + def test_hardware_usage_low(self, tmp_path): + options = make_options(tmp_path, set_hardware_usage="low") + parser = ArgParser(api_arguments=options) + # Should be parsed to a numeric value + assert isinstance(parser.arguments.set_hardware_usage, int) + + def test_hardware_usage_normal(self, tmp_path): + options = make_options(tmp_path, set_hardware_usage="normal") + parser = ArgParser(api_arguments=options) + assert isinstance(parser.arguments.set_hardware_usage, int) + + def test_hardware_usage_high(self, tmp_path): + options = make_options(tmp_path, set_hardware_usage="high") + parser = ArgParser(api_arguments=options) + assert isinstance(parser.arguments.set_hardware_usage, int) + + def test_hardware_usage_maximum(self, tmp_path): + options = make_options(tmp_path, set_hardware_usage="maximum") + parser = ArgParser(api_arguments=options) + assert isinstance(parser.arguments.set_hardware_usage, int) + + +class TestTimeoutOptions: + """Test timeout configuration.""" + + def test_timeout_float_value(self, tmp_path): + options = make_options(tmp_path, timeout=5.5) + parser = ArgParser(api_arguments=options) + assert parser.arguments.timeout == 5.5 + + def test_time_sleep_float_value(self, tmp_path): + options = make_options(tmp_path, time_sleep_between_requests=0.5) + parser = ArgParser(api_arguments=options) + assert parser.arguments.time_sleep_between_requests == 0.5 + + +class TestRetryOptions: + """Test retry configuration.""" + + def test_retries_integer(self, tmp_path): + options = make_options(tmp_path, retries=3) + parser = ArgParser(api_arguments=options) + assert parser.arguments.retries == 3 + + +class TestBooleanFlags: + """Test boolean flag options.""" + + def test_scan_ip_range_flag(self, tmp_path): + options = make_options(tmp_path, scan_ip_range=True) + parser = ArgParser(api_arguments=options) + assert parser.arguments.scan_ip_range is True + + def test_scan_subdomains_flag(self, tmp_path): + options = make_options(tmp_path, scan_subdomains=True) + parser = ArgParser(api_arguments=options) + assert parser.arguments.scan_subdomains is True + + def test_skip_service_discovery_flag(self, tmp_path): + options = make_options(tmp_path, skip_service_discovery=True) + parser = ArgParser(api_arguments=options) + assert parser.arguments.skip_service_discovery is True + + def test_ping_before_scan_flag(self, tmp_path): + options = make_options(tmp_path, ping_before_scan=True) + parser = ArgParser(api_arguments=options) + assert parser.arguments.ping_before_scan is True + + +class TestApiOptions: + """Test API configuration options.""" + + def test_api_hostname(self, tmp_path): + options = make_options(tmp_path, api_hostname="127.0.0.1") + parser = ArgParser(api_arguments=options) + assert parser.arguments.api_hostname == "127.0.0.1" + + def test_api_port(self, tmp_path): + options = make_options(tmp_path, api_port=8000) + parser = ArgParser(api_arguments=options) + assert parser.arguments.api_port == 8000 + + def test_api_debug_mode(self, tmp_path): + options = make_options(tmp_path, api_debug_mode=True) + parser = ArgParser(api_arguments=options) + assert parser.arguments.api_debug_mode is True + + def test_api_access_key(self, tmp_path): + options = make_options(tmp_path, api_access_key="test_key_123") + parser = ArgParser(api_arguments=options) + assert parser.arguments.api_access_key == "test_key_123" + + +class TestSocksProxyOption: + """Test SOCKS proxy configuration.""" + + def test_socks_proxy(self, tmp_path): + options = make_options(tmp_path, socks_proxy="127.0.0.1:9050") + parser = ArgParser(api_arguments=options) + assert parser.arguments.socks_proxy == "127.0.0.1:9050" diff --git a/tests/core/test_socks_proxy.py b/tests/core/test_socks_proxy.py new file mode 100644 index 000000000..74caa8401 --- /dev/null +++ b/tests/core/test_socks_proxy.py @@ -0,0 +1,22 @@ +from unittest.mock import patch, MagicMock + +from nettacker.core import socks_proxy as socks_proxy_module + +from nettacker.core.socks_proxy import set_socks_proxy + + +def test_set_socks_proxy_none(): + result = set_socks_proxy(None) + assert isinstance(result, tuple) + assert len(result) == 2 + + +@patch("socks.set_default_proxy") +@patch("socks.socksocket") +def test_set_socks_proxy_with_proxy(mock_socksocket, mock_set_default_proxy): + # Test with a valid SOCKS proxy setup + socket_factory, resolver = set_socks_proxy("socks5://localhost:1080") + assert isinstance((socket_factory, resolver), tuple) + assert socket_factory is mock_socksocket + assert resolver is socks_proxy_module.getaddrinfo + mock_set_default_proxy.assert_called_once() diff --git a/tests/core/test_template.py b/tests/core/test_template.py new file mode 100644 index 000000000..b6c4eabef --- /dev/null +++ b/tests/core/test_template.py @@ -0,0 +1,21 @@ +from nettacker.core.template import TemplateLoader +from nettacker.core.template import Config + + +def test_template_loader_initialization(): + loader = TemplateLoader("port_scan_scan") + assert loader is not None + + +def test_template_loader_open(tmp_path, monkeypatch): + scan_dir = tmp_path / "scan" + scan_dir.mkdir(parents=True) + module_file = scan_dir / "port_scan.yaml" + module_file.write_text("name: test\n", encoding="utf-8") + + monkeypatch.setattr(Config.path, "modules_dir", tmp_path) + + loader = TemplateLoader("port_scan_scan") + content = loader.open() + assert content is not None + assert isinstance(content, str) diff --git a/tests/core/utils/test_common_additional.py b/tests/core/utils/test_common_additional.py new file mode 100644 index 000000000..94b11dd06 --- /dev/null +++ b/tests/core/utils/test_common_additional.py @@ -0,0 +1,42 @@ +import time +from threading import Thread + +from nettacker.core.utils import common as common_utils + + +def test_remove_sensitive_header_keys_strips_auth(): + event = {"headers": {"Authorization": "secret", "X-Test": "ok"}} + cleaned = common_utils.remove_sensitive_header_keys(event) + assert "Authorization" not in cleaned["headers"] + assert cleaned["headers"]["X-Test"] == "ok" + + +def test_reverse_and_regex_condition_reverse_true(): + assert common_utils.reverse_and_regex_condition([], True) is True + + +def test_merge_logs_to_list_nested(): + data = {"outer": {"log": "a"}, "items": {"log": "b"}} + result = common_utils.merge_logs_to_list(data) + assert set(result) == {"a", "b"} + + +def test_sanitize_path(): + assert common_utils.sanitize_path("../etc/passwd") == "etc_passwd" + + +def test_wait_for_threads_to_finish(): + thread = Thread(target=lambda: time.sleep(0.01)) + thread.start() + assert common_utils.wait_for_threads_to_finish([thread]) is True + + +def test_generate_compare_filepath_format(): + scan_id = "scan123" + generated = common_utils.generate_compare_filepath(scan_id) + assert generated.endswith(f"{scan_id}.json") + + +def test_generate_random_token_length(): + token = common_utils.generate_random_token(5) + assert len(token) == 5 diff --git a/tests/core/utils/test_common_extended.py b/tests/core/utils/test_common_extended.py new file mode 100644 index 000000000..3cd040136 --- /dev/null +++ b/tests/core/utils/test_common_extended.py @@ -0,0 +1,224 @@ +""" +Targeted tests for common.py utilities to improve coverage. +Focus on the uncovered branches and edge cases. +""" + +import sys +from unittest.mock import MagicMock, patch +import pytest +import json + +from nettacker.core.utils.common import ( + replace_dependent_response, + merge_logs_to_list, + reverse_and_regex_condition, + wait_for_threads_to_finish, + remove_sensitive_header_keys, + get_http_header_key, + get_http_header_value, + find_args_value, + string_to_bytes, + generate_target_groups, + arrays_to_matrix, +) + + +class TestReplaceDependentResponse: + """Test replace_dependent_response function.""" + + def test_replace_dependent_response_with_data(self): + """Test replacing response dependent keys.""" + response_dependent = {"ip": "192.168.1.1"} + log = "Check response" + + result = replace_dependent_response(log, response_dependent) + assert result == "Check response" + + +class TestMergeLogsToList: + """Test merge_logs_to_list function.""" + + def test_merge_logs_empty_dict(self): + """Test merging empty dict.""" + result = merge_logs_to_list({}) + assert isinstance(result, list) + assert len(result) == 0 + + def test_merge_logs_with_log_key(self): + """Test merging dict with log key.""" + data = {"log": "test log message"} + result = merge_logs_to_list(data) + assert "test log message" in result + + def test_merge_logs_with_json_event_dict(self): + """Test with json_event as dict.""" + data = { + "json_event": {"key": "value"}, + "log": "message" + } + result = merge_logs_to_list(data) + assert "message" in result + + def test_merge_logs_with_json_event_string(self): + """Test with json_event as string.""" + data = { + "json_event": '{"key": "value"}', + "log": "message2" + } + result = merge_logs_to_list(data) + assert "message2" in result + + def test_merge_logs_nested_structure(self): + """Test with nested structure.""" + data = { + "level1": { + "level2": { + "log": "nested message" + } + } + } + result = merge_logs_to_list(data) + assert "nested message" in result + + def test_merge_logs_duplicate_deduplication(self): + """Test that duplicates are removed.""" + data = [ + {"log": "same message"}, + {"log": "same message"} + ] + # Note: this function expects dict, not list + result1 = merge_logs_to_list(data[0]) + result2 = merge_logs_to_list(data[1]) + # Results should be deduplicated sets + assert isinstance(result1, list) + + +class TestReverseAndRegexCondition: + """Test reverse_and_regex_condition function.""" + + def test_reverse_true_regex_true(self): + """Test with both reverse and regex true.""" + result = reverse_and_regex_condition(True, True) + assert result == [] + + +class TestGenerateWordList: + """Test utility functions.""" + + def test_arrays_to_matrix_empty(self): + """Test converting empty arrays to matrix.""" + arrays = [] + result = arrays_to_matrix(arrays) + assert isinstance(result, list) + + +class TestTextToJson: + """Test string_to_bytes function.""" + + def test_string_to_bytes_ascii(self): + """Test converting ASCII string to bytes.""" + text = "hello" + result = string_to_bytes(text) + assert isinstance(result, (bytes, str)) + + def test_string_to_bytes_unicode(self): + """Test converting unicode string to bytes.""" + text = "你好" + result = string_to_bytes(text) + assert result is not None + + def test_string_to_bytes_empty(self): + """Test with empty string.""" + result = string_to_bytes("") + assert result is not None + + +class TestRemoveSensitiveHeaders: + """Test remove_sensitive_header_keys function.""" + + def test_remove_sensitive_headers_empty(self): + """Test with empty event.""" + result = remove_sensitive_header_keys({}) + assert isinstance(result, dict) + + def test_remove_sensitive_headers_with_password(self): + """Test removing password header.""" + event = { + "response": { + "headers": { + "Authorization": "Bearer token123", + "Content-Type": "application/json" + } + } + } + result = remove_sensitive_header_keys(event) + assert isinstance(result, dict) + + def test_remove_sensitive_headers_with_cookie(self): + """Test removing cookie header.""" + event = { + "response": { + "headers": { + "Cookie": "session=abc123", + "User-Agent": "Mozilla" + } + } + } + result = remove_sensitive_header_keys(event) + assert isinstance(result, dict) + + +class TestHeaderKeyValueParse: + """Test header parsing functions.""" + + def test_get_http_header_key(self): + """Test getting header key.""" + header = "Content-Type: application/json" + key = get_http_header_key(header) + assert key == "Content-Type" or key is not None + + def test_get_http_header_key_with_spaces(self): + """Test header key with leading spaces.""" + header = " Authorization: Bearer token" + key = get_http_header_key(header) + assert key is not None + + def test_get_http_header_value(self): + """Test getting header value.""" + header = "Content-Type: application/json" + value = get_http_header_value(header) + assert "application/json" in str(value) or value is not None + + def test_get_http_header_value_complex(self): + """Test getting complex header value.""" + header = "Set-Cookie: session=abc123; Path=/" + value = get_http_header_value(header) + assert value is not None + + +class TestFindArgsValue: + """Test find_args_value function.""" + + def test_find_args_value_exists(self): + """Test finding existing argument.""" + with patch.object(sys, "argv", ["prog", "-L", "en"]): + result = find_args_value("-L") + assert result == "en" + + def test_find_args_value_not_exists(self): + """Test finding non-existent argument.""" + with patch.object(sys, "argv", ["prog"]): + result = find_args_value("-L") + assert result is None + + def test_find_args_long_flag(self): + """Test finding long flag argument.""" + with patch.object(sys, "argv", ["prog", "--language", "fa"]): + result = find_args_value("--language") + assert result == "fa" + + def test_find_args_value_last_in_argv(self): + """Test when flag is last argument with no value.""" + with patch.object(sys, "argv", ["prog", "-L"]): + result = find_args_value("-L") + assert result is None diff --git a/tests/core/utils/test_common_more.py b/tests/core/utils/test_common_more.py new file mode 100644 index 000000000..b5a41c365 --- /dev/null +++ b/tests/core/utils/test_common_more.py @@ -0,0 +1,54 @@ +import sys + +from nettacker.core.utils.common import ( + select_maximum_cpu_core, + now, + find_args_value, + get_http_header_key, + get_http_header_value, + string_to_bytes, +) + + +def test_select_maximum_cpu_core_modes(): + assert select_maximum_cpu_core("low") >= 1 + assert select_maximum_cpu_core("normal") >= 1 + assert select_maximum_cpu_core("high") >= 1 + assert select_maximum_cpu_core("maximum") >= 1 + assert select_maximum_cpu_core("invalid") == 1 + + +def test_now_format(): + result = now() + assert len(result) == 19 # "%Y-%m-%d %H:%M:%S" format + assert result.count("-") == 2 + assert result.count(":") == 2 + + +def test_find_args_value_exists(monkeypatch): + monkeypatch.setattr(sys, "argv", ["prog", "-t", "example.com"]) + result = find_args_value("-t") + assert result == "example.com" + + +def test_find_args_value_missing(monkeypatch): + monkeypatch.setattr(sys, "argv", ["prog"]) + result = find_args_value("-x") + assert result is None + + +def test_get_http_header_key(): + assert get_http_header_key("Authorization: Bearer token") == "Authorization" + assert get_http_header_key("X-Custom-Header: value") == "X-Custom-Header" + + +def test_get_http_header_value(): + assert get_http_header_value("Authorization: Bearer token") == "Bearer token" + assert get_http_header_value("X-Custom: ") is None + assert get_http_header_value("no-value") is None + + +def test_string_to_bytes(): + result = string_to_bytes("hello") + assert result == b"hello" + assert isinstance(result, bytes) diff --git a/tests/lib/graph/__init__.py b/tests/lib/graph/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/lib/graph/d3_tree_v1/__init__.py b/tests/lib/graph/d3_tree_v1/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/lib/graph/d3_tree_v1/test_engine.py b/tests/lib/graph/d3_tree_v1/test_engine.py new file mode 100644 index 000000000..d2f2d6880 --- /dev/null +++ b/tests/lib/graph/d3_tree_v1/test_engine.py @@ -0,0 +1,84 @@ +import json +import re + +import pytest + +from nettacker.lib.graph.d3_tree_v1 import engine as d3_v1 + + +def _extract_tree_data(result): + match = re.search(r"treeData\s*=\s*(\{.*?\});", result, flags=re.S) + assert match is not None + return json.loads(match.group(1)) + + +def test_escape_for_html_js(): + escaped = d3_v1.escape_for_html_js("&entity") + assert escaped == "\\u003Ctag\\u003E\\u0026entity\\u003C/tag\\u003E" + + +@pytest.mark.parametrize( + "raw, expected", + [ + ("", ""), + ("plain-text", "plain-text"), + ('"quoted"', '"quoted"'), + ("&", "\\u003Ca\\u003E\\u0026\\u003Cb\\u003E"), + ("مرحبا", "مرحبا\\u003Cok\\u003E"), + ], +) +def test_escape_for_html_js_parametrized(raw, expected): + assert d3_v1.escape_for_html_js(raw) == expected + + +def test_d3_tree_v1_start_empty(): + result = d3_v1.start([]) + assert isinstance(result, str) + assert result.strip() + + +def test_d3_tree_v1_start_with_multiple_events(): + events = [ + {"target": "127.0.0.1", "module_name": "port_scan", "port": 80, "event": "port_open"}, + {"target": "example.com", "module_name": "http_scan", "port": 443, "event": "http_ok"}, + ] + result = d3_v1.start(events) + payload = _extract_tree_data(result) + + assert isinstance(payload, dict) + assert isinstance(payload.get("children"), list) + + names = {child["name"] for child in payload["children"]} + assert "127.0.0.1" in names + assert "example.com" in names + + +def test_d3_tree_v1_start_with_missing_optional_fields(): + result = d3_v1.start([{"target": "only-target"}]) + payload = _extract_tree_data(result) + + assert isinstance(payload, dict) + assert any(child["name"] == "only-target" for child in payload["children"]) + + +def test_d3_tree_v1_start_escapes_xss_payload(): + events = [{"target": "", "module_name": "x", "port": 1, "event": "e"}] + result = d3_v1.start(events) + + assert "