Coverage for fastapi_docx/exception_finder.py: 93%
266 statements
« prev ^ index » next coverage.py v7.6.0, created at 2024-07-24 20:22 +0000
« prev ^ index » next coverage.py v7.6.0, created at 2024-07-24 20:22 +0000
1import ast
2import importlib
3import inspect
4import textwrap
5import typing
6from collections.abc import Callable, Generator, Iterable
7from itertools import chain
8from types import ModuleType
9from typing import Any, TypeVar
11from fastapi import status # noqa
12from fastapi.routing import APIRoute
13from starlette.exceptions import HTTPException
15ErrType = TypeVar("ErrType", bound=Exception)
16AstType = TypeVar("AstType", bound=ast.AST)
19def find_nodes(
20 tree: ast.Module,
21 node_type: type[AstType],
22 attr_filter: dict[str, type[ast.AST]] | None = None,
23) -> Generator[AstType, None, None]:
24 if attr_filter:
25 for node in ast.walk(tree):
26 if isinstance(node, node_type) and all(
27 hasattr(node, node_attr)
28 and isinstance(getattr(node, node_attr), attr_type)
29 for node_attr, attr_type in attr_filter.items()
30 ):
31 yield node
32 else:
33 for node in ast.walk(tree):
34 if isinstance(node, node_type):
35 yield node
38def is_function_or_coroutine(obj: Any) -> bool:
39 return inspect.isfunction(obj) or inspect.iscoroutinefunction(obj)
42def is_subclass_of_any(klass: type, classes: Iterable[type] | Iterable[str]) -> bool:
43 try:
44 base_names = (
45 [base.__name__ for base in klass.__bases__]
46 if hasattr(klass, "__bases__")
47 else []
48 )
49 base_names.append(type(klass).__name__)
50 if all(isinstance(cls, str) for cls in classes):
51 class_names = classes
52 else:
53 class_names = [
54 cls.__name__ # pyright: ignore
55 for cls in classes
56 if hasattr(cls, "__name__")
57 ]
58 return any(
59 base_name in class_names for base_name in base_names # pyright: ignore
60 )
61 except RuntimeError:
62 return False
65def is_callable_instance(obj: object) -> bool:
66 return hasattr(obj, "__call__") and not isinstance(obj, type)
69def is_annotated_alias(annot: Any) -> bool:
70 _AnnotatedAlias = getattr(typing, "_AnnotatedAlias")
71 return isinstance(annot, _AnnotatedAlias)
74def get_annotated_dependency(annot: Any) -> Callable | None:
75 dependency: Callable | None = None
76 if is_annotated_alias(annot):
77 annot_args = typing.get_args(annot)
78 annot_dep = annot_args[1] if len(annot_args) > 1 else None
79 if annot_dep and annot_dep.__class__.__name__ == "Depends":
80 dependency = getattr(annot_dep, "dependency", None)
81 return dependency
84def create_exc_instance(
85 exc_class: type[Exception], exc_args: list[Any] | None = None
86) -> Exception | None:
87 importlib.import_module(exc_class.__module__)
88 value = exc_class(*exc_args) if exc_args else exc_class()
89 return value if isinstance(value, Exception) else None
92def replace_binops(node: ast.AST) -> None:
93 """Replace binary addition operations with their string representation."""
94 for name, field in ast.iter_fields(node):
95 if isinstance(field, list):
96 for i, item in enumerate(field):
97 if isinstance(item, ast.BinOp) and isinstance(item.op, ast.Add):
98 field[i] = ast.Constant(
99 value=f"<{ast.unparse(item.left)}> + <{ast.unparse(item.right)}>"
100 )
101 else:
102 replace_binops(item)
103 elif isinstance(field, ast.AST):
104 if isinstance(field, ast.BinOp) and isinstance(field.op, ast.Add):
105 setattr(
106 node,
107 name,
108 ast.Constant(
109 value=f"<{ast.unparse(field.left)}> + <{ast.unparse(field.right)}>"
110 ),
111 )
112 else:
113 replace_binops(field)
116def replace_formatted_values(node: ast.AST) -> None:
117 """Replace formatted values with their string representation."""
118 for name, field in ast.iter_fields(node):
119 if isinstance(field, list):
120 for i, item in enumerate(field):
121 if isinstance(item, ast.FormattedValue):
122 field[i] = ast.Constant(
123 value=f"<{ast.unparse(item.value).upper()}>"
124 )
125 else:
126 replace_formatted_values(item)
127 replace_binops(item)
128 elif isinstance(field, ast.AST):
129 if isinstance(field, ast.FormattedValue):
130 setattr(
131 node,
132 name,
133 ast.Constant(value=f"<{ast.unparse(field.value).upper()}>"),
134 )
135 else:
136 replace_formatted_values(field)
137 replace_binops(field)
138 return None
141def eval_ast_exc_instance(
142 exc_class: type[Exception], ast_exec_inst: ast.expr
143) -> Exception | None:
144 importlib.import_module(exc_class.__module__)
145 replace_formatted_values(ast_exec_inst)
146 value = eval(ast.unparse(ast_exec_inst))
147 return value if isinstance(value, Exception) else None
150class RouteExcFinder:
151 def __init__(
152 self,
153 customError: type[ErrType] | None = None,
154 dependencyClasses: tuple[type] | None = None,
155 serviceClasses: tuple[type] | None = None,
156 ):
157 self.customError = customError
158 self.dependencyClasses = dependencyClasses
159 self.serviceClasses = serviceClasses
161 self.exceptions_to_find: tuple[str, ...] = (
162 ("HTTPException", self.customError.__name__)
163 if self.customError
164 else ("HTTPException",)
165 )
167 self.functions: list[Callable] = []
168 self.exceptions: list[HTTPException | ErrType] = []
170 def extract_exceptions(
171 self,
172 route: APIRoute,
173 ) -> list[HTTPException | ErrType]:
174 self.functions.append(getattr(route, "endpoint", route))
175 self.functions.extend(self.find_functions(route))
176 while len(self.functions) > 0:
177 function = self.functions.pop(0)
178 self.functions.extend(self.find_functions(function))
179 self.exceptions.extend(self.find_exceptions(function))
180 if self.dependencyClasses:
181 self.exceptions += self.find_dependency_exceptions(route)
182 if self.dependencyClasses:
183 self.exceptions += self.find_annotated_dependency_exceptions(route)
184 if self.serviceClasses:
185 self.exceptions += self.find_service_exceptions(route)
186 return self.exceptions
188 @staticmethod
189 def find_functions(route: Callable) -> list[Callable]:
190 _functions = []
191 func = getattr(route, "endpoint", route)
192 unwrapped = inspect.unwrap(func)
193 source = textwrap.dedent(inspect.getsource(unwrapped))
194 tree = ast.parse(source)
195 module = importlib.import_module(func.__module__)
196 for node in ast.walk(tree):
197 try:
198 obj = getattr(module, node.id) if hasattr(node, "id") else None
199 if is_function_or_coroutine(obj) and obj is not func:
200 _functions.append(obj)
201 except (AttributeError, ValueError):
202 ...
203 return _functions
205 def find_exceptions(
206 self,
207 callable: APIRoute | Callable | str,
208 owner: type | ModuleType | None = None,
209 ) -> list[HTTPException]:
210 _exceptions = []
211 callable = callable.endpoint if hasattr(callable, "endpoint") else callable
213 if isinstance(callable, str):
214 if owner is None:
215 raise ValueError("owner must be provided if callable is a string")
216 callable = getattr(owner, callable)
218 if module := inspect.getmodule(callable) if owner is not ModuleType else owner:
219 unwrapped = inspect.unwrap(callable)
220 source = textwrap.dedent(inspect.getsource(unwrapped))
221 tree = ast.parse(source)
223 for node in find_nodes(tree, ast.Raise, attr_filter={"exc": ast.Call}):
224 http_exec_instance = self.create_exc_inst_from_raise_stmt(node, module)
225 if http_exec_instance:
226 _exceptions.append(http_exec_instance)
227 return _exceptions
229 def find_service_exceptions(
230 self,
231 route: APIRoute | Callable,
232 ) -> list[HTTPException]:
233 exceptions = []
234 assert self.serviceClasses is not None
235 func = route.endpoint if hasattr(route, "endpoint") else route
236 module = inspect.getmodule(func)
237 source = inspect.getsource(func)
238 tree = ast.parse(source)
240 for node in find_nodes(tree, ast.Call, attr_filter={"func": ast.Attribute}):
241 assert isinstance(method := node.func, ast.Attribute)
243 cls = None
244 if hasattr(method.value, "func") and hasattr(method.value.func, "id"):
245 try:
246 cls = getattr(module, method.value.func.id)
247 except (AttributeError, NameError):
248 ...
250 elif hasattr(method.value, "id"):
251 try:
252 cls = getattr(module, method.value.id)
253 except (AttributeError, NameError):
254 instance_name = method.value.id
255 for assignment_node in find_nodes(tree, ast.Assign):
256 if (
257 hasattr(assignment_node.targets[0], "id")
258 and assignment_node.targets[0].id == instance_name
259 ):
260 if hasattr(assignment_node.value, "func") and hasattr(
261 assignment_node.value.func, "id"
262 ):
263 try:
264 cls = getattr(module, assignment_node.value.func.id)
265 break
266 except (AttributeError, NameError):
267 ...
269 if cls:
270 _exceptions = self.search_method_for_excs(
271 cls, method, self.serviceClasses
272 )
273 if _exceptions:
274 exceptions.extend(_exceptions)
276 return exceptions
278 def find_dependency_exceptions(
279 self,
280 route: APIRoute | Callable,
281 ) -> list[HTTPException]:
282 exceptions = []
283 assert self.dependencyClasses is not None
284 func = route.endpoint if hasattr(route, "endpoint") else route
285 module = inspect.getmodule(func)
286 source = inspect.getsource(func)
287 tree = ast.parse(source)
289 for node in chain(
290 find_nodes(tree, ast.FunctionDef), find_nodes(tree, ast.AsyncFunctionDef)
291 ):
292 if hasattr(node, "args"):
293 for kwarg in chain(node.args.defaults, node.args.kw_defaults):
294 if (
295 kwarg
296 and hasattr(kwarg, "func")
297 and hasattr(kwarg.func, "id")
298 and kwarg.func.id == "Depends"
299 ):
300 if (
301 method := kwarg.args[0]
302 if hasattr(kwarg, "args") and kwarg.args
303 else None
304 ):
305 cls = None
306 if hasattr(method, "value") and hasattr(method.value, "id"):
307 try:
308 cls = getattr(module, method.value.id)
309 except (AttributeError, NameError):
310 ...
311 elif hasattr(method, "id"):
312 try:
313 cls = getattr(module, method.id)
314 except (AttributeError, NameError):
315 ...
316 if cls:
317 _exceptions = self.search_method_for_excs(
318 cls, method, self.dependencyClasses
319 )
320 if _exceptions:
321 exceptions.extend(_exceptions)
322 return exceptions
324 def find_annotated_dependency_exceptions(
325 self,
326 route: APIRoute | Callable,
327 ) -> list[HTTPException]:
328 exceptions = []
329 assert self.dependencyClasses is not None
330 func = route.endpoint if hasattr(route, "endpoint") else route
331 module = inspect.getmodule(func)
332 source = inspect.getsource(func)
333 tree = ast.parse(source)
335 for node in chain(
336 find_nodes(tree, ast.FunctionDef), find_nodes(tree, ast.AsyncFunctionDef)
337 ):
338 if hasattr(node, "args"):
339 for kwarg in getattr(node.args, "kwonlyargs", []):
340 if kwarg and hasattr(kwarg, "annotation"):
341 if (
342 annot := kwarg.annotation.id
343 if hasattr(kwarg.annotation, "id") and kwarg.annotation.id
344 else None
345 ):
346 cls = None
347 try:
348 cls = getattr(module, annot)
349 except (AttributeError, NameError):
350 ...
351 if cls and (dependency := get_annotated_dependency(cls)):
352 if _exceptions := self.find_exceptions(
353 dependency, module
354 ):
355 exceptions.extend(_exceptions)
356 return exceptions
358 def create_exc_inst_from_raise_stmt(
359 self,
360 raise_stmt: ast.Raise,
361 module: ModuleType | type,
362 ) -> Exception | None:
363 if raise_stmt.exc and hasattr(raise_stmt.exc, "func"):
364 if hasattr(raise_stmt.exc.func, "attr"):
365 try:
366 outer_exc_class = getattr(
367 module,
368 raise_stmt.exc.func.value.id,
369 )
370 http_exc = getattr(outer_exc_class, raise_stmt.exc.func.attr)
371 except AttributeError:
372 return None
373 if is_subclass_of_any(http_exc, self.exceptions_to_find):
374 http_exec_instance = create_exc_instance(http_exc)
375 return http_exec_instance
376 else:
377 try:
378 http_exc = getattr(module, raise_stmt.exc.func.id)
379 except AttributeError:
380 return None
381 if is_subclass_of_any(http_exc, self.exceptions_to_find):
382 try:
383 http_exec_instance = eval_ast_exc_instance(
384 http_exc, raise_stmt.exc
385 )
386 except NameError:
387 http_exec_instance = create_exc_instance(http_exc)
388 return http_exec_instance
390 return None
392 def get_class_and_callable(
393 self, cls: type, method: ast.Attribute, types_to_find: tuple[type, ...]
394 ) -> tuple[type, str]:
395 callable = ""
397 if isinstance(cls, types_to_find):
398 if is_callable_instance(cls):
399 assert hasattr(cls, "__call__")
400 callable = cls.__call__.__name__
401 else:
402 callable = method.attr
403 cls = cls.__class__
405 elif is_subclass_of_any(cls, types_to_find):
406 callable = method.attr
408 return cls, callable
410 def search_method_for_excs(
411 self, cls: type, method: ast.Attribute, types_to_find: tuple[type, ...]
412 ) -> list[HTTPException | ErrType]:
413 exceptions = []
415 cls, callable = self.get_class_and_callable(cls, method, types_to_find)
417 nested_to_search = (
418 self.serviceClasses
419 if types_to_find == self.dependencyClasses
420 else self.dependencyClasses
421 )
422 if nested_to_search and (callable or is_subclass_of_any(cls, nested_to_search)):
423 if nested_to_search == self.dependencyClasses:
424 _exceptions = self.find_dependency_exceptions(cls)
425 else:
426 _exceptions = self.find_service_exceptions(cls)
427 exceptions.extend(_exceptions)
429 if callable:
430 if _exceptions := self.find_exceptions(callable, cls):
431 exceptions.extend(_exceptions)
433 return exceptions
435 def clear(self) -> None:
436 self.functions.clear()
437 self.exceptions.clear()