diff --git a/graphgen/models/llm/local/vllm_wrapper.py b/graphgen/models/llm/local/vllm_wrapper.py index fc412b51..6ae5bf3b 100644 --- a/graphgen/models/llm/local/vllm_wrapper.py +++ b/graphgen/models/llm/local/vllm_wrapper.py @@ -18,10 +18,14 @@ def __init__( gpu_memory_utilization: float = 0.9, temperature: float = 0.6, top_p: float = 1.0, - topk: int = 5, + top_k: int = 5, **kwargs: Any, ): - super().__init__(temperature=temperature, top_p=top_p, **kwargs) + temperature = float(temperature) + top_p = float(top_p) + top_k = int(top_k) + + super().__init__(temperature=temperature, top_p=top_p, top_k=top_k, **kwargs) try: from vllm import AsyncEngineArgs, AsyncLLMEngine, SamplingParams except ImportError as exc: @@ -39,9 +43,6 @@ def __init__( disable_log_stats=False, ) self.engine = AsyncLLMEngine.from_engine_args(engine_args) - self.temperature = temperature - self.top_p = top_p - self.topk = topk @staticmethod def _build_inputs(prompt: str, history: Optional[List[str]] = None) -> str: @@ -89,7 +90,7 @@ async def generate_topk_per_token( sp = self.SamplingParams( temperature=0, max_tokens=1, - logprobs=self.topk, + logprobs=self.top_k, prompt_logprobs=1, )