Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 20 additions & 1 deletion src/modelinfo/cli.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import argparse
import json
import math
import os
import sys
from typing import Sequence
Expand Down Expand Up @@ -41,6 +42,13 @@ def _positive_int(value: str) -> int:
return ivalue


def _positive_float(value: str) -> float:
fvalue = float(value)
if not math.isfinite(fvalue) or fvalue <= 0:
raise argparse.ArgumentTypeError("timeout must be a finite number greater than 0")
return fvalue
Comment thread
coderabbitai[bot] marked this conversation as resolved.


def parse_args(argv: Sequence[str] | None = None) -> argparse.Namespace:
parser = argparse.ArgumentParser(
prog="modelinfo",
Expand Down Expand Up @@ -82,6 +90,12 @@ def parse_args(argv: Sequence[str] | None = None) -> argparse.Namespace:
action="store_true",
help="Deep dive: Fetch all remote tensor shards to display the exact tensor size breakdown.",
)
parser.add_argument(
"--timeout",
type=_positive_float,
default=10.0,
help="Network timeout in seconds for remote Hugging Face fetches.",
)
parser.add_argument(
"--topology",
type=str,
Expand Down Expand Up @@ -122,6 +136,7 @@ def analyze_model(
gpu_count: int = 1,
batch_size: int = 1,
fetch_tensors: bool = False,
timeout: float = 10.0,
topology: str = "pcie4",
strategy: str = "tp",
is_vllm: bool = False,
Expand All @@ -136,7 +151,9 @@ def analyze_model(

if not os.path.exists(file_path) and not file_path_lower.endswith((".safetensors", ".gguf", ".pt", ".bin", ".index.json")):
from modelinfo.parsers.huggingface import fetch_huggingface_repo
tensors, config, format_name, disk_size = fetch_huggingface_repo(file_path, fetch_tensors=fetch_tensors)
tensors, config, format_name, disk_size = fetch_huggingface_repo(
file_path, fetch_tensors=fetch_tensors, timeout=timeout
)
elif file_path_lower.endswith(".safetensors") or file_path_lower.endswith(".index.json"):
tensors = parse_safetensors_header(file_path)
format_name = "SafeTensors"
Expand Down Expand Up @@ -240,6 +257,7 @@ def main(argv: Sequence[str] | None = None) -> int:
gpu_count=gpu_count,
batch_size=args.batch_size,
fetch_tensors=args.tensors,
timeout=args.timeout,
topology=args.topology,
strategy=args.strategy,
is_vllm=args.vllm,
Expand All @@ -259,6 +277,7 @@ def main(argv: Sequence[str] | None = None) -> int:
gpu_count=gpu_count,
batch_size=args.batch_size,
fetch_tensors=args.tensors,
timeout=args.timeout,
topology=args.topology,
strategy=args.strategy,
is_vllm=args.vllm,
Expand Down
31 changes: 18 additions & 13 deletions src/modelinfo/parsers/huggingface.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,12 @@ def _get_hf_token() -> str | None:

return None

def _make_request(url: str, headers: Dict[str, str] = None, limit: int | None = None) -> bytes:
def _make_request(
url: str,
headers: Dict[str, str] = None,
limit: int | None = None,
timeout: float = 10.0,
) -> bytes:
if headers is None:
headers = {}

Expand All @@ -57,7 +62,7 @@ def _make_request(url: str, headers: Dict[str, str] = None, limit: int | None =

req = urllib.request.Request(url, headers=headers)
try:
with urllib.request.urlopen(req, timeout=10) as response:
with urllib.request.urlopen(req, timeout=timeout) as response:

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Potential user input in HTTP request may allow SSRF attack - medium severity
If an attacker can control the URL input leading into this HTTP request, the attack might be able to perform an SSRF attack. This kind of attack is even more dangerous if the application returns the response of the request to the user. It could allow them to retrieve information from higher privileged services within the network (such as the metadata service, which is commonly available in cloud services, and could allow them to retrieve credentials).

Show fix

Remediation: If possible, only allow requests to allowlisting domains. If not, consult the article linked above to learn about other mitigating techniques such as disabling redirects, blocking private IPs and making sure private services have internal authentication. If you return data coming from the request to the user, validate the data before returning it to make sure you don't return random data.

Reply @AikidoSec ignore: [REASON] to ignore this issue.
More info

if limit is not None:
return response.read(limit)
return response.read()
Expand All @@ -68,16 +73,16 @@ def _make_request(url: str, headers: Dict[str, str] = None, limit: int | None =
raise FileNotFoundError(f"Could not find repository or file on Hugging Face (404 Not Found): {url}")
raise

def _fetch_safetensors_header(repo_id: str, filename: str) -> Dict[str, Any]:
def _fetch_safetensors_header(repo_id: str, filename: str, timeout: float = 10.0) -> Dict[str, Any]:
url = f"{_get_hf_endpoint()}/{repo_id}/resolve/main/{filename}"

# 1. Fetch the first 500KB in a single roundtrip
headers = {"Range": "bytes=0-500000"}
try:
chunk = _make_request(url, headers=headers, limit=500000)
chunk = _make_request(url, headers=headers, limit=500000, timeout=timeout)
except urllib.error.HTTPError as e:
if e.code == 416: # Range Not Satisfiable (file is smaller than 500KB)
chunk = _make_request(url, limit=500000)
chunk = _make_request(url, limit=500000, timeout=timeout)
else:
raise

Expand All @@ -92,18 +97,18 @@ def _fetch_safetensors_header(repo_id: str, filename: str) -> Dict[str, Any]:
else:
# 3. Double-roundtrip only if the header is massive (>500KB)
headers = {"Range": f"bytes=8-{8+header_size-1}"}
json_bytes = _make_request(url, headers=headers, limit=header_size)
json_bytes = _make_request(url, headers=headers, limit=header_size, timeout=timeout)

return json.loads(json_bytes)

def fetch_huggingface_repo(repo_id: str, fetch_tensors: bool = False) -> Tuple[Dict[str, Any], Dict[str, Any] | None, str, float]:
def fetch_huggingface_repo(repo_id: str, fetch_tensors: bool = False, timeout: float = 10.0) -> Tuple[Dict[str, Any], Dict[str, Any] | None, str, float]:
"""
Fetches the metadata directly from the Hugging Face Hub over the network.
Returns: (tensors, config, format_name, disk_size)
"""
api_url = f"{_get_hf_endpoint()}/api/models/{repo_id}"
try:
api_data = json.loads(_make_request(api_url).decode("utf-8"))
api_data = json.loads(_make_request(api_url, timeout=timeout).decode("utf-8"))
except urllib.error.HTTPError as e:
if e.code == 401:
raise PermissionError(f"Gated/Private Model (401 Unauthorized). Set the HF_TOKEN environment variable to access {repo_id}")
Expand All @@ -117,15 +122,15 @@ def fetch_huggingface_repo(repo_id: str, fetch_tensors: bool = False) -> Tuple[D
config = None
if "config.json" in filenames:
config_url = f"{_get_hf_endpoint()}/{repo_id}/resolve/main/config.json"
config = json.loads(_make_request(config_url).decode("utf-8"))
config = json.loads(_make_request(config_url, timeout=timeout).decode("utf-8"))

tensors = {}
total_size = 0.0

if "model.safetensors.index.json" in filenames:
# Sharded SafeTensors
index_url = f"{_get_hf_endpoint()}/{repo_id}/resolve/main/model.safetensors.index.json"
index_data = json.loads(_make_request(index_url).decode("utf-8"))
index_data = json.loads(_make_request(index_url, timeout=timeout).decode("utf-8"))

weight_map = index_data.get("weight_map", {})
unique_shards = list(set(weight_map.values()))
Expand All @@ -146,7 +151,7 @@ def fetch_huggingface_repo(repo_id: str, fetch_tensors: bool = False) -> Tuple[D
}
else:
def fetch_shard(shard: str):
return shard, _fetch_safetensors_header(repo_id, shard)
return shard, _fetch_safetensors_header(repo_id, shard, timeout=timeout)

with concurrent.futures.ThreadPoolExecutor(max_workers=max(1, min(8, len(unique_shards)))) as executor:
future_to_shard = {executor.submit(fetch_shard, shard): shard for shard in unique_shards}
Expand All @@ -172,12 +177,12 @@ def fetch_shard(shard: str):
if token:
req.add_header("Authorization", f"Bearer {token}")
try:
with urllib.request.urlopen(req) as response:
with urllib.request.urlopen(req, timeout=timeout) as response:
total_size = int(response.headers.get("Content-Length", 0))
except Exception:
pass

header = _fetch_safetensors_header(repo_id, "model.safetensors")
header = _fetch_safetensors_header(repo_id, "model.safetensors", timeout=timeout)
tensors = header

format_name = "SafeTensors"
Expand Down
100 changes: 100 additions & 0 deletions tests/test_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,46 @@ def test_batch_size_flag_rejects_negative():
assert exc_info.value.code == 2


def test_timeout_flag_defaults_to_ten_seconds():
args = parse_args(["model.gguf"])

assert args.timeout == 10.0


def test_timeout_flag_accepts_float():
args = parse_args(["--timeout", "30.5", "model.gguf"])

assert args.timeout == 30.5


def test_timeout_flag_rejects_zero():
with pytest.raises(SystemExit) as exc_info:
parse_args(["--timeout", "0", "model.gguf"])

assert exc_info.value.code == 2


def test_timeout_flag_rejects_negative():
with pytest.raises(SystemExit) as exc_info:
parse_args(["--timeout", "-1", "model.gguf"])

assert exc_info.value.code == 2


def test_timeout_flag_rejects_nan():
with pytest.raises(SystemExit) as exc_info:
parse_args(["--timeout", "nan", "model.gguf"])

assert exc_info.value.code == 2


def test_timeout_flag_rejects_inf():
with pytest.raises(SystemExit) as exc_info:
parse_args(["--timeout", "inf", "model.gguf"])

assert exc_info.value.code == 2


def test_analyze_model_passes_batch_size_to_footprint(monkeypatch, tmp_path):
model_path = tmp_path / "model.gguf"
model_path.write_bytes(b"mock")
Expand Down Expand Up @@ -77,3 +117,63 @@ def fake_calculate_footprint(tensors, *, context_length, batch_size, **kwargs):

assert captured == {"batch_size": 4, "context_length": 128}
assert info["footprint"]["kv_cache_bytes"] == 4.0


def test_analyze_model_passes_timeout_to_huggingface(monkeypatch):
captured = {}

def fake_exists(path):
return False

def fake_fetch(repo_id, *, fetch_tensors, timeout):
captured["repo_id"] = repo_id
captured["fetch_tensors"] = fetch_tensors
captured["timeout"] = timeout
return (
{
"model.layers.0.self_attn.k_proj.weight": {
"shape": [1, 1],
"dtype": "F16",
}
},
None,
"SafeTensors",
7.0,
)

def fake_calculate_footprint(tensors, *, context_length, batch_size, **kwargs):
return {
"total_params": 1,
"base_memory_bytes": 2.0,
"kv_cache_bytes": 1.0,
"overhead_bytes": 0.0,
"total_memory_bytes": 3.0,
"num_layers": 1,
"kv_dim": 1,
"primary_dtype": "F16",
"kv_is_estimate": False,
"penalty_percentage": 0.0,
"vllm_metrics": {},
}

from modelinfo.parsers import huggingface

monkeypatch.setattr(cli.os.path, "exists", fake_exists)
monkeypatch.setattr(huggingface, "fetch_huggingface_repo", fake_fetch)
monkeypatch.setattr(cli, "calculate_footprint", fake_calculate_footprint)
monkeypatch.setattr(
cli, "identify_architecture_name", lambda tensors, num_layers, config: "Mock"
)

cli.analyze_model(
"org/model",
context_override=128,
fetch_tensors=True,
timeout=22.5,
)

assert captured == {
"repo_id": "org/model",
"fetch_tensors": True,
"timeout": 22.5,
}
Loading