Coverage for src/km3pipe/io/hdf5.py: 82%

622 statements  

« prev     ^ index     » next       coverage.py v7.5.1, created at 2024-05-08 03:14 +0000

1# Filename: hdf5.py 

2# pylint: disable=C0103,R0903,C901 

3# vim:set ts=4 sts=4 sw=4 et: 

4""" 

5Read and write KM3NeT-formatted HDF5 files. 

6 

7""" 

8 

9from collections import OrderedDict, defaultdict, namedtuple 

10from functools import singledispatch 

11import os.path 

12import warnings 

13from uuid import uuid4 

14 

15import numpy as np 

16import tables as tb 

17import km3io 

18from thepipe import Provenance 

19 

20try: 

21 from numba import jit 

22except ImportError: 

23 jit = lambda f: f 

24 

25import km3pipe as kp 

26from thepipe import Module, Blob 

27from km3pipe.dataclasses import Table, NDArray 

28from km3pipe.logger import get_logger 

29from km3pipe.tools import decamelise, camelise, split, istype 

30 

31log = get_logger(__name__) # pylint: disable=C0103 

32 

33__author__ = "Tamas Gal and Moritz Lotze and Michael Moser" 

34__copyright__ = "Copyright 2016, Tamas Gal and the KM3NeT collaboration." 

35__credits__ = [] 

36__license__ = "MIT" 

37__maintainer__ = "Tamas Gal and Moritz Lotze" 

38__email__ = "tgal@km3net.de" 

39__status__ = "Development" 

40 

41FORMAT_VERSION = np.string_("5.1") 

42MINIMUM_FORMAT_VERSION = np.string_("4.1") 

43 

44 

45class H5VersionError(Exception): 

46 pass 

47 

48 

49def check_version(h5file): 

50 try: 

51 version = np.string_(h5file.root._v_attrs.format_version) 

52 except AttributeError: 

53 log.error( 

54 "Could not determine HDF5 format version: '%s'." 

55 "You may encounter unexpected errors! Good luck..." % h5file.filename 

56 ) 

57 return 

58 if split(version, int, np.string_(".")) < split( 

59 MINIMUM_FORMAT_VERSION, int, np.string_(".") 

60 ): 

61 raise H5VersionError( 

62 "HDF5 format version {0} or newer required!\n" 

63 "'{1}' has HDF5 format version {2}.".format( 

64 MINIMUM_FORMAT_VERSION.decode("utf-8"), 

65 h5file.filename, 

66 version.decode("utf-8"), 

67 ) 

68 ) 

69 

70 

71class HDF5Header(object): 

72 """Wrapper class for the `/raw_header` table in KM3HDF5 

73 

74 Parameters 

75 ---------- 

76 data : dict(str, str/tuple/dict/OrderedDict) 

77 The actual header data, consisting of a key and an entry. 

78 If possible, the key will be set as a property and the the values will 

79 be converted to namedtuples (fields sorted by name to ensure consistency 

80 when dictionaries are provided). 

81 

82 """ 

83 

84 def __init__(self, data): 

85 self._data = data 

86 self._user_friendly_data = {} # namedtuples, if possible 

87 self._set_attributes() 

88 

89 def _set_attributes(self): 

90 """Traverse the internal dictionary and set the getters""" 

91 for parameter in list(self._data.keys()): 

92 data = self._data[parameter] 

93 if isinstance(data, dict) or isinstance(data, OrderedDict): 

94 if not all(f.isidentifier() for f in data.keys()): 

95 break 

96 # Create a namedtuple for easier access 

97 field_names, field_values = zip(*data.items()) 

98 sorted_indices = np.argsort(field_names) 

99 clsname = "HeaderEntry" if not parameter.isidentifier() else parameter 

100 nt = namedtuple(clsname, [field_names[i] for i in sorted_indices]) 

101 data = nt(*[field_values[i] for i in sorted_indices]) 

102 if parameter.isidentifier(): 

103 setattr(self, parameter, data) 

104 self._user_friendly_data[parameter] = data 

105 

106 def __getitem__(self, key): 

107 return self._user_friendly_data[key] 

108 

109 def keys(self): 

110 return self._user_friendly_data.keys() 

111 

112 def values(self): 

113 return self._user_friendly_data.values() 

114 

115 def items(self): 

116 return self._user_friendly_data.items() 

117 

118 @classmethod 

119 def from_table(cls, table): 

120 data = OrderedDict() 

121 for i in range(len(table)): 

122 parameter = table["parameter"][i].decode() 

123 field_names = table["field_names"][i].decode().split(" ") 

124 field_values = table["field_values"][i].decode().split(" ") 

125 if field_values == [""]: 

126 log.info("No value for parameter '{}'! Skipping...".format(parameter)) 

127 continue 

128 dtypes = table["dtype"][i].decode() 

129 dtyped_values = [] 

