Coverage for src/km3pipe/dataclasses.py: 99%

329 statements  

« prev     ^ index     » next       coverage.py v7.4.4, created at 2024-04-23 03:15 +0000

1# Filename: dataclasses.py 

2# pylint: disable=W0232,C0103,C0111 

3# vim:set ts=4 sts=4 sw=4 et syntax=python: 

4""" 

5Dataclasses for internal use. Heavily based on Numpy arrays. 

6""" 

7import itertools 

8 

9import numpy as np 

10from numpy.lib import recfunctions as rfn 

11 

12from .dataclass_templates import TEMPLATES 

13from .logger import get_logger 

14from .tools import istype 

15 

16__author__ = "Tamas Gal and Moritz Lotze" 

17__copyright__ = "Copyright 2016, Tamas Gal and the KM3NeT collaboration." 

18__credits__ = [] 

19__license__ = "MIT" 

20__maintainer__ = "Tamas Gal and Moritz Lotze" 

21__email__ = "tgal@km3net.de" 

22__status__ = "Development" 

23__all__ = ("Table", "is_structured", "has_structured_dt", "inflate_dtype") 

24 

25DEFAULT_H5LOC = "/misc" 

26DEFAULT_NAME = "Generic Table" 

27DEFAULT_SPLIT = False 

28DEFAULT_H5SINGLETON = False 

29 

30log = get_logger(__name__) 

31 

32 

33def has_structured_dt(arr): 

34 """Check if the array representation has a structured dtype.""" 

35 arr = np.asanyarray(arr) 

36 return is_structured(arr.dtype) 

37 

38 

39def is_structured(dt): 

40 """Check if the dtype is structured.""" 

41 if not hasattr(dt, "fields"): 

42 return False 

43 return dt.fields is not None 

44 

45 

46def inflate_dtype(arr, names): 

47 """Create structured dtype from a 2d ndarray with unstructured dtype.""" 

48 arr = np.asanyarray(arr) 

49 if has_structured_dt(arr): 

50 return arr.dtype 

51 s_dt = arr.dtype 

52 dt = [(n, s_dt) for n in names] 

53 dt = np.dtype(dt) 

54 return dt 

55 

56 

57class Table(np.recarray): 

58 """2D generic Table with grouping index. 

59 

60 This is a `np.recarray` subclass with some metadata and helper methods. 

61 

62 You can initialize it directly from a structured numpy array, 

63 a pandas DataFrame, a dictionary of (columnar) arrays; or, initialize it 

64 from a list of rows/list of columns using the appropriate factory. 

65 

66 This class adds the following to ``np.recarray``: 

67 

68 Parameters 

69 ---------- 

70 data: array-like or dict(array-like) 

71 numpy array with structured/flat dtype, or dict of arrays. 

72 h5loc: str 

73 Location in HDF5 file where to store the data. [default: '/misc'] 

74 h5singleton: bool 

75 Tables defined as h5singletons are only written once to an HDF5 file. 

76 This is used for headers for example (default=False). 

77 dtype: numpy dtype 

78 Datatype over array. If not specified and data is an unstructured 

79 array, ``names`` needs to be specified. [default: None] 

80 

81 Attributes 

82 ---------- 

83 h5loc: str 

84 HDF5 group where to write into. (default='/misc') 

85 split_h5: bool 

86 Split the array into separate arrays, column-wise, when saving 

87 to hdf5? (default=False) 

88 name: str 

89 Human-readable name, e.g. 'Hits' 

90 h5singleton: bool 

91 Tables defined as h5singletons are only written once to an HDF5 file. 

92 This is used for headers for example (default=False). 

93 

94 Methods 

95 ------- 

96 from_dict(arr_dict, dtype=None, **kwargs) 

97 Create an Table from a dict of arrays (similar to pandas). 

98 from_template(data, template, **kwargs) 

99 Create an array from a dict of arrays with a predefined dtype. 

100 sorted(by) 

101 Sort the table by one of its columns. 

102 append_columns(colnames, values) 

103 Append new columns to the table. 

104 to_dataframe() 

105 Return as pandas dataframe. 

106 from_dataframe(df, **kwargs) 

107 Instantiate from a dataframe. 

108 from_rows(list_of_rows, **kwargs) 

109 Instantiate from an array-like with shape (n_rows, n_columns). 

110 from_columns(list_of_columns, **kwargs) 

111 Instantiate from an array-like with shape (n_columns, n_rows). 

112 """ 

