이번에 드디어 팀 프로젝트를 진행하게 되었다.

긴장 반 설렘 반...
제일 먼저 주어진 mission은 구름에서 제공한 baseline-code 에서 성능을 저하시키는 bug를 찾는 것이다.
우리 팀은 하루가 지나기 전에 bug를 발견했는데 코드를 보자면 다음과 같다.
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
|
def collate_fn_style(samples):
input_ids, labels = zip(*samples)
max_len = max(len(input_id) for input_id in input_ids)
sorted_indices = np.argsort([len(input_id) for input_id in input_ids])[::-1]
input_ids = pad_sequence([torch.tensor(input_ids[index]) for index in sorted_indices],
batch_first=True)
attention_mask = torch.tensor(
[[1] * len(input_ids[index]) + [0] * (max_len - len(input_ids[index])) for index in
sorted_indices])
token_type_ids = torch.tensor([[0] * len(input_ids[index]) for index in sorted_indices])
position_ids = torch.tensor([list(range(len(input_ids[index]))) for index in sorted_indices])
labels = torch.tensor(np.stack(labels, axis=0)[sorted_indices])
return input_ids, attention_mask, token_type_ids, position_ids, labels
|
cs |
1
2
3
4
5
6
7
8
9
10
11
12
13
14
|
def collate_fn_style_test(samples):
input_ids = samples
max_len = max(len(input_id) for input_id in input_ids)
sorted_indices = np.argsort([len(input_id) for input_id in input_ids])[::-1]
input_ids = pad_sequence([torch.tensor(input_ids[index]) for index in sorted_indices],
batch_first=True)
attention_mask = torch.tensor(
[[1] * len(input_ids[index]) + [0] * (max_len - len(input_ids[index])) for index in
sorted_indices])
token_type_ids = torch.tensor([[0] * len(input_ids[index]) for index in sorted_indices])
position_ids = torch.tensor([list(range(len(input_ids[index]))) for index in sorted_indices])
return input_ids, attention_mask, token_type_ids, position_ids
|
cs |
collate_fn_style 함수는 텍스트의 길이가 다 다르기에 input(id, label 리스트)을
동일한 크기로 맞춰주기 위한 패딩 작업을 진행하는 함수이다.
여기서 train에 적용하는 collate_fn_style 함수는 문제가 없지만
test는 데이터 특성상 train과 다르게 label을 제외한 값을 포함하기 때문에 input_Id만 가지고 있다.
label이 없는 상태에서 input값의 정렬을 변경한다면
성능을 평가할 때 제출한 test input과 test label이 적절하게 mapping이 되지 못하는 문제가 발생할 수 있다.
따라서 sorted_indices에서 sorting을 하지 않고 기존의 length대로 가져오면 된다.
변경한 코드는 다음과 같다.
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
|
def collate_fn_style_test(samples):
input_ids = samples
max_len = max(len(input_id) for input_id in input_ids)
# sorted_indices = np.argsort([len(input_id) for input_id in input_ids])[::-1]
sorted_indices = range(len(input_ids))
input_ids = pad_sequence([torch.tensor(input_ids[index]) for index in sorted_indices],
batch_first=True)
attention_mask = torch.tensor(
[[1] * len(input_ids[index]) + [0] * (max_len - len(input_ids[index])) for index in
sorted_indices])
token_type_ids = torch.tensor([[0] * len(input_ids[index]) for index in sorted_indices])
position_ids = torch.tensor([list(range(len(input_ids[index]))) for index in sorted_indices])
return input_ids, attention_mask, token_type_ids, position_ids
|
cs |
다음엔 데이콘 대회가 마감을 앞두고 있어 대회를 준비하다가 포스팅을 해야겠다.
1일 1포스팅... 너무어려워
'자연어처리 > 실습' 카테고리의 다른 글
한국 대중가요 가사 분석 프로젝트 (2) word cloud (0) | 2022.10.23 |
---|---|
MRC(기계독해) 실습 1 : JSON 데이터셋 불러오기 (Groom Competition) (0) | 2022.10.06 |
대화 텍스트로 감정 예측하기 대회 실습 (1) (0) | 2022.09.23 |
한국 대중 가요 가사 분석 프로젝트 (1) 빈도 분석 (0) | 2022.09.18 |
한국어 토크나이징 아주 간단하게! (복습용) (0) | 2022.09.17 |
댓글