Spaces:
Runtime error
Runtime error
| from utils import * | |
| import scipy.stats as stats | |
| def parse_args(): | |
| parser = argparse.ArgumentParser(description="A simple argument parser") | |
| parser.add_argument('--name', default='none', type=str) | |
| parser.add_argument('--path_exp', default=None, type=str) | |
| parser.add_argument('--path_ref', default=None, type=str) | |
| parser.add_argument('--use_tok', default=False, action='store_true') | |
| args = parser.parse_args() | |
| return args | |
| def read_dataset(data_path): | |
| print(f'Reading {data_path}...') | |
| with open(data_path, 'r', encoding='utf-8') as f: | |
| test_tgt = [json.loads(line) for line in f.readlines()] | |
| print(f'{len(test_tgt)} samples read.') | |
| gt_list = [i['targets'] for i in test_tgt] | |
| pred_list = [i['predictions'] for i in test_tgt] | |
| return gt_list, pred_list | |
| def t_test(mean_exp, std_exp, mean_ref, std_ref, n): | |
| numerator = mean_exp - mean_ref | |
| denominator = np.sqrt((std_exp**2 / n) + (std_ref**2 / n)) | |
| t_statistic = numerator / denominator | |
| df = (((std_exp**2 / n) + (std_ref**2 / n))**2) / (((std_exp**2 / n)**2 / (n-1)) + ((std_ref**2 / n)**2 / (n-1))) | |
| p_value = 2 * stats.t.sf(np.abs(t_statistic), df) | |
| return t_statistic, p_value | |
| def read_result(args): | |
| gt_list_exp, pred_list_exp = read_dataset(args.path_exp) | |
| gt_list_ref, pred_list_ref = read_dataset(args.path_ref) | |
| calculator = Metric_calculator() | |
| result_exp = calculator.get_result_list(gt_list_exp, pred_list_exp, args.use_tok) | |
| result_ref = calculator.get_result_list(gt_list_ref, pred_list_ref, args.use_tok) | |
| for key in ['bleu2', 'bleu4', 'rouge_1', 'rouge_2', 'rouge_l', 'lev_score', 'meteor_score']: | |
| if not isinstance(result_exp[key], list): | |
| continue | |
| levene_s, levene_p = stats.levene(result_exp[key], result_ref[key]) | |
| t_stat, p_val = stats.ttest_ind(result_exp[key], result_ref[key], equal_var=(levene_p > 0.05)) | |
| print(f'{key} (mean={float(np.mean(result_exp[key])):.4f}, levene p={levene_p:.3f}):\t{t_stat:.6f}\t{p_val}') | |
| if __name__ == "__main__": | |
| args=parse_args() | |
| read_result(args) |