113 

114 def __new__( 

115 cls, 

116 data, 

117 h5loc=DEFAULT_H5LOC, 

118 dtype=None, 

119 split_h5=DEFAULT_SPLIT, 

120 name=DEFAULT_NAME, 

121 h5singleton=DEFAULT_H5SINGLETON, 

122 **kwargs 

123 ): 

124 if isinstance(data, dict): 

125 return cls.from_dict( 

126 data, 

127 h5loc=h5loc, 

128 dtype=dtype, 

129 split_h5=split_h5, 

130 name=name, 

131 h5singleton=h5singleton, 

132 **kwargs 

133 ) 

134 if istype(data, "DataFrame"): 

135 return cls.from_dataframe( 

136 data, 

137 h5loc=h5loc, 

138 dtype=dtype, 

139 split_h5=split_h5, 

140 name=name, 

141 h5singleton=h5singleton, 

142 **kwargs 

143 ) 

144 if isinstance(data, (list, tuple)): 

145 raise ValueError( 

146 "Lists/tuples are not supported! " 

147 "Please use the `from_rows` or `from_columns` method instead!" 

148 ) 

149 if isinstance(data, np.record): 

150 # single record from recarrary/kp.Tables, let's blow it up 

151 data = data[np.newaxis] 

152 if not has_structured_dt(data): 

153 # flat (nonstructured) dtypes fail miserably! 

154 # default to `|V8` whyever 

155 raise ValueError( 

156 "Arrays without structured dtype are not supported! " 

157 "Please use the `from_rows` or `from_columns` method instead!" 

158 ) 

159 

160 if dtype is None: 

161 dtype = data.dtype 

162 

163 assert is_structured(dtype) 

164 

165 if dtype != data.dtype: 

166 dtype_names = set(dtype.names) 

167 data_dtype_names = set(data.dtype.names) 

168 if dtype_names == data_dtype_names: 

169 if not all(dtype[f] == data.dtype[f] for f in dtype_names): 

170 log.critical( 

171 "dtype mismatch! Matching field names but differing " 

172 "field types, no chance to reorder.\n" 

173 "dtype of data: %s\n" 

174 "requested dtype: %s" % (data.dtype, dtype) 

175 ) 

176 raise ValueError("dtype mismatch") 

177 log.once( 

178 "dtype mismatch, but matching field names and types. " 

179 "Rordering input data...", 

180 identifier=h5loc, 

181 ) 

182 data = Table({f: data[f] for f in dtype_names}, dtype=dtype) 

183 else: 

184 log.critical( 

185 "dtype mismatch, no chance to reorder due to differing " 

186 "fields!\n" 

187 "dtype of data: %s\n" 

188 "requested dtype: %s" % (data.dtype, dtype) 

189 ) 

190 raise ValueError("dtype mismatch") 

191 

192 obj = np.asanyarray(data, dtype=dtype).view(cls) 

193 obj.h5loc = h5loc 

194 obj.split_h5 = split_h5 

195 obj.name = name 

196 obj.h5singleton = h5singleton 

197 return obj 

198 

199 def __array_finalize__(self, obj): 

200 if obj is None: 

201 # called from explicit contructor 

202 return obj 

203 # views or slices 

204 self.h5loc = getattr(obj, "h5loc", DEFAULT_H5LOC) 

205 self.split_h5 = getattr(obj, "split_h5", DEFAULT_SPLIT) 

206 self.name = getattr(obj, "name", DEFAULT_NAME) 

