Source code for jgdv.debugging.malloc_tool
1#!/usr/bin/env python3
2"""
3
4See EOF for license/metadata/notes as applicable
5"""
6
7# Imports:
8from __future__ import annotations
9
10# ##-- stdlib imports
11import datetime
12import enum
13import fnmatch
14import functools as ftz
15import itertools as itz
16import linecache
17import logging as logmod
18import pathlib as pl
19import re
20import time
21import tracemalloc
22import weakref
23from uuid import UUID, uuid1
24
25# ##-- end stdlib imports
26
27import stackprinter
28import traceback
29
30# ##-- types
31# isort: off
32# General
33import abc
34import collections.abc
35import typing
36import types
37from typing import cast, assert_type, assert_never
38from typing import Generic, NewType, Never
39from typing import no_type_check, final, override, overload
40from typing import Concatenate as Cons
41# Protocols and Interfaces:
42from typing import Protocol, runtime_checkable
43if typing.TYPE_CHECKING:
44 from typing import Final, ClassVar, Any, Self
45 from typing import Literal, LiteralString
46 from typing import TypeGuard
47 from collections.abc import Iterable, Iterator, Callable, Generator
48 from collections.abc import Sequence, Mapping, MutableMapping, Hashable
49
50 from jgdv import Maybe, Traceback
51
52# isort: on
53# ##-- end types
54
55##-- logging
56logging = logmod.getLogger(__name__)
57##-- end logging
58
59STAT_FORMS : Final[tuple[str, ...]] = ("traceback", "filename", "lineno")
60INIT_SNAP_NAME : Final[str] = "_init_"
61FINAL_SNAP_NAME : Final[str] = "_final_"
62
[docs]
63def must_be_started[**I, O](fn:Callable[Cons[MallocTool, I],O]) -> Callable[Cons[MallocTool, I], O]:
64 return fn
65
66 @ftz.wraps
67 def _check(self:MallocTool, *args:I.args, **kwargs:I.kwargs) -> O:
68 assert(self.started)
69 return fn(self, *args, **kwargs)
70
71 return _check
72
73##--|
74
[docs]
75class MallocTool:
76 r""" see `tracemalloc <https://docs.python.org/3/library/tracemalloc.html>`_
77 in the stdlib. eg:
78
79 ::
80
81 with MallocTool(frame_count=2) as dm:
82 dm.whitelist(__file__)
83 dm.blacklist("\*tracemalloc.py", all_frames=False)
84 val = 2
85 dm.snapshot("simple")
86 vals = [random.random() for x in range(1000)]
87 dm.current()
88 dm.snapshot("list")
89 vals = None
90 dm.current()
91 dm.snapshot("cleared")
92
93 dm.compare("simple", "list")
94 dm.compare("list", "cleared")
95 dm.compare("list", "simple")
96 dm.inspect("list")
97
98 """
99 frame_count : int
100 started : bool
101 snapshots : list[tracemalloc.Snapshot]
102 named_snapshots : dict[str, tracemalloc.Snapshot]
103 filters : list[tracemalloc.Filter]
104
105 _logger : logmod.Logger
106 _curr_mem_msg : str
107 _allocation_loc_msg : str
108 _inspect_msg : str
109 _cmp_enter_msg : str
110 _change_msg : str
111 _diff_msg : str
112 _stat_line_msg : str
113 _enter_msg : str
114 _exit_msg : str
115 _take_snap_msg : str
116
117 def __init__(self, *, frame_count:int=5, logger:Maybe[logmod.Logger]=None) -> None:
118 assert(0 < frame_count)
119 self._logger = logger or logging
120 self.frame_count = frame_count
121 self.started = False
122 self.snapshots = []
123 self.named_snapshots = {}
124 self.filters = []
125 self.blacklist("*tracemalloc.py", all_frames=False)
126 self.blacklist(__file__)
127 ##--| Messages:
128 self._enter_msg = "[TraceMalloc]: --> Entering, tracking %s frames"
129 self._exit_msg = "[TraceMalloc]: <-- Exited, with %s snapshots"
130 self._take_snap_msg = "[TraceMalloc]: Taking Snapshot: %-15s (Current: %-10s, Peak: %s)"
131 self._curr_mem_msg = "[TraceMalloc]: Memory: (Current: %-10s, Peak: %s)"
132 self._allocation_loc_msg = "[TraceMalloc]: Value Allocated At: %s"
133 self._inspect_msg = "[TraceMalloc]: ---- Inspecting: %s ----"
134 self._cmp_enter_msg = "[TraceMalloc]: ---- Comparing (%s): %s -> %s. Objects:%s ----"
135 self._gen_exit_msg = "[TraceMalloc]: -- %s --"
136 self._diff_msg = "[TraceMalloc]: -- (obj:%s) delta: %s, %s blocks --"
137 self._stat_line_msg = "[TraceMalloc]: (obj:%s, frame:%3s) : %-50s (%s:%s)"
138 self._stat_line_no_frames_msg = "[TraceMalloc]: (obj:%s) %-15s : %-50s (%s:%s)"
139
140 def __enter__(self) -> Self:
141 """ Ctx handler to start tracing object allocations """
142 self._logger.info(self._enter_msg, self.frame_count)
143 tracemalloc.start(self.frame_count)
144 self.started = True
145 self.snapshot(INIT_SNAP_NAME)
146 return self
147
148 @must_be_started
149 def __exit__(self, etype:Maybe[type], err:Maybe[Exception], tb:Maybe[Traceback]) -> bool: # type: ignore[exit-return]
150 """ Stop tracing allocations """
151 self.snapshot(FINAL_SNAP_NAME)
152 tracemalloc.stop()
153 self.started = False
154 self._logger.info(self._exit_msg, len(self.snapshots))
155 return False
156
157 ##--| Setup
158
[docs]
159 def whitelist(self, file_pat:str, *, lineno:Maybe[int]=None, all_frames:bool=True) -> Self:
160 """ Add a filter to whitelist a file pattern """
161 self.filters.append(
162 tracemalloc.Filter(True, # noqa: FBT003
163 filename_pattern=file_pat,
164 lineno=lineno,
165 all_frames=all_frames),
166 )
167 return self
168
[docs]
169 def blacklist(self, file_pat:str, *, lineno:Maybe[int]=None, all_frames:bool=True) -> Self:
170 """ Blacklist a file pattern """
171 self.filters.append(
172 tracemalloc.Filter(False, # noqa: FBT003
173 filename_pattern=file_pat,
174 lineno=lineno,
175 all_frames=all_frames),
176 )
177 return self
178
179 ##--| Control
180
[docs]
181 @must_be_started
182 def snapshot(self, name:Maybe[str]=None) -> None:
183 """ Take a snapshot of the current memory state """
184 traced : Maybe[tuple]
185 ##--|
186 traced = tracemalloc.get_traced_memory()
187 logging.info(self._take_snap_msg, name, self._human(traced[0]), self._human(traced[1]))
188 traced = None
189 snap = tracemalloc.take_snapshot()
190 self.snapshots.append(snap)
191 if name and name not in self.named_snapshots:
192 self.named_snapshots[name] = snap
193
194 tracemalloc.clear_traces()
195
196 ##--| Report
197
[docs]
198 @must_be_started
199 def current(self, val:Maybe[object]=None) -> None:
200 """ Print a brief report about the current memory state """
201 traced = tracemalloc.get_traced_memory()
202 self._logger.info(self._curr_mem_msg, self._human(traced[0]), self._human(traced[1]))
203 if val:
204 self._logger.info(self._allocation_loc_msg, tracemalloc.get_object_traceback(val))
205
[docs]
206 def inspect(self, val:int|str, *, form:str="traceback", filter:bool=True, fullpath:bool=False) -> None: # noqa: A002
207 """ Inspect a single snapshot of the memory state """
208 assert(form in STAT_FORMS)
209 self._logger.info(self._inspect_msg, val)
210 snap = self._get_snapshot(val, filter=filter)
211 for stat in snap.statistics(form):
212 self._print_obj_stat_frames(stat, fullpath=fullpath)
213 else:
214 self._logger.info(self._gen_exit_msg, "inspect")
215
[docs]
216 def compare(self, val1:int|str, val2:int|str, *, form:str="traceback", filter:bool=True, fullpath:bool=False, count:int=10) -> None: # noqa: A002, ARG002, PLR0913
217 """ Compare two snapshots,
218 with control over filtering, output formatting,
219 and the number of objects to report about
220
221 """
222 differences : list[tracemalloc.StatisticDiff]
223 assert(form in STAT_FORMS)
224 snap1 = self._get_snapshot(val1, filter=filter)
225 snap2 = self._get_snapshot(val2, filter=filter)
226
227 if 1 < self.frame_count:
228 printer = self._print_diff_frames
229 else:
230 printer = self._print_diff_noframes
231
232 differences = snap2.compare_to(snap1, form)
233 # TODO differences = self._get_top_n(differences, count=count)
234 diff_count = len(differences)
235 self._logger.info(self._cmp_enter_msg, form, val1, val2, diff_count)
236 for i, stat in enumerate(differences):
237 printer(stat, idx=i, fullpath=fullpath)
238 else:
239 self._logger.info(self._gen_exit_msg, f"Compare ({diff_count}/{diff_count})")
240
241 ##--| utils
242
[docs]
243 def _print_diff_noframes(self, stat:tracemalloc.StatisticDiff, *, idx:Maybe[int]=None, fullpath:bool=False) -> None:
244 """ Print a diff without showing the stacktrace """
245 assert(isinstance(stat, tracemalloc.StatisticDiff))
246 tb = stat.traceback
247 frame = tb[-1]
248 size_change = self._human(stat.size, sign=True)
249 if fullpath:
250 path = frame.filename
251 else:
252 path = pl.Path(frame.filename).name
253 self._logger.info(self._stat_line_no_frames_msg,
254 idx,
255 size_change,
256 linecache.getline(frame.filename, frame.lineno).strip(),
257 path,
258 frame.lineno,
259 )
260
[docs]
261 def _print_diff_frames(self, stat:tracemalloc.StatisticDiff, *, idx:Maybe[int]=None, fullpath:bool=False) -> None:
262 """ Print a diff, with stacktrace """
263 assert(isinstance(stat, tracemalloc.StatisticDiff))
264 self._logger.info(self._diff_msg, idx, self._human(stat.size_diff, sign=True), stat.count_diff)
265 self._print_obj_stat_frames(stat, idx=idx, fullpath=fullpath)
266
[docs]
267 def _print_obj_stat_frames(self, stat:tracemalloc.Statistic|tracemalloc.StatisticDiff, *, idx:Maybe[int]=None, fullpath:bool=False) -> None:
268 """ Print a stacktrace for a a given object diff """
269 assert(isinstance(stat, tracemalloc.Statistic|tracemalloc.StatisticDiff))
270 tb = stat.traceback
271 total = len(tb)-1
272 for i, frame in enumerate(tb):
273 if fullpath:
274 path = frame.filename
275 else:
276 path = pl.Path(frame.filename).name
277 self._logger.info(self._stat_line_msg,
278 idx,
279 i-total,
280 linecache.getline(frame.filename, frame.lineno).strip(),
281 path,
282 frame.lineno,
283 )
284 else:
285 pass
286
[docs]
287 def _human(self, num:int, *, sign:bool=False) -> str:
288 """ Format a sized number in a human readable way. optionally with a sign prefix """
289 return cast("str", tracemalloc._format_size(num, sign)) # type: ignore[attr-defined]
290
[docs]
291 def _get_snapshot(self, val:int|str, *, filter:bool=True) -> tracemalloc.Snapshot: # noqa: A002
292 """ Retrieve a snapshot,
293 with control of whether it is filtered or not
294 """
295 match val:
296 case int() if 0 <= val < len(self.snapshots):
297 snap = self.snapshots[val]
298 case int() if val < 0:
299 snap = self.snapshots[val]
300 case str() if val in self.named_snapshots:
301 snap = self.named_snapshots[val]
302 case _:
303 raise TypeError(val)
304
305 if filter:
306 return snap.filter_traces(self.filters)
307
308 return snap
309
[docs]
310 def _check_file_pat(self, file_pat:str, file_name:str) -> bool:
311 return fnmatch.fnmatch(file_name, file_pat)
312
[docs]
313 def _get_top_n(self, stats:list[tracemalloc.StatisticDiff], count:int=10) -> list[tracemalloc.StatisticDiff]:
314 r""" Get the top {count} sized objects of a difference """
315 raise NotImplementedError()