Source code for sharrow.relationships

import ast
import logging

import networkx as nx
import numpy as np
import pandas as pd
import xarray as xr

from .dataset import Dataset, construct

try:
    from ast import unparse
except ImportError:
    from astunparse import unparse as _unparse

    unparse = lambda *args: _unparse(*args).strip("\n")


logger = logging.getLogger("sharrow")

well_known_names = {
    "nb",
    "np",
    "pd",
    "xr",
    "pa",
    "log",
    "exp",
    "log1p",
    "expm1",
    "max",
    "min",
    "piece",
    "hard_sigmoid",
    "transpose_leading",
    "clip",
}


def _require_string(x):
    if not isinstance(x, str):
        raise ValueError("must be string")
    return x


def _iat(source, *, _names=None, _load=False, _index_name=None, **idxs):
    loaders = {}
    if _index_name is None:
        _index_name = "index"
    for k, v in idxs.items():
        if v.ndim == 1:
            loaders[k] = xr.DataArray(v, dims=[_index_name])
        else:
            loaders[k] = xr.DataArray(
                v, dims=[f"{_index_name}{n}" for n in range(v.ndim)]
            )
    if _names:
        ds = source[_names]
    else:
        ds = source
    if _load:
        ds = ds._load()
    return ds.isel(**loaders)


def _at(source, *, _names=None, _load=False, _index_name=None, **idxs):
    loaders = {}
    if _index_name is None:
        _index_name = "index"
    for k, v in idxs.items():
        if v.ndim == 1:
            loaders[k] = xr.DataArray(v, dims=[_index_name])
        else:
            loaders[k] = xr.DataArray(
                v, dims=[f"{_index_name}{n}" for n in range(v.ndim)]
            )
    if _names:
        ds = source[_names]
    else:
        ds = source
    if _load:
        ds = ds._load()
    return ds.sel(**loaders)


def gather(source, indexes):
    """
    Extract values by label on the coordinates indicated by columns of a DataFrame.

    Parameters
    ----------
    source : xarray.DataArray or xarray.Dataset
        The source of the values to extract.
    indexes : Mapping[str, array-like]
        The keys of `indexes` (if given as a dataframe, the column names)
        should match the named dimensions of `source`.  The resulting extracted
        data will have a shape one row per row of `df`, and columns matching
        the data variables in `source`, and each value is looked up by the labels.

    Returns
    -------
    pd.DataFrame
    """
    result = _at(source, **indexes).reset_coords(drop=True)
    return result


def igather(source, positions):
    """
    Extract values by position on the coordinates indicated by columns of a DataFrame.

    Parameters
    ----------
    source : xarray.DataArray or xarray.Dataset
    positions : pd.DataFrame or Mapping[str, array-like]
        The columns (or keys) of `df` should match the named dimensions of
        this Dataset.  The resulting extracted DataFrame will have one row
        per row of `df`, columns matching the data variables in this dataset,
        and each value is looked up by the positions.

    Returns
    -------
    pd.DataFrame
    """
    result = _iat(source, **positions).reset_coords(drop=True)
    return result


def xgather(source, positions, indexes):
    if len(indexes) == 0:
        return igather(source, positions)
    elif len(positions) == 0:
        return gather(source, indexes)
    else:
        return gather(igather(source, positions), indexes)