207 self.h5singleton = getattr(obj, "h5singleton", DEFAULT_H5SINGLETON) 

208 # attribute access returns void instances on slicing/iteration 

209 # kudos to 

210 # https://github.com/numpy/numpy/issues/3581#issuecomment-108957200 

211 if obj is not None and type(obj) is not type(self): 

212 self.dtype = np.dtype((np.record, obj.dtype)) 

213 

214 def __array_wrap__(self, out_arr, context=None): 

215 # then just call the parent 

216 return Table( 

217 np.recarray.__array_wrap__(self, out_arr, context), 

218 h5loc=self.h5loc, 

219 split_h5=self.split_h5, 

220 name=self.name, 

221 h5singleton=self.h5singleton, 

222 ) 

223 

224 @staticmethod 

225 def _expand_scalars(arr_dict): 

226 scalars = [] 

227 maxlen = 1 # have at least 1-elem arrays 

228 for k, v in arr_dict.items(): 

229 if np.isscalar(v): 

230 scalars.append(k) 

231 continue 

232 # TODO: this is not covered yet, don't know if we need this 

233 # if hasattr(v, 'shape') and v.shape == (1,): # np.array([1]) 

234 # import pdb; pdb.set_trace() 

235 # arr_dict[k] = v[0] 

236 # continue 

237 if hasattr(v, "ndim") and v.ndim == 0: # np.array(1) 

238 arr_dict[k] = v.item() 

239 continue 

240 if len(v) > maxlen: 

241 maxlen = len(v) 

242 for s in scalars: 

243 arr_dict[s] = np.full(maxlen, arr_dict[s]) 

244 return arr_dict 

245 

246 @classmethod 

247 def from_dict(cls, arr_dict, dtype=None, fillna=False, **kwargs): 

248 """Generate a table from a dictionary of arrays.""" 

249 arr_dict = arr_dict.copy() 

250 # i hope order of keys == order or values 

251 if dtype is None: 

252 names = sorted(list(arr_dict.keys())) 

253 else: 

254 dtype = np.dtype(dtype) 

255 dt_names = [f for f in dtype.names] 

256 dict_names = [k for k in arr_dict.keys()] 

257 missing_names = set(dt_names) - set(dict_names) 

258 if missing_names: 

259 if fillna: 

260 dict_names = dt_names 

261 for missing_name in missing_names: 

262 arr_dict[missing_name] = np.nan 

263 else: 

264 raise KeyError("Dictionary keys and dtype fields do not match!") 

265 names = list(dtype.names) 

266 

267 arr_dict = cls._expand_scalars(arr_dict) 

268 data = [arr_dict[key] for key in names] 

269 return cls(np.rec.fromarrays(data, names=names, dtype=dtype), **kwargs) 

270 

271 @classmethod 

272 def from_columns(cls, column_list, dtype=None, colnames=None, **kwargs): 

273 if dtype is None or not is_structured(dtype): 

274 # infer structured dtype from array data + column names 

275 if colnames is None: 

276 raise ValueError( 

277 "Need to either specify column names or a " 

278 "structured dtype when passing unstructured arrays!" 

279 ) 

280 dtype = inflate_dtype(column_list, colnames) 

281 colnames = dtype.names 

282 if len(column_list) != len(dtype.names): 

283 raise ValueError("Number of columns mismatch between data and dtype!") 

284 data = {k: column_list[i] for i, k in enumerate(dtype.names)} 

285 return cls(data, dtype=dtype, colnames=colnames, **kwargs) 

286 

287 @classmethod 

288 def from_rows(cls, row_list, dtype=None, colnames=None, **kwargs): 

289 if dtype is None or not is_structured(dtype): 

290 # infer structured dtype from array data + column names 

291 if colnames is None: 

292 raise ValueError( 

293 "Need to either specify column names or a " 

294 "structured dtype when passing unstructured arrays!" 

295 ) 

296 dtype = inflate_dtype(row_list, colnames) 

