Coverage for src/km3modules/common.py: 89%
201 statements
« prev ^ index » next coverage.py v7.5.1, created at 2024-05-08 03:14 +0000
« prev ^ index » next coverage.py v7.5.1, created at 2024-05-08 03:14 +0000
1# Filename: common.py
2# -*- coding: utf-8 -*-
3# pylint: disable=locally-disabled
4"""
5A collection of commonly used modules.
7"""
9import sqlite3
10from time import time
12import numpy as np
14import km3pipe as kp
15from km3pipe import Module, Blob
16from km3pipe.tools import prettyln
17from km3pipe.sys import peak_memory_usage
19log = kp.logger.get_logger(__name__)
22class Dump(Module):
23 """Print the content of the blob.
25 Parameters
26 ----------
27 keys: collection(string), optional [default=None]
28 Keys to print. If None, print all keys.
29 full: bool, default=False
30 Print blob values too, not just the keys?
31 """
33 def configure(self):
34 self.keys = self.get("keys") or None
35 self.full = self.get("full") or False
36 key = self.get("key") or None
37 if key and not self.keys:
38 self.keys = [key]
40 def process(self, blob):
41 keys = sorted(blob.keys()) if self.keys is None else self.keys
42 for key in keys:
43 print(key + ":")
44 if self.full:
45 print(blob[key].__repr__())
46 print("")
47 print("----------------------------------------\n")
48 return blob
51class Delete(Module):
52 """Remove specific keys from the blob.
54 Parameters
55 ----------
56 keys: collection(string), optional
57 Keys to remove.
58 """
60 def configure(self):
61 self.keys = self.get("keys") or set()
62 key = self.get("key") or None
63 if key and not self.keys:
64 self.keys = [key]
66 def process(self, blob):
67 for key in self.keys:
68 blob.pop(key, None)
69 return blob
72class Keep(Module):
73 """Keep only specified keys in the blob.
75 Parameters
76 ----------
77 keys: collection(string), optional
78 Keys to keep. Everything else is removed.
79 """
81 def configure(self):
82 self.keys = self.get("keys", default=set())
83 key = self.get("key", default=None)
84 self.h5locs = self.get("h5locs", default=set())
85 if key and not self.keys:
86 self.keys = [key]
88 def process(self, blob):
89 out = Blob()
90 for key in blob.keys():
91 if key in self.keys:
92 out[key] = blob[key]
93 elif hasattr(blob[key], "h5loc") and blob[key].h5loc.startswith(
94 tuple(self.h5locs)
95 ):
96 out[key] = blob[key]
97 return out
100class HitCounter(Module):
101 """Prints the number of hits"""
103 def process(self, blob):
104 try:
105 self.cprint("Number of hits: {0}".format(len(blob["Hit"])))
106 except KeyError:
107 pass
108 return blob
111class HitCalibrator(Module):
112 """A very basic hit calibrator, which requires a `Calibration` module."""
114 def configure(self):
115 self.input_key = self.get("input_key", default="Hits")
116 self.output_key = self.get("output_key", default="CalibHits")
118 def process(self, blob):
119 if self.input_key not in blob:
120 self.log.warn("No hits found in key '{}'.".format(self.input_key))
121 return blob
122 hits = blob[self.input_key]
123 chits = self.calibration.apply(hits)
124 blob[self.output_key] = chits
125 return blob
128class BlobIndexer(Module):
129 """Puts an incremented index in each blob for the key 'blob_index'"""
131 def configure(self):
132 self.blob_index = 0
134 def process(self, blob):
135 blob["blob_index"] = self.blob_index
136 self.blob_index += 1
137 return blob
140class StatusBar(Module):
141 """Displays the current blob number."""
143 def configure(self):
144 self.iteration = 1
146 def process(self, blob):
147 prettyln("Blob {0:>7}".format(self.every * self.iteration))
148 self.iteration += 1
149 return blob
151 def finish(self):
152 prettyln(".", fill="=")
155class TickTock(Module):
156 """Display the elapsed time.
158 Parameters
159 ----------
160 every: int, optional [default=1]
161 Number of iterations between printout.
162 """
164 def configure(self):
165 self.t0 = time()
167 def process(self, blob):
168 t1 = (time() - self.t0) / 60
169 prettyln("Time/min: {0:.3f}".format(t1))
170 return blob
173class MemoryObserver(Module):
174 """Shows the maximum memory usage
176 Parameters
177 ----------
178 every: int, optional [default=1]
179 Number of iterations between printout.
180 """
182 def process(self, blob):
183 memory = peak_memory_usage()
184 prettyln("Memory peak: {0:.3f} MB".format(memory))
185 return blob
188class Siphon(Module):
189 """A siphon to accumulate a given volume of blobs.
191 Parameters
192 ----------
193 volume: int
194 number of blobs to hold
195 flush: bool
196 discard blobs after accumulation
198 """
200 def configure(self):
201 self.volume = self.require("volume") # [blobs]
202 self.flush = self.get("flush", default=False)
204 self.blob_count = 0
206 def process(self, blob):
207 self.blob_count += 1
208 if self.blob_count > self.volume:
209 log.debug("Siphone overflow reached!")
210 if self.flush:
211 log.debug("Flushing the siphon.")
212 self.blob_count = 0
213 return blob
216class MultiFilePump(kp.Module):
217 """Use the given pump to iterate through a list of files.
219 The group_id will be reset so that it's unique for each iteration.
221 Parameters
222 ----------
223 pump: Pump
224 The pump to be used to generate the blobs.
225 filenames: iterable(str)
226 List of filenames.
227 kwargs: dict(str -> any) optional
228 Keyword arguments to be passed to the pump.
230 """
232 def configure(self):
233 self.pump = self.require("pump")
234 self.filenames = self.require("filenames")
235 self.kwargs = self.get("kwargs", default={})
236 self.blobs = self.blob_generator()
237 self.cprint("Iterating through {} files.".format(len(self.filenames)))
238 self.n_processed = 0
239 self.group_id = 0
241 def blob_generator(self):
242 for filename in self.filenames:
243 self.cprint("Current file: {}".format(filename))
244 pump = self.pump(filename=filename, **self.kwargs)
245 for blob in pump:
246 self._set_group_id(blob)
247 blob["filename"] = filename
248 yield blob
249 self.group_id += 1
250 self.n_processed += 1
252 def _set_group_id(self, blob):
253 for key, entry in blob.items():
254 if isinstance(entry, kp.Table):
255 if hasattr(entry, "group_id"):
256 entry.group_id = self.group_id
257 else:
258 blob[key] = entry.append_columns("group_id", self.group_id)
260 def process(self, blob):
261 return next(self.blobs)
263 def finish(self):
264 self.cprint(
265 "Fully processed {} out of {} files.".format(
266 self.n_processed, len(self.filenames)
267 )
268 )
271class LocalDBService(kp.Module):
272 """Provides a local sqlite3 based database service to store information"""
274 def configure(self):
275 self.filename = self.require("filename")
276 self.thread_safety = self.get("thread_safety", default=True)
277 self.connection = None
279 self.expose(self.create_table, "create_table")
280 self.expose(self.table_exists, "table_exists")
281 self.expose(self.insert_row, "insert_row")
282 self.expose(self.query, "query")
284 self._create_connection()
286 def _create_connection(self):
287 """Create database connection"""
288 try:
289 self.connection = sqlite3.connect(
290 self.filename, check_same_thread=self.thread_safety
291 )
292 self.cprint(sqlite3.version)
293 except sqlite3.Error as exception:
294 self.log.error(exception)
296 def query(self, query):
297 """Execute a SQL query and return the result of fetchall()"""
298 cursor = self.connection.cursor()
299 cursor.execute(query)
300 return cursor.fetchall()
302 def insert_row(self, table, column_names, values):
303 """Insert a row into the table with a given list of values"""
304 cursor = self.connection.cursor()
305 query = "INSERT INTO {} ({}) VALUES ({})".format(
306 table, ", ".join(column_names), ",".join("'" + str(v) + "'" for v in values)
307 )
308 cursor.execute(query)
309 self.connection.commit()
311 def create_table(self, name, columns, types, overwrite=False):
312 """Create a table with given columns and types, overwrite if specified
315 The `types` should be a list of SQL types, like ["INT", "TEXT", "INT"]
316 """
317 cursor = self.connection.cursor()
318 if overwrite:
319 cursor.execute("DROP TABLE IF EXISTS {}".format(name))
321 cursor.execute(
322 "CREATE TABLE {} ({})".format(
323 name, ", ".join(["{} {}".format(*c) for c in zip(columns, types)])
324 )
325 )
326 self.connection.commit()
328 def table_exists(self, name):
329 """Check if a table exists in the database"""
330 cursor = self.connection.cursor()
331 cursor.execute(
332 "SELECT count(name) FROM sqlite_master "
333 "WHERE type='table' AND name='{}'".format(name)
334 )
335 return cursor.fetchone()[0] == 1
337 def finish(self):
338 if self.connection:
339 self.connection.close()
342class Observer(kp.Module):
343 """A simple helper to observe the blobs in a test pipeline.
345 Parameters
346 ----------
347 count: int
348 The exact number of iterations the pipeline has to drain
349 required_keys: list(str)
350 A list of keys which has to be present in a blob in every cycle.
351 """
353 def configure(self):
354 self.count = self.get("count")
355 self.required_keys = self.get("required_keys", default=[])
356 self._count = 0
358 def process(self, blob):
359 self._count += 1
360 for key in self.required_keys:
361 assert key in blob
362 return blob
364 def finish(self):
365 print(f"Target count={self._count}, actual count={self.count}")
366 if self.count is not None:
367 assert self.count == self._count
370class FilePump(kp.Module):
371 """A basic iterator for a list of files.
373 Parameters
374 ----------
375 filenames: iterable(str)
376 The filenames to be iterated over which are put into ``blob["filename"]``
378 """
380 def configure(self):
381 self.filenames = self.require("filenames")
382 self.blobs = self.blob_generator()
384 def blob_generator(self):
385 for filename in self.filenames:
386 yield kp.Blob({"filename": filename})
388 def process(self, blob):
389 blob.update(next(self.blobs))
390 return blob