Coverage for src/km3modules/tests/test_mc.py: 100%

67 statements  

« prev     ^ index     » next       coverage.py v7.5.1, created at 2024-05-08 03:14 +0000

1# Filename: test_mc.py 

2# pylint: disable=C0111,R0904,C0103 

3# vim:set ts=4 sts=4 sw=4 et: 

4""" 

5Tests for MC Modules. 

6""" 

7 

8import numpy as np 

9from numpy.testing import assert_array_equal, assert_allclose 

10import pytest 

11 

12from km3pipe import Table, Blob, Pipeline, Module 

13from km3pipe.testing import TestCase 

14 

15from km3modules.mc import MCTimeCorrector, GlobalRandomState 

16 

17__author__ = "Moritz Lotze, Michael Moser" 

18__copyright__ = "Copyright 2018, Tamas Gal and the KM3NeT collaboration." 

19__license__ = "MIT" 

20__maintainer__ = "Tamas Gal, Moritz Lotze" 

21__email__ = "tgal@km3net.de" 

22__status__ = "Development" 

23 

24 

25class TestGlobalRandomState(TestCase): 

26 def test_default_random_state(self): 

27 assertAlmostEqual = self.assertAlmostEqual 

28 

29 class Observer(Module): 

30 def configure(self): 

31 self.i = 0 

32 self.x = [0.3745401188, 0.950714306, 0.7319939418] 

33 

34 def process(self, blob): 

35 assertAlmostEqual(self.x[self.i], np.random.rand()) 

36 self.i += 1 

37 return blob 

38 

39 pipe = Pipeline() 

40 pipe.attach(GlobalRandomState) 

41 pipe.attach(Observer) 

42 pipe.drain(3) 

43 

44 def test_custom_random_state(self): 

45 assertAlmostEqual = self.assertAlmostEqual 

46 

47 class Observer(Module): 

48 def configure(self): 

49 self.i = 0 

50 self.x = [0.221993171, 0.870732306, 0.206719155] 

51 

52 def process(self, blob): 

53 assertAlmostEqual(self.x[self.i], np.random.rand()) 

54 self.i += 1 

55 return blob 

56 

57 pipe = Pipeline() 

58 pipe.attach(GlobalRandomState, seed=5) 

59 pipe.attach(Observer) 

60 pipe.drain(3) 

61 

62 def test_without_pipeline_and_default_state(self): 

63 GlobalRandomState() 

64 numbers = np.arange(1, 50) 

65 np.random.shuffle(numbers) 

66 lotto_numbers = sorted(numbers[:6]) 

67 self.assertListEqual([14, 18, 28, 45, 46, 48], lotto_numbers) 

68 

69 def test_without_pipeline_with_custom_seed(self): 

70 GlobalRandomState(seed=23) 

71 numbers = np.arange(1, 50) 

72 np.random.shuffle(numbers) 

73 lotto_numbers = sorted(numbers[:6]) 

74 self.assertListEqual([14, 15, 18, 19, 33, 44], lotto_numbers) 

75 

76 

77class TestMCConvert(TestCase): 

78 def setUp(self): 

79 self.event_info = Table( 

80 { 

81 "timestamp": 1, 

82 "nanoseconds": 700000000, 

83 "mc_time": 1.74999978e9, 

84 } 

85 ) 

86 

87 self.mc_tracks = Table( 

88 { 

89 "time": 1, 

90 } 

91 ) 

92 

93 self.mc_hits = Table( 

94 { 

95 "time": 30.79, 

96 } 

97 ) 

98 

99 self.blob = Blob( 

100 { 

101 "event_info": self.event_info, 

102 "mc_hits": self.mc_hits, 

103 "mc_tracks": self.mc_tracks, 

104 } 

105 ) 

106 

107 def test_process(self): 

108 corr = MCTimeCorrector( 

109 mc_hits_key="mc_hits", 

110 mc_tracks_key="mc_tracks", 

111 event_info_key="event_info", 

112 ) 

113 newblob = corr.process(self.blob) 

114 assert newblob["mc_hits"] is not None 

115 assert newblob["mc_tracks"] is not None 

116 assert np.allclose(newblob["mc_hits"].time, 49999810.79) 

117 assert np.allclose(newblob["mc_tracks"].time, 49999781)