297 # this *should* have been checked above, but do this 

298 # just to be sure in case I screwed up the logic above; 

299 # users will never see this, this should only show in tests 

300 assert is_structured(dtype) 

301 data = np.asanyarray(row_list).view(dtype) 

302 # drop useless 2nd dim 

303 data = data.reshape((data.shape[0],)) 

304 return cls(data, **kwargs) 

305 

306 @property 

307 def templates_avail(self): 

308 return sorted(list(TEMPLATES.keys())) 

309 

310 @classmethod 

311 def from_template(cls, data, template): 

312 """Create a table from a predefined datatype. 

313 

314 See the ``templates_avail`` property for available names. 

315 

316 Parameters 

317 ---------- 

318 data 

319 Data in a format that the ``__init__`` understands. 

320 template: str or dict 

321 Name of the dtype template to use from ``kp.dataclasses_templates`` 

322 or a ``dict`` containing the required attributes (see the other 

323 templates for reference). 

324 """ 

325 name = DEFAULT_NAME 

326 if isinstance(template, str): 

327 name = template 

328 table_info = TEMPLATES[name] 

329 else: 

330 table_info = template 

331 if "name" in table_info: 

332 name = table_info["name"] 

333 dt = table_info["dtype"] 

334 loc = table_info["h5loc"] 

335 split = table_info["split_h5"] 

336 h5singleton = table_info["h5singleton"] 

337 

338 return cls( 

339 data, 

340 h5loc=loc, 

341 dtype=dt, 

342 split_h5=split, 

343 name=name, 

344 h5singleton=h5singleton, 

345 ) 

346 

347 @staticmethod 

348 def _check_column_length(values, n): 

349 values = np.atleast_2d(values) 

350 for v in values: 

351 if len(v) == n: 

352 continue 

353 else: 

354 raise ValueError( 

355 "Trying to append more than one column, but " 

356 "some arrays mismatch in length!" 

357 ) 

358 

359 def append_columns(self, colnames, values, **kwargs): 

360 """Append new columns to the table. 

361 

362 When appending a single column, ``values`` can be a scalar or an 

363 array of either length 1 or the same length as this array (the one 

364 it's appended to). In case of multiple columns, values must have 

365 the shape ``list(arrays)``, and the dimension of each array 

366 has to match the length of this array. 

367 

368 See the docs for ``numpy.lib.recfunctions.append_fields`` for an 

369 explanation of the remaining options. 

370 """ 

371 n = len(self) 

372 if np.isscalar(values): 

373 values = np.full(n, values) 

374 

375 values = np.atleast_1d(values) 

376 if not isinstance(colnames, str) and len(colnames) > 1: 

377 values = np.atleast_2d(values) 

378 self._check_column_length(values, n) 

379 

380 if values.ndim == 1: 

381 if len(values) > n: 

382 raise ValueError("New Column is longer than existing table!") 

383 elif len(values) > 1 and len(values) < n: 

384 raise ValueError( 

385 "New Column is shorter than existing table, " 

386 "but not just one element!" 

387 ) 

388 elif len(values) == 1: 

389 values = np.full(n, values[0]) 

390 new_arr = rfn.append_fields( 

391 self, colnames, values, usemask=False, asrecarray=True, **kwargs 

392 ) 

393 return self.__class__( 

394 new_arr, 

395 h5loc=self.h5loc, 

396 split_h5=self.split_h5, 

397 name=self.name, 

398 h5singleton=self.h5singleton, 

399 ) 

400 

401 def drop_columns(self, colnames, **kwargs): 

402 """Drop columns from the table. 

403 

404 See the docs for ``numpy.lib.recfunctions.drop_fields`` for an 

405 explanation of the remaining options. 

406 """ 

407 new_arr = rfn.drop_fields( 

408 self, colnames, usemask=False, asrecarray=True, **kwargs 

409 ) 

