diff --git a/src/zarr/storage/_zip.py b/src/zarr/storage/_zip.py index 72bf9e335a..bc6677043f 100644 --- a/src/zarr/storage/_zip.py +++ b/src/zarr/storage/_zip.py @@ -103,6 +103,10 @@ def _sync_open(self) -> None: self._is_open = True + def _sync_ensure_open(self) -> None: + if not self._is_open: + self._sync_open() + async def _open(self) -> None: self._sync_open() @@ -120,12 +124,16 @@ def __setstate__(self, state: dict[str, Any]) -> None: def close(self) -> None: # docstring inherited + self._sync_ensure_open() + super().close() with self._lock: self._zf.close() async def clear(self) -> None: # docstring inherited + self._sync_ensure_open() + with self._lock: self._check_writable() self._zf.close() @@ -149,8 +157,7 @@ def _get( prototype: BufferPrototype, byte_range: ByteRequest | None = None, ) -> Buffer | None: - if not self._is_open: - self._sync_open() + self._sync_ensure_open() # docstring inherited try: with self._zf.open(key) as f: # will raise KeyError @@ -188,6 +195,7 @@ async def get_partial_values( key_ranges: Iterable[tuple[str, ByteRequest | None]], ) -> list[Buffer | None]: # docstring inherited + self._sync_ensure_open() out = [] with self._lock: for key, byte_range in key_ranges: @@ -195,8 +203,7 @@ async def get_partial_values( return out def _set(self, key: str, value: Buffer) -> None: - if not self._is_open: - self._sync_open() + self._sync_ensure_open() # generally, this should be called inside a lock keyinfo = zipfile.ZipInfo(filename=key, date_time=time.localtime(time.time())[:6]) keyinfo.compress_type = self.compression @@ -210,8 +217,7 @@ def _set(self, key: str, value: Buffer) -> None: async def set(self, key: str, value: Buffer) -> None: # docstring inherited self._check_writable() - if not self._is_open: - self._sync_open() + self._sync_ensure_open() assert isinstance(key, str) if not isinstance(value, Buffer): raise TypeError( @@ -222,6 +228,8 @@ async def set(self, key: str, value: Buffer) -> None: async def set_if_not_exists(self, key: str, value: Buffer) -> None: self._check_writable() + self._sync_ensure_open() + with self._lock: members = self._zf.namelist() if key not in members: @@ -245,6 +253,8 @@ async def delete(self, key: str) -> None: async def exists(self, key: str) -> bool: # docstring inherited + self._sync_ensure_open() + with self._lock: try: self._zf.getinfo(key) @@ -255,6 +265,8 @@ async def exists(self, key: str) -> bool: async def list(self) -> AsyncIterator[str]: # docstring inherited + self._sync_ensure_open() + with self._lock: for key in self._zf.namelist(): yield key diff --git a/tests/test_store/test_zip.py b/tests/test_store/test_zip.py index 744ee82945..7bfc4cfb0c 100644 --- a/tests/test_store/test_zip.py +++ b/tests/test_store/test_zip.py @@ -152,3 +152,17 @@ async def test_move(self, tmp_path: Path) -> None: assert destination.exists() assert not origin.exists() assert np.array_equal(array[...], np.arange(10)) + + async def test_lock_present(self, store: ZipStore) -> None: + buf = cpu.Buffer.from_bytes(b"bar") + await store.set("foo", buf) + await store.set_if_not_exists("foo", buf) + await store.exists("foo") + await store.get("foo", default_buffer_prototype()) + + async for _ in store.list(): + pass + + await store.clear() + + store.close()