Coverage for fastapi_docx/exception_finder.py: 93%

266 statements  

« prev     ^ index     » next       coverage.py v7.6.0, created at 2024-07-24 20:22 +0000

1import ast 

2import importlib 

3import inspect 

4import textwrap 

5import typing 

6from collections.abc import Callable, Generator, Iterable 

7from itertools import chain 

8from types import ModuleType 

9from typing import Any, TypeVar 

10 

11from fastapi import status # noqa 

12from fastapi.routing import APIRoute 

13from starlette.exceptions import HTTPException 

14 

15ErrType = TypeVar("ErrType", bound=Exception) 

16AstType = TypeVar("AstType", bound=ast.AST) 

17 

18 

19def find_nodes( 

20 tree: ast.Module, 

21 node_type: type[AstType], 

22 attr_filter: dict[str, type[ast.AST]] | None = None, 

23) -> Generator[AstType, None, None]: 

24 if attr_filter: 

25 for node in ast.walk(tree): 

26 if isinstance(node, node_type) and all( 

27 hasattr(node, node_attr) 

28 and isinstance(getattr(node, node_attr), attr_type) 

29 for node_attr, attr_type in attr_filter.items() 

30 ): 

31 yield node 

32 else: 

33 for node in ast.walk(tree): 

34 if isinstance(node, node_type): 

35 yield node 

36 

37 

38def is_function_or_coroutine(obj: Any) -> bool: 

39 return inspect.isfunction(obj) or inspect.iscoroutinefunction(obj) 

40 

41 

42def is_subclass_of_any(klass: type, classes: Iterable[type] | Iterable[str]) -> bool: 

43 try: 

44 base_names = ( 

45 [base.__name__ for base in klass.__bases__] 

46 if hasattr(klass, "__bases__") 

47 else [] 

48 ) 

49 base_names.append(type(klass).__name__) 

50 if all(isinstance(cls, str) for cls in classes): 

51 class_names = classes 

52 else: 

53 class_names = [ 

54 cls.__name__ # pyright: ignore 

55 for cls in classes 

56 if hasattr(cls, "__name__") 

57 ] 

58 return any( 

59 base_name in class_names for base_name in base_names # pyright: ignore 

60 ) 

61 except RuntimeError: 

62 return False 

63 

64 

65def is_callable_instance(obj: object) -> bool: 

66 return hasattr(obj, "__call__") and not isinstance(obj, type) 

67 

68 

69def is_annotated_alias(annot: Any) -> bool: 

70 _AnnotatedAlias = getattr(typing, "_AnnotatedAlias") 

71 return isinstance(annot, _AnnotatedAlias) 

72 

73 

74def get_annotated_dependency(annot: Any) -> Callable | None: 

75 dependency: Callable | None = None 

76 if is_annotated_alias(annot): 

77 annot_args = typing.get_args(annot) 

78 annot_dep = annot_args[1] if len(annot_args) > 1 else None 

79 if annot_dep and annot_dep.__class__.__name__ == "Depends": 

80 dependency = getattr(annot_dep, "dependency", None) 

81 return dependency 

82 

83 

84def create_exc_instance( 

85 exc_class: type[Exception], exc_args: list[Any] | None = None 

86) -> Exception | None: 

87 importlib.import_module(exc_class.__module__) 

88 value = exc_class(*exc_args) if exc_args else exc_class() 

89 return value if isinstance(value, Exception) else None 

90 

91 

92def replace_binops(node: ast.AST) -> None: 

93 """Replace binary addition operations with their string representation.""" 

94 for name, field in ast.iter_fields(node): 

95 if isinstance(field, list): 

96 for i, item in enumerate(field): 

97 if isinstance(item, ast.BinOp) and isinstance(item.op, ast.Add): 

98 field[i] = ast.Constant( 

99 value=f"<{ast.unparse(item.left)}> + <{ast.unparse(item.right)}>" 

100 ) 

101 else: 

102 replace_binops(item) 

103 elif isinstance(field, ast.AST): 

104 if isinstance(field, ast.BinOp) and isinstance(field.op, ast.Add): 

