diff --git a/hamilton/driver.py b/hamilton/driver.py index 541b1da02..76c0a69ce 100644 --- a/hamilton/driver.py +++ b/hamilton/driver.py @@ -16,6 +16,7 @@ # under the License. import abc +import functools import importlib import importlib.util import json @@ -792,6 +793,26 @@ def list_available_variables( results = [Variable.from_node(n) for n in all_nodes] return results + @functools.cached_property + def variables(self) -> dict[str, Variable]: + """Returns all variables in the graph keyed by name.""" + return { + node_name: Variable.from_node(node_) for node_name, node_ in self.graph.nodes.items() + } + + def get_variable(self, name: str) -> Variable: + """Returns a variable by name. + + :param name: Name of the variable to return. + :return: Matching HamiltonNode. + :raises KeyError: If the variable does not exist in this Driver's graph. + """ + return self.variables[name] + + def get_graph(self) -> graph_types.HamiltonGraph: + """Returns the public HamiltonGraph representation for this Driver.""" + return graph_types.HamiltonGraph.from_graph(self.graph) + @capture_function_usage def display_all_functions( self, diff --git a/tests/test_hamilton_driver.py b/tests/test_hamilton_driver.py index ffea258fc..e87a46275 100644 --- a/tests/test_hamilton_driver.py +++ b/tests/test_hamilton_driver.py @@ -216,6 +216,22 @@ def test_driver_variables_exposes_original_function(): assert originating_functions["a"] == (tests.resources.very_simple_dag.b,) # a is an input +def test_driver_variable_lookup(): + dr = Driver({}, tests.resources.very_simple_dag) + + assert set(dr.variables) == {"a", "b"} + assert dr.variables["b"].name == "b" + assert dr.get_variable("a").is_external_input is True + + +def test_driver_get_graph_returns_hamilton_graph(): + dr = Driver({}, tests.resources.very_simple_dag) + + hamilton_graph = dr.get_graph() + + assert hamilton_graph["b"].name == "b" + + @pytest.mark.parametrize( "driver_factory", [