Source code for jgdv.debugging.malloc_tool

  1#!/usr/bin/env python3
  2"""
  3
  4See EOF for license/metadata/notes as applicable
  5"""
  6
  7# Imports:
  8from __future__ import annotations
  9
 10# ##-- stdlib imports
 11import datetime
 12import enum
 13import fnmatch
 14import functools as ftz
 15import itertools as itz
 16import linecache
 17import logging as logmod
 18import pathlib as pl
 19import re
 20import time
 21import tracemalloc
 22import weakref
 23from uuid import UUID, uuid1
 24
 25# ##-- end stdlib imports
 26
 27import stackprinter
 28import traceback
 29
 30# ##-- types
 31# isort: off
 32# General
 33import abc
 34import collections.abc
 35import typing
 36import types
 37from typing import cast, assert_type, assert_never
 38from typing import Generic, NewType, Never
 39from typing import no_type_check, final, override, overload
 40from typing import Concatenate as Cons
 41# Protocols and Interfaces:
 42from typing import Protocol, runtime_checkable
 43if typing.TYPE_CHECKING:
 44    from typing import Final, ClassVar, Any, Self
 45    from typing import Literal, LiteralString
 46    from typing import TypeGuard
 47    from collections.abc import Iterable, Iterator, Callable, Generator
 48    from collections.abc import Sequence, Mapping, MutableMapping, Hashable
 49
 50    from jgdv import Maybe, Traceback
 51
 52# isort: on
 53# ##-- end types
 54
 55##-- logging
 56logging = logmod.getLogger(__name__)
 57##-- end logging
 58
 59STAT_FORMS       : Final[tuple[str, ...]] = ("traceback", "filename", "lineno")
 60INIT_SNAP_NAME   : Final[str] = "_init_"
 61FINAL_SNAP_NAME  : Final[str] = "_final_"
 62
