Skip to content

Adding multisample feature along with testcases#740

Open
VijayVignesh1 wants to merge 29 commits intoLightning-AI:mainfrom
VijayVignesh1:feature/add_multisample_support
Open

Adding multisample feature along with testcases#740
VijayVignesh1 wants to merge 29 commits intoLightning-AI:mainfrom
VijayVignesh1:feature/add_multisample_support

Conversation

@VijayVignesh1
Copy link
Copy Markdown
Contributor

@VijayVignesh1 VijayVignesh1 commented Oct 24, 2025

Before submitting
  • Was this discussed/agreed via a Github issue? (no need for typos and docs improvements)
  • Did you read the contributor guideline, Pull Request section?
  • Did you make sure to update the docs?
  • Did you write any new necessary tests?

What does this PR do?

Fixes #317

PR review

Added support for multisample item.

Basically added a sample_count parameter which creates a batch of sub samples for each sample, given a single transform function.

Note:
Multi-sample behavior applies only when the transform is passed to the
StreamingDataset constructor (i.e., via the `transform` argument),
and not when overriding `__init__` in this subclass. 

Sample code:

    def transform_fn_sq(x, sample_idx, *args, **kwargs):
        """A simple transform function that doubles the input."""
        return x * sample_idx

    dataset = StreamingDataset(
        data_dir,
        cache_dir=str(cache_dir),
        shuffle=False,
        transform=[transform_fn_sq],
        sample_count=3,
    )

Anyone in the community is free to review the PR once the tests have passed.
If we didn't discuss your PR in GitHub issues there's a high chance it will not be merged.

Did you have fun?

Make sure you had fun coding 🙃

@VijayVignesh1 VijayVignesh1 force-pushed the feature/add_multisample_support branch from 1b01b6f to 6a77302 Compare October 24, 2025 20:12
@VijayVignesh1
Copy link
Copy Markdown
Contributor Author

@tchaton @deependujha @bhimrazy Can you verify the approach once? I can then make changes to the README.

Comment thread src/litdata/streaming/dataset.py Outdated
Comment thread src/litdata/streaming/dataset.py Outdated
@codecov
Copy link
Copy Markdown

codecov Bot commented Oct 29, 2025

Codecov Report

❌ Patch coverage is 86.48649% with 5 lines in your changes missing coverage. Please review.
✅ Project coverage is 81%. Comparing base (90bd404) to head (7f95d5d).
⚠️ Report is 1 commits behind head on main.

Additional details and impacted files
@@         Coverage Diff         @@
##           main   #740   +/-   ##
===================================
- Coverage    81%    81%   -0%     
===================================
  Files        54     54           
  Lines      7613   7642   +29     
===================================
+ Hits       6140   6163   +23     
- Misses     1473   1479    +6     
🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

@VijayVignesh1 VijayVignesh1 marked this pull request as ready for review November 3, 2025 21:19
@bhimrazy bhimrazy marked this pull request as draft January 8, 2026 10:16
@bhimrazy
Copy link
Copy Markdown
Collaborator

Closing this PR due to inactivity. Please feel free to reopen or recreate it whenever convenient.
A clearer path forward can also be finalized through further discussion over issue.

@bhimrazy bhimrazy closed this Mar 19, 2026
@VijayVignesh1
Copy link
Copy Markdown
Contributor Author

Hi @bhimrazy,
I believe the required feature for this PR has already been completed. Would it be possible to reopen it so we can continue the review and move it forward? If there are any outstanding concerns, I’m happy to address them.

@deependujha deependujha reopened this Mar 23, 2026
@VijayVignesh1
Copy link
Copy Markdown
Contributor Author

VijayVignesh1 commented Mar 24, 2026

@deependujha The pipeline is failing on test_wav_deserialization in VideoSerializer, even though my PR didn’t touch this file. Do you know why it might be failing?

UPDATE: The latest TorchVision release (v0.26) has deprecated read_video. Our checks are failing because we are using this newest version of TorchVision.
More details: https://docs.pytorch.org/vision/0.25/io.html#video-deprecated
I have pinned the torchvision version to 0.25 and below for now.

@VijayVignesh1 VijayVignesh1 marked this pull request as ready for review March 24, 2026 18:56
@deependujha deependujha requested a review from Copilot March 24, 2026 20:02
Comment thread src/litdata/streaming/dataset.py Outdated
Comment thread src/litdata/streaming/dataset.py Outdated
Comment thread requirements.txt Outdated
Comment thread src/litdata/streaming/dataset.py Outdated
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copilot encountered an error and was unable to review this pull request. You can try again by re-requesting a review.

Comment thread requirements/test.txt Outdated
Comment thread src/litdata/streaming/dataset.py Outdated
Comment thread src/litdata/streaming/dataset.py Outdated
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 4 out of 4 changed files in this pull request and generated 7 comments.

Comments suppressed due to low confidence (1)

