diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index e808659c..ed6f2133 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -201,6 +201,10 @@ jobs: - name: Setup Go and protoc run: | python do.py setup_ext ${{ matrix.go-version }} + + - name: Set Toolchain to 1.25.0 + run: echo "GOTOOLCHAIN=go1.25.0" >> $GITHUB_ENV + - name: Run artifact generation run: | python do.py generate go @@ -221,9 +225,7 @@ jobs: name: python_package - name: Display structure of downloaded files run: ls -R - - uses: actions/setup-go@v2 - with: - go-version: ${{ matrix.go-version }} + - name: Run go tests run: | python do.py testgo diff --git a/do.py b/do.py index af392b6a..13ba88b0 100644 --- a/do.py +++ b/do.py @@ -274,7 +274,7 @@ def testgo(): def go_lint(): try: - version = "1.64.2" + version = "1.64.8" pkg = "go install" if on_linux() or on_macos(): @@ -471,7 +471,8 @@ def run(commands, capture_output=False): cmd = cmd.encode("utf-8", errors="ignore") subprocess.check_call(cmd, shell=True, stdout=fd) return flush_output(fd, logfile) - except Exception: + except Exception as e: + print(f"Error: {e}") flush_output(fd, logfile) sys.exit(1) @@ -484,6 +485,7 @@ def getstatusoutput(command): def build(sdk="all", env_setup=None): + os.environ["GOTOOLCHAIN"] = "go1.25.0" print("\nSTEP 1: Set up virtual environment") if env_setup is not None and env_setup.lower() == "clean": @@ -512,6 +514,7 @@ def build(sdk="all", env_setup=None): ) init() run([py() + " -m pip install ."]) + print("\nSTEP 3: Generating Python and Go SDKs\n") generate(sdk=sdk, cicd="True") if sdk == "python" or sdk == "all": diff --git a/openapiart/common.py b/openapiart/common.py index 65d976ee..3b6e7096 100644 --- a/openapiart/common.py +++ b/openapiart/common.py @@ -8,15 +8,11 @@ import io import sys import time -import grpc import semantic_version import types import platform import base64 import re -from google.protobuf import json_format -import sanity_pb2_grpc as pb2_grpc -import sanity_pb2 as pb2 try: from typing import Union, Dict, List, Any, Literal diff --git a/openapiart/generator.py b/openapiart/generator.py index 6161d06b..ebc0da19 100644 --- a/openapiart/generator.py +++ b/openapiart/generator.py @@ -163,19 +163,6 @@ def generate(self): os.path.join(os.path.dirname(__file__), "common.py"), "r" ) as fp: common_content = fp.read() - cnf_text = "import sanity_pb2_grpc as pb2_grpc" - modify_text = "try:\n from {pkg_name} {text}\nexcept ImportError:\n {text}".format( - pkg_name=self._package_name, - text=cnf_text.replace("sanity", self._protobuf_package_name), - ) - common_content = common_content.replace(cnf_text, modify_text) - - cnf_text = "import sanity_pb2 as pb2" - modify_text = "try:\n from {pkg_name} {text}\nexcept ImportError:\n {text}".format( - pkg_name=self._package_name, - text=cnf_text.replace("sanity", self._protobuf_package_name), - ) - common_content = common_content.replace(cnf_text, modify_text) cnf_text = 'log = logging.getLogger("common")' modify_text = 'log = logging.getLogger("{pkg_name}")'.format( @@ -385,6 +372,22 @@ def _get_methods_and_factories(self): return methods, factories, rpc_methods def _write_rpc_api_class(self, rpc_methods): + + pb_imports = """ + try: + self._pb2_grpc = importlib.import_module('{pkg_name}.{protobuf_name}_pb2_grpc') + except ImportError: + self._pb2_grpc = importlib.import_module('{protobuf_name}_pb2_grpc') + + try: + self._pb2 = importlib.import_module('{pkg_name}.{protobuf_name}_pb2') + except ImportError: + self._pb2 = importlib.import_module('{protobuf_name}_pb2') + """.format( + pkg_name=self._package_name, + protobuf_name=self._protobuf_package_name, + ) + class_code = """class GrpcApi(Api): # OpenAPI gRPC Api def __init__(self, **kwargs): @@ -404,9 +407,14 @@ def __init__(self, **kwargs): else "localhost:50051" ) self._transport = kwargs["transport"] if "transport" in kwargs else None - log.debug("gRPCTransport args: {}".format(", ".join(["{}={!r}".format(k, v) for k, v in kwargs.items()]))) + log.debug("gRPCTransport args: {{}}".format(", ".join(["{{}}={{!r}}".format(k, v) for k, v in kwargs.items()]))) self._telemetry.initiate_grpc_instrumentation() + #lazy load grpc packages and protobuf stubs + self._grpc = importlib.import_module('grpc') + self._json_format = importlib.import_module('google.protobuf.json_format') + {pb_imports} + def _use_secure_connection(self, cert_path, cert_domain=None): \"\"\"Accepts certificate and host_name for SSL Connection.\"\"\" if cert_path is None: @@ -420,16 +428,16 @@ def _get_stub(self): ('grpc.keepalive_timeout_ms', self._keep_alive_timeout), ('grpc.max_receive_message_length', self._maximum_receive_buffer_size)] if self._cert is None: - self._channel = grpc.insecure_channel(self._location, options=CHANNEL_OPTIONS) + self._channel = self._grpc.insecure_channel(self._location, options=CHANNEL_OPTIONS) else: crt = open(self._cert, "rb").read() - creds = grpc.ssl_channel_credentials(crt) + creds = self._grpc.ssl_channel_credentials(crt) if self._cert_domain is not None: CHANNEL_OPTIONS.append(('grpc.ssl_target_name_override', self._cert_domain)) - self._channel = grpc.secure_channel( + self._channel = self._grpc.secure_channel( self._location, credentials=creds, options=CHANNEL_OPTIONS ) - self._stub = pb2_grpc.OpenapiStub(self._channel) + self._stub = self._pb2_grpc.OpenapiStub(self._channel) return self._stub def _serialize_payload(self, payload): @@ -459,7 +467,7 @@ def _client_stream(self, stub, data): chunk = data[i:len(data)] else: chunk = data[i:i+self._chunk_size] - data_chunks.append(pb2.Data(datum=chunk, chunk_size=self._chunk_size)) + data_chunks.append(self._pb2.Data(datum=chunk, chunk_size=self._chunk_size)) # print(chunk_list, len(chunk_list)) reqs = iter(data_chunks) return reqs @@ -513,7 +521,7 @@ def close(self): with open(self._api_filename, "a") as self._fid: self._write() self._write() - self._write(0, class_code) + self._write(0, class_code.format(pb_imports=pb_imports)) for rpc_method in rpc_methods: self._write() status_msg = "" @@ -540,7 +548,7 @@ def close(self): self._write(2, "stub = self._get_stub()") self._write( 2, - "empty = pb2_grpc.google_dot_protobuf_dot_empty__pb2.Empty()", + "empty = self._pb2_grpc.google_dot_protobuf_dot_empty__pb2.Empty()", ) ( line_indent, @@ -577,9 +585,11 @@ def close(self): self._write(2, "self.add_warnings('%s')" % status_msg) if not only_bytes: - self._write(2, "pb_obj = json_format.Parse(") + self._write(2, "pb_obj = self._json_format.Parse(") self._write(3, "self._serialize_payload(payload),") - self._write(3, "pb2.%s()" % rpc_method.request_class) + self._write( + 3, "self._pb2.%s()" % rpc_method.request_class + ) self._write(2, ")") if ( self._generate_version_api @@ -589,7 +599,7 @@ def close(self): self._write( 2, - "req_obj = pb2.{operation_name}Request({request_property}={obj})".format( + "req_obj = self._pb2.{operation_name}Request({request_property}={obj})".format( operation_name=rpc_method.operation_name, request_property=rpc_method.request_property, obj="payload" if only_bytes else "pb_obj", @@ -617,9 +627,9 @@ def close(self): "res_obj = stub.%s(req_obj, timeout=self._request_timeout)" % rpc_method.operation_name, ) - self._write(2, "except grpc.RpcError as grpc_error:") + self._write(2, "except self._grpc.RpcError as grpc_error:") self._write(3, "self._raise_exception(grpc_error)") - self._write(2, "response = json_format.MessageToDict(") + self._write(2, "response = self._json_format.MessageToDict(") self._write(3, "res_obj, preserving_proto_field_name=True") self._write(2, ")") self._write(2, 'log.debug("Response - " + str(response))') @@ -652,7 +662,9 @@ def close(self): if including_default: self._write(3, "if len(result) == 0:") - self._write(4, "result = json_format.MessageToDict(") + self._write( + 4, "result = self._json_format.MessageToDict(" + ) self._write( 5, "res_obj.%s," % rpc_method.proto_field_name ) @@ -849,6 +861,10 @@ def _write_api_class(self, methods, factories): self._write( 2, 'transport = kwargs.get("otel_collector_transport")' ) + self._write( + 2, + "self._is_grpc_transport = True if kwargs.get('transport') == 'grpc' else False", + ) self._write(2, "self._telemetry = Telemetry(endpoint, transport)") self._write() @@ -875,15 +891,17 @@ def _write_api_class(self, methods, factories): self._write(2, "# type: (Exception) -> Union[Error, None]") self._write(2, "if isinstance(error, Error):") self._write(3, "return error") - self._write(2, "elif isinstance(error, grpc.RpcError):") - self._write(3, "err = self._deserialize_error(error.details())") - self._write(3, "if err is not None:") - self._write(4, "return err") - self._write(3, "err = self.error()") - self._write(3, "err.code = error.code().value[0]") - self._write(3, "err.errors = [error.details()]") - self._write(3, "return err") self._write(2, "elif isinstance(error, Exception):") + self._write(3, "if self._is_grpc_transport:") + self._write(4, "import grpc") + self._write(4, "if isinstance(error, grpc.RpcError):") + self._write(5, "err = self._deserialize_error(error.details())") + self._write(5, "if err is not None:") + self._write(6, "return err") + self._write(5, "err = self.error()") + self._write(5, "err.code = error.code().value[0]") + self._write(5, "err.errors = [error.details()]") + self._write(5, "return err") self._write(3, "if len(error.args) != 1:") self._write(4, "return None") self._write(3, "if isinstance(error.args[0], Error):") @@ -2242,7 +2260,7 @@ def _add_streaming_code(self, rpc_method, return_byte, line_indent): elif rpc_method.good_response_type: self._write( line_indent, - "res_obj = pb2.%s()" % rpc_method.response_class, + "res_obj = self._pb2.%s()" % rpc_method.response_class, ) self._write(line_indent, "res_obj.ParseFromString(data)") else: diff --git a/openapiart/tests/test_error_schema.py b/openapiart/tests/test_error_schema.py index c21a3071..2dc32ff1 100644 --- a/openapiart/tests/test_error_schema.py +++ b/openapiart/tests/test_error_schema.py @@ -18,7 +18,15 @@ def create_openapi_artifacts(openapiart_class, sdk=None, file_name=None): def str_compare(validte_str, entire_str, item): - return validte_str in entire_str and item in entire_str + normalized_entire = ( + entire_str.replace("(", "").replace(")", "").replace("'", "") + ) + normalized_validate = ( + validte_str.replace("(", "").replace(")", "").replace("'", "") + ) + return ( + normalized_validate in normalized_entire and item in normalized_entire + ) def test_validate_response_default(): @@ -64,7 +72,7 @@ def test_error_for_missing_required(): file_name="./response/response_missing_required_in_error.yaml", ) error_value = execinfo.value.args[0] - assert error_msg == error_value + assert str_compare(error_msg, error_value, "Error") if __name__ == "__main__": diff --git a/openapiart/tests/test_lazy_import.py b/openapiart/tests/test_lazy_import.py new file mode 100644 index 00000000..1b622ef3 --- /dev/null +++ b/openapiart/tests/test_lazy_import.py @@ -0,0 +1,197 @@ +"""Tests to verify that grpc and protobuf imports are lazily loaded. + +The generated code should NOT import grpc, protobuf stubs (*_pb2, *_pb2_grpc) +at module level. These should only be imported inside GrpcApi.__init__ via +importlib.import_module, so that users who only need HTTP transport are not +required to have grpc/protobuf installed. +""" + +import ast +import os +import sys +import pytest + + +def _get_generated_source(): + """Return the source code of the generated sanity.py file.""" + src_path = os.path.join( + pytest.artifacts_path, pytest.module_name, pytest.module_name + ".py" + ) + with open(src_path, "r") as f: + return f.read() + + +def _get_top_level_imports(source): + """Parse the module AST and return all top-level import names.""" + tree = ast.parse(source) + imports = set() + for node in ast.iter_child_nodes(tree): + if isinstance(node, ast.Import): + for alias in node.names: + imports.add(alias.name) + elif isinstance(node, ast.ImportFrom): + if node.module is not None: + imports.add(node.module) + return imports + + +def test_grpc_not_in_top_level_imports(): + """grpc should not be imported at the module level.""" + source = _get_generated_source() + top_imports = _get_top_level_imports(source) + assert ( + "grpc" not in top_imports + ), "grpc should not be a top-level import in the generated module" + + +def test_protobuf_stubs_not_in_top_level_imports(): + """protobuf stubs (sanity_pb2, sanity_pb2_grpc) should not be imported + at the module level.""" + source = _get_generated_source() + top_imports = _get_top_level_imports(source) + pb2_imports = [ + name + for name in top_imports + if name.endswith("_pb2") or name.endswith("_pb2_grpc") + ] + assert len(pb2_imports) == 0, ( + "protobuf stubs should not be top-level imports, found: %s" + % pb2_imports + ) + + +def test_grpc_lazy_loaded_in_grpc_api_init(): + """The GrpcApi.__init__ should call importlib.import_module for grpc, + protobuf json_format, and the pb2 stubs.""" + source = _get_generated_source() + tree = ast.parse(source) + + # Find the GrpcApi class + grpc_api_class = None + for node in ast.walk(tree): + if isinstance(node, ast.ClassDef) and node.name == "GrpcApi": + grpc_api_class = node + break + assert ( + grpc_api_class is not None + ), "GrpcApi class not found in generated code" + + # Find __init__ method + init_method = None + for item in grpc_api_class.body: + if isinstance(item, ast.FunctionDef) and item.name == "__init__": + init_method = item + break + assert init_method is not None, "GrpcApi.__init__ not found" + + # Collect all importlib.import_module call string arguments in __init__ + lazy_imports = set() + for node in ast.walk(init_method): + if isinstance(node, ast.Call): + func = node.func + # Match importlib.import_module(...) or self._xxx = importlib.import_module(...) + is_importlib_call = False + if ( + isinstance(func, ast.Attribute) + and func.attr == "import_module" + ): + if isinstance(func.value, ast.Attribute): + is_importlib_call = True + elif ( + isinstance(func.value, ast.Name) + and func.value.id == "importlib" + ): + is_importlib_call = True + if is_importlib_call and node.args: + arg = node.args[0] + if isinstance(arg, ast.Constant) and isinstance( + arg.value, str + ): + lazy_imports.add(arg.value) + + assert ( + "grpc" in lazy_imports + ), "grpc should be lazily imported in GrpcApi.__init__" + assert ( + "google.protobuf.json_format" in lazy_imports + ), "google.protobuf.json_format should be lazily imported in GrpcApi.__init__" + pb2_lazy = [m for m in lazy_imports if "pb2" in m] + assert len(pb2_lazy) >= 2, ( + "Both pb2 and pb2_grpc stubs should be lazily imported in " + "GrpcApi.__init__, found: %s" % pb2_lazy + ) + + +def test_http_api_does_not_load_grpc(monkeypatch): + """Creating an HttpApi should not trigger any grpc or pb2 imports. + + We verify this by temporarily making grpc un-importable and confirming + that HttpApi instantiation still works. + """ + import importlib + + original_import = importlib.import_module + + grpc_modules_loaded = [] + + def tracking_import(name, *args, **kwargs): + if "grpc" in name or "_pb2" in name: + grpc_modules_loaded.append(name) + return original_import(name, *args, **kwargs) + + monkeypatch.setattr(importlib, "import_module", tracking_import) + + # Clear any previously cached state + grpc_modules_loaded.clear() + + # Creating an HTTP API should not trigger grpc/pb2 imports + module = pytest.module + http_api = module.api( + location="http://127.0.0.1:12345", + transport=module.Transport.HTTP, + verify=False, + ) + assert http_api is not None + + grpc_related = [ + m for m in grpc_modules_loaded if "grpc" in m or "_pb2" in m + ] + assert len(grpc_related) == 0, ( + "Creating HttpApi should not trigger grpc/pb2 imports, but loaded: %s" + % grpc_related + ) + + +def test_grpc_api_triggers_lazy_imports(monkeypatch): + """Creating a GrpcApi should trigger lazy imports of grpc and pb2 modules.""" + import importlib + + original_import = importlib.import_module + + grpc_modules_loaded = [] + + def tracking_import(name, *args, **kwargs): + if "grpc" in name or "_pb2" in name: + grpc_modules_loaded.append(name) + return original_import(name, *args, **kwargs) + + monkeypatch.setattr(importlib, "import_module", tracking_import) + + grpc_modules_loaded.clear() + + module = pytest.module + grpc_api = module.api( + location="localhost:50051", + transport=module.Transport.GRPC, + ) + assert grpc_api is not None + + # Verify grpc and pb2 modules were loaded + assert any( + "grpc" == m or m.endswith("_grpc") for m in grpc_modules_loaded + ), ("GrpcApi should lazily import grpc, loaded: %s" % grpc_modules_loaded) + assert any("_pb2" in m for m in grpc_modules_loaded), ( + "GrpcApi should lazily import pb2 stubs, loaded: %s" + % grpc_modules_loaded + ) + grpc_api.close() diff --git a/openapiart/tests/test_validate_x_field_pattern.py b/openapiart/tests/test_validate_x_field_pattern.py index 819ad18e..0ae8e2e9 100644 --- a/openapiart/tests/test_validate_x_field_pattern.py +++ b/openapiart/tests/test_validate_x_field_pattern.py @@ -18,7 +18,9 @@ def create_openapi_artifacts(openapiart_class, sdk=None, file_name=None): def str_compare(validte_str, entire_str): - return validte_str in entire_str + normalized_entire = entire_str.replace("(", "").replace(")", "") + normalized_validate = validte_str.replace("(", "").replace(")", "") + return normalized_validate in normalized_entire def test_validate_pattern():