# Filename: calib.py
# pylint: disable=locally-disabled
"""
Calibration.
"""
import awkward as ak
import numba as nb
import numpy as np
import km3db
import km3io
from thepipe import Module
from km3pipe.hardware import Detector
from km3pipe.dataclasses import Table
from km3pipe.tools import istype
from km3pipe.logger import get_logger
__author__ = "Tamas Gal"
__copyright__ = "Copyright 2016, Tamas Gal and the KM3NeT collaboration."
__credits__ = ["Thomas Heid"]
__license__ = "MIT"
__maintainer__ = "Tamas Gal"
__email__ = "tgal@km3net.de"
__status__ = "Development"
[docs]
log = get_logger(__name__)
[docs]
class Calibration(Module):
"""A module which applies time, position and rotation corrections to hits.
This module also calibrates MC hits, but be aware, t0s are not appended to
the MC hit times.
Additionally, the global PMT ID is added to regular hits as ``pmt_id`` and
in case of MC hits, the ``dom_id`` and ``channel_id`` (DAQ) are set.
Parameters
----------
apply: bool, optional [default=True]
Apply the calibration to the hits (add position/direction/t0)?
filename: str, optional [default=None]
DetX file with detector description.
det_id: int, optional
.detx ID of detector (when retrieving from database).
t0set: optional
t0set (when retrieving from database).
calibset: optional
calibset (when retrieving from database).
key: str, optional [default="Hits"]
the blob key of the hits
outkey: str, optional [default="CalibHits"]
the output blob key of the calibrated hits
key_mc: str, optional [default="McHits"]
the blob key of the MC hits (if present)
outkey_mc: str, optional [default="CalibMcHits"]
the output blob key of the calibrated MC hits
"""
__name__ = "Calibration"
[docs]
def process(self, blob):
if self._should_apply:
blob[self.outkey] = self.apply(blob[self.key])
if self.key_mc in blob:
blob[self.outkey_mc] = self.apply(blob[self.key_mc])
return blob
[docs]
def get_detector(self):
"""Return the detector"""
return self.detector
[docs]
def dus(self, hits):
"""Return the DUs for given hits"""
dom_ids = hits["dom_id"]
return _dus(dom_ids, len(dom_ids), self._calib_by_dom_and_channel)
[docs]
def floors(self, hits):
"""Return the floors for given hits"""
dom_ids = hits["dom_id"]
return _floors(dom_ids, len(dom_ids), self._calib_by_dom_and_channel)
[docs]
def apply_t0(self, hits):
"""Apply only t0s"""
apply_t0_nb(hits.time, hits.dom_id, hits.channel_id, self._lookup_tables)
return hits
[docs]
def apply(self, hits, no_copy=False, correct_slewing=True, slewing_variant=3):
"""Add x, y, z, t0 (and du, floor if DataFrame) columns to the hits."""
if not no_copy:
try:
hits = hits.copy()
except AttributeError: # probably a km3io object
pass
if isinstance(hits, (ak.Array, ak.Record, km3io.rootio.Branch)):
if hasattr(hits, "dom_id"):
hits = Table(
dict(
dom_id=hits.dom_id,
channel_id=hits.channel_id,
time=hits.t,
tot=hits.tot,
triggered=hits.trig,
)
)
else: # mc_hits in km3io
hits = Table(
dict(
pmt_id=hits.pmt_id,
time=hits.t,
a=hits.a,
# TODO: Not all MC files have these two fields
# pure_a=hits.pure_a,
# pure_t=hits.pure_t,
origin=hits.origin,
)
)
if istype(hits, "DataFrame"):
# do we ever see McHits here?
hits = Table.from_template(hits, "Hits")
is_mc = None
if hasattr(hits, "dom_id") and hasattr(hits, "channel_id"):
try:
(
dir_x,
dir_y,
dir_z,
du,
floor,
pos_x,
pos_y,
pos_z,
t0,
pmt_id,
) = _get_calibration_for_hits(hits, self._calib_by_dom_and_channel)
except KeyError as e:
self.log.critical("Wrong calibration (DETX) data provided.")
raise
is_mc = False
elif hasattr(hits, "pmt_id"):
try:
(
dir_x,
dir_y,
dir_z,
du,
floor,
pos_x,
pos_y,
pos_z,
t0,
dom_id,
channel_id,
) = _get_calibration_for_mchits(hits, self._calib_by_pmt_id)
except KeyError as e:
self.log.critical("Wrong calibration (DETX) data provided.")
raise
is_mc = True
else:
raise TypeError(
"Don't know how to apply calibration to '{0}'. "
"We need at least 'dom_id' and 'channel_id', or "
"'pmt_id'.".format(hits.name)
)
if hasattr(hits, "time") and not is_mc:
if hits.time.dtype != t0.dtype:
time = hits.time.astype("f4") + t0.astype("f4")
hits = hits.drop_columns(["time"])
hits = hits.append_columns(["time"], [time])
else:
hits.time += t0
hits_data = {}
for colname in hits.dtype.names:
hits_data[colname] = hits[colname]
calib = {
"dir_x": dir_x,
"dir_y": dir_y,
"dir_z": dir_z,
"du": du.astype(np.uint8),
"floor": floor.astype(np.uint8),
"pos_x": pos_x,
"pos_y": pos_y,
"pos_z": pos_z,
"t0": t0,
}
if is_mc:
calib["dom_id"] = dom_id.astype(np.int32)
calib["channel_id"] = channel_id.astype(np.int32)
else:
calib["pmt_id"] = pmt_id.astype(np.int32)
hits_data.update(calib)
if correct_slewing and not is_mc:
hits_data["time"] -= slew(hits_data["tot"], variant=slewing_variant)
return Table(
hits_data, h5loc=hits.h5loc, split_h5=hits.split_h5, name=hits.name
)
def _create_dom_channel_lookup(self):
data = nb.typed.Dict.empty(
key_type=nb.types.i8, value_type=nb.types.float64[:, :]
)
for pmt in self.detector.pmts:
if pmt.dom_id not in data:
data[pmt.dom_id] = np.zeros((31, 10))
data[pmt.dom_id][pmt.channel_id] = np.asarray(
[
pmt.pos_x,
pmt.pos_y,
pmt.pos_z,
pmt.dir_x,
pmt.dir_y,
pmt.dir_z,
pmt.t0,
pmt.du,
pmt.floor,
pmt.pmt_id,
],
dtype=np.float64,
)
self._calib_by_dom_and_channel = data
self._lookup_tables = [(dom, cal) for dom, cal in data.items()]
def _create_pmt_id_lookup(self):
data = nb.typed.Dict.empty(key_type=nb.types.i8, value_type=nb.types.float64[:])
for pmt in self.detector.pmts:
data[pmt.pmt_id] = np.asarray(
[
pmt.pos_x,
pmt.pos_y,
pmt.pos_z,
pmt.dir_x,
pmt.dir_y,
pmt.dir_z,
pmt.t0,
pmt.du,
pmt.floor,
pmt.dom_id,
pmt.channel_id,
],
dtype=np.float64,
)
self._calib_by_pmt_id = data
def __repr__(self):
return self.__str__()
def __str__(self):
return "Calibration: det_id({0})".format(self.det_id)
@nb.njit
[docs]
def apply_t0_nb(times, dom_ids, channel_ids, lookup_tables):
"""Apply t0s using a lookup table of tuples (dom_id, calib)"""
dom_id = 0
lookup = np.empty((31, 9))
for i in range(len(times)):
cur_dom_id = dom_ids[i]
if cur_dom_id != dom_id:
dom_id = cur_dom_id
for (d, m) in lookup_tables:
if d == dom_id:
np.copyto(lookup, m)
t0 = lookup[channel_ids[i]][6]
times[i] += t0
@nb.jit(nopython=True)
def _get_calibration_for_hits(hits, lookup):
"""Append the position, direction and t0 columns and add t0 to time"""
n = len(hits)
cal = np.empty((n, 10))
for i in range(n):
calib = lookup[hits["dom_id"][i]][hits["channel_id"][i]]
cal[i] = calib
dir_x = cal[:, 3]
dir_y = cal[:, 4]
dir_z = cal[:, 5]
du = cal[:, 7]
floor = cal[:, 8]
pos_x = cal[:, 0]
pos_y = cal[:, 1]
pos_z = cal[:, 2]
pmt_id = cal[:, 9]
t0 = cal[:, 6]
return [dir_x, dir_y, dir_z, du, floor, pos_x, pos_y, pos_z, t0, pmt_id]
@nb.jit(nopython=True)
def _get_calibration_for_mchits(hits, lookup):
"""Append the position, direction and t0 columns and add t0 to time"""
n_hits = len(hits)
cal = np.empty((n_hits, 11))
for i in range(n_hits):
cal[i] = lookup[hits["pmt_id"][i]]
dir_x = cal[:, 3]
dir_y = cal[:, 4]
dir_z = cal[:, 5]
du = cal[:, 7]
floor = cal[:, 8]
pos_x = cal[:, 0]
pos_y = cal[:, 1]
pos_z = cal[:, 2]
t0 = cal[:, 6]
dom_id = cal[:, 9]
channel_id = cal[:, 10]
return [dir_x, dir_y, dir_z, du, floor, pos_x, pos_y, pos_z, t0, dom_id, channel_id]
@nb.njit
def _dus(dom_ids, n, lookup):
dus = np.empty(n, dtype=np.uint8)
for i in range(n):
# DU is at index 7
dus[i] = lookup[dom_ids[i]][0][7] # looking only at channel ID 0
return dus
@nb.njit
def _floors(dom_ids, n, lookup):
dus = np.empty(n, dtype=np.uint8)
for i in range(n):
# floor is at index 8
dus[i] = lookup[dom_ids[i]][0][8] # looking only at channel ID 0
return dus
[docs]
class CalibrationService(Module):
"""A service which provides calibration routines for hits
Parameters
----------
filename: str, optional [default=None]
DetX file with detector description.
det_id: int, optional
.detx ID of detector (when retrieving from database).
t0set: optional
t0set (when retrieving from database).
calibset: optional
calibset (when retrieving from database).
detector: kp.hardware.Detector, optional
"""
__name__ = "Calibration"
[docs]
def load_calibration(self, filename=None, det_id=None, t0set=None, calibset=None):
"""Load another calibration"""
self.filename = filename
self.det_id = det_id
self.t0set = t0set
self.calibset = calibset
self._detector = None
self._calibration = None
[docs]
def calibrate(self, hits, correct_slewing=True):
return self.calibration.apply(hits, correct_slewing=correct_slewing)
@property
[docs]
def detector(self):
if self._detector is None:
self._detector = self.calibration.detector
return self._detector
[docs]
def get_detector(self):
"""Extra getter to be as lazy as possible (expose triggers otherwise"""
return self.detector
@property
[docs]
def calibration(self):
if self._calibration is None:
self._calibration = Calibration(
filename=self.filename,
det_id=self.det_id,
t0set=self.t0set,
calibset=self.calibset,
)
return self._calibration
[docs]
def get_calibration(self):
"""Extra getter to be as lazy as possible (expose triggers otherwise"""
return self.calibration
[docs]
def correct_slewing(self, hits):
"""Apply time slewing correction to the hit times"""
hits.time -= slew(hits.tot)
[docs]
def slew(tot, variant=3):
"""Calculate the time slewing of a PMT response for a given ToT
Parameters
----------
tot: int or np.array(int)
Time over threshold value of a hit
variant: int, optional
The variant of the slew calculation.
1: The first parametrisation approach
2: Jannik's improvement of the parametrisation
3: The latest lookup table approach based on lab measurements.
Returns
-------
time: int
Time slewing, which has to be subtracted from the original hit time.
"""
if variant == 1:
return _slew_parametrised(7.70824, 0.00879447, -0.0621101, -1.90226, tot)
if variant == 2:
return _slew_parametrised(
13.6488662517, -0.128744123166, -0.0174837749244, -4.47119633965, tot
)
if variant == 3:
if isinstance(tot, (int, np.integer)):
return _SLEWS[tot]
return _slew_tabulated(np.array(_SLEWS), tot)
raise ValueError("Unknown slew calculation variant '{}'".format(variant))
@nb.jit(nopython=True)
def _slew_parametrised(p0, p1, p2, p3, tot):
return p0 * np.exp(p1 * np.sqrt(tot) + p2 * tot) + p3
@nb.jit(nopython=True)
def _slew_tabulated(tab, tots):
n = len(tots)
out = np.empty(n)
for i in range(n):
out[i] = tab[tots[i]]
return out
_SLEWS = np.array(
[
8.01,
7.52,
7.05,
6.59,
6.15,
5.74,
5.33,
4.95,
4.58,
4.22,
3.89,
3.56,
3.25,
2.95,
2.66,
2.39,
2.12,
1.87,
1.63,
1.40,
1.19,
0.98,
0.78,
0.60,
0.41,
0.24,
0.07,
-0.10,
-0.27,
-0.43,
-0.59,
-0.75,
-0.91,
-1.08,
-1.24,
-1.41,
-1.56,
-1.71,
-1.85,
-1.98,
-2.11,
-2.23,
-2.35,
-2.47,
-2.58,
-2.69,
-2.79,
-2.89,
-2.99,
-3.09,
-3.19,
-3.28,
-3.37,
-3.46,
-3.55,
-3.64,
-3.72,
-3.80,
-3.88,
-3.96,
-4.04,
-4.12,
-4.20,
-4.27,
-4.35,
-4.42,
-4.49,
-4.56,
-4.63,
-4.70,
-4.77,
-4.84,
-4.90,
-4.97,
-5.03,
-5.10,
-5.16,
-5.22,
-5.28,
-5.34,
-5.40,
-5.46,
-5.52,
-5.58,
-5.63,
-5.69,
-5.74,
-5.80,
-5.85,
-5.91,
-5.96,
-6.01,
-6.06,
-6.11,
-6.16,
-6.21,
-6.26,
-6.31,
-6.36,
-6.41,
-6.45,
-6.50,
-6.55,
-6.59,
-6.64,
-6.68,
-6.72,
-6.77,
-6.81,
-6.85,
-6.89,
-6.93,
-6.98,
-7.02,
-7.06,
-7.09,
-7.13,
-7.17,
-7.21,
-7.25,
-7.28,
-7.32,
-7.36,
-7.39,
-7.43,
-7.46,
-7.50,
-7.53,
-7.57,
-7.60,
-7.63,
-7.66,
-7.70,
-7.73,
-7.76,
-7.79,
-7.82,
-7.85,
-7.88,
-7.91,
-7.94,
-7.97,
-7.99,
-8.02,
-8.05,
-8.07,
-8.10,
-8.13,
-8.15,
-8.18,
-8.20,
-8.23,
-8.25,
-8.28,
-8.30,
-8.32,
-8.34,
-8.37,
-8.39,
-8.41,
-8.43,
-8.45,
-8.47,
-8.49,
-8.51,
-8.53,
-8.55,
-8.57,
-8.59,
-8.61,
-8.62,
-8.64,
-8.66,
-8.67,
-8.69,
-8.70,
-8.72,
-8.74,
-8.75,
-8.76,
-8.78,
-8.79,
-8.81,
-8.82,
-8.83,
-8.84,
-8.86,
-8.87,
-8.88,
-8.89,
-8.90,
-8.92,
-8.93,
-8.94,
-8.95,
-8.96,
-8.97,
-8.98,
-9.00,
-9.01,
-9.02,
-9.04,
-9.04,
-9.04,
-9.04,
-9.04,
-9.04,
-9.04,
-9.04,
-9.04,
-9.04,
-9.04,
-9.04,
-9.04,
-9.04,
-9.04,
-9.04,
-9.04,
-9.04,
-9.04,
-9.04,
-9.04,
-9.04,
-9.04,
-9.04,
-9.04,
-9.04,
-9.04,
-9.04,
-9.04,
-9.04,
-9.04,
-9.04,
-9.04,
-9.04,
-9.04,
-9.04,
-9.04,
-9.04,
-9.04,
-9.04,
-9.04,
-9.04,
-9.04,
-9.04,
-9.04,
-9.04,
-9.04,
-9.04,
-9.04,
-9.04,
-9.04,
-9.04,
-9.04,
-9.04,
-9.04,
]
)