130 for dtype, value in zip(dtypes.split(" "), field_values): 

131 if dtype.startswith("a"): 

132 dtyped_values.append(value) 

133 else: 

134 value = np.fromstring(value, dtype=dtype, sep=" ")[0] 

135 dtyped_values.append(value) 

136 data[parameter] = OrderedDict(zip(field_names, dtyped_values)) 

137 return cls(data) 

138 

139 @classmethod 

140 def from_km3io(cls, header): 

141 if not isinstance(header, km3io.offline.Header): 

142 raise TypeError( 

143 "The given header object is not an instance of km3io.offline.Header" 

144 ) 

145 return cls(header._data) 

146 

147 @classmethod 

148 def from_aanet(cls, table): 

149 data = OrderedDict() 

150 for i in range(len(table)): 

151 parameter = table["parameter"][i].astype(str) 

152 field_names = [n.decode() for n in table["field_names"][i].split()] 

153 field_values = [n.decode() for n in table["field_values"][i].split()] 

154 if field_values in [[b""], []]: 

155 log.info("No value for parameter '{}'! Skipping...".format(parameter)) 

156 continue 

157 dtypes = table["dtype"][i] 

158 dtyped_values = [] 

159 for dtype, value in zip(dtypes.split(), field_values): 

160 if dtype.startswith(b"a"): 

161 dtyped_values.append(value) 

162 else: 

163 value = np.fromstring(value, dtype=dtype, sep=" ")[0] 

164 dtyped_values.append(value) 

165 data[parameter] = OrderedDict(zip(field_names, dtyped_values)) 

166 return cls(data) 

167 

168 @classmethod 

169 def from_hdf5(cls, filename): 

170 with tb.open_file(filename, "r") as f: 

171 try: 

172 table = f.get_node("/raw_header") 

173 except tb.NoSuchNodeError: 

174 msg = f"No header information found in '{filename}'" 

175 raise 

176 return cls.from_pytable(table) 

177 

178 @classmethod 

179 def from_pytable(cls, table): 

180 data = OrderedDict() 

181 for row in table: 

182 parameter = row["parameter"].decode() 

183 field_names = row["field_names"].decode().split(" ") 

184 field_values = row["field_values"].decode().split(" ") 

185 if field_values == [""]: 

186 log.info("No value for parameter '{}'! Skipping...".format(parameter)) 

187 continue 

188 dtypes = row["dtype"].decode() 

189 dtyped_values = [] 

190 for dtype, value in zip(dtypes.split(" "), field_values): 

191 if dtype.startswith("a"): 

192 dtyped_values.append(value) 

193 else: 

194 value = np.fromstring(value, dtype=dtype, sep=" ")[0] 

195 dtyped_values.append(value) 

196 data[parameter] = OrderedDict(zip(field_names, dtyped_values)) 

197 return cls(data) 

198 

199 

200class HDF5IndexTable: 

201 def __init__(self, h5loc, start=0): 

202 self.h5loc = h5loc 

203 self._data = defaultdict(list) 

204 self._index = 0 

205 if start > 0: 

206 self._data["indices"] = [0] * start 

207 self._data["n_items"] = [0] * start 

208 

209 def append(self, n_items): 

210 self._data["indices"].append(self._index) 

211 self._data["n_items"].append(n_items) 

212 self._index += n_items 

213 

214 @property 

215 def data(self): 

216 return self._data 

217 

218 def fillup(self, length): 

219 missing = length - len(self) 

220 self._data["indices"] += [self.data["indices"][-1]] * missing 

221 self._data["n_items"] += [0] * missing 

222 

223 def __len__(self): 

224 return len(self.data["indices"]) 

225 

226 

227class HDF5Sink(Module): 

228 """Write KM3NeT-formatted HDF5 files, event-by-event. 

229 

230 The data can be a ``kp.Table``, a numpy structured array, 

231 a pandas DataFrame, or a simple scalar. 

232 

233 The name of the corresponding H5 table is the decamelised 

234 blob-key, so values which are stored in the blob under `FooBar` 

235 will be written to `/foo_bar` in the HDF5 file. 

236 

237 Parameters 

238 ---------- 

239 filename: str, optional [default: 'dump.h5'] 

240 Where to store the events. 

241 h5file: pytables.File instance, optional [default: None] 

242 Opened file to write to. This is mutually exclusive with filename. 

243 keys: list of strings, optional 

244 List of Blob-keys to write, everything else is ignored. 

245 complib : str [default: zlib] 

246 Compression library that should be used. 

247 'zlib', 'lzf', 'blosc' and all other PyTables filters 

248 are available. 

249 complevel : int [default: 5] 

250 Compression level. 

251 chunksize : int [optional] 

252 Chunksize that should be used for saving along the first axis 

253 of the input array. 

254 flush_frequency: int, optional [default: 500] 

255 The number of iterations to cache tables and arrays before 

256 dumping to disk. 

257 pytab_file_args: dict [optional] 

258 pass more arguments to the pytables File init 

259 n_rows_expected = int, optional [default: 10000] 

260 append: bool, optional [default: False] 

261 reset_group_id: bool, optional [default: True] 

262 Resets the group_id so that it's continuous in the output file. 

263 Use this with care! 

264 

265 Notes 

266 ----- 

267 Provides service write_table(tab, h5loc=None): tab:Table, h5loc:str 

268 The table to write, with ".h5loc" set or to h5loc if specified. 

269 

270 """ 

