Coverage for src/km3pipe/tests/test_calib.py: 100%

282 statements  

« prev     ^ index     » next       coverage.py v7.5.1, created at 2024-05-08 03:14 +0000

1# Filename: test_calib.py 

2# pylint: disable=C0111,E1003,R0904,C0103,R0201,C0102 

3from os.path import dirname, join 

4import functools 

5import operator 

6import shutil 

7import sys 

8import tempfile 

9 

10import km3io 

11from thepipe import Module, Pipeline 

12import km3pipe as kp 

13from km3pipe.dataclasses import Table 

14from km3pipe.io.daq import DAQEvent 

15from km3pipe.hardware import Detector 

16from km3pipe.io.hdf5 import HDF5Sink 

17from km3pipe.testing import TestCase, MagicMock, patch, skip, skipif, data_path 

18from km3pipe.calib import Calibration, CalibrationService, slew 

19 

20from .test_hardware import EXAMPLE_DETX 

21 

22import numpy as np 

23import tables as tb 

24 

25__author__ = "Tamas Gal" 

26__copyright__ = "Copyright 2016, Tamas Gal and the KM3NeT collaboration." 

27__credits__ = [] 

28__license__ = "MIT" 

29__maintainer__ = "Tamas Gal" 

30__email__ = "tgal@km3net.de" 

31__status__ = "Development" 

32 

33 

34class TestCalibration(TestCase): 

35 """Tests for the Calibration class""" 

36 

37 def test_init_with_wrong_file_extension(self): 

38 with self.assertRaises(NotImplementedError): 

39 Calibration(filename="foo") 

40 

41 @patch("km3pipe.calib.Detector") 

42 def test_init_with_filename(self, mock_detector): 

43 Calibration(filename="foo.detx") 

44 mock_detector.assert_called_with(filename="foo.detx") 

45 

46 @patch("km3pipe.calib.Detector") 

47 def test_init_with_det_id(self, mock_detector): 

48 Calibration(det_id=1) 

49 mock_detector.assert_called_with(t0set=None, calibset=None, det_id=1) 

50 Calibration(det_id=1, calibset=2, t0set=3) 

51 mock_detector.assert_called_with(t0set=3, calibset=2, det_id=1) 

52 

53 def test_init_with_detector(self): 

54 det = Detector(data_path("detx/detx_v1.detx")) 

55 Calibration(detector=det) 

56 

57 def test_apply_to_hits_with_pmt_id_aka_mc_hits(self): 

58 calib = Calibration(filename=data_path("detx/detx_v1.detx")) 

59 

60 hits = Table({"pmt_id": [1, 2, 1], "time": [10.1, 11.2, 12.3]}) 

61 

62 chits = calib.apply(hits, correct_slewing=False) 

63 

64 assert len(hits) == len(chits) 

65 

66 a_hit = chits[0] 

67 self.assertAlmostEqual(1.1, a_hit.pos_x) 

68 self.assertAlmostEqual(10, a_hit.t0) 

69 self.assertAlmostEqual(10.1, a_hit.time) # t0 should not bei applied 

70 

71 a_hit = chits[1] 

72 self.assertAlmostEqual(1.4, a_hit.pos_x) 

73 self.assertAlmostEqual(20, a_hit.t0) 

74 self.assertAlmostEqual(11.2, a_hit.time) # t0 should not be applied 

75 

76 def test_apply_to_hits_with_pmt_id_aka_mc_hits_from_km3io(self): 

77 calib = Calibration(filename=data_path("detx/KM3NeT_-00000001_20171212.detx")) 

78 f = km3io.OfflineReader( 

79 data_path( 

80 "offline/mcv6.gsg_nue-CCHEDIS_1e4-1e6GeV.sirene.jte.jchain.aanet.1.root" 

81 ) 

82 ) 

83 

84 for event in f: 

85 chits = calib.apply(event.mc_hits) 

86 assert 840 == len(chits.t0) 

87 assert np.allclose([3, 26, 24, 4, 23, 25], chits.channel_id[:6]) 

88 assert np.allclose([3401, 3401, 3406, 3411, 5501, 5501], chits.dom_id[:6]) 

