Coverage for src/km3pipe/tests/test_calib.py: 100%
282 statements
« prev ^ index » next coverage.py v7.5.1, created at 2024-05-08 03:14 +0000
« prev ^ index » next coverage.py v7.5.1, created at 2024-05-08 03:14 +0000
1# Filename: test_calib.py
2# pylint: disable=C0111,E1003,R0904,C0103,R0201,C0102
3from os.path import dirname, join
4import functools
5import operator
6import shutil
7import sys
8import tempfile
10import km3io
11from thepipe import Module, Pipeline
12import km3pipe as kp
13from km3pipe.dataclasses import Table
14from km3pipe.io.daq import DAQEvent
15from km3pipe.hardware import Detector
16from km3pipe.io.hdf5 import HDF5Sink
17from km3pipe.testing import TestCase, MagicMock, patch, skip, skipif, data_path
18from km3pipe.calib import Calibration, CalibrationService, slew
20from .test_hardware import EXAMPLE_DETX
22import numpy as np
23import tables as tb
25__author__ = "Tamas Gal"
26__copyright__ = "Copyright 2016, Tamas Gal and the KM3NeT collaboration."
27__credits__ = []
28__license__ = "MIT"
29__maintainer__ = "Tamas Gal"
30__email__ = "tgal@km3net.de"
31__status__ = "Development"
34class TestCalibration(TestCase):
35 """Tests for the Calibration class"""
37 def test_init_with_wrong_file_extension(self):
38 with self.assertRaises(NotImplementedError):
39 Calibration(filename="foo")
41 @patch("km3pipe.calib.Detector")
42 def test_init_with_filename(self, mock_detector):
43 Calibration(filename="foo.detx")
44 mock_detector.assert_called_with(filename="foo.detx")
46 @patch("km3pipe.calib.Detector")
47 def test_init_with_det_id(self, mock_detector):
48 Calibration(det_id=1)
49 mock_detector.assert_called_with(t0set=None, calibset=None, det_id=1)
50 Calibration(det_id=1, calibset=2, t0set=3)
51 mock_detector.assert_called_with(t0set=3, calibset=2, det_id=1)
53 def test_init_with_detector(self):
54 det = Detector(data_path("detx/detx_v1.detx"))
55 Calibration(detector=det)
57 def test_apply_to_hits_with_pmt_id_aka_mc_hits(self):
58 calib = Calibration(filename=data_path("detx/detx_v1.detx"))
60 hits = Table({"pmt_id": [1, 2, 1], "time": [10.1, 11.2, 12.3]})
62 chits = calib.apply(hits, correct_slewing=False)
64 assert len(hits) == len(chits)
66 a_hit = chits[0]
67 self.assertAlmostEqual(1.1, a_hit.pos_x)
68 self.assertAlmostEqual(10, a_hit.t0)
69 self.assertAlmostEqual(10.1, a_hit.time) # t0 should not bei applied
71 a_hit = chits[1]
72 self.assertAlmostEqual(1.4, a_hit.pos_x)
73 self.assertAlmostEqual(20, a_hit.t0)
74 self.assertAlmostEqual(11.2, a_hit.time) # t0 should not be applied
76 def test_apply_to_hits_with_pmt_id_aka_mc_hits_from_km3io(self):
77 calib = Calibration(filename=data_path("detx/KM3NeT_-00000001_20171212.detx"))
78 f = km3io.OfflineReader(
79 data_path(
80 "offline/mcv6.gsg_nue-CCHEDIS_1e4-1e6GeV.sirene.jte.jchain.aanet.1.root"
81 )
82 )
84 for event in f:
85 chits = calib.apply(event.mc_hits)
86 assert 840 == len(chits.t0)
87 assert np.allclose([3, 26, 24, 4, 23, 25], chits.channel_id[:6])
88 assert np.allclose([3401, 3401, 3406, 3411, 5501, 5501], chits.dom_id[:6])
89 assert np.allclose([1, 1, 6, 11, 1, 1], chits.floor[:6])
90 assert np.allclose([34, 34, 34, 34, 55, 55], chits.du[:6])
91 assert np.allclose(
92 [
93 1679.18706571,
94 1827.14262054,
95 1926.71722628,
96 2433.83097585,
97 1408.35942832,
98 1296.51397496,
99 ],
100 chits.time[:6],
101 )
102 assert np.allclose(
103 [2.034, 1.847, 1.938, 2.082, -54.96, -55.034], chits.pos_x[:6]
104 )
105 assert np.allclose(
106 [-233.415, -233.303, -233.355, -233.333, -341.346, -341.303],
107 chits.pos_y[:6],
108 )
109 assert np.allclose(
110 [65.059, 64.83, 244.83, 425.111, 64.941, 64.83], chits.pos_z[:6]
111 )
112 assert np.allclose([4, 4, 4, 26, 4, 4], f.mc_hits.origin[0][:6].tolist())
113 assert np.allclose(
114 [36835, 36881, 37187, 37457, 60311, 60315],
115 f.mc_hits.pmt_id[0][:6].tolist(),
116 )
117 break
119 def test_dus(self):
120 calib = Calibration(filename=data_path("detx/detx_v1.detx"))
122 hits = Table(
123 {"dom_id": [2, 6, 3], "channel_id": [0, 1, 2], "time": [10.1, 11.2, 12.3]}
124 )
126 dus = calib.dus(hits)
127 assert np.allclose([1, 2, 1], dus)
129 def test_floors(self):
130 calib = Calibration(filename=data_path("detx/detx_v1.detx"))
132 hits = Table(
133 {"dom_id": [2, 6, 3], "channel_id": [0, 1, 2], "time": [10.1, 11.2, 12.3]}
134 )
136 floors = calib.floors(hits)
137 assert np.allclose([2, 3, 3], floors)
139 def test_apply_to_hits_with_dom_id_and_channel_id(self):
140 calib = Calibration(filename=data_path("detx/detx_v1.detx"))
142 hits = Table(
143 {"dom_id": [2, 3, 3], "channel_id": [0, 1, 2], "time": [10.1, 11.2, 12.3]}
144 )
146 chits = calib.apply(hits, correct_slewing=False)
148 assert len(hits) == len(chits)
150 a_hit = chits[0]
151 self.assertAlmostEqual(2.1, a_hit.pos_x)
152 self.assertAlmostEqual(40, a_hit.t0)
153 t0 = a_hit.t0
154 self.assertAlmostEqual(10.1 + t0, a_hit.time)
156 a_hit = chits[1]
157 self.assertAlmostEqual(3.4, a_hit.pos_x)
158 self.assertAlmostEqual(80, a_hit.t0)
159 t0 = a_hit.t0
160 self.assertAlmostEqual(11.2 + t0, a_hit.time)
162 def test_assert_apply_adds_dom_id_and_channel_id_to_mc_hits(self):
163 calib = Calibration(filename=data_path("detx/detx_v1.detx"))
164 hits = Table({"pmt_id": [1, 2, 1], "time": [10.1, 11.2, 12.3]})
165 chits = calib.apply(hits)
166 self.assertListEqual([1, 1, 1], list(chits.dom_id))
167 self.assertListEqual([0, 1, 0], list(chits.channel_id))
169 def test_assert_apply_adds_pmt_id_to_hits(self):
170 calib = Calibration(filename=data_path("detx/detx_v1.detx"))
171 hits = Table(
172 {"dom_id": [2, 3, 3], "channel_id": [0, 1, 2], "time": [10.1, 11.2, 12.3]}
173 )
174 chits = calib.apply(hits, correct_slewing=False)
175 self.assertListEqual([4, 8, 9], list(chits.pmt_id))
177 def test_apply_to_hits_with_pmt_id_with_wrong_calib_raises(self):
178 calib = Calibration(filename=data_path("detx/detx_v1.detx"))
180 hits = Table({"pmt_id": [999], "time": [10.1]})
182 with self.assertRaises(KeyError):
183 calib.apply(hits, correct_slewing=False)
185 def test_apply_to_hits_with_dom_id_and_channel_id_with_wrong_calib_raises(self):
186 calib = Calibration(filename=data_path("detx/detx_v1.detx"))
188 hits = Table({"dom_id": [999], "channel_id": [0], "time": [10.1]})
190 with self.assertRaises(KeyError):
191 calib.apply(hits, correct_slewing=False)
193 def test_apply_to_hits_from_km3io(self):
194 calib = Calibration(filename=data_path("detx/km3net_offline.detx"))
195 hits = km3io.OfflineReader(data_path("offline/km3net_offline.root"))[0].hits
197 chits = calib.apply(hits)
198 assert 176 == len(chits.t0)
199 assert np.allclose([207747.825, 207745.656, 207743.836], chits.t0.tolist()[:3])
201 chits = calib.apply(hits[:3])
202 assert 3 == len(chits.t0)
203 assert np.allclose([207747.825, 207745.656, 207743.836], chits.t0.tolist()[:3])
205 def test_apply_to_hits_from_km3io_iterator(self):
206 calib = Calibration(filename=data_path("detx/km3net_offline.detx"))
207 f = km3io.OfflineReader(data_path("offline/km3net_offline.root"))
209 for event in f:
210 chits = calib.apply(event.hits)
211 assert 176 == len(chits.t0)
212 assert np.allclose(
213 [207747.825, 207745.656, 207743.836], chits.t0.tolist()[:3]
214 )
215 break
217 def test_daq_triggered_hits(self):
218 calib = Calibration(filename=data_path("detx/detx_v1.detx"))
220 dt = DAQEvent.triggered_hits_dt_final
222 raw_hits = np.array(
223 [(2, 0, 10, 0, 100), (3, 1, 11, 10, 200), (3, 2, 12, 255, 300)],
224 dtype=dt,
225 )
226 print(raw_hits.dtype)
228 hits = Table(raw_hits)
230 chits = calib.apply(hits) # correct_slewing=True is default
232 assert len(hits) == len(chits)
234 a_hit = chits[0]
235 self.assertAlmostEqual(10 + a_hit.t0 - slew(a_hit.tot), a_hit.time, places=5)
237 a_hit = chits[1]
238 self.assertAlmostEqual(11 + a_hit.t0 - slew(a_hit.tot), a_hit.time, places=5)
240 a_hit = chits[2]
241 self.assertAlmostEqual(12 + a_hit.t0 - slew(a_hit.tot), a_hit.time, places=5)
243 def test_time_slewing_correction(self):
244 calib = Calibration(filename=data_path("detx/detx_v1.detx"))
246 hits = Table(
247 {
248 "dom_id": [2, 3, 3],
249 "channel_id": [0, 1, 2],
250 "time": [10.1, 11.2, 12.3],
251 "tot": [0, 10, 255],
252 }
253 )
255 chits = calib.apply(hits) # correct_slewing=True is default
257 assert len(hits) == len(chits)
259 a_hit = chits[0]
260 self.assertAlmostEqual(10.1 + a_hit.t0 - slew(a_hit.tot), a_hit.time)
262 a_hit = chits[1]
263 self.assertAlmostEqual(11.2 + a_hit.t0 - slew(a_hit.tot), a_hit.time)
265 a_hit = chits[2]
266 self.assertAlmostEqual(12.3 + a_hit.t0 - slew(a_hit.tot), a_hit.time)
268 def test_apply_to_timeslice_hits(self):
269 tshits = Table.from_template(
270 {
271 "channel_id": [0, 1, 2],
272 "dom_id": [2, 3, 3],
273 "time": [10.1, 11.2, 12.3],
274 "tot": np.ones(3, dtype=float),
275 "group_id": 0,
276 },
277 "TimesliceHits",
278 )
279 calib = Calibration(filename=data_path("detx/detx_v1.detx"))
280 c_tshits = calib.apply(tshits, correct_slewing=False)
281 assert len(c_tshits) == len(tshits)
282 assert np.allclose([40, 80, 90], c_tshits.t0)
283 # TimesliceHits is using int4 for times, so it's truncated when we pass in float64
284 assert np.allclose([50.1, 91.2, 102.3], c_tshits.time, atol=0.1)
286 def test_apply_without_affecting_primary_hit_table(self):
287 calib = Calibration(filename=data_path("detx/detx_v1.detx"))
288 hits = Table({"pmt_id": [1, 2, 1], "time": [10.1, 11.2, 12.3]})
289 hits_compare = hits.copy()
290 calib.apply(hits, correct_slewing=False)
292 for t_primary, t_calib in zip(hits_compare, hits):
293 self.assertAlmostEqual(t_primary, t_calib)
295 def test_calibration_in_pipeline(self):
296 class DummyPump(kp.Module):
297 def configure(self):
298 self.index = 0
300 def process(self, blob):
301 self.index += 1
302 mc_hits = Table({"pmt_id": [1, 2, 1], "time": [10.1, 11.2, 12.3]})
303 hits = Table(
304 {
305 "dom_id": [2, 3, 3],
306 "channel_id": [0, 1, 2],
307 "time": [10.1, 11.2, 12.3],
308 "tot": [0, 10, 255],
309 }
310 )
312 blob["Hits"] = hits
313 blob["McHits"] = mc_hits
314 return blob
316 _self = self
318 class Observer(kp.Module):
319 def process(self, blob):
320 assert "Hits" in blob
321 assert "McHits" in blob
322 assert "CalibHits" in blob
323 assert "CalibMcHits" in blob
324 assert not hasattr(blob["Hits"], "pmt_id")
325 assert hasattr(blob["CalibHits"], "pmt_id")
326 assert not hasattr(blob["McHits"], "dom_id")
327 assert hasattr(blob["CalibHits"], "dom_id")
328 assert np.allclose([10.1, 11.2, 12.3], blob["Hits"].time)
329 assert np.allclose([42.09, 87.31, 111.34], blob["CalibHits"].time)
330 assert np.allclose(blob["McHits"].time, blob["CalibMcHits"].time)
331 return blob
333 pipe = kp.Pipeline()
334 pipe.attach(DummyPump)
335 pipe.attach(Calibration, filename=data_path("detx/detx_v1.detx"))
336 pipe.attach(Observer)
337 pipe.drain(3)
340class TestCalibrationService(TestCase):
341 def test_apply_to_hits_with_dom_id_and_channel_id(self):
343 hits = Table(
344 {
345 "dom_id": [2, 3, 3],
346 "channel_id": [0, 1, 2],
347 "time": [10.1, 11.2, 12.3],
348 "tot": [23, 105, 231],
349 }
350 )
352 tester = self
354 class HitCalibrator(Module):
355 def process(self, blob):
356 chits = self.services["calibrate"](hits)
358 assert len(hits) == len(chits)
360 a_hit = chits[0]
361 tester.assertAlmostEqual(2.1, a_hit.pos_x)
362 tester.assertAlmostEqual(40, a_hit.t0)
363 t0 = a_hit.t0
364 tester.assertAlmostEqual(10.1 + t0 - slew(a_hit.tot), a_hit.time)
366 a_hit = chits[1]
367 tester.assertAlmostEqual(3.4, a_hit.pos_x)
368 tester.assertAlmostEqual(80, a_hit.t0)
369 t0 = a_hit.t0
370 tester.assertAlmostEqual(11.2 + t0 - slew(a_hit.tot), a_hit.time)
371 return blob
373 pipe = Pipeline()
374 pipe.attach(CalibrationService, filename=data_path("detx/detx_v1.detx"))
375 pipe.attach(HitCalibrator)
376 pipe.drain(1)
378 def test_apply_to_hits_with_dom_id_and_channel_id_without_slewing(self):
380 hits = Table(
381 {
382 "dom_id": [2, 3, 3],
383 "channel_id": [0, 1, 2],
384 "time": [10.1, 11.2, 12.3],
385 "tot": [23, 105, 231],
386 }
387 )
389 tester = self
391 class HitCalibrator(Module):
392 def process(self, blob):
393 chits = self.services["calibrate"](hits, correct_slewing=False)
395 assert len(hits) == len(chits)
397 a_hit = chits[0]
398 tester.assertAlmostEqual(2.1, a_hit.pos_x)
399 tester.assertAlmostEqual(40, a_hit.t0)
400 t0 = a_hit.t0
401 tester.assertAlmostEqual(10.1 + t0, a_hit.time)
403 a_hit = chits[1]
404 tester.assertAlmostEqual(3.4, a_hit.pos_x)
405 tester.assertAlmostEqual(80, a_hit.t0)
406 t0 = a_hit.t0
407 tester.assertAlmostEqual(11.2 + t0, a_hit.time)
408 return blob
410 pipe = Pipeline()
411 pipe.attach(CalibrationService, filename=data_path("detx/detx_v1.detx"))
412 pipe.attach(HitCalibrator)
413 pipe.drain(1)
415 def test_correct_slewing(self):
417 hits = Table(
418 {
419 "dom_id": [2, 3, 3],
420 "channel_id": [0, 1, 2],
421 "time": [10.1, 11.2, 12.3],
422 "tot": [0, 10, 255],
423 }
424 )
426 tester = self
428 class HitCalibrator(Module):
429 def process(self, blob):
430 self.services["correct_slewing"](hits)
432 a_hit = hits[0]
433 tester.assertAlmostEqual(10.1 - slew(a_hit.tot), a_hit.time)
435 a_hit = hits[1]
436 tester.assertAlmostEqual(11.2 - slew(a_hit.tot), a_hit.time)
437 return blob
439 pipe = Pipeline()
440 pipe.attach(CalibrationService, filename=data_path("detx/detx_v1.detx"))
441 pipe.attach(HitCalibrator)
442 pipe.drain(1)
444 def test_provided_detector_data(self):
445 class DetectorReader(Module):
446 def process(self, blob):
447 assert "get_detector" in self.services
448 det = self.services["get_detector"]()
449 assert isinstance(det, Detector)
451 pipe = Pipeline()
452 pipe.attach(CalibrationService, filename=data_path("detx/detx_v1.detx"))
453 pipe.attach(DetectorReader)
454 pipe.drain(1)
457class TestSlew(TestCase):
458 def test_slew(self):
459 self.assertAlmostEqual(8.01, slew(0))
460 self.assertAlmostEqual(0.60, slew(23))
461 self.assertAlmostEqual(-9.04, slew(255))
463 def test_slew_vectorised(self):
464 assert np.allclose([8.01, 0.60, -9.04], slew(np.array([0, 23, 255])))