271 

272 def configure(self): 

273 self.filename = self.get("filename", default="dump.h5") 

274 self.ext_h5file = self.get("h5file") 

275 self.keys = self.get("keys", default=[]) 

276 self.complib = self.get("complib", default="zlib") 

277 self.complevel = self.get("complevel", default=5) 

278 self.chunksize = self.get("chunksize") 

279 self.flush_frequency = self.get("flush_frequency", default=500) 

280 self.pytab_file_args = self.get("pytab_file_args", default=dict()) 

281 self.keep_open = self.get("keep_open") 

282 self._reset_group_id = self.get("reset_group_id", default=True) 

283 self.indices = {} # to store HDF5IndexTables for each h5loc 

284 self._singletons_written = {} 

285 # magic 10000: this is the default of the "expectedrows" arg 

286 # from the tables.File.create_table() function 

287 # at least according to the docs 

288 # might be able to set to `None`, I don't know... 

289 self.n_rows_expected = self.get("n_rows_expected", default=10000) 

290 self.index = 0 

291 self._uuid = str(uuid4()) 

292 

293 self.expose(self.write_table, "write_table") 

294 

295 if self.ext_h5file is not None: 

296 self.h5file = self.ext_h5file 

297 else: 

298 self.h5file = tb.open_file( 

299 self.filename, 

300 mode="w", 

301 title="KM3NeT", 

302 **self.pytab_file_args, 

303 ) 

304 Provenance().record_output( 

305 self.filename, uuid=self._uuid, comment="HDF5Sink output" 

306 ) 

307 self.filters = tb.Filters( 

308 complevel=self.complevel, 

309 shuffle=True, 

310 fletcher32=True, 

311 complib=self.complib, 

312 ) 

313 self._tables = OrderedDict() 

314 self._ndarrays = OrderedDict() 

315 self._ndarrays_cache = defaultdict(list) 

316 

317 def _to_array(self, data, name=None): 

318 if data is None: 

319 return 

320 if np.isscalar(data): 

321 self.log.debug("toarray: is a scalar") 

322 return Table( 

323 {name: np.asarray(data).reshape((1,))}, 

324 h5loc="/misc/{}".format(decamelise(name)), 

325 name=name, 

326 ) 

327 if hasattr(data, "len") and len(data) <= 0: # a bit smelly ;) 

328 self.log.debug("toarray: data has no length") 

329 return 

330 # istype instead isinstance, to avoid heavy pandas import (hmmm...) 

331 if istype(data, "DataFrame"): # noqa 

332 self.log.debug("toarray: pandas dataframe") 

333 data = Table.from_dataframe(data) 

334 return data 

335 

336 def _cache_ndarray(self, arr): 

337 self._ndarrays_cache[arr.h5loc].append(arr) 

338 

339 def _write_ndarrays_cache_to_disk(self): 

340 """Writes all the cached NDArrays to disk and empties the cache""" 

341 for h5loc, arrs in self._ndarrays_cache.items(): 

342 title = arrs[0].title 

343 chunkshape = ( 

344 (self.chunksize,) + arrs[0].shape[1:] 

345 if self.chunksize is not None 

346 else None 

347 ) 

348 

349 arr = NDArray(np.concatenate(arrs), h5loc=h5loc, title=title) 

350 

351 if h5loc not in self._ndarrays: 

352 loc, tabname = os.path.split(h5loc) 

353 ndarr = self.h5file.create_earray( 

354 loc, 

355 tabname, 

356 tb.Atom.from_dtype(arr.dtype), 

357 (0,) + arr.shape[1:], 

358 chunkshape=chunkshape, 

359 title=title, 

360 filters=self.filters, 

361 createparents=True, 

362 ) 

363 self._ndarrays[h5loc] = ndarr 

364 else: 

365 ndarr = self._ndarrays[h5loc] 

366 

367 # for arr_length in (len(a) for a in arrs): 

368 # self._record_index(h5loc, arr_length) 

369 

370 ndarr.append(arr) 

371 

372 self._ndarrays_cache = defaultdict(list) 

373 

374 def write_table(self, table, h5loc=None): 

375 """Write a single table to the HDF5 file, exposed as a service""" 

376 self.log.debug("Writing table %s", table.name) 

377 if h5loc is None: 

