diff --git a/.github/workflows/onpush.yml b/.github/workflows/onpush.yml index 48b49ec..6dad26d 100644 --- a/.github/workflows/onpush.yml +++ b/.github/workflows/onpush.yml @@ -9,6 +9,7 @@ on: paths-ignore: - 'README.md' - 'CLAUDE.md' + - 'CHANGELOG.md' - 'docs/**' # Manual trigger for re-running CI without a new commit (e.g. after a transient # GitHub Actions hiccup that silently drops a push event): diff --git a/CHANGELOG.md b/CHANGELOG.md index 7202e2c..19b9526 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -2,13 +2,19 @@ --- +## [#34](https://github.com/andre-salvati/databricks-template/pull/34) · 2026-06-05 · feat: standardize silver/gold field names, fix dashboard KPIs, add total_orders + +Dropped `ds_kpi` from the dashboard — all three KPI counters (Total Value, Total Orders, Number of Customers) now bind to `ds_orders` with aggregate expressions so all five filters update them; added a third KPI tile for Total Orders (`COUNT DISTINCT order_id`). +Standardized field names across silver (`curated.order_enriched`) and gold (`report.order_agg`) following four rules: `{entity}_id` suffix, entity-qualified names, `item_*` prefix for item-level fields, no abbreviations; `date` is now cast to `DateType` in silver. +Added `order_enriched_schema` and `order_agg_schema` to `commonSchemas.py` as canonical schemas for silver and gold; all tests and the integration validator import from there instead of inlining definitions. + +--- + ## [#33](https://github.com/andre-salvati/databricks-template/pull/33) · 2026-06-04 · feat: AI/BI dashboard, country in gold layer, randomized seed data -Added `country` to `curated.order_enriched` and `report.order_agg` (and their SDP equivalents) so the gold layer carries the full customer dimension needed for country-based reporting; unit tests updated accordingly. -Added an AI/BI (Lakeview) dashboard with three line charts (total value by date × country, by date × product, by date × category) and a global filter page (date range, country, customer, product, category); uses `make truncate env=X yes=--yes` before first post-deploy run to handle the schema change to `report.order_agg`. -Dashboard JSON (`resources/orders_dashboard.lvdash.json`) and its DAB resource entry are generated by `sdk_generate_template_job.py` at deploy time with the target catalog embedded — both files are gitignored. -Completed README documentation: added "Databricks Dashboards" to the Technologies section, added a dashboard screenshot block, and replaced the placeholder dashboard Features bullet with a full description of the charts and filter panel. -Improved seed data chart visibility: customers are assigned a non-uniform country distribution (US=200, UK=100, DE=50, FR=50, BR=30, CA=25, AU=20, JP=15, MX=7, IN=3) so country lines are clearly separated; `total_item` now scales with `prod_category_id` (category × $15 base + $10 noise), producing a ~6× spread across categories visible in the category chart. +Added `country` to `curated.order_enriched` and `report.order_agg` (and SDP equivalents) so the gold layer carries the full customer dimension needed for country-based reporting. +Added an AI/BI (Lakeview) dashboard with three line charts (total value by date × country, product, and category) and a global filter page; dashboard JSON is generated by `sdk_generate_template_job.py` at deploy time with the target catalog embedded and is gitignored. +Improved seed data chart visibility with a non-uniform country distribution and `total_item` scaling with `prod_category_id` (category × $15 base + $10 noise), producing a ~6× spread across categories. --- diff --git a/CLAUDE.md b/CLAUDE.md index d2a5fe6..5e3d330 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -95,8 +95,8 @@ Medallion schemas (`MEDALLION_SCHEMAS` in `config.py`): Each task's input/output tables are **hardcoded** in the task module (e.g. `raw.customer` → `curated.order_enriched`). The medallion layer is a semantic contract, not a runtime parameter — this is the dbt `ref()` pattern. Don't parameterize the layer; if a task genuinely needs a configurable target, that's a different task. -`curated.order_enriched` columns: `name, country, id_customer, id_order, total, date, product_id, prod_category_id, seq, desc_item, qty, total_item` -`report.order_agg` columns: `name, country, date, product_id, prod_category_id, total_qty, total_value` +`curated.order_enriched` columns: `customer_name, country, customer_id, order_id, order_total, order_date (DateType), product_id, product_category_id, item_seq, item_description, item_quantity, item_total` +`report.order_agg` columns: `customer_name, country, order_date (DateType), product_id, product_category_id, total_quantity, total_value, total_orders` ### Job-level parameters (runtime, overridable per-run) diff --git a/scripts/sdk_generate_template_job.py b/scripts/sdk_generate_template_job.py index c267912..5d44942 100644 --- a/scripts/sdk_generate_template_job.py +++ b/scripts/sdk_generate_template_job.py @@ -480,41 +480,15 @@ def _build_dashboard_json(catalog: str) -> dict: """ return { "datasets": [ - { - "name": "ds_kpi", - "displayName": "KPIs", - "queryLines": [ - f"SELECT ROUND(SUM(total_item), 2) AS total_value, " - f"COUNT(DISTINCT id_order) AS num_orders, " - f"COUNT(DISTINCT id_customer) AS num_customers " - f"FROM {catalog}.curated.order_enriched " - f"WHERE date BETWEEN :date_range.min AND :date_range.max" - ], - "parameters": [ - { - "keyword": "date_range", - "displayName": "Date Range", - "dataType": "DATE", - "complexType": "RANGE", - "defaultSelection": { - "range": { - "dataType": "DATE", - "min": {"value": "now-1y"}, - "max": {"value": "now"}, - } - }, - } - ], - }, { "name": "ds_orders", "displayName": "Orders", "queryLines": [ - f"SELECT CAST(date AS DATE) AS order_date, country, name AS customer, " - f"CAST(product_id AS STRING) AS product_id, CAST(prod_category_id AS STRING) AS category_id, " - f"SUM(total_value) AS total_value " + f"SELECT order_date, country, customer_name AS customer, " + f"CAST(product_id AS STRING) AS product_id, CAST(product_category_id AS STRING) AS category_id, " + f"SUM(total_value) AS total_value, SUM(total_orders) AS total_orders " f"FROM {catalog}.report.order_agg " - f"WHERE date BETWEEN :date_range.min AND :date_range.max " + f"WHERE order_date BETWEEN :date_range.min AND :date_range.max " f"GROUP BY 1, 2, 3, 4, 5" ], "parameters": [ @@ -566,9 +540,9 @@ def _build_dashboard_json(catalog: str) -> dict: { "name": "main_query", "query": { - "datasetName": "ds_kpi", - "fields": [{"name": "total_value", "expression": "`total_value`"}], - "disaggregated": True, + "datasetName": "ds_orders", + "fields": [{"name": "total_value", "expression": "SUM(`total_value`)"}], + "disaggregated": False, }, } ], @@ -583,22 +557,22 @@ def _build_dashboard_json(catalog: str) -> dict: }, { "widget": { - "name": "kpi-num-orders", + "name": "kpi-total-orders", "queries": [ { "name": "main_query", "query": { - "datasetName": "ds_kpi", - "fields": [{"name": "num_orders", "expression": "`num_orders`"}], - "disaggregated": True, + "datasetName": "ds_orders", + "fields": [{"name": "total_orders", "expression": "SUM(`total_orders`)"}], + "disaggregated": False, }, } ], "spec": { "version": 2, "widgetType": "counter", - "encodings": {"value": {"fieldName": "num_orders", "displayName": "Number of Orders"}}, - "frame": {"title": "Number of Orders", "showTitle": True}, + "encodings": {"value": {"fieldName": "total_orders", "displayName": "Total Orders"}}, + "frame": {"title": "Total Orders", "showTitle": True}, }, }, "position": {"x": 2, "y": 2, "width": 2, "height": 3}, @@ -610,9 +584,11 @@ def _build_dashboard_json(catalog: str) -> dict: { "name": "main_query", "query": { - "datasetName": "ds_kpi", - "fields": [{"name": "num_customers", "expression": "`num_customers`"}], - "disaggregated": True, + "datasetName": "ds_orders", + "fields": [ + {"name": "num_customers", "expression": "COUNT(DISTINCT `customer`)"} + ], + "disaggregated": False, }, } ], @@ -784,14 +760,6 @@ def _build_dashboard_json(catalog: str) -> dict: "disaggregated": False, }, }, - { - "name": "q_date_kpi", - "query": { - "datasetName": "ds_kpi", - "parameters": [{"name": "date_range", "keyword": "date_range"}], - "disaggregated": False, - }, - }, ], "spec": { "version": 2, @@ -799,7 +767,6 @@ def _build_dashboard_json(catalog: str) -> dict: "encodings": { "fields": [ {"parameterName": "date_range", "queryName": "q_date"}, - {"parameterName": "date_range", "queryName": "q_date_kpi"}, ] }, "frame": {"showTitle": True, "title": "Date Range"}, diff --git a/src/template/commonSchemas.py b/src/template/commonSchemas.py index 3a55f72..fb744fe 100644 --- a/src/template/commonSchemas.py +++ b/src/template/commonSchemas.py @@ -1,6 +1,9 @@ from pyspark.sql.types import ( + DateType, + DoubleType, FloatType, IntegerType, + LongType, StringType, StructField, StructType, @@ -34,3 +37,33 @@ StructField("total_item", FloatType(), True), ] ) + +order_enriched_schema = StructType( + [ + StructField("customer_name", StringType(), True), + StructField("country", StringType(), True), + StructField("customer_id", IntegerType(), True), + StructField("order_id", IntegerType(), True), + StructField("order_total", FloatType(), True), + StructField("order_date", DateType(), True), + StructField("product_id", IntegerType(), True), + StructField("product_category_id", IntegerType(), True), + StructField("item_seq", IntegerType(), True), + StructField("item_description", StringType(), True), + StructField("item_quantity", IntegerType(), True), + StructField("item_total", FloatType(), True), + ] +) + +order_agg_schema = StructType( + [ + StructField("customer_name", StringType(), True), + StructField("country", StringType(), True), + StructField("order_date", DateType(), True), + StructField("product_id", IntegerType(), True), + StructField("product_category_id", IntegerType(), True), + StructField("total_quantity", LongType(), True), + StructField("total_value", DoubleType(), True), + StructField("total_orders", LongType(), True), + ] +) diff --git a/src/template/job1/generate_orders.py b/src/template/job1/generate_orders.py index 1327605..1985ecf 100644 --- a/src/template/job1/generate_orders.py +++ b/src/template/job1/generate_orders.py @@ -12,18 +12,18 @@ def enrich_order(self, df_customer, df_order, df_order_item): df_order_item.join(df_order, df_order_item["id_order"] == df_order["id"]) .join(df_customer, df_order["id_customer"] == df_customer["id"]) .select( - "name", + df_customer["name"].alias("customer_name"), "country", - "id_customer", - "id_order", - "total", - "date", + df_order["id_customer"].alias("customer_id"), + df_order_item["id_order"].alias("order_id"), + df_order["total"].alias("order_total"), + df_order["date"].cast("date").alias("order_date"), "product_id", - "prod_category_id", - "seq", - "desc_item", - "qty", - "total_item", + df_order["prod_category_id"].alias("product_category_id"), + df_order_item["seq"].alias("item_seq"), + df_order_item["desc_item"].alias("item_description"), + df_order_item["qty"].alias("item_quantity"), + df_order_item["total_item"].alias("item_total"), ) ) diff --git a/src/template/job1/generate_orders_agg.py b/src/template/job1/generate_orders_agg.py index 32b3ca9..1fa8d78 100644 --- a/src/template/job1/generate_orders_agg.py +++ b/src/template/job1/generate_orders_agg.py @@ -1,4 +1,4 @@ -from pyspark.sql.functions import sum +from pyspark.sql.functions import countDistinct, sum from ..baseTask import BaseTask @@ -10,9 +10,10 @@ def __init__(self, config): def aggregate_orders(self, df_order): # TODO code your transformations here... - return df_order.groupBy("name", "country", "date", "product_id", "prod_category_id").agg( - sum("qty").alias("total_qty"), - sum("total_item").alias("total_value"), + return df_order.groupBy("customer_name", "country", "order_date", "product_id", "product_category_id").agg( + sum("item_quantity").alias("total_quantity"), + sum("item_total").alias("total_value"), + countDistinct("order_id").alias("total_orders"), ) def run(self): diff --git a/src/template/job1_sdp/transforms.py b/src/template/job1_sdp/transforms.py index 4228bf8..eb3a2c2 100644 --- a/src/template/job1_sdp/transforms.py +++ b/src/template/job1_sdp/transforms.py @@ -26,24 +26,25 @@ def enrich_order(df_customer: DataFrame, df_order: DataFrame, df_order_item: Dat Returns: Enriched DataFrame with columns: - name, country, id_customer, id_order, total, date, product_id, prod_category_id, seq, desc_item, qty, total_item + customer_name, country, customer_id, order_id, order_total, order_date, product_id, + product_category_id, item_seq, item_description, item_quantity, item_total """ return ( df_order_item.join(df_order, df_order_item["id_order"] == df_order["id"]) .join(df_customer, df_order["id_customer"] == df_customer["id"]) .select( - "name", + df_customer["name"].alias("customer_name"), "country", - "id_customer", - "id_order", - "total", - "date", + df_order["id_customer"].alias("customer_id"), + df_order_item["id_order"].alias("order_id"), + df_order["total"].alias("order_total"), + df_order["date"].cast("date").alias("order_date"), "product_id", - "prod_category_id", - "seq", - "desc_item", - "qty", - "total_item", + df_order["prod_category_id"].alias("product_category_id"), + df_order_item["seq"].alias("item_seq"), + df_order_item["desc_item"].alias("item_description"), + df_order_item["qty"].alias("item_quantity"), + df_order_item["total_item"].alias("item_total"), ) ) @@ -58,10 +59,11 @@ def aggregate_orders(df_order_enriched: DataFrame) -> DataFrame: df_order_enriched: curated.order_enriched Returns: - DataFrame with columns: name, country, date, product_id, prod_category_id, - total_qty (LongType), total_value (DoubleType) + DataFrame with columns: customer_name, country, order_date, product_id, + product_category_id, total_quantity (LongType), total_value (DoubleType), total_orders (LongType) """ - return df_order_enriched.groupBy("name", "country", "date", "product_id", "prod_category_id").agg( - F.sum("qty").alias("total_qty"), - F.sum("total_item").alias("total_value"), + return df_order_enriched.groupBy("customer_name", "country", "order_date", "product_id", "product_category_id").agg( + F.sum("item_quantity").alias("total_quantity"), + F.sum("item_total").alias("total_value"), + F.countDistinct("order_id").alias("total_orders"), ) diff --git a/tests/job1/integration_validate.py b/tests/job1/integration_validate.py index f6e12d4..69a25bc 100644 --- a/tests/job1/integration_validate.py +++ b/tests/job1/integration_validate.py @@ -1,8 +1,10 @@ +from datetime import date + from pyspark.sql import functions as F -from pyspark.sql.types import DoubleType, IntegerType, LongType, StringType, StructField, StructType from pyspark.testing import assertDataFrameEqual from template.baseTask import BaseTask +from template.commonSchemas import order_agg_schema class Validate(BaseTask): @@ -10,23 +12,12 @@ def __init__(self, config): super().__init__(config) def _validate_standard(self, catalog): - # groupBy(name, country, date, product_id, prod_category_id) → still 2 rows (one per order) + # groupBy(customer_name, country, order_date, product_id, product_category_id) → still 2 rows (one per order) expected_data = [ - ("John Doe", "USA", "2023-01-01", 1, 1, 3, 100.0), - ("Jane Smith", "UK", "2023-01-02", 2, 1, 3, 151.0), + ("John Doe", "USA", date(2023, 1, 1), 1, 1, 3, 100.0, 1), + ("Jane Smith", "UK", date(2023, 1, 2), 2, 1, 3, 151.0, 1), ] - expected_schema = StructType( - [ - StructField("name", StringType(), True), - StructField("country", StringType(), True), - StructField("date", StringType(), True), - StructField("product_id", IntegerType(), True), - StructField("prod_category_id", IntegerType(), True), - StructField("total_qty", LongType(), True), - StructField("total_value", DoubleType(), True), - ] - ) - df_expected = self.spark.createDataFrame(expected_data, schema=expected_schema) + df_expected = self.spark.createDataFrame(expected_data, schema=order_agg_schema) for table in (f"{catalog}.report.order_agg", f"{catalog}.report.order_agg_sdp"): df_out = self.spark.table(table) @@ -37,16 +28,19 @@ def _validate_standard(self, catalog): def _validate_load_test(self, catalog): # 500 customers × 100 products × 1 date = 50,000 rows - # Each (customer, product): 40 orders × 3 items × qty=2 → total_qty=240 + # Each (customer, product): 40 orders × 3 items × qty=2 → total_quantity=240 # Each (customer, product): 40 orders × 3 items × total_item=50.0 → total_value=6,000.0 + # Each (customer, product): 40 distinct orders → total_orders=40 for table in (f"{catalog}.report.order_agg", f"{catalog}.report.order_agg_sdp"): df_out = self.spark.table(table) count = df_out.count() if count != 50_000: raise RuntimeError(f"Expected 50,000 rows in {table}, got {count}") - wrong = df_out.filter((F.col("total_qty") != 240) | (F.col("total_value") != 6_000.0)).count() + wrong = df_out.filter( + (F.col("total_quantity") != 240) | (F.col("total_value") != 6_000.0) | (F.col("total_orders") != 40) + ).count() if wrong > 0: - raise RuntimeError(f"{wrong} rows in {table} have unexpected total_qty/total_value") + raise RuntimeError(f"{wrong} rows in {table} have unexpected total_quantity/total_value/total_orders") def run(self): load_test = self.config.get_value("load_test") == "true" diff --git a/tests/job1/unit_test.py b/tests/job1/unit_test.py index 394ae39..ba837a4 100644 --- a/tests/job1/unit_test.py +++ b/tests/job1/unit_test.py @@ -1,4 +1,5 @@ from argparse import Namespace +from datetime import date import pytest from pyspark.sql import * @@ -9,7 +10,13 @@ from template.job1.generate_orders import GenerateOrders from template.job1.generate_orders_agg import GenerateOrdersAgg from template.job1.seed_sources import SeedSources, _INCREMENTAL_ORDERS, _INCREMENTAL_CUSTOMER_UPDATES -from template.commonSchemas import customer_schema, order_schema, order_item_schema +from template.commonSchemas import ( + customer_schema, + order_schema, + order_item_schema, + order_enriched_schema, + order_agg_schema, +) from pyspark.testing import assertDataFrameEqual from pyspark.sql.functions import explode @@ -51,27 +58,11 @@ def df_orders_from_source(spark) -> DataFrame: @pytest.fixture def df_orders(spark) -> DataFrame: orders_data = [ - ("John Doe", "USA", 10, 1, 100.0, "2023-01-01", 1, 1, 1, "Item A", 2, 50.0), - ("John Doe", "USA", 10, 1, 100.0, "2023-01-01", 1, 1, 2, "Item B", 1, 50.0), - ("Jane Smith", "UK", 20, 2, 150.0, "2023-01-02", 2, 2, 1, "Item C", 3, 150.0), + ("John Doe", "USA", 10, 1, 100.0, date(2023, 1, 1), 1, 1, 1, "Item A", 2, 50.0), + ("John Doe", "USA", 10, 1, 100.0, date(2023, 1, 1), 1, 1, 2, "Item B", 1, 50.0), + ("Jane Smith", "UK", 20, 2, 150.0, date(2023, 1, 2), 2, 2, 1, "Item C", 3, 150.0), ] - orders_schema = StructType( - [ - StructField("name", StringType(), True), - StructField("country", StringType(), True), - StructField("id_customer", IntegerType(), True), - StructField("id_order", IntegerType(), True), - StructField("total", FloatType(), True), - StructField("date", StringType(), True), - StructField("product_id", IntegerType(), True), - StructField("prod_category_id", IntegerType(), True), - StructField("seq", IntegerType(), True), - StructField("desc_item", StringType(), True), - StructField("qty", IntegerType(), True), - StructField("total_item", FloatType(), True), - ] - ) - return spark.createDataFrame(orders_data, schema=orders_schema) + return spark.createDataFrame(orders_data, schema=order_enriched_schema) def test_arg_parser(): @@ -181,21 +172,10 @@ def test_aggregate_orders(spark, config, df_orders): assert df_out.count() == 2 expected_data = [ - ("John Doe", "USA", "2023-01-01", 1, 1, 3, 100.0), - ("Jane Smith", "UK", "2023-01-02", 2, 2, 3, 150.0), + ("John Doe", "USA", date(2023, 1, 1), 1, 1, 3, 100.0, 1), + ("Jane Smith", "UK", date(2023, 1, 2), 2, 2, 3, 150.0, 1), ] - expected_schema = StructType( - [ - StructField("name", StringType(), True), - StructField("country", StringType(), True), - StructField("date", StringType(), True), - StructField("product_id", IntegerType(), True), - StructField("prod_category_id", IntegerType(), True), - StructField("total_qty", LongType(), True), - StructField("total_value", DoubleType(), True), - ] - ) - df_expected = spark.createDataFrame(expected_data, schema=expected_schema) + df_expected = spark.createDataFrame(expected_data, schema=order_agg_schema) assertDataFrameEqual(df_out, df_expected) diff --git a/tests/job1/unit_test_sdp.py b/tests/job1/unit_test_sdp.py index 1eb738c..53206a3 100644 --- a/tests/job1/unit_test_sdp.py +++ b/tests/job1/unit_test_sdp.py @@ -10,20 +10,19 @@ pipelines are held to the same data contract. """ +from datetime import date + import pytest from pyspark.sql import DataFrame, SparkSession -from pyspark.sql.types import ( - DoubleType, - FloatType, - IntegerType, - LongType, - StringType, - StructField, - StructType, -) from pyspark.testing import assertDataFrameEqual -from template.commonSchemas import customer_schema, order_item_schema, order_schema +from template.commonSchemas import ( + customer_schema, + order_agg_schema, + order_enriched_schema, + order_item_schema, + order_schema, +) from template.job1_sdp.transforms import aggregate_orders, enrich_order @@ -39,23 +38,11 @@ def spark() -> SparkSession: def df_orders_enriched(spark) -> DataFrame: """Pre-joined enriched orders matching the output of enrich_order().""" data = [ - ("John Doe", 10, 1, 100.0, 1, "Item A", 2, 50.0), - ("John Doe", 10, 1, 100.0, 2, "Item B", 1, 50.0), - ("Jane Smith", 20, 2, 150.0, 1, "Item C", 3, 150.0), + ("John Doe", "USA", 10, 1, 100.0, date(2023, 1, 1), 1, 1, 1, "Item A", 2, 50.0), + ("John Doe", "USA", 10, 1, 100.0, date(2023, 1, 1), 1, 1, 2, "Item B", 1, 50.0), + ("Jane Smith", "UK", 20, 2, 150.0, date(2023, 1, 2), 2, 2, 1, "Item C", 3, 150.0), ] - schema = StructType( - [ - StructField("name", StringType(), True), - StructField("id_customer", IntegerType(), True), - StructField("id_order", IntegerType(), True), - StructField("total", FloatType(), True), - StructField("seq", IntegerType(), True), - StructField("desc_item", StringType(), True), - StructField("qty", IntegerType(), True), - StructField("total_item", FloatType(), True), - ] - ) - return spark.createDataFrame(data, schema=schema) + return spark.createDataFrame(data, schema=order_enriched_schema) # ── enrich_order ────────────────────────────────────────────────────────────── @@ -68,7 +55,7 @@ def test_enrich_order_row_count(spark, df_orders_enriched): schema=customer_schema, ) df_order = spark.createDataFrame( - [(1, 10, 100.0, "2023-01-01"), (2, 20, 150.0, "2023-01-02")], + [(1, 10, 100.0, "2023-01-01", 1, 1), (2, 20, 150.0, "2023-01-02", 2, 2)], schema=order_schema, ) df_order_item = spark.createDataFrame( @@ -84,10 +71,23 @@ def test_enrich_order_row_count(spark, df_orders_enriched): def test_enrich_order_columns(spark): """Output schema must contain exactly the declared columns.""" - expected_cols = {"name", "id_customer", "id_order", "total", "seq", "desc_item", "qty", "total_item"} + expected_cols = { + "customer_name", + "country", + "customer_id", + "order_id", + "order_total", + "order_date", + "product_id", + "product_category_id", + "item_seq", + "item_description", + "item_quantity", + "item_total", + } df_customer = spark.createDataFrame([(10, "Alice", "US")], schema=customer_schema) - df_order = spark.createDataFrame([(1, 10, 50.0, "2024-01-01")], schema=order_schema) + df_order = spark.createDataFrame([(1, 10, 50.0, "2024-01-01", 1, 1)], schema=order_schema) df_order_item = spark.createDataFrame([(1, 1, "Widget", 1, 50.0)], schema=order_item_schema) df_out = enrich_order(df_customer, df_order, df_order_item) @@ -108,16 +108,9 @@ def test_aggregate_orders_values(spark, df_orders_enriched): df_out = aggregate_orders(df_orders_enriched) expected_data = [ - ("John Doe", 3, 100.0), # qty: 2+1=3, total_item: 50.0+50.0=100.0 - ("Jane Smith", 3, 150.0), # qty: 3, total_item: 150.0 + ("John Doe", "USA", date(2023, 1, 1), 1, 1, 3, 100.0, 1), # item_quantity: 2+1=3, item_total: 50.0+50.0=100.0 + ("Jane Smith", "UK", date(2023, 1, 2), 2, 2, 3, 150.0, 1), # item_quantity: 3, item_total: 150.0 ] - expected_schema = StructType( - [ - StructField("name", StringType(), True), - StructField("total_qty", LongType(), True), - StructField("total_value", DoubleType(), True), - ] - ) - df_expected = spark.createDataFrame(expected_data, schema=expected_schema) + df_expected = spark.createDataFrame(expected_data, schema=order_agg_schema) assertDataFrameEqual(df_out, df_expected)