class Relationship:
    """
    Defines a linkage between datasets in a `DataTree`.
    """

    def __init__(
        self,
        parent_data,
        parent_name,
        child_data,
        child_name,
        indexing="label",
        analog=None,
    ):
        self.parent_data = _require_string(parent_data)
        """str: Name of the parent dataset."""

        self.parent_name = _require_string(parent_name)
        """str: Variable in the parent dataset that references the child dimension."""

        self.child_data = _require_string(child_data)
        """str: Name of the child dataset."""

        self.child_name = _require_string(child_name)
        """str: Dimension in the child dataset that is used by this relationship."""

        if indexing not in {"label", "position"}:
            raise ValueError("indexing must be by label or position")
        self.indexing = indexing
        """str: How the target dimension is used, either by 'label' or 'position'."""

        self.analog = analog
        """str: Original variable that defined label-based relationship before digitization."""

    def __eq__(self, other):
        if isinstance(other, self.__class__):
            return repr(self) == repr(other)

    def __repr__(self):
        return f"<Relationship by {self.indexing}: {self.parent_data}[{self.parent_name!r}] -> {self.child_data}[{self.child_name!r}]>"

    def attrs(self):
        return dict(
            parent_name=self.parent_name,
            child_name=self.child_name,
            indexing=self.indexing,
        )

    @classmethod
    def from_string(cls, s):
        """
        Construct a `Relationship` from a string.

        Parameters
        ----------
        s : str
            The relationship definition.
            To create a label-based relationship, the string should look like
            "ParentNode.variable_name @ ChildNode.dimension_name".  To create
            a position-based relationship, give
            "ParentNode.variable_name -> ChildNode.dimension_name".

        Returns
        -------
        Relationship
        """
        if "->" in s:
            parent, child = s.split("->", 1)
            i = "position"
        elif "@":
            parent, child = s.split("@", 1)
            i = "label"
        p1, p2 = parent.split(".", 1)
        c1, c2 = child.split(".", 1)
        p1 = p1.strip()
        p2 = p2.strip()
        c1 = c1.strip()
        c2 = c2.strip()
        return cls(
            parent_data=p1,
            parent_name=p2,
            child_data=c1,
            child_name=c2,
            indexing=i,
        )


class DataTree:
    """
    A tree representing linked datasets, from which data can flow.

    Parameters
    ----------
    graph : networkx.MultiDiGraph
    root_node_name : str
        The name of the node at the root of the tree.
    extra_funcs : Tuple[Callable]
        Additional functions that can be called by Flow objects created
        using this DataTree.  These functions should have defined `__name__`
        attributes, so they can be called in expressions.
    extra_vars : Mapping[str,Any], optional
        Additional named constants that can be referenced by expressions in
        Flow objects created using this DataTree.
    cache_dir : Path-like, optional
        The default directory where Flow objects are created.
    relationships : Iterable[str or Relationship]
        The relationship definitions used to define this tree.  All dataset
        nodes named in these relationships should also be included as
        keyword arguments for this constructor.
    force_digitization : bool, default False
        Whether to automatically digitize all relationships (converting them
        from label-based to position-based).  Digitization is required to
        evaluate Flows, but doing so automatically on construction may be
        inefficient.
    dim_order : Tuple[str], optional
        The order of dimensions to use in Flow outputs.  Generally only needed
        if there are multiple dimensions in the root dataset.
    """

    DatasetType = Dataset

    def __init__(
        self,
        graph=None,
        root_node_name=None,
        extra_funcs=(),
        extra_vars=None,
        cache_dir=None,
        relationships=(),
        force_digitization=False,
        dim_order=None,
        **kwargs,
    ):
        if isinstance(graph, Dataset):
            raise ValueError("datasets must be given as keyword arguments")

        # raw init
        if graph is None:
            graph = nx.MultiDiGraph()
        self._graph = graph
        self._root_node_name = None
        self.force_digitization = force_digitization
        self.dim_order = dim_order
        self.dim_exclude = set()

        # defined init
        if root_node_name is not None and root_node_name in kwargs:
            self.add_dataset(root_node_name, kwargs[root_node_name])
        self.root_node_name = root_node_name
        self.extra_funcs = extra_funcs
        self.extra_vars = extra_vars or {}
        self.cache_dir = cache_dir
        self._cached_indexes = {}
        for k, v in kwargs.items():
            if root_node_name is not None and k == root_node_name:
                continue
            self.add_dataset(k, v)
        for r in relationships:
            self.add_relationship(r)
        if force_digitization:
            self.digitize_relationships(inplace=True)

    @property
    def shape(self):
        """Tuple[int]: base shape of arrays that will be loaded when using this DataTree."""
        if self.dim_order:
            dim_order = self.dim_order
        else:
            from .flows import presorted

            dim_order = presorted(self.root_dataset.dims, self.dim_order)
        return tuple(
            self.root_dataset.dims[i] for i in dim_order if i not in self.dim_exclude
        )

    def __shallow_copy_extras(self):
        return dict(
            extra_funcs=self.extra_funcs,
            extra_vars=self.extra_vars,
            cache_dir=self.cache_dir,
            force_digitization=self.force_digitization,
        )

    def __repr__(self):
        s = f"<{self.__module__}.{self.__class__.__name__}>"
        if len(self._graph.nodes):
            s += "\n datasets:"
            if self.root_node_name:
                s += f"\n - {self.root_node_name}"
            for k in self._graph.nodes:
                if k == self.root_node_name:
                    continue
                s += f"\n - {k}"
        else:
            s += "\n datasets: none"
        if len(self._graph.edges):
            s += "\n relationships:"
            for e in self._graph.edges:
                s += f"\n - {self._get_relationship(e)!r}".replace(
                    "<Relationship ", ""
                ).rstrip(">")
        else:
            s += "\n relationships: none"
        return s

    def _hash_features(self):
        h = []
        if len(self._graph.nodes):
            if self.root_node_name:
                h.append(f"dataset:{self.root_node_name}")
            for k in self._graph.nodes:
                if k == self.root_node_name:
                    continue
                h.append(f"dataset:{k}")
        else:
            h.append("datasets:none")
        if len(self._graph.edges):
            for e in self._graph.edges:
                r = f"relationship:{self._get_relationship(e)!r}".replace(
                    "<Relationship ", ""
                ).rstrip(">")
                h.append(r)
        else:
            h.append("relationships:none")
        h.append(f"dim_order:{self.dim_order}")
        return h

    @property
    def root_node_name(self):
        """str: The root node for this data tree, which is only ever a parent."""
        if self._root_node_name is None:
            for nodename in self._graph.nodes:
                if self._graph.in_degree(nodename) == 0:
                    self._root_node_name = nodename
                    break
        return self._root_node_name

    @root_node_name.setter
    def root_node_name(self, name):
        if name is None:
            self._root_node_name = None
            return
        if not isinstance(name, str):
            raise TypeError(f"root_node_name must be str not {type(name)}")
        if name not in self._graph.nodes:
            raise KeyError(name)
        self._root_node_name = name