89 assert np.allclose([1, 1, 6, 11, 1, 1], chits.floor[:6]) 

90 assert np.allclose([34, 34, 34, 34, 55, 55], chits.du[:6]) 

91 assert np.allclose( 

92 [ 

93 1679.18706571, 

94 1827.14262054, 

95 1926.71722628, 

96 2433.83097585, 

97 1408.35942832, 

98 1296.51397496, 

99 ], 

100 chits.time[:6], 

101 ) 

102 assert np.allclose( 

103 [2.034, 1.847, 1.938, 2.082, -54.96, -55.034], chits.pos_x[:6] 

104 ) 

105 assert np.allclose( 

106 [-233.415, -233.303, -233.355, -233.333, -341.346, -341.303], 

107 chits.pos_y[:6], 

108 ) 

109 assert np.allclose( 

110 [65.059, 64.83, 244.83, 425.111, 64.941, 64.83], chits.pos_z[:6] 

111 ) 

112 assert np.allclose([4, 4, 4, 26, 4, 4], f.mc_hits.origin[0][:6].tolist()) 

113 assert np.allclose( 

114 [36835, 36881, 37187, 37457, 60311, 60315], 

115 f.mc_hits.pmt_id[0][:6].tolist(), 

116 ) 

117 break 

118 

119 def test_dus(self): 

120 calib = Calibration(filename=data_path("detx/detx_v1.detx")) 

121 

122 hits = Table( 

123 {"dom_id": [2, 6, 3], "channel_id": [0, 1, 2], "time": [10.1, 11.2, 12.3]} 

124 ) 

125 

126 dus = calib.dus(hits) 

127 assert np.allclose([1, 2, 1], dus) 

128 

129 def test_floors(self): 

130 calib = Calibration(filename=data_path("detx/detx_v1.detx")) 

131 

132 hits = Table( 

133 {"dom_id": [2, 6, 3], "channel_id": [0, 1, 2], "time": [10.1, 11.2, 12.3]} 

134 ) 

135 

136 floors = calib.floors(hits) 

137 assert np.allclose([2, 3, 3], floors) 

138 

139 def test_apply_to_hits_with_dom_id_and_channel_id(self): 

140 calib = Calibration(filename=data_path("detx/detx_v1.detx")) 

141 

142 hits = Table( 

143 {"dom_id": [2, 3, 3], "channel_id": [0, 1, 2], "time": [10.1, 11.2, 12.3]} 

144 ) 

145 

146 chits = calib.apply(hits, correct_slewing=False) 

147 

148 assert len(hits) == len(chits) 

149 

150 a_hit = chits[0] 

151 self.assertAlmostEqual(2.1, a_hit.pos_x) 

152 self.assertAlmostEqual(40, a_hit.t0) 

153 t0 = a_hit.t0 

154 self.assertAlmostEqual(10.1 + t0, a_hit.time) 

155 

156 a_hit = chits[1] 

157 self.assertAlmostEqual(3.4, a_hit.pos_x) 

158 self.assertAlmostEqual(80, a_hit.t0) 

159 t0 = a_hit.t0 

160 self.assertAlmostEqual(11.2 + t0, a_hit.time) 

161 

162 def test_assert_apply_adds_dom_id_and_channel_id_to_mc_hits(self): 

163 calib = Calibration(filename=data_path("detx/detx_v1.detx")) 

164 hits = Table({"pmt_id": [1, 2, 1], "time": [10.1, 11.2, 12.3]}) 

165 chits = calib.apply(hits) 

166 self.assertListEqual([1, 1, 1], list(chits.dom_id)) 

167 self.assertListEqual([0, 1, 0], list(chits.channel_id)) 

168 

169 def test_assert_apply_adds_pmt_id_to_hits(self): 

170 calib = Calibration(filename=data_path("detx/detx_v1.detx")) 

171 hits = Table( 

172 {"dom_id": [2, 3, 3], "channel_id": [0, 1, 2], "time": [10.1, 11.2, 12.3]} 

173 ) 