[docs] 63def must_be_started[**I, O](fn:Callable[Cons[MallocTool, I],O]) -> Callable[Cons[MallocTool, I], O]: 64 return fn 65 66 @ftz.wraps 67 def _check(self:MallocTool, *args:I.args, **kwargs:I.kwargs) -> O: 68 assert(self.started) 69 return fn(self, *args, **kwargs) 70 71 return _check
72 73##--| 74
[docs] 75class MallocTool: 76 r""" see `tracemalloc <https://docs.python.org/3/library/tracemalloc.html>`_ 77 in the stdlib. eg: 78 79 :: 80 81 with MallocTool(frame_count=2) as dm: 82 dm.whitelist(__file__) 83 dm.blacklist("\*tracemalloc.py", all_frames=False) 84 val = 2 85 dm.snapshot("simple") 86 vals = [random.random() for x in range(1000)] 87 dm.current() 88 dm.snapshot("list") 89 vals = None 90 dm.current() 91 dm.snapshot("cleared") 92 93 dm.compare("simple", "list") 94 dm.compare("list", "cleared") 95 dm.compare("list", "simple") 96 dm.inspect("list") 97 98 """ 99 frame_count : int 100 started : bool 101 snapshots : list[tracemalloc.Snapshot] 102 named_snapshots : dict[str, tracemalloc.Snapshot] 103 filters : list[tracemalloc.Filter] 104 105 _logger : logmod.Logger 106 _curr_mem_msg : str 107 _allocation_loc_msg : str 108 _inspect_msg : str 109 _cmp_enter_msg : str 110 _change_msg : str 111 _diff_msg : str 112 _stat_line_msg : str 113 _enter_msg : str 114 _exit_msg : str 115 _take_snap_msg : str 116 117 def __init__(self, *, frame_count:int=5, logger:Maybe[logmod.Logger]=None) -> None: 118 assert(0 < frame_count) 119 self._logger = logger or logging 120 self.frame_count = frame_count 121 self.started = False 122 self.snapshots = [] 123 self.named_snapshots = {} 124 self.filters = [] 125 self.blacklist("*tracemalloc.py", all_frames=False) 126 self.blacklist(__file__) 127 ##--| Messages: 128 self._enter_msg = "[TraceMalloc]: --> Entering, tracking %s frames" 129 self._exit_msg = "[TraceMalloc]: <-- Exited, with %s snapshots" 130 self._take_snap_msg = "[TraceMalloc]: Taking Snapshot: %-15s (Current: %-10s, Peak: %s)" 131 self._curr_mem_msg = "[TraceMalloc]: Memory: (Current: %-10s, Peak: %s)" 132 self._allocation_loc_msg = "[TraceMalloc]: Value Allocated At: %s" 133 self._inspect_msg = "[TraceMalloc]: ---- Inspecting: %s ----" 134 self._cmp_enter_msg = "[TraceMalloc]: ---- Comparing (%s): %s -> %s. Objects:%s ----" 135 self._gen_exit_msg = "[TraceMalloc]: -- %s --" 136 self._diff_msg = "[TraceMalloc]: -- (obj:%s) delta: %s, %s blocks --" 137 self._stat_line_msg = "[TraceMalloc]: (obj:%s, frame:%3s) : %-50s (%s:%s)" 138 self._stat_line_no_frames_msg = "[TraceMalloc]: (obj:%s) %-15s : %-50s (%s:%s)" 139 140 def __enter__(self) -> Self: 141 """ Ctx handler to start tracing object allocations """ 142 self._logger.info(self._enter_msg, self.frame_count) 143 tracemalloc.start(self.frame_count) 144 self.started = True 145 self.snapshot(INIT_SNAP_NAME) 146 return self 147 148 @must_be_started 149 def __exit__(self, etype:Maybe[type], err:Maybe[Exception], tb:Maybe[Traceback]) -> bool: # type: ignore[exit-return] 150 """ Stop tracing allocations """ 151 self.snapshot(FINAL_SNAP_NAME) 152 tracemalloc.stop() 153 self.started = False 154 self._logger.info(self._exit_msg, len(self.snapshots)) 155 return False 156 157 ##--| Setup 158
[docs] 159 def whitelist(self, file_pat:str, *, lineno:Maybe[int]=None, all_frames:bool=True) -> Self: 160 """ Add a filter to whitelist a file pattern """ 161 self.filters.append( 162 tracemalloc.Filter(True, # noqa: FBT003 163 filename_pattern=file_pat, 164 lineno=lineno, 165 all_frames=all_frames), 166 ) 167 return self
168
[docs] 169 def blacklist(self, file_pat:str, *, lineno:Maybe[int]=None, all_frames:bool=True) -> Self: 170 """ Blacklist a file pattern """ 171 self.filters.append( 172 tracemalloc.Filter(False, # noqa: FBT003 173 filename_pattern=file_pat, 174 lineno=lineno, 175 all_frames=all_frames), 176 ) 177 return self
178 179 ##--| Control 180
[docs] 181 @must_be_started 182 def snapshot(self, name:Maybe[str]=None) -> None: 183 """ Take a snapshot of the current memory state """ 184 traced : Maybe[tuple] 185 ##--| 186 traced = tracemalloc.get_traced_memory() 187 logging.info(self._take_snap_msg, name, self._human(traced[0]), self._human(traced[1])) 188 traced = None 189 snap = tracemalloc.take_snapshot() 190 self.snapshots.append(snap) 191 if name and name not in self.named_snapshots: 192 self.named_snapshots[name] = snap 193 194 tracemalloc.clear_traces()
195 196 ##--| Report 197
[docs] 198 @must_be_started 199 def current(self, val:Maybe[object]=None) -> None: 200 """ Print a brief report about the current memory state """ 201 traced = tracemalloc.get_traced_memory() 202 self._logger.info(self._curr_mem_msg, self._human(traced[0]), self._human(traced[1])) 203 if val: 204 self._logger.info(self._allocation_loc_msg, tracemalloc.get_object_traceback(val))
205
[docs] 206 def inspect(self, val:int|str, *, form:str="traceback", filter:bool=True, fullpath:bool=False) -> None: # noqa: A002 207 """ Inspect a single snapshot of the memory state """ 208 assert(form in STAT_FORMS) 209 self._logger.info(self._inspect_msg, val) 210 snap = self._get_snapshot(val, filter=filter) 211 for stat in snap.statistics(form): 212 self._print_obj_stat_frames(stat, fullpath=fullpath) 213 else: 214 self._logger.info(self._gen_exit_msg, "inspect")
215
[docs] 216 def compare(self, val1:int|str, val2:int|str, *, form:str="traceback", filter:bool=True, fullpath:bool=False, count:int=10) -> None: # noqa: A002, ARG002, PLR0913 217 """ Compare two snapshots, 218 with control over filtering, output formatting, 219 and the number of objects to report about 220 221 """ 222 differences : list[tracemalloc.StatisticDiff] 223 assert(form in STAT_FORMS) 224 snap1 = self._get_snapshot(val1, filter=filter) 225 snap2 = self._get_snapshot(val2, filter=filter) 226 227 if 1 < self.frame_count: 228 printer = self._print_diff_frames 229 else: 230 printer = self._print_diff_noframes 231 232 differences = snap2.compare_to(snap1, form) 233 # TODO differences = self._get_top_n(differences, count=count) 234 diff_count = len(differences) 235 self._logger.info(self._cmp_enter_msg, form, val1, val2, diff_count) 236 for i, stat in enumerate(differences): 237 printer(stat, idx=i, fullpath=fullpath) 238 else: 239 self._logger.info(self._gen_exit_msg, f"Compare ({diff_count}/{diff_count})")
240 241 ##--| utils 242
[docs] 243 def _print_diff_noframes(self, stat:tracemalloc.StatisticDiff, *, idx:Maybe[int]=None, fullpath:bool=False) -> None: 244 """ Print a diff without showing the stacktrace """ 245 assert(isinstance(stat, tracemalloc.StatisticDiff)) 246 tb = stat.traceback 247 frame = tb[-1] 248 size_change = self._human(stat.size, sign=True) 249 if fullpath: 250 path = frame.filename 251 else: 252 path = pl.Path(frame.filename).name 253 self._logger.info(self._stat_line_no_frames_msg, 254 idx, 255 size_change, 256 linecache.getline(frame.filename, frame.lineno).strip(), 257 path, 258 frame.lineno, 259 )
260
[docs] 261 def _print_diff_frames(self, stat:tracemalloc.StatisticDiff, *, idx:Maybe[int]=None, fullpath:bool=False) -> None: 262 """ Print a diff, with stacktrace """ 263 assert(isinstance(stat, tracemalloc.StatisticDiff)) 264 self._logger.info(self._diff_msg, idx, self._human(stat.size_diff, sign=True), stat.count_diff) 265 self._print_obj_stat_frames(stat, idx=idx, fullpath=fullpath)
266
[docs] 267 def _print_obj_stat_frames(self, stat:tracemalloc.Statistic|tracemalloc.StatisticDiff, *, idx:Maybe[int]=None, fullpath:bool=False) -> None: 268 """ Print a stacktrace for a a given object diff """ 269 assert(isinstance(stat, tracemalloc.Statistic|tracemalloc.StatisticDiff)) 270 tb = stat.traceback 271 total = len(tb)-1 272 for i, frame in enumerate(tb): 273 if fullpath: 274 path = frame.filename 275 else: 276 path = pl.Path(frame.filename).name 277 self._logger.info(self._stat_line_msg, 278 idx, 279 i-total, 280 linecache.getline(frame.filename, frame.lineno).strip(), 281 path, 282 frame.lineno, 283 ) 284 else: 285 pass
286
[docs] 287 def _human(self, num:int, *, sign:bool=False) -> str: 288 """ Format a sized number in a human readable way. optionally with a sign prefix """ 289 return cast("str", tracemalloc._format_size(num, sign)) # type: ignore[attr-defined]
290
[docs] 291 def _get_snapshot(self, val:int|str, *, filter:bool=True) -> tracemalloc.Snapshot: # noqa: A002 292 """ Retrieve a snapshot, 293 with control of whether it is filtered or not 294 """ 295 match val: 296 case int() if 0 <= val < len(self.snapshots): 297 snap = self.snapshots[val] 298 case int() if val < 0: 299 snap = self.snapshots[val] 300 case str() if val in self.named_snapshots: 301 snap = self.named_snapshots[val] 302 case _: 303 raise TypeError(val) 304 305 if filter: 306 return snap.filter_traces(self.filters) 307 308 return snap
309
[docs] 310 def _check_file_pat(self, file_pat:str, file_name:str) -> bool: 311 return fnmatch.fnmatch(file_name, file_pat)
312
[docs] 313 def _get_top_n(self, stats:list[tracemalloc.StatisticDiff], count:int=10) -> list[tracemalloc.StatisticDiff]: 314 r""" Get the top {count} sized objects of a difference """ 315 raise NotImplementedError()