src/litdata/streaming/dataset.py:499

  • Resume logic appears inconsistent with the new inflated indexing. num_samples_yielded counts yielded samples (now multiplied by sample_count), but _replay_chunks_sampling subtracts raw interval sizes (not scaled by sample_count), so chunks_index / indexes will be wrong when sample_count>1. This will cause _resume() to restart from the wrong chunk/offset (repeats or skips). The replay math needs to incorporate sample_count (and ideally persist/validate sample_count in the state_dict).
        # TODO: Implement elastic sampling where the number of workers, ranks can change.
        num_samples_yielded = self._state_dict["num_samples_yielded"]

        worker_start = self.distributed_env.global_rank * num_workers
        worker_end = worker_start + num_workers

        # replay sampling from each worker / chunks using the batch size
        indexes = _replay_sampling(num_samples_yielded, batch_size, num_workers)
        chunks_index, indexes = _replay_chunks_sampling(
            workers_intervals={i: workers_intervals[j] for i, j in enumerate(range(worker_start, worker_end))},
            indexes=indexes,
        )

        # select the chunks and intervals associated to this worker
        worker_rank = self.distributed_env.global_rank * self.worker_env.world_size + self.worker_env.rank
        worker_local_rank = self.worker_env.rank

        self.num_chunks = len(workers_intervals[worker_rank])
        self.worker_next_chunk_index = chunks_index[worker_local_rank]
        self.worker_chunks = workers_chunks[worker_rank]
        self.worker_intervals = workers_intervals[worker_rank]

        if self.worker_next_chunk_index >= self.num_chunks:
            # This can happen when interrupting and resuming after some but not all workers are done.
            # Proceeding would result in an indexing error when attempting to access the next chunk.
            # To prevent this we exit early and let the worker raise a StopIteration in __next__.
            return

        # replay the indexes for the current chunks
        interval = self.worker_intervals[self.worker_next_chunk_index]

        # multiply the interval by the sample_count for multisample case
        current_indexes = np.arange(interval[1] * self.sample_count, interval[2] * self.sample_count)

        # re-shuffle the indexes
        current_indexes = self.shuffler(
            current_indexes, self.num_chunks, self.current_epoch, self.worker_next_chunk_index
        )

        # skip any indexes already consumed
        current_indexes = current_indexes[indexes[worker_local_rank] :]
        self.upcoming_indexes = current_indexes

        self.global_index = indexes[worker_local_rank]

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment thread src/litdata/streaming/dataset.py
Comment thread src/litdata/streaming/dataset.py Outdated
Comment thread src/litdata/streaming/dataset.py
Comment thread tests/streaming/test_dataloader.py Outdated
Comment thread tests/streaming/test_dataset.py Outdated
Comment thread README.md
Comment thread src/litdata/streaming/dataset.py Outdated
deependujha and others added 7 commits April 3, 2026 19:19
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
@deependujha
Copy link
Copy Markdown
Collaborator

Doesn't support resuming with multisample. Will be good to do it in the subsequent PR.

Also, sample_count word seems quite avg, what about num_variants, or something better? cc: @bhimrazy

Copy link
Copy Markdown
Collaborator

@deependujha deependujha left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the cool work. 🥳

Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 4 out of 4 changed files in this pull request and generated 8 comments.

Comments suppressed due to low confidence (1)

src/litdata/streaming/dataset.py:505

  • When sample_count > 1, resume logic will miscompute which chunk to resume from because _replay_chunks_sampling(...) subtracts chunk sizes in base-item units (interval[2]-interval[1]), while num_samples_yielded is counted in yielded (virtual) samples. This can cause resuming in the wrong chunk / offset. The replay logic needs to account for sample_count (e.g., treat each interval size as (interval[2]-interval[1]) * sample_count).
        # replay the indexes for the current chunks
        interval = self.worker_intervals[self.worker_next_chunk_index]

        # multiply the interval by the sample_count for multisample case
        current_indexes = np.arange(interval[1] * self.sample_count, interval[2] * self.sample_count)

        # re-shuffle the indexes
        current_indexes = self.shuffler(
            current_indexes, self.num_chunks, self.current_epoch, self.worker_next_chunk_index
        )

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment thread src/litdata/streaming/dataset.py
Comment thread src/litdata/streaming/dataset.py Outdated
Comment thread src/litdata/streaming/dataset.py
Comment thread src/litdata/streaming/dataset.py
Comment thread tests/streaming/test_dataloader.py Outdated
Comment thread tests/streaming/test_dataloader.py
Comment thread tests/streaming/test_dataloader.py
Comment thread tests/streaming/test_dataset.py
deependujha and others added 4 commits April 3, 2026 19:47
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
@deependujha
Copy link
Copy Markdown
Collaborator

requires #806

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Add support for multi sample item in optimize and yielding from the _getitem_ of the StreamingDataset

4 participants