[docs] def add_relationship(self, *args, **kwargs): """ Add a relationship to this DataTree. The new relationship will point from a variable in one dataset to a dimension of another dataset in this tree. Both the parent and the child datasets should already have been added. Parameters ---------- *args, **kwargs All arguments are passed through to the `Relationship` contructor, unless only a single `str` argument is provided, in which case the `Relationship.from_string` class constructor is used. """ if len(args) == 1 and isinstance(args[0], Relationship): r = args[0] elif len(args) == 1 and isinstance(args[0], str): s = args[0] if "->" in s: parent, child = s.split("->", 1) i = "position" elif "@": parent, child = s.split("@", 1) i = "label" p1, p2 = parent.split(".", 1) c1, c2 = child.split(".", 1) p1 = p1.strip() p2 = p2.strip() c1 = c1.strip() c2 = c2.strip() r = Relationship( parent_data=p1, parent_name=p2, child_data=c1, child_name=c2, indexing=i, ) else: r = Relationship(*args, **kwargs) # check for existing relationships, don't duplicate for e in self._graph.edges: r2 = self._get_relationship(e) if r == r2: return # confirm correct pointer r.parent_data = self.finditem(r.parent_name, maybe_in=r.parent_data) self._graph.add_edge(r.parent_data, r.child_data, **r.attrs()) if self.force_digitization: self.digitize_relationships(inplace=True)
def get_relationship(self, parent, child): attrs = self._graph.edges[parent, child] return Relationship(parent_data=parent, child_data=child, **attrs)
[docs] def add_dataset(self, name, dataset, relationships=(), as_root=False): """ Add a new Dataset node to this DataTree. Parameters ---------- name : str dataset : Dataset or pandas.DataFrame Will be coerced into a `Dataset` object if it is not already in that format, using a no-copy process if possible. relationships : Tuple[str or Relationship] Also add these relationships. as_root : bool, default False Set this new node as the root of the tree, displacing any existing root. """ self._graph.add_node(name, dataset=construct(dataset)) if self.root_node_name is None or as_root: self.root_node_name = name if isinstance(relationships, str): relationships = [relationships] for r in relationships: # TODO validate relationships before adding. self.add_relationship(r) if self.force_digitization: self.digitize_relationships(inplace=True)
def add_items(self, items): from collections.abc import Mapping, Sequence if isinstance(items, Sequence): for i in items: self.add_items(i) elif isinstance(items, Mapping): if "name" in items and "dataset" in items: self.add_dataset(items["name"], items["dataset"]) preload = True else: preload = False for k, v in items.items(): if preload and k in {"name", "dataset"}: continue if k == "relationships": for r in v: self.add_relationship(r) else: self.add_dataset(k, v) else: raise ValueError("add_items requires Sequence or Mapping") @property def root_node(self): return self._graph.nodes[self.root_node_name] @property def root_dataset(self): return self._graph.nodes[self.root_node_name]["dataset"] @root_dataset.setter def root_dataset(self, x): from .dataset import Dataset if not isinstance(x, Dataset): x = construct(x) self._graph.nodes[self.root_node_name]["dataset"] = x def _get_relationship(self, edge): return Relationship( parent_data=edge[0], child_data=edge[1], **self._graph.edges[edge] ) def __getitem__(self, item): if isinstance(item, (list, tuple)): from .dataset import Dataset return Dataset({k: self[k] for k in item}) try: return self._getitem(item) except KeyError: return self._getitem(item, include_blank_dims=True) def finditem(self, item, maybe_in=None): if maybe_in is not None and maybe_in in self._graph.nodes: dataset = self._graph.nodes[maybe_in].get("dataset", {}) if item in dataset: return maybe_in return self._getitem(item, just_node_name=True) def _getitem( self, item, include_blank_dims=False, only_dims=False, just_node_name=False ): if isinstance(item, (list, tuple)): from .dataset import Dataset return Dataset({k: self[k] for k in item}) if "." in item: item_in, item = item.split(".", 1) else: item_in = None queue = [self.root_node_name] examined = set() while len(queue): current_node = queue.pop(0) if current_node in examined: continue dataset = self._graph.nodes[current_node].get("dataset", {}) try: by_name = item in dataset and not only_dims except TypeError: by_name = False try: by_dims = not by_name and include_blank_dims and (item in dataset.dims) except TypeError: by_dims = False if (by_name or by_dims) and (item_in is None or item_in == current_node): if just_node_name: return current_node if current_node == self.root_node_name: if by_dims: return xr.DataArray( pd.RangeIndex(dataset.dims[item]), dims=item ) else: return dataset[item] else: _positions = {} _labels = {} if by_dims: if item in dataset.variables: coords = {item: dataset.variables[item]} else: coords = None result = xr.DataArray( pd.RangeIndex(dataset.dims[item]), dims=item, coords=coords, ) else: result = dataset[item] dims_in_result = set(result.dims) for path in nx.algorithms.simple_paths.all_simple_edge_paths( self._graph, self.root_node_name, current_node ): path_dim = self._graph.edges[path[-1]].get("child_name") if path_dim not in dims_in_result: continue # path_indexing = self._graph.edges[path[-1]].get('indexing') t1 = None # intermediate nodes on path for (e, e_next) in zip(path[:-1], path[1:]): r = self._get_relationship(e) r_next = self._get_relationship(e_next) if t1 is None: t1 = self._graph.nodes[r.parent_data].get("dataset") t2 = self._graph.nodes[r.child_data].get("dataset")[ [r_next.parent_name] ] if r.indexing == "label": t1 = t2.sel( {r.child_name: t1[r.parent_name].to_numpy()} ) else: # by position t1 = t2.isel( {r.child_name: t1[r.parent_name].to_numpy()} ) # final node in path e = path[-1] r = Relationship( parent_data=e[0], child_data=e[1], **self._graph.edges[e] ) if t1 is None: t1 = self._graph.nodes[r.parent_data].get("dataset") if r.indexing == "label": _labels[r.child_name] = t1[r.parent_name].to_numpy() else: # by position _idx = t1[r.parent_name].to_numpy() if not np.issubdtype(_idx.dtype, np.integer): _idx = _idx.astype(np.int64) _positions[r.child_name] = _idx y = xgather(result, _positions, _labels) if len(result.dims) == 1 and len(y.dims) == 1: y = y.rename({y.dims[0]: result.dims[0]}) elif len(dims_in_result) == len(y.dims): y = y.rename({_i: _j for _i, _j in zip(y.dims, result.dims)}) return y else: examined.add(current_node) for _, next_up in self._graph.out_edges(current_node): if next_up not in examined: queue.append(next_up) raise KeyError(item) def get_expr(self, expression, engine="sharrow"): """ Access or evaluate an expression. Parameters ---------- expression : str Returns ------- DataArray """ try: result = self[expression] except (KeyError, IndexError): if engine == "sharrow": result = ( self.setup_flow({expression: expression}) .load_dataarray() .isel(expressions=0) ) elif engine == "numexpr": from xarray import DataArray result = DataArray( pd.eval(expression, resolvers=[self], engine="numexpr"), ) return result @property def subspaces(self): """Mapping[str,Dataset] : Direct access to node Dataset objects by name.""" spaces = {} for k in self._graph.nodes: s = self._graph.nodes[k].get("dataset", None) if s is not None: spaces[k] = s return spaces def subspaces_iter(self): for k in self._graph.nodes: s = self._graph.nodes[k].get("dataset", None) if s is not None: yield (k, s) def namespace_names(self): namespace = set() for spacename, spacearrays in self.subspaces_iter(): for k, arr in spacearrays.coords.items(): namespace.add(f"__{spacename or 'base'}__{k}") for k, arr in spacearrays.items(): namespace.add(f"__{spacename or 'base'}__{k}") return namespace @property def dims(self): """ Mapping from dimension names to lengths across all dataset nodes. """ dims = {} for k, v in self.subspaces_iter(): for name, length in v.dims.items(): if name in dims: if dims[name] != length: raise ValueError( "inconsistent dimensions\n" + self.dims_detail() ) else: dims[name] = length return xr.core.utils.Frozen(dims) def dims_detail(self): """ Report on the names and sizes of dimensions in all Dataset nodes. Returns ------- str """ s = "" for k, v in self.subspaces_iter(): s += f"\n{k}:" for name, length in v.dims.items(): s += f"\n - {name}: {length}" return s[1:] def drop_dims(self, dims, inplace=False, ignore_missing_dims=True): """ Drop dimensions from root Dataset node. Parameters ---------- dims : str or Iterable[str] One or more named dimensions to drop. inplace : bool, default False Whether to drop dimensions in-place. ignore_missing_dims : bool, default True Simply ignore any dimensions that are not present. Returns ------- DataTree Returns self if dropping inplace, otherwise returns a copy with dimensions dropped. """ if isinstance(dims, str): dims = [dims] if inplace: obj = self else: obj = self.copy() if not ignore_missing_dims: obj.root_dataset = obj.root_dataset.drop_dims(dims) else: for d in dims: if d in obj.root_dataset.dims: obj.root_dataset = obj.root_dataset.drop_dims(d) obj.dim_order = tuple(x for x in self.dim_order if x not in dims) return obj def get_indexes( self, position_only=True, as_dict=True, replacements=None, use_cache=True, check_shapes=True, ): if use_cache and (position_only, as_dict) in self._cached_indexes: return self._cached_indexes[(position_only, as_dict)] if not position_only: raise NotImplementedError dims = [ d for d in self.dims if d[-1:] != "_" or (d[-1:] == "_" and d[:-1] not in self.dims) ] if replacements is not None: obj = self.replace_datasets(replacements) else: obj = self result = {} result_shape = None for k in sorted(dims): result_k = obj._getitem(k, include_blank_dims=True, only_dims=True) if result_shape is None: result_shape = result_k.shape if result_shape != result_k.shape: if check_shapes: raise ValueError( f"inconsistent index shapes {result_k.shape} v {result_shape} (probably an error on {k} or {sorted(dims)[0]})" ) result[k] = result_k if as_dict: result = {k: v.to_numpy() for k, v in result.items()} else: result = Dataset(result) if use_cache: self._cached_indexes[(position_only, as_dict)] = result return result
[docs] def replace_datasets(self, other=None, validate=True, redigitize=True, **kwargs): """ Replace one or more datasets in the nodes of this tree. Parameters ---------- other : Mapping[str,Dataset] A dictionary of replacement datasets. validate : bool, default True Raise an error when replacing downstream datasets that are referenced by position, unless the replacement is identically sized. If validation is deactivated, and an incompatible dataset is placed in this tree, flows that rely on that relationship will give erroneous results or crash with a segfault. redigitize : bool, default True Automatically re-digitize relationships that are label-based and were previously digitized. **kwargs : Mapping[str,Dataset] Alternative format to `other`. Returns ------- DataTree A new DataTree with data replacements completed. """ replacements = {} if other is not None: replacements.update(other) replacements.update(kwargs) graph = self._graph.copy() for k in replacements: if k not in graph.nodes: raise KeyError(k) x = construct(replacements[k]) if validate: if x.dims != graph.nodes[k]["dataset"].dims: # when replacement dimensions do not match, check for # any upstream nodes that reference this dataset by # position... which will potentially be problematic. for e in self._graph.edges: if e[1] == k: indexing = self._graph.edges[e].get("indexing") if indexing == "position": raise ValueError( f"dimensions mismatch on " f"positionally-referenced dataset {k}: " f"receiving {x.dims} " f"expected {graph.nodes[k]['dataset'].dims}" ) graph.nodes[k]["dataset"] = x result = type(self)(graph, self.root_node_name, **self.__shallow_copy_extras()) if redigitize: result.digitize_relationships(inplace=True) return result
def setup_flow( self, definition_spec, *, cache_dir=None, name=None, dtype="float32", boundscheck=False, error_model="numpy", nopython=True, fastmath=True, parallel=True, readme=None, flow_library=None, extra_hash_data=(), write_hash_audit=True, hashing_level=1, dim_exclude=None, ): """ Set up a new Flow for analysis using the structure of this DataTree. Parameters ---------- definition_spec : Dict[str,str] Gives the names and expressions that define the variables to create in this new `Flow`. cache_dir : Path-like, optional A location to write out generated python and numba code. If not provided, a unique temporary directory is created. name : str, optional The name of this Flow used for writing out cached files. If not provided, a unique name is generated. If `cache_dir` is given, be sure to avoid name conflicts with other flow's in the same directory. dtype : str, default "float32" The name of the numpy dtype that will be used for the output. boundscheck : bool, default False If True, boundscheck enables bounds checking for array indices, and out of bounds accesses will raise IndexError. The default is to not do bounds checking, which is faster but can produce garbage results or segfaults if there are problems, so try turning this on for debugging if you are getting unexplained errors or crashes. error_model : {'numpy', 'python'}, default 'numpy' The error_model option controls the divide-by-zero behavior. Setting it to ‘python’ causes divide-by-zero to raise exception like CPython. Setting it to ‘numpy’ causes divide-by-zero to set the result to +/-inf or nan. nopython : bool, default True Compile using numba's `nopython` mode. Provided for debugging only, as there's little point in turning this off for production code, as all the speed benefits of sharrow will be lost. fastmath : bool, default True If true, fastmath enables the use of "fast" floating point transforms, which can improve performance but can result in tiny distortions in results. See numba docs for details. parallel : bool, default True Enable or disable parallel computation for certain functions. readme : str, optional A string to inject as a comment at the top of the flow Python file. flow_library : Mapping[str,Flow], optional An in-memory cache of precompiled Flow objects. Using this can result in performance improvements when repeatedly using the same definitions. extra_hash_data : Tuple[Hashable], optional Additional data used for generating the flow hash. Useful to prevent conflicts when using a flow_library with multiple similar flows. write_hash_audit : bool, default True Writes a hash audit log into a comment in the flow Python file, for debugging purposes. hashing_level : int, default 1 Level of detail to write into flow hashes. Increase detail to avoid hash conflicts for similar flows. Level 2 adds information about names used in expressions and digital encodings to the flow hash, which prevents conflicts but requires more pre-computation to generate the hash. dim_exclude : Collection[str], optional Exclude these root dataset dimensions from this flow. Returns ------- Flow """ from .flows import Flow return Flow( self, definition_spec, cache_dir=cache_dir or self.cache_dir, name=name, dtype=dtype, boundscheck=boundscheck, nopython=nopython, fastmath=fastmath, parallel=parallel, readme=readme, flow_library=flow_library, extra_hash_data=extra_hash_data, hashing_level=hashing_level, error_model=error_model, write_hash_audit=write_hash_audit, dim_order=self.dim_order, dim_exclude=dim_exclude, ) def _spill(self, all_name_tokens=()): """ Write backup code for sharrow-lite. Parameters ---------- all_name_tokens Returns ------- """ cmds = [] return "\n".join(cmds) def get_named_array(self, mangled_name): if mangled_name[:2] != "__": raise KeyError(mangled_name) name1, name2 = mangled_name[2:].split("__", 1) dataset = self._graph.nodes[name1].get("dataset") return dataset[name2].to_numpy() _BY_OFFSET = "digitizedOffset"
[docs] def digitize_relationships(self, inplace=False, redigitize=True): """ Convert all label-based relationships into position-based. Parameters ---------- inplace : bool, default False redigitize : bool, default True Re-compute position-based relationships from labels, even if the relationship had previously been digitized. Returns ------- DataTree or None Only returns a copy if not digitizing in-place. """ if inplace: obj = self else: obj = self.copy() for e in obj._graph.edges: r = obj._get_relationship(e) if redigitize and r.analog: p_dataset = obj._graph.nodes[r.parent_data].get("dataset", None) if p_dataset is not None: if r.parent_name not in p_dataset: r.indexing = "label" r.parent_name = r.analog if r.indexing == "label": p_dataset = obj._graph.nodes[r.parent_data].get("dataset", None) c_dataset = obj._graph.nodes[r.child_data].get("dataset", None) upstream = p_dataset[r.parent_name] downstream = c_dataset[r.child_name] # vectorize version mapper = {i: j for (j, i) in enumerate(downstream.to_numpy())} offsets = xr.apply_ufunc(np.vectorize(mapper.get), upstream) # candidate name for write back r_parent_name_new = ( f"{self._BY_OFFSET}{r.parent_name}_{r.child_data}_{r.child_name}" ) # it is common to have mirrored offsets in various dimensions. # we'd like to retain only the same data in memory once, so we'll # check if these offsets match any existing ones and if so just # point to that memory. for k in p_dataset: if isinstance(k, str) and k.startswith(self._BY_OFFSET): if p_dataset[k].equals(offsets): # we found a match, so we'll assign this name to # the match's memory storage instead of replicating it. obj._graph.nodes[r.parent_data][ "dataset" ] = p_dataset.assign({r_parent_name_new: p_dataset[k]}) # r_parent_name_new = k break else: # no existing offset arrays match, make this new one obj._graph.nodes[r.parent_data]["dataset"] = p_dataset.assign( {r_parent_name_new: offsets} ) obj._graph.edges[e].update( dict( parent_name=r_parent_name_new, indexing="position", analog=r.parent_name, ) ) if not inplace: return obj
@property def relationships_are_digitized(self): """bool : Whether all relationships are digital (by position).""" for e in self._graph.edges: r = self._get_relationship(e) if r.indexing != "position": return False return True def _arg_tokenizer(self, spacename, spacearray, exclude_dims=None): if spacename == self.root_node_name: root_dataset = self.root_dataset from .flows import presorted root_dims = list(presorted(root_dataset.dims, self.dim_order, exclude_dims)) if isinstance(spacearray, str): from_dims = root_dataset[spacearray].dims else: from_dims = spacearray.dims return tuple( ast.parse(f"_arg{root_dims.index(dim):02}", mode="eval").body for dim in from_dims ) if isinstance(spacearray, str): from_dims = self._graph.nodes[spacename]["dataset"][spacearray].dims else: from_dims = spacearray.dims tokens = [] for dimname in from_dims: for e in self._graph.in_edges(spacename, keys=True): this_dim_name = self._graph.edges[e]["child_name"] if dimname != this_dim_name: continue parent_name = self._graph.edges[e]["parent_name"] parent_data = e[0] upside_ast = self._arg_tokenizer( parent_data, parent_name, exclude_dims=exclude_dims ) try: upside = ", ".join(unparse(t) for t in upside_ast) except: # noqa: E722 for t in upside_ast: print(f"t:{t}") raise tokens.append(f"__{parent_data}__{parent_name}[{upside}]") result = [] for t in tokens: result.append(ast.parse(t, mode="eval").body) return tuple(result) @property def coords(self): return self.root_dataset.coords def copy(self): return type(self)( self._graph.copy(), self.root_node_name, **self.__shallow_copy_extras() )