Coverage for src/km3modules/common.py: 89%

201 statements  

« prev     ^ index     » next       coverage.py v7.5.1, created at 2024-05-08 03:14 +0000

1# Filename: common.py 

2# -*- coding: utf-8 -*- 

3# pylint: disable=locally-disabled 

4""" 

5A collection of commonly used modules. 

6 

7""" 

8 

9import sqlite3 

10from time import time 

11 

12import numpy as np 

13 

14import km3pipe as kp 

15from km3pipe import Module, Blob 

16from km3pipe.tools import prettyln 

17from km3pipe.sys import peak_memory_usage 

18 

19log = kp.logger.get_logger(__name__) 

20 

21 

22class Dump(Module): 

23 """Print the content of the blob. 

24 

25 Parameters 

26 ---------- 

27 keys: collection(string), optional [default=None] 

28 Keys to print. If None, print all keys. 

29 full: bool, default=False 

30 Print blob values too, not just the keys? 

31 """ 

32 

33 def configure(self): 

34 self.keys = self.get("keys") or None 

35 self.full = self.get("full") or False 

36 key = self.get("key") or None 

37 if key and not self.keys: 

38 self.keys = [key] 

39 

40 def process(self, blob): 

41 keys = sorted(blob.keys()) if self.keys is None else self.keys 

42 for key in keys: 

43 print(key + ":") 

44 if self.full: 

45 print(blob[key].__repr__()) 

46 print("") 

47 print("----------------------------------------\n") 

48 return blob 

49 

50 

51class Delete(Module): 

52 """Remove specific keys from the blob. 

53 

54 Parameters 

55 ---------- 

56 keys: collection(string), optional 

57 Keys to remove. 

58 """ 

59 

60 def configure(self): 

61 self.keys = self.get("keys") or set() 

62 key = self.get("key") or None 

63 if key and not self.keys: 

64 self.keys = [key] 

65 

66 def process(self, blob): 

67 for key in self.keys: 

68 blob.pop(key, None) 

69 return blob 

70 

71 

72class Keep(Module): 

73 """Keep only specified keys in the blob. 

74 

75 Parameters 

76 ---------- 

77 keys: collection(string), optional 

78 Keys to keep. Everything else is removed. 

79 """ 

80 

81 def configure(self): 

82 self.keys = self.get("keys", default=set()) 

83 key = self.get("key", default=None) 

84 self.h5locs = self.get("h5locs", default=set()) 

85 if key and not self.keys: 

86 self.keys = [key] 

87 

88 def process(self, blob): 

89 out = Blob() 

90 for key in blob.keys(): 

91 if key in self.keys: 

92 out[key] = blob[key] 

93 elif hasattr(blob[key], "h5loc") and blob[key].h5loc.startswith( 

94 tuple(self.h5locs) 

95 ): 

96 out[key] = blob[key] 

97 return out 

98 

99 

100class HitCounter(Module): 

101 """Prints the number of hits""" 

102 

103 def process(self, blob): 

104 try: 

105 self.cprint("Number of hits: {0}".format(len(blob["Hit"]))) 

106 except KeyError: 

107 pass 

108 return blob 

109 

110 

111class HitCalibrator(Module): 

112 """A very basic hit calibrator, which requires a `Calibration` module.""" 

113 

114 def configure(self): 

115 self.input_key = self.get("input_key", default="Hits") 

116 self.output_key = self.get("output_key", default="CalibHits") 

117 

118 def process(self, blob): 

119 if self.input_key not in blob: 

120 self.log.warn("No hits found in key '{}'.".format(self.input_key)) 

121 return blob 

122 hits = blob[self.input_key] 

123 chits = self.calibration.apply(hits) 

124 blob[self.output_key] = chits 

125 return blob 

126 

127 

128class BlobIndexer(Module): 

129 """Puts an incremented index in each blob for the key 'blob_index'""" 

130 

131 def configure(self): 

132 self.blob_index = 0 

133 

