Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions tb_plugin/torch_tb_profiler/io/azureblob.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ def __init__(self):
raise ImportError('azure-storage-blob must be installed for Azure Blob support.')
self.connection_string = os.environ.get('AZURE_STORAGE_CONNECTION_STRING', None)

# pyrefly: ignore [bad-param-name-override]
def exists(self, dirname):
"""Returns whether the path is a directory or not."""
basename, parts = self.split_blob_path(dirname)
Expand All @@ -34,6 +35,7 @@ def exists(self, dirname):
else:
return basename == parts[0]

# pyrefly: ignore [bad-param-name-override]
def read(self, filename, binary_mode=False, size=None, continue_from=None):
"""Reads contents of a file to a string."""
logger.info('azure blob: starting reading file %s' % filename)
Expand Down Expand Up @@ -125,6 +127,7 @@ def listdir(self, dirname):
items.append(item)
return items

# pyrefly: ignore [bad-param-name-override]
def makedirs(self, dirname):
"""No need create directory since the upload blob will automatically create"""
pass
Expand All @@ -145,6 +148,7 @@ def walk(self, top, topdown=True, onerror=None):
for blob in blobs:
dirname, basename = self.split(blob.name)
dirname = 'https://{}/{}/{}'.format(account, container, dirname)
# pyrefly: ignore [missing-attribute]
results.setdefault(dirname, []).append(basename)
for key, value in results.items():
yield key, None, value
Expand Down
2 changes: 2 additions & 0 deletions tb_plugin/torch_tb_profiler/io/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,9 @@

class Cache:
def __init__(self, cache_dir=None):
# pyrefly: ignore [missing-attribute]
self._lock = mp.Lock()
# pyrefly: ignore [missing-attribute]
self._manager = mp.Manager()
self._cache_dict = self._manager.dict()
self._cache_dir = cache_dir
Expand Down
8 changes: 6 additions & 2 deletions tb_plugin/torch_tb_profiler/io/file.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,7 @@
logger = utils.get_logger()

try:
# pyre-fixme[21]: Could not find module `boto3`.
import boto3
# pyre-fixme[21]: Could not find module `botocore.exceptions`.
import botocore.exceptions

S3_ENABLED = True
Expand Down Expand Up @@ -101,6 +99,7 @@ def __init__(self):
def exists(self, filename):
return os.path.exists(filename)

# pyrefly: ignore [bad-param-name-override]
def read(self, filename, binary_mode=False, size=None, continue_from=None):
mode = "rb" if binary_mode else "r"
encoding = None if binary_mode else "utf8"
Expand All @@ -113,6 +112,7 @@ def read(self, filename, binary_mode=False, size=None, continue_from=None):
with open(filename, mode, encoding=encoding) as f:
if offset is not None:
f.seek(offset)
# pyrefly: ignore [bad-argument-type]
data = f.read(size)
# The new offset may not be `offset + len(data)`, due to decoding
# and newline translation.
Expand Down Expand Up @@ -210,6 +210,7 @@ def exists(self, filename):
return True
return False

# pyrefly: ignore [bad-param-name-override]
def read(self, filename, binary_mode=False, size=None, continue_from=None):
"""Reads contents of a file to a string."""
s3 = boto3.resource("s3", endpoint_url=self._s3_endpoint)
Expand Down Expand Up @@ -333,6 +334,7 @@ def listdir(self, dirname):
keys.append(key)
return keys

# pyrefly: ignore [bad-param-name-override]
def makedirs(self, dirname):
"""Creates a directory and all parent/intermediate directories."""
if not self.exists(dirname):
Expand Down Expand Up @@ -400,8 +402,10 @@ def __iter__(self):

def _read_buffer_to_offset(self, new_buff_offset):
old_buff_offset = self.buff_offset
# pyrefly: ignore [bad-argument-type]
read_size = min(len(self.buff), new_buff_offset) - old_buff_offset
self.buff_offset += read_size
# pyrefly: ignore [unsupported-operation]
return self.buff[old_buff_offset: old_buff_offset + read_size]

def read(self, n=None):
Expand Down
4 changes: 4 additions & 0 deletions tb_plugin/torch_tb_profiler/io/gs.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,13 +21,15 @@ def __init__(self):
if not storage:
raise ImportError('google-cloud-storage must be installed for Google Cloud Blob support.')

