diff --git a/hamilton/graph.py b/hamilton/graph.py index 6e6caa864..6f7138547 100644 --- a/hamilton/graph.py +++ b/hamilton/graph.py @@ -356,6 +356,10 @@ def _get_node_type(n: node.Node) -> str: else: return "function" + def _is_async_node(n: node.Node) -> bool: + """Returns whether a DAG node is backed by an async callable.""" + return n.callable is not None and inspect.iscoroutinefunction(n.callable) + def _get_node_style(node_type: str) -> dict[str, str]: """Get the style of a node type. Graphviz needs values to be strings. @@ -408,6 +412,8 @@ def _get_function_modifier_style(modifier: str) -> dict[str, str]: modifier_style = dict(style="filled,diagonals") elif modifier == "materializer": modifier_style = dict(shape="cylinder") + elif modifier == "async": + modifier_style = dict(fillcolor="#CDB4DB", style="rounded,filled,bold") elif modifier == "field": modifier_style = dict(fillcolor="#c8dae0", fontname="Courier") elif modifier == "cluster": @@ -457,6 +463,7 @@ def _get_legend( "config", "input", "function", + "async", "cluster", "field", "output", @@ -565,6 +572,11 @@ def _get_legend( node_style.update(**modifier_style) seen_node_types.add("materializer") + if _is_async_node(n): + modifier_style = _get_function_modifier_style("async") + node_style.update(**modifier_style) + seen_node_types.add("async") + # apply custom styles before node modifiers seen_node_type = None if custom_style_function: diff --git a/tests/test_graph.py b/tests/test_graph.py index bbf473d79..5ed802990 100644 --- a/tests/test_graph.py +++ b/tests/test_graph.py @@ -1273,6 +1273,37 @@ def test_create_graphviz_graph(): assert dot_set == expected_set +def test_create_graphviz_graph_styles_async_nodes(): + async def async_node() -> int: + return 1 + + def sync_node(async_node: int) -> int: + return async_node + 1 + + module = ad_hoc_utils.create_temporary_module(async_node, sync_node) + fg = graph.FunctionGraph.from_modules(module, config={}) + + digraph = graph.create_graphviz_graph( + set(fg.get_nodes()), + "Dependency Graph\n", + graphviz_kwargs={}, + node_modifiers={}, + strictly_display_only_nodes_passed_in=False, + config={}, + ) + dot_source = str(digraph) + + assert ( + "\tasync_node [label=<async_node

int> " + 'fillcolor="#CDB4DB" fontname=Helvetica margin=0.15 shape=rectangle ' + 'style="rounded,filled,bold"]' + ) in dot_source + assert ( + '\t\tasync [fillcolor="#CDB4DB" fontname=Helvetica margin=0.15 ' + 'shape=rectangle style="rounded,filled,bold"]' + ) in dot_source + + def test_create_networkx_graph(): """Tests that we create a networkx graph""" fg = graph.FunctionGraph.from_modules(tests.resources.dummy_functions, config={})