174 chits = calib.apply(hits, correct_slewing=False) 

175 self.assertListEqual([4, 8, 9], list(chits.pmt_id)) 

176 

177 def test_apply_to_hits_with_pmt_id_with_wrong_calib_raises(self): 

178 calib = Calibration(filename=data_path("detx/detx_v1.detx")) 

179 

180 hits = Table({"pmt_id": [999], "time": [10.1]}) 

181 

182 with self.assertRaises(KeyError): 

183 calib.apply(hits, correct_slewing=False) 

184 

185 def test_apply_to_hits_with_dom_id_and_channel_id_with_wrong_calib_raises(self): 

186 calib = Calibration(filename=data_path("detx/detx_v1.detx")) 

187 

188 hits = Table({"dom_id": [999], "channel_id": [0], "time": [10.1]}) 

189 

190 with self.assertRaises(KeyError): 

191 calib.apply(hits, correct_slewing=False) 

192 

193 def test_apply_to_hits_from_km3io(self): 

194 calib = Calibration(filename=data_path("detx/km3net_offline.detx")) 

195 hits = km3io.OfflineReader(data_path("offline/km3net_offline.root"))[0].hits 

196 

197 chits = calib.apply(hits) 

198 assert 176 == len(chits.t0) 

199 assert np.allclose([207747.825, 207745.656, 207743.836], chits.t0.tolist()[:3]) 

200 

201 chits = calib.apply(hits[:3]) 

202 assert 3 == len(chits.t0) 

203 assert np.allclose([207747.825, 207745.656, 207743.836], chits.t0.tolist()[:3]) 

204 

205 def test_apply_to_hits_from_km3io_iterator(self): 

206 calib = Calibration(filename=data_path("detx/km3net_offline.detx")) 

207 f = km3io.OfflineReader(data_path("offline/km3net_offline.root")) 

208 

209 for event in f: 

210 chits = calib.apply(event.hits) 

211 assert 176 == len(chits.t0) 

212 assert np.allclose( 

213 [207747.825, 207745.656, 207743.836], chits.t0.tolist()[:3] 

214 ) 

215 break 

216 

217 def test_daq_triggered_hits(self): 

218 calib = Calibration(filename=data_path("detx/detx_v1.detx")) 

219 

220 dt = DAQEvent.triggered_hits_dt_final 

221 

222 raw_hits = np.array( 

223 [(2, 0, 10, 0, 100), (3, 1, 11, 10, 200), (3, 2, 12, 255, 300)], 

224 dtype=dt, 

225 ) 

226 print(raw_hits.dtype) 

227 

228 hits = Table(raw_hits) 

229 

230 chits = calib.apply(hits) # correct_slewing=True is default 

231 

232 assert len(hits) == len(chits) 

233 

234 a_hit = chits[0] 

235 self.assertAlmostEqual(10 + a_hit.t0 - slew(a_hit.tot), a_hit.time, places=5) 

236 

237 a_hit = chits[1] 

238 self.assertAlmostEqual(11 + a_hit.t0 - slew(a_hit.tot), a_hit.time, places=5) 

239 

240 a_hit = chits[2] 

241 self.assertAlmostEqual(12 + a_hit.t0 - slew(a_hit.tot), a_hit.time, places=5) 

242 

243 def test_time_slewing_correction(self): 

244 calib = Calibration(filename=data_path("detx/detx_v1.detx")) 

245 

246 hits = Table( 

247 { 

248 "dom_id": [2, 3, 3], 

249 "channel_id": [0, 1, 2], 

250 "time": [10.1, 11.2, 12.3], 

251 "tot": [0, 10, 255], 

252 } 

253 ) 

254 

255 chits = calib.apply(hits) # correct_slewing=True is default 

256 

257 assert len(hits) == len(chits) 

258 

259 a_hit = chits[0] 

260 self.assertAlmostEqual(10.1 + a_hit.t0 - slew(a_hit.tot), a_hit.time) 

261 

262 a_hit = chits[1] 

263 self.assertAlmostEqual(11.2 + a_hit.t0 - slew(a_hit.tot), a_hit.time) 

264 