105 setattr( 

106 node, 

107 name, 

108 ast.Constant( 

109 value=f"<{ast.unparse(field.left)}> + <{ast.unparse(field.right)}>" 

110 ), 

111 ) 

112 else: 

113 replace_binops(field) 

114 

115 

116def replace_formatted_values(node: ast.AST) -> None: 

117 """Replace formatted values with their string representation.""" 

118 for name, field in ast.iter_fields(node): 

119 if isinstance(field, list): 

120 for i, item in enumerate(field): 

121 if isinstance(item, ast.FormattedValue): 

122 field[i] = ast.Constant( 

123 value=f"<{ast.unparse(item.value).upper()}>" 

124 ) 

125 else: 

126 replace_formatted_values(item) 

127 replace_binops(item) 

128 elif isinstance(field, ast.AST): 

129 if isinstance(field, ast.FormattedValue): 

130 setattr( 

131 node, 

132 name, 

133 ast.Constant(value=f"<{ast.unparse(field.value).upper()}>"), 

134 ) 

135 else: 

136 replace_formatted_values(field) 

137 replace_binops(field) 

138 return None 

139 

140 

141def eval_ast_exc_instance( 

142 exc_class: type[Exception], ast_exec_inst: ast.expr 

143) -> Exception | None: 

144 importlib.import_module(exc_class.__module__) 

145 replace_formatted_values(ast_exec_inst) 

146 value = eval(ast.unparse(ast_exec_inst)) 

147 return value if isinstance(value, Exception) else None 

148 

149 

150class RouteExcFinder: 

151 def __init__( 

152 self, 

153 customError: type[ErrType] | None = None, 

154 dependencyClasses: tuple[type] | None = None, 

155 serviceClasses: tuple[type] | None = None, 

156 ): 

157 self.customError = customError 

158 self.dependencyClasses = dependencyClasses 

159 self.serviceClasses = serviceClasses 

160 

161 self.exceptions_to_find: tuple[str, ...] = ( 

162 ("HTTPException", self.customError.__name__) 

163 if self.customError 

164 else ("HTTPException",) 

165 ) 

166 

167 self.functions: list[Callable] = [] 

168 self.exceptions: list[HTTPException | ErrType] = [] 

169 

170 def extract_exceptions( 

171 self, 

172 route: APIRoute, 

173 ) -> list[HTTPException | ErrType]: 

174 self.functions.append(getattr(route, "endpoint", route)) 

175 self.functions.extend(self.find_functions(route)) 

176 while len(self.functions) > 0: 

177 function = self.functions.pop(0) 

178 self.functions.extend(self.find_functions(function)) 

179 self.exceptions.extend(self.find_exceptions(function)) 

180 if self.dependencyClasses: 

181 self.exceptions += self.find_dependency_exceptions(route) 

182 if self.dependencyClasses: 

183 self.exceptions += self.find_annotated_dependency_exceptions(route) 

184 if self.serviceClasses: 

185 self.exceptions += self.find_service_exceptions(route) 

186 return self.exceptions 

187 

188 @staticmethod 

189 def find_functions(route: Callable) -> list[Callable]: 

190 _functions = [] 

191 func = getattr(route, "endpoint", route) 

192 unwrapped = inspect.unwrap(func) 

193 source = textwrap.dedent(inspect.getsource(unwrapped)) 

194 tree = ast.parse(source) 

195 module = importlib.import_module(func.__module__) 

196 for node in ast.walk(tree): 

197 try: 

198 obj = getattr(module, node.id) if hasattr(node, "id") else None 

199 if is_function_or_coroutine(obj) and obj is not func: 

200 _functions.append(obj) 

201 except (AttributeError, ValueError): 

202 ... 

203 return _functions 

204 

205 def find_exceptions( 

206 self, 

207 callable: APIRoute | Callable | str, 

208 owner: type | ModuleType | None = None, 

209 ) -> list[HTTPException]: 

210 _exceptions = [] 

211 callable = callable.endpoint if hasattr(callable, "endpoint") else callable 

212 

213 if isinstance(callable, str): 

214 if owner is None: 

215 raise ValueError("owner must be provided if callable is a string") 

