-
Notifications
You must be signed in to change notification settings - Fork 5
feat: add Hugging Face fetch timeout flag #33
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
b734098
eeacf57
bd6ce16
acd0e9e
539fe49
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -47,7 +47,12 @@ def _get_hf_token() -> str | None: | |
|
|
||
| return None | ||
|
|
||
| def _make_request(url: str, headers: Dict[str, str] = None, limit: int | None = None) -> bytes: | ||
| def _make_request( | ||
| url: str, | ||
| headers: Dict[str, str] = None, | ||
| limit: int | None = None, | ||
| timeout: float = 10.0, | ||
| ) -> bytes: | ||
| if headers is None: | ||
| headers = {} | ||
|
|
||
|
|
@@ -57,7 +62,7 @@ def _make_request(url: str, headers: Dict[str, str] = None, limit: int | None = | |
|
|
||
| req = urllib.request.Request(url, headers=headers) | ||
| try: | ||
| with urllib.request.urlopen(req, timeout=10) as response: | ||
| with urllib.request.urlopen(req, timeout=timeout) as response: | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Potential user input in HTTP request may allow SSRF attack - medium severity Show fixRemediation: If possible, only allow requests to allowlisting domains. If not, consult the article linked above to learn about other mitigating techniques such as disabling redirects, blocking private IPs and making sure private services have internal authentication. If you return data coming from the request to the user, validate the data before returning it to make sure you don't return random data. Reply |
||
| if limit is not None: | ||
| return response.read(limit) | ||
| return response.read() | ||
|
|
@@ -68,16 +73,16 @@ def _make_request(url: str, headers: Dict[str, str] = None, limit: int | None = | |
| raise FileNotFoundError(f"Could not find repository or file on Hugging Face (404 Not Found): {url}") | ||
| raise | ||
|
|
||
| def _fetch_safetensors_header(repo_id: str, filename: str) -> Dict[str, Any]: | ||
| def _fetch_safetensors_header(repo_id: str, filename: str, timeout: float = 10.0) -> Dict[str, Any]: | ||
| url = f"{_get_hf_endpoint()}/{repo_id}/resolve/main/{filename}" | ||
|
|
||
| # 1. Fetch the first 500KB in a single roundtrip | ||
| headers = {"Range": "bytes=0-500000"} | ||
| try: | ||
| chunk = _make_request(url, headers=headers, limit=500000) | ||
| chunk = _make_request(url, headers=headers, limit=500000, timeout=timeout) | ||
| except urllib.error.HTTPError as e: | ||
| if e.code == 416: # Range Not Satisfiable (file is smaller than 500KB) | ||
| chunk = _make_request(url, limit=500000) | ||
| chunk = _make_request(url, limit=500000, timeout=timeout) | ||
| else: | ||
| raise | ||
|
|
||
|
|
@@ -92,18 +97,18 @@ def _fetch_safetensors_header(repo_id: str, filename: str) -> Dict[str, Any]: | |
| else: | ||
| # 3. Double-roundtrip only if the header is massive (>500KB) | ||
| headers = {"Range": f"bytes=8-{8+header_size-1}"} | ||
| json_bytes = _make_request(url, headers=headers, limit=header_size) | ||
| json_bytes = _make_request(url, headers=headers, limit=header_size, timeout=timeout) | ||
|
|
||
| return json.loads(json_bytes) | ||
|
|
||
| def fetch_huggingface_repo(repo_id: str, fetch_tensors: bool = False) -> Tuple[Dict[str, Any], Dict[str, Any] | None, str, float]: | ||
| def fetch_huggingface_repo(repo_id: str, fetch_tensors: bool = False, timeout: float = 10.0) -> Tuple[Dict[str, Any], Dict[str, Any] | None, str, float]: | ||
| """ | ||
| Fetches the metadata directly from the Hugging Face Hub over the network. | ||
| Returns: (tensors, config, format_name, disk_size) | ||
| """ | ||
| api_url = f"{_get_hf_endpoint()}/api/models/{repo_id}" | ||
| try: | ||
| api_data = json.loads(_make_request(api_url).decode("utf-8")) | ||
| api_data = json.loads(_make_request(api_url, timeout=timeout).decode("utf-8")) | ||
| except urllib.error.HTTPError as e: | ||
| if e.code == 401: | ||
| raise PermissionError(f"Gated/Private Model (401 Unauthorized). Set the HF_TOKEN environment variable to access {repo_id}") | ||
|
|
@@ -117,15 +122,15 @@ def fetch_huggingface_repo(repo_id: str, fetch_tensors: bool = False) -> Tuple[D | |
| config = None | ||
| if "config.json" in filenames: | ||
| config_url = f"{_get_hf_endpoint()}/{repo_id}/resolve/main/config.json" | ||
| config = json.loads(_make_request(config_url).decode("utf-8")) | ||
| config = json.loads(_make_request(config_url, timeout=timeout).decode("utf-8")) | ||
|
|
||
| tensors = {} | ||
| total_size = 0.0 | ||
|
|
||
| if "model.safetensors.index.json" in filenames: | ||
| # Sharded SafeTensors | ||
| index_url = f"{_get_hf_endpoint()}/{repo_id}/resolve/main/model.safetensors.index.json" | ||
| index_data = json.loads(_make_request(index_url).decode("utf-8")) | ||
| index_data = json.loads(_make_request(index_url, timeout=timeout).decode("utf-8")) | ||
|
|
||
| weight_map = index_data.get("weight_map", {}) | ||
| unique_shards = list(set(weight_map.values())) | ||
|
|
@@ -146,7 +151,7 @@ def fetch_huggingface_repo(repo_id: str, fetch_tensors: bool = False) -> Tuple[D | |
| } | ||
| else: | ||
| def fetch_shard(shard: str): | ||
| return shard, _fetch_safetensors_header(repo_id, shard) | ||
| return shard, _fetch_safetensors_header(repo_id, shard, timeout=timeout) | ||
|
|
||
| with concurrent.futures.ThreadPoolExecutor(max_workers=max(1, min(8, len(unique_shards)))) as executor: | ||
| future_to_shard = {executor.submit(fetch_shard, shard): shard for shard in unique_shards} | ||
|
|
@@ -172,12 +177,12 @@ def fetch_shard(shard: str): | |
| if token: | ||
| req.add_header("Authorization", f"Bearer {token}") | ||
| try: | ||
| with urllib.request.urlopen(req) as response: | ||
| with urllib.request.urlopen(req, timeout=timeout) as response: | ||
| total_size = int(response.headers.get("Content-Length", 0)) | ||
| except Exception: | ||
| pass | ||
|
|
||
| header = _fetch_safetensors_header(repo_id, "model.safetensors") | ||
| header = _fetch_safetensors_header(repo_id, "model.safetensors", timeout=timeout) | ||
| tensors = header | ||
|
|
||
| format_name = "SafeTensors" | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.