410 return self.__class__( 

411 new_arr, 

412 h5loc=self.h5loc, 

413 split_h5=self.split_h5, 

414 name=self.name, 

415 h5singleton=self.h5singleton, 

416 ) 

417 

418 def sorted(self, by, **kwargs): 

419 """Sort array by a column. 

420 

421 Parameters 

422 ========== 

423 by: str 

424 Name of the columns to sort by(e.g. 'time'). 

425 """ 

426 sort_idc = np.argsort(self[by], **kwargs) 

427 return self.__class__( 

428 self[sort_idc], h5loc=self.h5loc, split_h5=self.split_h5, name=self.name 

429 ) 

430 

431 def to_dataframe(self): 

432 from pandas import DataFrame 

433 

434 return DataFrame(self) 

435 

436 @classmethod 

437 def from_dataframe(cls, df, **kwargs): 

438 rec = df.to_records(index=False) 

439 return cls(rec, **kwargs) 

440 

441 @classmethod 

442 def merge(cls, tables, fillna=False): 

443 """Merge a list of tables""" 

444 cols = set(itertools.chain(*[table.dtype.descr for table in tables])) 

445 

446 tables_to_merge = [] 

447 for table in tables: 

448 missing_cols = cols - set(table.dtype.descr) 

449 

450 if missing_cols: 

451 if fillna: 

452 n = len(table) 

453 n_cols = len(missing_cols) 

454 col_names = [] 

455 for col_name, col_dtype in missing_cols: 

456 if "f" not in col_dtype: 

457 raise ValueError( 

458 "Cannot create NaNs for non-float" 

459 " type column '{}'".format(col_name) 

460 ) 

461 col_names.append(col_name) 

462 

463 table = table.append_columns( 

464 col_names, np.full((n_cols, n), np.nan) 

465 ) 

466 else: 

467 raise ValueError( 

468 "Table columns do not match. Use fill_na=True" 

469 " if you want to append missing values with NaNs" 

470 ) 

471 tables_to_merge.append(table) 

472 

473 first_table = tables_to_merge[0] 

474 

475 merged_table = sum(tables_to_merge[1:], first_table) 

476 

477 merged_table.h5loc = first_table.h5loc 

478 merged_table.h5singleton = first_table.h5singleton 

479 merged_table.split_h5 = first_table.split_h5 

480 merged_table.name = first_table.name 

481 

482 return merged_table 

483 

484 def __add__(self, other): 

485 cols1 = set(self.dtype.descr) 

486 cols2 = set(other.dtype.descr) 

487 if len(cols1 ^ cols2) != 0: 

488 cols1 = set(self.dtype.names) 

489 cols2 = set(other.dtype.names) 

490 if len(cols1 ^ cols2) == 0: 

491 raise NotImplementedError 

492 else: 

493 raise TypeError("Table columns do not match") 

494 col_order = list(self.dtype.names) 

495 ret = self.copy() 

496 len_self = len(self) 

497 len_other = len(other) 

498 final_length = len_self + len_other 

499 ret.resize(final_length, refcheck=False) 

500 ret[len_self:] = other[col_order] 

501 return Table( 

502 ret, 

503 h5loc=self.h5loc, 

504 h5singleton=self.h5singleton, 

505 split_h5=self.split_h5, 

506 name=self.name, 

507 ) 

508 

509 def __str__(self): 

510 name = self.name 

511 spl = "split" if self.split_h5 else "no split" 

512 s = "{} {}\n".format(name, type(self)) 

513 s += "HDF5 location: {} ({})\n".format(self.h5loc, spl) 

514 s += "\n".join( 

515 map( 

516 lambda d: "{1} (dtype: {2}) = {0}".format(self[d[0]], *d), 

517 self.dtype.descr, 

518 ) 

519 ) 

520 return s 

521 

522 def __repr__(self): 

523 s = "{} {} (rows: {})".format(self.name, type(self), self.size) 

524 return s 

525 

526 def __contains__(self, elem): 

527 return elem in self.dtype.names 