# pyrefly: ignore [bad-param-name-override]
def exists(self, dirname):
"""Returns whether the path is a directory or not."""
bucket_name, path = self.bucket_and_path(dirname)
client = self.create_google_cloud_client()
bucket = client.bucket(bucket_name)
return bucket.blob(path).exists()

# pyrefly: ignore [bad-param-name-override]
def read(self, filename, binary_mode=False, size=None, continue_from=None):
raise NotImplementedError

Expand Down Expand Up @@ -67,6 +69,7 @@ def listdir(self, dirname):
items.append(item)
return items

# pyrefly: ignore [bad-param-name-override]
def makedirs(self, dirname):
"""No need create directory since the upload blob will automatically create"""
pass
Expand All @@ -87,6 +90,7 @@ def walk(self, top, topdown=True, onerror=None):
for blob in blobs:
dirname, basename = self.split(blob.name)
dirname = 'gs://{}/{}'.format(bucket_name, dirname)
# pyrefly: ignore [missing-attribute]
results.setdefault(dirname, []).append(basename)
for key, value in results.items():
yield key, None, value
Expand Down
4 changes: 2 additions & 2 deletions tb_plugin/torch_tb_profiler/io/hdfs.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,13 @@ class HadoopFileSystem(RemotePath, BaseFileSystem):
def __init__(self) -> None:
super().__init__()

# pyre-fixme[11]: Annotation `HadoopFileSystem` is not defined as a type.
def get_fs(self) -> arrow.HadoopFileSystem:
return fsspec.filesystem("hdfs")

def exists(self, filename):
return self.get_fs().exists(filename)

# pyrefly: ignore [bad-param-name-override]
def read(self, filename, binary_mode=False, size=None, continue_from=None):
fs = self.get_fs()
mode = "rb" if binary_mode else "r"
Expand Down Expand Up @@ -69,4 +69,4 @@ def support_append(self):
return False

def download_file(self, file_to_download, file_to_save):
return self.get_fs().download(file_to_download, file_to_save, recursive=True)
return self.get_fs().download(file_to_download, file_to_save, recursive=True)
29 changes: 27 additions & 2 deletions tb_plugin/torch_tb_profiler/plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,6 @@ def wrapper(*args, **kwargs):
exceptions.HTTPException.get_headers = decorate_headers(exceptions.HTTPException.get_headers)


# pyre-fixme[11]: Annotation `TBPlugin` is not defined as a type.
class TorchProfilerPlugin(base_plugin.TBPlugin):
"""TensorBoard plugin for Torch Profiler."""