378 h5loc = table.h5loc 

379 self._write_table(h5loc, table, table.name) 

380 

381 def _write_table(self, h5loc, arr, title): 

382 level = len(h5loc.split("/")) 

383 

384 if h5loc not in self._tables: 

385 dtype = arr.dtype 

386 if any("U" in str(dtype.fields[f][0]) for f in dtype.fields): 

387 self.log.error( 

388 "Cannot write data to '{}'. Unicode strings are not supported!".format( 

389 h5loc 

390 ) 

391 ) 

392 return 

393 loc, tabname = os.path.split(h5loc) 

394 self.log.debug( 

395 "h5loc '{}', Loc '{}', tabname '{}'".format(h5loc, loc, tabname) 

396 ) 

397 with warnings.catch_warnings(): 

398 warnings.simplefilter("ignore", tb.NaturalNameWarning) 

399 tab = self.h5file.create_table( 

400 loc, 

401 tabname, 

402 chunkshape=self.chunksize, 

403 description=dtype, 

404 title=title, 

405 filters=self.filters, 

406 createparents=True, 

407 expectedrows=self.n_rows_expected, 

408 ) 

409 tab._v_attrs.datatype = title 

410 if level < 5: 

411 self._tables[h5loc] = tab 

412 else: 

413 tab = self._tables[h5loc] 

414 

415 h5_colnames = set(tab.colnames) 

416 tab_colnames = set(arr.dtype.names) 

417 if h5_colnames != tab_colnames: 

418 missing_cols = h5_colnames - tab_colnames 

419 if missing_cols: 

420 self.log.info("Missing columns in table, trying to append NaNs.") 

421 arr = arr.append_columns( 

422 missing_cols, np.full((len(missing_cols), len(arr)), np.nan) 

423 ) 

424 if arr.dtype != tab.dtype: 

425 self.log.error( 

426 "Differing dtypes after appending " 

427 "missing columns to the table! Skipping..." 

428 ) 

429 return 

430 

431 if arr.dtype != tab.dtype: 

432 try: 

433 arr = Table(arr, dtype=tab.dtype) 

434 except ValueError: 

435 self.log.critical( 

436 "Cannot write a table to '%s' since its dtype is " 

437 "different compared to the previous table with the same " 

438 "HDF5 location, which was used to fix the dtype of the " 

439 "HDF5 compund type." % h5loc 

440 ) 

441 raise 

442 

443 tab.append(arr) 

444 

445 if level < 4: 

446 tab.flush() 

447 

448 def _write_separate_columns(self, where, obj, title): 

449 f = self.h5file 

450 loc, group_name = os.path.split(where) 

451 if where not in f: 

452 group = f.create_group(loc, group_name, title, createparents=True) 

453 group._v_attrs.datatype = title 

454 else: 

455 group = f.get_node(where) 

456 

457 for col, (dt, _) in obj.dtype.fields.items(): 

458 data = obj.__array__()[col] 

459 

460 if col not in group: 

461 a = tb.Atom.from_dtype(dt) 

462 arr = f.create_earray( 

463 group, col, a, (0,), col.capitalize(), filters=self.filters 

464 ) 

465 else: 

466 arr = getattr(group, col) 

467 arr.append(data) 

468 

469 # create index table 

470 # if where not in self.indices: 

471 # self.indices[where] = HDF5IndexTable(where + "/_indices", start=self.index) 

472 

473 self._record_index(where, len(data), split=True) 

474 

475 def _process_entry(self, key, entry): 

476 self.log.debug("Inspecting {}".format(key)) 

477 if ( 

478 hasattr(entry, "h5singleton") 

479 and entry.h5singleton 

480 and entry.h5loc in self._singletons_written 

481 ): 

482 self.log.debug( 

483 "Skipping '%s' since it's a singleton and already written." 

484 % entry.h5loc 

485 ) 

486 return 

487 if not hasattr(entry, "h5loc"): 

488 self.log.debug("Ignoring '%s': no h5loc attribute" % key) 

489 return 

490 

491 if isinstance(entry, NDArray): 

492 self._cache_ndarray(entry) 

493 self._record_index(entry.h5loc, len(entry)) 

494 return entry 

495 try: 

496 title = entry.name 

497 except AttributeError: 

498 title = key 

499 

500 if isinstance(entry, Table) and not entry.h5singleton: 

501 if "group_id" not in entry: 

502 entry = entry.append_columns("group_id", self.index) 

503 elif self._reset_group_id: 

504 # reset group_id to the HDF5Sink's continuous counter 

505 entry.group_id = self.index 

506 

507 self.log.debug("h5l: '{}', title '{}'".format(entry.h5loc, title)) 

508 

509 if hasattr(entry, "split_h5") and entry.split_h5: 

510 self.log.debug("Writing into separate columns...") 

511 self._write_separate_columns(entry.h5loc, entry, title=title) 