134 def process(self, blob): 

135 blob["blob_index"] = self.blob_index 

136 self.blob_index += 1 

137 return blob 

138 

139 

140class StatusBar(Module): 

141 """Displays the current blob number.""" 

142 

143 def configure(self): 

144 self.iteration = 1 

145 

146 def process(self, blob): 

147 prettyln("Blob {0:>7}".format(self.every * self.iteration)) 

148 self.iteration += 1 

149 return blob 

150 

151 def finish(self): 

152 prettyln(".", fill="=") 

153 

154 

155class TickTock(Module): 

156 """Display the elapsed time. 

157 

158 Parameters 

159 ---------- 

160 every: int, optional [default=1] 

161 Number of iterations between printout. 

162 """ 

163 

164 def configure(self): 

165 self.t0 = time() 

166 

167 def process(self, blob): 

168 t1 = (time() - self.t0) / 60 

169 prettyln("Time/min: {0:.3f}".format(t1)) 

170 return blob 

171 

172 

173class MemoryObserver(Module): 

174 """Shows the maximum memory usage 

175 

176 Parameters 

177 ---------- 

178 every: int, optional [default=1] 

179 Number of iterations between printout. 

180 """ 

181 

182 def process(self, blob): 

183 memory = peak_memory_usage() 

184 prettyln("Memory peak: {0:.3f} MB".format(memory)) 

185 return blob 

186 

187 

188class Siphon(Module): 

189 """A siphon to accumulate a given volume of blobs. 

190 

191 Parameters 

192 ---------- 

193 volume: int 

194 number of blobs to hold 

195 flush: bool 

196 discard blobs after accumulation 

197 

198 """ 

199 

200 def configure(self): 

201 self.volume = self.require("volume") # [blobs] 

202 self.flush = self.get("flush", default=False) 

203 

204 self.blob_count = 0 

205 

206 def process(self, blob): 

207 self.blob_count += 1 

208 if self.blob_count > self.volume: 

209 log.debug("Siphone overflow reached!") 

210 if self.flush: 

211 log.debug("Flushing the siphon.") 

212 self.blob_count = 0 

213 return blob 

214 

215 

216class MultiFilePump(kp.Module): 

217 """Use the given pump to iterate through a list of files. 

218 

219 The group_id will be reset so that it's unique for each iteration. 

220 

221 Parameters 

222 ---------- 

223 pump: Pump 

224 The pump to be used to generate the blobs. 

225 filenames: iterable(str) 

226 List of filenames. 

227 kwargs: dict(str -> any) optional 

228 Keyword arguments to be passed to the pump. 

229 

230 """ 

231 

232 def configure(self): 

233 self.pump = self.require("pump") 

234 self.filenames = self.require("filenames") 

235 self.kwargs = self.get("kwargs", default={}) 

236 self.blobs = self.blob_generator() 

237 self.cprint("Iterating through {} files.".format(len(self.filenames))) 

238 self.n_processed = 0 

239 self.group_id = 0 

240 

241 def blob_generator(self): 

242 for filename in self.filenames: 

243 self.cprint("Current file: {}".format(filename)) 

244 pump = self.pump(filename=filename, **self.kwargs) 

245 for blob in pump: 

246 self._set_group_id(blob) 

247 blob["filename"] = filename 

248 yield blob 

249 self.group_id += 1 

250 self.n_processed += 1 

251 

252 def _set_group_id(self, blob): 

253 for key, entry in blob.items(): 

254 if isinstance(entry, kp.Table): 

255 if hasattr(entry, "group_id"): 

256 entry.group_id = self.group_id 

257 else: 

258 blob[key] = entry.append_columns("group_id", self.group_id) 

259 

260 def process(self, blob): 

261 return next(self.blobs) 

262 

263 def finish(self): 

264 self.cprint( 

265 "Fully processed {} out of {} files.".format( 

266 self.n_processed, len(self.filenames) 

267 ) 

268 ) 

269 

