-
Notifications
You must be signed in to change notification settings - Fork 1.2k
Cli update #1666
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: dev
Are you sure you want to change the base?
Cli update #1666
Changes from 2 commits
f764627
055601f
f3f1872
995cfcc
89decfd
4edf52f
cba4e3f
4d49cff
cc72cec
d7f91ca
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
|
|
@@ -16,7 +16,7 @@ | |||||
| from itertools import islice | ||||||
| from logging import getLogger | ||||||
| from pathlib import Path | ||||||
| from typing import Optional, Union | ||||||
| from typing import Any, Callable, Optional, Union | ||||||
|
|
||||||
| from deeppavlov.core.commands.utils import import_packages, parse_config | ||||||
| from deeppavlov.core.common.chainer import Chainer | ||||||
|
|
@@ -28,8 +28,13 @@ | |||||
| log = getLogger(__name__) | ||||||
|
|
||||||
|
|
||||||
| def build_model(config: Union[str, Path, dict], mode: str = 'infer', | ||||||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Revert unnecessary style changes here and below. If using black (I guess), we should use it on whole code base.
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. UPD: make black line lenght 120, and commit black config |
||||||
| load_trained: bool = False, install: bool = False, download: bool = False) -> Chainer: | ||||||
| def build_model( | ||||||
| config: Union[str, Path, dict], | ||||||
| mode: str = "infer", | ||||||
| load_trained: bool = False, | ||||||
| install: bool = False, | ||||||
| download: bool = False, | ||||||
| ) -> Chainer: | ||||||
| """Build and return the model described in corresponding configuration file.""" | ||||||
| config = parse_config(config) | ||||||
|
|
||||||
|
|
@@ -38,66 +43,103 @@ def build_model(config: Union[str, Path, dict], mode: str = 'infer', | |||||
| if download: | ||||||
| deep_download(config) | ||||||
|
|
||||||
| import_packages(config.get('metadata', {}).get('imports', [])) | ||||||
| import_packages(config.get("metadata", {}).get("imports", [])) | ||||||
|
|
||||||
| model_config = config['chainer'] | ||||||
| model_config = config["chainer"] | ||||||
|
|
||||||
| model = Chainer(model_config['in'], model_config['out'], model_config.get('in_y')) | ||||||
| model = Chainer(model_config["in"], model_config["out"], model_config.get("in_y")) | ||||||
|
|
||||||
| for component_config in model_config['pipe']: | ||||||
| if load_trained and ('fit_on' in component_config or 'in_y' in component_config): | ||||||
| for component_config in model_config["pipe"]: | ||||||
| if load_trained and ( | ||||||
| "fit_on" in component_config or "in_y" in component_config | ||||||
| ): | ||||||
| try: | ||||||
| component_config['load_path'] = component_config['save_path'] | ||||||
| component_config["load_path"] = component_config["save_path"] | ||||||
| except KeyError: | ||||||
| log.warning('No "save_path" parameter for the {} component, so "load_path" will not be renewed' | ||||||
| .format(component_config.get('class_name', component_config.get('ref', 'UNKNOWN')))) | ||||||
| log.warning( | ||||||
| 'No "save_path" parameter for the {} component, so "load_path" will not be renewed'.format( | ||||||
| component_config.get( | ||||||
| "class_name", component_config.get("ref", "UNKNOWN") | ||||||
| ) | ||||||
| ) | ||||||
| ) | ||||||
|
|
||||||
| component = from_params(component_config, mode=mode) | ||||||
|
|
||||||
| if 'id' in component_config: | ||||||
| model._components_dict[component_config['id']] = component | ||||||
| if "id" in component_config: | ||||||
| model._components_dict[component_config["id"]] = component | ||||||
|
|
||||||
| if 'in' in component_config: | ||||||
| c_in = component_config['in'] | ||||||
| c_out = component_config['out'] | ||||||
| in_y = component_config.get('in_y', None) | ||||||
| main = component_config.get('main', False) | ||||||
| if "in" in component_config: | ||||||
| c_in = component_config["in"] | ||||||
| c_out = component_config["out"] | ||||||
| in_y = component_config.get("in_y", None) | ||||||
| main = component_config.get("main", False) | ||||||
| model.append(component, c_in, c_out, in_y, main) | ||||||
|
|
||||||
| return model | ||||||
|
|
||||||
|
|
||||||
| def end_repl_mode(function: Callable[..., Any]) -> Callable[..., Any]: | ||||||
| """Decorator for processing ctrl-c, ctrl-d pressing.""" | ||||||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||
|
|
||||||
| def wrapper(*args: Any, **kwargs: Any): | ||||||
| try: | ||||||
| return function(*args, **kwargs) | ||||||
| except (KeyboardInterrupt, EOFError): | ||||||
| print("\nExit repl mode.") | ||||||
| sys.exit(0) | ||||||
|
|
||||||
| return wrapper | ||||||
|
|
||||||
|
|
||||||
| @end_repl_mode | ||||||
| def interact_model(config: Union[str, Path, dict]) -> None: | ||||||
| """Start interaction with the model described in corresponding configuration file.""" | ||||||
| model = build_model(config) | ||||||
|
|
||||||
| print("\nExit repl - type q and press enter, or press ctrl-c, or ctrl-d.") | ||||||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Replace repl with something more understandable for regular user. What other app usually write? |
||||||
|
|
||||||
| def input_data(prompt: str) -> tuple[str]: | ||||||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||||||
| """Filter and processing input data.""" | ||||||
| while True: | ||||||
| data: str = input(f"\033[34m\033[107m{prompt}:\033[0m ") | ||||||
| if data == "": | ||||||
| continue | ||||||
| if data.isspace(): | ||||||
| continue | ||||||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Remove continue. We should allow user to remain some arguments blank.
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. After removing continue, code from input_data could be moved from separated function back |
||||||
| if data.strip() == "q": | ||||||
| print("\nExit repl mode.") | ||||||
| sys.exit(0) | ||||||
|
|
||||||
| return (data,) | ||||||
|
|
||||||
| while True: | ||||||
| args = [] | ||||||
| arguments: list[tuple[str]] = [] | ||||||
|
IgnatovFedor marked this conversation as resolved.
Outdated
|
||||||
| for in_x in model.in_x: | ||||||
| args.append((input('{}::'.format(in_x)),)) | ||||||
| # check for exit command | ||||||
| if args[-1][0] in {'exit', 'stop', 'quit', 'q'}: | ||||||
| return | ||||||
| data = input_data(in_x) | ||||||
| arguments.append(data) | ||||||
|
|
||||||
| pred = model(*args) | ||||||
| pred = model(*arguments) | ||||||
| if len(model.out_params) > 1: | ||||||
| pred = zip(*pred) | ||||||
|
|
||||||
| print('>>', *pred) | ||||||
| print("==> ", *pred) | ||||||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why? I would understand |
||||||
|
|
||||||
|
|
||||||
| def predict_on_stream(config: Union[str, Path, dict], | ||||||
| batch_size: Optional[int] = None, | ||||||
| file_path: Optional[str] = None) -> None: | ||||||
| def predict_on_stream( | ||||||
| config: Union[str, Path, dict], | ||||||
| batch_size: Optional[int] = None, | ||||||
| file_path: Optional[str] = None, | ||||||
| ) -> None: | ||||||
| """Make a prediction with the component described in corresponding configuration file.""" | ||||||
|
|
||||||
| batch_size = batch_size or 1 | ||||||
| if file_path is None or file_path == '-': | ||||||
| if file_path is None or file_path == "-": | ||||||
| if sys.stdin.isatty(): | ||||||
| raise RuntimeError('To process data from terminal please use interact mode') | ||||||
| raise RuntimeError("To process data from terminal please use interact mode") | ||||||
| f = sys.stdin | ||||||
| else: | ||||||
| f = open(file_path, encoding='utf8') | ||||||
| f = open(file_path, encoding="utf8") | ||||||
|
|
||||||
| model: Chainer = build_model(config) | ||||||
|
|
||||||
|
|
||||||
Uh oh!
There was an error while loading. Please reload this page.