diff --git a/.github/workflows/on-pr-comment.yml b/.github/workflows/on-pr-comment.yml index 870c415fe..3b166d57d 100644 --- a/.github/workflows/on-pr-comment.yml +++ b/.github/workflows/on-pr-comment.yml @@ -42,16 +42,29 @@ jobs: unit-test-python: if: ${{ github.event.issue.pull_request && (contains(github.event.comment.body, '/unit_test_py') || endsWith(github.event.comment.body, '/unit_test') || contains(github.event.comment.body, '/all_test')) }} runs-on: ubuntu-latest - # TODO(kmonte): Reduce this :( - timeout-minutes: 120 + timeout-minutes: 60 + strategy: + fail-fast: false + matrix: + include: + - name: "Type Check" + command: "make type_check" + - name: "Shard 0" + command: "make unit_test_py_shard SHARD_INDEX=0 TOTAL_SHARDS=4" + - name: "Shard 1" + command: "make unit_test_py_shard SHARD_INDEX=1 TOTAL_SHARDS=4" + - name: "Shard 2" + command: "make unit_test_py_shard SHARD_INDEX=2 TOTAL_SHARDS=4" + - name: "Shard 3" + command: "make unit_test_py_shard SHARD_INDEX=3 TOTAL_SHARDS=4" steps: - - name: Run Python Unit Tests + - name: Run Python Unit Tests (${{ matrix.name }}) uses: snapchat/gigl/.github/actions/run-command-on-pr@main with: github-token: ${{ secrets.GITHUB_TOKEN }} pr_number: ${{ github.event.issue.number }} should_leave_progress_comments: "true" - descriptive_workflow_name: "Python Unit Test" + descriptive_workflow_name: "Python Unit Test (${{ matrix.name }})" setup_gcloud: "true" # We use cloud run here instead of using github hosted runners because of limitation of tests # using GFile library (a.k.a anything that does IO w/ Tensorflow). GFile does not understand @@ -61,8 +74,7 @@ jobs: gcp_project_id: ${{ vars.GCP_PROJECT_ID }} workload_identity_provider: ${{ secrets.WORKLOAD_IDENTITY_PROVIDER }} gcp_service_account_email: ${{ secrets.GCP_SERVICE_ACCOUNT_EMAIL }} - command: | - make unit_test_py + command: ${{ matrix.command }} unit-test-scala: if: ${{ github.event.issue.pull_request && (contains(github.event.comment.body, '/unit_test_scala') || endsWith(github.event.comment.body, '/unit_test') || contains(github.event.comment.body, '/all_test')) }} @@ -87,23 +99,37 @@ jobs: integration-test: if: ${{ github.event.issue.pull_request && (contains(github.event.comment.body, '/integration_test') || contains(github.event.comment.body, '/all_test')) }} runs-on: ubuntu-latest - # TODO(kmonte): Reduce this :( - timeout-minutes: 120 + timeout-minutes: 60 + strategy: + fail-fast: false + matrix: + include: + - name: "Shard 0" + command: "make integration_test_shard SHARD_INDEX=0 TOTAL_SHARDS=4" + - name: "Shard 1" + command: "make integration_test_shard SHARD_INDEX=1 TOTAL_SHARDS=4" + - name: "Shard 2" + command: "make integration_test_shard SHARD_INDEX=2 TOTAL_SHARDS=4" + - name: "Shard 3" + command: "make integration_test_shard SHARD_INDEX=3 TOTAL_SHARDS=4" steps: - - name: Run Integration Tests + - name: Run Integration Tests (${{ matrix.name }}) uses: snapchat/gigl/.github/actions/run-command-on-pr@main with: github-token: ${{ secrets.GITHUB_TOKEN }} pr_number: ${{ github.event.issue.number }} should_leave_progress_comments: "true" - descriptive_workflow_name: "Integration Test" + descriptive_workflow_name: "Integration Test (${{ matrix.name }})" setup_gcloud: "true" + # We use cloud run here instead of using github hosted runners because of limitation of tests + # using GFile library (a.k.a anything that does IO w/ Tensorflow). GFile does not understand + # how to leverage Workload Identity Federation to read assets from GCS, et al. See: + # https://github.com/tensorflow/tensorflow/issues/57104 use_cloud_run: "true" gcp_project_id: ${{ vars.GCP_PROJECT_ID }} workload_identity_provider: ${{ secrets.WORKLOAD_IDENTITY_PROVIDER }} gcp_service_account_email: ${{ secrets.GCP_SERVICE_ACCOUNT_EMAIL }} - command: | - make integration_test + command: ${{ matrix.command }} integration-e2e-test: if: ${{ github.event.issue.pull_request && (contains(github.event.comment.body, '/e2e_test') || contains(github.event.comment.body, '/all_test')) }} diff --git a/.github/workflows/on-pr-merge.yml b/.github/workflows/on-pr-merge.yml index 0e1f9ddd0..46037b4af 100644 --- a/.github/workflows/on-pr-merge.yml +++ b/.github/workflows/on-pr-merge.yml @@ -23,6 +23,20 @@ jobs: # Our tests take a long time to run, so this is not ideal. if: github.event_name == 'merge_group' runs-on: ubuntu-latest + strategy: + fail-fast: false + matrix: + include: + - name: "Type Check" + command: "make type_check" + - name: "Shard 0" + command: "make unit_test_py_shard SHARD_INDEX=0 TOTAL_SHARDS=4" + - name: "Shard 1" + command: "make unit_test_py_shard SHARD_INDEX=1 TOTAL_SHARDS=4" + - name: "Shard 2" + command: "make unit_test_py_shard SHARD_INDEX=2 TOTAL_SHARDS=4" + - name: "Shard 3" + command: "make unit_test_py_shard SHARD_INDEX=3 TOTAL_SHARDS=4" steps: - uses: actions/checkout@v4 - name: Setup development environment @@ -32,16 +46,16 @@ jobs: gcp_project_id: ${{ vars.GCP_PROJECT_ID }} workload_identity_provider: ${{ secrets.workload_identity_provider }} gcp_service_account_email: ${{ secrets.gcp_service_account_email }} - - name: Run Python Unit Tests - # We use cloud run here instead of using github hosted runners because of limitation of tests + - name: Run Python Unit Tests (${{ matrix.name }}) + # We use Cloud Build instead of GitHub hosted runners because of limitation of tests # using GFile library (a.k.a anything that does IO w/ Tensorflow). GFile does not understand # how to leverage Workload Identity Federation to read assets from GCS, et al. See: # https://github.com/tensorflow/tensorflow/issues/57104 uses: ./.github/actions/run-cloud-run-command-on-active-checkout with: - cmd: "make unit_test_py" - service_account: ${{ secrets.gcp_service_account_email }} - project: ${{ vars.GCP_PROJECT_ID }} + cmd: ${{ matrix.command }} + service_account: ${{ secrets.gcp_service_account_email }} + project: ${{ vars.GCP_PROJECT_ID }} ci-unit-test-scala: # Because of limitation discussed https://github.com/orgs/community/discussions/46757#discussioncomment-4912738 @@ -73,6 +87,18 @@ jobs: ci-integration-test: if: github.event_name == 'merge_group' runs-on: ubuntu-latest + strategy: + fail-fast: false + matrix: + include: + - name: "Shard 0" + command: "make integration_test_shard SHARD_INDEX=0 TOTAL_SHARDS=4" + - name: "Shard 1" + command: "make integration_test_shard SHARD_INDEX=1 TOTAL_SHARDS=4" + - name: "Shard 2" + command: "make integration_test_shard SHARD_INDEX=2 TOTAL_SHARDS=4" + - name: "Shard 3" + command: "make integration_test_shard SHARD_INDEX=3 TOTAL_SHARDS=4" steps: - uses: actions/checkout@v4 - name: Setup development environment @@ -82,12 +108,16 @@ jobs: gcp_project_id: ${{ vars.GCP_PROJECT_ID }} workload_identity_provider: ${{ secrets.workload_identity_provider }} gcp_service_account_email: ${{ secrets.gcp_service_account_email }} - - name: Run Integration Tests + - name: Run Integration Tests (${{ matrix.name }}) + # We use Cloud Build instead of GitHub hosted runners because of limitation of tests + # using GFile library (a.k.a anything that does IO w/ Tensorflow). GFile does not understand + # how to leverage Workload Identity Federation to read assets from GCS, et al. See: + # https://github.com/tensorflow/tensorflow/issues/57104 uses: ./.github/actions/run-cloud-run-command-on-active-checkout with: - cmd: "make integration_test" - service_account: ${{ secrets.gcp_service_account_email }} - project: ${{ vars.GCP_PROJECT_ID }} + cmd: ${{ matrix.command }} + service_account: ${{ secrets.gcp_service_account_email }} + project: ${{ vars.GCP_PROJECT_ID }} ci-integration-e2e-test: if: github.event_name == 'merge_group' diff --git a/Makefile b/Makefile index e15a063f3..37c6b9706 100644 --- a/Makefile +++ b/Makefile @@ -23,6 +23,8 @@ DOCKER_IMAGE_DEV_WORKBENCH_NAME_WITH_TAG?=${DOCKER_IMAGE_DEV_WORKBENCH_NAME}:${D PYTHON_DIRS:=.github/scripts examples gigl tests snapchat scripts PY_TEST_FILES?="*_test.py" +SHARD_INDEX?=0 +TOTAL_SHARDS?=0 # You can override GIGL_TEST_DEFAULT_RESOURCE_CONFIG by setting it in your environment i.e. # adding `export GIGL_TEST_DEFAULT_RESOURCE_CONFIG=your_resource_config` to your shell config (~/.bashrc, ~/.zshrc, etc.) GIGL_TEST_DEFAULT_RESOURCE_CONFIG?=${PWD}/deployment/configs/unittest_resource_config.yaml @@ -81,6 +83,14 @@ unit_test_py: clean_build_files_py type_check --resource_config_uri=${GIGL_TEST_DEFAULT_RESOURCE_CONFIG} \ --test_file_pattern=$(PY_TEST_FILES) \ +# Runs a single shard of the Python unit tests (no type checking). +# Usage: make unit_test_py_shard SHARD_INDEX=0 TOTAL_SHARDS=4 +unit_test_py_shard: clean_build_files_py + uv run python -m tests.unit.main \ + --env=test \ + --resource_config_uri=${GIGL_TEST_DEFAULT_RESOURCE_CONFIG} \ + --test_file_pattern=$(PY_TEST_FILES) \ + --shard_index=$(SHARD_INDEX) --total_shards=$(TOTAL_SHARDS) unit_test_scala: clean_build_files_scala ( cd scala; sbt test ) @@ -121,6 +131,14 @@ integration_test: --resource_config_uri=${GIGL_TEST_DEFAULT_RESOURCE_CONFIG} \ --test_file_pattern=$(PY_TEST_FILES) \ +# Runs a single shard of the integration tests. +# Usage: make integration_test_shard SHARD_INDEX=0 TOTAL_SHARDS=4 +integration_test_shard: clean_build_files_py + uv run python -m tests.integration.main \ + --env=test \ + --resource_config_uri=${GIGL_TEST_DEFAULT_RESOURCE_CONFIG} \ + --test_file_pattern=$(PY_TEST_FILES) \ + --shard_index=$(SHARD_INDEX) --total_shards=$(TOTAL_SHARDS) notebooks_test: RESOURCE_CONFIG_PATH=${GIGL_TEST_DEFAULT_RESOURCE_CONFIG} python -m tests.config_tests.notebooks_test diff --git a/gigl/common/utils/test_utils.py b/gigl/common/utils/test_utils.py index ff7c7dc14..ee2791cb0 100644 --- a/gigl/common/utils/test_utils.py +++ b/gigl/common/utils/test_utils.py @@ -1,9 +1,10 @@ import argparse +import hashlib import time import unittest from concurrent.futures import ProcessPoolExecutor from dataclasses import dataclass -from typing import Iterator, Tuple +from typing import Iterator from unittest import TestCase from gigl.common import LocalUri @@ -14,18 +15,22 @@ @dataclass(frozen=True) class TestArgs: - """Container for CLI arguements to Python tests. + """Container for CLI arguments to Python tests. Attributes: - test_file_pattern (str): Glob pattern for filtering which test files to run. + test_file_pattern: Glob pattern for filtering which test files to run. See doc comment in `parse_args` for more details. + shard_index: Zero-based index of the current shard. + total_shards: Total number of shards. 0 means no sharding. """ test_file_pattern: str + shard_index: int = 0 + total_shards: int = 0 def parse_args() -> TestArgs: - """Parses test-exclusive CLI arguements.""" + """Parses test-exclusive CLI arguments.""" parser = argparse.ArgumentParser() parser.add_argument( "-tf", @@ -43,13 +48,110 @@ def parse_args() -> TestArgs: ``` """, ) + parser.add_argument( + "--shard_index", + type=int, + default=0, + help="Zero-based index of the current shard (used with --total_shards).", + ) + parser.add_argument( + "--total_shards", + type=int, + default=0, + help="Total number of shards. 0 or 1 means no sharding (run all tests).", + ) args, _ = parser.parse_known_args() - test_args = TestArgs(test_file_pattern=args.test_file_pattern) + test_args = TestArgs( + test_file_pattern=args.test_file_pattern, + shard_index=args.shard_index, + total_shards=args.total_shards, + ) logger.info(f"Test args: {test_args}") return test_args -def _run_individual_test(test: TestCase) -> Tuple[bool, int]: +def _get_shard_for_module( + module_name: str, + total_shards: int, + pinned_modules: tuple[str, ...], +) -> int: + """Returns the shard index a module should be assigned to. + + Pinned modules use their position in ``pinned_modules`` to determine the + shard (``index % total_shards``). All other modules fall back to + SHA-256 hashing. + + Args: + module_name: Fully-qualified module name. + total_shards: Total number of shards (must be >= 2). + pinned_modules: Ordered tuple of module names with deterministic + position-based shard assignment. + + Returns: + Zero-based shard index for the module. + """ + if module_name in pinned_modules: + return pinned_modules.index(module_name) % total_shards + hash_value = int(hashlib.sha256(module_name.encode()).hexdigest(), 16) + return hash_value % total_shards + + +def _filter_tests_by_shard( + suite: unittest.TestSuite, + shard_index: int, + total_shards: int, + pinned_modules: tuple[str, ...] = (), +) -> unittest.TestSuite: + """Filters a test suite to only include tests belonging to the given shard. + + Sharding is done at the file (module) level so that setUpClass/tearDownClass + are not split across shards. Pinned modules are assigned by their position + in ``pinned_modules`` (``index % total_shards``); all other modules use + SHA-256 hashing. + + Args: + suite: The full test suite discovered by unittest. + shard_index: Zero-based index of the current shard. + total_shards: Total number of shards. If <= 1, the suite is returned + unchanged. + pinned_modules: Ordered tuple of module names with deterministic + position-based shard assignment. + + Returns: + A new TestSuite containing only the tests assigned to this shard. + """ + if total_shards <= 1: + return suite + + filtered = unittest.TestSuite() + for test_group in suite: + module_name = _get_test_group_module_name(test_group) + if ( + _get_shard_for_module(module_name, total_shards, pinned_modules) + == shard_index + ): + filtered.addTest(test_group) + return filtered + + +def _get_test_group_module_name(test_group: unittest.TestSuite | TestCase) -> str: + """Extracts the module name from a test group for shard assignment. + + Args: + test_group: A test suite or individual test case. + + Returns: + The module name string used for hashing. + """ + if isinstance(test_group, unittest.TestSuite): + # Recurse into nested suites to find the first actual test case + for item in test_group: + return _get_test_group_module_name(item) + # TestCase instance — use its module + return type(test_group).__module__ + + +def _run_individual_test(test: TestCase) -> tuple[bool, int]: # If we don't have any test cases, we skip running the test. # This reduces some noise in the logs. if test.countTestCases() == 0: @@ -64,15 +166,26 @@ def _run_individual_test(test: TestCase) -> Tuple[bool, int]: def run_tests( - start_dir: LocalUri, pattern: str, use_sequential_execution: bool = False + start_dir: LocalUri, + pattern: str, + use_sequential_execution: bool = False, + shard_index: int = 0, + total_shards: int = 0, + pinned_modules: tuple[str, ...] = (), ) -> bool: - """ + """Discovers and runs tests, optionally filtering by shard. + Args: - start_dir (LocalUri): Local Directory for running tests - pattern (str): file text pattern for running tests - use_sequential_execution (bool): Whether sequential exection should be used - Return: - bool: Whether all tests passed successfully + start_dir: Local directory for running tests. + pattern: File text pattern for running tests. + use_sequential_execution: Whether sequential execution should be used. + shard_index: Zero-based index of the current shard. + total_shards: Total number of shards. 0 or 1 means no sharding. + pinned_modules: Ordered tuple of module names with deterministic + position-based shard assignment (``index % total_shards``). + + Returns: + Whether all tests passed successfully. """ start = time.perf_counter() @@ -83,6 +196,14 @@ def run_tests( pattern=pattern, ) + total_discovered: int = suite.countTestCases() + suite = _filter_tests_by_shard(suite, shard_index, total_shards, pinned_modules) + + if total_shards > 1: + logger.info( + f"Shard {shard_index}/{total_shards}: running {suite.countTestCases()}/{total_discovered} test cases" + ) + was_successful: bool total_num_test_cases: int = 0 @@ -92,7 +213,7 @@ def run_tests( total_num_test_cases = suite.countTestCases() else: with ProcessPoolExecutor() as executor: - was_successful_iter: Iterator[Tuple[bool, int]] = executor.map( + was_successful_iter: Iterator[tuple[bool, int]] = executor.map( _run_individual_test, suite._tests ) was_successful = True @@ -102,5 +223,5 @@ def run_tests( logger.info(f"Ran {total_num_test_cases}/{suite.countTestCases()} test cases") finish = time.perf_counter() - logger.info(f"It took {finish-start: .2f} second(s) to run tests") + logger.info(f"It took {finish - start: .2f} second(s) to run tests") return was_successful diff --git a/tests/integration/main.py b/tests/integration/main.py index fd1765afd..00a6638b3 100644 --- a/tests/integration/main.py +++ b/tests/integration/main.py @@ -1,4 +1,5 @@ import sys +from typing import Final import gigl.src.common.constants.local_fs as local_fs_constants from gigl.common import LocalUri @@ -6,8 +7,28 @@ from gigl.src.common.utils.metrics_service_provider import initialize_metrics from tests.test_assets.uri_constants import DEFAULT_NABLP_TASK_CONFIG_URI +# Slow test modules that must be spread across shards. Position in the tuple +# determines the shard: ``index % total_shards``. **Append-only** — never +# reorder existing entries, or every module's shard assignment will shift. +# +# Durations measured 2026-02-27 (unsharded CI run, 77.5 min total): +INTEGRATION_TEST_SHARD_PINNED_MODULES: Final[tuple[str, ...]] = ( + "tests.integration.distributed.distributed_dataset_test", # 14.5 min (18.7%) + "tests.integration.distributed.utils.networking_test", # 13.3 min (17.2%) + "tests.integration.distributed.graph_store.graph_store_integration_test", # 13.0 min (16.8%) + "tests.integration.pipeline.data_preprocessor.data_preprocessor_pipeline_test", # 11.7 min (15.1%) + "tests.integration.pipeline.subgraph_sampler.subgraph_sampler_test", # 8.8 min (11.4%) + "tests.integration.common.services.vertex_ai_test", # 6.5 min (8.4%) + "tests.integration.pipeline.split_generator.split_generator_pipeline_test", # 3.8 min (5.0%) + "tests.integration.pipeline.inferencer.inferencer_test", # 2.1 min (2.8%) +) -def run(pattern: str = "*_test.py") -> bool: + +def run( + pattern: str = "*_test.py", + shard_index: int = 0, + total_shards: int = 0, +) -> bool: initialize_metrics( task_config_uri=DEFAULT_NABLP_TASK_CONFIG_URI, service_name="integration_test" ) @@ -17,9 +38,17 @@ def run(pattern: str = "*_test.py") -> bool: ), pattern=pattern, use_sequential_execution=True, + shard_index=shard_index, + total_shards=total_shards, + pinned_modules=INTEGRATION_TEST_SHARD_PINNED_MODULES, ) if __name__ == "__main__": - was_successful: bool = run(pattern=parse_args().test_file_pattern) + test_args = parse_args() + was_successful: bool = run( + pattern=test_args.test_file_pattern, + shard_index=test_args.shard_index, + total_shards=test_args.total_shards, + ) sys.exit(not was_successful) diff --git a/tests/unit/common/utils/test_sharding_test.py b/tests/unit/common/utils/test_sharding_test.py new file mode 100644 index 000000000..5b741e4eb --- /dev/null +++ b/tests/unit/common/utils/test_sharding_test.py @@ -0,0 +1,338 @@ +import hashlib +import unittest + +import gigl.src.common.constants.local_fs as local_fs_constants +from gigl.common import LocalUri +from gigl.common.utils.test_utils import _filter_tests_by_shard, _get_shard_for_module +from tests.integration.main import INTEGRATION_TEST_SHARD_PINNED_MODULES +from tests.test_assets.test_case import TestCase +from tests.unit.main import UNIT_TEST_SHARD_PINNED_MODULES + + +def _extract_module_names(suite: unittest.TestSuite) -> list[str]: + """Extracts module names from a filtered test suite, preserving order. + + Assumes the suite has the two-level nesting produced by + ``_make_test_suite_with_modules``: outer suite → inner TestSuite per + module → individual TestCase(s). + + Args: + suite: A filtered test suite. + + Returns: + Ordered list of module name strings found in the suite. + """ + return [ + type(test_case).__module__ + for test_group in suite + if isinstance(test_group, unittest.TestSuite) + for test_case in test_group + ] + + +def _make_test_suite_with_modules(module_names: list[str]) -> unittest.TestSuite: + """Creates a test suite where each top-level group simulates a different module. + + Each module gets a dynamically created TestCase subclass with one test method, + mirroring the structure produced by ``unittest.TestLoader.discover()``. + + Args: + module_names: List of module name strings to simulate. + + Returns: + A TestSuite containing one nested TestSuite per module name. + """ + outer_suite = unittest.TestSuite() + for module_name in module_names: + # Dynamically create a TestCase class with a unique module + test_class = type( + f"TestFor_{module_name.replace('.', '_')}", + (unittest.TestCase,), + { + "test_placeholder": lambda self: None, + "__module__": module_name, + }, + ) + inner_suite = unittest.TestSuite([test_class("test_placeholder")]) + outer_suite.addTest(inner_suite) + return outer_suite + + +class FilterTestsByShardTest(TestCase): + """Tests for the _filter_tests_by_shard function.""" + + MODULES: list[str] = [ + "tests.unit.module_a_test", + "tests.unit.module_b_test", + "tests.unit.module_c_test", + "tests.unit.module_d_test", + "tests.unit.module_e_test", + "tests.unit.module_f_test", + "tests.unit.module_g_test", + "tests.unit.module_h_test", + ] + + def test_no_sharding_when_total_shards_is_zero(self) -> None: + suite = _make_test_suite_with_modules(self.MODULES) + result = _filter_tests_by_shard(suite, shard_index=0, total_shards=0) + self.assertEqual(result.countTestCases(), suite.countTestCases()) + + def test_no_sharding_when_total_shards_is_one(self) -> None: + suite = _make_test_suite_with_modules(self.MODULES) + result = _filter_tests_by_shard(suite, shard_index=0, total_shards=1) + self.assertEqual(result.countTestCases(), suite.countTestCases()) + + def test_all_tests_covered_across_shards(self) -> None: + """Union of all shards must equal the full suite.""" + total_shards = 4 + all_test_counts: list[int] = [] + for shard_index in range(total_shards): + suite = _make_test_suite_with_modules(self.MODULES) + result = _filter_tests_by_shard(suite, shard_index, total_shards) + all_test_counts.append(result.countTestCases()) + + self.assertEqual( + sum(all_test_counts), + len(self.MODULES), + f"Total tests across shards ({sum(all_test_counts)}) != total modules ({len(self.MODULES)})", + ) + + def test_no_overlap_between_shards(self) -> None: + """Each module must appear in exactly one shard.""" + total_shards = 4 + seen_modules: set[str] = set() + for shard_index in range(total_shards): + suite = _make_test_suite_with_modules(self.MODULES) + result = _filter_tests_by_shard(suite, shard_index, total_shards) + for test_group in result: + assert isinstance(test_group, unittest.TestSuite) + for test_case in test_group: + module = type(test_case).__module__ + self.assertNotIn( + module, + seen_modules, + f"Module {module} appeared in multiple shards", + ) + seen_modules.add(module) + + def test_deterministic_assignment(self) -> None: + """Running the same shard twice must produce identical results.""" + total_shards = 3 + shard_index = 1 + suite1 = _make_test_suite_with_modules(self.MODULES) + result1 = _filter_tests_by_shard(suite1, shard_index, total_shards) + modules1 = _extract_module_names(result1) + + suite2 = _make_test_suite_with_modules(self.MODULES) + result2 = _filter_tests_by_shard(suite2, shard_index, total_shards) + modules2 = _extract_module_names(result2) + + self.assertEqual(modules1, modules2) + + def test_each_shard_gets_at_least_one_test_when_enough_modules(self) -> None: + """With enough modules, each shard should get at least one test.""" + total_shards = 3 + for shard_index in range(total_shards): + suite = _make_test_suite_with_modules(self.MODULES) + result = _filter_tests_by_shard(suite, shard_index, total_shards) + self.assertGreater( + result.countTestCases(), + 0, + f"Shard {shard_index} got no tests", + ) + + +class ShardPinningTest(TestCase): + """Tests for manual shard pinning via pinned_modules.""" + + PINNED: tuple[str, ...] = ( + "tests.unit.distributed.dist_ablp_neighborloader_test", + "tests.unit.distributed.distributed_dataset_test", + "tests.unit.distributed.distributed_neighborloader_test", + "tests.unit.distributed.distributed_partitioner_test", + "tests.unit.distributed.utils.networking_test", + ) + + UNPINNED: list[str] = [ + "tests.unit.module_a_test", + "tests.unit.module_b_test", + "tests.unit.module_c_test", + "tests.unit.module_d_test", + "tests.unit.module_e_test", + ] + + def test_pinned_modules_assigned_by_position(self) -> None: + """Pinned module at index i is assigned to shard i % total_shards.""" + total_shards = 4 + for index, module_name in enumerate(self.PINNED): + expected_shard = index % total_shards + actual_shard = _get_shard_for_module(module_name, total_shards, self.PINNED) + self.assertEqual( + actual_shard, + expected_shard, + f"Pinned module {module_name} (index {index}) expected shard " + f"{expected_shard}, got {actual_shard}", + ) + + def test_pinned_modules_use_all_shards_with_four_shards(self) -> None: + """With 5 pinned modules and 4 shards, every shard gets at least one.""" + total_shards = 4 + assigned_shards = { + _get_shard_for_module(m, total_shards, self.PINNED) for m in self.PINNED + } + self.assertEqual( + len(assigned_shards), + min(len(self.PINNED), total_shards), + f"Expected {min(len(self.PINNED), total_shards)} distinct shards, " + f"got {assigned_shards}", + ) + + def test_full_coverage_no_overlap_with_pinned_and_unpinned(self) -> None: + """All modules appear exactly once across all shards.""" + total_shards = 4 + all_modules = list(self.PINNED) + self.UNPINNED + + seen_modules: set[str] = set() + for shard_index in range(total_shards): + fresh_suite = _make_test_suite_with_modules(all_modules) + result = _filter_tests_by_shard( + fresh_suite, shard_index, total_shards, pinned_modules=self.PINNED + ) + for test_group in result: + assert isinstance(test_group, unittest.TestSuite) + for test_case in test_group: + module = type(test_case).__module__ + self.assertNotIn( + module, + seen_modules, + f"Module {module} appeared in multiple shards", + ) + seen_modules.add(module) + + self.assertEqual( + seen_modules, + set(all_modules), + "Not all modules were covered across shards", + ) + + def test_pinning_across_various_total_shards(self) -> None: + """Pinned modules land on expected shards for several shard counts.""" + for total_shards in (2, 3, 4, 5, 8): + for index, module_name in enumerate(self.PINNED): + expected_shard = index % total_shards + actual_shard = _get_shard_for_module( + module_name, total_shards, self.PINNED + ) + self.assertEqual( + actual_shard, + expected_shard, + f"total_shards={total_shards}: pinned module {module_name} " + f"(index {index}) expected shard {expected_shard}, got {actual_shard}", + ) + + def test_unpinned_modules_use_hash(self) -> None: + """Unpinned modules still use SHA-256 hashing, unaffected by pinned list.""" + total_shards = 4 + for module_name in self.UNPINNED: + expected = ( + int(hashlib.sha256(module_name.encode()).hexdigest(), 16) % total_shards + ) + actual = _get_shard_for_module(module_name, total_shards, self.PINNED) + self.assertEqual( + actual, + expected, + f"Unpinned module {module_name} should use hash-based assignment", + ) + + def test_real_unit_test_pinned_modules_cover_all_shards(self) -> None: + """The actual UNIT_TEST_SHARD_PINNED_MODULES cover every shard with 4 shards.""" + total_shards = 4 + assigned_shards = { + _get_shard_for_module(m, total_shards, UNIT_TEST_SHARD_PINNED_MODULES) + for m in UNIT_TEST_SHARD_PINNED_MODULES + } + self.assertEqual( + assigned_shards, + set(range(total_shards)), + f"Expected all shards 0..{total_shards - 1} covered, got {assigned_shards}", + ) + + def test_real_integration_test_pinned_modules_cover_all_shards(self) -> None: + """The actual INTEGRATION_TEST_SHARD_PINNED_MODULES cover every shard with 4 shards.""" + total_shards = 4 + assigned_shards = { + _get_shard_for_module( + m, total_shards, INTEGRATION_TEST_SHARD_PINNED_MODULES + ) + for m in INTEGRATION_TEST_SHARD_PINNED_MODULES + } + self.assertEqual( + assigned_shards, + set(range(total_shards)), + f"Expected all shards 0..{total_shards - 1} covered, got {assigned_shards}", + ) + + +def _collect_test_ids(suite: unittest.TestSuite) -> set[str]: + """Recursively collects all individual test case IDs from a suite. + + Args: + suite: A (possibly nested) test suite. + + Returns: + Set of fully-qualified test IDs (e.g. ``module.Class.test_method``). + """ + ids: set[str] = set() + for item in suite: + if isinstance(item, unittest.TestSuite): + ids.update(_collect_test_ids(item)) + else: + ids.add(item.id()) + return ids + + +class RealDiscoveryShardingTest(TestCase): + """Discovers real unit tests and verifies sharding preserves them all.""" + + TOTAL_SHARDS: int = 4 + + @classmethod + def setUpClass(cls) -> None: + start_dir = LocalUri.join( + local_fs_constants.get_project_root_directory(), "tests", "unit" + ) + cls.start_dir = start_dir + full_suite = unittest.TestLoader().discover( + start_dir=start_dir.uri, pattern="*_test.py" + ) + cls.unsharded_test_ids = _collect_test_ids(full_suite) + + def test_sharded_tests_equal_unsharded(self) -> None: + """Union of test IDs across all shards equals the full unsharded set.""" + sharded_test_ids: set[str] = set() + for shard_index in range(self.TOTAL_SHARDS): + suite = unittest.TestLoader().discover( + start_dir=self.start_dir.uri, pattern="*_test.py" + ) + filtered = _filter_tests_by_shard( + suite, + shard_index, + self.TOTAL_SHARDS, + UNIT_TEST_SHARD_PINNED_MODULES, + ) + shard_ids = _collect_test_ids(filtered) + overlap = sharded_test_ids & shard_ids + self.assertEqual( + overlap, + set(), + f"Shard {shard_index} overlaps with previous shards: {overlap}", + ) + sharded_test_ids.update(shard_ids) + + self.assertEqual( + sharded_test_ids, + self.unsharded_test_ids, + f"Test ID mismatch.\n" + f" Only in sharded: {sharded_test_ids - self.unsharded_test_ids}\n" + f" Only in unsharded: {self.unsharded_test_ids - sharded_test_ids}", + ) diff --git a/tests/unit/main.py b/tests/unit/main.py index 83b3b75d4..28809710c 100644 --- a/tests/unit/main.py +++ b/tests/unit/main.py @@ -1,20 +1,46 @@ import sys +from typing import Final import gigl.src.common.constants.local_fs as local_fs_constants from gigl.common import LocalUri from gigl.common.utils.test_utils import parse_args, run_tests +# Slow test modules that must be spread across shards. Position in the tuple +# determines the shard: ``index % total_shards``. **Append-only** — never +# reorder existing entries, or every module's shard assignment will shift. +# +# Durations measured 2026-02-27 (unsharded CI run, 61.7 min total): +UNIT_TEST_SHARD_PINNED_MODULES: Final[tuple[str, ...]] = ( + "tests.unit.distributed.dist_ablp_neighborloader_test", # 24.7 min (40.0%) + "tests.unit.distributed.distributed_dataset_test", # 10.7 min (17.4%) + "tests.unit.distributed.distributed_neighborloader_test", # 9.6 min (15.5%) + "tests.unit.distributed.distributed_partitioner_test", # 6.5 min (10.5%) + "tests.unit.distributed.utils.networking_test", # 2.7 min (4.4%) +) -def run(pattern: str = "*_test.py") -> bool: + +def run( + pattern: str = "*_test.py", + shard_index: int = 0, + total_shards: int = 0, +) -> bool: return run_tests( start_dir=LocalUri.join( local_fs_constants.get_project_root_directory(), "tests", "unit" ), pattern=pattern, use_sequential_execution=True, + shard_index=shard_index, + total_shards=total_shards, + pinned_modules=UNIT_TEST_SHARD_PINNED_MODULES, ) if __name__ == "__main__": - was_successful: bool = run(pattern=parse_args().test_file_pattern) + test_args = parse_args() + was_successful: bool = run( + pattern=test_args.test_file_pattern, + shard_index=test_args.shard_index, + total_shards=test_args.total_shards, + ) sys.exit(not was_successful)