Coverage for src/km3modules/io.py: 78%

143 statements  

« prev     ^ index     » next       coverage.py v7.5.1, created at 2024-05-08 03:14 +0000

1#!/usr/bin/env python3 

2from collections import defaultdict 

3 

4import numpy as np 

5import awkward as ak 

6 

7import km3pipe as kp 

8import km3io 

9 

10USR_MC_TRACKS_KEYS = [b"energy_lost_in_can", b"bx", b"by", b"ichan", b"cc"] 

11 

12 

13class HitsTabulator(kp.Module): 

14 """ 

15 Create `kp.Table` from hits provided by `km3io`. 

16 

17 Parameters 

18 ---------- 

19 kind: str 

20 The kind of hits to tabulate: 

21 "offline": the hits in an offline file 

22 "online": snapshot and triggered hits (will be combined) 

23 "mc": MC hits 

24 split: bool (default: True) 

25 Defines whether the hits should be split up into individual arrays 

26 in a single group (e.g. hits/dom_id, hits/channel_id) or stored 

27 as a single HDF5Compound array (e.g. hits). 

28 """ 

29 

30 def configure(self): 

31 self.kind = self.require("kind") 

32 self.with_calibration = self.get("with_calibration", default=False) 

33 self.split = self.get("split", default=True) 

34 

35 def process(self, blob): 

36 if self.kind == "offline": 

37 n = blob["event"].n_hits 

38 if n == 0: 

39 return blob 

40 hits = blob["event"].hits 

41 

42 hits_data = { 

43 "channel_id": hits.channel_id, 

44 "dom_id": hits.dom_id, 

45 "time": hits.t, 

46 "tot": hits.tot, 

47 "triggered": hits.trig, 

48 } 

49 

50 if self.with_calibration: 

51 hits_data["pos_x"] = hits.pos_x 

52 hits_data["pos_y"] = hits.pos_y 

53 hits_data["pos_z"] = hits.pos_z 

54 hits_data["dir_x"] = hits.dir_x 

55 hits_data["dir_y"] = hits.dir_y 

56 hits_data["dir_z"] = hits.dir_z 

57 hits_data["tdc"] = hits.tdc 

58 

59 blob["Hits"] = kp.Table( 

60 hits_data, 

61 h5loc="/hits", 

62 split_h5=self.split, 

63 name="Hits", 

64 ) 

65 

66 if self.kind == "mc": 

67 n = blob["event"].n_mc_hits 

68 if n == 0: 

69 return blob 

70 mc_hits = blob["event"].mc_hits 

71 blob["McHits"] = kp.Table( 

72 { 

73 "a": mc_hits.a, 

74 "origin": mc_hits.origin, 

75 "pmt_id": mc_hits.pmt_id, 

76 "time": mc_hits.t, 

77 }, 

78 h5loc="/mc_hits", 

79 split_h5=self.split, 

80 name="McHits", 

81 ) 

82 

83 if self.kind == "online": 

84 raise NotImplementedError( 

85 "The extraction of online (DAQ) hits is not implemented yet." 

86 ) 

87 return blob 

88 

89 

90class MCTracksTabulator(kp.Module): 

91 """ 

92 Create `kp.Table` from MC tracks provided by `km3io`. 

93 

94 Parameters 

95 ---------- 

96 split: bool (default: False) 

97 Defines whether the tracks should be split up into individual arrays 

98 in a single group (e.g. mc_tracks/by, mc_tracks/origin) or stored 

99 as a single HDF5Compound array (e.g. mc_tracks). 

100 read_usr_data: bool (default: False) 

101 Parses usr-data which is originally meant for user stored values, but 

102 was abused by generator software to store properties. This issue will 

103 be sorted out hopefully soon as it dramatically decreases the processing 

104 performance and usability. 

105 """ 

106 

107 def configure(self): 

108 self.split = self.get("split", default=False) 

109 

110 self._read_usr_data = self.get("read_usr_data", default=False) 

111 if self._read_usr_data: 

112 self.log.warning( 

113 "Reading usr-data will massively decrease the performance." 

114 ) 

115 

116 def process(self, blob): 

117 n = blob["event"].n_mc_tracks 

118 if n == 0: 

119 return blob 

120 

121 mc_tracks = blob["event"].mc_tracks 

122 blob["McTracks"] = self._parse_mc_tracks(mc_tracks) 

123 return blob 

124 

125 def _parse_usr_to_dct(self, mc_tracks): 

126 dct = defaultdict(list) 

127 for k in USR_MC_TRACKS_KEYS: 

128 dec_key = k.decode("utf_8") 

129 for i in range(len(mc_tracks.usr_names)): 

130 value = np.nan 

131 if k in mc_tracks.usr_names[i]: 

132 mask = mc_tracks.usr_names[i] == k 

133 value = mc_tracks.usr[i][mask][0] 

134 dct[dec_key].append(value) 

135 return dct 