216 callable = getattr(owner, callable) 

217 

218 if module := inspect.getmodule(callable) if owner is not ModuleType else owner: 

219 unwrapped = inspect.unwrap(callable) 

220 source = textwrap.dedent(inspect.getsource(unwrapped)) 

221 tree = ast.parse(source) 

222 

223 for node in find_nodes(tree, ast.Raise, attr_filter={"exc": ast.Call}): 

224 http_exec_instance = self.create_exc_inst_from_raise_stmt(node, module) 

225 if http_exec_instance: 

226 _exceptions.append(http_exec_instance) 

227 return _exceptions 

228 

229 def find_service_exceptions( 

230 self, 

231 route: APIRoute | Callable, 

232 ) -> list[HTTPException]: 

233 exceptions = [] 

234 assert self.serviceClasses is not None 

235 func = route.endpoint if hasattr(route, "endpoint") else route 

236 module = inspect.getmodule(func) 

237 source = inspect.getsource(func) 

238 tree = ast.parse(source) 

239 

240 for node in find_nodes(tree, ast.Call, attr_filter={"func": ast.Attribute}): 

241 assert isinstance(method := node.func, ast.Attribute) 

242 

243 cls = None 

244 if hasattr(method.value, "func") and hasattr(method.value.func, "id"): 

245 try: 

246 cls = getattr(module, method.value.func.id) 

247 except (AttributeError, NameError): 

248 ... 

249 

250 elif hasattr(method.value, "id"): 

251 try: 

252 cls = getattr(module, method.value.id) 

253 except (AttributeError, NameError): 

254 instance_name = method.value.id 

255 for assignment_node in find_nodes(tree, ast.Assign): 

256 if ( 

257 hasattr(assignment_node.targets[0], "id") 

258 and assignment_node.targets[0].id == instance_name 

259 ): 

260 if hasattr(assignment_node.value, "func") and hasattr( 

261 assignment_node.value.func, "id" 

262 ): 

263 try: 

264 cls = getattr(module, assignment_node.value.func.id) 

265 break 

266 except (AttributeError, NameError): 

267 ... 

268 

269 if cls: 

270 _exceptions = self.search_method_for_excs( 

271 cls, method, self.serviceClasses 

272 ) 

273 if _exceptions: 

274 exceptions.extend(_exceptions) 

275 

276 return exceptions 

277 

278 def find_dependency_exceptions( 

279 self, 

280 route: APIRoute | Callable, 

281 ) -> list[HTTPException]: 

282 exceptions = [] 

283 assert self.dependencyClasses is not None 

284 func = route.endpoint if hasattr(route, "endpoint") else route 

285 module = inspect.getmodule(func) 

286 source = inspect.getsource(func) 

287 tree = ast.parse(source) 

288 

289 for node in chain( 

290 find_nodes(tree, ast.FunctionDef), find_nodes(tree, ast.AsyncFunctionDef) 

291 ): 

292 if hasattr(node, "args"): 

293 for kwarg in chain(node.args.defaults, node.args.kw_defaults): 

294 if ( 

295 kwarg 

296 and hasattr(kwarg, "func") 

297 and hasattr(kwarg.func, "id") 

298 and kwarg.func.id == "Depends" 

299 ): 

300 if ( 

301 method := kwarg.args[0] 

302 if hasattr(kwarg, "args") and kwarg.args 

303 else None 

304 ): 

305 cls = None 

306 if hasattr(method, "value") and hasattr(method.value, "id"): 

307 try: 

308 cls = getattr(module, method.value.id) 

309 except (AttributeError, NameError): 

310 ... 

311 elif hasattr(method, "id"): 

312 try: 

313 cls = getattr(module, method.id) 

314 except (AttributeError, NameError): 

315 ... 

316 if cls: 

317 _exceptions = self.search_method_for_excs( 

318 cls, method, self.dependencyClasses 

319 ) 

320 if _exceptions: 

321 exceptions.extend(_exceptions) 

322 return exceptions 

323 

324 def find_annotated_dependency_exceptions( 

325 self, 

326 route: APIRoute | Callable, 

327 ) -> list[HTTPException]: 