512 else: 

513 self.log.debug("Writing into single Table...") 

514 self._write_table(entry.h5loc, entry, title=title) 

515 

516 if hasattr(entry, "h5singleton") and entry.h5singleton: 

517 self._singletons_written[entry.h5loc] = True 

518 

519 return entry 

520 

521 def process(self, blob): 

522 written_blob = Blob() 

523 for key, entry in sorted(blob.items()): 

524 if self.keys and key not in self.keys: 

525 self.log.info("Skipping blob, since it's not in the keys list") 

526 continue 

527 self.log.debug("Processing %s", key) 

528 data = self._process_entry(key, entry) 

529 if data is not None: 

530 written_blob[key] = data 

531 

532 if "GroupInfo" not in blob: 

533 gi = Table( 

534 {"group_id": self.index, "blob_length": len(written_blob)}, 

535 h5loc="/group_info", 

536 name="Group Info", 

537 ) 

538 self._process_entry("GroupInfo", gi) 

539 

540 # fill up NDArray indices with 0 entries if needed 

541 if written_blob: 

542 ndarray_h5locs = set(self._ndarrays.keys()).union( 

543 self._ndarrays_cache.keys() 

544 ) 

545 written_h5locs = set( 

546 e.h5loc for e in written_blob.values() if isinstance(e, NDArray) 

547 ) 

548 missing_h5locs = ndarray_h5locs - written_h5locs 

549 for h5loc in missing_h5locs: 

550 self.log.info("Filling up %s with 0 length entry", h5loc) 

551 self._record_index(h5loc, 0) 

552 

553 if not self.index % self.flush_frequency: 

554 self.flush() 

555 

556 self.index += 1 

557 return blob 

558 

559 def _record_index(self, h5loc, count, split=False): 

560 """Add an index entry (optionally create table) for an NDArray h5loc. 

561 

562 Parameters 

563 ---------- 

564 h5loc : str 

565 location in HDF5 

566 count : int 

567 number of elements (can be 0) 

568 split : bool 

569 if it's a split table 

570 

571 """ 

572 suffix = "/_indices" if split else "_indices" 

573 idx_table_h5loc = h5loc + suffix 

574 if idx_table_h5loc not in self.indices: 

575 self.indices[idx_table_h5loc] = HDF5IndexTable( 

576 idx_table_h5loc, start=self.index 

577 ) 

578 

579 idx_tab = self.indices[idx_table_h5loc] 

580 idx_tab.append(count) 

581 

582 def flush(self): 

583 """Flush tables and arrays to disk""" 

584 self.log.info("Flushing tables and arrays to disk...") 

585 for tab in self._tables.values(): 

586 tab.flush() 

587 self._write_ndarrays_cache_to_disk() 

588 

589 def finish(self): 

590 self.flush() 

591 self.h5file.root._v_attrs.km3pipe = np.string_(kp.__version__) 

592 self.h5file.root._v_attrs.pytables = np.string_(tb.__version__) 

593 self.h5file.root._v_attrs.kid = np.string_(self._uuid) 

594 self.h5file.root._v_attrs.format_version = np.string_(FORMAT_VERSION) 

595 self.log.info("Adding index tables.") 

596 for where, idx_tab in self.indices.items(): 

597 # any skipped NDArrays or split groups will be filled with 0 entries 

598 idx_tab.fillup(self.index) 

599 

600 self.log.debug("Creating index table for '%s'" % where) 

601 h5loc = idx_tab.h5loc 

602 self.log.info(" -> {0}".format(h5loc)) 

603 indices = Table( 

604 {"index": idx_tab.data["indices"], "n_items": idx_tab.data["n_items"]}, 

605 h5loc=h5loc, 

606 ) 

607 self._write_table(h5loc, indices, title="Indices") 

608 self.log.info( 

609 "Creating pytables index tables. " "This may take a few minutes..." 

610 ) 

611 for tab in self._tables.values(): 

612 if "frame_id" in tab.colnames: 

613 tab.cols.frame_id.create_index() 

614 if "slice_id" in tab.colnames: 

615 tab.cols.slice_id.create_index() 

616 if "dom_id" in tab.colnames: 

617 tab.cols.dom_id.create_index() 

618 if "event_id" in tab.colnames: 

619 try: 

620 tab.cols.event_id.create_index() 

621 except NotImplementedError: 

622 log.warning( 

623 "Table '{}' has an uint64 column, " 

624 "not indexing...".format(tab._v_name) 

625 ) 

626 if "group_id" in tab.colnames: 

627 try: 

628 tab.cols.group_id.create_index() 

629 except NotImplementedError: 

630 log.warning( 

631 "Table '{}' has an uint64 column, " 

632 "not indexing...".format(tab._v_name) 

633 ) 

634 tab.flush() 

635 

636 if "HDF5MetaData" in self.services: 