136 

137 def _parse_mc_tracks(self, mc_tracks): 

138 dct = { 

139 "dir_x": mc_tracks.dir_x, 

140 "dir_y": mc_tracks.dir_y, 

141 "dir_z": mc_tracks.dir_z, 

142 "pos_x": mc_tracks.pos_x, 

143 "pos_y": mc_tracks.pos_y, 

144 "pos_z": mc_tracks.pos_z, 

145 "energy": mc_tracks.E, 

146 "time": mc_tracks.t, 

147 "pdgid": mc_tracks.pdgid, 

148 "id": mc_tracks.id, 

149 "length": mc_tracks.len, 

150 } 

151 if self._read_usr_data: 

152 dct.update(self._parse_usr_to_dct(mc_tracks)) 

153 return kp.Table(dct, name="McTracks", h5loc="/mc_tracks", split_h5=self.split) 

154 

155 

156class RecoTracksTabulator(kp.Module): 

157 """ 

158 Create `kp.Table` from recostruced tracks provided by `km3io`. 

159 

160 Parameters 

161 ---------- 

162 best_tracks: bool (default: False) 

163 Additionally determine best track. 

164 split: bool (default: False) 

165 Defines whether the tracks should be split up into individual arrays 

166 in a single group (e.g. reco/tracks/dom_id, reco/tracks/channel_id) or stored 

167 as a single HDF5Compound array (e.g. reco/tracks). 

168 """ 

169 

170 def configure(self): 

171 

172 self.split = self.get("split", default=False) 

173 self.best_tracks = self.get("best_tracks", default=False) 

174 self.aashower_legacy = self.get("aashower_legacy", default=False) 

175 

176 self._best_track_fmap = { 

177 km3io.definitions.reconstruction.JMUONPREFIT: ( 

178 km3io.tools.best_jmuon, 

179 "best_jmuon", 

180 ), 

181 km3io.definitions.reconstruction.JSHOWERPREFIT: ( 

182 km3io.tools.best_jshower, 

183 "best_jshower", 

184 ), 

185 km3io.definitions.reconstruction.DUSJSHOWERPREFIT: ( 

186 km3io.tools.best_dusjshower, 

187 "best_dusjshower", 

188 ), 

189 km3io.definitions.reconstruction.AASHOWERFITPREFIT: ( 

190 km3io.tools.best_aashower, 

191 "best_aashower", 

192 ), 

193 } 

194 

195 def process(self, blob): 

196 n_tracks = blob["event"].n_tracks 

197 # we first check if there are any tracks, otherwise the other calls will raise 

198 if n_tracks == 0: 

199 return blob 

200 

201 all_tracks = blob["event"].tracks 

202 

203 if self.aashower_legacy == True: 

204 

205 all_tracks.rec_stages = np.where( 

206 all_tracks.rec_type 

207 == km3io.definitions.reconstruction.AANET_RECONSTRUCTION_TYPE, 

208 all_tracks.rec_stages + 300, 

209 all_tracks.rec_stages, 

210 ) 

211 

212 # put all tracks into the blob 

213 self._put_tracks_into_blob(blob, all_tracks, "tracks", n_tracks) 

214 

215 # select the best track using the km3io tools 

216 if self.best_tracks: 

217 

218 # check if it contains any of the specific reco types (can be several) 

219 for stage, (best_track, reco_name) in self._best_track_fmap.items(): 

220 if stage in all_tracks.rec_stages: 

221 tracks = best_track(all_tracks) 

222 self._put_tracks_into_blob(blob, tracks, reco_name, 1) 

223 

224 return blob 

225 

226 def _put_tracks_into_blob(self, blob, tracks, reco_identifier, n_tracks): 

227 

228 """ 

229 Put a certain type of "tracks" in the blob and give specific name. 

230 

231 Parameters 

232 ---------- 

233 tracks : awkward array 

234 The tracks object to be put in the blob eventually. Can be only best tracks. 

235 identifier : string 

236 A string to name the kp table. 

237 n_tracks : int 

238 The number of tracks from before. Use to distinguish between best and all tracks. 

239 

240 """ 

241 

242 reco_tracks = dict( 

243 pos_x=tracks.pos_x, 

244 pos_y=tracks.pos_y, 

245 pos_z=tracks.pos_z, 

246 dir_x=tracks.dir_x, 

247 dir_y=tracks.dir_y, 

248 dir_z=tracks.dir_z, 

249 E=tracks.E, 

250 rec_type=tracks.rec_type, 

251 t=tracks.t, 

252 likelihood=tracks.lik, 

253 length=tracks.len, # do all recos have this? 

254 ) 

255 

256 if n_tracks != 1: 

257 reco_tracks.update( 

258 id=tracks.id, 

259 idx=np.arange(n_tracks), 

260 ) 

261 

262 n_columns = max(km3io.definitions.fitparameters.values()) + 1 

