Source code for lena.structures.root_graphs
import array
import lena.context
import lena.flow
def _list_to_array(coords, type_code):
return array.array(type_code, (coord for coord in coords))
[docs]class root_graph_errors():
"""2-dimensional ROOT graph with errors.
This is an adapter for
`TGraphErrors <https://root.cern.ch/doc/master/classTGraphErrors.html>`_
and contains that graph as a field *root_graph*.
"""
def __init__(self, graph, type_code='d'):
"""*graph* is a Lena :class:`.graph`.
*type_code* is the basic numeric type of array values
(by default double). 'f' means floating values.
See Python module
`array <https://docs.python.org/3/library/array.html>`_
for more options.
.. versionadded:: 0.5
"""
import ROOT
if graph.dim != 2:
raise lena.core.LenaValueError(
"graph dimension must be 2"
)
errors = graph._parsed_error_names
# this is not possible, because we forbid suffixes
# if len(errors) > 2:
# raise lena.core.LenaValueError(
# "graph contains too many error fields (maximum is 2)"
# )
x_coord = graph.field_names[0]
y_coord = graph.field_names[1]
x_error = ROOT.nullptr
y_error = ROOT.nullptr
error_x_ind = 0
error_y_ind = 0
for err in errors:
if err[2]:
# errors for unknown coordinates
# are forbidden in graph itself.
raise lena.core.LenaValueError(
"error suffixes are not allowed"
)
error_ind = err[3]
if err[1] == x_coord:
x_error = graph.coords[error_ind]
error_x_ind = error_ind
elif err[1] == y_coord:
y_error = graph.coords[error_ind]
error_y_ind = error_ind
self._error_x_ind = error_x_ind
self._error_y_ind = error_y_ind
n_points = len(graph.coords[0])
xs = _list_to_array(graph.coords[0], type_code)
ys = _list_to_array(graph.coords[1], type_code)
exs = ROOT.nullptr
eys = ROOT.nullptr
if x_error:
exs = _list_to_array(x_error, type_code)
if y_error:
eys = _list_to_array(y_error, type_code)
self.root_graph = ROOT.TGraphErrors(n_points, xs, ys, exs, eys)
def _arrays(self):
import ROOT
# not a class field, because it can't be pickled
rg = self.root_graph
arrays = [
# all these values are pointers,
# so they can't be pickled.
rg.GetX(),
rg.GetY(),
]
if self._error_x_ind:
arrays.append(rg.GetEX())
if self._error_y_ind:
arrays.append(rg.GetEY())
return arrays
def __eq__(self, other):
if not isinstance(other, root_graph_errors):
return False
# looks they can't be compared directly
# return self.root_graph == other.root_graph
# error indices are the same
if (self._error_x_ind != other._error_x_ind
or self._error_y_ind != other._error_y_ind):
return False
# pointwise comparison
return list(self) == list(other)
def __iter__(self):
npoints = self.root_graph.GetN()
for ind in range(npoints):
res = tuple((arr[ind] for arr in self._arrays()))
yield res
def __len__(self):
return self.root_graph.GetN()
def _update_context(self, context):
error_x_ind = self._error_x_ind
error_y_ind = self._error_y_ind
if error_x_ind:
lena.context.update_recursively(
context, "error.x.index", error_x_ind
)
if error_y_ind:
lena.context.update_recursively(
context, "error.y.index", error_y_ind
)
[docs]class ROOTGraphErrors():
"""Element to convert graphs to :class:`.root_graph_errors`."""
[docs] def __call__(self, value):
"""Convert data part of the value
(which must be a :class:`.graph`)
to :class:`.root_graph_errors`.
.. versionadded:: 0.5
"""
graph, context = lena.flow.get_data_context(value)
return (root_graph_errors(graph), context)