265 a_hit = chits[2] 

266 self.assertAlmostEqual(12.3 + a_hit.t0 - slew(a_hit.tot), a_hit.time) 

267 

268 def test_apply_to_timeslice_hits(self): 

269 tshits = Table.from_template( 

270 { 

271 "channel_id": [0, 1, 2], 

272 "dom_id": [2, 3, 3], 

273 "time": [10.1, 11.2, 12.3], 

274 "tot": np.ones(3, dtype=float), 

275 "group_id": 0, 

276 }, 

277 "TimesliceHits", 

278 ) 

279 calib = Calibration(filename=data_path("detx/detx_v1.detx")) 

280 c_tshits = calib.apply(tshits, correct_slewing=False) 

281 assert len(c_tshits) == len(tshits) 

282 assert np.allclose([40, 80, 90], c_tshits.t0) 

283 # TimesliceHits is using int4 for times, so it's truncated when we pass in float64 

284 assert np.allclose([50.1, 91.2, 102.3], c_tshits.time, atol=0.1) 

285 

286 def test_apply_without_affecting_primary_hit_table(self): 

287 calib = Calibration(filename=data_path("detx/detx_v1.detx")) 

288 hits = Table({"pmt_id": [1, 2, 1], "time": [10.1, 11.2, 12.3]}) 

289 hits_compare = hits.copy() 

290 calib.apply(hits, correct_slewing=False) 

291 

292 for t_primary, t_calib in zip(hits_compare, hits): 

293 self.assertAlmostEqual(t_primary, t_calib) 

294 

295 def test_calibration_in_pipeline(self): 

296 class DummyPump(kp.Module): 

297 def configure(self): 

298 self.index = 0 

299 

300 def process(self, blob): 

301 self.index += 1 

302 mc_hits = Table({"pmt_id": [1, 2, 1], "time": [10.1, 11.2, 12.3]}) 

303 hits = Table( 

304 { 

305 "dom_id": [2, 3, 3], 

306 "channel_id": [0, 1, 2], 

307 "time": [10.1, 11.2, 12.3], 

308 "tot": [0, 10, 255], 

309 } 

310 ) 

311 

312 blob["Hits"] = hits 

313 blob["McHits"] = mc_hits 

314 return blob 

315 

316 _self = self 

317 

318 class Observer(kp.Module): 

319 def process(self, blob): 

320 assert "Hits" in blob 

321 assert "McHits" in blob 

322 assert "CalibHits" in blob 

323 assert "CalibMcHits" in blob 

324 assert not hasattr(blob["Hits"], "pmt_id") 

325 assert hasattr(blob["CalibHits"], "pmt_id") 

326 assert not hasattr(blob["McHits"], "dom_id") 

327 assert hasattr(blob["CalibHits"], "dom_id") 

328 assert np.allclose([10.1, 11.2, 12.3], blob["Hits"].time) 

329 assert np.allclose([42.09, 87.31, 111.34], blob["CalibHits"].time) 

330 assert np.allclose(blob["McHits"].time, blob["CalibMcHits"].time) 

331 return blob 

332 

333 pipe = kp.Pipeline() 

334 pipe.attach(DummyPump) 

335 pipe.attach(Calibration, filename=data_path("detx/detx_v1.detx")) 

336 pipe.attach(Observer) 

337 pipe.drain(3) 

338 

339 

340class TestCalibrationService(TestCase): 

341 def test_apply_to_hits_with_dom_id_and_channel_id(self): 

342 

343 hits = Table( 

344 { 

345 "dom_id": [2, 3, 3], 

346 "channel_id": [0, 1, 2], 

347 "time": [10.1, 11.2, 12.3], 

348 "tot": [23, 105, 231], 

349 } 

350 ) 

351 

352 tester = self 

353 

354 class HitCalibrator(Module): 

355 def process(self, blob): 

356 chits = self.services["calibrate"](hits) 

357 

358 assert len(hits) == len(chits) 

359 

360 a_hit = chits[0] 

361 tester.assertAlmostEqual(2.1, a_hit.pos_x) 

362 tester.assertAlmostEqual(40, a_hit.t0) 