263 fitinf_array = np.ma.filled( 

264 ak.to_numpy(ak.pad_none(tracks.fitinf, target=n_columns, axis=-1)), 

265 fill_value=np.nan, 

266 ).astype("float32") 

267 fitinf_split = np.split(fitinf_array, fitinf_array.shape[-1], axis=-1) 

268 

269 if n_tracks == 1: 

270 for fitparam, idx in km3io.definitions.fitparameters.items(): 

271 reco_tracks[fitparam] = fitinf_split[idx][0] 

272 

273 else: 

274 for fitparam, idx in km3io.definitions.fitparameters.items(): 

275 reco_tracks[fitparam] = fitinf_split[idx][:, 0] 

276 

277 blob["Reco_" + reco_identifier] = kp.Table( 

278 reco_tracks, 

279 h5loc=f"/reco/" + reco_identifier, 

280 name="Reco " + reco_identifier, 

281 split_h5=self.split, 

282 ) 

283 

284 # write out the rec stages only once with all tracks 

285 if n_tracks != 1: 

286 

287 _rec_stage = np.array(ak.flatten(tracks.rec_stages)._layout) 

288 _counts = ak.count(tracks.rec_stages, axis=1) 

289 _idx = np.repeat(np.arange(n_tracks), _counts) 

290 

291 blob["RecStages"] = kp.Table( 

292 dict(rec_stage=_rec_stage, idx=_idx), 

293 # Just to save space, we specify smaller dtypes. 

294 # We assume there will be never more than 32767 

295 # reco tracks for a single reconstruction type. 

296 dtypes=[("rec_stage", np.int16), ("idx", np.uint16)], 

297 h5loc=f"/reco/rec_stages", 

298 name="Reconstruction Stages", 

299 split_h5=self.split, 

300 ) 

301 

302 

303class EventInfoTabulator(kp.Module): 

304 """ 

305 Create `kp.Table` from event information provided by `km3io`. 

306 

307 """ 

308 

309 def process(self, blob): 

310 

311 # get the sim program 

312 if blob["header"]: 

313 if "simul" in blob["header"].keys(): 

314 sim_program = blob["header"].simul.program 

315 else: # not existent for real data 

316 sim_program = None 

317 else: 

318 sim_program = None 

319 

320 blob["EventInfo"] = self._parse_eventinfo(blob["event"], sim_program) 

321 return blob 

322 

323 def _parse_eventinfo(self, event, sim_program): 

324 wgt1, wgt2, wgt3, wgt4 = self._parse_wgts(event.w) 

325 tab_data = { 

326 "event_id": event.id, 

327 "run_id": event.run_id, 

328 "weight_w1": wgt1, 

329 "weight_w2": wgt2, 

330 "weight_w3": wgt3, 

331 "weight_w4": wgt4, 

332 "timestamp": event.t_sec, 

333 "nanoseconds": event.t_ns, 

334 "mc_time": event.mc_t, 

335 "trigger_mask": event.trigger_mask, 

336 "trigger_counter": event.trigger_counter, 

337 "overlays": event.overlays, 

338 "det_id": event.det_id, 

339 "frame_index": event.frame_index, 

340 "mc_run_id": event.mc_run_id, 

341 } 

342 

343 if sim_program != None: 

344 

345 # unfold the info in the w2list 

346 w2list_dict = self._unfold_w2list(event.w2list, sim_program) 

347 tab_data.update(w2list_dict) 

348 

349 info = kp.Table(tab_data, h5loc="/event_info", name="EventInfo") 

350 return info 

351 

352 def _unfold_w2list(self, w2list, sim_program): 

353 

354 w2list_dict = {} 

355 definitions_dict = {} 

356 

357 if sim_program.lower() == "gseagen": 

358 definitions_dict = km3io.definitions.w2list_gseagen 

359 elif sim_program.lower() == "genhen": 

360 definitions_dict = km3io.definitions.w2list_genhen 

361 # for cases like sim_program == "MUPAGE", the w2list is empty 

362 

363 for key, idx in definitions_dict.items(): 

364 w2list_dict[key] = np.nan if idx >= len(w2list) else w2list[idx] 

365 

366 return w2list_dict 

367 

368 @staticmethod 

369 def _parse_wgts(wgt): 

370 if len(wgt) == 3: 

371 wgt1, wgt2, wgt3 = wgt 

372 wgt4 = np.nan 

373 elif len(wgt) == 4: 

374 # what the hell is w4? 

375 wgt1, wgt2, wgt3, wgt4 = wgt 

376 else: 

377 wgt1 = wgt2 = wgt3 = wgt4 = np.nan 

378 return wgt1, wgt2, wgt3, wgt4 

379 

380 

381class OfflineHeaderTabulator(kp.Module): 

382 def process(self, blob): 

383 if blob["header"]: 

384 blob["RawHeader"] = kp.io.hdf5.header2table(blob["header"]) 

385 return blob