Coverage for src/km3pipe/dataclasses.py: 99%
329 statements
« prev ^ index » next coverage.py v7.4.4, created at 2024-04-23 03:15 +0000
« prev ^ index » next coverage.py v7.4.4, created at 2024-04-23 03:15 +0000
1# Filename: dataclasses.py
2# pylint: disable=W0232,C0103,C0111
3# vim:set ts=4 sts=4 sw=4 et syntax=python:
4"""
5Dataclasses for internal use. Heavily based on Numpy arrays.
6"""
7import itertools
9import numpy as np
10from numpy.lib import recfunctions as rfn
12from .dataclass_templates import TEMPLATES
13from .logger import get_logger
14from .tools import istype
16__author__ = "Tamas Gal and Moritz Lotze"
17__copyright__ = "Copyright 2016, Tamas Gal and the KM3NeT collaboration."
18__credits__ = []
19__license__ = "MIT"
20__maintainer__ = "Tamas Gal and Moritz Lotze"
21__email__ = "tgal@km3net.de"
22__status__ = "Development"
23__all__ = ("Table", "is_structured", "has_structured_dt", "inflate_dtype")
25DEFAULT_H5LOC = "/misc"
26DEFAULT_NAME = "Generic Table"
27DEFAULT_SPLIT = False
28DEFAULT_H5SINGLETON = False
30log = get_logger(__name__)
33def has_structured_dt(arr):
34 """Check if the array representation has a structured dtype."""
35 arr = np.asanyarray(arr)
36 return is_structured(arr.dtype)
39def is_structured(dt):
40 """Check if the dtype is structured."""
41 if not hasattr(dt, "fields"):
42 return False
43 return dt.fields is not None
46def inflate_dtype(arr, names):
47 """Create structured dtype from a 2d ndarray with unstructured dtype."""
48 arr = np.asanyarray(arr)
49 if has_structured_dt(arr):
50 return arr.dtype
51 s_dt = arr.dtype
52 dt = [(n, s_dt) for n in names]
53 dt = np.dtype(dt)
54 return dt
57class Table(np.recarray):
58 """2D generic Table with grouping index.
60 This is a `np.recarray` subclass with some metadata and helper methods.
62 You can initialize it directly from a structured numpy array,
63 a pandas DataFrame, a dictionary of (columnar) arrays; or, initialize it
64 from a list of rows/list of columns using the appropriate factory.
66 This class adds the following to ``np.recarray``:
68 Parameters
69 ----------
70 data: array-like or dict(array-like)
71 numpy array with structured/flat dtype, or dict of arrays.
72 h5loc: str
73 Location in HDF5 file where to store the data. [default: '/misc']
74 h5singleton: bool
75 Tables defined as h5singletons are only written once to an HDF5 file.
76 This is used for headers for example (default=False).
77 dtype: numpy dtype
78 Datatype over array. If not specified and data is an unstructured
79 array, ``names`` needs to be specified. [default: None]
81 Attributes
82 ----------
83 h5loc: str
84 HDF5 group where to write into. (default='/misc')
85 split_h5: bool
86 Split the array into separate arrays, column-wise, when saving
87 to hdf5? (default=False)
88 name: str
89 Human-readable name, e.g. 'Hits'
90 h5singleton: bool
91 Tables defined as h5singletons are only written once to an HDF5 file.
92 This is used for headers for example (default=False).
94 Methods
95 -------
96 from_dict(arr_dict, dtype=None, **kwargs)
97 Create an Table from a dict of arrays (similar to pandas).
98 from_template(data, template, **kwargs)
99 Create an array from a dict of arrays with a predefined dtype.
100 sorted(by)
101 Sort the table by one of its columns.
102 append_columns(colnames, values)
103 Append new columns to the table.
104 to_dataframe()
105 Return as pandas dataframe.
106 from_dataframe(df, **kwargs)
107 Instantiate from a dataframe.
108 from_rows(list_of_rows, **kwargs)
109 Instantiate from an array-like with shape (n_rows, n_columns).
110 from_columns(list_of_columns, **kwargs)
111 Instantiate from an array-like with shape (n_columns, n_rows).
112 """
114 def __new__(
115 cls,
116 data,
117 h5loc=DEFAULT_H5LOC,
118 dtype=None,
119 split_h5=DEFAULT_SPLIT,
120 name=DEFAULT_NAME,
121 h5singleton=DEFAULT_H5SINGLETON,
122 **kwargs
123 ):
124 if isinstance(data, dict):
125 return cls.from_dict(
126 data,
127 h5loc=h5loc,
128 dtype=dtype,
129 split_h5=split_h5,
130 name=name,
131 h5singleton=h5singleton,
132 **kwargs
133 )
134 if istype(data, "DataFrame"):
135 return cls.from_dataframe(
136 data,
137 h5loc=h5loc,
138 dtype=dtype,
139 split_h5=split_h5,
140 name=name,
141 h5singleton=h5singleton,
142 **kwargs
143 )
144 if isinstance(data, (list, tuple)):
145 raise ValueError(
146 "Lists/tuples are not supported! "
147 "Please use the `from_rows` or `from_columns` method instead!"
148 )
149 if isinstance(data, np.record):
150 # single record from recarrary/kp.Tables, let's blow it up
151 data = data[np.newaxis]
152 if not has_structured_dt(data):
153 # flat (nonstructured) dtypes fail miserably!
154 # default to `|V8` whyever
155 raise ValueError(
156 "Arrays without structured dtype are not supported! "
157 "Please use the `from_rows` or `from_columns` method instead!"
158 )
160 if dtype is None:
161 dtype = data.dtype
163 assert is_structured(dtype)
165 if dtype != data.dtype:
166 dtype_names = set(dtype.names)
167 data_dtype_names = set(data.dtype.names)
168 if dtype_names == data_dtype_names:
169 if not all(dtype[f] == data.dtype[f] for f in dtype_names):
170 log.critical(
171 "dtype mismatch! Matching field names but differing "
172 "field types, no chance to reorder.\n"
173 "dtype of data: %s\n"
174 "requested dtype: %s" % (data.dtype, dtype)
175 )
176 raise ValueError("dtype mismatch")
177 log.once(
178 "dtype mismatch, but matching field names and types. "
179 "Rordering input data...",
180 identifier=h5loc,
181 )
182 data = Table({f: data[f] for f in dtype_names}, dtype=dtype)
183 else:
184 log.critical(
185 "dtype mismatch, no chance to reorder due to differing "
186 "fields!\n"
187 "dtype of data: %s\n"
188 "requested dtype: %s" % (data.dtype, dtype)
189 )
190 raise ValueError("dtype mismatch")
192 obj = np.asanyarray(data, dtype=dtype).view(cls)
193 obj.h5loc = h5loc
194 obj.split_h5 = split_h5
195 obj.name = name
196 obj.h5singleton = h5singleton
197 return obj
199 def __array_finalize__(self, obj):
200 if obj is None:
201 # called from explicit contructor
202 return obj
203 # views or slices
204 self.h5loc = getattr(obj, "h5loc", DEFAULT_H5LOC)
205 self.split_h5 = getattr(obj, "split_h5", DEFAULT_SPLIT)
206 self.name = getattr(obj, "name", DEFAULT_NAME)
207 self.h5singleton = getattr(obj, "h5singleton", DEFAULT_H5SINGLETON)
208 # attribute access returns void instances on slicing/iteration
209 # kudos to
210 # https://github.com/numpy/numpy/issues/3581#issuecomment-108957200
211 if obj is not None and type(obj) is not type(self):
212 self.dtype = np.dtype((np.record, obj.dtype))
214 def __array_wrap__(self, out_arr, context=None):
215 # then just call the parent
216 return Table(
217 np.recarray.__array_wrap__(self, out_arr, context),
218 h5loc=self.h5loc,
219 split_h5=self.split_h5,
220 name=self.name,
221 h5singleton=self.h5singleton,
222 )
224 @staticmethod
225 def _expand_scalars(arr_dict):
226 scalars = []
227 maxlen = 1 # have at least 1-elem arrays
228 for k, v in arr_dict.items():
229 if np.isscalar(v):
230 scalars.append(k)
231 continue
232 # TODO: this is not covered yet, don't know if we need this
233 # if hasattr(v, 'shape') and v.shape == (1,): # np.array([1])
234 # import pdb; pdb.set_trace()
235 # arr_dict[k] = v[0]
236 # continue
237 if hasattr(v, "ndim") and v.ndim == 0: # np.array(1)
238 arr_dict[k] = v.item()
239 continue
240 if len(v) > maxlen:
241 maxlen = len(v)
242 for s in scalars:
243 arr_dict[s] = np.full(maxlen, arr_dict[s])
244 return arr_dict
246 @classmethod
247 def from_dict(cls, arr_dict, dtype=None, fillna=False, **kwargs):
248 """Generate a table from a dictionary of arrays."""
249 arr_dict = arr_dict.copy()
250 # i hope order of keys == order or values
251 if dtype is None:
252 names = sorted(list(arr_dict.keys()))
253 else:
254 dtype = np.dtype(dtype)
255 dt_names = [f for f in dtype.names]
256 dict_names = [k for k in arr_dict.keys()]
257 missing_names = set(dt_names) - set(dict_names)
258 if missing_names:
259 if fillna:
260 dict_names = dt_names
261 for missing_name in missing_names:
262 arr_dict[missing_name] = np.nan
263 else:
264 raise KeyError("Dictionary keys and dtype fields do not match!")
265 names = list(dtype.names)
267 arr_dict = cls._expand_scalars(arr_dict)
268 data = [arr_dict[key] for key in names]
269 return cls(np.rec.fromarrays(data, names=names, dtype=dtype), **kwargs)
271 @classmethod
272 def from_columns(cls, column_list, dtype=None, colnames=None, **kwargs):
273 if dtype is None or not is_structured(dtype):
274 # infer structured dtype from array data + column names
275 if colnames is None:
276 raise ValueError(
277 "Need to either specify column names or a "
278 "structured dtype when passing unstructured arrays!"
279 )
280 dtype = inflate_dtype(column_list, colnames)
281 colnames = dtype.names
282 if len(column_list) != len(dtype.names):
283 raise ValueError("Number of columns mismatch between data and dtype!")
284 data = {k: column_list[i] for i, k in enumerate(dtype.names)}
285 return cls(data, dtype=dtype, colnames=colnames, **kwargs)
287 @classmethod
288 def from_rows(cls, row_list, dtype=None, colnames=None, **kwargs):
289 if dtype is None or not is_structured(dtype):
290 # infer structured dtype from array data + column names
291 if colnames is None:
292 raise ValueError(
293 "Need to either specify column names or a "
294 "structured dtype when passing unstructured arrays!"
295 )
296 dtype = inflate_dtype(row_list, colnames)
297 # this *should* have been checked above, but do this
298 # just to be sure in case I screwed up the logic above;
299 # users will never see this, this should only show in tests
300 assert is_structured(dtype)
301 data = np.asanyarray(row_list).view(dtype)
302 # drop useless 2nd dim
303 data = data.reshape((data.shape[0],))
304 return cls(data, **kwargs)
306 @property
307 def templates_avail(self):
308 return sorted(list(TEMPLATES.keys()))
310 @classmethod
311 def from_template(cls, data, template):
312 """Create a table from a predefined datatype.
314 See the ``templates_avail`` property for available names.
316 Parameters
317 ----------
318 data
319 Data in a format that the ``__init__`` understands.
320 template: str or dict
321 Name of the dtype template to use from ``kp.dataclasses_templates``
322 or a ``dict`` containing the required attributes (see the other
323 templates for reference).
324 """
325 name = DEFAULT_NAME
326 if isinstance(template, str):
327 name = template
328 table_info = TEMPLATES[name]
329 else:
330 table_info = template
331 if "name" in table_info:
332 name = table_info["name"]
333 dt = table_info["dtype"]
334 loc = table_info["h5loc"]
335 split = table_info["split_h5"]
336 h5singleton = table_info["h5singleton"]
338 return cls(
339 data,
340 h5loc=loc,
341 dtype=dt,
342 split_h5=split,
343 name=name,
344 h5singleton=h5singleton,
345 )
347 @staticmethod
348 def _check_column_length(values, n):
349 values = np.atleast_2d(values)
350 for v in values:
351 if len(v) == n:
352 continue
353 else:
354 raise ValueError(
355 "Trying to append more than one column, but "
356 "some arrays mismatch in length!"
357 )
359 def append_columns(self, colnames, values, **kwargs):
360 """Append new columns to the table.
362 When appending a single column, ``values`` can be a scalar or an
363 array of either length 1 or the same length as this array (the one
364 it's appended to). In case of multiple columns, values must have
365 the shape ``list(arrays)``, and the dimension of each array
366 has to match the length of this array.
368 See the docs for ``numpy.lib.recfunctions.append_fields`` for an
369 explanation of the remaining options.
370 """
371 n = len(self)
372 if np.isscalar(values):
373 values = np.full(n, values)
375 values = np.atleast_1d(values)
376 if not isinstance(colnames, str) and len(colnames) > 1:
377 values = np.atleast_2d(values)
378 self._check_column_length(values, n)
380 if values.ndim == 1:
381 if len(values) > n:
382 raise ValueError("New Column is longer than existing table!")
383 elif len(values) > 1 and len(values) < n:
384 raise ValueError(
385 "New Column is shorter than existing table, "
386 "but not just one element!"
387 )
388 elif len(values) == 1:
389 values = np.full(n, values[0])
390 new_arr = rfn.append_fields(
391 self, colnames, values, usemask=False, asrecarray=True, **kwargs
392 )
393 return self.__class__(
394 new_arr,
395 h5loc=self.h5loc,
396 split_h5=self.split_h5,
397 name=self.name,
398 h5singleton=self.h5singleton,
399 )
401 def drop_columns(self, colnames, **kwargs):
402 """Drop columns from the table.
404 See the docs for ``numpy.lib.recfunctions.drop_fields`` for an
405 explanation of the remaining options.
406 """
407 new_arr = rfn.drop_fields(
408 self, colnames, usemask=False, asrecarray=True, **kwargs
409 )
410 return self.__class__(
411 new_arr,
412 h5loc=self.h5loc,
413 split_h5=self.split_h5,
414 name=self.name,
415 h5singleton=self.h5singleton,
416 )
418 def sorted(self, by, **kwargs):
419 """Sort array by a column.
421 Parameters
422 ==========
423 by: str
424 Name of the columns to sort by(e.g. 'time').
425 """
426 sort_idc = np.argsort(self[by], **kwargs)
427 return self.__class__(
428 self[sort_idc], h5loc=self.h5loc, split_h5=self.split_h5, name=self.name
429 )
431 def to_dataframe(self):
432 from pandas import DataFrame
434 return DataFrame(self)
436 @classmethod
437 def from_dataframe(cls, df, **kwargs):
438 rec = df.to_records(index=False)
439 return cls(rec, **kwargs)
441 @classmethod
442 def merge(cls, tables, fillna=False):
443 """Merge a list of tables"""
444 cols = set(itertools.chain(*[table.dtype.descr for table in tables]))
446 tables_to_merge = []
447 for table in tables:
448 missing_cols = cols - set(table.dtype.descr)
450 if missing_cols:
451 if fillna:
452 n = len(table)
453 n_cols = len(missing_cols)
454 col_names = []
455 for col_name, col_dtype in missing_cols:
456 if "f" not in col_dtype:
457 raise ValueError(
458 "Cannot create NaNs for non-float"
459 " type column '{}'".format(col_name)
460 )
461 col_names.append(col_name)
463 table = table.append_columns(
464 col_names, np.full((n_cols, n), np.nan)
465 )
466 else:
467 raise ValueError(
468 "Table columns do not match. Use fill_na=True"
469 " if you want to append missing values with NaNs"
470 )
471 tables_to_merge.append(table)
473 first_table = tables_to_merge[0]
475 merged_table = sum(tables_to_merge[1:], first_table)
477 merged_table.h5loc = first_table.h5loc
478 merged_table.h5singleton = first_table.h5singleton
479 merged_table.split_h5 = first_table.split_h5
480 merged_table.name = first_table.name
482 return merged_table
484 def __add__(self, other):
485 cols1 = set(self.dtype.descr)
486 cols2 = set(other.dtype.descr)
487 if len(cols1 ^ cols2) != 0:
488 cols1 = set(self.dtype.names)
489 cols2 = set(other.dtype.names)
490 if len(cols1 ^ cols2) == 0:
491 raise NotImplementedError
492 else:
493 raise TypeError("Table columns do not match")
494 col_order = list(self.dtype.names)
495 ret = self.copy()
496 len_self = len(self)
497 len_other = len(other)
498 final_length = len_self + len_other
499 ret.resize(final_length, refcheck=False)
500 ret[len_self:] = other[col_order]
501 return Table(
502 ret,
503 h5loc=self.h5loc,
504 h5singleton=self.h5singleton,
505 split_h5=self.split_h5,
506 name=self.name,
507 )
509 def __str__(self):
510 name = self.name
511 spl = "split" if self.split_h5 else "no split"
512 s = "{} {}\n".format(name, type(self))
513 s += "HDF5 location: {} ({})\n".format(self.h5loc, spl)
514 s += "\n".join(
515 map(
516 lambda d: "{1} (dtype: {2}) = {0}".format(self[d[0]], *d),
517 self.dtype.descr,
518 )
519 )
520 return s
522 def __repr__(self):
523 s = "{} {} (rows: {})".format(self.name, type(self), self.size)
524 return s
526 def __contains__(self, elem):
527 return elem in self.dtype.names
529 @property
530 def pos(self):
531 return np.array([self.pos_x, self.pos_y, self.pos_z]).T
533 @pos.setter
534 def pos(self, arr):
535 arr = np.atleast_2d(arr)
536 assert arr.shape[1] == 3
537 assert len(arr) == len(self)
538 self.pos_x = arr[:, 0]
539 self.pos_y = arr[:, 1]
540 self.pos_z = arr[:, 2]
542 @property
543 def dir(self):
544 return np.array([self.dir_x, self.dir_y, self.dir_z]).T
546 @dir.setter
547 def dir(self, arr):
548 arr = np.atleast_2d(arr)
549 assert arr.shape[1] == 3
550 assert len(arr) == len(self)
551 self.dir_x = arr[:, 0]
552 self.dir_y = arr[:, 1]
553 self.dir_z = arr[:, 2]
555 @property
556 def phi(self):
557 from km3pipe.math import phi_separg
559 return phi_separg(self.dir_x, self.dir_y)
561 @property
562 def theta(self):
563 from km3pipe.math import theta_separg
565 return theta_separg(self.dir_z)
567 @property
568 def zenith(self):
569 from km3pipe.math import neutrino_to_source_direction
571 _, zen = neutrino_to_source_direction(self.phi, self.theta)
572 return zen
574 @property
575 def azimuth(self):
576 from km3pipe.math import neutrino_to_source_direction
578 azi, _ = neutrino_to_source_direction(self.phi, self.theta)
579 return azi
581 @property
582 def triggered_rows(self):
583 if not hasattr(self, "triggered"):
584 raise KeyError("Table has no 'triggered' column!")
585 return self[self.triggered.astype(bool)]
588class NDArray(np.ndarray):
589 """Array with HDF5 metadata."""
591 def __new__(cls, array, dtype=None, order=None, **kwargs):
592 obj = np.asarray(array, dtype=dtype, order=order).view(cls)
593 h5loc = kwargs.get("h5loc", "/misc")
594 title = kwargs.get("title", "Unnamed NDArray")
595 group_id = kwargs.get("group_id", None)
596 obj.h5loc = h5loc
597 obj.title = title
598 obj.group_id = group_id
599 return obj
601 def __array_finalize__(self, obj):
602 if obj is None:
603 return
604 self.h5loc = getattr(obj, "h5loc", None)
605 self.title = getattr(obj, "title", None)
606 self.group_id = getattr(obj, "group_id", None)
609class Vec3(object):
610 def __init__(self, x, y, z):
611 self.x = x
612 self.y = y
613 self.z = z
615 def __add__(self, other):
616 return Vec3(*np.add(self, other))
618 def __radd__(self, other):
619 return Vec3(*np.add(other, self))
621 def __sub__(self, other):
622 return Vec3(*np.subtract(self, other))
624 def __rsub__(self, other):
625 return Vec3(*np.subtract(other, self))
627 def __mul__(self, other):
628 return Vec3(*np.multiply(self, other))
630 def __rmul__(self, other):
631 return Vec3(*np.multiply(other, self))
633 def __div__(self, other):
634 return self.__truediv__(other)
636 def __truediv__(self, other):
637 return Vec3(*np.divide(self, other))
639 def __array__(self, dtype=None):
640 if dtype is not None:
641 return np.array([self.x, self.y, self.z], dtype=dtype)
642 else:
643 return np.array([self.x, self.y, self.z])
645 def __getitem__(self, index):
646 return self.__array__()[index]