363 t0 = a_hit.t0 

364 tester.assertAlmostEqual(10.1 + t0 - slew(a_hit.tot), a_hit.time) 

365 

366 a_hit = chits[1] 

367 tester.assertAlmostEqual(3.4, a_hit.pos_x) 

368 tester.assertAlmostEqual(80, a_hit.t0) 

369 t0 = a_hit.t0 

370 tester.assertAlmostEqual(11.2 + t0 - slew(a_hit.tot), a_hit.time) 

371 return blob 

372 

373 pipe = Pipeline() 

374 pipe.attach(CalibrationService, filename=data_path("detx/detx_v1.detx")) 

375 pipe.attach(HitCalibrator) 

376 pipe.drain(1) 

377 

378 def test_apply_to_hits_with_dom_id_and_channel_id_without_slewing(self): 

379 

380 hits = Table( 

381 { 

382 "dom_id": [2, 3, 3], 

383 "channel_id": [0, 1, 2], 

384 "time": [10.1, 11.2, 12.3], 

385 "tot": [23, 105, 231], 

386 } 

387 ) 

388 

389 tester = self 

390 

391 class HitCalibrator(Module): 

392 def process(self, blob): 

393 chits = self.services["calibrate"](hits, correct_slewing=False) 

394 

395 assert len(hits) == len(chits) 

396 

397 a_hit = chits[0] 

398 tester.assertAlmostEqual(2.1, a_hit.pos_x) 

399 tester.assertAlmostEqual(40, a_hit.t0) 

400 t0 = a_hit.t0 

401 tester.assertAlmostEqual(10.1 + t0, a_hit.time) 

402 

403 a_hit = chits[1] 

404 tester.assertAlmostEqual(3.4, a_hit.pos_x) 

405 tester.assertAlmostEqual(80, a_hit.t0) 

406 t0 = a_hit.t0 

407 tester.assertAlmostEqual(11.2 + t0, a_hit.time) 

408 return blob 

409 

410 pipe = Pipeline() 

411 pipe.attach(CalibrationService, filename=data_path("detx/detx_v1.detx")) 

412 pipe.attach(HitCalibrator) 

413 pipe.drain(1) 

414 

415 def test_correct_slewing(self): 

416 

417 hits = Table( 

418 { 

419 "dom_id": [2, 3, 3], 

420 "channel_id": [0, 1, 2], 

421 "time": [10.1, 11.2, 12.3], 

422 "tot": [0, 10, 255], 

423 } 

424 ) 

425 

426 tester = self 

427 

428 class HitCalibrator(Module): 

429 def process(self, blob): 

430 self.services["correct_slewing"](hits) 

431 

432 a_hit = hits[0] 

433 tester.assertAlmostEqual(10.1 - slew(a_hit.tot), a_hit.time) 

434 

435 a_hit = hits[1] 

436 tester.assertAlmostEqual(11.2 - slew(a_hit.tot), a_hit.time) 

437 return blob 

438 

439 pipe = Pipeline() 

440 pipe.attach(CalibrationService, filename=data_path("detx/detx_v1.detx")) 

441 pipe.attach(HitCalibrator) 

442 pipe.drain(1) 

443 

444 def test_provided_detector_data(self): 

445 class DetectorReader(Module): 

446 def process(self, blob): 

447 assert "get_detector" in self.services 

448 det = self.services["get_detector"]() 

449 assert isinstance(det, Detector) 

450 

451 pipe = Pipeline() 

452 pipe.attach(CalibrationService, filename=data_path("detx/detx_v1.detx")) 

453 pipe.attach(DetectorReader) 

454 pipe.drain(1) 

455 

456 

457class TestSlew(TestCase): 

458 def test_slew(self): 

459 self.assertAlmostEqual(8.01, slew(0)) 

460 self.assertAlmostEqual(0.60, slew(23)) 

461 self.assertAlmostEqual(-9.04, slew(255)) 

462 

463 def test_slew_vectorised(self): 

464 assert np.allclose([8.01, 0.60, -9.04], slew(np.array([0, 23, 255])))