import weave
from pydantic import PrivateAttr
from typing import Any, List, Dict, Optional
from unsloth import FastLanguageModel
import torch
class UnslothLoRAChatModel(weave.Model):
"""
모델 이름 이상의 파라미터를 저장하고 버전 관리하기 위해 추가 ChatModel 클래스를 정의합니다.
이를 통해 특정 파라미터에 대한 파인튜닝이 가능해집니다.
"""
chat_model: str
cm_temperature: float
cm_max_new_tokens: int
cm_quantize: bool
inference_batch_size: int
dtype: Any
device: str
_model: Any = PrivateAttr()
_tokenizer: Any = PrivateAttr()
def model_post_init(self, __context):
# wandb run 시작 및 아티팩트 다운로드
run = wandb.init(project=PROJECT, job_type="model_download")
artifact_ref = self.chat_model.replace("wandb-artifact://", "")
artifact = run.use_artifact(artifact_ref)
model_path = artifact.download()
# unsloth 버전 (기본적으로 2배 빠른 추론 활성화)
self._model, self._tokenizer = FastLanguageModel.from_pretrained(
model_name=model_path,
max_seq_length=self.cm_max_new_tokens,
dtype=self.dtype,
load_in_4bit=self.cm_quantize,
)
FastLanguageModel.for_inference(self._model)
@weave.op()
async def predict(self, query: List[str]) -> dict:
# add_generation_prompt = true - 생성을 위해 반드시 추가해야 함
input_ids = self._tokenizer.apply_chat_template(
query,
tokenize=True,
add_generation_prompt=True,
return_tensors="pt",
).to("cuda")
output_ids = self._model.generate(
input_ids=input_ids,
max_new_tokens=64,
use_cache=True,
temperature=1.5,
min_p=0.1,
)
decoded_outputs = self._tokenizer.batch_decode(
output_ids[0][input_ids.shape[1] :], skip_special_tokens=True
)
return "".join(decoded_outputs).strip()