270 

271class LocalDBService(kp.Module): 

272 """Provides a local sqlite3 based database service to store information""" 

273 

274 def configure(self): 

275 self.filename = self.require("filename") 

276 self.thread_safety = self.get("thread_safety", default=True) 

277 self.connection = None 

278 

279 self.expose(self.create_table, "create_table") 

280 self.expose(self.table_exists, "table_exists") 

281 self.expose(self.insert_row, "insert_row") 

282 self.expose(self.query, "query") 

283 

284 self._create_connection() 

285 

286 def _create_connection(self): 

287 """Create database connection""" 

288 try: 

289 self.connection = sqlite3.connect( 

290 self.filename, check_same_thread=self.thread_safety 

291 ) 

292 self.cprint(sqlite3.version) 

293 except sqlite3.Error as exception: 

294 self.log.error(exception) 

295 

296 def query(self, query): 

297 """Execute a SQL query and return the result of fetchall()""" 

298 cursor = self.connection.cursor() 

299 cursor.execute(query) 

300 return cursor.fetchall() 

301 

302 def insert_row(self, table, column_names, values): 

303 """Insert a row into the table with a given list of values""" 

304 cursor = self.connection.cursor() 

305 query = "INSERT INTO {} ({}) VALUES ({})".format( 

306 table, ", ".join(column_names), ",".join("'" + str(v) + "'" for v in values) 

307 ) 

308 cursor.execute(query) 

309 self.connection.commit() 

310 

311 def create_table(self, name, columns, types, overwrite=False): 

312 """Create a table with given columns and types, overwrite if specified 

313 

314 

315 The `types` should be a list of SQL types, like ["INT", "TEXT", "INT"] 

316 """ 

317 cursor = self.connection.cursor() 

318 if overwrite: 

319 cursor.execute("DROP TABLE IF EXISTS {}".format(name)) 

320 

321 cursor.execute( 

322 "CREATE TABLE {} ({})".format( 

323 name, ", ".join(["{} {}".format(*c) for c in zip(columns, types)]) 

324 ) 

325 ) 

326 self.connection.commit() 

327 

328 def table_exists(self, name): 

329 """Check if a table exists in the database""" 

330 cursor = self.connection.cursor() 

331 cursor.execute( 

332 "SELECT count(name) FROM sqlite_master " 

333 "WHERE type='table' AND name='{}'".format(name) 

334 ) 

335 return cursor.fetchone()[0] == 1 

336 

337 def finish(self): 

338 if self.connection: 

339 self.connection.close() 

340 

341 

342class Observer(kp.Module): 

343 """A simple helper to observe the blobs in a test pipeline. 

344 

345 Parameters 

346 ---------- 

347 count: int 

348 The exact number of iterations the pipeline has to drain 

349 required_keys: list(str) 

350 A list of keys which has to be present in a blob in every cycle. 

351 """ 

352 

353 def configure(self): 

354 self.count = self.get("count") 

355 self.required_keys = self.get("required_keys", default=[]) 

356 self._count = 0 

357 

358 def process(self, blob): 

359 self._count += 1 

360 for key in self.required_keys: 

361 assert key in blob 

362 return blob 

363 

364 def finish(self): 

365 print(f"Target count={self._count}, actual count={self.count}") 

366 if self.count is not None: 

367 assert self.count == self._count 

368 

369 

370class FilePump(kp.Module): 

371 """A basic iterator for a list of files. 

372 

373 Parameters 

374 ---------- 

375 filenames: iterable(str) 

376 The filenames to be iterated over which are put into ``blob["filename"]`` 

377 

378 """ 

379 

380 def configure(self): 

381 self.filenames = self.require("filenames") 

382 self.blobs = self.blob_generator() 

383 

384 def blob_generator(self): 

385 for filename in self.filenames: 

386 yield kp.Blob({"filename": filename}) 

387 

388 def process(self, blob): 

389 blob.update(next(self.blobs)) 

390 return blob