528 

529 @property 

530 def pos(self): 

531 return np.array([self.pos_x, self.pos_y, self.pos_z]).T 

532 

533 @pos.setter 

534 def pos(self, arr): 

535 arr = np.atleast_2d(arr) 

536 assert arr.shape[1] == 3 

537 assert len(arr) == len(self) 

538 self.pos_x = arr[:, 0] 

539 self.pos_y = arr[:, 1] 

540 self.pos_z = arr[:, 2] 

541 

542 @property 

543 def dir(self): 

544 return np.array([self.dir_x, self.dir_y, self.dir_z]).T 

545 

546 @dir.setter 

547 def dir(self, arr): 

548 arr = np.atleast_2d(arr) 

549 assert arr.shape[1] == 3 

550 assert len(arr) == len(self) 

551 self.dir_x = arr[:, 0] 

552 self.dir_y = arr[:, 1] 

553 self.dir_z = arr[:, 2] 

554 

555 @property 

556 def phi(self): 

557 from km3pipe.math import phi_separg 

558 

559 return phi_separg(self.dir_x, self.dir_y) 

560 

561 @property 

562 def theta(self): 

563 from km3pipe.math import theta_separg 

564 

565 return theta_separg(self.dir_z) 

566 

567 @property 

568 def zenith(self): 

569 from km3pipe.math import neutrino_to_source_direction 

570 

571 _, zen = neutrino_to_source_direction(self.phi, self.theta) 

572 return zen 

573 

574 @property 

575 def azimuth(self): 

576 from km3pipe.math import neutrino_to_source_direction 

577 

578 azi, _ = neutrino_to_source_direction(self.phi, self.theta) 

579 return azi 

580 

581 @property 

582 def triggered_rows(self): 

583 if not hasattr(self, "triggered"): 

584 raise KeyError("Table has no 'triggered' column!") 

585 return self[self.triggered.astype(bool)] 

586 

587 

588class NDArray(np.ndarray): 

589 """Array with HDF5 metadata.""" 

590 

591 def __new__(cls, array, dtype=None, order=None, **kwargs): 

592 obj = np.asarray(array, dtype=dtype, order=order).view(cls) 

593 h5loc = kwargs.get("h5loc", "/misc") 

594 title = kwargs.get("title", "Unnamed NDArray") 

595 group_id = kwargs.get("group_id", None) 

596 obj.h5loc = h5loc 

597 obj.title = title 

598 obj.group_id = group_id 

599 return obj 

600 

601 def __array_finalize__(self, obj): 

602 if obj is None: 

603 return 

604 self.h5loc = getattr(obj, "h5loc", None) 

605 self.title = getattr(obj, "title", None) 

606 self.group_id = getattr(obj, "group_id", None) 

607 

608 

609class Vec3(object): 

610 def __init__(self, x, y, z): 

611 self.x = x 

612 self.y = y 

613 self.z = z 

614 

615 def __add__(self, other): 

616 return Vec3(*np.add(self, other)) 

617 

618 def __radd__(self, other): 

619 return Vec3(*np.add(other, self)) 

620 

621 def __sub__(self, other): 

622 return Vec3(*np.subtract(self, other)) 

623 

624 def __rsub__(self, other): 

625 return Vec3(*np.subtract(other, self)) 

626 

627 def __mul__(self, other): 

628 return Vec3(*np.multiply(self, other)) 

629 

630 def __rmul__(self, other): 

631 return Vec3(*np.multiply(other, self)) 

632 

633 def __div__(self, other): 

634 return self.__truediv__(other) 

635 

636 def __truediv__(self, other): 

637 return Vec3(*np.divide(self, other)) 

638 

639 def __array__(self, dtype=None): 

640 if dtype is not None: 

641 return np.array([self.x, self.y, self.z], dtype=dtype) 

642 else: 

643 return np.array([self.x, self.y, self.z]) 

644 

645 def __getitem__(self, index): 

646 return self.__array__()[index]