diff --git a/model/dataset.py b/model/dataset.py index f3cc897..ef58956 100644 --- a/model/dataset.py +++ b/model/dataset.py @@ -80,18 +80,18 @@ class SFTDataset(Dataset): # sample = self.df.iloc[index] history = self.safe_eval(sample['history']) - q = sample['q'] - a = sample['a'] + q = str(sample['q']) + a = str(sample['a']) messages = [] for history_message in history: if len(history_message) <= 1: continue messages.append( - {"role": 'user', "content": history_message[0][:self.max_length // 2]} + {"role": 'user', "content": str(history_message[0])[:self.max_length // 2]} ) messages.append( - {"role": 'assistant', "content": history_message[1][:self.max_length // 2]} + {"role": 'assistant', "content": str(history_message[1])[:self.max_length // 2]} ) messages += [