diff --git a/hamilton/graph.py b/hamilton/graph.py
index 6e6caa864..6f7138547 100644
--- a/hamilton/graph.py
+++ b/hamilton/graph.py
@@ -356,6 +356,10 @@ def _get_node_type(n: node.Node) -> str:
else:
return "function"
+ def _is_async_node(n: node.Node) -> bool:
+ """Returns whether a DAG node is backed by an async callable."""
+ return n.callable is not None and inspect.iscoroutinefunction(n.callable)
+
def _get_node_style(node_type: str) -> dict[str, str]:
"""Get the style of a node type.
Graphviz needs values to be strings.
@@ -408,6 +412,8 @@ def _get_function_modifier_style(modifier: str) -> dict[str, str]:
modifier_style = dict(style="filled,diagonals")
elif modifier == "materializer":
modifier_style = dict(shape="cylinder")
+ elif modifier == "async":
+ modifier_style = dict(fillcolor="#CDB4DB", style="rounded,filled,bold")
elif modifier == "field":
modifier_style = dict(fillcolor="#c8dae0", fontname="Courier")
elif modifier == "cluster":
@@ -457,6 +463,7 @@ def _get_legend(
"config",
"input",
"function",
+ "async",
"cluster",
"field",
"output",
@@ -565,6 +572,11 @@ def _get_legend(
node_style.update(**modifier_style)
seen_node_types.add("materializer")
+ if _is_async_node(n):
+ modifier_style = _get_function_modifier_style("async")
+ node_style.update(**modifier_style)
+ seen_node_types.add("async")
+
# apply custom styles before node modifiers
seen_node_type = None
if custom_style_function:
diff --git a/tests/test_graph.py b/tests/test_graph.py
index bbf473d79..5ed802990 100644
--- a/tests/test_graph.py
+++ b/tests/test_graph.py
@@ -1273,6 +1273,37 @@ def test_create_graphviz_graph():
assert dot_set == expected_set
+def test_create_graphviz_graph_styles_async_nodes():
+ async def async_node() -> int:
+ return 1
+
+ def sync_node(async_node: int) -> int:
+ return async_node + 1
+
+ module = ad_hoc_utils.create_temporary_module(async_node, sync_node)
+ fg = graph.FunctionGraph.from_modules(module, config={})
+
+ digraph = graph.create_graphviz_graph(
+ set(fg.get_nodes()),
+ "Dependency Graph\n",
+ graphviz_kwargs={},
+ node_modifiers={},
+ strictly_display_only_nodes_passed_in=False,
+ config={},
+ )
+ dot_source = str(digraph)
+
+ assert (
+ "\tasync_node [label=<async_node
int> "
+ 'fillcolor="#CDB4DB" fontname=Helvetica margin=0.15 shape=rectangle '
+ 'style="rounded,filled,bold"]'
+ ) in dot_source
+ assert (
+ '\t\tasync [fillcolor="#CDB4DB" fontname=Helvetica margin=0.15 '
+ 'shape=rectangle style="rounded,filled,bold"]'
+ ) in dot_source
+
+
def test_create_networkx_graph():
"""Tests that we create a networkx graph"""
fg = graph.FunctionGraph.from_modules(tests.resources.dummy_functions, config={})