본문 바로가기

NLP

[Day 7] 한권으로 끝내는 실전 LLM 파인튜닝 - 단일 GPU Gemma-2B-it 파인튜닝

다음은 <한권으로 끝내는 실전 LLM 파인튜닝> 도서의 스터디 Day7 요약입니다.

3.4 _ 단일 GPU를 활용한 Gemma-2B-it 파인튜닝

  • Gemma-2B 사용을 위해선 huggingface hub에서 access token을 발급받아야 함. 
  • Huggingface hub -> settings -> Acess Tokens에서 create_token 수행

3.4.2 Gemma 모델 준비

from huggingface_hub import login

api_token = 'YOUR_ACCESS_TOKEN'
login(api_token)

import torch
import wandb

from sklearn.model_selection import train_test_split
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    Trainer,
    TrainingArguments,
    pipeline
)
from transformers.integrations import WandbCallback
from trl import DataCollatorForCompletionOnlyLM
import evaluate

# 모델과 토크나이저 불러오기 
model_name = 'google/gemma-2b-it'
model = AutoModelForCausalLM.from_pretrained(model_name, 
                                             use_cache=False,
                                             device_map='auto',
                                             torch_dtype = torch.bfloat16,
                                             low_cpu_mem_usage=True,
                                             attn_implementation='eager',
                                             )
tokenizer = AutoTokenizer.from_pretrained(model_name)
  • AutoModelForCausalLM, AutoTokenizer : 저장된 모델과 토크나어지를 불러올 때 사용. 
  • use_cache=False : 모델이 예측할 때 임시 저장소(캐시)를 사용하지 않도록 함. 메모리 아낄 수 있으나 속도가 느려질 수 있다.
  • device_map='auto' : 모델이 어떤 장치에서 실행될지 자동으로 선택. GPU 번호를 할당할수도 있고, CPU 선택도 가능. 
  • torch_dtype : 파이토치에서 사용할 데이터 형식을 결정. 16bit 부동소수점 형식, bfloat16은 float16보다 더 넓은 범위의 수를 표현할 수 있고 수치 안정성이 더 뛰어남.
  • low_cpu_mem_usage=True : CPU 메모리 사용량을 최소화하면서 모델을 불러옴.
  • attn_implementation='flash_attention_2' : 모델의 어텐션 메커니즘을 flash_attention_2로 설정함. 대규모 언어 모델에서 자주 사용되는 최적화된 어텐션 알고리즘으로, 빠른 속도와 적은 메모리 사용량을 보장함. 

3.4.3 데이터셋 준비

  • huggingface hub에서 다른 사용자가 올린 데이터셋을 받아 활용할 수 있음. 여기에선 jaehy12 사용자가 올린 news3 데이터셋을 활용함. (news original, summary로 구성됨)

3.4.4 Gemma 모델의 기능 확인하기

키워드 추출 기능 확인

  • 위의 예시를 바탕으로 키워드 추출 기능을 다음과 같이 확인할 수 있다. 
def change_inference_chat_format(input_text, summary):
  return [
      {"role": "user", "content": f"{input_text}"},
      {"role": "assistant", "content": f"{summary}"},
      {"role": "user", "content": "중요한 키워드 5개를 뽑아주세요"},
      {"role":"assistant", "content": ""},
  ]

input_text = "다음 텍스트를 한국어로 간단히 요약해주세요:" + element['original']
summary = element['summary']
prompt = change_inference_chat_format(input_text, summary)
# tokenizer 초기화 및 적용
inputs = tokenizer.apply_chat_template(prompt, tokenize=True, add_generation_prompt=True, return_tensors='pt').to(model.device)
outputs = model.generate(input_ids=inputs.to(model.device), max_new_tokens=256)
print(tokenizer.decode(outputs[0], skip_special_tokens=True))

keyword 추출 결과

  • tokenizer.apply_chat_template 기능을 활용해 데이터를 대화 형식으로 바꿔 모델에 입력하게 됨. 키워드 추출 능력이 준수함을 확인할 수 있다. 

요약 기능 확인

def change_inference_chat_format(input_text):
  return [
      {"role": "user", "content": f"{input_text}"},
      # zero-shot summary example
      {"role": "assistant", "content": "한국어 요약:\n"},
  ]

input_text = "다음 텍스트를 한국어로 간단히 요약해주세요:" + element['original']
prompt = change_inference_chat_format(input_text)
inputs = tokenizer.apply_chat_template(prompt, tokenize=True, add_generation_prompt=True, return_tensors='pt').to(model.device)
outputs = model.generate(input_ids=inputs, max_new_tokens=256, use_cache=True)
print(tokenizer.decode(outputs[0], skip_special_tokens=True))

  • 정답 summary와 zero-shot 요약과 비교 : 
정답 summary Gemma-2b-it
물류센터·교회 발 코로나19 확산이 지역사회 N차 감염으로 이어지면서 확진자 동선을 확대 공개해야 한다는 주장이 제기된 가운데 사생활 침해 등을 근거로 반대하는 의견도 나오면서 동선공개 논란에 다시 불이 붙는 모양새다. 이후 사생활 침해 논란이 일자 질본은 감염병 예방에 필요한 정보에 한해 확진자 정보를 공개하라는 권고사항을 발표했다. 동선 공개 논란이 다시 일어나고 있으며, 확진자 정보를 공개하는 것이 논란으로 남아 있다. 특히, 지난 2월에는 확진자 동선이 대부분 공개됐지만, 최근 부천시 페이스북 신종 코로나바이러스 감염증(코로나19) 관련 게시물에는 확진자 동선 공개 범위를 성토하는 댓글이 여럿 달렸다.