# Filename: common.py
# -*- coding: utf-8 -*-
# pylint: disable=locally-disabled
"""
A collection of commonly used modules.
"""
import sqlite3
from time import time
import numpy as np
import km3pipe as kp
from km3pipe import Module, Blob
from km3pipe.tools import prettyln
from km3pipe.sys import peak_memory_usage
[docs]
log = kp.logger.get_logger(__name__)
[docs]
class Dump(Module):
"""Print the content of the blob.
Parameters
----------
keys: collection(string), optional [default=None]
Keys to print. If None, print all keys.
full: bool, default=False
Print blob values too, not just the keys?
"""
[docs]
def process(self, blob):
keys = sorted(blob.keys()) if self.keys is None else self.keys
for key in keys:
print(key + ":")
if self.full:
print(blob[key].__repr__())
print("")
print("----------------------------------------\n")
return blob
[docs]
class Delete(Module):
"""Remove specific keys from the blob.
Parameters
----------
keys: collection(string), optional
Keys to remove.
"""
[docs]
def process(self, blob):
for key in self.keys:
blob.pop(key, None)
return blob
[docs]
class Keep(Module):
"""Keep only specified keys in the blob.
Parameters
----------
keys: collection(string), optional
Keys to keep. Everything else is removed.
"""
[docs]
def process(self, blob):
out = Blob()
for key in blob.keys():
if key in self.keys:
out[key] = blob[key]
elif hasattr(blob[key], "h5loc") and blob[key].h5loc.startswith(
tuple(self.h5locs)
):
out[key] = blob[key]
return out
[docs]
class HitCounter(Module):
"""Prints the number of hits"""
[docs]
def process(self, blob):
try:
self.cprint("Number of hits: {0}".format(len(blob["Hit"])))
except KeyError:
pass
return blob
[docs]
class HitCalibrator(Module):
"""A very basic hit calibrator, which requires a `Calibration` module."""
[docs]
def process(self, blob):
if self.input_key not in blob:
self.log.warn("No hits found in key '{}'.".format(self.input_key))
return blob
hits = blob[self.input_key]
chits = self.calibration.apply(hits)
blob[self.output_key] = chits
return blob
[docs]
class BlobIndexer(Module):
"""Puts an incremented index in each blob for the key 'blob_index'"""
[docs]
def process(self, blob):
blob["blob_index"] = self.blob_index
self.blob_index += 1
return blob
[docs]
class StatusBar(Module):
"""Displays the current blob number."""
[docs]
def process(self, blob):
prettyln("Blob {0:>7}".format(self.every * self.iteration))
self.iteration += 1
return blob
[docs]
def finish(self):
prettyln(".", fill="=")
[docs]
class TickTock(Module):
"""Display the elapsed time.
Parameters
----------
every: int, optional [default=1]
Number of iterations between printout.
"""
[docs]
def process(self, blob):
t1 = (time() - self.t0) / 60
prettyln("Time/min: {0:.3f}".format(t1))
return blob
[docs]
class MemoryObserver(Module):
"""Shows the maximum memory usage
Parameters
----------
every: int, optional [default=1]
Number of iterations between printout.
"""
[docs]
def process(self, blob):
memory = peak_memory_usage()
prettyln("Memory peak: {0:.3f} MB".format(memory))
return blob
[docs]
class Siphon(Module):
"""A siphon to accumulate a given volume of blobs.
Parameters
----------
volume: int
number of blobs to hold
flush: bool
discard blobs after accumulation
"""
[docs]
def process(self, blob):
self.blob_count += 1
if self.blob_count > self.volume:
log.debug("Siphone overflow reached!")
if self.flush:
log.debug("Flushing the siphon.")
self.blob_count = 0
return blob
[docs]
class MultiFilePump(kp.Module):
"""Use the given pump to iterate through a list of files.
The group_id will be reset so that it's unique for each iteration.
Parameters
----------
pump: Pump
The pump to be used to generate the blobs.
filenames: iterable(str)
List of filenames.
kwargs: dict(str -> any) optional
Keyword arguments to be passed to the pump.
"""
[docs]
def blob_generator(self):
for filename in self.filenames:
self.cprint("Current file: {}".format(filename))
pump = self.pump(filename=filename, **self.kwargs)
for blob in pump:
self._set_group_id(blob)
blob["filename"] = filename
yield blob
self.group_id += 1
self.n_processed += 1
def _set_group_id(self, blob):
for key, entry in blob.items():
if isinstance(entry, kp.Table):
if hasattr(entry, "group_id"):
entry.group_id = self.group_id
else:
blob[key] = entry.append_columns("group_id", self.group_id)
[docs]
def process(self, blob):
return next(self.blobs)
[docs]
def finish(self):
self.cprint(
"Fully processed {} out of {} files.".format(
self.n_processed, len(self.filenames)
)
)
[docs]
class LocalDBService(kp.Module):
"""Provides a local sqlite3 based database service to store information"""
def _create_connection(self):
"""Create database connection"""
try:
self.connection = sqlite3.connect(
self.filename, check_same_thread=self.thread_safety
)
self.cprint(sqlite3.version)
except sqlite3.Error as exception:
self.log.error(exception)
[docs]
def query(self, query):
"""Execute a SQL query and return the result of fetchall()"""
cursor = self.connection.cursor()
cursor.execute(query)
return cursor.fetchall()
[docs]
def insert_row(self, table, column_names, values):
"""Insert a row into the table with a given list of values"""
cursor = self.connection.cursor()
query = "INSERT INTO {} ({}) VALUES ({})".format(
table, ", ".join(column_names), ",".join("'" + str(v) + "'" for v in values)
)
cursor.execute(query)
self.connection.commit()
[docs]
def create_table(self, name, columns, types, overwrite=False):
"""Create a table with given columns and types, overwrite if specified
The `types` should be a list of SQL types, like ["INT", "TEXT", "INT"]
"""
cursor = self.connection.cursor()
if overwrite:
cursor.execute("DROP TABLE IF EXISTS {}".format(name))
cursor.execute(
"CREATE TABLE {} ({})".format(
name, ", ".join(["{} {}".format(*c) for c in zip(columns, types)])
)
)
self.connection.commit()
[docs]
def table_exists(self, name):
"""Check if a table exists in the database"""
cursor = self.connection.cursor()
cursor.execute(
"SELECT count(name) FROM sqlite_master "
"WHERE type='table' AND name='{}'".format(name)
)
return cursor.fetchone()[0] == 1
[docs]
def finish(self):
if self.connection:
self.connection.close()
[docs]
class Observer(kp.Module):
"""A simple helper to observe the blobs in a test pipeline.
Parameters
----------
count: int
The exact number of iterations the pipeline has to drain
required_keys: list(str)
A list of keys which has to be present in a blob in every cycle.
"""
[docs]
def process(self, blob):
self._count += 1
for key in self.required_keys:
assert key in blob
return blob
[docs]
def finish(self):
print(f"Target count={self._count}, actual count={self.count}")
if self.count is not None:
assert self.count == self._count
[docs]
class FilePump(kp.Module):
"""A basic iterator for a list of files.
Parameters
----------
filenames: iterable(str)
The filenames to be iterated over which are put into ``blob["filename"]``
"""
[docs]
def blob_generator(self):
for filename in self.filenames:
yield kp.Blob({"filename": filename})
[docs]
def process(self, blob):
blob.update(next(self.blobs))
return blob