Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions .github/workflows/build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -201,6 +201,10 @@ jobs:
- name: Setup Go and protoc
run: |
python do.py setup_ext ${{ matrix.go-version }}

- name: Set Toolchain to 1.25.0
run: echo "GOTOOLCHAIN=go1.25.0" >> $GITHUB_ENV

- name: Run artifact generation
run: |
python do.py generate go
Expand All @@ -221,9 +225,7 @@ jobs:
name: python_package
- name: Display structure of downloaded files
run: ls -R
- uses: actions/setup-go@v2
with:
go-version: ${{ matrix.go-version }}

- name: Run go tests
run: |
python do.py testgo
Expand Down
7 changes: 5 additions & 2 deletions do.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,7 +274,7 @@ def testgo():
def go_lint():
try:

version = "1.64.2"
version = "1.64.8"

pkg = "go install"
if on_linux() or on_macos():
Expand Down Expand Up @@ -471,7 +471,8 @@ def run(commands, capture_output=False):
cmd = cmd.encode("utf-8", errors="ignore")
subprocess.check_call(cmd, shell=True, stdout=fd)
return flush_output(fd, logfile)
except Exception:
except Exception as e:
print(f"Error: {e}")
flush_output(fd, logfile)
sys.exit(1)

Expand All @@ -484,6 +485,7 @@ def getstatusoutput(command):


def build(sdk="all", env_setup=None):
os.environ["GOTOOLCHAIN"] = "go1.25.0"
print("\nSTEP 1: Set up virtual environment")

if env_setup is not None and env_setup.lower() == "clean":
Expand Down Expand Up @@ -512,6 +514,7 @@ def build(sdk="all", env_setup=None):
)
init()
run([py() + " -m pip install ."])

print("\nSTEP 3: Generating Python and Go SDKs\n")
generate(sdk=sdk, cicd="True")
if sdk == "python" or sdk == "all":
Expand Down
4 changes: 0 additions & 4 deletions openapiart/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,15 +8,11 @@
import io
import sys
import time
import grpc
import semantic_version
import types
import platform
import base64
import re
from google.protobuf import json_format
import sanity_pb2_grpc as pb2_grpc
import sanity_pb2 as pb2

try:
from typing import Union, Dict, List, Any, Literal
Expand Down
90 changes: 54 additions & 36 deletions openapiart/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,19 +163,6 @@ def generate(self):
os.path.join(os.path.dirname(__file__), "common.py"), "r"
) as fp:
common_content = fp.read()
cnf_text = "import sanity_pb2_grpc as pb2_grpc"
modify_text = "try:\n from {pkg_name} {text}\nexcept ImportError:\n {text}".format(
pkg_name=self._package_name,
text=cnf_text.replace("sanity", self._protobuf_package_name),
)
common_content = common_content.replace(cnf_text, modify_text)

cnf_text = "import sanity_pb2 as pb2"
modify_text = "try:\n from {pkg_name} {text}\nexcept ImportError:\n {text}".format(
pkg_name=self._package_name,
text=cnf_text.replace("sanity", self._protobuf_package_name),
)
common_content = common_content.replace(cnf_text, modify_text)