637 self.log.info("Writing HDF5 meta data.") 

638 metadata = self.services["HDF5MetaData"] 

639 for name, value in metadata.items(): 

640 self.h5file.set_node_attr("/", name, value) 

641 

642 if not self.keep_open: 

643 self.h5file.close() 

644 self.cprint("HDF5 file written to: {}".format(self.filename)) 

645 

646 

647class HDF5Pump(Module): 

648 """Read KM3NeT-formatted HDF5 files, event-by-event. 

649 

650 Parameters 

651 ---------- 

652 filename: str 

653 From where to read events. Either this OR ``filenames`` needs to be 

654 defined. 

655 skip_version_check: bool [default: False] 

656 Don't check the H5 version. Might lead to unintended consequences. 

657 shuffle: bool, optional [default: False] 

658 Shuffle the group_ids, so that the blobs are mixed up. 

659 shuffle_function: function, optional [default: np.random.shuffle 

660 The function to be used to shuffle the group IDs. 

661 reset_index: bool, optional [default: True] 

662 When shuffle is set to true, reset the group ID - start to count 

663 the group_id by 0. 

664 

665 Notes 

666 ----- 

667 Provides service h5singleton(h5loc): h5loc:str -> kp.Table 

668 Singleton tables for a given HDF5 location. 

669 """ 

670 

671 def configure(self): 

672 self.filename = self.get("filename") 

673 self.skip_version_check = self.get("skip_version_check", default=False) 

674 self.verbose = bool(self.get("verbose")) 

675 self.shuffle = self.get("shuffle", default=False) 

676 self.shuffle_function = self.get("shuffle_function", default=np.random.shuffle) 

677 self.reset_index = self.get("reset_index", default=False) 

678 

679 self.h5file = None 

680 self.cut_mask = None 

681 self.indices = {} 

682 self._tab_indices = {} 

683 self._singletons = {} 

684 self.header = None 

685 self.group_ids = None 

686 self._n_groups = None 

687 self.index = 0 

688 

689 self.h5file = tb.open_file(self.filename, "r") 

690 

691 Provenance().record_input(self.filename, comment="HDF5Pump input") 

692 

693 if not self.skip_version_check: 

694 check_version(self.h5file) 

695 

696 self._read_group_info() 

697 

698 self.expose(self.h5singleton, "h5singleton") 

699 

700 def _read_group_info(self): 

701 h5file = self.h5file 

702 

703 if "/group_info" not in h5file: 

704 self.log.critical("Missing /group_info '%s', aborting..." % h5file.filename) 

705 raise SystemExit 

706 

707 self.log.info("Reading group information from '/group_info'.") 

708 group_info = h5file.get_node("/", "group_info") 

709 self.group_ids = group_info.cols.group_id[:] 

710 self._n_groups = len(self.group_ids) 

711 

712 if "/raw_header" in h5file: 

713 self.log.info("Reading /raw_header") 

714 try: 

715 self.header = HDF5Header.from_pytable(h5file.get_node("/raw_header")) 

716 except TypeError: 

717 self.log.error("Could not parse the raw header, skipping!") 

718 

719 if self.shuffle: 

720 self.log.info("Shuffling group IDs") 

721 self.shuffle_function(self.group_ids) 

722 

723 def h5singleton(self, h5loc): 

724 """Returns the singleton table for a given HDF5 location""" 

725 return self._singletons[h5loc] 

726 

727 def process(self, blob): 

728 self.log.info("Reading blob at index %s" % self.index) 

729 if self.index >= self._n_groups: 

730 self.log.info("All groups are read.") 

731 raise StopIteration 

732 blob = self.get_blob(self.index) 

733 self.index += 1 

734 return blob 

735 

736 def get_blob(self, index): 

737 blob = Blob() 

738 group_id = self.group_ids[index] 

739 

740 # skip groups with separate columns 

741 # and deal with them later 

742 # this should be solved using hdf5 attributes in near future 

743 split_table_locs = [] 

744 ndarray_locs = [] 

745 for tab in self.h5file.walk_nodes(classname="Table"): 

746 h5loc = tab._v_pathname 

747 loc, tabname = os.path.split(h5loc) 

748 if tabname in self.indices: 

749 self.log.info("index table '%s' already read, skip..." % h5loc) 

750 continue 

751 if loc in split_table_locs: 

752 self.log.info("get_blob: '%s' is noted, skip..." % h5loc) 

753 continue 

754 if tabname == "_indices": 

755 self.log.debug("get_blob: found index table '%s'" % h5loc) 

756 split_table_locs.append(loc) 

757 self.indices[loc] = self.h5file.get_node(h5loc) 

758 continue 

759 if tabname.endswith("_indices"): 

760 self.log.debug("get_blob: found index table '%s' for NDArray" % h5loc) 

761 ndarr_loc = h5loc.replace("_indices", "") 