Expand All @@ -52,7 +51,6 @@ def __init__(self, context: base_plugin.TBContext):
Args:
context: A base_plugin.TBContext instance.
"""
# pyre-fixme[19]: Expected 0 positional arguments.
super(TorchProfilerPlugin, self).__init__(context)
if not context.logdir and context.flags.logdir_spec:
dirs = context.flags.logdir_spec.split(',')
Expand Down Expand Up @@ -133,6 +131,7 @@ def get_plugin_apps(self):
def frontend_metadata(self):
return base_plugin.FrontendMetadata(es_module_path='/index.js', disable_reload=True)

# pyrefly: ignore [bad-argument-type]
@wrappers.Request.application
def runs_route(self, request: werkzeug.Request):
with self._runs_lock:
Expand All @@ -144,6 +143,7 @@ def runs_route(self, request: werkzeug.Request):
}
return self.respond_as_json(data)

# pyrefly: ignore [bad-argument-type]
@wrappers.Request.application
def views_route(self, request: werkzeug.Request):
name = request.args.get('run')
Expand All @@ -152,6 +152,7 @@ def views_route(self, request: werkzeug.Request):
views_list = [view.display_name for view in run.views]
return self.respond_as_json(views_list)

# pyrefly: ignore [bad-argument-type]
@wrappers.Request.application
def workers_route(self, request: werkzeug.Request):
name = request.args.get('run')
Expand All @@ -160,6 +161,7 @@ def workers_route(self, request: werkzeug.Request):
run = self._get_run(name)
return self.respond_as_json(run.get_workers(view))

# pyrefly: ignore [bad-argument-type]
@wrappers.Request.application
def spans_route(self, request: werkzeug.Request):
name = request.args.get('run')
Expand All @@ -168,6 +170,7 @@ def spans_route(self, request: werkzeug.Request):
run = self._get_run(name)
return self.respond_as_json(run.get_spans(worker))

# pyrefly: ignore [bad-argument-type]
@wrappers.Request.application
def overview_route(self, request: werkzeug.Request):
profile = self._get_profile_for_request(request)
Expand All @@ -176,15 +179,18 @@ def overview_route(self, request: werkzeug.Request):
data = profile.overview
is_gpu_used = profile.has_runtime or profile.has_kernel or profile.has_memcpy_or_memset
normal_workers = [worker for worker in run.workers if worker != 'All']
# pyrefly: ignore [unsupported-operation]
data['environments'] = [{'title': 'Number of Worker(s)', 'value': str(len(normal_workers))},
{'title': 'Device Type', 'value': 'GPU' if is_gpu_used else 'CPU'}]
if profile.gpu_summary and profile.gpu_tooltip:
# pyrefly: ignore [unsupported-operation]
data['gpu_metrics'] = {'title': 'GPU Summary',
'data': profile.gpu_summary,
'tooltip': profile.gpu_tooltip}

return self.respond_as_json(data)

# pyrefly: ignore [bad-argument-type]
@wrappers.Request.application
def operation_pie_route(self, request: werkzeug.Request):
profile = self._get_profile_for_request(request)
Expand All @@ -195,6 +201,7 @@ def operation_pie_route(self, request: werkzeug.Request):
else:
return self.respond_as_json(profile.operation_pie_by_name)

# pyrefly: ignore [bad-argument-type]
@wrappers.Request.application
def operation_table_route(self, request: werkzeug.Request):
profile = self._get_profile_for_request(request)
Expand All @@ -205,6 +212,7 @@ def operation_table_route(self, request: werkzeug.Request):
else:
return self.respond_as_json(profile.operation_table_by_name)

# pyrefly: ignore [bad-argument-type]
@wrappers.Request.application
def operation_stack_route(self, request: werkzeug.Request):
profile = self._get_profile_for_request(request)
Expand All @@ -218,12 +226,14 @@ def operation_stack_route(self, request: werkzeug.Request):
else:
return self.respond_as_json(profile.operation_stack_by_name[str(op_name)])

# pyrefly: ignore [bad-argument-type]
@wrappers.Request.application
def kernel_pie_route(self, request: werkzeug.Request):
profile = self._get_profile_for_request(request)

return self.respond_as_json(profile.kernel_pie)

# pyrefly: ignore [bad-argument-type]
@wrappers.Request.application
def kernel_table_route(self, request: werkzeug.Request):
profile = self._get_profile_for_request(request)
Expand All @@ -234,12 +244,14 @@ def kernel_table_route(self, request: werkzeug.Request):
else:
return self.respond_as_json(profile.kernel_op_table)

# pyrefly: ignore [bad-argument-type]
@wrappers.Request.application
def kernel_tc_route(self, request: werkzeug.Request):
profile = self._get_profile_for_request(request)

return self.respond_as_json(profile.tc_pie)

# pyrefly: ignore [bad-argument-type]
@wrappers.Request.application
def trace_route(self, request: werkzeug.Request):
profile = self._get_profile_for_request(request)
Expand Down Expand Up @@ -270,26 +282,31 @@ def trace_route(self, request: werkzeug.Request):
headers.extend(TorchProfilerPlugin.headers)
return werkzeug.Response(raw_data, content_type=TorchProfilerPlugin.CONTENT_TYPE, headers=headers)

# pyrefly: ignore [bad-argument-type]
@wrappers.Request.application
def dist_gpu_info_route(self, request: werkzeug.Request):
profile = self._get_distributed_profile_for_request(request)
return self.respond_as_json(profile.gpu_info)

# pyrefly: ignore [bad-argument-type]
@wrappers.Request.application
def comm_overlap_route(self, request: werkzeug.Request):
profile = self._get_distributed_profile_for_request(request)
return self.respond_as_json(profile.steps_to_overlap)

# pyrefly: ignore [bad-argument-type]
@wrappers.Request.application
def comm_wait_route(self, request: werkzeug.Request):
profile = self._get_distributed_profile_for_request(request)
return self.respond_as_json(profile.steps_to_wait)

# pyrefly: ignore [bad-argument-type]
@wrappers.Request.application
def comm_ops_route(self, request: werkzeug.Request):
profile = self._get_distributed_profile_for_request(request)
return self.respond_as_json(profile.comm_ops)

# pyrefly: ignore [bad-argument-type]
@wrappers.Request.application
def memory_route(self, request: werkzeug.Request):
profile = self._get_profile_for_request(request)
Expand All @@ -304,6 +321,7 @@ def memory_route(self, request: werkzeug.Request):
return self.respond_as_json(
profile.get_memory_stats(start_ts=start_ts, end_ts=end_ts, memory_metric=memory_metric), True)

# pyrefly: ignore [bad-argument-type]
@wrappers.Request.application
def memory_curve_route(self, request: werkzeug.Request):
profile = self._get_profile_for_request(request)
Expand All @@ -312,6 +330,7 @@ def memory_curve_route(self, request: werkzeug.Request):
return self.respond_as_json(
profile.get_memory_curve(time_metric=time_metric, memory_metric=memory_metric), True)

# pyrefly: ignore [bad-argument-type]
@wrappers.Request.application
def memory_events_route(self, request: werkzeug.Request):
profile = self._get_profile_for_request(request)
Expand All @@ -328,6 +347,7 @@ def memory_events_route(self, request: werkzeug.Request):
profile.get_memory_events(start_ts, end_ts, time_metric=time_metric,
memory_metric=memory_metric), True)

# pyrefly: ignore [bad-argument-type]
@wrappers.Request.application
def module_route(self, request: werkzeug.Request):
profile = self._get_profile_for_request(request)
Expand All @@ -340,19 +360,22 @@ def module_route(self, request: werkzeug.Request):
span = request.args.get('span')
raise exceptions.NotFound('could not find the run for %s/%s/%s' % (name, worker, span))

# pyrefly: ignore [bad-argument-type]
@wrappers.Request.application
def op_tree_route(self, request: werkzeug.Request):
profile = self._get_profile_for_request(request)
content = profile.get_operator_tree()
return self.respond_as_json(content, True)

# pyrefly: ignore [bad-argument-type]
@wrappers.Request.application
def diff_run_route(self, request: werkzeug.Request):
base, exp = self.get_diff_runs(request)
diff_stats = self.get_diff_status(base, exp)
content = diff_stats.get_diff_tree_summary()
return self.respond_as_json(content, True)

# pyrefly: ignore [bad-argument-type]
@wrappers.Request.application
def diff_run_node_route(self, request: werkzeug.Request):
base, exp = self.get_diff_runs(request)
Expand All @@ -364,6 +387,7 @@ def diff_run_node_route(self, request: werkzeug.Request):
content = diff_stat.get_diff_node_summary(path)
return self.respond_as_json(content, True)

# pyrefly: ignore [bad-argument-type]
@wrappers.Request.application
def static_file_route(self, request: werkzeug.Request):
filename = os.path.basename(request.path)
Expand Down Expand Up @@ -511,6 +535,7 @@ def _load_run(self, run_dir):
logger.info('Run %s loaded', name)
self._queue.put(run)
except Exception as ex:
# pyrefly: ignore [unbound-name]
logger.warning('Failed to load run %s. Exception=%s', ex, name, exc_info=True)

t = threading.current_thread()
Expand Down
3 changes: 3 additions & 0 deletions tb_plugin/torch_tb_profiler/profiler/communication.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@ def analyze_communication_nodes(comm_node_list: List[CommunicationNode])\
total_size = 1
for size in comm_node.input_shape[i]:
total_size *= size
# pyrefly: ignore [unsupported-operation]
total_comm_stats[comm_node.name][1] += total_size * bytes_one_value
total_comm_stats[comm_node.name][2].extend(comm_node.kernel_ranges)
total_comm_stats[comm_node.name][3].extend(comm_node.real_time_ranges)
Expand All @@ -100,7 +101,9 @@ def analyze_communication_nodes(comm_node_list: List[CommunicationNode])\
]

for _, stats in total_comm_stats.items():
# pyrefly: ignore [unsupported-operation]
stats[2] = get_ranges_sum(merge_ranges(stats[2]))
# pyrefly: ignore [unsupported-operation]
stats[3] = get_ranges_sum(merge_ranges(stats[3]))

# pyre-fixme[7]: Expected `Tuple[Dict[str, Tuple[int, int]], Dict[str,
Expand Down
Loading
Loading