cnf_text = 'log = logging.getLogger("common")'
modify_text = 'log = logging.getLogger("{pkg_name}")'.format(
Expand Down Expand Up @@ -385,6 +372,22 @@ def _get_methods_and_factories(self):
return methods, factories, rpc_methods

def _write_rpc_api_class(self, rpc_methods):

pb_imports = """
try:
self._pb2_grpc = importlib.import_module('{pkg_name}.{protobuf_name}_pb2_grpc')
except ImportError:
self._pb2_grpc = importlib.import_module('{protobuf_name}_pb2_grpc')

try:
self._pb2 = importlib.import_module('{pkg_name}.{protobuf_name}_pb2')
except ImportError:
self._pb2 = importlib.import_module('{protobuf_name}_pb2')
""".format(
pkg_name=self._package_name,
protobuf_name=self._protobuf_package_name,
)

class_code = """class GrpcApi(Api):
# OpenAPI gRPC Api
def __init__(self, **kwargs):
Expand All @@ -404,9 +407,14 @@ def __init__(self, **kwargs):
else "localhost:50051"
)
self._transport = kwargs["transport"] if "transport" in kwargs else None
log.debug("gRPCTransport args: {}".format(", ".join(["{}={!r}".format(k, v) for k, v in kwargs.items()])))
log.debug("gRPCTransport args: {{}}".format(", ".join(["{{}}={{!r}}".format(k, v) for k, v in kwargs.items()])))
self._telemetry.initiate_grpc_instrumentation()

#lazy load grpc packages and protobuf stubs
self._grpc = importlib.import_module('grpc')
self._json_format = importlib.import_module('google.protobuf.json_format')
{pb_imports}

def _use_secure_connection(self, cert_path, cert_domain=None):
\"\"\"Accepts certificate and host_name for SSL Connection.\"\"\"
if cert_path is None:
Expand All @@ -420,16 +428,16 @@ def _get_stub(self):
('grpc.keepalive_timeout_ms', self._keep_alive_timeout),
('grpc.max_receive_message_length', self._maximum_receive_buffer_size)]
if self._cert is None:
self._channel = grpc.insecure_channel(self._location, options=CHANNEL_OPTIONS)
self._channel = self._grpc.insecure_channel(self._location, options=CHANNEL_OPTIONS)
else:
crt = open(self._cert, "rb").read()
creds = grpc.ssl_channel_credentials(crt)
creds = self._grpc.ssl_channel_credentials(crt)
if self._cert_domain is not None:
CHANNEL_OPTIONS.append(('grpc.ssl_target_name_override', self._cert_domain))
self._channel = grpc.secure_channel(
self._channel = self._grpc.secure_channel(
self._location, credentials=creds, options=CHANNEL_OPTIONS
)
self._stub = pb2_grpc.OpenapiStub(self._channel)
self._stub = self._pb2_grpc.OpenapiStub(self._channel)
return self._stub

def _serialize_payload(self, payload):
Expand Down Expand Up @@ -459,7 +467,7 @@ def _client_stream(self, stub, data):
chunk = data[i:len(data)]
else:
chunk = data[i:i+self._chunk_size]
data_chunks.append(pb2.Data(datum=chunk, chunk_size=self._chunk_size))
data_chunks.append(self._pb2.Data(datum=chunk, chunk_size=self._chunk_size))
# print(chunk_list, len(chunk_list))
reqs = iter(data_chunks)
return reqs
Expand Down Expand Up @@ -513,7 +521,7 @@ def close(self):
with open(self._api_filename, "a") as self._fid:
self._write()
self._write()
self._write(0, class_code)
self._write(0, class_code.format(pb_imports=pb_imports))
for rpc_method in rpc_methods:
self._write()
status_msg = ""
Expand All @@ -540,7 +548,7 @@ def close(self):
self._write(2, "stub = self._get_stub()")
self._write(
2,
"empty = pb2_grpc.google_dot_protobuf_dot_empty__pb2.Empty()",
"empty = self._pb2_grpc.google_dot_protobuf_dot_empty__pb2.Empty()",
)
(
line_indent,
Expand Down Expand Up @@ -577,9 +585,11 @@ def close(self):
self._write(2, "self.add_warnings('%s')" % status_msg)

if not only_bytes:
self._write(2, "pb_obj = json_format.Parse(")
self._write(2, "pb_obj = self._json_format.Parse(")
self._write(3, "self._serialize_payload(payload),")
self._write(3, "pb2.%s()" % rpc_method.request_class)
self._write(
3, "self._pb2.%s()" % rpc_method.request_class
)
self._write(2, ")")
if (
self._generate_version_api
Expand All @@ -589,7 +599,7 @@ def close(self):

self._write(
2,
"req_obj = pb2.{operation_name}Request({request_property}={obj})".format(
"req_obj = self._pb2.{operation_name}Request({request_property}={obj})".format(
operation_name=rpc_method.operation_name,
request_property=rpc_method.request_property,
obj="payload" if only_bytes else "pb_obj",
Expand Down Expand Up @@ -617,9 +627,9 @@ def close(self):
"res_obj = stub.%s(req_obj, timeout=self._request_timeout)"
% rpc_method.operation_name,
)
self._write(2, "except grpc.RpcError as grpc_error:")
self._write(2, "except self._grpc.RpcError as grpc_error:")
self._write(3, "self._raise_exception(grpc_error)")
self._write(2, "response = json_format.MessageToDict(")
self._write(2, "response = self._json_format.MessageToDict(")
self._write(3, "res_obj, preserving_proto_field_name=True")
self._write(2, ")")
self._write(2, 'log.debug("Response - " + str(response))')
Expand Down Expand Up @@ -652,7 +662,9 @@ def close(self):

if including_default:
self._write(3, "if len(result) == 0:")
self._write(4, "result = json_format.MessageToDict(")
self._write(
4, "result = self._json_format.MessageToDict("
)
self._write(
5, "res_obj.%s," % rpc_method.proto_field_name
)
Expand Down Expand Up @@ -849,6 +861,10 @@ def _write_api_class(self, methods, factories):
self._write(
2, 'transport = kwargs.get("otel_collector_transport")'
)
self._write(
2,
"self._is_grpc_transport = True if kwargs.get('transport') == 'grpc' else False",
)
self._write(2, "self._telemetry = Telemetry(endpoint, transport)")

self._write()
Expand All @@ -875,15 +891,17 @@ def _write_api_class(self, methods, factories):
self._write(2, "# type: (Exception) -> Union[Error, None]")
self._write(2, "if isinstance(error, Error):")
self._write(3, "return error")
self._write(2, "elif isinstance(error, grpc.RpcError):")
self._write(3, "err = self._deserialize_error(error.details())")
self._write(3, "if err is not None:")
self._write(4, "return err")
self._write(3, "err = self.error()")
self._write(3, "err.code = error.code().value[0]")
self._write(3, "err.errors = [error.details()]")
self._write(3, "return err")
self._write(2, "elif isinstance(error, Exception):")
self._write(3, "if self._is_grpc_transport:")
self._write(4, "import grpc")
self._write(4, "if isinstance(error, grpc.RpcError):")
self._write(5, "err = self._deserialize_error(error.details())")
self._write(5, "if err is not None:")
self._write(6, "return err")
self._write(5, "err = self.error()")
self._write(5, "err.code = error.code().value[0]")
self._write(5, "err.errors = [error.details()]")
self._write(5, "return err")
self._write(3, "if len(error.args) != 1:")
self._write(4, "return None")
self._write(3, "if isinstance(error.args[0], Error):")
Expand Down Expand Up @@ -2242,7 +2260,7 @@ def _add_streaming_code(self, rpc_method, return_byte, line_indent):
elif rpc_method.good_response_type:
self._write(
line_indent,
"res_obj = pb2.%s()" % rpc_method.response_class,
"res_obj = self._pb2.%s()" % rpc_method.response_class,
)
self._write(line_indent, "res_obj.ParseFromString(data)")
else:
Expand Down
12 changes: 10 additions & 2 deletions openapiart/tests/test_error_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,15 @@ def create_openapi_artifacts(openapiart_class, sdk=None, file_name=None):


def str_compare(validte_str, entire_str, item):
return validte_str in entire_str and item in entire_str
normalized_entire = (
entire_str.replace("(", "").replace(")", "").replace("'", "")
)
normalized_validate = (
validte_str.replace("(", "").replace(")", "").replace("'", "")
)
return (
normalized_validate in normalized_entire and item in normalized_entire
)


def test_validate_response_default():
Expand Down Expand Up @@ -64,7 +72,7 @@ def test_error_for_missing_required():
file_name="./response/response_missing_required_in_error.yaml",
)
error_value = execinfo.value.args[0]
assert error_msg == error_value
assert str_compare(error_msg, error_value, "Error")


if __name__ == "__main__":
Expand Down
Loading
Loading