762 ndarray_locs.append(ndarr_loc) 

763 if ndarr_loc in self.indices: 

764 self.log.info( 

765 "index table for NDArray '%s' already read, skip..." % ndarr_loc 

766 ) 

767 continue 

768 _index_table = self.h5file.get_node(h5loc) 

769 self.indices[ndarr_loc] = { 

770 "index": _index_table.col("index")[:], 

771 "n_items": _index_table.col("n_items")[:], 

772 } 

773 continue 

774 tabname = camelise(tabname) 

775 

776 if "group_id" in tab.dtype.names: 

777 try: 

778 if h5loc not in self._tab_indices: 

779 self._read_tab_indices(h5loc) 

780 tab_idx_start = self._tab_indices[h5loc][0][group_id] 

781 tab_n_items = self._tab_indices[h5loc][1][group_id] 

782 if tab_n_items == 0: 

783 continue 

784 arr = tab[tab_idx_start : tab_idx_start + tab_n_items] 

785 except IndexError: 

786 self.log.debug("No data for h5loc '%s'" % h5loc) 

787 continue 

788 except NotImplementedError: 

789 # 64-bit unsigned integer columns like ``group_id`` 

790 # are not yet supported in conditions 

791 self.log.debug( 

792 "get_blob: found uint64 column at '{}'...".format(h5loc) 

793 ) 

794 arr = tab.read() 

795 arr = arr[arr["group_id"] == group_id] 

796 except ValueError: 

797 # "there are no columns taking part 

798 # in condition ``group_id == 0``" 

799 self.log.info( 

800 "get_blob: no `%s` column found in '%s'! " 

801 "skipping... " % ("group_id", h5loc) 

802 ) 

803 continue 

804 else: 

805 if h5loc not in self._singletons: 

806 log.info("Caching H5 singleton: {} ({})".format(tabname, h5loc)) 

807 self._singletons[h5loc] = Table( 

808 tab.read(), 

809 h5loc=h5loc, 

810 split_h5=False, 

811 name=tabname, 

812 h5singleton=True, 

813 ) 

814 blob[tabname] = self._singletons[h5loc] 

815 continue 

816 

817 self.log.debug("h5loc: '{}'".format(h5loc)) 

818 tab = Table(arr, h5loc=h5loc, split_h5=False, name=tabname) 

819 if self.shuffle and self.reset_index: 

820 tab.group_id[:] = index 

821 blob[tabname] = tab 

822 

823 # skipped locs are now column wise datasets (usually hits) 

824 # currently hardcoded, in future using hdf5 attributes 

825 # to get the right constructor 

826 for loc in split_table_locs: 

827 # if some events are missing (group_id not continuous), 

828 # this does not work as intended 

829 # idx, n_items = self.indices[loc][group_id] 

830 idx = self.indices[loc].col("index")[group_id] 

831 n_items = self.indices[loc].col("n_items")[group_id] 

832 end = idx + n_items 

833 node = self.h5file.get_node(loc) 

834 columns = (c for c in node._v_children if c != "_indices") 

835 data = {} 

836 for col in columns: 

837 data[col] = self.h5file.get_node(loc + "/" + col)[idx:end] 

838 tabname = camelise(loc.split("/")[-1]) 

839 s_tab = Table(data, h5loc=loc, split_h5=True, name=tabname) 

840 if self.shuffle and self.reset_index: 

841 s_tab.group_id[:] = index 

842 blob[tabname] = s_tab 

843 

844 if self.header is not None: 

845 blob["Header"] = self.header 

846 

847 for ndarr_loc in ndarray_locs: 

848 self.log.info("Reading %s" % ndarr_loc) 

849 try: 

850 idx = self.indices[ndarr_loc]["index"][group_id] 

851 n_items = self.indices[ndarr_loc]["n_items"][group_id] 

852 except IndexError: 

853 continue 

854 end = idx + n_items 

855 ndarr = self.h5file.get_node(ndarr_loc) 

856 ndarr_name = camelise(ndarr_loc.split("/")[-1]) 

857 _ndarr = NDArray( 

858 ndarr[idx:end], h5loc=ndarr_loc, title=ndarr.title, group_id=group_id 

859 ) 

860 if self.shuffle and self.reset_index: 

861 _ndarr.group_id = index 

862 blob[ndarr_name] = _ndarr 

863 

864 return blob 

865 

866 def _read_tab_indices(self, h5loc): 

867 self.log.info("Reading table indices for '{}'".format(h5loc)) 

868 node = self.h5file.get_node(h5loc) 

869 group_ids = None 

870 if "group_id" in node.dtype.names: 

871 group_ids = self.h5file.get_node(h5loc).cols.group_id[:] 

872 else: 

873 self.log.error("No data found in '{}'".format(h5loc)) 

874 return 

875 

876 self._tab_indices[h5loc] = create_index_tuple(group_ids) 