328 exceptions = [] 

329 assert self.dependencyClasses is not None 

330 func = route.endpoint if hasattr(route, "endpoint") else route 

331 module = inspect.getmodule(func) 

332 source = inspect.getsource(func) 

333 tree = ast.parse(source) 

334 

335 for node in chain( 

336 find_nodes(tree, ast.FunctionDef), find_nodes(tree, ast.AsyncFunctionDef) 

337 ): 

338 if hasattr(node, "args"): 

339 for kwarg in getattr(node.args, "kwonlyargs", []): 

340 if kwarg and hasattr(kwarg, "annotation"): 

341 if ( 

342 annot := kwarg.annotation.id 

343 if hasattr(kwarg.annotation, "id") and kwarg.annotation.id 

344 else None 

345 ): 

346 cls = None 

347 try: 

348 cls = getattr(module, annot) 

349 except (AttributeError, NameError): 

350 ... 

351 if cls and (dependency := get_annotated_dependency(cls)): 

352 if _exceptions := self.find_exceptions( 

353 dependency, module 

354 ): 

355 exceptions.extend(_exceptions) 

356 return exceptions 

357 

358 def create_exc_inst_from_raise_stmt( 

359 self, 

360 raise_stmt: ast.Raise, 

361 module: ModuleType | type, 

362 ) -> Exception | None: 

363 if raise_stmt.exc and hasattr(raise_stmt.exc, "func"): 

364 if hasattr(raise_stmt.exc.func, "attr"): 

365 try: 

366 outer_exc_class = getattr( 

367 module, 

368 raise_stmt.exc.func.value.id, 

369 ) 

370 http_exc = getattr(outer_exc_class, raise_stmt.exc.func.attr) 

371 except AttributeError: 

372 return None 

373 if is_subclass_of_any(http_exc, self.exceptions_to_find): 

374 http_exec_instance = create_exc_instance(http_exc) 

375 return http_exec_instance 

376 else: 

377 try: 

378 http_exc = getattr(module, raise_stmt.exc.func.id) 

379 except AttributeError: 

380 return None 

381 if is_subclass_of_any(http_exc, self.exceptions_to_find): 

382 try: 

383 http_exec_instance = eval_ast_exc_instance( 

384 http_exc, raise_stmt.exc 

385 ) 

386 except NameError: 

387 http_exec_instance = create_exc_instance(http_exc) 

388 return http_exec_instance 

389 

390 return None 

391 

392 def get_class_and_callable( 

393 self, cls: type, method: ast.Attribute, types_to_find: tuple[type, ...] 

394 ) -> tuple[type, str]: 

395 callable = "" 

396 

397 if isinstance(cls, types_to_find): 

398 if is_callable_instance(cls): 

399 assert hasattr(cls, "__call__") 

400 callable = cls.__call__.__name__ 

401 else: 

402 callable = method.attr 

403 cls = cls.__class__ 

404 

405 elif is_subclass_of_any(cls, types_to_find): 

406 callable = method.attr 

407 

408 return cls, callable 

409 

410 def search_method_for_excs( 

411 self, cls: type, method: ast.Attribute, types_to_find: tuple[type, ...] 

412 ) -> list[HTTPException | ErrType]: 

413 exceptions = [] 

414 

415 cls, callable = self.get_class_and_callable(cls, method, types_to_find) 

416 

417 nested_to_search = ( 

418 self.serviceClasses 

419 if types_to_find == self.dependencyClasses 

420 else self.dependencyClasses 

421 ) 

422 if nested_to_search and (callable or is_subclass_of_any(cls, nested_to_search)): 

423 if nested_to_search == self.dependencyClasses: 

424 _exceptions = self.find_dependency_exceptions(cls) 

425 else: 

426 _exceptions = self.find_service_exceptions(cls) 

427 exceptions.extend(_exceptions) 

428 

429 if callable: 

430 if _exceptions := self.find_exceptions(callable, cls): 

431 exceptions.extend(_exceptions) 

432 

433 return exceptions 

434 

435 def clear(self) -> None: 

436 self.functions.clear() 

437 self.exceptions.clear()