877 

878 def __len__(self): 

879 self.log.info("Opening all HDF5 files to check the number of groups") 

880 n_groups = 0 

881 for filename in self.filenames: 

882 with tb.open_file(filename, "r") as h5file: 

883 group_info = h5file.get_node("/", "group_info") 

884 self.group_ids = group_info.cols.group_id[:] 

885 n_groups += len(self.group_ids) 

886 return n_groups 

887 

888 def __iter__(self): 

889 return self 

890 

891 def __next__(self): 

892 # TODO: wrap that in self._check_if_next_file_is_needed(self.index) 

893 if self.index >= self._n_groups: 

894 self.log.info("All groups are read") 

895 raise StopIteration 

896 blob = self.get_blob(self.index) 

897 self.index += 1 

898 return blob 

899 

900 def __getitem__(self, index): 

901 if isinstance(index, int): 

902 return self.get_blob(index) 

903 elif isinstance(index, slice): 

904 return self._slice_generator(index) 

905 else: 

906 raise TypeError("index must be int or slice") 

907 

908 def _slice_generator(self, index): 

909 """A simple slice generator for iterations""" 

910 start, stop, step = index.indices(len(self)) 

911 for i in range(start, stop, step): 

912 yield self.get_blob(i) 

913 

914 self.filename = None 

915 

916 def _close_h5file(self): 

917 if self.h5file: 

918 self.h5file.close() 

919 

920 def finish(self): 

921 self._close_h5file() 

922 

923 

924@jit(nopython=True) 

925def create_index_tuple(group_ids): 

926 """An helper function to create index tuples for fast lookup in HDF5Pump""" 

927 max_group_id = np.max(group_ids) 

928 

929 start_idx_arr = np.full(max_group_id + 1, 0) 

930 n_items_arr = np.full(max_group_id + 1, 0) 

931 

932 current_group_id = group_ids[0] 

933 current_idx = 0 

934 item_count = 0 

935 

936 for group_id in group_ids: 

937 if group_id != current_group_id: 

938 start_idx_arr[current_group_id] = current_idx 

939 n_items_arr[current_group_id] = item_count 

940 current_idx += item_count 

941 item_count = 0 

942 current_group_id = group_id 

943 item_count += 1 

944 else: 

945 start_idx_arr[current_group_id] = current_idx 

946 n_items_arr[current_group_id] = item_count 

947 

948 return (start_idx_arr, n_items_arr) 

949 

950 

951class HDF5MetaData(Module): 

952 """Metadata to attach to the HDF5 file. 

953 

954 Parameters 

955 ---------- 

956 data: dict 

957 

958 """ 

959 

960 def configure(self): 

961 self.data = self.require("data") 

962 self.expose(self.data, "HDF5MetaData") 

963 

964 

965@singledispatch 

966def header2table(data): 

967 """Convert a header to an `HDF5Header` compliant `kp.Table`""" 

968 print(f"Unsupported header data of type {type(data)}") 

969 

970 

971@header2table.register(dict) 

972def _(header_dict): 

973 if not header_dict: 

974 print("Empty header dictionary.") 

975 return 

976 tab_dict = defaultdict(list) 

977 

978 for parameter, data in header_dict.items(): 

979 fields = [] 

980 values = [] 

981 types = [] 

982 for field_name, field_value in data.items(): 

983 fields.append(field_name) 

984 values.append(str(field_value)) 

985 try: 

986 _ = float(field_value) # noqa 

987 types.append("f4") 

988 except ValueError: 

989 types.append("a{}".format(len(field_value))) 

990 except TypeError: # e.g. values is None 

991 types.append("a{}".format(len(str(field_value)))) 

992 tab_dict["parameter"].append(parameter.encode()) 

993 tab_dict["field_names"].append(" ".join(fields).encode()) 

994 tab_dict["field_values"].append(" ".join(values).encode()) 

995 tab_dict["dtype"].append(" ".join(types).encode()) 

996 log.debug( 

997 "{}: {} {} {}".format( 

998 tab_dict["parameter"][-1], 

999 tab_dict["field_names"][-1], 

1000 tab_dict["field_values"][-1], 

1001 tab_dict["dtype"][-1], 

1002 ) 

1003 ) 

1004 return Table(tab_dict, h5loc="/raw_header", name="RawHeader", h5singleton=True) 

1005 

1006 

1007@header2table.register(km3io.offline.Header) 

1008def _(header): 

1009 out = {} 

1010 for parameter, values in header._data.items(): 

1011 try: 

1012 values = values._asdict() 

1013 except AttributeError: 

1014 # single entry without further parameter name 

1015 # in specification 

1016 values = {parameter + "_0": values} 

1017 out[parameter] = values 

1018 return header2table(out) 

1019 

1020 

1021@header2table.register(HDF5Header) 

1022def _(header): 

1023 return header2table(header._data)